Skip to content

Commit

Permalink
cleanup readability of DLCPosV1 make conditions
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelbray32 committed Jan 3, 2025
1 parent 5901edc commit e5af455
Showing 1 changed file with 90 additions and 84 deletions.
174 changes: 90 additions & 84 deletions src/spyglass/position/v1/position_dlc_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,93 +69,15 @@ def make(self, key):

pos_nwb = (DLCCentroid & key).fetch_nwb()[0]
ori_nwb = (DLCOrientation & key).fetch_nwb()[0]

if isinstance(pos_nwb["dlc_position"], pd.DataFrame):
# Null entry case
key["analysis_file_name"] = AnalysisNwbfile().create(
nwb_file_name=key["nwb_file_name"]
)
obj_id = AnalysisNwbfile().add_nwb_object(
key["analysis_file_name"], pd.DataFrame()
)
key["position_object_id"] = obj_id
key["orientation_object_id"] = obj_id
key["velocity_object_id"] = obj_id

else:
pos_obj = pos_nwb["dlc_position"].spatial_series["position"]
vel_obj = pos_nwb["dlc_velocity"].time_series["velocity"]
vid_frame_obj = pos_nwb["dlc_velocity"].time_series[
"video_frame_ind"
]
ori_obj = ori_nwb["dlc_orientation"].spatial_series["orientation"]

position = pynwb.behavior.Position()
orientation = pynwb.behavior.CompassDirection()
velocity = pynwb.behavior.BehavioralTimeSeries()

position.create_spatial_series(
name=pos_obj.name,
timestamps=np.asarray(pos_obj.timestamps),
conversion=pos_obj.conversion,
data=np.asarray(pos_obj.data),
reference_frame=pos_obj.reference_frame,
comments=pos_obj.comments,
description=pos_obj.description,
)

orientation.create_spatial_series(
name=ori_obj.name,
timestamps=np.asarray(ori_obj.timestamps),
conversion=ori_obj.conversion,
data=np.asarray(ori_obj.data),
reference_frame=ori_obj.reference_frame,
comments=ori_obj.comments,
description=ori_obj.description,
)

velocity.create_timeseries(
name=vel_obj.name,
timestamps=np.asarray(vel_obj.timestamps),
conversion=vel_obj.conversion,
unit=vel_obj.unit,
data=np.asarray(vel_obj.data),
comments=vel_obj.comments,
description=vel_obj.description,
)

velocity.create_timeseries(
name=vid_frame_obj.name,
timestamps=np.asarray(vid_frame_obj.timestamps),
unit=vid_frame_obj.unit,
data=np.asarray(vid_frame_obj.data),
description=vid_frame_obj.description,
comments=vid_frame_obj.comments,
)

# Add to Analysis NWB file
analysis_file_name = AnalysisNwbfile().create(key["nwb_file_name"])
key["analysis_file_name"] = analysis_file_name
nwb_analysis_file = AnalysisNwbfile()

key.update(
{
"analysis_file_name": analysis_file_name,
"position_object_id": nwb_analysis_file.add_nwb_object(
analysis_file_name, position
),
"orientation_object_id": nwb_analysis_file.add_nwb_object(
analysis_file_name, orientation
),
"velocity_object_id": nwb_analysis_file.add_nwb_object(
analysis_file_name, velocity
),
}
)
key = (
self.make_null_position_nwb(key)
if isinstance(pos_nwb["dlc_position"], pd.DataFrame) # null case
else self.make_dlc_pos_nwb(key, pos_nwb, ori_nwb) # normal case
)

AnalysisNwbfile().add(
nwb_file_name=key["nwb_file_name"],
analysis_file_name=analysis_file_name,
analysis_file_name=key["analysis_file_name"],
)
self.insert1(key)

Expand All @@ -169,6 +91,90 @@ def make(self, key):
)
AnalysisNwbfile().log(key, table=self.full_table_name)

@staticmethod
def make_null_position_nwb(key):
key["analysis_file_name"] = AnalysisNwbfile().create(
nwb_file_name=key["nwb_file_name"]
)
obj_id = AnalysisNwbfile().add_nwb_object(
key["analysis_file_name"], pd.DataFrame()
)
key["position_object_id"] = obj_id
key["orientation_object_id"] = obj_id
key["velocity_object_id"] = obj_id
return key

@staticmethod
def make_dlc_pos_nwb(key, pos_nwb, ori_nwb):
pos_obj = pos_nwb["dlc_position"].spatial_series["position"]
vel_obj = pos_nwb["dlc_velocity"].time_series["velocity"]
vid_frame_obj = pos_nwb["dlc_velocity"].time_series["video_frame_ind"]
ori_obj = ori_nwb["dlc_orientation"].spatial_series["orientation"]

position = pynwb.behavior.Position()
orientation = pynwb.behavior.CompassDirection()
velocity = pynwb.behavior.BehavioralTimeSeries()

position.create_spatial_series(
name=pos_obj.name,
timestamps=np.asarray(pos_obj.timestamps),
conversion=pos_obj.conversion,
data=np.asarray(pos_obj.data),
reference_frame=pos_obj.reference_frame,
comments=pos_obj.comments,
description=pos_obj.description,
)

orientation.create_spatial_series(
name=ori_obj.name,
timestamps=np.asarray(ori_obj.timestamps),
conversion=ori_obj.conversion,
data=np.asarray(ori_obj.data),
reference_frame=ori_obj.reference_frame,
comments=ori_obj.comments,
description=ori_obj.description,
)

velocity.create_timeseries(
name=vel_obj.name,
timestamps=np.asarray(vel_obj.timestamps),
conversion=vel_obj.conversion,
unit=vel_obj.unit,
data=np.asarray(vel_obj.data),
comments=vel_obj.comments,
description=vel_obj.description,
)

velocity.create_timeseries(
name=vid_frame_obj.name,
timestamps=np.asarray(vid_frame_obj.timestamps),
unit=vid_frame_obj.unit,
data=np.asarray(vid_frame_obj.data),
description=vid_frame_obj.description,
comments=vid_frame_obj.comments,
)

# Add to Analysis NWB file
analysis_file_name = AnalysisNwbfile().create(key["nwb_file_name"])
key["analysis_file_name"] = analysis_file_name
nwb_analysis_file = AnalysisNwbfile()

key.update(
{
"analysis_file_name": analysis_file_name,
"position_object_id": nwb_analysis_file.add_nwb_object(
analysis_file_name, position
),
"orientation_object_id": nwb_analysis_file.add_nwb_object(
analysis_file_name, orientation
),
"velocity_object_id": nwb_analysis_file.add_nwb_object(
analysis_file_name, velocity
),
}
)
return key

def fetch1_dataframe(self) -> pd.DataFrame:
"""Return the position data as a DataFrame."""
nwb_data = self.fetch_nwb()[0]
Expand Down

0 comments on commit e5af455

Please sign in to comment.