Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moseq Pipeline #1056

Open
wants to merge 40 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
9d6d3e1
Add organization tables for keypoint pose data
samuelbray32 Jun 10, 2024
28748f1
Add model training tables
samuelbray32 Jun 11, 2024
7d5d66c
apply fit moseq model to data
samuelbray32 Jun 13, 2024
09779a6
add ability to initialize model training from pre-existing model
samuelbray32 Jun 13, 2024
983f179
make model names unique
samuelbray32 Jun 13, 2024
fdaff80
add initial tutorial
samuelbray32 Jun 13, 2024
c2b4d4f
add moseq pipeline diagram
Jun 14, 2024
2e9ea46
get rid of PoseOutput and fetch pose data from PositionOutput
Jun 26, 2024
62c8eef
remove PoseOutput and corresponding references
samuelbray32 Aug 6, 2024
36ad016
allow passing of null entries from DLCPoseEstimation to PositionOutpu…
samuelbray32 Aug 6, 2024
fe3c6d4
Apply suggestions from code review
samuelbray32 Dec 20, 2024
ccf370f
Merge branch 'master' into moseq
samuelbray32 Dec 20, 2024
b8dfd42
lint
samuelbray32 Dec 20, 2024
310abf7
implement changes from review
samuelbray32 Dec 20, 2024
69335a8
rename fetch_video_name to fetch_video_path
samuelbray32 Dec 20, 2024
fc46b37
cleanup outdated usages of PoseOutput
samuelbray32 Dec 20, 2024
c3c7b84
Cleanup errors from removal of PoseOutput
samuelbray32 Dec 20, 2024
e5b2c1b
add get_position_interval_epoch
samuelbray32 Dec 20, 2024
dc8435c
Add method to get training results pdf
samuelbray32 Dec 27, 2024
f92320b
Add description and example of hyperparameter sweep
samuelbray32 Dec 27, 2024
d663e0e
move moseq dir definition to spyglass config
samuelbray32 Dec 27, 2024
3a1c079
move moseq config function to method
samuelbray32 Dec 27, 2024
589bdb2
Update changelog
samuelbray32 Dec 27, 2024
7d30fc1
update pipeline diagram
samuelbray32 Dec 30, 2024
81f3dfa
Apply suggestions from code review
samuelbray32 Jan 2, 2025
3bfcb9f
move moseq into v1 folder
samuelbray32 Jan 2, 2025
8fd900b
add docstrings
samuelbray32 Jan 2, 2025
5a33090
make video symlink more robust
samuelbray32 Jan 3, 2025
bd29c76
cleanup setup_project call
samuelbray32 Jan 3, 2025
c7ea1f5
cleanup config method
samuelbray32 Jan 3, 2025
5901edc
Implement suggestions from code review
samuelbray32 Jan 3, 2025
e5af455
cleanup readability of DLCPosV1 make conditions
samuelbray32 Jan 3, 2025
04c2a37
allow key argument when fetching video path
samuelbray32 Jan 3, 2025
8106929
add moseq dependencies
samuelbray32 Jan 16, 2025
2bff705
Add moseq install instructions to tutorials
samuelbray32 Jan 16, 2025
31014f9
accept key in all get video path functions
samuelbray32 Jan 16, 2025
3dacc40
Merge branch 'master' into moseq
samuelbray32 Jan 16, 2025
f61f10c
fix spelling
samuelbray32 Jan 16, 2025
719486b
Merge branch 'moseq' of https://github.com/LorenFrankLab/spyglass int…
samuelbray32 Jan 16, 2025
24e1150
fix spelling
samuelbray32 Jan 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added notebook-images/moseq_outline.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1,844 changes: 1,844 additions & 0 deletions notebooks/60_MoSeq.ipynb

Large diffs are not rendered by default.

Empty file.
200 changes: 200 additions & 0 deletions src/spyglass/behavior/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
from pathlib import Path
samuelbray32 marked this conversation as resolved.
Show resolved Hide resolved
from uuid import uuid4
samuelbray32 marked this conversation as resolved.
Show resolved Hide resolved

import datajoint as dj
import numpy as np
import pandas as pd

from spyglass.position.position_merge import PositionOutput
from spyglass.utils import SpyglassMixin, SpyglassMixinPart

schema = dj.schema("behavior_core_v1")
samuelbray32 marked this conversation as resolved.
Show resolved Hide resolved


@schema
class PoseGroup(SpyglassMixin, dj.Manual):
definition = """
pose_group_name: varchar(80)
----
bodyparts = NULL: longblob # list of body parts to include in the pose
"""

class Pose(SpyglassMixinPart):
definition = """
-> PoseGroup
-> PositionOutput.proj(pose_merge_id='merge_id')
"""

def create_group(
self,
group_name: str,
merge_ids: list[str],
samuelbray32 marked this conversation as resolved.
Show resolved Hide resolved
bodyparts: list[str] = None,
):
samuelbray32 marked this conversation as resolved.
Show resolved Hide resolved
"""create a group of pose information
samuelbray32 marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
group_name : str
name of the group
samuelbray32 marked this conversation as resolved.
Show resolved Hide resolved
keys : list[dict]
list of keys from PoseOutput to include in the group
bodyparts : list[str], optional
body parts to include in the group, by default None includes all from every set
"""
group_key = {
"pose_group_name": group_name,
samuelbray32 marked this conversation as resolved.
Show resolved Hide resolved
}
self.insert1(
{
**group_key,
"bodyparts": bodyparts,
},
skip_duplicates=True,
samuelbray32 marked this conversation as resolved.
Show resolved Hide resolved
)
for merge_id in merge_ids:
self.Pose.insert1(
{
**group_key,
"pose_merge_id": merge_id,
},
skip_duplicates=True,
samuelbray32 marked this conversation as resolved.
Show resolved Hide resolved
)

