Skip to content

Commit

Permalink
WIP: error on upscaled data
Browse files Browse the repository at this point in the history
  • Loading branch information
CBroz1 committed Nov 6, 2024
1 parent 46818b2 commit 3e5c1c9
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 8 deletions.
3 changes: 2 additions & 1 deletion src/spyglass/position/v1/dlc_utils_makevid.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ def __init__(
logger.info(f"Finished video: {self.output_video_filename}")
logger.debug(f"Dropped frames: {self.dropped_frames}")

shutil.rmtree(self.temp_dir) # Clean up temp directory
if not debug:
shutil.rmtree(self.temp_dir) # Clean up temp directory

def _set_frame_info(self):
"""Set the frame information for the video."""
Expand Down
48 changes: 41 additions & 7 deletions src/spyglass/position/v1/position_trodes_position.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import os
from pathlib import Path

import datajoint as dj
import numpy as np
Expand Down Expand Up @@ -304,10 +305,28 @@ def make(self, key):
{"nwb_file_name": key["nwb_file_name"], "epoch": epoch}
)

# Check if video exists
if not video_path:
self.insert1(dict(**key, has_video=False))
return

# Check timepoints overlap
if not set(video_time).intersection(set(pos_df.index)):
raise ValueError(
"No overlapping time points between video and position data"
)

params_pk = "trodes_pos_params_name"
params = (TrodesPosParams() & {params_pk: key[params_pk]}).fetch1(
"params"
)

# Check if upsampled
if params["is_upsampled"]:
raise NotImplementedError(
"Upsampled position data not supported for video creation"
)

video_path = find_mp4(
video_path=os.path.dirname(video_path) + "/",
video_filename=video_filename,
Expand All @@ -316,14 +335,17 @@ def make(self, key):
output_video_filename = (
key["nwb_file_name"].replace(".nwb", "")
+ f"_{epoch:02d}_"
+ f'{key["trodes_pos_params_name"]}.mp4'
+ f"{key[params_pk]}.mp4"
)

adj_df = _fix_col_names(raw_df) # adjust 'xloc1' to 'xloc'

if test_mode:
# if limit := params.get("limit", None):
limit = 550
if limit or test_mode:
output_video_filename = Path(".") / f"TEST_VID_{limit}.mp4"
# pytest video data has mismatched shapes in some cases
min_len = min(len(adj_df), len(pos_df), len(video_time))
min_len = limit or min(len(adj_df), len(pos_df), len(video_time))
adj_df = adj_df[:min_len]
pos_df = pos_df[:min_len]
video_time = video_time[:min_len]
Expand All @@ -335,6 +357,7 @@ def make(self, key):
position_mean = np.asarray(pos_df[["position_x", "position_y"]])
orientation_mean = np.asarray(pos_df[["orientation"]])
position_time = np.asarray(pos_df.index)

if np.any(video_time):
centroids = {
color: fill_nan(
Expand All @@ -344,12 +367,18 @@ def make(self, key):
)
for color, data in centroids.items()
}
position_mean = fill_nan(position_mean, video_time, position_time)
position_mean = fill_nan(
variable=position_mean,
video_time=video_time,
variable_time=position_time,
)
orientation_mean = fill_nan(
orientation_mean, video_time, position_time
variable=orientation_mean,
video_time=video_time,
variable_time=position_time,
)

make_video(
vid_maker = make_video(
video_filename=video_path,
centroids=centroids,
video_time=video_time,
Expand All @@ -358,7 +387,12 @@ def make(self, key):
position_time=position_time,
output_video_filename=output_video_filename,
cm_to_pixels=meters_per_pixel * M_TO_CM,
debug=params.get("debug", False),
key_hash=dj.hash.key_hash(key),
**params,
)

# self.insert1(dict(**key, has_video=True)) # INTENTIONAL FAIL
if limit:
return vid_maker

self.insert1(dict(**key, has_video=True))

0 comments on commit 3e5c1c9

Please sign in to comment.