Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decoding qol updates #1198

Merged
merged 11 commits into from
Dec 5, 2024
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ dj.FreeTable(dj.conn(), "common_session.session_group").drop()
- Remove numpy version restriction #1169
- Merge table delete removes orphaned master entries #1164
- Edit `merge_fetch` to expect positional before keyword arguments #1181
- Add mixin method `get_fully_defined_key` #1198

### Pipelines

Expand All @@ -57,6 +58,8 @@ dj.FreeTable(dj.conn(), "common_session.session_group").drop()
- Decoding

- Fix edge case errors in spike time loading #1083
- Allow fetch of partial key from `DecodingParameters` #1198
- Allow data fetching with partial but unique key #1198

- Linearization

Expand Down
46 changes: 38 additions & 8 deletions src/spyglass/decoding/v1/clusterless.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,8 @@ def fetch_model(self):
"""Retrieve the decoding model"""
return ClusterlessDetector.load_model(self.fetch1("classifier_path"))

@staticmethod
def fetch_environments(key):
@classmethod
def fetch_environments(cls, key):
"""Fetch the environments for the decoding model

Parameters
Expand All @@ -329,6 +329,9 @@ def fetch_environments(key):
List[TrackGraph]
list of track graphs in the trained model
"""
key = cls.get_fully_defined_key(
key, required_fields=["decoding_param_name"]
)
model_params = (
DecodingParameters
& {"decoding_param_name": key["decoding_param_name"]}
Expand All @@ -354,8 +357,8 @@ def fetch_environments(key):

return classifier.environments

@staticmethod
def fetch_position_info(key):
@classmethod
def fetch_position_info(cls, key):
"""Fetch the position information for the decoding model

Parameters
Expand All @@ -368,6 +371,15 @@ def fetch_position_info(key):
Tuple[pd.DataFrame, List[str]]
The position information and the names of the position variables
"""
key = cls.get_fully_defined_key(
key,
required_fields=[
"nwb_file_name",
"position_group_name",
"encoding_interval",
"decoding_interval",
],
)
position_group_key = {
"position_group_name": key["position_group_name"],
"nwb_file_name": key["nwb_file_name"],
Expand All @@ -380,8 +392,8 @@ def fetch_position_info(key):

return position_info, position_variable_names

@staticmethod
def fetch_linear_position_info(key):
@classmethod
def fetch_linear_position_info(cls, key):
"""Fetch the position information and project it onto the track graph

Parameters
Expand All @@ -394,6 +406,16 @@ def fetch_linear_position_info(key):
pd.DataFrame
The linearized position information
"""
key = cls.get_fully_defined_key(
key,
required_fields=[
"nwb_file_name",
"position_group_name",
"encoding_interval",
"decoding_interval",
],
)

environment = ClusterlessDecodingV1.fetch_environments(key)[0]

position_df = ClusterlessDecodingV1.fetch_position_info(key)[0]
Expand All @@ -416,8 +438,8 @@ def fetch_linear_position_info(key):
axis=1,
).loc[min_time:max_time]

@staticmethod
def fetch_spike_data(key, filter_by_interval=True):
@classmethod
def fetch_spike_data(cls, key, filter_by_interval=True):
"""Fetch the spike times for the decoding model

Parameters
Expand All @@ -433,6 +455,14 @@ def fetch_spike_data(key, filter_by_interval=True):
list[np.ndarray]
List of spike times for each unit in the model's spike group
"""
key = cls.get_fully_defined_key(
key,
required_fields=[
"nwb_file_name",
"waveform_features_group_name",
],
)

waveform_keys = (
(
UnitWaveformFeaturesGroup.UnitFeatures
Expand Down
48 changes: 34 additions & 14 deletions src/spyglass/decoding/v1/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,28 +69,48 @@ def insert(self, rows, *args, **kwargs):
def fetch(self, *args, **kwargs):
"""Return decoding parameters as a list of classes."""
rows = super().fetch(*args, **kwargs)
if len(rows) > 0 and len(rows[0]) > 1:
if kwargs.get("format", None) == "array":
# case when recalled by dj.fetch(), class conversion performed later in stack
return rows

if not len(args):
# infer args from table heading
args = tuple(self.heading)

if "decoding_params" not in args:
return rows

params_index = args.index("decoding_params")
if len(args) == 1:
# only fetching decoding_params
content = [restore_classes(r) for r in rows]
elif len(rows):
content = []
for (
decoding_param_name,
decoding_params,
decoding_kwargs,
) in rows:
content.append(
(
decoding_param_name,
restore_classes(decoding_params),
decoding_kwargs,
)
)
for row in zip(*rows):
row = list(row)
row[params_index] = restore_classes(row[params_index])
content.append(tuple(row))
else:
content = rows
return content

def fetch1(self, *args, **kwargs):
"""Return one decoding paramset as a class."""
row = super().fetch1(*args, **kwargs)
row["decoding_params"] = restore_classes(row["decoding_params"])

if len(args) == 0:
row["decoding_params"] = restore_classes(row["decoding_params"])
return row

if "decoding_params" in args:
if len(args) == 1:
return restore_classes(row)
row = list(row)
row[args.index("decoding_params")] = restore_classes(
row[args.index("decoding_params")]
)
return tuple(row)

return row


Expand Down
52 changes: 44 additions & 8 deletions src/spyglass/decoding/v1/sorted_spikes.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,8 @@ def fetch_model(self):
"""Retrieve the decoding model"""
return SortedSpikesDetector.load_model(self.fetch1("classifier_path"))

@staticmethod
def fetch_environments(key):
@classmethod
def fetch_environments(cls, key):
"""Fetch the environments for the decoding model

Parameters
Expand All @@ -289,6 +289,10 @@ def fetch_environments(key):
List[TrackGraph]
list of track graphs in the trained model
"""
key = cls.get_fully_defined_key(
key, required_fields=["decoding_param_name"]
)

model_params = (
DecodingParameters
& {"decoding_param_name": key["decoding_param_name"]}
Expand All @@ -314,8 +318,8 @@ def fetch_environments(key):

return classifier.environments

@staticmethod
def fetch_position_info(key):
@classmethod
def fetch_position_info(cls, key):
"""Fetch the position information for the decoding model

Parameters
Expand All @@ -328,6 +332,16 @@ def fetch_position_info(key):
Tuple[pd.DataFrame, List[str]]
The position information and the names of the position variables
"""
key = cls.get_fully_defined_key(
key,
required_fields=[
"position_group_name",
"nwb_file_name",
"encoding_interval",
"decoding_interval",
],
)

position_group_key = {
"position_group_name": key["position_group_name"],
"nwb_file_name": key["nwb_file_name"],
Expand All @@ -339,8 +353,8 @@ def fetch_position_info(key):

return position_info, position_variable_names

@staticmethod
def fetch_linear_position_info(key):
@classmethod
def fetch_linear_position_info(cls, key):
"""Fetch the position information and project it onto the track graph

Parameters
Expand All @@ -353,6 +367,16 @@ def fetch_linear_position_info(key):
pd.DataFrame
The linearized position information
"""
key = cls.get_fully_defined_key(
key,
required_fields=[
"position_group_name",
"nwb_file_name",
"encoding_interval",
"decoding_interval",
],
)

environment = SortedSpikesDecodingV1.fetch_environments(key)[0]

position_df = SortedSpikesDecodingV1.fetch_position_info(key)[0]
Expand All @@ -374,9 +398,13 @@ def fetch_linear_position_info(key):
axis=1,
).loc[min_time:max_time]

@staticmethod
@classmethod
def fetch_spike_data(
key, filter_by_interval=True, time_slice=None, return_unit_ids=False
cls,
key,
filter_by_interval=True,
time_slice=None,
return_unit_ids=False,
) -> Union[list[np.ndarray], Optional[list[dict]]]:
"""Fetch the spike times for the decoding model

Expand All @@ -399,6 +427,14 @@ def fetch_spike_data(
list[np.ndarray]
List of spike times for each unit in the model's spike group
"""
key = cls.get_fully_defined_key(
key,
required_fields=[
"encoding_interval",
"decoding_interval",
],
)

spike_times, unit_ids = SortedSpikesGroup.fetch_spike_data(
key, return_unit_ids=True
)
Expand Down
10 changes: 7 additions & 3 deletions src/spyglass/spikesorting/analysis/v1/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import datajoint as dj
import numpy as np
from ripple_detection import get_multiunit_population_firing_rate

from spyglass.common import Session # noqa: F401
from spyglass.settings import test_mode
Expand Down Expand Up @@ -127,9 +126,12 @@ def filter_units(
include_mask[ind] = True
return include_mask

@staticmethod
@classmethod
def fetch_spike_data(
key: dict, time_slice: list[float] = None, return_unit_ids: bool = False
cls,
key: dict,
time_slice: list[float] = None,
return_unit_ids: bool = False,
) -> Union[list[np.ndarray], Optional[list[dict]]]:
"""fetch spike times for units in the group

Expand All @@ -148,6 +150,8 @@ def fetch_spike_data(
list of np.ndarray
list of spike times for each unit in the group
"""
key = cls.get_fully_defined_key(key)

# get merge_ids for SpikeSortingOutput
merge_ids = (
(
Expand Down
21 changes: 21 additions & 0 deletions src/spyglass/utils/dj_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,27 @@ def _safe_context(cls):
else nullcontext()
)

@classmethod
def get_fully_defined_key(
cls, key: dict = None, required_fields: list[str] = None
) -> dict:
if key is None:
key = dict()

required_fields = required_fields or cls.primary_key
if isinstance(key, (str, dict)): # check is either keys or substrings
if not all(
field in key for field in required_fields
): # check if all required fields are in key
if not len(query := cls() & key) == 1: # check if key is unique
raise KeyError(
f"Key {key} is neither fully specified nor a unique entry in"
+ f"{cls.full_table_name}"
samuelbray32 marked this conversation as resolved.
Show resolved Hide resolved
)
key = query.fetch1("KEY")

return key

# ------------------------------- fetch_nwb -------------------------------

@cached_property
Expand Down
Loading