Skip to content

Commit

Permalink
change of folder structure
Browse files Browse the repository at this point in the history
  • Loading branch information
SarahMuth committed Jan 31, 2025
1 parent 6ce0513 commit 99b8033
Show file tree
Hide file tree
Showing 13 changed files with 51 additions and 15 deletions.
6 changes: 3 additions & 3 deletions inference/run_protein_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
import sys
sys.path.append("/user/muth9/u12095/czii-protein-challenge")

from utils.prediction import get_prediction_torch_em
from utils.protein_detection import protein_detection
from utils.tiling_helper import parse_tiling
from utils import get_prediction_torch_em
from utils import protein_detection
from utils import parse_tiling

def get_volume(input_path):
zarr_file = zarr.open(os.path.join(input_path, "VoxelSpacing10.000", "denoised.zarr", "0"), mode='r')
Expand Down
4 changes: 2 additions & 2 deletions training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import sys
sys.path.append("/user/muth9/u12095/czii-protein-challenge")

from utils.dataset_splits import get_paths # noqa
from utils.training import supervised_training # noqa
from utils import get_paths # noqa
from utils import supervised_training # noqa

TRAIN_ROOT = "/scratch-grete/projects/nim00007/cryo-et/challenge-data/train/static/"
LABEL_ROOT = "/scratch-grete/projects/nim00007/cryo-et/challenge-data/train/overlay/"
Expand Down
15 changes: 8 additions & 7 deletions utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .data_loader import create_data_loader
from .dataset_splits import get_paths
from .heatmap_dataset import HeatmapDataset
from .training import supervised_training
from .prediction import get_prediction_torch_em
from .protein_detection import protein_detection
from .tiling_helper import parse_tiling
from .training.data_loader import create_data_loader
from .training.dataset_splits import get_paths
from .training.heatmap_dataset import HeatmapDataset
from .training.training import supervised_training
from .prediction.prediction import get_prediction_torch_em
from .inference.protein_detection import protein_detection
from .training.tiling_helper import parse_tiling
from .evaluation.evaluation_metrics import metric_coords, get_distance_threshold_from_gridsearch
5 changes: 5 additions & 0 deletions utils/evaluation/evaluation_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
def metric_coords():
return

def get_distance_threshold_from_gridsearch():
return
2 changes: 0 additions & 2 deletions utils/gridsearch.py

This file was deleted.

31 changes: 31 additions & 0 deletions utils/inference/gridsearch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pandas
from tqdm import tqdm
from skimage.feature import blob_log
from ..evaluation.evaluation_metrics import metric_coords, get_distance_threshold_from_gridsearch


def gridsearch():
data = []
for i in tqdm(range(len(idx))):
image_path = grid_search_image_paths[i]
label_path = grid_search_label_paths[i]

model, device = load_model(config, in_channels)
preds = predict_image(image_path, model, tuple(config["patch_shape"]), device=device)
label_coords = get_center_coordinates(label_path)


for thresh in threshes:
#smalles protein structure: "beta-amylase": 33.27
#bigges protein structure: "ribosome": 109.02,
#0.3 is the factor to match the PDB size to the experimental data size
adj_factor=0.3 #TODO implement this as an argument, also when creating heatmap
pred_coords = blob_log(preds, min_sigma=33.27*adj_factor, max_sigma=109.02*adj_factor, threshold_abs=thresh)


_, _, f1, _, _, _ = metric_coords(label_coords, pred_coords)
data.append([f1, dist, thresh])

df = pandas.DataFrame(data=data, columns=["f1", "Distance", "Threshold"])
dist, thresh = get_distance_threshold_from_gridsearch(df, distances, threshes)
return dist, thresh, grid_search_numbers
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
from skimage.feature import blob_log
from .gridsearch import gridsearch

def protein_detection(heatmap): #TODO do this properly
"""
Expand All @@ -15,7 +16,7 @@ def protein_detection(heatmap): #TODO do this properly
"""
detections = []

data_path=#TODO pass the val data paths
data_path="test"#TODO pass the val data paths
threshold = gridsearch(data_path, model)
#smalles protein structure: "beta-amylase": 33.27
#bigges protein structure: "ribosome": 109.02,
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 comments on commit 99b8033

Please sign in to comment.