-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
282 additions
and
90 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import os | ||
import numpy as np | ||
from typing import Tuple | ||
import multiprocessing as mp | ||
from cylinder import transform_points_to_C0 | ||
from mesh_generation.mes_visualization import visualize_pointcloud | ||
from numba import jit | ||
import pyvista as pv | ||
|
||
from ribctl.lib.landmarks.constriction import get_constriction | ||
from ribctl.lib.landmarks.ptc_via_trna import PTC_location | ||
from ribctl.lib.npet.tunnel_bbox_ptc_constriction import filter_residues_parallel, ribosome_entities | ||
|
||
def chunk_points(points: np.ndarray, n_chunks: int) -> list: | ||
"""Split points into chunks for parallel processing""" | ||
chunk_size = len(points) // n_chunks | ||
return [points[i:i + chunk_size] for i in range(0, len(points), chunk_size)] | ||
|
||
@jit(nopython=True) | ||
def process_points_chunk(points: np.ndarray, grid_coords: np.ndarray, radius: float) -> np.ndarray: | ||
"""Process a chunk of points with Numba acceleration""" | ||
mask = np.zeros(grid_coords.shape[1], dtype=np.bool_) | ||
for point in points: | ||
distances = np.sqrt(np.sum((grid_coords.T - point)**2, axis=1)) | ||
mask |= (distances <= radius) | ||
return mask | ||
|
||
def parallel_point_cloud_mask(points: np.ndarray, X: np.ndarray, Y: np.ndarray, Z: np.ndarray, | ||
radius_around_point: float) -> np.ndarray: | ||
"""Generate point cloud mask using parallel processing""" | ||
# Prepare grid coordinates once | ||
grid_coords = np.stack([X, Y, Z]) | ||
original_shape = X.shape | ||
grid_coords = grid_coords.reshape(3, -1) | ||
|
||
# Determine number of chunks based on CPU cores | ||
n_cores = mp.cpu_count() - 4 | ||
chunks = chunk_points(points, n_cores) | ||
|
||
# Process chunks in parallel | ||
with mp.Pool(n_cores) as pool: | ||
results = pool.starmap(process_points_chunk, | ||
[(chunk, grid_coords, radius_around_point) for chunk in chunks]) | ||
|
||
# Combine results | ||
final_mask = np.any(results, axis=0) | ||
return final_mask.reshape(original_shape) | ||
|
||
def main(): | ||
# Your existing setup code... | ||
RCSB_ID = '3J7Z' | ||
radius = 40 | ||
height = 80 | ||
voxel_size = 1 | ||
ATOM_RADIUS = 2 | ||
|
||
# residues, base, axis = get_npet_cylinder_residues(RCSB_ID, radius=radius, height=height) | ||
|
||
base_point = np.array(PTC_location(RCSB_ID).location) | ||
axis_point = np.array( get_constriction(RCSB_ID) ) | ||
# translation, rotation = get_transformation_to_C0(base, axis) | ||
# t_base = ( base + translation ) @ rotation.T | ||
# t_axis = ( axis + translation ) @ rotation.T | ||
|
||
if os.path.exists('points.npy'): | ||
points = np.load('points.npy') | ||
print("Loaded") | ||
else: | ||
residues= filter_residues_parallel( ribosome_entities(RCSB_ID, 'R'), base_point, axis_point, radius, height, ) | ||
points = np.array([atom.get_coord() for residue in residues for atom in residue.child_list]) | ||
np.save('points.npy', points) | ||
print("Saved") | ||
|
||
nx = ny = int(2 * radius / voxel_size) + 1 | ||
nz = int(height / voxel_size) + 1 | ||
x = np.linspace(-radius, radius, nx) | ||
y = np.linspace(-radius, radius, ny) | ||
z = np.linspace(0, height, nz) | ||
X, Y, Z = np.meshgrid(x, y, z, indexing='ij') | ||
|
||
# Transform points (vectorized) | ||
transformed = transform_points_to_C0(points, base_point, axis_point) | ||
X_I, Y_I, Z_I = transformed.T | ||
points = np.column_stack((X_I, Y_I, Z_I)) | ||
|
||
# Generate cylinder mask (vectorized) | ||
cylinder_mask = (np.sqrt(X**2 + Y**2) <= radius) | ||
hollow_cylinder = ~cylinder_mask | ||
|
||
# Generate point cloud mask in parallel | ||
point_cloud_mask = parallel_point_cloud_mask(points, X, Y, Z, ATOM_RADIUS) | ||
|
||
# Combine masks | ||
final_mask = hollow_cylinder | point_cloud_mask | ||
|
||
# Visualize results | ||
occupied = np.where(~final_mask) | ||
visualization_points = np.column_stack(( | ||
x[occupied[0]], | ||
y[occupied[1]], | ||
z[occupied[2]] | ||
)) | ||
occupied_points = pv.PolyData(visualization_points) | ||
visualize_pointcloud(occupied_points) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import numpy as np | ||
from scipy.spatial import cKDTree | ||
import pyvista as pv | ||
|
||
from cylinder import transform_points_to_C0 | ||
from mesh_generation.mes_visualization import visualize_pointcloud | ||
from ribctl.lib.landmarks.constriction import get_constriction | ||
from ribctl.lib.landmarks.ptc_via_trna import PTC_location | ||
|
||
def generate_voxel_centers(radius: float, height: float, voxel_size: float) -> tuple: | ||
"""Generate centers of all voxels in the grid""" | ||
nx = ny = int(2 * radius / voxel_size) + 1 | ||
nz = int(height / voxel_size) + 1 | ||
|
||
x = np.linspace(-radius, radius, nx) | ||
y = np.linspace(-radius, radius, ny) | ||
z = np.linspace(0, height, nz) | ||
|
||
# Generate all voxel center coordinates | ||
X, Y, Z = np.meshgrid(x, y, z, indexing='ij') | ||
voxel_centers = np.column_stack((X.ravel(), Y.ravel(), Z.ravel())) | ||
|
||
return voxel_centers, (X.shape, x, y, z) | ||
|
||
def create_point_cloud_mask(points: np.ndarray, | ||
radius: float, | ||
height: float, | ||
voxel_size: float = 1.0, | ||
radius_around_point: float = 2.0): | ||
""" | ||
Create point cloud mask using KDTree for efficient spatial queries | ||
""" | ||
# Generate voxel centers | ||
voxel_centers, (grid_shape, x, y, z) = generate_voxel_centers(radius, height, voxel_size) | ||
|
||
# Create KDTree from the transformed points | ||
tree = cKDTree(points) | ||
|
||
# Find all voxels that have points within radius_around_point | ||
# This is much more efficient than checking each point against each voxel | ||
indices = tree.query_ball_point(voxel_centers, radius_around_point) | ||
|
||
# Create mask from the indices | ||
point_cloud_mask = np.zeros(len(voxel_centers), dtype=bool) | ||
point_cloud_mask[[i for i, idx in enumerate(indices) if idx]] = True | ||
|
||
# Reshape mask back to grid shape | ||
point_cloud_mask = point_cloud_mask.reshape(grid_shape) | ||
|
||
# Create cylinder mask | ||
X, Y, Z = np.meshgrid(x, y, z, indexing='ij') | ||
cylinder_mask = (np.sqrt(X**2 + Y**2) <= radius) | ||
hollow_cylinder = ~cylinder_mask | ||
|
||
# Combine masks | ||
final_mask = hollow_cylinder | point_cloud_mask | ||
|
||
return final_mask, (x, y, z) | ||
|
||
def main(): | ||
# Load your points and transform them as before | ||
points = np.load('points.npy') | ||
RCSB_ID = '3J7Z' | ||
base_point = np.array(PTC_location(RCSB_ID).location) | ||
axis_point = np.array(get_constriction(RCSB_ID) ) | ||
print("loaded and got axis") | ||
transformed_points = transform_points_to_C0(points, base_point, axis_point) | ||
|
||
final_mask, (x, y, z) = create_point_cloud_mask( | ||
transformed_points, | ||
radius = 40, | ||
height = 80, | ||
voxel_size = 1.0, | ||
radius_around_point = 2.0 | ||
) | ||
|
||
# Extract points for visualization | ||
occupied = np.where(~final_mask) | ||
visualization_points = np.column_stack(( | ||
x[occupied[0]], | ||
y[occupied[1]], | ||
z[occupied[2]] | ||
)) | ||
|
||
# Visualize results | ||
occupied_points = pv.PolyData(visualization_points) | ||
visualize_pointcloud(occupied_points) | ||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters