diff --git a/CHANGELOG.md b/CHANGELOG.md index 9ca5b47e1..781f66fe9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -44,6 +44,7 @@ dj.FreeTable(dj.conn(), "common_session.session_group").drop() - Merge table delete removes orphaned master entries #1164 - Edit `merge_fetch` to expect positional before keyword arguments #1181 - Allow part restriction `SpyglassMixinPart.delete` #1192 +- Add mixin method `get_fully_defined_key` #1198 ### Pipelines @@ -62,6 +63,8 @@ dj.FreeTable(dj.conn(), "common_session.session_group").drop() - Decoding - Fix edge case errors in spike time loading #1083 + - Allow fetch of partial key from `DecodingParameters` #1198 + - Allow data fetching with partial but unique key #1198 - Linearization diff --git a/src/spyglass/decoding/v1/clusterless.py b/src/spyglass/decoding/v1/clusterless.py index fbfba2183..ae49f388e 100644 --- a/src/spyglass/decoding/v1/clusterless.py +++ b/src/spyglass/decoding/v1/clusterless.py @@ -316,8 +316,8 @@ def fetch_model(self): """Retrieve the decoding model""" return ClusterlessDetector.load_model(self.fetch1("classifier_path")) - @staticmethod - def fetch_environments(key): + @classmethod + def fetch_environments(cls, key): """Fetch the environments for the decoding model Parameters @@ -330,6 +330,9 @@ def fetch_environments(key): List[TrackGraph] list of track graphs in the trained model """ + key = cls.get_fully_defined_key( + key, required_fields=["decoding_param_name"] + ) model_params = ( DecodingParameters & {"decoding_param_name": key["decoding_param_name"]} @@ -355,8 +358,8 @@ def fetch_environments(key): return classifier.environments - @staticmethod - def fetch_position_info(key): + @classmethod + def fetch_position_info(cls, key): """Fetch the position information for the decoding model Parameters @@ -369,6 +372,15 @@ def fetch_position_info(key): Tuple[pd.DataFrame, List[str]] The position information and the names of the position variables """ + key = cls.get_fully_defined_key( + key, + required_fields=[ + "nwb_file_name", + "position_group_name", + "encoding_interval", + "decoding_interval", + ], + ) position_group_key = { "position_group_name": key["position_group_name"], "nwb_file_name": key["nwb_file_name"], @@ -381,8 +393,8 @@ def fetch_position_info(key): return position_info, position_variable_names - @staticmethod - def fetch_linear_position_info(key): + @classmethod + def fetch_linear_position_info(cls, key): """Fetch the position information and project it onto the track graph Parameters @@ -395,6 +407,16 @@ def fetch_linear_position_info(key): pd.DataFrame The linearized position information """ + key = cls.get_fully_defined_key( + key, + required_fields=[ + "nwb_file_name", + "position_group_name", + "encoding_interval", + "decoding_interval", + ], + ) + environment = ClusterlessDecodingV1.fetch_environments(key)[0] position_df = ClusterlessDecodingV1.fetch_position_info(key)[0] @@ -417,8 +439,8 @@ def fetch_linear_position_info(key): axis=1, ).loc[min_time:max_time] - @staticmethod - def fetch_spike_data(key, filter_by_interval=True): + @classmethod + def fetch_spike_data(cls, key, filter_by_interval=True): """Fetch the spike times for the decoding model Parameters @@ -434,6 +456,14 @@ def fetch_spike_data(key, filter_by_interval=True): list[np.ndarray] List of spike times for each unit in the model's spike group """ + key = cls.get_fully_defined_key( + key, + required_fields=[ + "nwb_file_name", + "waveform_features_group_name", + ], + ) + waveform_keys = ( ( UnitWaveformFeaturesGroup.UnitFeatures diff --git a/src/spyglass/decoding/v1/core.py b/src/spyglass/decoding/v1/core.py index 177a87d22..0e3d0fee4 100644 --- a/src/spyglass/decoding/v1/core.py +++ b/src/spyglass/decoding/v1/core.py @@ -70,20 +70,27 @@ def insert(self, rows, *args, **kwargs): def fetch(self, *args, **kwargs): """Return decoding parameters as a list of classes.""" rows = super().fetch(*args, **kwargs) - if len(rows) > 0 and len(rows[0]) > 1: + if kwargs.get("format", None) == "array": + # case when recalled by dj.fetch(), class conversion performed later in stack + return rows + + if not len(args): + # infer args from table heading + args = tuple(self.heading) + + if "decoding_params" not in args: + return rows + + params_index = args.index("decoding_params") + if len(args) == 1: + # only fetching decoding_params + content = [restore_classes(r) for r in rows] + elif len(rows): content = [] - for ( - decoding_param_name, - decoding_params, - decoding_kwargs, - ) in rows: - content.append( - ( - decoding_param_name, - restore_classes(decoding_params), - decoding_kwargs, - ) - ) + for row in zip(*rows): + row = list(row) + row[params_index] = restore_classes(row[params_index]) + content.append(tuple(row)) else: content = rows return content @@ -91,7 +98,20 @@ def fetch(self, *args, **kwargs): def fetch1(self, *args, **kwargs): """Return one decoding paramset as a class.""" row = super().fetch1(*args, **kwargs) - row["decoding_params"] = restore_classes(row["decoding_params"]) + + if len(args) == 0: + row["decoding_params"] = restore_classes(row["decoding_params"]) + return row + + if "decoding_params" in args: + if len(args) == 1: + return restore_classes(row) + row = list(row) + row[args.index("decoding_params")] = restore_classes( + row[args.index("decoding_params")] + ) + return tuple(row) + return row diff --git a/src/spyglass/decoding/v1/sorted_spikes.py b/src/spyglass/decoding/v1/sorted_spikes.py index 9e4c2c3ba..7b4ede194 100644 --- a/src/spyglass/decoding/v1/sorted_spikes.py +++ b/src/spyglass/decoding/v1/sorted_spikes.py @@ -275,8 +275,8 @@ def fetch_model(self): """Retrieve the decoding model""" return SortedSpikesDetector.load_model(self.fetch1("classifier_path")) - @staticmethod - def fetch_environments(key): + @classmethod + def fetch_environments(cls, key): """Fetch the environments for the decoding model Parameters @@ -289,6 +289,10 @@ def fetch_environments(key): List[TrackGraph] list of track graphs in the trained model """ + key = cls.get_fully_defined_key( + key, required_fields=["decoding_param_name"] + ) + model_params = ( DecodingParameters & {"decoding_param_name": key["decoding_param_name"]} @@ -314,8 +318,8 @@ def fetch_environments(key): return classifier.environments - @staticmethod - def fetch_position_info(key): + @classmethod + def fetch_position_info(cls, key): """Fetch the position information for the decoding model Parameters @@ -328,6 +332,16 @@ def fetch_position_info(key): Tuple[pd.DataFrame, List[str]] The position information and the names of the position variables """ + key = cls.get_fully_defined_key( + key, + required_fields=[ + "position_group_name", + "nwb_file_name", + "encoding_interval", + "decoding_interval", + ], + ) + position_group_key = { "position_group_name": key["position_group_name"], "nwb_file_name": key["nwb_file_name"], @@ -339,8 +353,8 @@ def fetch_position_info(key): return position_info, position_variable_names - @staticmethod - def fetch_linear_position_info(key): + @classmethod + def fetch_linear_position_info(cls, key): """Fetch the position information and project it onto the track graph Parameters @@ -353,6 +367,16 @@ def fetch_linear_position_info(key): pd.DataFrame The linearized position information """ + key = cls.get_fully_defined_key( + key, + required_fields=[ + "position_group_name", + "nwb_file_name", + "encoding_interval", + "decoding_interval", + ], + ) + environment = SortedSpikesDecodingV1.fetch_environments(key)[0] position_df = SortedSpikesDecodingV1.fetch_position_info(key)[0] @@ -374,9 +398,13 @@ def fetch_linear_position_info(key): axis=1, ).loc[min_time:max_time] - @staticmethod + @classmethod def fetch_spike_data( - key, filter_by_interval=True, time_slice=None, return_unit_ids=False + cls, + key, + filter_by_interval=True, + time_slice=None, + return_unit_ids=False, ) -> Union[list[np.ndarray], Optional[list[dict]]]: """Fetch the spike times for the decoding model @@ -399,6 +427,14 @@ def fetch_spike_data( list[np.ndarray] List of spike times for each unit in the model's spike group """ + key = cls.get_fully_defined_key( + key, + required_fields=[ + "encoding_interval", + "decoding_interval", + ], + ) + spike_times, unit_ids = SortedSpikesGroup.fetch_spike_data( key, return_unit_ids=True ) diff --git a/src/spyglass/spikesorting/analysis/v1/group.py b/src/spyglass/spikesorting/analysis/v1/group.py index 2f862c4fb..34041117b 100644 --- a/src/spyglass/spikesorting/analysis/v1/group.py +++ b/src/spyglass/spikesorting/analysis/v1/group.py @@ -3,7 +3,6 @@ import datajoint as dj import numpy as np -from ripple_detection import get_multiunit_population_firing_rate from spyglass.common import Session # noqa: F401 from spyglass.settings import test_mode @@ -127,9 +126,12 @@ def filter_units( include_mask[ind] = True return include_mask - @staticmethod + @classmethod def fetch_spike_data( - key: dict, time_slice: list[float] = None, return_unit_ids: bool = False + cls, + key: dict, + time_slice: list[float] = None, + return_unit_ids: bool = False, ) -> Union[list[np.ndarray], Optional[list[dict]]]: """fetch spike times for units in the group @@ -148,6 +150,8 @@ def fetch_spike_data( list of np.ndarray list of spike times for each unit in the group """ + key = cls.get_fully_defined_key(key) + # get merge_ids for SpikeSortingOutput merge_ids = ( ( diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index 91cf35870..2df4844a4 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -137,6 +137,28 @@ def _safe_context(cls): else nullcontext() ) + @classmethod + def get_fully_defined_key( + cls, key: dict = None, required_fields: list[str] = None + ) -> dict: + if key is None: + key = dict() + + required_fields = required_fields or cls.primary_key + if isinstance(key, (str, dict)): # check is either keys or substrings + if not all( + field in key for field in required_fields + ): # check if all required fields are in key + if not len(query := cls() & key) == 1: # check if key is unique + raise KeyError( + f"Key is neither fully specified nor a unique entry in" + + f"table.\n\tTable: {cls.full_table_name}\n\tKey: {key}" + + f"Required fields: {required_fields}\n\tResult: {query}" + ) + key = query.fetch1("KEY") + + return key + # ------------------------------- fetch_nwb ------------------------------- @cached_property