Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
CBroz1 committed Nov 17, 2023
2 parents ff63090 + eb3b050 commit ddeaf92
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 50 deletions.
109 changes: 59 additions & 50 deletions src/spyglass/common/common_position.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
plot_track_graph,
)

from ..settings import raw_dir
from ..settings import raw_dir, video_dir
from ..utils.dj_helper_fn import fetch_nwb
from .common_behav import RawPosition, VideoFile
from .common_interval import IntervalList # noqa F401
Expand Down Expand Up @@ -248,51 +248,6 @@ def _fix_kwargs(
max_plausible_speed,
)

@staticmethod
def _fix_col_names(spatial_df):
"""Renames columns in spatial dataframe according to previous norm
Accepts unnamed first led, 1 or 0 indexed.
Prompts user for confirmation of renaming unexpected columns.
For backwards compatibility, renames to "xloc", "yloc", "xloc2", "yloc2"
"""

DEFAULT_COLS = ["xloc", "yloc", "xloc2", "yloc2"]
ONE_IDX_COLS = ["xloc1", "yloc1", "xloc2", "yloc2"]
ZERO_IDX_COLS = ["xloc0", "yloc0", "xloc1", "yloc1"]

input_cols = list(spatial_df.columns)

has_default = all([c in input_cols for c in DEFAULT_COLS])
has_0_idx = all([c in input_cols for c in ZERO_IDX_COLS])
has_1_idx = all([c in input_cols for c in ONE_IDX_COLS])

# if unexpected columns, ask user to confirm
if len(input_cols) != 4 or not (has_default or has_0_idx or has_1_idx):
choice = dj.utils.user_choice(
"Unexpected columns in raw position. Assume "
+ f"{DEFAULT_COLS[:4]}?\n{spatial_df}\n"
)
if choice.lower() not in ["yes", "y"]:
raise ValueError(
f"Unexpected columns in raw position: {input_cols}"
)
spatial_df.columns = DEFAULT_COLS + input_cols[4:]

# Ensure data order, only 4 col
spatial_df = (
spatial_df[DEFAULT_COLS]
if has_default
else spatial_df[ZERO_IDX_COLS]
if has_0_idx
else spatial_df[ONE_IDX_COLS]
)

# rename to default
spatial_df.columns = DEFAULT_COLS

return spatial_df

@staticmethod
def _upsample(
front_LED,
Expand Down Expand Up @@ -378,7 +333,7 @@ def calculate_position_info(
**kwargs,
)

spatial_df = self._fix_col_names(spatial_df)
spatial_df = _fix_col_names(spatial_df)
# Get spatial series properties
time = np.asarray(spatial_df.index) # seconds
position = np.asarray(spatial_df.iloc[:, :4]) # meters
Expand Down Expand Up @@ -898,17 +853,26 @@ def make(self, key):
VideoFile()
& {"nwb_file_name": key["nwb_file_name"], "epoch": epoch}
).fetch1()
io = pynwb.NWBHDF5IO(raw_dir() + video_info["nwb_file_name"], "r")
io = pynwb.NWBHDF5IO(raw_dir + "/" + video_info["nwb_file_name"], "r")
nwb_file = io.read()
nwb_video = nwb_file.objects[video_info["video_file_object_id"]]
video_filename = nwb_video.external_file.value[0]
video_filename = nwb_video.external_file[0]

nwb_base_filename = key["nwb_file_name"].replace(".nwb", "")
output_video_filename = (
f"{nwb_base_filename}_{epoch:02d}_"
f'{key["position_info_param_name"]}.mp4'
)

# ensure standardized column names
raw_position_df = _fix_col_names(raw_position_df)
# if IntervalPositionInfo supersampled position, downsample to video
if position_info_df.shape[0] > raw_position_df.shape[0]:
ind = np.digitize(
raw_position_df.index, position_info_df.index, right=True
)
position_info_df = position_info_df.iloc[ind]

centroids = {
"red": np.asarray(raw_position_df[["xloc", "yloc"]]),
"green": np.asarray(raw_position_df[["xloc2", "yloc2"]]),
Expand All @@ -925,7 +889,7 @@ def make(self, key):

print("Making video...")
self.make_video(
video_filename,
f"{video_dir}/{video_filename}",
centroids,
head_position_mean,
head_orientation_mean,
Expand Down Expand Up @@ -1082,3 +1046,48 @@ def make_video(
video.release()
out.release()
cv2.destroyAllWindows()


def _fix_col_names(spatial_df):
"""Renames columns in spatial dataframe according to previous norm
Accepts unnamed first led, 1 or 0 indexed.
Prompts user for confirmation of renaming unexpected columns.
For backwards compatibility, renames to "xloc", "yloc", "xloc2", "yloc2"
"""

DEFAULT_COLS = ["xloc", "yloc", "xloc2", "yloc2"]
ONE_IDX_COLS = ["xloc1", "yloc1", "xloc2", "yloc2"]
ZERO_IDX_COLS = ["xloc0", "yloc0", "xloc1", "yloc1"]

input_cols = list(spatial_df.columns)

has_default = all([c in input_cols for c in DEFAULT_COLS])
has_0_idx = all([c in input_cols for c in ZERO_IDX_COLS])
has_1_idx = all([c in input_cols for c in ONE_IDX_COLS])

if has_default:
# move the 4 position columns to front, continue
spatial_df = spatial_df[DEFAULT_COLS]
elif has_0_idx:
# move the 4 position columns to front, rename to default, continue
spatial_df = spatial_df[ZERO_IDX_COLS]
spatial_df.columns = DEFAULT_COLS
elif has_1_idx:
# move the 4 position columns to front, rename to default, continue
spatial_df = spatial_df[ONE_IDX_COLS]
spatial_df.columns = DEFAULT_COLS
else:
if len(input_cols) != 4 or not has_default:
choice = dj.utils.user_choice(
"Unexpected columns in raw position. Assume "
+ f"{DEFAULT_COLS[:4]}?\n{spatial_df}\n"
)
if choice.lower() not in ["yes", "y"]:
raise ValueError(
f"Unexpected columns in raw position: {input_cols}"
)
# rename first 4 columns, keep rest. Rest dropped below
spatial_df.columns = DEFAULT_COLS + input_cols[4:]

return spatial_df
1 change: 1 addition & 0 deletions src/spyglass/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,4 +438,5 @@ def debug_mode(self) -> bool:
analysis_dir = sg_config.analysis_dir
sorting_dir = sg_config.sorting_dir
waveform_dir = sg_config.waveform_dir
video_dir = sg_config.video_dir
debug_mode = sg_config.debug_mode

0 comments on commit ddeaf92

Please sign in to comment.