diff --git a/utils/gridsearch.py b/utils/gridsearch.py new file mode 100644 index 0000000..b9e0a83 --- /dev/null +++ b/utils/gridsearch.py @@ -0,0 +1,2 @@ +def gridsearch(): + \ No newline at end of file diff --git a/utils/protein_detection.py b/utils/protein_detection.py index 7e5f6b2..9fb8bc3 100644 --- a/utils/protein_detection.py +++ b/utils/protein_detection.py @@ -1,9 +1,7 @@ import numpy as np -from scipy.ndimage import gaussian_laplace, label, find_objects, center_of_mass -from scipy.ndimage.measurements import variance +from skimage.feature import blob_log - -def protein_detection(heatmap): # TODO do this properly +def protein_detection(heatmap): #TODO do this properly """ Detects local maxima and estimates sizes of Gaussians in a 3D heatmap. @@ -15,40 +13,20 @@ def protein_detection(heatmap): # TODO do this properly - 'coordinates': Tuple of (z, y, x) for the local maxima - 'size': Estimated size of the Gaussian (sigma equivalent) """ - # Apply Laplacian of Gaussian (LoG) filter to enhance Gaussian-like structures - log_filtered = -gaussian_laplace(heatmap, sigma=1) - - # Find local maxima - labeled, num_features = label(log_filtered > np.mean(log_filtered)) # Binary threshold - - # Get bounding boxes and compute centers of mass - regions = find_objects(labeled) detections = [] - for i, region in enumerate(regions): - if region is None: - continue - - # Extract subregion - subregion = labeled[region] - sub_heatmap = heatmap[region] - - # Mask specific to the current label - mask = (subregion == (i + 1)) - - # Compute center of mass as the coordinates of the local maximum - com = center_of_mass(sub_heatmap, labels=mask, index=1) - - # Compute size: estimate the variance of the Gaussian - size = np.sqrt(variance(sub_heatmap, labels=mask, index=1)) + data_path=#TODO pass the val data paths + threshold = gridsearch(data_path, model) + #smalles protein structure: "beta-amylase": 33.27 + #bigges protein structure: "ribosome": 109.02, + #0.3 is the factor to match the PDB size to the experimental data size + adj_factor=0.3 #TODO implement this as an argument, also when creating heatmap + pred_coords = blob_log(preds, min_sigma=33.27*adj_factor, max_sigma=109.02*adj_factor, threshold_abs=threshold) - # Adjust coordinates to global - starts = [r.start for r in region] # Dynamically handle dimensions - com_global = tuple(com[i] + starts[i] for i in range(len(starts))) - detections.append({ - 'coordinates': com_global, - 'size': size - }) + detections.append({ + 'coordinates': coordinates, + 'size': size + }) - return detections + return detections \ No newline at end of file