def fetch_pose_datasets(
self, key: dict = None, format_for_moseq: bool = False
):
"""fetch pose information for a group of videos

Parameters
----------
key : dict
group key
format_for_moseq : bool, optional
format for MoSeq, by default False

Returns
-------
dict
dictionary of video name to pose dataset
"""
if key is None:
key = {}

bodyparts = (self & key).fetch1("bodyparts")
datasets = {}
for merge_key in (self.Pose & key).proj(merge_id="pose_merge_id"):
video_name = (
Path((PositionOutput & merge_key).fetch_video_name()).stem
+ ".mp4"
samuelbray32 marked this conversation as resolved.
Show resolved Hide resolved
)
bodyparts_df = (PositionOutput & merge_key).fetch_dataframe()
if bodyparts is None:
bodyparts = (
bodyparts_df.keys().get_level_values(0).unique().values
)
bodyparts_df = bodyparts_df[bodyparts]
datasets[video_name] = bodyparts_df
if format_for_moseq:
datasets = format_dataset_for_moseq(datasets, bodyparts)
return datasets

def fetch_video_paths(self, key: dict = None):
"""fetch video paths for a group of videos

Parameters
----------
key : dict
group key

Returns
-------
list[Path]
list of video paths
"""
if key is None:
key = {}
video_paths = [
Path((PositionOutput & merge_key).fetch_video_name())
for merge_key in (self.Pose & key).proj(merge_id="pose_merge_id")
]
return video_paths
samuelbray32 marked this conversation as resolved.
Show resolved Hide resolved


def format_dataset_for_moseq(
datasets: dict[str, pd.DataFrame],
bodyparts: list[str],
coordinate_axes: list[str] = ["x", "y"],
):
"""format pose datasets for MoSeq

Parameters
----------
datasets : dict[str, pd.DataFrame]
dictionary of video name to pose dataset
bodyparts : list[str]
list of body parts to include in the pose

Returns
-------
tuple[dict[str, np.ndarray], dict[str, np.ndarray]
coordinates and confidences for each video
"""
num_keypoints = len(bodyparts)
num_dimensions = len(coordinate_axes)
coordinates = {}
confidences = {}

for video_name, bodyparts_df in datasets.items():
coordinates_i = None
samuelbray32 marked this conversation as resolved.
Show resolved Hide resolved
confidences_i = None
for i, bodypart in enumerate(bodyparts):
part_df = bodyparts_df[bodypart]
print(len(part_df))
if coordinates_i is None:
num_frames = len(part_df)
coordinates_i = np.empty(
(num_frames, num_keypoints, num_dimensions)
)
confidences_i = np.empty((num_frames, num_keypoints))
coordinates_i[:, i, :] = part_df[coordinate_axes].values
confidences_i[:, i] = part_df["likelihood"].values
coordinates[video_name] = coordinates_i
confidences[video_name] = confidences_i
return coordinates, confidences


def results_to_df(results):
for key in results.keys():
samuelbray32 marked this conversation as resolved.
Show resolved Hide resolved
column_names, data = [], []

if "syllable" in results[key].keys():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this looking for substrings? Is this results dict structured as
{'my_syllables': {'syllable': data_obj}?

column_names.append(["syllable"])
data.append(results[key]["syllable"].reshape(-1, 1))

if "centroid" in results[key].keys():
d = results[key]["centroid"].shape[1]
column_names.append(["centroid x", "centroid y", "centroid z"][:d])
data.append(results[key]["centroid"])

if "heading" in results[key].keys():
column_names.append(["heading"])
data.append(results[key]["heading"].reshape(-1, 1))

if "latent_state" in results[key].keys():
latent_dim = results[key]["latent_state"].shape[1]
column_names.append(
[f"latent_state {i}" for i in range(latent_dim)]
samuelbray32 marked this conversation as resolved.
Show resolved Hide resolved
)
data.append(results[key]["latent_state"])

dfs = [
pd.DataFrame(arr, columns=cols)
for arr, cols in zip(data, column_names)
]
df = pd.concat(dfs, axis=1)

for col in df.select_dtypes(include=[np.floating]).columns:
df[col] = df[col].astype(float).round(4)

return df
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The indenting here makes it look like you're only working with the first key in results and ignoring the rest

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took a stab at a version that removes conditionals to make it easier to see how the data types are differential handled.

def results_to_df(results):
    centroid_dim = results.get("centroid", np.array([])).shape[1]
    latent_dim = results.get("latent", np.array([])).shape[1]

    type_to_col = {
        "syllable": "syllable",
        "centroid": ["centroid x", "centroid y", "centroid z"][:centroid_dim],
        "heading": "heading",
        "latent_state": [f"latent_state {i}" for i in range(latent_dim)],
    }

    for key, val in results.items():
        column_names, data = [], []
        type_to_data = {
            "syllable": val.get("syllable", np.array([])).reshape(-1, 1),
            "centroid": val.get("centroid", np.array([])),
            "heading": val.get("heading", np.array([])).reshape(-1, 1),
            "latent_state": val.get("latent", np.array([])),
        }

        for data_type, cols in type_to_col:
            if data_type in results:
                column_names.append(cols)
                data.append(type_to_data[data_type])

        df = pd.concat(
            [
                pd.DataFrame(arr, columns=cols)
                for arr, cols in zip(data, column_names)
            ],
            axis=1,
        )

        for col in df.select_dtypes(include=[np.floating]).columns:
            df[col] = df[col].astype(float).round(4)

        return df

Loading