From d50ac4ee03289d9c1a6f76fffcc5ae115f16ea0b Mon Sep 17 00:00:00 2001 From: rtviii Date: Fri, 22 Nov 2024 19:47:58 -0600 Subject: [PATCH] improve voxel indexing via kdtree --- cylinder.py | 169 +++++++++++++-------------- cylinder_parallel.py | 108 +++++++++++++++++ kdtree_approach.py | 90 ++++++++++++++ ribctl/lib/landmarks/constriction.py | 5 +- 4 files changed, 282 insertions(+), 90 deletions(-) create mode 100644 cylinder_parallel.py create mode 100644 kdtree_approach.py diff --git a/cylinder.py b/cylinder.py index 0ff6d3f9..225a55c6 100644 --- a/cylinder.py +++ b/cylinder.py @@ -88,89 +88,86 @@ def transform_points_to_C0(points: np.ndarray, base_point: np.ndarray, axis_poin return points_transformed -RCSB_ID = '3J7Z' -radius = 40 -height = 80 -voxel_size = 1 - -# residues, base, axis = get_npet_cylinder_residues(RCSB_ID, radius=radius, height=height) - -base_point = np.array(PTC_location(RCSB_ID).location) -axis_point = np.array( get_constriction(RCSB_ID) ) -# translation, rotation = get_transformation_to_C0(base, axis) -# t_base = ( base + translation ) @ rotation.T -# t_axis = ( axis + translation ) @ rotation.T - -if os.path.exists('points.npy'): - points = np.load('points.npy') - print("Loaded") -else: - residues= filter_residues_parallel( ribosome_entities(RCSB_ID, 'R'), base_point, axis_point, radius, height, ) - points = np.array([atom.get_coord() for residue in residues for atom in residue.child_list]) - np.save('points.npy', points) - print("Saved") - ... - - -nx = ny = int(2 * radius / voxel_size) + 1 -nz = int(height / voxel_size) + 1 -x = np.linspace(-radius, radius, nx) -y = np.linspace(-radius, radius, ny) -z = np.linspace(0, height, nz) -X, Y, Z = np.meshgrid(x, y, z, indexing='ij') - - -transformed = transform_points_to_C0(points, base, axis) -X_I = np.round(transformed[:,0]) -Y_I = np.round(transformed[:,1]) -Z_I = np.round(transformed[:,2]) - -cylinder_mask = (np.sqrt(X**2 + Y**2) <= radius) -hollow_cylinder = ~cylinder_mask - -# !------------- -# 3. Create point cloud mask -# point_cloud_mask = np.zeros_like(X, dtype=bool) -# for point in zip(X_I, Y_I, Z_I): -# point_cloud_mask |= (X == point[0]) & (Y == point[1]) & (Z == point[2]) - - -# !------------- -radius_around_point = 2.0 # radius of sphere around each point -# point_cloud_mask = np.zeros_like(X, dtype=bool) -# for point in zip(X_I, Y_I, Z_I): -# distance_to_point = np.sqrt( -# (X - point[0])**2 + -# (Y - point[1])**2 + -# (Z - point[2])**2 -# ) -# point_cloud_mask |= (distance_to_point <= radius_around_point) - - -# !------------- -points = np.column_stack((X_I, Y_I, Z_I)) # Shape: (N, 3) -point_cloud_mask = np.zeros_like(X, dtype=bool) - -# Reshape grid coordinates for broadcasting -grid_coords = np.stack([X, Y, Z]) # Shape: (3, nx, ny, nz) -grid_coords = grid_coords.reshape(3, -1) # Shape: (3, nx*ny*nz) - -for point in points: - # Calculate distances using broadcasting - distances = np.sqrt(np.sum((grid_coords.T - point)**2, axis=1)) - # Reshape back to grid shape and add to mask - point_cloud_mask |= (distances.reshape(X.shape) <= radius_around_point) - - -# !------------- - -final_mask = hollow_cylinder | point_cloud_mask -occupied = np.where(final_mask) - -points = np.column_stack(( - x[occupied[0]], - y[occupied[1]], - z[occupied[2]] -)) -occupied_points = pv.PolyData(points) -visualize_pointcloud(occupied_points) \ No newline at end of file +if __name__ == '__main__': + RCSB_ID = '3J7Z' + radius = 40 + height = 80 + voxel_size = 1 + + + base_point = np.array(PTC_location(RCSB_ID).location) + axis_point = np.array( get_constriction(RCSB_ID) ) + + if os.path.exists('points.npy'): + points = np.load('points.npy') + print("Loaded") + else: + residues= filter_residues_parallel( ribosome_entities(RCSB_ID, 'R'), base_point, axis_point, radius, height, ) + points = np.array([atom.get_coord() for residue in residues for atom in residue.child_list]) + np.save('points.npy', points) + print("Saved") + ... + + + nx = ny = int(2 * radius / voxel_size) + 1 + nz = int(height / voxel_size) + 1 + x = np.linspace(-radius, radius, nx) + y = np.linspace(-radius, radius, ny) + z = np.linspace(0, height, nz) + X, Y, Z = np.meshgrid(x, y, z, indexing='ij') + + + transformed = transform_points_to_C0(points, base_point, axis_point) + X_I = np.round(transformed[:,0]) + Y_I = np.round(transformed[:,1]) + Z_I = np.round(transformed[:,2]) + + cylinder_mask = (np.sqrt(X**2 + Y**2) <= radius) + hollow_cylinder = ~cylinder_mask + + # !------------- + # 3. Create point cloud mask + # point_cloud_mask = np.zeros_like(X, dtype=bool) + # for point in zip(X_I, Y_I, Z_I): + # point_cloud_mask |= (X == point[0]) & (Y == point[1]) & (Z == point[2]) + + + # !------------- + radius_around_point = 2.0 # radius of sphere around each point + # point_cloud_mask = np.zeros_like(X, dtype=bool) + # for point in zip(X_I, Y_I, Z_I): + # distance_to_point = np.sqrt( + # (X - point[0])**2 + + # (Y - point[1])**2 + + # (Z - point[2])**2 + # ) + # point_cloud_mask |= (distance_to_point <= radius_around_point) + + + # !------------- + points = np.column_stack((X_I, Y_I, Z_I)) # Shape: (N, 3) + point_cloud_mask = np.zeros_like(X, dtype=bool) + + # Reshape grid coordinates for broadcasting + grid_coords = np.stack([X, Y, Z]) # Shape: (3, nx, ny, nz) + grid_coords = grid_coords.reshape(3, -1) # Shape: (3, nx*ny*nz) + + for point in points: + # Calculate distances using broadcasting + distances = np.sqrt(np.sum((grid_coords.T - point)**2, axis=1)) + # Reshape back to grid shape and add to mask + point_cloud_mask |= (distances.reshape(X.shape) <= radius_around_point) + + + # !------------- + + final_mask = hollow_cylinder | point_cloud_mask + occupied = np.where(final_mask) + + points = np.column_stack(( + x[occupied[0]], + y[occupied[1]], + z[occupied[2]] + )) + occupied_points = pv.PolyData(points) + visualize_pointcloud(occupied_points) \ No newline at end of file diff --git a/cylinder_parallel.py b/cylinder_parallel.py new file mode 100644 index 00000000..36939cee --- /dev/null +++ b/cylinder_parallel.py @@ -0,0 +1,108 @@ +import os +import numpy as np +from typing import Tuple +import multiprocessing as mp +from cylinder import transform_points_to_C0 +from mesh_generation.mes_visualization import visualize_pointcloud +from numba import jit +import pyvista as pv + +from ribctl.lib.landmarks.constriction import get_constriction +from ribctl.lib.landmarks.ptc_via_trna import PTC_location +from ribctl.lib.npet.tunnel_bbox_ptc_constriction import filter_residues_parallel, ribosome_entities + +def chunk_points(points: np.ndarray, n_chunks: int) -> list: + """Split points into chunks for parallel processing""" + chunk_size = len(points) // n_chunks + return [points[i:i + chunk_size] for i in range(0, len(points), chunk_size)] + +@jit(nopython=True) +def process_points_chunk(points: np.ndarray, grid_coords: np.ndarray, radius: float) -> np.ndarray: + """Process a chunk of points with Numba acceleration""" + mask = np.zeros(grid_coords.shape[1], dtype=np.bool_) + for point in points: + distances = np.sqrt(np.sum((grid_coords.T - point)**2, axis=1)) + mask |= (distances <= radius) + return mask + +def parallel_point_cloud_mask(points: np.ndarray, X: np.ndarray, Y: np.ndarray, Z: np.ndarray, + radius_around_point: float) -> np.ndarray: + """Generate point cloud mask using parallel processing""" + # Prepare grid coordinates once + grid_coords = np.stack([X, Y, Z]) + original_shape = X.shape + grid_coords = grid_coords.reshape(3, -1) + + # Determine number of chunks based on CPU cores + n_cores = mp.cpu_count() - 4 + chunks = chunk_points(points, n_cores) + + # Process chunks in parallel + with mp.Pool(n_cores) as pool: + results = pool.starmap(process_points_chunk, + [(chunk, grid_coords, radius_around_point) for chunk in chunks]) + + # Combine results + final_mask = np.any(results, axis=0) + return final_mask.reshape(original_shape) + +def main(): + # Your existing setup code... + RCSB_ID = '3J7Z' + radius = 40 + height = 80 + voxel_size = 1 + ATOM_RADIUS = 2 + + # residues, base, axis = get_npet_cylinder_residues(RCSB_ID, radius=radius, height=height) + + base_point = np.array(PTC_location(RCSB_ID).location) + axis_point = np.array( get_constriction(RCSB_ID) ) + # translation, rotation = get_transformation_to_C0(base, axis) + # t_base = ( base + translation ) @ rotation.T + # t_axis = ( axis + translation ) @ rotation.T + + if os.path.exists('points.npy'): + points = np.load('points.npy') + print("Loaded") + else: + residues= filter_residues_parallel( ribosome_entities(RCSB_ID, 'R'), base_point, axis_point, radius, height, ) + points = np.array([atom.get_coord() for residue in residues for atom in residue.child_list]) + np.save('points.npy', points) + print("Saved") + + nx = ny = int(2 * radius / voxel_size) + 1 + nz = int(height / voxel_size) + 1 + x = np.linspace(-radius, radius, nx) + y = np.linspace(-radius, radius, ny) + z = np.linspace(0, height, nz) + X, Y, Z = np.meshgrid(x, y, z, indexing='ij') + + # Transform points (vectorized) + transformed = transform_points_to_C0(points, base_point, axis_point) + X_I, Y_I, Z_I = transformed.T + points = np.column_stack((X_I, Y_I, Z_I)) + + # Generate cylinder mask (vectorized) + cylinder_mask = (np.sqrt(X**2 + Y**2) <= radius) + hollow_cylinder = ~cylinder_mask + + # Generate point cloud mask in parallel + point_cloud_mask = parallel_point_cloud_mask(points, X, Y, Z, ATOM_RADIUS) + + # Combine masks + final_mask = hollow_cylinder | point_cloud_mask + + # Visualize results + occupied = np.where(~final_mask) + visualization_points = np.column_stack(( + x[occupied[0]], + y[occupied[1]], + z[occupied[2]] + )) + occupied_points = pv.PolyData(visualization_points) + visualize_pointcloud(occupied_points) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/kdtree_approach.py b/kdtree_approach.py new file mode 100644 index 00000000..fb678c9b --- /dev/null +++ b/kdtree_approach.py @@ -0,0 +1,90 @@ +import numpy as np +from scipy.spatial import cKDTree +import pyvista as pv + +from cylinder import transform_points_to_C0 +from mesh_generation.mes_visualization import visualize_pointcloud +from ribctl.lib.landmarks.constriction import get_constriction +from ribctl.lib.landmarks.ptc_via_trna import PTC_location + +def generate_voxel_centers(radius: float, height: float, voxel_size: float) -> tuple: + """Generate centers of all voxels in the grid""" + nx = ny = int(2 * radius / voxel_size) + 1 + nz = int(height / voxel_size) + 1 + + x = np.linspace(-radius, radius, nx) + y = np.linspace(-radius, radius, ny) + z = np.linspace(0, height, nz) + + # Generate all voxel center coordinates + X, Y, Z = np.meshgrid(x, y, z, indexing='ij') + voxel_centers = np.column_stack((X.ravel(), Y.ravel(), Z.ravel())) + + return voxel_centers, (X.shape, x, y, z) + +def create_point_cloud_mask(points: np.ndarray, + radius: float, + height: float, + voxel_size: float = 1.0, + radius_around_point: float = 2.0): + """ + Create point cloud mask using KDTree for efficient spatial queries + """ + # Generate voxel centers + voxel_centers, (grid_shape, x, y, z) = generate_voxel_centers(radius, height, voxel_size) + + # Create KDTree from the transformed points + tree = cKDTree(points) + + # Find all voxels that have points within radius_around_point + # This is much more efficient than checking each point against each voxel + indices = tree.query_ball_point(voxel_centers, radius_around_point) + + # Create mask from the indices + point_cloud_mask = np.zeros(len(voxel_centers), dtype=bool) + point_cloud_mask[[i for i, idx in enumerate(indices) if idx]] = True + + # Reshape mask back to grid shape + point_cloud_mask = point_cloud_mask.reshape(grid_shape) + + # Create cylinder mask + X, Y, Z = np.meshgrid(x, y, z, indexing='ij') + cylinder_mask = (np.sqrt(X**2 + Y**2) <= radius) + hollow_cylinder = ~cylinder_mask + + # Combine masks + final_mask = hollow_cylinder | point_cloud_mask + + return final_mask, (x, y, z) + +def main(): + # Load your points and transform them as before + points = np.load('points.npy') + RCSB_ID = '3J7Z' + base_point = np.array(PTC_location(RCSB_ID).location) + axis_point = np.array(get_constriction(RCSB_ID) ) + print("loaded and got axis") + transformed_points = transform_points_to_C0(points, base_point, axis_point) + + final_mask, (x, y, z) = create_point_cloud_mask( + transformed_points, + radius = 40, + height = 80, + voxel_size = 1.0, + radius_around_point = 2.0 + ) + + # Extract points for visualization + occupied = np.where(~final_mask) + visualization_points = np.column_stack(( + x[occupied[0]], + y[occupied[1]], + z[occupied[2]] + )) + + # Visualize results + occupied_points = pv.PolyData(visualization_points) + visualize_pointcloud(occupied_points) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/ribctl/lib/landmarks/constriction.py b/ribctl/lib/landmarks/constriction.py index 9c051e87..3144d37c 100644 --- a/ribctl/lib/landmarks/constriction.py +++ b/ribctl/lib/landmarks/constriction.py @@ -12,13 +12,10 @@ def get_constriction(rcsb_id: str)->np.ndarray: else: uL4 = ro.get_poly_by_polyclass('uL4') uL22 = ro.get_poly_by_polyclass('uL22') - if uL4 is None or uL22 is None: raise ValueError("Could not find uL4 or uL22 in {}".format(rcsb_id)) - structure = ro.assets.biopython_structure() - - uL4_c :Chain = structure[0][uL4.auth_asym_id] + uL4_c :Chain = structure[0][uL4.auth_asym_id] uL22_c :Chain = structure[0][uL22.auth_asym_id] uL4_coords = [(r.center_of_mass() ) for r in uL4_c.child_list]