diff --git a/meshparty/skeleton_io.py b/meshparty/skeleton_io.py index 885f35b..4180640 100644 --- a/meshparty/skeleton_io.py +++ b/meshparty/skeleton_io.py @@ -1,4 +1,5 @@ import os +from pathlib import Path import h5py import orjson import json @@ -77,11 +78,15 @@ def write_skeleton_h5_by_part( """ - if os.path.isfile(filename): - if overwrite: - os.remove(filename) - else: - return + if isinstance(filename, (str, Path)): + if os.path.isfile(filename): + if overwrite: + os.remove(filename) + else: + raise FileExistsError( + f"File {filename} already exists, use overwrite=True to overwrite" + ) + with h5py.File(filename, "w") as f: f.attrs["file_version"] = FILE_VERSION @@ -141,7 +146,10 @@ def read_skeleton_h5_by_part(filename): overwrite, whether to overwrite file """ - assert os.path.isfile(filename) + # if filename is a string test that it is a file + if isinstance(filename, (str, Path)): + if not os.path.isfile(filename): + raise FileNotFoundError(f"File {filename} not found") with h5py.File(filename, "r") as f: vertices = f["vertices"][()] diff --git a/meshparty/trimesh_io.py b/meshparty/trimesh_io.py index 40ba863..438c5fb 100644 --- a/meshparty/trimesh_io.py +++ b/meshparty/trimesh_io.py @@ -1,8 +1,10 @@ import collections +from pathlib import Path import numpy as np import h5py from scipy import spatial, sparse from sklearn import decomposition + try: from pykdtree.kdtree import KDTree except: @@ -17,11 +19,14 @@ import logging from functools import wraps import cloudvolume -from cloudvolume.datasource.precomputed.mesh.multilod import ShardedMultiLevelPrecomputedMeshSource +from cloudvolume.datasource.precomputed.mesh.multilod import ( + ShardedMultiLevelPrecomputedMeshSource, +) from multiwrapper import multiprocessing_utils as mu import trimesh from trimesh import caching + try: from trimesh import exchange except ImportError: @@ -34,33 +39,36 @@ try: from caveclient import infoservice + allow_framework_client = True except ImportError: - logging.warning( - "Need to pip install caveclient to use dataset_name parameters") + logging.warning("Need to pip install caveclient to use dataset_name parameters") allow_framework_client = False class EmptyMaskException(Exception): """Raised when applying a mask that has all zeros""" + pass -def _get_cv_path_from_info(dataset_name, server_address=None, segmentation_type='graphene'): +def _get_cv_path_from_info( + dataset_name, server_address=None, segmentation_type="graphene" +): """Get the cloudvolume path from a dataset name. Segmentation type should be - either `graphene` or `flat`. + either `graphene` or `flat`. """ if allow_framework_client is False: - logging.warning( - "Need to pip install caveclient to use dataset_name parameters") + logging.warning("Need to pip install caveclient to use dataset_name parameters") return None info = infoservice.InfoServiceClient( - dataset_name=dataset_name, server_address=server_address) - if segmentation_type == 'graphene': - cv_path = info.graphene_source(format_for='cloudvolume') - elif segmentation_type == 'flat': - cv_path = info.flat_segmentation_source(format_for='cloudvolume') + dataset_name=dataset_name, server_address=server_address + ) + if segmentation_type == "graphene": + cv_path = info.graphene_source(format_for="cloudvolume") + elif segmentation_type == "flat": + cv_path = info.flat_segmentation_source(format_for="cloudvolume") else: cv_path = None return cv_path @@ -97,16 +105,17 @@ def read_mesh_h5(filename): AssertionError if the filename is not a file """ - assert os.path.isfile(filename) + if isinstance(filename, (str, Path)): + if not os.path.isfile(filename): + raise FileNotFoundError(f"File {filename} not found") with h5py.File(filename, "r") as f: if "draco" in f.keys(): - mesh_object = DracoPy.decode_buffer_to_mesh( - f["draco"][()].tostring()) + mesh_object = DracoPy.decode_buffer_to_mesh(f["draco"][()].tostring()) vertices = np.array(mesh_object.points).astype(np.float32) if len(vertices.shape) == 1: N = len(vertices) - vertices = vertices.reshape((N//3, 3)) + vertices = vertices.reshape((N // 3, 3)) faces = np.array(mesh_object.faces).astype(np.uint32) else: vertices = f["vertices"][()] @@ -132,9 +141,16 @@ def read_mesh_h5(filename): return vertices, faces, normals, link_edges, node_mask -def write_mesh_h5(filename, vertices, faces, - normals=None, link_edges=None, node_mask=None, - draco=False, overwrite=False): +def write_mesh_h5( + filename, + vertices, + faces, + normals=None, + link_edges=None, + node_mask=None, + draco=False, + overwrite=False, +): """Writes a mesh's vertices, faces (and normals) to an hdf5 file Parameters @@ -144,7 +160,7 @@ def write_mesh_h5(filename, vertices, faces, vertices : np.array a Nx3 x,y,z coordinates (float) faces: np.array - a Mx3 a,b,c index into vertices for triangle faces np.int32 + a Mx3 a,b,c index into vertices for triangle faces np.int32 normals: np.array a Mx3 x,y,z direction for face normals, np.float32 if it doesn't exist (default None) @@ -158,18 +174,20 @@ def write_mesh_h5(filename, vertices, faces, whether to overwrite the file, will return silently if mesh file exists already """ - - if os.path.isfile(filename): - if overwrite: - os.remove(filename) - else: - return + if isinstance(filename, (str, Path)): + if os.path.isfile(filename): + if overwrite: + os.remove(filename) + else: + raise FileExistsError( + f"File {filename} already exists, user overwrite=Ture to overwrite" + ) with h5py.File(filename, "w") as f: if draco: - - buf = DracoPy.encode_mesh_to_buffer(vertices.flatten('C'), - faces.flatten('C')) + buf = DracoPy.encode_mesh_to_buffer( + vertices.flatten("C"), faces.flatten("C") + ) f.create_dataset("draco", data=np.void(buf)) else: f.create_dataset("vertices", data=vertices, compression="gzip") @@ -212,11 +230,11 @@ def read_mesh(filename): """ if filename.endswith(".obj"): - with open(filename, 'r') as fp: + with open(filename, "r") as fp: mesh_d = exchange.obj.load_obj(fp) - vertices = mesh_d['vertices'] - faces = mesh_d['faces'] - normals = mesh_d.get('normals', None) + vertices = mesh_d["vertices"] + faces = mesh_d["faces"] + normals = mesh_d.get("normals", None) link_edges = None node_mask = None elif filename.endswith(".h5"): @@ -228,7 +246,7 @@ def read_mesh(filename): def _download_meshes_thread_graphene(args): - """ Helper to Download meshes into target directory from graphene sources. + """Helper to Download meshes into target directory from graphene sources. Parameters ---------- args : tuple @@ -260,43 +278,55 @@ def _download_meshes_thread_graphene(args): progress: bool show progress bars - """ - seg_ids, cv_path, target_dir, fmt, overwrite, \ - merge_large_components, stitch_mesh_chunks, map_gs_to_https, \ - remove_duplicate_vertices, progress, chunk_size, save_draco = args + """ + ( + seg_ids, + cv_path, + target_dir, + fmt, + overwrite, + merge_large_components, + stitch_mesh_chunks, + map_gs_to_https, + remove_duplicate_vertices, + progress, + chunk_size, + save_draco, + ) = args cv = cloudvolume.CloudVolume(cv_path, use_https=map_gs_to_https) for seg_id in seg_ids: - print('downloading {}'.format(seg_id)) + print("downloading {}".format(seg_id)) target_file = os.path.join(target_dir, f"{seg_id}.h5") if not overwrite and os.path.exists(target_file): - print('file exists {}'.format(target_file)) + print("file exists {}".format(target_file)) continue - print('file does not exist {}'.format(target_file)) + print("file does not exist {}".format(target_file)) try: cv_mesh = cv.mesh.get( - seg_id, remove_duplicate_vertices=remove_duplicate_vertices)[seg_id] + seg_id, remove_duplicate_vertices=remove_duplicate_vertices + )[seg_id] faces = np.array(cv_mesh.faces) if len(faces.shape) == 1: faces = faces.reshape(-1, 3) - mesh = Mesh(vertices=cv_mesh.vertices, - faces=faces, - process=False) + mesh = Mesh(vertices=cv_mesh.vertices, faces=faces, process=False) if merge_large_components: mesh.merge_large_components() if fmt == "hdf5": - write_mesh_h5(f"{target_dir}/{seg_id}.h5", - mesh.vertices, - mesh.faces.flatten(), - link_edges=mesh.link_edges, - draco=save_draco, - overwrite=overwrite) + write_mesh_h5( + f"{target_dir}/{seg_id}.h5", + mesh.vertices, + mesh.faces.flatten(), + link_edges=mesh.link_edges, + draco=save_draco, + overwrite=overwrite, + ) else: mesh.write_to_file(f"{target_dir}/{seg_id}.{fmt}") except Exception as e: @@ -304,7 +334,7 @@ def _download_meshes_thread_graphene(args): def _download_meshes_thread_precomputed(args): - """ Helper to Download meshes into target directory + """Helper to Download meshes into target directory Parameters ---------- @@ -334,22 +364,32 @@ def _download_meshes_thread_precomputed(args): chuck size for deduplification progress: bool show progress bars - """ - seg_ids, cv_path, target_dir, fmt, overwrite, \ - merge_large_components, stitch_mesh_chunks, \ - map_gs_to_https, remove_duplicate_vertices, \ - progress, chunk_size, save_draco = args + """ + ( + seg_ids, + cv_path, + target_dir, + fmt, + overwrite, + merge_large_components, + stitch_mesh_chunks, + map_gs_to_https, + remove_duplicate_vertices, + progress, + chunk_size, + save_draco, + ) = args cv = cloudvolume.CloudVolume( - cv_path, use_https=map_gs_to_https, + cv_path, + use_https=map_gs_to_https, progress=progress, ) download_segids = [ - segid for segid in seg_ids - if overwrite or not os.path.exists( - os.path.join(target_dir, f"{segid}.h5") - ) + segid + for segid in seg_ids + if overwrite or not os.path.exists(os.path.join(target_dir, f"{segid}.h5")) ] already_have = list(set(seg_ids).difference(set(download_segids))) @@ -361,14 +401,14 @@ def _download_meshes_thread_precomputed(args): while len(download_segids): download_now = download_segids[:100] - download_segids = download_segids[len(download_now):] + download_segids = download_segids[len(download_now) :] if isinstance(cv.mesh, ShardedMultiLevelPrecomputedMeshSource): cv_meshes = cv.mesh.get(download_now) - else: + else: cv_meshes = cv.mesh.get( download_now, remove_duplicate_vertices=remove_duplicate_vertices, - fuse=False + fuse=False, ) for segid, cv_mesh in cv_meshes.items(): @@ -382,26 +422,35 @@ def _download_meshes_thread_precomputed(args): mesh.merge_large_components() if fmt == "hdf5": - write_mesh_h5(f"{target_dir}/{segid}.h5", - mesh.vertices, - mesh.faces.flatten(), - link_edges=mesh.link_edges, - draco=save_draco, - overwrite=overwrite) + write_mesh_h5( + f"{target_dir}/{segid}.h5", + mesh.vertices, + mesh.faces.flatten(), + link_edges=mesh.link_edges, + draco=save_draco, + overwrite=overwrite, + ) else: mesh.write_to_file(f"{target_dir}/{segid}.{fmt}") -def download_meshes(seg_ids, target_dir, cv_path, overwrite=True, - n_threads=1, verbose=False, - stitch_mesh_chunks=True, - merge_large_components=False, - remove_duplicate_vertices=False, - map_gs_to_https=True, fmt="hdf5", - save_draco=False, - chunk_size=None, - progress=False): - """ Downloads meshes in target directory (in parallel) +def download_meshes( + seg_ids, + target_dir, + cv_path, + overwrite=True, + n_threads=1, + verbose=False, + stitch_mesh_chunks=True, + merge_large_components=False, + remove_duplicate_vertices=False, + map_gs_to_https=True, + fmt="hdf5", + save_draco=False, + chunk_size=None, + progress=False, +): + """Downloads meshes in target directory (in parallel) will break up the seg_ids into n_threads*3 job blocks or fewer and download them all Parameters @@ -443,7 +492,7 @@ def download_meshes(seg_ids, target_dir, cv_path, overwrite=True, n_jobs = len(seg_ids) # Use the cv path to establish if the source is graphene or not. - if re.search('^graphene://', cv_path) is not None: + if re.search("^graphene://", cv_path) is not None: _download_meshes_thread = _download_meshes_thread_graphene else: _download_meshes_thread = _download_meshes_thread_precomputed @@ -452,53 +501,85 @@ def download_meshes(seg_ids, target_dir, cv_path, overwrite=True, multi_args = [] for seg_id_block in seg_id_blocks: - multi_args.append([seg_id_block, cv_path, target_dir, fmt, - overwrite, merge_large_components, stitch_mesh_chunks, - map_gs_to_https, remove_duplicate_vertices, progress, chunk_size, save_draco]) + multi_args.append( + [ + seg_id_block, + cv_path, + target_dir, + fmt, + overwrite, + merge_large_components, + stitch_mesh_chunks, + map_gs_to_https, + remove_duplicate_vertices, + progress, + chunk_size, + save_draco, + ] + ) if n_jobs == 1: - mu.multiprocess_func(_download_meshes_thread, - multi_args, debug=True, - verbose=verbose, n_threads=n_threads) + mu.multiprocess_func( + _download_meshes_thread, + multi_args, + debug=True, + verbose=verbose, + n_threads=n_threads, + ) else: - mu.multisubprocess_func(_download_meshes_thread, - multi_args, n_threads=n_threads, - package_name="meshparty", n_retries=40) + mu.multisubprocess_func( + _download_meshes_thread, + multi_args, + n_threads=n_threads, + package_name="meshparty", + n_retries=40, + ) class MeshMeta(object): - """ Manager class to keep meshes in memory and seemingless download them + """Manager class to keep meshes in memory and seemingless download them - Parameters - ---------- - cache_size: int - number of meshes to keep in memory adapt this to your available memory and size of meshes - set to zero to use less memory but read from disk cache - cv_path: str - path to pass to cloudvolume.CloudVolume - dataset_name: str - Dataset name to use to get cloudvolume path via infoservice - server_address: str - Server address for the infoservice. Uses a default value if None. - segmentation_type: 'graphene' or 'flat' - Selects which type of segmentation to use. Graphene is for proofreadable segmentations, flat is for static segmentations. - disk_cache_path: str - meshes are dumped to this directory => should be equal to target_dir - in download_meshes (default None will not cache meshes) - map_gs_to_https: bool - whether to change gs paths to https paths, via cloudvolume's use_https option - voxel_scaling: 3x1 numeric - Allows a post-facto multiplicative scaling of vertex locations. These values are NOT saved, just used for analysis and visualization. - """ - - def __init__(self, cache_size=400, cv_path=None, dataset_name=None, server_address=None, segmentation_type='graphene', - disk_cache_path=None, map_gs_to_https=True, voxel_scaling=None): + Parameters + ---------- + cache_size: int + number of meshes to keep in memory adapt this to your available memory and size of meshes + set to zero to use less memory but read from disk cache + cv_path: str + path to pass to cloudvolume.CloudVolume + dataset_name: str + Dataset name to use to get cloudvolume path via infoservice + server_address: str + Server address for the infoservice. Uses a default value if None. + segmentation_type: 'graphene' or 'flat' + Selects which type of segmentation to use. Graphene is for proofreadable segmentations, flat is for static segmentations. + disk_cache_path: str + meshes are dumped to this directory => should be equal to target_dir + in download_meshes (default None will not cache meshes) + map_gs_to_https: bool + whether to change gs paths to https paths, via cloudvolume's use_https option + voxel_scaling: 3x1 numeric + Allows a post-facto multiplicative scaling of vertex locations. These values are NOT saved, just used for analysis and visualization. + """ + def __init__( + self, + cache_size=400, + cv_path=None, + dataset_name=None, + server_address=None, + segmentation_type="graphene", + disk_cache_path=None, + map_gs_to_https=True, + voxel_scaling=None, + ): self._mesh_cache = {} self._cache_size = cache_size if cv_path is None and dataset_name is not None: cv_path = _get_cv_path_from_info( - dataset_name=dataset_name, server_address=server_address, segmentation_type=segmentation_type) + dataset_name=dataset_name, + server_address=server_address, + segmentation_type=segmentation_type, + ) self._cv_path = cv_path self._cv = None self._map_gs_to_https = map_gs_to_https @@ -526,10 +607,11 @@ def disk_cache_path(self): @property def cv(self): - """ cloudvoume.CloudVolume : the cloudvolume object""" + """cloudvoume.CloudVolume : the cloudvolume object""" if self._cv is None and self.cv_path is not None: - self._cv = cloudvolume.CloudVolume(self.cv_path, parallel=10, - use_https=self._map_gs_to_https) + self._cv = cloudvolume.CloudVolume( + self.cv_path, parallel=10, use_https=self._map_gs_to_https + ) return self._cv @@ -539,7 +621,7 @@ def voxel_scaling(self): return self._voxel_scaling def _filename(self, seg_id, lod=None): - """ a method to define what path this seg_id will or is saved to + """a method to define what path this seg_id will or is saved to Parameters ---------- @@ -555,15 +637,20 @@ def _filename(self, seg_id, lod=None): else: return "%s/%d.h5" % (self.disk_cache_path, seg_id) - def mesh(self, filename=None, seg_id=None, cache_mesh=True, - merge_large_components=False, - stitch_mesh_chunks=True, - overwrite_merge_large_components=False, - remove_duplicate_vertices=False, - force_download=False, - lod=0, - voxel_scaling='default'): - """ Loads mesh either from cache, disk or google storage + def mesh( + self, + filename=None, + seg_id=None, + cache_mesh=True, + merge_large_components=False, + stitch_mesh_chunks=True, + overwrite_merge_large_components=False, + remove_duplicate_vertices=False, + force_download=False, + lod=0, + voxel_scaling="default", + ): + """Loads mesh either from cache, disk or google storage Note, if the mesh is in a cache (memory or disk) you will get exactly what was in the cache @@ -596,12 +683,12 @@ def mesh(self, filename=None, seg_id=None, cache_mesh=True, whether to force the mesh to be redownloaded from cloudvolume voxel_scaling: 3 element numeric or None Allows a post-facto multiplicative scaling of vertex locations. These values are NOT saved, just used for analysis and visualization. - By default, pulls from the value in the meshmeta. + By default, pulls from the value in the meshmeta. Returns ------- :obj:`Mesh` - The mesh object of this seg_id + The mesh object of this seg_id Raises ------ @@ -612,66 +699,73 @@ def mesh(self, filename=None, seg_id=None, cache_mesh=True, if not isinstance(self.cv.mesh, ShardedMultiLevelPrecomputedMeshSource): lod = None - if voxel_scaling == 'default': + if voxel_scaling == "default": voxel_scaling = self.voxel_scaling if filename is not None: if filename not in self._mesh_cache: mesh_data = read_mesh(filename) vertices, faces, normals, link_edges, node_mask = mesh_data - mesh = Mesh(vertices=vertices, faces=faces, normals=normals, - link_edges=link_edges, node_mask=node_mask) + mesh = Mesh( + vertices=vertices, + faces=faces, + normals=normals, + link_edges=link_edges, + node_mask=node_mask, + ) if cache_mesh and len(self._mesh_cache) < self.cache_size: self._mesh_cache[filename] = mesh else: mesh = self._mesh_cache[filename] - if self.disk_cache_path is not None and \ - overwrite_merge_large_components: + if self.disk_cache_path is not None and overwrite_merge_large_components: mesh.write_to_file(self._filename(seg_id, lod=lod)) else: if self.disk_cache_path is not None and force_download is False: if os.path.exists(self._filename(seg_id, lod=lod)): - mesh = self.mesh(filename=self._filename(seg_id, lod=lod), - cache_mesh=cache_mesh, - merge_large_components=merge_large_components, - overwrite_merge_large_components=overwrite_merge_large_components, - voxel_scaling=voxel_scaling) + mesh = self.mesh( + filename=self._filename(seg_id, lod=lod), + cache_mesh=cache_mesh, + merge_large_components=merge_large_components, + overwrite_merge_large_components=overwrite_merge_large_components, + voxel_scaling=voxel_scaling, + ) return mesh - assert (seg_id is not None and self.cv is not None) + assert seg_id is not None and self.cv is not None if seg_id not in self._mesh_cache or force_download is True: - if isinstance(self.cv.mesh, ShardedMultiLevelPrecomputedMeshSource): cv_mesh_d = self.cv.mesh.get(seg_id, lod=lod) else: cv_mesh_d = self.cv.mesh.get( - seg_id, remove_duplicate_vertices=remove_duplicate_vertices) + seg_id, remove_duplicate_vertices=remove_duplicate_vertices + ) if isinstance(cv_mesh_d, (dict, collections.defaultdict)): cv_mesh = cv_mesh_d[seg_id] else: cv_mesh = cv_mesh_d faces = np.array(cv_mesh.faces) - if (len(faces.shape) == 1): + if len(faces.shape) == 1: faces = faces.reshape(-1, 3) - mesh = Mesh(vertices=cv_mesh.vertices, - faces=faces) + mesh = Mesh(vertices=cv_mesh.vertices, faces=faces) if isinstance(self.cv.mesh, ShardedMultiLevelPrecomputedMeshSource): - mesh=mesh.process() + mesh = mesh.process() if cache_mesh and len(self._mesh_cache) < self.cache_size: self._mesh_cache[seg_id] = mesh if self.disk_cache_path is not None: - mesh.write_to_file(self._filename( - seg_id, lod=lod), overwrite=force_download) + mesh.write_to_file( + self._filename(seg_id, lod=lod), overwrite=force_download + ) else: mesh = self._mesh_cache[seg_id] mesh.voxel_scaling = voxel_scaling - if (merge_large_components and (len(mesh.link_edges) == 0)) or \ - overwrite_merge_large_components: + if ( + merge_large_components and (len(mesh.link_edges) == 0) + ) or overwrite_merge_large_components: mesh.merge_large_components() return mesh @@ -696,21 +790,30 @@ class Mesh(trimesh.Trimesh): apply_mask: bool whether to apply the node_mask to the result link_edges: np.array - a Kx2 array of indices into vertices that represent extra edges you + a Kx2 array of indices into vertices that represent extra edges you want to store in the mesh graph **kwargs: all the other keyword args you want to pass to :class:`trimesh.Trimesh` """ - def __init__(self, *args, node_mask=None, unmasked_size=None, apply_mask=False, link_edges=None, voxel_scaling=None, **kwargs): - if 'vertices' in kwargs: - vertices_all = kwargs.pop('vertices') + def __init__( + self, + *args, + node_mask=None, + unmasked_size=None, + apply_mask=False, + link_edges=None, + voxel_scaling=None, + **kwargs, + ): + if "vertices" in kwargs: + vertices_all = kwargs.pop("vertices") else: vertices_all = args[0] - if 'faces' in kwargs: - faces_all = kwargs.pop('faces') + if "faces" in kwargs: + faces_all = kwargs.pop("faces") else: # If faces are in args, vertices must also have been in args faces_all = args[1] @@ -721,28 +824,27 @@ def __init__(self, *args, node_mask=None, unmasked_size=None, apply_mask=False, else: unmasked_size = len(vertices_all) if unmasked_size < len(vertices_all): - raise ValueError( - 'Original size cannot be smaller than current size') + raise ValueError("Original size cannot be smaller than current size") self._unmasked_size = unmasked_size if node_mask is None: node_mask = np.full(self.unmasked_size, True, dtype=bool) - elif node_mask.dtype is not np.dtype('bool'): + elif node_mask.dtype is not np.dtype("bool"): node_mask_inds = node_mask.copy() node_mask = np.full(self.unmasked_size, False, dtype=bool) node_mask[node_mask_inds] = True if len(node_mask) != unmasked_size: raise ValueError( - 'The node mask must be the same length as the unmasked size') + "The node mask must be the same length as the unmasked size" + ) self._node_mask = node_mask if apply_mask: if any(self.node_mask == False): nodes_f = vertices_all[self.node_mask] - faces_f = utils.filter_shapes( - np.flatnonzero(node_mask), faces_all)[0] + faces_f = utils.filter_shapes(np.flatnonzero(node_mask), faces_all)[0] else: nodes_f, faces_f = vertices_all, faces_all else: @@ -751,9 +853,9 @@ def __init__(self, *args, node_mask=None, unmasked_size=None, apply_mask=False, new_args = (nodes_f, faces_f) if len(args) > 2: new_args += args[2:] - if kwargs.get('process', False): - print('No silent changing of the mesh is allowed') - kwargs['process'] = False + if kwargs.get("process", False): + print("No silent changing of the mesh is allowed") + kwargs["process"] = False self._voxel_scaling = None self._MeshIndex = None @@ -763,7 +865,8 @@ def __init__(self, *args, node_mask=None, unmasked_size=None, apply_mask=False, if link_edges is not None: if any(self.node_mask == False): self.link_edges = utils.filter_shapes( - np.flatnonzero(node_mask), link_edges)[0] + np.flatnonzero(node_mask), link_edges + )[0] else: self.link_edges = link_edges else: @@ -786,6 +889,7 @@ def wrapper(self, *args, **kwargs): self.voxel_scaling = None func(self, *args, **kwargs) self.voxel_scaling = original_scaling + return wrapper @property @@ -799,7 +903,7 @@ def voxel_scaling(self, new_scaling): @property def inverse_voxel_scaling(self): if self.voxel_scaling is not None: - return 1/self.voxel_scaling + return 1 / self.voxel_scaling else: return None @@ -808,7 +912,7 @@ def _update_voxel_scaling(self, new_scaling): Parameters ---------- - new_scale : 3-element vector + new_scale : 3-element vector Sets the new xyz scale relative to the resolution from the mesh source """ if self.voxel_scaling is not None: @@ -822,7 +926,9 @@ def _update_voxel_scaling(self, new_scaling): self._clear_extra_cached_vertex_keys() - def _clear_extra_cached_vertex_keys(self, keys=['nxgraph', 'csgraph', 'pykdtree', 'kdtree']): + def _clear_extra_cached_vertex_keys( + self, keys=["nxgraph", "csgraph", "pykdtree", "kdtree"] + ): for k in keys: self._cache.delete(k) @@ -830,7 +936,7 @@ def _clear_extra_cached_vertex_keys(self, keys=['nxgraph', 'csgraph', 'pykdtree' def link_edges(self): """numpy.array : a Kx2 set of extra edges you want to store in the mesh graph, :func:`edges` will return this plus :func:`face_edges`""" - return self._data['link_edges'] + return self._data["link_edges"] @link_edges.setter def link_edges(self, values): @@ -840,28 +946,32 @@ def link_edges(self, values): values = np.asanyarray(values, dtype=np.int64) # prevents cache from being invalidated with self._cache: - self._data['link_edges'] = values + self._data["link_edges"] = values # now invalidate all items affected # not sure this is all of them that are not affected # by adding link_edges - self._cache.clear(exclude=['face_normals', - 'vertex_normals', - 'faces_sparse', - 'bounds', - 'extents', - 'scale', - 'centroid', - 'principal_inertia_components', - 'principal_inertia_transform', - 'symmetry', - 'triangles', - 'triangles_tree', - 'triangles_center', - 'triangles_cross', - 'edges', - 'edges_face', - 'edges_unique', - 'edges_unique_length']) + self._cache.clear( + exclude=[ + "face_normals", + "vertex_normals", + "faces_sparse", + "bounds", + "extents", + "scale", + "centroid", + "principal_inertia_components", + "principal_inertia_transform", + "symmetry", + "triangles", + "triangles_tree", + "triangles_center", + "triangles_cross", + "edges", + "edges_face", + "edges_unique", + "edges_unique_length", + ] + ) @caching.cache_decorator def nxgraph(self): @@ -903,15 +1013,14 @@ def n_faces(self): def graph_edges(self): # mesh.edges has bidirectional edges, so we need to pass bidirectional link_edges. if len(self.link_edges) > 0: - link_edges_sym = np.vstack( - (self.link_edges, self.link_edges[:, [1, 0]])) + link_edges_sym = np.vstack((self.link_edges, self.link_edges[:, [1, 0]])) link_edges_sym_unique = np.unique(link_edges_sym, axis=1) else: link_edges_sym_unique = self.link_edges return np.vstack([self.edges, link_edges_sym_unique]) def fix_mesh(self, wiggle_vertices=False, verbose=False): - """ Executes rudimentary fixing function from pymeshfix + """Executes rudimentary fixing function from pymeshfix Good for closing holes, fixes mesh in place will recalculate normals @@ -947,24 +1056,28 @@ def fix_mesh(self, wiggle_vertices=False, verbose=False): # self.faces, # verbose=verbose) self.vertices, self.faces = _meshfix.clean_from_arrays( - self.vertices, self.faces, verbose=verbose) + self.vertices, self.faces, verbose=verbose + ) self.fix_normals() - def get_local_views(self, n_points=None, - max_dist=np.inf, - sample_n_points=None, - fisheye=False, - pc_align=False, - center_node_ids=None, - center_coords=None, - verbose=False, - return_node_ids=False, - svd_solver="auto", - return_faces=False, - adapt_unit_sphere_norm=False, - pc_norm=False): - """ Extracts a local view (points) + def get_local_views( + self, + n_points=None, + max_dist=np.inf, + sample_n_points=None, + fisheye=False, + pc_align=False, + center_node_ids=None, + center_coords=None, + verbose=False, + return_node_ids=False, + svd_solver="auto", + return_faces=False, + adapt_unit_sphere_norm=False, + pc_norm=False, + ): + """Extracts a local view (points) Parameters ---------- @@ -1010,9 +1123,9 @@ def get_local_views(self, n_points=None, center_node_ids, a K long array of center node ids.. useful if you had it choose random points if you had passed center_coords this will not accurately reflect the centers used np.array - return_node_ids, Optional depending on whether return_node_ids. A K long list of + return_node_ids, Optional depending on whether return_node_ids. A K long list of np.array - return_faces, Optional depending on return_faces. faces on the local views, a K list of mx3 triangle faces. + return_faces, Optional depending on return_faces. faces on the local views, a K list of mx3 triangle faces. """ if center_node_ids is None and center_coords is None: @@ -1038,8 +1151,9 @@ def get_local_views(self, n_points=None, else: sample_n_points = np.min([sample_n_points, len(self.vertices)]) - dists, node_ids = self.kdtree.query(center_coords, sample_n_points, - distance_upper_bound=max_dist) + dists, node_ids = self.kdtree.query( + center_coords, sample_n_points, distance_upper_bound=max_dist + ) if n_points is not None: if sample_n_points > n_points: @@ -1050,9 +1164,9 @@ def get_local_views(self, n_points=None, new_node_ids = [] ids = np.arange(0, sample_n_points, dtype=int) for i_sample in range(len(center_coords)): - sample_ids = np.random.choice(ids, n_points, - replace=False, - p=probs[i_sample]) + sample_ids = np.random.choice( + ids, n_points, replace=False, p=probs[i_sample] + ) new_dists.append(dists[i_sample, sample_ids]) new_node_ids.append(node_ids[i_sample, sample_ids]) @@ -1066,8 +1180,7 @@ def get_local_views(self, n_points=None, node_ids = node_ids[:, sample_ids] if verbose: - print(np.mean(dists, axis=1), np.max(dists, axis=1), - np.min(dists, axis=1)) + print(np.mean(dists, axis=1), np.max(dists, axis=1), np.min(dists, axis=1)) if max_dist < np.inf: node_ids = list(node_ids) @@ -1083,9 +1196,9 @@ def get_local_views(self, n_points=None, if pc_align: for i_lv in range(len(local_vertices)): - local_vertices[i_lv] = self._calc_pc_align(local_vertices[i_lv], - svd_solver, - pc_norm=pc_norm) + local_vertices[i_lv] = self._calc_pc_align( + local_vertices[i_lv], svd_solver, pc_norm=pc_norm + ) if adapt_unit_sphere_norm: local_vertices -= center_coords @@ -1095,22 +1208,31 @@ def get_local_views(self, n_points=None, return_tuple = (local_vertices, center_node_ids) if return_node_ids: - return_tuple += (node_ids, ) + return_tuple += (node_ids,) if return_faces: - return_tuple += (self._filter_faces(node_ids), ) + return_tuple += (self._filter_faces(node_ids),) return return_tuple - def get_local_view(self, n_points=None, max_dist=np.inf, - sample_n_points=None, - adapt_unit_sphere_norm=False, - fisheye=False, - pc_align=False, center_node_id=None, - center_coord=None, method="kdtree", verbose=False, - return_node_ids=False, svd_solver="auto", - return_faces=False, pc_norm=False): - """ Single version of get_local_views for backwards compatibility """ + def get_local_view( + self, + n_points=None, + max_dist=np.inf, + sample_n_points=None, + adapt_unit_sphere_norm=False, + fisheye=False, + pc_align=False, + center_node_id=None, + center_coord=None, + method="kdtree", + verbose=False, + return_node_ids=False, + svd_solver="auto", + return_faces=False, + pc_norm=False, + ): + """Single version of get_local_views for backwards compatibility""" assert method == "kdtree" @@ -1120,64 +1242,75 @@ def get_local_view(self, n_points=None, max_dist=np.inf, if center_coord is None: center_coord = self.vertices[center_node_id] - return self.get_local_views(n_points=n_points, - sample_n_points=sample_n_points, - max_dist=max_dist, - pc_align=pc_align, - center_node_ids=[center_node_id], - center_coords=[center_coord], - adapt_unit_sphere_norm=adapt_unit_sphere_norm, - fisheye=fisheye, - verbose=verbose, - return_node_ids=return_node_ids, - svd_solver=svd_solver, - return_faces=return_faces, - pc_norm=pc_norm) + return self.get_local_views( + n_points=n_points, + sample_n_points=sample_n_points, + max_dist=max_dist, + pc_align=pc_align, + center_node_ids=[center_node_id], + center_coords=[center_coord], + adapt_unit_sphere_norm=adapt_unit_sphere_norm, + fisheye=fisheye, + verbose=verbose, + return_node_ids=return_node_ids, + svd_solver=svd_solver, + return_faces=return_faces, + pc_norm=pc_norm, + ) def _filter_faces(self, node_ids): - """ method to return reindexed faces that involve only certain vertices + """method to return reindexed faces that involve only certain vertices Parameters ---------- node_ids: np.array a M long set of indices into vertices that you want to filter faces by - so only return faces that involve these vertices. node_ids has to be sorted! + so only return faces that involve these vertices. node_ids has to be sorted! Returns ------- - np.array + np.array a Kx3 matrix that is a proper faces for a mesh whose vertices = mesh.vertices[node_ids] """ return utils.filter_shapes(node_ids, self.faces) def _filter_graph_edges(self, node_ids): - """ method to return reindexed edges that involve only certain vertices + """method to return reindexed edges that involve only certain vertices Parameters ---------- node_ids: np.array a M long set of indices into vertices that you want to filter graph_edges by - so only return faces that involve these vertices. node_ids has to be sorted! + so only return faces that involve these vertices. node_ids has to be sorted! Returns ------- - np.array + np.array a Kx2 matrix of edges for a mesh whose vertices = mesh.vertices[node_ids] """ return utils.filter_shapes(node_ids, self.graph_edges) @ScalingManagement.original_scaling - def add_link_edges(self, seg_id=None, merge_log=None, datastack_name=None, server_address=None, - close_map_distance=300, client=None, verbose=False, base_resolution=None): - """ add a set of link edges to this mesh from a PyChunkedGraph endpoint - This will ask the pcg server where merges were done and try to calculate + def add_link_edges( + self, + seg_id=None, + merge_log=None, + datastack_name=None, + server_address=None, + close_map_distance=300, + client=None, + verbose=False, + base_resolution=None, + ): + """add a set of link edges to this mesh from a PyChunkedGraph endpoint + This will ask the pcg server where merges were done and try to calculate where edges should be added to reflect the merge operations that have been done on this mesh, linking disconnected portions of the mesh. Parameters ---------- - seg_id: int + seg_id: int the seg_id of this mesh merge_log : dict JSON dict of merge log as it comes out of the chunkedgraph client. If used, must @@ -1199,29 +1332,41 @@ def add_link_edges(self, seg_id=None, merge_log=None, datastack_name=None, serve Resolution of the supervoxel segmentation at its lowest mip. """ if seg_id is None and merge_log is None: - raise ValueError( - 'Must set either seg id or pre-determined merge log') + raise ValueError("Must set either seg id or pre-determined merge log") if merge_log is not None: - link_edges = trimesh_repair.merge_log_edges(self, - merge_log=merge_log, - base_resolution=base_resolution, - close_map_distance=close_map_distance, - verbose=verbose) + link_edges = trimesh_repair.merge_log_edges( + self, + merge_log=merge_log, + base_resolution=base_resolution, + close_map_distance=close_map_distance, + verbose=verbose, + ) else: # Use the get_link_edges approach - link_edges = trimesh_repair.get_link_edges(self, seg_id, datastack_name=datastack_name, - close_map_distance=close_map_distance, - server_address=server_address, - verbose=verbose, - client=client) + link_edges = trimesh_repair.get_link_edges( + self, + seg_id, + datastack_name=datastack_name, + close_map_distance=close_map_distance, + server_address=server_address, + verbose=verbose, + client=client, + ) self.link_edges = np.vstack([self.link_edges, link_edges]) - def get_local_meshes(self, n_points, max_dist=np.inf, center_node_ids=None, - center_coords=None, pc_align=False, pc_norm=False, - fix_meshes=False): - """ Extracts a local mesh + def get_local_meshes( + self, + n_points, + max_dist=np.inf, + center_node_ids=None, + center_coords=None, + pc_align=False, + pc_norm=False, + fix_meshes=False, + ): + """Extracts a local mesh Parameters ---------- @@ -1233,17 +1378,18 @@ def get_local_meshes(self, n_points, max_dist=np.inf, center_node_ids=None, pc_norm: bool fix_meshes: bool """ - local_view_tuple = self.get_local_views(n_points=n_points, - max_dist=max_dist, - center_node_ids=center_node_ids, - center_coords=center_coords, - return_faces=True, - pc_align=pc_align, - pc_norm=pc_norm) + local_view_tuple = self.get_local_views( + n_points=n_points, + max_dist=max_dist, + center_node_ids=center_node_ids, + center_coords=center_coords, + return_faces=True, + pc_align=pc_align, + pc_norm=pc_norm, + ) vertices, _, faces = local_view_tuple - meshes = [Mesh(vertices=v, faces=f) for v, f in - zip(vertices, faces)] + meshes = [Mesh(vertices=v, faces=f) for v, f in zip(vertices, faces)] if fix_meshes: for mesh in meshes: @@ -1252,10 +1398,16 @@ def get_local_meshes(self, n_points, max_dist=np.inf, center_node_ids=None, return meshes - def get_local_mesh(self, n_points=None, max_dist=np.inf, - center_node_id=None, - center_coord=None, pc_align=True, pc_norm=False): - """ Single version of get_local_meshes """ + def get_local_mesh( + self, + n_points=None, + max_dist=np.inf, + center_node_id=None, + center_coord=None, + pc_align=True, + pc_norm=False, + ): + """Single version of get_local_meshes""" if center_node_id is not None: center_node_id = [center_node_id] @@ -1263,28 +1415,29 @@ def get_local_mesh(self, n_points=None, max_dist=np.inf, if center_coord is not None: center_coord = [center_coord] - return self.get_local_meshes(n_points, max_dist=max_dist, - center_node_ids=center_node_id, - center_coords=center_coord, - pc_align=pc_align, pc_norm=pc_norm)[0] + return self.get_local_meshes( + n_points, + max_dist=max_dist, + center_node_ids=center_node_id, + center_coords=center_coord, + pc_align=pc_align, + pc_norm=pc_norm, + )[0] def _calc_pc_align(self, vertices, svd_solver, pc_norm=False): - """ Calculates PC alignment """ + """Calculates PC alignment""" vertices = vertices.copy() if pc_norm: vertices -= vertices.mean(axis=0) vertices /= vertices.std(axis=0) - pca = decomposition.PCA(n_components=3, - svd_solver=svd_solver, - copy=False) + pca = decomposition.PCA(n_components=3, svd_solver=svd_solver, copy=False) return pca.fit_transform(vertices) - def merge_large_components(self, size_threshold=100, max_dist=1000, - dist_step=100): - """ Finds edges between disconnected components + def merge_large_components(self, size_threshold=100, max_dist=1000, dist_step=100): + """Finds edges between disconnected components will add the edges to the existing set of link_edges or start a set of link_edges if there are None Note: can cause self-contacts to be innapropriately merged @@ -1325,16 +1478,22 @@ def merge_large_components(self, size_threshold=100, max_dist=1000, if np.any(kdtrees[i_tree].query_ball_tree(kdtrees[j_tree], max_dist)): for this_dist in range(dist_step, max_dist + dist_step, dist_step): - - pairs = kdtrees[i_tree].query_ball_tree(kdtrees[j_tree], - this_dist) + pairs = kdtrees[i_tree].query_ball_tree( + kdtrees[j_tree], this_dist + ) if np.any(pairs): for i_p, p in enumerate(pairs): if len(p) > 0: - add_edges.extend([[vertex_ids[i_tree][i_p], - vertex_ids[j_tree][v]] - for v in p]) + add_edges.extend( + [ + [ + vertex_ids[i_tree][i_p], + vertex_ids[j_tree][v], + ] + for v in p + ] + ) break print(f"Adding {len(add_edges)} new edges.") @@ -1344,44 +1503,46 @@ def merge_large_components(self, size_threshold=100, max_dist=1000, print("TIME MERGING: %.3fs" % (time.time() - time_start)) def _create_nxgraph(self): - """ Computes networkx graph for this mesh + """Computes networkx graph for this mesh Returns ------- :class:`networkx.Graph` """ - return utils.create_nxgraph(self.vertices, self.graph_edges, euclidean_weight=True, - directed=False) + return utils.create_nxgraph( + self.vertices, self.graph_edges, euclidean_weight=True, directed=False + ) def _create_csgraph(self): - """ Computes scipy.sparse.csgraph with weights equal to euclidean distance + """Computes scipy.sparse.csgraph with weights equal to euclidean distance with directed=False""" - return utils.create_csgraph(self.vertices, self.graph_edges, euclidean_weight=True, - directed=True) + return utils.create_csgraph( + self.vertices, self.graph_edges, euclidean_weight=True, directed=True + ) @property def node_mask(self): - ''' + """ np.array: Returns the node/vertex mask currently applied to the data - ''' + """ return self._node_mask @property def indices_unmasked(self): - ''' + """ np.array: Gets the indices of nodes in the filtered mesh in the unmasked index array - ''' + """ return np.flatnonzero(self.node_mask) @property def unmasked_size(self): - ''' + """ Returns the unmasked number of nodes in the mesh - ''' + """ return self._unmasked_size def apply_mask(self, new_mask, **kwargs): - ''' + """ Makes a new Mesh by adding a new mask to the existing one. new_mask is a boolean array, either of the original vertex space or the current masked length (in which case it is padded with zeros appropriately). @@ -1392,50 +1553,51 @@ def apply_mask(self, new_mask, **kwargs): a N long array of bool where False correponds to vertices that should be masked N needs to equal to mesh.vertices.shape[0] (or the original vertex shape if you are operating on an already masked mesh) - kwargs: + kwargs: keyword arguments to pass on to the new Mesh.__init__ function Returns ------- trimesh_io.Mesh the mesh with the mask applied - ''' + """ if not np.any(new_mask): - raise(EmptyMaskException("new_mask is all False, mesh will be empty")) + raise (EmptyMaskException("new_mask is all False, mesh will be empty")) # We need to express the mask in the current vertex indices if np.size(new_mask) == np.size(self.node_mask): joint_mask = self.node_mask & new_mask new_mask = self.filter_unmasked_boolean(new_mask) elif np.size(new_mask) == self.vertices.shape[0]: - joint_mask = self.node_mask & self.map_boolean_to_unmasked( - new_mask) + joint_mask = self.node_mask & self.map_boolean_to_unmasked(new_mask) else: raise ValueError( - 'Incompatible shape. Must be either original length or current length of vertices.') + "Incompatible shape. Must be either original length or current length of vertices." + ) if self.voxel_scaling is None: new_vertices = self.vertices else: new_vertices = self.vertices * self.inverse_voxel_scaling - new_mesh = Mesh(new_vertices, - self.faces, - node_mask=joint_mask, - unmasked_size=self.unmasked_size, - voxel_scaling=self.voxel_scaling, - **kwargs) + new_mesh = Mesh( + new_vertices, + self.faces, + node_mask=joint_mask, + unmasked_size=self.unmasked_size, + voxel_scaling=self.voxel_scaling, + **kwargs, + ) link_edge_unmask = self.map_indices_to_unmasked(self.link_edges) new_mesh._apply_new_mask_in_place(new_mask, link_edge_unmask) return new_mesh def _apply_new_mask_in_place(self, mask, link_edge_unmask): - """ Internal function for applying masks.. use apply_mask + """Internal function for applying masks.. use apply_mask Use builtin Trimesh tools for masking The new 0 index is the first nonzero element of the mask. Unfortunately, update_vertices maps all masked face values to 0 as well. """ - num_zero_expected = np.sum( - self.faces == np.flatnonzero(mask)[0], axis=1) + num_zero_expected = np.sum(self.faces == np.flatnonzero(mask)[0], axis=1) self.update_vertices(mask) num_zero_new = np.sum(self.faces == 0, axis=1) @@ -1444,7 +1606,7 @@ def _apply_new_mask_in_place(self, mask, link_edge_unmask): self.link_edges = self.filter_unmasked_indices(link_edge_unmask) def map_indices_to_unmasked(self, unmapped_indices): - ''' + """ For a set of masked indices, returns the corresponding unmasked indices Parameters @@ -1456,11 +1618,11 @@ def map_indices_to_unmasked(self, unmapped_indices): ------- np.array the indices mapped back to the original mesh index space - ''' + """ return utils.map_indices_to_unmasked(self.indices_unmasked, unmapped_indices) def map_boolean_to_unmasked(self, unmapped_boolean): - ''' + """ For a boolean index in the masked indices, returns the corresponding unmasked boolean index Parameters @@ -1472,11 +1634,13 @@ def map_boolean_to_unmasked(self, unmapped_boolean): ------- np.array a bool array in the original index space. Is True if the unmapped_boolean suggests it should be. - ''' - return utils.map_boolean_to_unmasked(self.unmasked_size, self.node_mask, unmapped_boolean) + """ + return utils.map_boolean_to_unmasked( + self.unmasked_size, self.node_mask, unmapped_boolean + ) def filter_unmasked_boolean(self, unmasked_boolean): - ''' + """ For an unmasked boolean slice, returns a boolean slice filtered to the masked mesh Parameters @@ -1488,7 +1652,7 @@ def filter_unmasked_boolean(self, unmasked_boolean): ------- np.array returns the elements of unmasked_boolean that are still relevant in the masked index space - ''' + """ return utils.filter_unmasked_boolean(self.node_mask, unmasked_boolean) def filter_unmasked_indices(self, unmasked_shape, mask=None): @@ -1536,7 +1700,7 @@ def filter_unmasked_indices_padded(self, unmasked_shape, mask=None): @ScalingManagement.original_scaling def write_to_file(self, filename, overwrite=True, draco=False): - """ Exports the mesh to any format supported by trimesh + """Exports the mesh to any format supported by trimesh Parameters ---------- @@ -1546,23 +1710,25 @@ def write_to_file(self, filename, overwrite=True, draco=False): '.obj' for wavefront all others supported by :func:`trimesh.exchange.export.export_mesh` """ - if os.path.splitext(filename)[1] == '.h5': - write_mesh_h5(filename, - self.vertices, - self.faces, - normals=self.face_normals, - link_edges=self.link_edges, - node_mask=self.node_mask, - draco=draco, - overwrite=overwrite) + if os.path.splitext(filename)[1] == ".h5": + write_mesh_h5( + filename, + self.vertices, + self.faces, + normals=self.face_normals, + link_edges=self.link_edges, + node_mask=self.node_mask, + draco=draco, + overwrite=overwrite, + ) else: exchange.export.export_mesh(self, filename) @property def index_map(self): - ''' + """ A dict mapping global indices into the masked mesh indices. - ''' + """ if self._index_map is None: self._index_map = defaultdict(lambda: np.nan) for ii, index in enumerate(self.indices_unmasked): @@ -1571,8 +1737,9 @@ def index_map(self): class MaskedMesh(Mesh): - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs): warnings.warn( "use of MaskedMesh deprecated, Mesh now contains all MaskedMesh functionality", - DeprecationWarning) + DeprecationWarning, + ) super(MaskedMesh, self).__init__(*args, **kwargs)