Skip to content

Commit

Permalink
Clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
sfmig committed Oct 24, 2024
1 parent 0a4ce82 commit 0ef92ba
Showing 1 changed file with 9 additions and 211 deletions.
220 changes: 9 additions & 211 deletions examples/notebook_egocentric.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@


# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Import data
# Import sample data
# one individual, 6 keypoints

ds_path = sample_data.fetch_dataset_paths(
"SLEAP_single-mouse_EPM.analysis.h5"
Expand All @@ -32,7 +33,6 @@
# Compute centroids

# get position data array
# (we dont use any of the other data arrays in the dataset)
position = ds.position

# Compute centroid per individual
Expand Down Expand Up @@ -89,16 +89,9 @@
ds_egocentric["position"] = (
position_egocentric # keypoint positions in egocentric coord system
)
# add centroid and heading angle of the posterior2anterior vector in the
# image coordinate system (ICS)
ds_egocentric["centroid_ics"] = centroid
ds_egocentric["heading_angle_ics"] = posterior2anterior_pol.sel(
space_pol="phi"
)


# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Check by plotting keypoint trajectories in egocentric coordinate system
# Check by plotting the keypoint trajectories in the egocentric coordinate system

fig, ax = plt.subplots(1, 1)
for kpt in ds_egocentric.coords["keypoints"].data:
Expand Down Expand Up @@ -165,47 +158,17 @@


# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Check if position rotated matches the result by rotations

# extend posterior2anterior vector to 3D
posterior2anterior_3d = np.pad(
posterior2anterior.data.squeeze(), ((0, 0), (0, 1))
)

# compute rotations to align posterior2anterior vector to x-axis of ECS
# ideally, if nan return nan?
list_rotations = []
list_rssd = []
for vec in posterior2anterior_3d:
# add nan to list if no vector defined
if np.isnan(vec).any():
# list_rotations.append(np.nan)
# list_rssd.append(np.nan)
list_rotations.append(R.from_matrix(np.zeros((3, 3))))
continue

# else compute rotation to x-axis
rrot, rssd = R.align_vectors(
np.array([[1, 0, 0]]), vec, return_sensitivity=False
)
list_rotations.append(rrot)
list_rssd.append(rssd)

#

# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Add rotation matrices to dataset
# Compare to alternative approach using scipy Rotation objects

# expand posterior2anterior data array to 3d space
# ideally, reference vector defined everywhere --- slerp?
# ideally, reference vector defined everywhere -- interpolate with slerp?
posterior2anterior_3d = posterior2anterior.pad(
{"space": (0, 1)}, constant_values={"space": (0)}
)
posterior2anterior_3d.coords["space"] = ["x", "y", "z"]


# compute array of rotation matrices from ICS to ECS
# ideally, reference vector defined everywhere
def compute_rotation_to_align_x_axis(vec):
if np.isnan(vec).any():
return R.from_matrix(np.eye(3)) # ---> identity, maybe not the best?
Expand Down Expand Up @@ -238,7 +201,7 @@ def compute_rotation_to_align_x_axis(vec):

# compute keypoints in ECS (translated and rotated)
position_ego_3d = xr.apply_ufunc(
lambda rot, trans, vec: rot.apply(vec - trans, inverse=False),
lambda rot, trans, vec: rot.apply(vec - trans),
rotation2egocentric, # rot
centroid_3d, # trans
position_3d, # vec
Expand All @@ -248,175 +211,10 @@ def compute_rotation_to_align_x_axis(vec):
)

# compare to other approach
print(position_3d.sel(keypoints="snout").data)
print(position_3d.sel(keypoints="snout").data[-3:, :, :]) # ICS
print("-----")
print(position_ego_3d.sel(keypoints="snout").data)
print(position_ego_3d.sel(keypoints="snout").data[-3:, :, :]) # ECS
print("-----")
print(ds_egocentric.position.sel(keypoints="snout").data)
print(ds_egocentric.position.sel(keypoints="snout").data[-3:, :, :]) # ECS-2D

# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# %%
# add rotation matrices from ICS to ECS to dataset
req_shape = tuple(ds.sizes[d] for d in ["time", "individuals"])
ds["rotation_matrices"] = xr.DataArray(
np.array(list_rotations).reshape(req_shape),
dims=("time", "individuals"),
)

# pad position
ds["position"] = position.pad({"space": (0, 1)})
ds["position"].coords["space"] = ["x", "y", "z"]

# %%
# compute rotated coordinates
ds_ego = ds.copy()

# ds_ego['position'] = ds['rotation_matrices'].apply(ds['position'])

# %%
xr.apply_ufunc(
lambda r, p: r.apply(p),
ds_ego["rotation_matrices"],
ds_ego["position"],
input_core_dims=[[], ["keypoints"]],
output_core_dims=[["keypoints"]],
vectorize=True,
)


# %%%%%
import xarray as xr

# expand data array to 3d space
posterior2anterior_3d = posterior2anterior.pad({"space": (0, 1)})
posterior2anterior_3d.coords["space"] = ["x", "y", "z"]


# %%


def align_vectors_modif(u, v):
if np.isnan(v).any():
return R.from_quat([0, 0, 0, 1])
else:
return R.align_vectors(u, v, return_sensitivity=False)


xr.apply_ufunc(
align_vectors_modif, # lambda u,v: R.align_vectors(u, v),
np.broadcast_to(
np.array([[1, 0, 0]]), posterior2anterior_3d.squeeze().shape
),
posterior2anterior_3d.squeeze(),
input_core_dims=[[], ["individuals"]],
)
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# a = np.array([[1, 0, 0]])
# b = np.array([[0, 1, 0]])

# rrot, rssd = R.align_vectors(
# a,
# b,
# return_sensitivity=False
# )

# print(rrot.as_matrix())

# print(np.testing.assert_allclose(a, rrot.apply(b), atol=1e-10))


# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# plot for a small time window

# time_window = range(1650, 1700) # frames

# fig, ax = plt.subplots(1, 1)
# for mouse_name, col in zip(
# position.individuals.values, ["r", "g"], strict=False
# ):
# # plot centroid
# ax.plot(
# centroid.sel(individuals=mouse_name, time=time_window, space="x"),
# centroid.sel(individuals=mouse_name, time=time_window, space="y"),
# label=mouse_name,
# color=col,
# linestyle="-",
# marker=".",
# markersize=10,
# linewidth=0.5,
# )
# # plot centroid anterior
# ax.plot(
# centroid_anterior.sel(
# individuals=mouse_name, time=time_window, space="x"
# ),
# centroid_anterior.sel(
# individuals=mouse_name, time=time_window, space="y"
# ),
# label=mouse_name,
# color=col,
# linestyle="-",
# marker="x",
# markersize=10,
# linewidth=0.5,
# )
# # plot centroid posterior
# ax.plot(
# centroid_posterior.sel(
# individuals=mouse_name, time=time_window, space="x"
# ),
# centroid_posterior.sel(
# individuals=mouse_name, time=time_window, space="y"
# ),
# label=mouse_name,
# color=col,
# linestyle="-",
# marker="*",
# markersize=10,
# linewidth=0.5,
# )
# # plot keypoints
# ax.scatter(
# x=position.sel(individuals=mouse_name, time=time_window, space="x"),
# y=position.sel(individuals=mouse_name, time=time_window, space="y"),
# s=1,
# )

# # plot vector
# ax.quiver(
# centroid_posterior.sel(
# individuals=mouse_name, time=time_window, space="x"
# ),
# centroid_posterior.sel(
# individuals=mouse_name, time=time_window, space="y"
# ),
# posterior2anterior.sel(
# individuals=mouse_name, time=time_window, space="x"
# ),
# posterior2anterior.sel(
# individuals=mouse_name, time=time_window, space="y"
# ),
# angles="xy",
# scale=1,
# scale_units="xy",
# headwidth=7,
# headlength=9,
# headaxislength=9,
# color="gray",
# )
# # # add text
# # for kpt in position.keypoints.values:
# # ax.text(
# # position.sel(
# # individuals=mouse_name, time=time_window, space='x', keypoints=kpt
# # ).data,
# # position.sel(
# # individuals=mouse_name, time=time_window, space='y', keypoints=kpt
# # ).data,
# # str(kpt),
# # )
# ax.legend()
# ax.axis("equal")
# ax.set_xlabel("x (pixels)")
# ax.set_ylabel("y (pixels)")
# ax.invert_yaxis()

0 comments on commit 0ef92ba

Please sign in to comment.