From 19930b77d089ca5c71c87c34c368f53dbd4ae10c Mon Sep 17 00:00:00 2001 From: Sam Bray Date: Tue, 26 Nov 2024 15:42:01 -0800 Subject: [PATCH] initial key decorator --- src/spyglass/decoding/v1/clusterless.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/src/spyglass/decoding/v1/clusterless.py b/src/spyglass/decoding/v1/clusterless.py index 7e0711ad9..fd95b9bee 100644 --- a/src/spyglass/decoding/v1/clusterless.py +++ b/src/spyglass/decoding/v1/clusterless.py @@ -11,6 +11,7 @@ import copy import uuid +from functools import wraps from pathlib import Path import datajoint as dj @@ -37,6 +38,26 @@ schema = dj.schema("decoding_clusterless_v1") +def classmethod_full_key_decorator(required_keys=[]): + def decorator(method): + @wraps(method) + def wrapper(cls, key=None, *args, **kwargs): + # Ensure key is not None + if key is None: + key = {} + + # Check if required keys are in key, and fetch if not + if not all([k in key for k in required_keys]): + key = (cls() & key).fetch1("KEY") + + # Call the original method with the modified key + return method(cls, key, *args, **kwargs) + + return wrapper + + return decorator + + @schema class UnitWaveformFeaturesGroup(SpyglassMixin, dj.Manual): definition = """ @@ -467,6 +488,9 @@ def fetch_spike_data(key, filter_by_interval=True): return new_spike_times, new_waveform_features @classmethod + @classmethod_full_key_decorator( + required_keys=["nwb_file_name", "waveform_features_group_name"] + ) def get_spike_indicator(cls, key, time): """get spike indicator matrix for the group