Skip to content

Commit

Permalink
finishing touches
Browse files Browse the repository at this point in the history
  • Loading branch information
niksirbi committed Oct 22, 2024
1 parent cb7e4cf commit 7ec347d
Showing 1 changed file with 108 additions and 79 deletions.
187 changes: 108 additions & 79 deletions examples/convert_file_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,44 +16,64 @@
# In the following example, we will load a dataset from a
# SLEAP file, modify the keypoints (rename, delete, reorder),
# and save the modified dataset as a DeepLabCut file.
#
# We'll first walk through each step separately, and then
# combine them into a single function that can be applied
# to multiple files at once.

# %%
# Imports
# -------
import pathlib
import tempfile
from pathlib import Path

from movement import sample_data
from movement.io import load_poses, save_poses

# %%
# Load the dataset
# --------------------
# This should the location of a file output by one of
# our supported pose estimation
# supported pose estimation
# frameworks (e.g., DeepLabCut, SLEAP), containing predicted pose tracks.
# ----------------
# We'll start with the path to a file output by one of
# our :ref:`supported pose estimation frameworks<target-formats>`.
# For example, the path could be something like:

# uncomment and edit the following line to point to your own local file
# file_path = "/path/to/my/data.h5"

# %%
# For the sake of this example, we will use the path to one of
# the sample datasets provided with ``movement``.

fpath = sample_data.fetch_dataset_paths("SLEAP_single-mouse_EPM.analysis.h5")[
"poses"
]
print(fpath)
file_path = sample_data.fetch_dataset_paths(
"SLEAP_single-mouse_EPM.analysis.h5"
)["poses"]
print(file_path)

# %%
# Now let's load this file into an xarray dataset, which we can then
# modify to our liking.
ds = load_poses.from_sleap_file(fpath, fps=30)
print(ds)
# Now let's load this file into a
# :ref:`movement poses dataset<target-poses-and-bboxes-dataset>`,
# which we can then modify to our liking.

ds = load_poses.from_sleap_file(file_path, fps=30)
print(ds, "\n")
print("Individuals:", ds.coords["individuals"].values)
print("Keypoints:", ds.coords["keypoints"].values)


# %%
# .. note::
# If you're running this code in a Jupyter notebook,
# you can just type ``ds`` (instead of printing it)
# to explore the dataset interactively.

# %%
# Rename keypoints
# --------------------------------
# Create a dictionary that maps old keypoint names to new ones
# ----------------
# We start with a dictionary that maps old keypoint names to new ones.
# Next, we define a function that takes that dictionary and a dataset
# as inputs, and returns a modified dataset. Notice that under the hood
# this function calls :meth:`xarray.Dataset.assign_coords`.

rename_dict = {
"snout": "nose",
"left_ear": "earL",
Expand All @@ -64,83 +84,65 @@
}


# %%
# Now we can run the following function,
# to rename the keypoints as defined in ``rename_dict``.


# the keypoints have been renamed.
# this function takes the dataset and the rename_dict as input.
def rename_keypoints(ds, rename_dict):
# get the current names of the keypoints
keypoint_names = ds.coords["keypoints"].values
print("Original keypoints:", keypoint_names)

# rename the keypoints
if not rename_dict:
print("No keypoints to rename. Skipping renaming step.")
else:
new_keypoints = [rename_dict.get(kp, str(kp)) for kp in keypoint_names]
print("New keypoints:", new_keypoints)
# Assign the modified values back to the Dataset
ds = ds.assign_coords(keypoints=new_keypoints)
return ds


# %%
# To prove to ourselves that the keypoints have been renamed,
# we can print the keypoints in the modified dataset.
# Let's apply the function to our dataset and see the results.
ds_renamed = rename_keypoints(ds, rename_dict)
print("Keypoints in modified dataset:", ds_renamed.coords["keypoints"].values)


# %%
# Delete Keypoints
# Delete keypoints
# -----------------
# Let's create a list of keypoints to delete.
# to delete modify this list accordingly
# In this case, we choose to get rid of the ``tailend`` keypoint,
# which is often hard to reliably track.
# We delete it using :meth:`xarray.Dataset.drop_sel`,
# wrapped in an appropriately named function.

kps_to_delete = ["tailend"]
keypoints_to_delete = ["tailend"]


# %%
# Now we can go ahead and delete these keypoints
# using an appropriate function.
def delete_keypoints(ds, delete_keypoints):
if not delete_keypoints:
print("No keypoints to delete. Skipping deleting step.")
else:
# Delete the specified keypoints
# and their corresponding data
# Delete the specified keypoints and their corresponding data
ds = ds.drop_sel(keypoints=delete_keypoints)
return ds


# %%
# To prove to ourselves that the keypoints have been deleted,
# we can print the keypoints in the modified dataset.

