Skip to content

Commit

Permalink
WIP: fix TrodesPosVideo
Browse files Browse the repository at this point in the history
  • Loading branch information
CBroz1 committed Oct 31, 2024
1 parent 010d098 commit 58bb47f
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/spyglass/position/v1/dlc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def find_mp4(
.rsplit(video_filepath.parent.as_posix(), maxsplit=1)[-1]
.split("/")[-1]
)
return _convert_mp4(video_file, video_path, output_path, videotype="mp4")
return _convert_mp4(video_file, video_path, output_path, videotype="mp4", count_frames=True)


def _convert_mp4(
Expand Down
12 changes: 9 additions & 3 deletions src/spyglass/position/v1/dlc_utils_makevid.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,16 @@ def __init__(
if not Path(video_filename).exists():
raise FileNotFoundError(f"Video not found: {video_filename}")

try:
position_mean = position_mean["DLC"]
orientation_mean = orientation_mean["DLC"]
except IndexError:
pass # trodes data provides bare arrays

self.video_filename = video_filename
self.video_frame_inds = video_frame_inds
self.position_mean = position_mean["DLC"]
self.orientation_mean = orientation_mean["DLC"]
self.position_mean = position_mean
self.orientation_mean = orientation_mean
self.centroids = centroids
self.likelihoods = likelihoods
self.position_time = position_time
Expand Down Expand Up @@ -163,7 +169,7 @@ def _get_input_stats(self, video_filename=None) -> Tuple[int, int]:
"stream=width,height,r_frame_rate",
"-of",
"csv=p=0:s=x",
video_filename,
str(video_filename),
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
Expand Down
35 changes: 26 additions & 9 deletions src/spyglass/position/v1/position_trodes_position.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from spyglass.position.v1.dlc_utils_makevid import make_video
from spyglass.settings import test_mode
from spyglass.utils import SpyglassMixin, logger
from spyglass.utils.position import fill_nan

schema = dj.schema("position_v1_trodes_position")

Expand Down Expand Up @@ -327,19 +328,35 @@ def make(self, key):
pos_df = pos_df[:min_len]
video_time = video_time[:min_len]

centroids = {
"red": np.asarray(adj_df[["xloc", "yloc"]]),
"green": np.asarray(adj_df[["xloc2", "yloc2"]]),
}
position_mean = np.asarray(pos_df[["position_x", "position_y"]])
orientation_mean = np.asarray(pos_df[["orientation"]])
position_time = np.asarray(pos_df.index)
if np.any(video_time):
centroids = {
color: fill_nan(
variable=data,
video_time=video_time,
variable_time=position_time,
)
for color, data in centroids.items()
}
position_mean = fill_nan(position_mean, video_time, position_time)
orientation_mean = fill_nan(
orientation_mean, video_time, position_time
)

make_video(
processor="opencv-trodes",
video_filename=video_path,
centroids={
"red": np.asarray(adj_df[["xloc", "yloc"]]),
"green": np.asarray(adj_df[["xloc2", "yloc2"]]),
},
position_mean=np.asarray(pos_df[["position_x", "position_y"]]),
orientation_mean=np.asarray(pos_df[["orientation"]]),
centroids=centroids,
video_time=video_time,
position_time=np.asarray(pos_df.index),
position_mean=position_mean,
orientation_mean=orientation_mean,
position_time=position_time,
output_video_filename=output_video_filename,
cm_to_pixels=meters_per_pixel * M_TO_CM,
disable_progressbar=False,
)
self.insert1(dict(**key, has_video=True))

0 comments on commit 58bb47f

Please sign in to comment.