diff --git a/src/spyglass/behavior/v1/moseq.py b/src/spyglass/behavior/v1/moseq.py index f838976ce..e683596a5 100644 --- a/src/spyglass/behavior/v1/moseq.py +++ b/src/spyglass/behavior/v1/moseq.py @@ -193,19 +193,37 @@ def make(self, key): } ) - def _make_model_name(self, key: dict = None): + def _make_model_name(self, key: dict): # make a unique model name based on the key - if key is None: - key = {} key = (MoseqModelSelection & key).fetch1("KEY") return dj.hash.key_hash(key) @staticmethod def _initialize_model( - data, metadata, project_dir, model_name, config, model_params + data: dict, + metadata: tuple, + project_dir: str, + model_name: str, + config: dict, + model_params: dict, ): """Method to initialize a model. Creates model and runs initional ARHMM fit + Parameters + ---------- + data : dict + data dictionary (get from kpms.format_data) + metadata : tuple + metadata tuple (get from kpms.format_data) + project_dir : str + path to the project directory + model_name : str + name of the model + config : dict + keypoint moseq config + model_params : dict + params dictionary fetched from spyglass parameter table entry + Returns ------- tuple @@ -229,20 +247,20 @@ def _initialize_model( model_name=model_name + "_ar", ) - def analyze_pca(self, key: dict = None): + def analyze_pca(self, key: dict, explained_variace: float = 0.9): """Method to analyze the PCA of a model Parameters ---------- key : dict key to a single MoseqModel table entry + explained_variace : float, optional + minimum explained variance to print, by default 0.9 """ - if key is None: - key = {} project_dir = (self & key).fetch1("project_dir") pca = kpms.load_pca(project_dir) config = kpms.load_config(project_dir) - kpms.print_dims_to_explain_variance(pca, 0.9) + kpms.print_dims_to_explain_variance(pca, explained_variace) kpms.plot_scree(pca, project_dir=project_dir) kpms.plot_pcs(pca, project_dir=project_dir, **config) @@ -280,8 +298,9 @@ def get_training_progress_path(self, key: dict = None): """ if key is None: key = {} - project_dir = (self & key).fetch1("project_dir") - model_name = (self & key).fetch1("model_name") + project_dir, model_name = (self & key).fetch1( + "project_dir", "model_name" + ) return f"{project_dir}/{model_name}/fitting_progress.pdf" @@ -306,7 +325,9 @@ def validate_bodyparts(self, key): model_bodyparts = (PoseGroup & key).fetch1("bodyparts") merge_key = {"merge_id": key["pose_merge_id"]} bodyparts_df = (PositionOutput & merge_key).fetch_pose_dataframe() - data_bodyparts = bodyparts_df.keys().get_level_values(0).unique().values + data_bodyparts = MoseqSyllable.get_bodyparts_from_dataframe( + bodyparts_df + ) missing = [bp for bp in model_bodyparts if bp not in data_bodyparts] if missing: @@ -338,11 +359,11 @@ def make(self, key): # load data and format for moseq merge_query = PositionOutput & merge_key video_path = merge_query.fetch_video_path() - video_name = Path(video_path).stem + ".mp4" + video_name = Path(video_path).name bodyparts_df = merge_query.fetch_pose_dataframe() if bodyparts is None: - bodyparts = bodyparts_df.keys().get_level_values(0).unique().values + bodyparts = self.get_bodyparts_from_dataframe(bodyparts_df) datasets = {video_name: bodyparts_df[bodyparts]} coordinates, confidences = format_dataset_for_moseq(datasets, bodyparts) data, metadata = kpms.format_data(coordinates, confidences, **config) @@ -379,3 +400,19 @@ def fetch1_dataframe(self): dataframe = self.fetch_nwb()[0]["moseq"] dataframe.set_index("time", inplace=True) return dataframe + + @staticmethod + def get_bodyparts_from_dataframe(dataframe): + """Method to get the list of bodyparts from a dataframe + + Parameters + ---------- + dataframe : pd.DataFrame + dataframe with bodypart data from PositionOutput + + Returns + ------- + List[str] + list of bodyparts + """ + return dataframe.keys().get_level_values(0).unique().values