ds_deleted = delete_keypoints(ds_renamed, kps_to_delete)
ds_deleted = delete_keypoints(ds_renamed, keypoints_to_delete)
print("Keypoints in modified dataset:", ds_deleted.coords["keypoints"].values)


# %%
# Reorder keypoints
# ------------------
# Again create a list with the
# Let's list the keypoints in the desired order.
# We start with a list of keypoints in the desired order
# (in this case, we'll just swap the order of the left and right ears).
# We then use :meth:`xarray.Dataset.reindex`, wrapped in yet another function.

ordered_keypoints = ["nose", "earR", "earL", "middle", "tailbase"]


# %%
# Now we can go ahead and reorder
# those keypoints
def reorder_keypoints(ds, ordered_keypoints):
# reorder the keypoints
if not ordered_keypoints:
print("No keypoints to reorder. Skipping reordering step.")
else:
# Reorder the keypoints in the Dataset
ds = ds.reindex(keypoints=ordered_keypoints)
return ds

Expand All @@ -151,52 +153,79 @@ def reorder_keypoints(ds, ordered_keypoints):
)

# %%
# # One function to rule them all
# # -----------------------------
# # Now that we know how to rename, delete, and reorder keypoints,
# # let's put it all together in a single function,
# # and see how we'd use this in a real-world scenario.
# #
# # The following function will convert all files in a folder
# # (that end with a specified suffix) to the desired format.
# # Each file will be loaded, modified, and saved to a new file.
# Save the modified dataset
# ---------------------------
# Now that we have modified the dataset to our liking,
# let's save it to a .csv file in the DeepLabCut format.
# In this case, we save the file to a temporary
# directory, and we use the same file name
# as the original, but ending in ``_dlc.csv``.
# You will need to specify a different ``target_dir`` and edit
# the ``dest_path`` variable to your liking.

target_dir = tempfile.mkdtemp()
dest_path = Path(target_dir) / f"{file_path.stem}_dlc.csv"

save_poses.to_dlc_file(ds_reordered, dest_path, split_individuals=False)
print(f"Saved modified dataset to {dest_path}.")

# %%
# .. note::
# The ``split_individuals`` argument allows you to save
# a dataset with multiple individuals as separate files,
# with the individual ID appended to each file name.
# In this case, we set it to ``False`` because we only have
# one individual in the dataset, and we don't need its name
# appended to the file name.


# %%
# One function to rule them all
# -----------------------------
# Since we know how to rename, delete, and reorder keypoints,
# let's put it all together in a single function
# and see how we could apply it to multiple files at once,
# as we might do in a real-world scenario.
#
# The following function will convert all files in a folder
# (that end with a specified suffix) from SLEAP to DeepLabCut format.
# Each file will be loaded, modified according to the
# ``rename_dict``, ``keypoints_to_delete``, and ``ordered_keypoints``
# we've defined above, and saved to the target directory.


data_dir = "/path/to/your/data/"
target_dir = "/path/to/your/target/data/"


def convert_all(data_dir, target_dir, suffix=".slp"):
source_folder = pathlib.Path(data_dir)
fpaths = list(source_folder.rglob(f"*{suffix}"))
source_folder = Path(data_dir)
file_paths = list(source_folder.rglob(f"*{suffix}"))

for fpath in fpaths:
fpath = pathlib.Path(fpath)
target_path = pathlib.Path(target_dir)
for file_path in file_paths:
file_path = Path(file_path)

# this determines the filename of your modified file
# change it if you like to change the filename
dest_path = target_path / f"{fpath.stem}_dlc.csv"
# this determines the file names for the modified files
dest_path = Path(target_dir) / f"{file_path.stem}_dlc.csv"

if dest_path.exists():
print(f"Skipping {fpath} as {dest_path} already exists.")
return

if fpath.exists():
print(f"processing: {fpath}")
# load the data
ds = load_poses.from_sleap_file(fpath)
print(f"Skipping {file_path} as {dest_path} already exists.")
continue

if file_path.exists():
print(f"Processing: {file_path}")
# load the data from SLEAP file
ds = load_poses.from_sleap_file(file_path)
# modify the data
ds_renamed = rename_keypoints(ds, rename_dict)
ds_deleted = delete_keypoints(ds_renamed, kps_to_delete)
ds_deleted = delete_keypoints(ds_renamed, keypoints_to_delete)
ds_reordered = reorder_keypoints(ds_deleted, ordered_keypoints)
# save poses to dlc file format
save_poses.to_dlc_file(ds_reordered, dest_path)

# save modified data to a DeepLabCut file
save_poses.to_dlc_file(
ds_reordered, dest_path, split_individuals=False
)
else:
raise ValueError(
f"File '{fpath}' does not exist. "
f"File '{file_path}' does not exist. "
f"Please check the file path and try again."
)


# %%

0 comments on commit 7ec347d

Please sign in to comment.