From 2c5d57196dab408d7248065d54aa3a134c396c52 Mon Sep 17 00:00:00 2001 From: Marco Pascucci Date: Mon, 10 Jul 2023 18:20:45 +0200 Subject: [PATCH] recipe(meshes) shift subject meshes on embedding position --- dico_toolbox/mesh.py | 2 +- dico_toolbox/recipes/__init__.py | 2 +- dico_toolbox/recipes/meshes.py | 38 +++++++++++++++++++++++++------- 3 files changed, 32 insertions(+), 10 deletions(-) diff --git a/dico_toolbox/mesh.py b/dico_toolbox/mesh.py index c4f729b..91b65d3 100644 --- a/dico_toolbox/mesh.py +++ b/dico_toolbox/mesh.py @@ -124,4 +124,4 @@ def apply_Talairach_to_mesh(mesh, dxyz, rotation, translation, flip=False): if flip: flip_mesh(mesh) - return mesh + return mesh \ No newline at end of file diff --git a/dico_toolbox/recipes/__init__.py b/dico_toolbox/recipes/__init__.py index 890d01a..2f26724 100644 --- a/dico_toolbox/recipes/__init__.py +++ b/dico_toolbox/recipes/__init__.py @@ -1 +1 @@ -from .meshes import mesh_of_average, mesh_of_averages, mesh_one_point_cloud, mesh_of_point_clouds +from .meshes import mesh_of_average, mesh_of_averages, mesh_one_point_cloud, mesh_of_point_clouds, shift_meshes_in_embedding diff --git a/dico_toolbox/recipes/meshes.py b/dico_toolbox/recipes/meshes.py index a337bc3..3c00279 100644 --- a/dico_toolbox/recipes/meshes.py +++ b/dico_toolbox/recipes/meshes.py @@ -71,7 +71,6 @@ def mesh_one_point_cloud(data): tal = data["talairach"] flip = data['flip'] align = data["align"] - shift = data["shift"] meshing_parameters = data['meshing_parameters'] # generate mesh @@ -110,7 +109,7 @@ def _parse_pool_results(results): return meshes -def mesh_of_point_clouds(pcs, pre_transformation=None, flip=False, post_transformation=None, embedding=None, embedding_scale=1, **meshing_parameters): +def mesh_of_point_clouds(pcs, pre_transformation=None, flip=False, post_transformation=None, **meshing_parameters): """Build the mesh of the pointclouds. Args: @@ -118,24 +117,18 @@ def mesh_of_point_clouds(pcs, pre_transformation=None, flip=False, post_transfor pre_transformation (collection of dict, optional): This transformation is applied before flip. keys = {dxy, rot, tra}. Defaults to None. flip (bool, optional): flip the data. Defaults to False. post_transformation (collection of dict, optional): This transformation is applied after flip. keys = {rot, tra}. Defaults to None. - embedding (DataFrame, optional): The embedding. Defaults to None. - embedding_scale (int, optional): Scale for the final offset of the mesh. Defaults to 1. Returns: _type_: _description_ """ data = [] for name, pc in pcs.items(): - shift = None - if embedding is not None: - shift = embedding.loc[name].values*embedding_scale data.append(dict( name=name, pc=pc, talairach=pre_transformation, # {dxy, rot, tra} flip=flip, align=post_transformation, # {rot, tra}, - shift=shift, # [x,y,z] meshing_parameters=meshing_parameters )) @@ -147,3 +140,32 @@ def mesh_of_point_clouds(pcs, pre_transformation=None, flip=False, post_transfor res = _parse_pool_results(res) return res + + +def shift_meshes_in_embedding(meshes:dict, embedding, scale=1): + """Shift the meshes to an amount proportional to their position in the embedding. + + Args: + -meshes:dict(name,ndarray) name:point-cloud dictionary + -embedding:pandas.DataFrame the embedding of the point-clouds (e.g. isomap axis) + -scale: multiplicative factor for the actual shift. + + Return: + -name:dict(shifted_meshes) + """ + shifted_meshes = dict() + + for name, mesh in tqdm(meshes.items(), desc="shifting"): + shift_v = np.zeros(3) + assert len(embedding.shape)<3, "For displaying, the embedding dimension must be < 3" + + if len(embedding.shape) == 1: + shift_v[0] = embedding.loc[name] + else: + for i,v in enumerate(embedding.loc[name]): + shift_v[i]=v + + + shifted_mesh = shift_aims_mesh(mesh, shift_v, scale=scale) + shifted_meshes[name] = shifted_mesh + return shifted_meshes \ No newline at end of file