Skip to content

Commit

Permalink
recipe(meshes) shift subject meshes on embedding position
Browse files Browse the repository at this point in the history
  • Loading branch information
mpascucci committed Jul 10, 2023
1 parent 261afc0 commit 2c5d571
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 10 deletions.
2 changes: 1 addition & 1 deletion dico_toolbox/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,4 +124,4 @@ def apply_Talairach_to_mesh(mesh, dxyz, rotation, translation, flip=False):
if flip:
flip_mesh(mesh)

return mesh
return mesh
2 changes: 1 addition & 1 deletion dico_toolbox/recipes/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .meshes import mesh_of_average, mesh_of_averages, mesh_one_point_cloud, mesh_of_point_clouds
from .meshes import mesh_of_average, mesh_of_averages, mesh_one_point_cloud, mesh_of_point_clouds, shift_meshes_in_embedding
38 changes: 30 additions & 8 deletions dico_toolbox/recipes/meshes.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def mesh_one_point_cloud(data):
tal = data["talairach"]
flip = data['flip']
align = data["align"]
shift = data["shift"]
meshing_parameters = data['meshing_parameters']

# generate mesh
Expand Down Expand Up @@ -110,32 +109,26 @@ def _parse_pool_results(results):
return meshes


def mesh_of_point_clouds(pcs, pre_transformation=None, flip=False, post_transformation=None, embedding=None, embedding_scale=1, **meshing_parameters):
def mesh_of_point_clouds(pcs, pre_transformation=None, flip=False, post_transformation=None, **meshing_parameters):
"""Build the mesh of the pointclouds.
Args:
pcs (dict): the point clouds
pre_transformation (collection of dict, optional): This transformation is applied before flip. keys = {dxy, rot, tra}. Defaults to None.
flip (bool, optional): flip the data. Defaults to False.
post_transformation (collection of dict, optional): This transformation is applied after flip. keys = {rot, tra}. Defaults to None.
embedding (DataFrame, optional): The embedding. Defaults to None.
embedding_scale (int, optional): Scale for the final offset of the mesh. Defaults to 1.
Returns:
_type_: _description_
"""
data = []
for name, pc in pcs.items():
shift = None
if embedding is not None:
shift = embedding.loc[name].values*embedding_scale
data.append(dict(
name=name,
pc=pc,
talairach=pre_transformation, # {dxy, rot, tra}
flip=flip,
align=post_transformation, # {rot, tra},
shift=shift, # [x,y,z]
meshing_parameters=meshing_parameters
))

Expand All @@ -147,3 +140,32 @@ def mesh_of_point_clouds(pcs, pre_transformation=None, flip=False, post_transfor
res = _parse_pool_results(res)

return res


def shift_meshes_in_embedding(meshes:dict, embedding, scale=1):
"""Shift the meshes to an amount proportional to their position in the embedding.
Args:
-meshes:dict(name,ndarray) name:point-cloud dictionary
-embedding:pandas.DataFrame the embedding of the point-clouds (e.g. isomap axis)
-scale: multiplicative factor for the actual shift.
Return:
-name:dict(shifted_meshes)
"""
shifted_meshes = dict()

for name, mesh in tqdm(meshes.items(), desc="shifting"):
shift_v = np.zeros(3)
assert len(embedding.shape)<3, "For displaying, the embedding dimension must be < 3"

if len(embedding.shape) == 1:
shift_v[0] = embedding.loc[name]
else:
for i,v in enumerate(embedding.loc[name]):
shift_v[i]=v


shifted_mesh = shift_aims_mesh(mesh, shift_v, scale=scale)
shifted_meshes[name] = shifted_mesh
return shifted_meshes

0 comments on commit 2c5d571

Please sign in to comment.