From 84da68dc4eb5c3665b83ca8c98b11b7017d1e702 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Wed, 29 Nov 2023 10:31:23 -0600 Subject: [PATCH] V0 migration model. Renaming for dropped, no surgery --- src/spyglass/common/__init__.py | 80 ++- src/spyglass/common/common_position.py | 213 +----- src/spyglass/linearization/__init__.py | 1 + .../merge.py} | 14 +- src/spyglass/linearization/v0/__init__.py | 6 + src/spyglass/linearization/v0/main.py | 616 ++++++++++++++++++ src/spyglass/linearization/v1/__init__.py | 6 + .../v1/main.py} | 99 +-- .../position_linearization/__init__.py | 3 - .../position_linearization/v1/__init__.py | 7 - src/spyglass/utils/dj_helper_fn.py | 59 ++ src/spyglass/utils/dj_merge_tables.py | 2 +- 12 files changed, 856 insertions(+), 250 deletions(-) create mode 100644 src/spyglass/linearization/__init__.py rename src/spyglass/{position_linearization/position_linearization_merge.py => linearization/merge.py} (52%) create mode 100644 src/spyglass/linearization/v0/__init__.py create mode 100644 src/spyglass/linearization/v0/main.py create mode 100644 src/spyglass/linearization/v1/__init__.py rename src/spyglass/{position_linearization/v1/linearization.py => linearization/v1/main.py} (61%) delete mode 100644 src/spyglass/position_linearization/__init__.py delete mode 100644 src/spyglass/position_linearization/v1/__init__.py diff --git a/src/spyglass/common/__init__.py b/src/spyglass/common/__init__.py index 47c00beb0..540bcad30 100644 --- a/src/spyglass/common/__init__.py +++ b/src/spyglass/common/__init__.py @@ -56,14 +56,10 @@ NwbfileKachery, ) from .common_position import ( - IntervalLinearizationSelection, - IntervalLinearizedPosition, IntervalPositionInfo, IntervalPositionInfoSelection, - LinearizationParameters, PositionInfoParameters, PositionVideo, - TrackGraph, ) from .common_region import BrainRegion from .common_sensors import SensorData @@ -73,5 +69,81 @@ from .populate_all_common import populate_all_common from .prepopulate import populate_from_yaml, prepopulate_default +from spyglass.linearization.v0 import ( # isort:skip + IntervalLinearizationSelection, + IntervalLinearizedPosition, + LinearizationParameters, + TrackGraph, +) + +__all__ = [ + "AnalysisNwbfile", + "AnalysisNwbfileKachery", + "BrainRegion", + "CameraDevice", + "DIOEvents", + "DataAcquisitionDevice", + "DataAcquisitionDeviceAmplifier", + "DataAcquisitionDeviceSystem", + "Electrode", + "ElectrodeGroup", + "FirFilterParameters", + "Institution", + "IntervalLinearizationSelection", + "IntervalLinearizedPosition", + "IntervalList", + "IntervalPositionInfo", + "IntervalPositionInfoSelection", + "LFP", + "LFPBand", + "LFPBandSelection", + "LFPSelection", + "Lab", + "LabMember", + "LabTeam", + "LinearizationParameters", + "Nwbfile", + "NwbfileKachery", + "PositionInfoParameters", + "PositionIntervalMap", + "PositionSource", + "PositionVideo", + "Probe", + "ProbeType", + "Raw", + "RawPosition", + "SampleCount", + "SensorData", + "Session", + "SessionGroup", + "StateScriptFile", + "Subject", + "Task", + "TaskEpoch", + "TrackGraph", + "VideoFile", + "close_nwb_files", + "convert_epoch_interval_name_to_position_interval_name", + "estimate_sampling_rate", + "get_data_interface", + "get_electrode_indices", + "get_nwb_file", + "get_raw_eseries", + "get_valid_intervals", + "interval_list_censor", + "interval_list_contains", + "interval_list_contains_ind", + "interval_list_excludes", + "interval_list_excludes_ind", + "interval_list_intersect", + "interval_list_union", + "intervals_by_length", + "os", + "populate_all_common", + "populate_from_yaml", + "prepopulate_default", + "sg", +] + if sg.config["prepopulate"]: prepopulate_default() diff --git a/src/spyglass/common/common_position.py b/src/spyglass/common/common_position.py index 67b3ce95d..93441983a 100644 --- a/src/spyglass/common/common_position.py +++ b/src/spyglass/common/common_position.py @@ -16,18 +16,11 @@ ) from position_tools.core import gaussian_smooth from tqdm import tqdm_notebook as tqdm -from track_linearization import ( - get_linearized_position, - make_track_graph, - plot_graph_as_1D, - plot_track_graph, -) -from ..settings import raw_dir, video_dir -from ..utils.dj_helper_fn import fetch_nwb -from .common_behav import RawPosition, VideoFile -from .common_interval import IntervalList # noqa F401 -from .common_nwbfile import AnalysisNwbfile +from spyglass.common.common_behav import RawPosition, VideoFile +from spyglass.common.common_nwbfile import AnalysisNwbfile +from spyglass.settings import raw_dir, video_dir +from spyglass.utils.dj_helper_fn import deprecated_factory, fetch_nwb schema = dj.schema("common_position") @@ -503,184 +496,30 @@ def _data_to_df(data, prefix="head_", add_frame_ind=False): return df -@schema -class LinearizationParameters(dj.Lookup): - """Choose whether to use an HMM to linearize position. - - This can help when the euclidean distances between separate arms are too - close and the previous position has some information about which arm the - animal is on. - - route_euclidean_distance_scaling: How much to prefer route distances between - successive time points that are closer to the euclidean distance. Smaller - numbers mean the route distance is more likely to be close to the euclidean - distance. - """ - - definition = """ - linearization_param_name : varchar(80) # name for this set of parameters - --- - use_hmm = 0 : int # use HMM to determine linearization - route_euclidean_distance_scaling = 1.0 : float # Preference for euclidean. - sensor_std_dev = 5.0 : float # Uncertainty of position sensor (in cm). - # Biases the transition matrix to prefer the current track segment. - diagonal_bias = 0.5 : float - """ - - -@schema -class TrackGraph(dj.Manual): - """Graph representation of track representing the spatial environment. - - Used for linearizing position. - """ - - definition = """ - track_graph_name : varchar(80) - ---- - environment : varchar(80) # Type of Environment - node_positions : blob # 2D position of nodes, (n_nodes, 2) - edges: blob # shape (n_edges, 2) - linear_edge_order : blob # order of edges in linear space, (n_edges, 2) - linear_edge_spacing : blob # space btwn edges in linear space, (n_edges,) - """ - - def get_networkx_track_graph(self, track_graph_parameters=None): - if track_graph_parameters is None: - track_graph_parameters = self.fetch1() - return make_track_graph( - node_positions=track_graph_parameters["node_positions"], - edges=track_graph_parameters["edges"], - ) - - def plot_track_graph(self, ax=None, draw_edge_labels=False, **kwds): - """Plot the track graph in 2D position space.""" - track_graph = self.get_networkx_track_graph() - plot_track_graph( - track_graph, ax=ax, draw_edge_labels=draw_edge_labels, **kwds - ) - - def plot_track_graph_as_1D( - self, - ax=None, - axis="x", - other_axis_start=0.0, - draw_edge_labels=False, - node_size=300, - node_color="#1f77b4", - ): - """Plot the track graph in 1D to see how the linearization is set up.""" - track_graph_parameters = self.fetch1() - track_graph = self.get_networkx_track_graph( - track_graph_parameters=track_graph_parameters - ) - plot_graph_as_1D( - track_graph, - edge_order=track_graph_parameters["linear_edge_order"], - edge_spacing=track_graph_parameters["linear_edge_spacing"], - ax=ax, - axis=axis, - other_axis_start=other_axis_start, - draw_edge_labels=draw_edge_labels, - node_size=node_size, - node_color=node_color, - ) - - -@schema -class IntervalLinearizationSelection(dj.Lookup): - definition = """ - -> IntervalPositionInfo - -> TrackGraph - -> LinearizationParameters - --- - """ - - -@schema -class IntervalLinearizedPosition(dj.Computed): - """Linearized position for a given interval""" - - definition = """ - -> IntervalLinearizationSelection - --- - -> AnalysisNwbfile - linearized_position_object_id : varchar(40) - """ - - def make(self, key): - print(f"Computing linear position for: {key}") - - key["analysis_file_name"] = AnalysisNwbfile().create( - key["nwb_file_name"] - ) - - position_nwb = ( - IntervalPositionInfo - & { - "nwb_file_name": key["nwb_file_name"], - "interval_list_name": key["interval_list_name"], - "position_info_param_name": key["position_info_param_name"], - } - ).fetch_nwb()[0] - - position = np.asarray( - position_nwb["head_position"].get_spatial_series().data - ) - time = np.asarray( - position_nwb["head_position"].get_spatial_series().timestamps - ) - - linearization_parameters = ( - LinearizationParameters() - & {"linearization_param_name": key["linearization_param_name"]} - ).fetch1() - track_graph_info = ( - TrackGraph() & {"track_graph_name": key["track_graph_name"]} - ).fetch1() +# ------------------------------ Migrated Tables ------------------------------ - track_graph = make_track_graph( - node_positions=track_graph_info["node_positions"], - edges=track_graph_info["edges"], - ) +from spyglass.linearization.v0 import main as linV0 # noqa: E402 - linear_position_df = get_linearized_position( - position=position, - track_graph=track_graph, - edge_spacing=track_graph_info["linear_edge_spacing"], - edge_order=track_graph_info["linear_edge_order"], - use_HMM=linearization_parameters["use_hmm"], - route_euclidean_distance_scaling=linearization_parameters[ - "route_euclidean_distance_scaling" - ], - sensor_std_dev=linearization_parameters["sensor_std_dev"], - diagonal_bias=linearization_parameters["diagonal_bias"], - ) - - linear_position_df["time"] = time - - # Insert into analysis nwb file - nwb_analysis_file = AnalysisNwbfile() - - key["linearized_position_object_id"] = nwb_analysis_file.add_nwb_object( - analysis_file_name=key["analysis_file_name"], - nwb_object=linear_position_df, - ) - - nwb_analysis_file.add( - nwb_file_name=key["nwb_file_name"], - analysis_file_name=key["analysis_file_name"], - ) - - self.insert1(key) - - def fetch_nwb(self, *attrs, **kwargs): - return fetch_nwb( - self, (AnalysisNwbfile, "analysis_file_abs_path"), *attrs, **kwargs - ) - - def fetch1_dataframe(self): - return self.fetch_nwb()[0]["linearized_position"].set_index("time") +( + LinearizationParameters, + TrackGraph, + IntervalLinearizationSelection, + IntervalLinearizedPosition, +) = deprecated_factory( + [ + ("LinearizationParameters", linV0.LinearizationParameters), + ("TrackGraph", linV0.TrackGraph), + ( + "IntervalLinearizationSelection", + linV0.IntervalLinearizationSelection, + ), + ( + "IntervalLinearizedPosition", + linV0.IntervalLinearizedPosition, + ), + ], + old_module=__name__, +) class NodePicker: diff --git a/src/spyglass/linearization/__init__.py b/src/spyglass/linearization/__init__.py new file mode 100644 index 000000000..6bf01cb8b --- /dev/null +++ b/src/spyglass/linearization/__init__.py @@ -0,0 +1 @@ +# from spyglass.linearization.merge import LinearizedOutput diff --git a/src/spyglass/position_linearization/position_linearization_merge.py b/src/spyglass/linearization/merge.py similarity index 52% rename from src/spyglass/position_linearization/position_linearization_merge.py rename to src/spyglass/linearization/merge.py index 1efd38afc..e084037a2 100644 --- a/src/spyglass/position_linearization/position_linearization_merge.py +++ b/src/spyglass/linearization/merge.py @@ -1,27 +1,25 @@ import datajoint as dj -from spyglass.position_linearization.v1.linearization import ( # noqa F401 - LinearizedPositionV1, -) +from spyglass.linearization.v1.main import LinearizedV1 # noqa F401 from ..utils.dj_merge_tables import _Merge -schema = dj.schema("position_linearization_merge") +schema = dj.schema("linearization_merge") @schema -class LinearizedPositionOutput(_Merge): +class LinearizedOutput(_Merge): definition = """ merge_id: uuid --- source: varchar(32) """ - class LinearizedPositionV1(dj.Part): # noqa: F811 + class LinearizedV1(dj.Part): # noqa F811 definition = """ - -> LinearizedPositionOutput + -> master --- - -> LinearizedPositionV1 + -> LinearizedV1 """ def fetch1_dataframe(self): diff --git a/src/spyglass/linearization/v0/__init__.py b/src/spyglass/linearization/v0/__init__.py new file mode 100644 index 000000000..438552bcf --- /dev/null +++ b/src/spyglass/linearization/v0/__init__.py @@ -0,0 +1,6 @@ +from .main import ( + IntervalLinearizationSelection, + IntervalLinearizedPosition, + LinearizationParameters, + TrackGraph, +) diff --git a/src/spyglass/linearization/v0/main.py b/src/spyglass/linearization/v0/main.py new file mode 100644 index 000000000..e4b5a4932 --- /dev/null +++ b/src/spyglass/linearization/v0/main.py @@ -0,0 +1,616 @@ +import cv2 +import datajoint as dj +import matplotlib.pyplot as plt +import numpy as np +import pynwb +import pynwb.behavior +from tqdm import tqdm_notebook as tqdm +from track_linearization import ( + get_linearized_position, + make_track_graph, + plot_graph_as_1D, + plot_track_graph, +) + +from spyglass.common.common_behav import RawPosition, VideoFile +from spyglass.common.common_interval import IntervalList # noqa F401 +from spyglass.common.common_nwbfile import AnalysisNwbfile +from spyglass.common.common_position import IntervalPositionInfo # noqa F401 +from spyglass.settings import raw_dir, video_dir +from spyglass.utils.dj_helper_fn import fetch_nwb + +schema = dj.schema("common_position") +# CBroz: I would rename 'linearization_v0', but would require db surgery +# Similarly, I would rename tables below and transfer contents +# - LinearizationParameters -> LinearizationParams +# - IntervalLinearizedSelection -> LinerarizedSeledtion +# - IntervalLinearizedPosition -> LinearizedV0 + + +@schema +class LinearizationParameters(dj.Lookup): + """Choose whether to use an HMM to linearize position. + + This can help when the euclidean distances between separate arms are too + close and the previous position has some information about which arm the + animal is on. + + route_euclidean_distance_scaling: How much to prefer route distances between + successive time points that are closer to the euclidean distance. Smaller + numbers mean the route distance is more likely to be close to the euclidean + distance. + """ + + definition = """ + linearization_param_name : varchar(80) # name for this set of parameters + --- + use_hmm = 0 : int # use HMM to determine linearization + route_euclidean_distance_scaling = 1.0 : float # Preference for euclidean. + sensor_std_dev = 5.0 : float # Uncertainty of position sensor (in cm). + # Biases the transition matrix to prefer the current track segment. + diagonal_bias = 0.5 : float + """ + + +@schema +class TrackGraph(dj.Manual): + """Graph representation of track representing the spatial environment. + + Used for linearizing position. + """ + + definition = """ + track_graph_name : varchar(80) + ---- + environment : varchar(80) # Type of Environment + node_positions : blob # 2D position of nodes, (n_nodes, 2) + edges: blob # shape (n_edges, 2) + linear_edge_order : blob # order of edges in linear space, (n_edges, 2) + linear_edge_spacing : blob # space btwn edges in linear space, (n_edges,) + """ + + def get_networkx_track_graph(self, track_graph_parameters=None): + if track_graph_parameters is None: + track_graph_parameters = self.fetch1() + return make_track_graph( + node_positions=track_graph_parameters["node_positions"], + edges=track_graph_parameters["edges"], + ) + + def plot_track_graph(self, ax=None, draw_edge_labels=False, **kwds): + """Plot the track graph in 2D position space.""" + track_graph = self.get_networkx_track_graph() + plot_track_graph( + track_graph, ax=ax, draw_edge_labels=draw_edge_labels, **kwds + ) + + def plot_track_graph_as_1D( + self, + ax=None, + axis="x", + other_axis_start=0.0, + draw_edge_labels=False, + node_size=300, + node_color="#1f77b4", + ): + """Plot the track graph in 1D to see how the linearization is set up.""" + track_graph_parameters = self.fetch1() + track_graph = self.get_networkx_track_graph( + track_graph_parameters=track_graph_parameters + ) + plot_graph_as_1D( + track_graph, + edge_order=track_graph_parameters["linear_edge_order"], + edge_spacing=track_graph_parameters["linear_edge_spacing"], + ax=ax, + axis=axis, + other_axis_start=other_axis_start, + draw_edge_labels=draw_edge_labels, + node_size=node_size, + node_color=node_color, + ) + + +@schema +class IntervalLinearizationSelection(dj.Lookup): + definition = """ + -> IntervalPositionInfo + -> TrackGraph + -> LinearizationParameters + """ + + +@schema +class IntervalLinearizedPosition(dj.Computed): + """Linearized position for a given interval""" + + definition = """ + -> IntervalLinearizationSelection + --- + -> AnalysisNwbfile + linearized_position_object_id : varchar(40) + """ + + def make(self, key): + print(f"Computing linear position for: {key}") + + key["analysis_file_name"] = AnalysisNwbfile().create( + key["nwb_file_name"] + ) + + position_nwb = ( + IntervalPositionInfo + & { + "nwb_file_name": key["nwb_file_name"], + "interval_list_name": key["interval_list_name"], + "position_info_param_name": key["position_info_param_name"], + } + ).fetch_nwb()[0] + + position = np.asarray( + position_nwb["head_position"].get_spatial_series().data + ) + time = np.asarray( + position_nwb["head_position"].get_spatial_series().timestamps + ) + + linearization_parameters = ( + LinearizationParameters() + & {"linearization_param_name": key["linearization_param_name"]} + ).fetch1() + track_graph_info = ( + TrackGraph() & {"track_graph_name": key["track_graph_name"]} + ).fetch1() + + track_graph = make_track_graph( + node_positions=track_graph_info["node_positions"], + edges=track_graph_info["edges"], + ) + + linear_position_df = get_linearized_position( + position=position, + track_graph=track_graph, + edge_spacing=track_graph_info["linear_edge_spacing"], + edge_order=track_graph_info["linear_edge_order"], + use_HMM=linearization_parameters["use_hmm"], + route_euclidean_distance_scaling=linearization_parameters[ + "route_euclidean_distance_scaling" + ], + sensor_std_dev=linearization_parameters["sensor_std_dev"], + diagonal_bias=linearization_parameters["diagonal_bias"], + ) + + linear_position_df["time"] = time + + # Insert into analysis nwb file + nwb_analysis_file = AnalysisNwbfile() + + key["linearized_position_object_id"] = nwb_analysis_file.add_nwb_object( + analysis_file_name=key["analysis_file_name"], + nwb_object=linear_position_df, + ) + + nwb_analysis_file.add( + nwb_file_name=key["nwb_file_name"], + analysis_file_name=key["analysis_file_name"], + ) + + self.insert1(key) + + def fetch_nwb(self, *attrs, **kwargs): + return fetch_nwb( + self, (AnalysisNwbfile, "analysis_file_abs_path"), *attrs, **kwargs + ) + + def fetch1_dataframe(self): + return self.fetch_nwb()[0]["linearized_position"].set_index("time") + + +class NodePicker: + """Interactive creation of track graph by looking at video frames.""" + + def __init__( + self, ax=None, video_filename=None, node_color="#1f78b4", node_size=100 + ): + if ax is None: + ax = plt.gca() + self.ax = ax + self.canvas = ax.get_figure().canvas + self.cid = None + self._nodes = [] + self.node_color = node_color + self._nodes_plot = ax.scatter( + [], [], zorder=5, s=node_size, color=node_color + ) + self.edges = [[]] + self.video_filename = video_filename + + if video_filename is not None: + self.video = cv2.VideoCapture(video_filename) + frame = self.get_video_frame() + ax.imshow(frame, picker=True) + ax.set_title( + "Left click to place node.\nRight click to remove node." + "\nShift+Left click to clear nodes." + "\nCntrl+Left click two nodes to place an edge" + ) + + self.connect() + + @property + def node_positions(self): + return np.asarray(self._nodes) + + def connect(self): + if self.cid is None: + self.cid = self.canvas.mpl_connect( + "button_press_event", self.click_event + ) + + def disconnect(self): + if self.cid is not None: + self.canvas.mpl_disconnect(self.cid) + self.cid = None + + def click_event(self, event): + if not event.inaxes: + return + if (event.key not in ["control", "shift"]) & ( + event.button == 1 + ): # left click + self._nodes.append((event.xdata, event.ydata)) + if (event.key not in ["control", "shift"]) & ( + event.button == 3 + ): # right click + self.remove_point((event.xdata, event.ydata)) + if (event.key == "shift") & (event.button == 1): + self.clear() + if (event.key == "control") & (event.button == 1): + point = (event.xdata, event.ydata) + distance_to_nodes = np.linalg.norm( + self.node_positions - point, axis=1 + ) + closest_node_ind = np.argmin(distance_to_nodes) + if len(self.edges[-1]) < 2: + self.edges[-1].append(closest_node_ind) + else: + self.edges.append([closest_node_ind]) + + self.redraw() + + def redraw(self): + # Draw Node Circles + if len(self.node_positions) > 0: + self._nodes_plot.set_offsets(self.node_positions) + else: + self._nodes_plot.set_offsets([]) + + # Draw Node Numbers + self.ax.texts = [] + for ind, (x, y) in enumerate(self.node_positions): + self.ax.text( + x, + y, + ind, + zorder=6, + fontsize=12, + horizontalalignment="center", + verticalalignment="center", + clip_on=True, + bbox=None, + transform=self.ax.transData, + ) + # Draw Edges + self.ax.lines = [] # clears the existing lines + for edge in self.edges: + if len(edge) > 1: + x1, y1 = self.node_positions[edge[0]] + x2, y2 = self.node_positions[edge[1]] + self.ax.plot( + [x1, x2], [y1, y2], color=self.node_color, linewidth=2 + ) + + self.canvas.draw_idle() + + def remove_point(self, point): + if len(self._nodes) > 0: + distance_to_nodes = np.linalg.norm( + self.node_positions - point, axis=1 + ) + closest_node_ind = np.argmin(distance_to_nodes) + self._nodes.pop(closest_node_ind) + + def clear(self): + self._nodes = [] + self.edges = [[]] + self.redraw() + + def get_video_frame(self): + is_grabbed, frame = self.video.read() + if is_grabbed: + return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + + +@schema +class PositionVideo(dj.Computed): + """Creates a video of the computed head position and orientation as well as + the original LED positions overlaid on the video of the animal. + + Use for debugging the effect of position extraction parameters.""" + + definition = """ + -> IntervalPositionInfo + --- + """ + + def make(self, key): + M_TO_CM = 100 + + print("Loading position data...") + raw_position_df = ( + RawPosition() + & { + "nwb_file_name": key["nwb_file_name"], + "interval_list_name": key["interval_list_name"], + } + ).fetch1_dataframe() + position_info_df = ( + IntervalPositionInfo() + & { + "nwb_file_name": key["nwb_file_name"], + "interval_list_name": key["interval_list_name"], + "position_info_param_name": key["position_info_param_name"], + } + ).fetch1_dataframe() + + print("Loading video data...") + epoch = ( + int( + key["interval_list_name"] + .replace("pos ", "") + .replace(" valid times", "") + ) + + 1 + ) + video_info = ( + VideoFile() + & {"nwb_file_name": key["nwb_file_name"], "epoch": epoch} + ).fetch1() + io = pynwb.NWBHDF5IO(raw_dir + "/" + video_info["nwb_file_name"], "r") + nwb_file = io.read() + nwb_video = nwb_file.objects[video_info["video_file_object_id"]] + video_filename = nwb_video.external_file[0] + + nwb_base_filename = key["nwb_file_name"].replace(".nwb", "") + output_video_filename = ( + f"{nwb_base_filename}_{epoch:02d}_" + f'{key["position_info_param_name"]}.mp4' + ) + + # ensure standardized column names + raw_position_df = _fix_col_names(raw_position_df) + # if IntervalPositionInfo supersampled position, downsample to video + if position_info_df.shape[0] > raw_position_df.shape[0]: + ind = np.digitize( + raw_position_df.index, position_info_df.index, right=True + ) + position_info_df = position_info_df.iloc[ind] + + centroids = { + "red": np.asarray(raw_position_df[["xloc", "yloc"]]), + "green": np.asarray(raw_position_df[["xloc2", "yloc2"]]), + } + head_position_mean = np.asarray( + position_info_df[["head_position_x", "head_position_y"]] + ) + head_orientation_mean = np.asarray( + position_info_df[["head_orientation"]] + ) + video_time = np.asarray(nwb_video.timestamps) + position_time = np.asarray(position_info_df.index) + cm_per_pixel = nwb_video.device.meters_per_pixel * M_TO_CM + + print("Making video...") + self.make_video( + f"{video_dir}/{video_filename}", + centroids, + head_position_mean, + head_orientation_mean, + video_time, + position_time, + output_video_filename=output_video_filename, + cm_to_pixels=cm_per_pixel, + disable_progressbar=False, + ) + + @staticmethod + def convert_to_pixels(data, frame_size, cm_to_pixels=1.0): + """Converts from cm to pixels and flips the y-axis. + Parameters + ---------- + data : ndarray, shape (n_time, 2) + frame_size : array_like, shape (2,) + cm_to_pixels : float + + Returns + ------- + converted_data : ndarray, shape (n_time, 2) + """ + return data / cm_to_pixels + + @staticmethod + def fill_nan(variable, video_time, variable_time): + video_ind = np.digitize(variable_time, video_time[1:]) + + n_video_time = len(video_time) + try: + n_variable_dims = variable.shape[1] + filled_variable = np.full((n_video_time, n_variable_dims), np.nan) + except IndexError: + filled_variable = np.full((n_video_time,), np.nan) + filled_variable[video_ind] = variable + + return filled_variable + + def make_video( + self, + video_filename, + centroids, + head_position_mean, + head_orientation_mean, + video_time, + position_time, + output_video_filename="output.mp4", + cm_to_pixels=1.0, + disable_progressbar=False, + arrow_radius=15, + circle_radius=8, + ): + RGB_PINK = (234, 82, 111) + RGB_YELLOW = (253, 231, 76) + RGB_WHITE = (255, 255, 255) + + video = cv2.VideoCapture(video_filename) + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + frame_size = (int(video.get(3)), int(video.get(4))) + frame_rate = video.get(5) + n_frames = int(head_orientation_mean.shape[0]) + + out = cv2.VideoWriter( + output_video_filename, fourcc, frame_rate, frame_size, True + ) + + centroids = { + color: self.fill_nan(data, video_time, position_time) + for color, data in centroids.items() + } + head_position_mean = self.fill_nan( + head_position_mean, video_time, position_time + ) + head_orientation_mean = self.fill_nan( + head_orientation_mean, video_time, position_time + ) + + for time_ind in tqdm( + range(n_frames - 1), desc="frames", disable=disable_progressbar + ): + is_grabbed, frame = video.read() + if is_grabbed: + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + + red_centroid = centroids["red"][time_ind] + green_centroid = centroids["green"][time_ind] + + head_position = head_position_mean[time_ind] + head_position = self.convert_to_pixels( + head_position, frame_size, cm_to_pixels + ) + head_orientation = head_orientation_mean[time_ind] + + if np.all(~np.isnan(red_centroid)): + cv2.circle( + img=frame, + center=tuple(red_centroid.astype(int)), + radius=circle_radius, + color=RGB_YELLOW, + thickness=-1, + shift=cv2.CV_8U, + ) + + if np.all(~np.isnan(green_centroid)): + cv2.circle( + img=frame, + center=tuple(green_centroid.astype(int)), + radius=circle_radius, + color=RGB_PINK, + thickness=-1, + shift=cv2.CV_8U, + ) + + if np.all(~np.isnan(head_position)) & np.all( + ~np.isnan(head_orientation) + ): + arrow_tip = ( + int( + head_position[0] + + arrow_radius * np.cos(head_orientation) + ), + int( + head_position[1] + + arrow_radius * np.sin(head_orientation) + ), + ) + cv2.arrowedLine( + img=frame, + pt1=tuple(head_position.astype(int)), + pt2=arrow_tip, + color=RGB_WHITE, + thickness=4, + line_type=8, + shift=cv2.CV_8U, + tipLength=0.25, + ) + + if np.all(~np.isnan(head_position)): + cv2.circle( + img=frame, + center=tuple(head_position.astype(int)), + radius=circle_radius, + color=RGB_WHITE, + thickness=-1, + shift=cv2.CV_8U, + ) + + frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) + out.write(frame) + else: + break + + video.release() + out.release() + cv2.destroyAllWindows() + + +def _fix_col_names(spatial_df): + """Renames columns in spatial dataframe according to previous norm + + Accepts unnamed first led, 1 or 0 indexed. + Prompts user for confirmation of renaming unexpected columns. + For backwards compatibility, renames to "xloc", "yloc", "xloc2", "yloc2" + """ + + DEFAULT_COLS = ["xloc", "yloc", "xloc2", "yloc2"] + ONE_IDX_COLS = ["xloc1", "yloc1", "xloc2", "yloc2"] + ZERO_IDX_COLS = ["xloc0", "yloc0", "xloc1", "yloc1"] + + input_cols = list(spatial_df.columns) + + has_default = all([c in input_cols for c in DEFAULT_COLS]) + has_0_idx = all([c in input_cols for c in ZERO_IDX_COLS]) + has_1_idx = all([c in input_cols for c in ONE_IDX_COLS]) + + if has_default: + # move the 4 position columns to front, continue + spatial_df = spatial_df[DEFAULT_COLS] + elif has_0_idx: + # move the 4 position columns to front, rename to default, continue + spatial_df = spatial_df[ZERO_IDX_COLS] + spatial_df.columns = DEFAULT_COLS + elif has_1_idx: + # move the 4 position columns to front, rename to default, continue + spatial_df = spatial_df[ONE_IDX_COLS] + spatial_df.columns = DEFAULT_COLS + else: + if len(input_cols) != 4 or not has_default: + choice = dj.utils.user_choice( + "Unexpected columns in raw position. Assume " + + f"{DEFAULT_COLS[:4]}?\n{spatial_df}\n" + ) + if choice.lower() not in ["yes", "y"]: + raise ValueError( + f"Unexpected columns in raw position: {input_cols}" + ) + # rename first 4 columns, keep rest. Rest dropped below + spatial_df.columns = DEFAULT_COLS + input_cols[4:] + + return spatial_df diff --git a/src/spyglass/linearization/v1/__init__.py b/src/spyglass/linearization/v1/__init__.py new file mode 100644 index 000000000..9a8415311 --- /dev/null +++ b/src/spyglass/linearization/v1/__init__.py @@ -0,0 +1,6 @@ +from spyglass.linearization.v1.main import ( + LinearizationParams, + LinearizationSelection, + LinearizedV1, + TrackGraph, +) diff --git a/src/spyglass/position_linearization/v1/linearization.py b/src/spyglass/linearization/v1/main.py similarity index 61% rename from src/spyglass/position_linearization/v1/linearization.py rename to src/spyglass/linearization/v1/main.py index 7f30be8c9..a02df4b11 100644 --- a/src/spyglass/position_linearization/v1/linearization.py +++ b/src/spyglass/linearization/v1/main.py @@ -1,7 +1,8 @@ import copy + import datajoint as dj -from datajoint.utils import to_camel_case import numpy as np +from datajoint.utils import to_camel_case from track_linearization import ( get_linearized_position, make_track_graph, @@ -13,44 +14,55 @@ from spyglass.position.position_merge import PositionOutput from spyglass.utils.dj_helper_fn import fetch_nwb -schema = dj.schema("position_linearization_v1") +schema = dj.schema("linearization_v1") @schema -class LinearizationParameters(dj.Lookup): - """Choose whether to use an HMM to linearize position. This can help when - the eucledian distances between separate arms are too close and the previous - position has some information about which arm the animal is on.""" +class LinearizationParams(dj.Lookup): + """Choose whether to use an HMM to linearize position. + + HMM can help when the eucledian distances between separate arms are too + close and the previous position has some information about which arm the + animal is on. + + route_euclidean_distance_scaling : float + How much to prefer route distances between successive time points that + are closer to the euclidean distance. Smaller numbers mean the route + distance is more likely to be close to the euclidean distance. + + """ definition = """ - linearization_param_name : varchar(80) # name for this set of parameters + linearization_param_name : varchar(32) # name for this set of parameters --- - use_hmm = 0 : int # use HMM to determine linearization - # How much to prefer route distances between successive time points that are closer to the euclidean distance. Smaller numbers mean the route distance is more likely to be close to the euclidean distance. + use_hmm = 0 : int # use HMM to linearize route_euclidean_distance_scaling = 1.0 : float - sensor_std_dev = 5.0 : float # Uncertainty of position sensor (in cm). - # Biases the transition matrix to prefer the current track segment. - diagonal_bias = 0.5 : float + sensor_std_dev = 5.0 : float # Uncertainty of position sensor (in cm) + diagonal_bias = 0.5 : float # Biases transition matrix, prefer current segment """ @schema class TrackGraph(dj.Manual): - """Graph representation of track representing the spatial environment. - Used for linearizing position.""" + """Graph representation of track in the spatial environment.""" definition = """ track_graph_name : varchar(80) ---- - environment : varchar(80) # Type of Environment - node_positions : blob # 2D position of track_graph nodes, shape (n_nodes, 2) - edges: blob # shape (n_edges, 2) - linear_edge_order : blob # order of track graph edges in the linear space, shape (n_edges, 2) - linear_edge_spacing : blob # amount of space between edges in the linear space, shape (n_edges,) + environment : varchar(80) # Type of Environment + node_positions : blob # 2D position of nodes, (n_nodes, 2) + edges: blob # shape (n_edges, 2) + linear_edge_order : blob # order of edges in linear space, (n_edges, 2) + linear_edge_spacing : blob # space btwn edges in linear space, (n_edges,) """ def get_networkx_track_graph(self, track_graph_parameters=None): if track_graph_parameters is None: + if len(self) > 1: + raise ValueError( + "More than one track graph found." + "Please specify track_graph_parameters." + ) track_graph_parameters = self.fetch1() return make_track_graph( node_positions=track_graph_parameters["node_positions"], @@ -72,9 +84,16 @@ def plot_track_graph_as_1D( draw_edge_labels=False, node_size=300, node_color="#1f77b4", + track_graph_parameters=None, ): - """Plot the track graph in 1D to see how the linearization is set up.""" - track_graph_parameters = self.fetch1() + """Plot the track graph in 1D to see linearization set up.""" + if track_graph_parameters is None: + if len(self) > 1: + raise ValueError( + "More than one track graph found." + "Please specify track_graph_parameters." + ) + track_graph_parameters = self.fetch1() track_graph = self.get_networkx_track_graph( track_graph_parameters=track_graph_parameters ) @@ -94,15 +113,14 @@ def plot_track_graph_as_1D( @schema class LinearizationSelection(dj.Lookup): definition = """ - -> PositionOutput + -> PositionOutput.proj(pos_merge_id='merge_id') -> TrackGraph - -> LinearizationParameters - --- + -> LinearizationParams """ @schema -class LinearizedPositionV1(dj.Computed): +class LinearizedV1(dj.Computed): """Linearized position for a given interval""" definition = """ @@ -116,21 +134,17 @@ def make(self, key): orig_key = copy.deepcopy(key) print(f"Computing linear position for: {key}") - position_nwb = PositionOutput.fetch_nwb(key)[0] - key["analysis_file_name"] = AnalysisNwbfile().create( - key["nwb_file_name"] - ) - position = np.asarray( - position_nwb["position"].get_spatial_series().data - ) - time = np.asarray( - position_nwb["position"].get_spatial_series().timestamps - ) + analysis_file_name = AnalysisNwbfile().create(key["nwb_file_name"]) + + position_obj = PositionOutput.fetch_nwb(key)[0]["position"] + position = np.asarray(position_obj.get_spatial_series().data) + time = np.asarray(position_obj.get_spatial_series().timestamps) linearization_parameters = ( - LinearizationParameters() + LinearizationParams() & {"linearization_param_name": key["linearization_param_name"]} ).fetch1() + track_graph_info = ( TrackGraph() & {"track_graph_name": key["track_graph_name"]} ).fetch1() @@ -152,17 +166,22 @@ def make(self, key): sensor_std_dev=linearization_parameters["sensor_std_dev"], diagonal_bias=linearization_parameters["diagonal_bias"], ) - linear_position_df["time"] = time # Insert into analysis nwb file nwb_analysis_file = AnalysisNwbfile() - key["linearized_position_object_id"] = nwb_analysis_file.add_nwb_object( + lin_pos_obj_id = nwb_analysis_file.add_nwb_object( analysis_file_name=key["analysis_file_name"], nwb_object=linear_position_df, ) + key.update( + { + "linearized_position_object_id": lin_pos_obj_id, + "analysis_file_name": analysis_file_name, + } + ) nwb_analysis_file.add( nwb_file_name=key["nwb_file_name"], analysis_file_name=key["analysis_file_name"], @@ -170,11 +189,11 @@ def make(self, key): self.insert1(key) - from ..position_linearization_merge import LinearizedPositionOutput + from ..merge import LinearizedOutput part_name = to_camel_case(self.table_name.split("__")[-1]) - LinearizedPositionOutput._merge_insert( + LinearizedOutput._merge_insert( [orig_key], part_name=part_name, skip_duplicates=True ) diff --git a/src/spyglass/position_linearization/__init__.py b/src/spyglass/position_linearization/__init__.py deleted file mode 100644 index 49103ec16..000000000 --- a/src/spyglass/position_linearization/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from spyglass.position_linearization.position_linearization_merge import ( - LinearizedPositionOutput, -) diff --git a/src/spyglass/position_linearization/v1/__init__.py b/src/spyglass/position_linearization/v1/__init__.py deleted file mode 100644 index 46d287a45..000000000 --- a/src/spyglass/position_linearization/v1/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from spyglass.common.common_position import NodePicker -from spyglass.position_linearization.v1.linearization import ( - LinearizationParameters, - LinearizationSelection, - LinearizedPositionV1, - TrackGraph, -) diff --git a/src/spyglass/utils/dj_helper_fn.py b/src/spyglass/utils/dj_helper_fn.py index 8aaf40313..4ad2d58ce 100644 --- a/src/spyglass/utils/dj_helper_fn.py +++ b/src/spyglass/utils/dj_helper_fn.py @@ -1,6 +1,7 @@ """Helper functions for manipulating information from DataJoint fetch calls.""" import inspect import os +from typing import Type import datajoint as dj import numpy as np @@ -8,6 +9,64 @@ from .nwb_helper_fn import get_nwb_file +def deprecated_factory(classes: list, old_module: str = "") -> list: + """Creates a list of classes and prints a warning when instantiated + + Parameters + --------- + classes : list + list of tuples containing old_class, new_class + + Returns + ------ + list + list of classes that will print a warning when instantiated + """ + + if not isinstance(classes, list): + classes = [classes] + + ret = [ + _subclass_factory(old_name=c[0], new_class=c[1], old_module=old_module) + for c in classes + ] + + return ret[0] if len(ret) == 1 else ret + + +def _subclass_factory( + old_name: str, new_class: Type, old_module: str = "" +) -> Type: + """Creates a sublcass with a deprecation warning on __init__ + + Old class is a subclass of new class, so it will inherit all of the new + class's methods. Old class retains its original name and module. Use + __name__ to get the module name of the caller. + + Usage: OldClass = _subclass_factory('OldClass', __name__, NewClass) + """ + + new_module = new_class().__class__.__module__ + + # Define the __call__ method for the new class + def init_override(self, *args, **kwargs): + print( + "Deprecation Warning: this class has been moved out of " + + f"{old_module}\n" + + f"\t{old_name} -> {new_module}.{new_class.__name__}" + + "\nPlease use the new location." + ) + return super(self.__class__, self).__init__(*args, **kwargs) + + class_dict = { + "__module__": old_module or new_class.__class__.__module__, + "__init__": init_override, + "_is_deprecated": True, + } + + return type(old_name, (new_class,), class_dict) + + def dj_replace(original_table, new_values, key_column, replace_column): """Given the output of a fetch() call from a schema and a 2D array made up of (key_value, replace_value) tuples, find each instance of key_value in diff --git a/src/spyglass/utils/dj_merge_tables.py b/src/spyglass/utils/dj_merge_tables.py index 2e184c66a..e200f3602 100644 --- a/src/spyglass/utils/dj_merge_tables.py +++ b/src/spyglass/utils/dj_merge_tables.py @@ -204,7 +204,7 @@ def _merge_repr(cls, restriction: str = True) -> dj.expression.Union: for p in cls._merge_restrict_parts( restriction=restriction, add_invalid_restrict=False, - return_empties=False, + return_empties=True, ) ]