Skip to content

Commit

Permalink
improve voxel indexing via kdtree
Browse files Browse the repository at this point in the history
  • Loading branch information
rtviii committed Nov 23, 2024
1 parent de890e7 commit d50ac4e
Show file tree
Hide file tree
Showing 4 changed files with 282 additions and 90 deletions.
169 changes: 83 additions & 86 deletions cylinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,89 +88,86 @@ def transform_points_to_C0(points: np.ndarray, base_point: np.ndarray, axis_poin
return points_transformed


RCSB_ID = '3J7Z'
radius = 40
height = 80
voxel_size = 1

# residues, base, axis = get_npet_cylinder_residues(RCSB_ID, radius=radius, height=height)

base_point = np.array(PTC_location(RCSB_ID).location)
axis_point = np.array( get_constriction(RCSB_ID) )
# translation, rotation = get_transformation_to_C0(base, axis)
# t_base = ( base + translation ) @ rotation.T
# t_axis = ( axis + translation ) @ rotation.T

if os.path.exists('points.npy'):
points = np.load('points.npy')
print("Loaded")
else:
residues= filter_residues_parallel( ribosome_entities(RCSB_ID, 'R'), base_point, axis_point, radius, height, )
points = np.array([atom.get_coord() for residue in residues for atom in residue.child_list])
np.save('points.npy', points)
print("Saved")
...


nx = ny = int(2 * radius / voxel_size) + 1
nz = int(height / voxel_size) + 1
x = np.linspace(-radius, radius, nx)
y = np.linspace(-radius, radius, ny)
z = np.linspace(0, height, nz)
X, Y, Z = np.meshgrid(x, y, z, indexing='ij')


transformed = transform_points_to_C0(points, base, axis)
X_I = np.round(transformed[:,0])
Y_I = np.round(transformed[:,1])
Z_I = np.round(transformed[:,2])

cylinder_mask = (np.sqrt(X**2 + Y**2) <= radius)
hollow_cylinder = ~cylinder_mask

# !-------------
# 3. Create point cloud mask
# point_cloud_mask = np.zeros_like(X, dtype=bool)
# for point in zip(X_I, Y_I, Z_I):
# point_cloud_mask |= (X == point[0]) & (Y == point[1]) & (Z == point[2])


# !-------------
radius_around_point = 2.0 # radius of sphere around each point
# point_cloud_mask = np.zeros_like(X, dtype=bool)
# for point in zip(X_I, Y_I, Z_I):
# distance_to_point = np.sqrt(
# (X - point[0])**2 +
# (Y - point[1])**2 +
# (Z - point[2])**2
# )
# point_cloud_mask |= (distance_to_point <= radius_around_point)


# !-------------
points = np.column_stack((X_I, Y_I, Z_I)) # Shape: (N, 3)
point_cloud_mask = np.zeros_like(X, dtype=bool)

# Reshape grid coordinates for broadcasting
grid_coords = np.stack([X, Y, Z]) # Shape: (3, nx, ny, nz)
grid_coords = grid_coords.reshape(3, -1) # Shape: (3, nx*ny*nz)

for point in points:
# Calculate distances using broadcasting
distances = np.sqrt(np.sum((grid_coords.T - point)**2, axis=1))
# Reshape back to grid shape and add to mask
point_cloud_mask |= (distances.reshape(X.shape) <= radius_around_point)


# !-------------

final_mask = hollow_cylinder | point_cloud_mask
occupied = np.where(final_mask)

points = np.column_stack((
x[occupied[0]],
y[occupied[1]],
z[occupied[2]]
))
occupied_points = pv.PolyData(points)
visualize_pointcloud(occupied_points)
if __name__ == '__main__':
RCSB_ID = '3J7Z'
radius = 40
height = 80
voxel_size = 1


base_point = np.array(PTC_location(RCSB_ID).location)
axis_point = np.array( get_constriction(RCSB_ID) )

if os.path.exists('points.npy'):
points = np.load('points.npy')
print("Loaded")
else:
residues= filter_residues_parallel( ribosome_entities(RCSB_ID, 'R'), base_point, axis_point, radius, height, )
points = np.array([atom.get_coord() for residue in residues for atom in residue.child_list])
np.save('points.npy', points)
print("Saved")
...


nx = ny = int(2 * radius / voxel_size) + 1
nz = int(height / voxel_size) + 1
x = np.linspace(-radius, radius, nx)
y = np.linspace(-radius, radius, ny)
z = np.linspace(0, height, nz)
X, Y, Z = np.meshgrid(x, y, z, indexing='ij')


transformed = transform_points_to_C0(points, base_point, axis_point)
X_I = np.round(transformed[:,0])
Y_I = np.round(transformed[:,1])
Z_I = np.round(transformed[:,2])

cylinder_mask = (np.sqrt(X**2 + Y**2) <= radius)
hollow_cylinder = ~cylinder_mask

# !-------------
# 3. Create point cloud mask
# point_cloud_mask = np.zeros_like(X, dtype=bool)
# for point in zip(X_I, Y_I, Z_I):
# point_cloud_mask |= (X == point[0]) & (Y == point[1]) & (Z == point[2])


# !-------------
radius_around_point = 2.0 # radius of sphere around each point
# point_cloud_mask = np.zeros_like(X, dtype=bool)
# for point in zip(X_I, Y_I, Z_I):
# distance_to_point = np.sqrt(
# (X - point[0])**2 +
# (Y - point[1])**2 +
# (Z - point[2])**2
# )
# point_cloud_mask |= (distance_to_point <= radius_around_point)


# !-------------
points = np.column_stack((X_I, Y_I, Z_I)) # Shape: (N, 3)
point_cloud_mask = np.zeros_like(X, dtype=bool)

# Reshape grid coordinates for broadcasting
grid_coords = np.stack([X, Y, Z]) # Shape: (3, nx, ny, nz)
grid_coords = grid_coords.reshape(3, -1) # Shape: (3, nx*ny*nz)

for point in points:
# Calculate distances using broadcasting
distances = np.sqrt(np.sum((grid_coords.T - point)**2, axis=1))
# Reshape back to grid shape and add to mask
point_cloud_mask |= (distances.reshape(X.shape) <= radius_around_point)


# !-------------

final_mask = hollow_cylinder | point_cloud_mask
occupied = np.where(final_mask)

points = np.column_stack((
x[occupied[0]],
y[occupied[1]],
z[occupied[2]]
))
occupied_points = pv.PolyData(points)
visualize_pointcloud(occupied_points)
108 changes: 108 additions & 0 deletions cylinder_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import os
import numpy as np
from typing import Tuple
import multiprocessing as mp
from cylinder import transform_points_to_C0
from mesh_generation.mes_visualization import visualize_pointcloud
from numba import jit
import pyvista as pv

from ribctl.lib.landmarks.constriction import get_constriction
from ribctl.lib.landmarks.ptc_via_trna import PTC_location
from ribctl.lib.npet.tunnel_bbox_ptc_constriction import filter_residues_parallel, ribosome_entities

def chunk_points(points: np.ndarray, n_chunks: int) -> list:
"""Split points into chunks for parallel processing"""
chunk_size = len(points) // n_chunks
return [points[i:i + chunk_size] for i in range(0, len(points), chunk_size)]

@jit(nopython=True)
def process_points_chunk(points: np.ndarray, grid_coords: np.ndarray, radius: float) -> np.ndarray:
"""Process a chunk of points with Numba acceleration"""
mask = np.zeros(grid_coords.shape[1], dtype=np.bool_)
for point in points:
distances = np.sqrt(np.sum((grid_coords.T - point)**2, axis=1))
mask |= (distances <= radius)
return mask

def parallel_point_cloud_mask(points: np.ndarray, X: np.ndarray, Y: np.ndarray, Z: np.ndarray,
radius_around_point: float) -> np.ndarray:
"""Generate point cloud mask using parallel processing"""
# Prepare grid coordinates once
grid_coords = np.stack([X, Y, Z])
original_shape = X.shape
grid_coords = grid_coords.reshape(3, -1)

# Determine number of chunks based on CPU cores
n_cores = mp.cpu_count() - 4
chunks = chunk_points(points, n_cores)

# Process chunks in parallel
with mp.Pool(n_cores) as pool:
results = pool.starmap(process_points_chunk,
[(chunk, grid_coords, radius_around_point) for chunk in chunks])

# Combine results
final_mask = np.any(results, axis=0)
return final_mask.reshape(original_shape)

def main():
# Your existing setup code...
RCSB_ID = '3J7Z'
radius = 40
height = 80
voxel_size = 1
ATOM_RADIUS = 2

# residues, base, axis = get_npet_cylinder_residues(RCSB_ID, radius=radius, height=height)

base_point = np.array(PTC_location(RCSB_ID).location)
axis_point = np.array( get_constriction(RCSB_ID) )
# translation, rotation = get_transformation_to_C0(base, axis)
# t_base = ( base + translation ) @ rotation.T
# t_axis = ( axis + translation ) @ rotation.T

if os.path.exists('points.npy'):
points = np.load('points.npy')
print("Loaded")
else:
residues= filter_residues_parallel( ribosome_entities(RCSB_ID, 'R'), base_point, axis_point, radius, height, )
points = np.array([atom.get_coord() for residue in residues for atom in residue.child_list])
np.save('points.npy', points)
print("Saved")

nx = ny = int(2 * radius / voxel_size) + 1
nz = int(height / voxel_size) + 1
x = np.linspace(-radius, radius, nx)
y = np.linspace(-radius, radius, ny)
z = np.linspace(0, height, nz)
X, Y, Z = np.meshgrid(x, y, z, indexing='ij')

# Transform points (vectorized)
transformed = transform_points_to_C0(points, base_point, axis_point)
X_I, Y_I, Z_I = transformed.T
points = np.column_stack((X_I, Y_I, Z_I))

# Generate cylinder mask (vectorized)
cylinder_mask = (np.sqrt(X**2 + Y**2) <= radius)
hollow_cylinder = ~cylinder_mask

# Generate point cloud mask in parallel
point_cloud_mask = parallel_point_cloud_mask(points, X, Y, Z, ATOM_RADIUS)

# Combine masks
final_mask = hollow_cylinder | point_cloud_mask

# Visualize results
occupied = np.where(~final_mask)
visualization_points = np.column_stack((
x[occupied[0]],
y[occupied[1]],
z[occupied[2]]
))
occupied_points = pv.PolyData(visualization_points)
visualize_pointcloud(occupied_points)


if __name__ == '__main__':
main()
90 changes: 90 additions & 0 deletions kdtree_approach.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import numpy as np
from scipy.spatial import cKDTree
import pyvista as pv

from cylinder import transform_points_to_C0
from mesh_generation.mes_visualization import visualize_pointcloud
from ribctl.lib.landmarks.constriction import get_constriction
from ribctl.lib.landmarks.ptc_via_trna import PTC_location

def generate_voxel_centers(radius: float, height: float, voxel_size: float) -> tuple:
"""Generate centers of all voxels in the grid"""
nx = ny = int(2 * radius / voxel_size) + 1
nz = int(height / voxel_size) + 1

x = np.linspace(-radius, radius, nx)
y = np.linspace(-radius, radius, ny)
z = np.linspace(0, height, nz)

# Generate all voxel center coordinates
X, Y, Z = np.meshgrid(x, y, z, indexing='ij')
voxel_centers = np.column_stack((X.ravel(), Y.ravel(), Z.ravel()))

return voxel_centers, (X.shape, x, y, z)

def create_point_cloud_mask(points: np.ndarray,
radius: float,
height: float,
voxel_size: float = 1.0,
radius_around_point: float = 2.0):
"""
Create point cloud mask using KDTree for efficient spatial queries
"""
# Generate voxel centers
voxel_centers, (grid_shape, x, y, z) = generate_voxel_centers(radius, height, voxel_size)

# Create KDTree from the transformed points
tree = cKDTree(points)

# Find all voxels that have points within radius_around_point
# This is much more efficient than checking each point against each voxel
indices = tree.query_ball_point(voxel_centers, radius_around_point)

# Create mask from the indices
point_cloud_mask = np.zeros(len(voxel_centers), dtype=bool)
point_cloud_mask[[i for i, idx in enumerate(indices) if idx]] = True

# Reshape mask back to grid shape
point_cloud_mask = point_cloud_mask.reshape(grid_shape)

# Create cylinder mask
X, Y, Z = np.meshgrid(x, y, z, indexing='ij')
cylinder_mask = (np.sqrt(X**2 + Y**2) <= radius)
hollow_cylinder = ~cylinder_mask

# Combine masks
final_mask = hollow_cylinder | point_cloud_mask

return final_mask, (x, y, z)

def main():
# Load your points and transform them as before
points = np.load('points.npy')
RCSB_ID = '3J7Z'
base_point = np.array(PTC_location(RCSB_ID).location)
axis_point = np.array(get_constriction(RCSB_ID) )
print("loaded and got axis")
transformed_points = transform_points_to_C0(points, base_point, axis_point)

final_mask, (x, y, z) = create_point_cloud_mask(
transformed_points,
radius = 40,
height = 80,
voxel_size = 1.0,
radius_around_point = 2.0
)

# Extract points for visualization
occupied = np.where(~final_mask)
visualization_points = np.column_stack((
x[occupied[0]],
y[occupied[1]],
z[occupied[2]]
))

# Visualize results
occupied_points = pv.PolyData(visualization_points)
visualize_pointcloud(occupied_points)

if __name__ == '__main__':
main()
5 changes: 1 addition & 4 deletions ribctl/lib/landmarks/constriction.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,10 @@ def get_constriction(rcsb_id: str)->np.ndarray:
else:
uL4 = ro.get_poly_by_polyclass('uL4')
uL22 = ro.get_poly_by_polyclass('uL22')

if uL4 is None or uL22 is None:
raise ValueError("Could not find uL4 or uL22 in {}".format(rcsb_id))

structure = ro.assets.biopython_structure()

uL4_c :Chain = structure[0][uL4.auth_asym_id]
uL4_c :Chain = structure[0][uL4.auth_asym_id]
uL22_c :Chain = structure[0][uL22.auth_asym_id]

uL4_coords = [(r.center_of_mass() ) for r in uL4_c.child_list]
Expand Down

0 comments on commit d50ac4e

Please sign in to comment.