Skip to content

Commit

Permalink
Add initial code
Browse files Browse the repository at this point in the history
  • Loading branch information
frosinastojanovska committed Aug 27, 2024
1 parent d3a031e commit 8b182a5
Show file tree
Hide file tree
Showing 24 changed files with 1,769 additions and 0 deletions.
54 changes: 54 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]

# C extensions
*.so

# Distribution / packaging
bin/
build/
develop-eggs/
dist/
eggs/
lib/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
.tox/
.coverage
.cache
nosetests.xml
coverage.xml

# Translations
*.mo

# Mr Developer
.mr.developer.cfg
.project
.pydevproject

# Rope
.ropeproject

# Django stuff:
*.log
*.pot

# Sphinx documentation
docs/_build/

.idea/

experiments/*
Empty file added cryosiam/__init__.py
Empty file.
Empty file added cryosiam/apps/__init__.py
Empty file.
1 change: 1 addition & 0 deletions cryosiam/apps/dense_simsiam_regression/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .utils import load_backbone_model, load_prediction_model
28 changes: 28 additions & 0 deletions cryosiam/apps/dense_simsiam_regression/config_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
data_folder: '/scratch/stojanov/mycoplasma_data/vpp/240422/'
gt_folder: '/scratch/stojanov/mycoplasma_data/vpp/240422/'
log_dir: '/scratch/stojanov/development/cryoet-torch/experiments/dense_simsiam_semantic'
prediction_folder: '/scratch/stojanov/development/cryoet-torch/experiments/dense_simsiam_semantic/predictions'
trained_model: '/scratch/stojanov/development/cryoet-torch/experiments/dense_simsiam_semantic/model/model-best.ckpt'
file_extension: '.mrc'

test_files: [ 'TS_56_6.80Apx.mrc', 'TS_61_6.80Apx.mrc' ]

eval_skip_prediction: False

scale_prediction: True

parameters:
gpu_devices: 1
data:
patch_size: [ 64, 64, 64 ]
min: 0
max: 15
mean: 0.5789
std: 0.12345
network:
in_channels: 1
spatial_dims: 3
n_output_channels: 1

hyper_parameters:
batch_size: 40
133 changes: 133 additions & 0 deletions cryosiam/apps/dense_simsiam_regression/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import os
import yaml
import h5py
import torch
import numpy as np
from torch.utils.data import DataLoader
from monai.data import Dataset, list_data_collate, GridPatchDataset, ITKReader, ITKWriter
from monai.transforms import (
Compose,
LoadImaged,
NormalizeIntensityd,
ScaleIntensityRanged,
SpatialPadd,
EnsureChannelFirstd,
EnsureTyped,
EnsureType
)

from cryosiam.utils import parser_helper
from cryosiam.data import MrcReader, TiffReader, PatchIter, MrcWriter, TiffWriter
from cryosiam.apps.dense_simsiam_regression import load_backbone_model, load_prediction_model


def main(config_file_path):
with open(config_file_path, "r") as ymlfile:
cfg = yaml.safe_load(ymlfile)

if 'trained_model' in cfg and cfg['trained_model'] is not None:
checkpoint_path = cfg['trained_model']
else:
checkpoint_path = os.path.join(cfg['log_dir'], 'model', 'model_best.ckpt')
backbone = load_backbone_model(checkpoint_path)
prediction_model = load_prediction_model(checkpoint_path)

checkpoint = torch.load(checkpoint_path)
net_config = checkpoint['hyper_parameters']['config']

test_folder = cfg['data_folder']
prediction_folder = cfg['prediction_folder']
num_output_channels = net_config['parameters']['network']['n_output_channels']
patch_size = net_config['parameters']['data']['patch_size']
spatial_dims = net_config['parameters']['network']['spatial_dims']
os.makedirs(prediction_folder, exist_ok=True)
files = cfg['test_files']
if files is None:
files = [x for x in os.listdir(test_folder) if os.path.isfile(os.path.join(test_folder, x))]
test_data = []
for idx, file in enumerate(files):
test_data.append({'image': os.path.join(test_folder, file),
'file_name': os.path.join(test_folder, file)})
reader = MrcReader(read_in_mem=True) if cfg['file_extension'] in ['.mrc', '.rec'] else \
TiffReader() if cfg['file_extension'] in ['.tiff', '.tif'] else ITKReader()

if cfg['file_extension'] in ['.mrc', '.rec']:
writer = MrcWriter(output_dtype=np.float32, overwrite=True)
writer.set_metadata({'voxel_size': 1})
elif cfg['file_extension'] in ['.tiff', '.tif']:
writer = TiffWriter(output_dtype=np.float32)
else:
writer = ITKWriter()

transforms = Compose(
[
LoadImaged(keys='image', reader=reader),
EnsureChannelFirstd(keys='image'),
ScaleIntensityRanged(keys='image', a_min=cfg['parameters']['data']['min'],
a_max=cfg['parameters']['data']['max'], b_min=0, b_max=1, clip=True),
SpatialPadd(keys='image', spatial_size=patch_size),
NormalizeIntensityd(keys='image', subtrahend=cfg['parameters']['data']['mean'],
divisor=cfg['parameters']['data']['std']),
EnsureTyped(keys='image', data_type='tensor')
]
)
if spatial_dims == 2:
patch_iter = PatchIter(patch_size=tuple(patch_size), start_pos=(0, 0), overlap=(0, 0.5, 0.5))
else:
patch_iter = PatchIter(patch_size=tuple(patch_size), start_pos=(0, 0, 0), overlap=(0, 0.5, 0.5, 0.5))
post_pred = Compose([EnsureType('numpy', dtype=np.float32, device=torch.device('cpu'))])

test_dataset = Dataset(data=test_data, transform=transforms)
test_loader = DataLoader(test_dataset, batch_size=1, num_workers=1, collate_fn=list_data_collate)

print('Prediction')
with torch.no_grad():
for i, test_sample in enumerate(test_loader):
out_file = os.path.join(prediction_folder, os.path.basename(test_sample['file_name'][0]))
patch_dataset = GridPatchDataset(data=[test_sample['image'][0]],
patch_iter=patch_iter)
input_size = list(test_sample['image'][0][0].shape)
preds_out = np.zeros([num_output_channels] + input_size, dtype=np.float32)
loader = DataLoader(patch_dataset, batch_size=cfg['hyper_parameters']['batch_size'], num_workers=2)
for item in loader:
img, coord = item[0], item[1].numpy().astype(int)
z, _ = backbone.forward_predict(img.cuda())
out = prediction_model(z)
out = post_pred(out)
for batch_i in range(img.shape[0]):
c_batch = coord[batch_i][1:]
o_batch = out[batch_i]
# avoid getting patch that is outside of the original dimensions of the image
if c_batch[0][0] >= input_size[0] - patch_size[0] // 4 or \
c_batch[1][0] >= input_size[1] - patch_size[1] // 4 or \
(spatial_dims == 3 and c_batch[2][0] >= input_size[2] - patch_size[2] // 4):
continue
# create slices for the coordinates in the output to get only the middle of the patch
# and the separate cases for the first and last patch in each dimension
slices = tuple(
slice(c[0], c[1] - p // 4) if c[0] == 0 else slice(c[0] + p // 4, c[1])
if c[1] >= s else slice(c[0] + p // 4, c[1] - p // 4)
for c, s, p in zip(c_batch, input_size, patch_size))
# create slices to crop the patch so we only get the middle information
# and the separate cases for the first and last patch in each dimension
slices2 = tuple(
slice(0, 3 * p // 4) if c[0] == 0 else slice(p // 4, p - (c[1] - s))
if c[1] >= s else slice(p // 4, 3 * p // 4)
for c, s, p in zip(c_batch, input_size, patch_size))
preds_out[(slice(0, num_output_channels),) + slices] = o_batch[(slice(0, num_output_channels),)
+ slices2]

if cfg['scale_prediction']:
preds_out = (preds_out - preds_out.min()) / (preds_out.max() - preds_out.min())

with h5py.File(out_file.split(cfg['file_extension'])[0] + '_preds.h5', 'w') as f:
f.create_dataset('preds', data=preds_out)

writer.set_data_array(preds_out[0], channel_dim=None)
writer.write(out_file.split(cfg['file_extension'])[0] + f'{cfg["file_extension"]}')


if __name__ == "__main__":
parser = parser_helper()
args = parser.parse_args()
main(args.config_file)
82 changes: 82 additions & 0 deletions cryosiam/apps/dense_simsiam_regression/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import torch
import collections

from cryosiam.networks.nets import (
CNNHead,
DenseSimSiam
)


def load_backbone_model(checkpoint_path, device="cuda:0"):
"""Load DenseSimSiam trained model from given checkpoint
:param checkpoint_path: path to the checkpoint
:type checkpoint_path: str
:param device: on which device should the model be loaded, default is cuda:0
:type device: str
:return: DenseSimSiam model with laoded trained weights
:rtype: cryosiam.networks.nets.DenseSimSiam
"""
checkpoint = torch.load(checkpoint_path)
config = checkpoint['hyper_parameters']['backbone_config']
model_backbone = DenseSimSiam(block_type=config['parameters']['network']['block_type'],
spatial_dims=config['parameters']['network']['spatial_dims'],
n_input_channels=config['parameters']['network']['in_channels'],
num_layers=config['parameters']['network']['num_layers'],
num_filters=config['parameters']['network']['num_filters'],
fpn_channels=config['parameters']['network']['fpn_channels'],
no_max_pool=config['parameters']['network']['no_max_pool'],
dim=config['parameters']['network']['dim'],
pred_dim=config['parameters']['network']['pred_dim'],
dense_dim=config['parameters']['network']['dense_dim'],
dense_pred_dim=config['parameters']['network']['dense_pred_dim'],
include_levels=config['parameters']['network']['include_levels_loss']
if 'include_levels_loss' in config['parameters']['network'] else False,
add_later_conv=config['parameters']['network']['add_fpn_later_conv']
if 'add_fpn_later_conv' in config['parameters']['network'] else False,
decoder_type=config['parameters']['network']['decoder_type']
if 'decoder_type' in config['parameters']['network'] else 'fpn',
decoder_layers=config['parameters']['network']['fpn_layers']
if 'fpn_layers' in config['parameters']['network'] else 2)
# model_backbone.load_state_dict(checkpoint['state_dict'])
new_state_dict = collections.OrderedDict()
for k, v in checkpoint['state_dict'].items():
if not k.startswith('_model_backbone.'):
continue
name = k.replace("_model_backbone.", '') # remove `_model_backbone.`
new_state_dict[name] = v
model_backbone.load_state_dict(new_state_dict)
model_backbone.eval()
device = torch.device(device)
model_backbone.to(device)
return model_backbone


def load_prediction_model(checkpoint_path, device="cuda:0"):
"""Load SemanticHeads trained model from given checkpoint
:param checkpoint_path: path to the checkpoint
:type checkpoint_path: str
:param device: on which device should the model be loaded, default is cuda:0
:type device: str
:return: InstanceHeads model with loaded trained weights
:rtype: cryosiam.networks.nets.InstanceHeads
"""
checkpoint = torch.load(checkpoint_path)
config = checkpoint['hyper_parameters']['config']
model = CNNHead(n_input_channels=config['parameters']['network']['dense_dim'],
n_output_channels=config['parameters']['network']['n_output_channels'],
spatial_dims=config['parameters']['network']['spatial_dims'],
filters=config['parameters']['network']['filters'],
kernel_size=config['parameters']['network']['kernel_size'],
padding=config['parameters']['network']['padding'])

new_state_dict = collections.OrderedDict()
for k, v in checkpoint['state_dict'].items():
if not k.startswith('_model.'):
continue
name = k.replace("_model.", '') # remove `_model.`
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
model.eval()
device = torch.device(device)
model.to(device)
return model
Empty file.
65 changes: 65 additions & 0 deletions cryosiam/apps/preprocessing/match_pixel_size.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import os
import mrcfile
import argparse
import numpy as np

from cryoet_torch.utils import match_pixel_size

def scale_tomogram(tomo, percentile_lower=None, percentile_upper=None):
if percentile_lower:
min_val = np.percentile(tomo, percentile_lower)
else:
min_val = tomo.min()

if percentile_upper:
max_val = np.percentile(tomo, percentile_upper)
else:
max_val = tomo.max()

tomo = (tomo - min_val) / (max_val - min_val)

return np.clip(tomo, 0, 1)


def parser_helper(description=None):
description = "Match the pixel size of a tomogram with a desired pixel size" if description is None else description
parser = argparse.ArgumentParser(description, add_help=True,
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--input_path', type=str, required=True, help='path to the input tomogram or '
'path to the folder with input tomogram/s')
parser.add_argument('--output_path', type=str, required=True, help='path to save the output tomogram or '
'path to folder to save the output tomogram/s')
parser.add_argument("--pixel_size_in", type=float, required=True, help="Pixel size (angstroms) of the input.")
parser.add_argument("--pixel_size_out", type=float, required=True, help="Pixel size (angstroms) of the output.")
parser.add_argument("--disable_smooth", action="store_true", default=True, help="Disable smoothing of the output.")
return parser


if __name__ == '__main__':
parser = parser_helper()
args = parser.parse_args()

if os.path.isdir(args.input_path):
os.makedirs(args.output_path, exist_ok=True)
for tomo in os.listdir(args.input_path):
if tomo.endswith(".mrc") or tomo.endswith(".rec"):
with mrcfile.open(os.path.join(args.input_path, tomo), permissive=True) as m:
tomogram = m.data
voxel_size = m.voxel_size
if args.pixel_size_in != args.pixel_size_out:
tomogram = match_pixel_size(tomogram, args.pixel_size_in, args.pixel_size_out, args.disable_smooth)

with mrcfile.new(os.path.join(args.output_path, tomo), overwrite=True) as m:
m.set_data(tomogram.astype(np.float32))
m.voxel_size = args.pixel_size_out
else:
with mrcfile.open(args.input_path, permissive=True) as m:
tomogram = m.data
voxel_size = m.voxel_size

if args.pixel_size_in != args.pixel_size_out:
tomogram = match_pixel_size(tomogram, args.pixel_size_in, args.pixel_size_out, args.disable_smooth)

with mrcfile.new(args.output_path, overwrite=True) as m:
m.set_data(tomogram.astype(np.float32))
m.voxel_size = args.pixel_size_out
Loading

0 comments on commit 8b182a5

Please sign in to comment.