Skip to content

Commit

Permalink
Merge pull request #14 from juglab/merger
Browse files Browse the repository at this point in the history
v0.1.18
  • Loading branch information
tibuch authored Feb 5, 2021
2 parents 89dc8ec + 831bf50 commit aa6394e
Show file tree
Hide file tree
Showing 10 changed files with 1,831 additions and 14 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Build Python package:
`python setup.py bdist_wheel`

Build singularity recipe:
`neurodocker generate singularity -b nvidia/cuda:10.2-cudnn7-devel-ubuntu18.04 -p apt --copy /home/tibuch/Gitrepos/FourierImageTransformer/dist/fourier_image_transformers-0.1.17-py3-none-any.whl /fourier_image_transformers-0.1.17-py3-none-any.whl --miniconda create_env=fit conda_install='python=3.7 astra-toolbox pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch -c astra-toolbox/label/dev' pip_install='/fourier_image_transformers-0.1.17-py3-none-any.whl' activate=true --entrypoint "/neurodocker/startup.sh python" > v0.1.17.Singularity`
`neurodocker generate singularity -b nvidia/cuda:10.2-cudnn7-devel-ubuntu18.04 -p apt --copy /home/tibuch/Gitrepos/FourierImageTransformer/dist/fourier_image_transformers-0.1.18-py3-none-any.whl /fourier_image_transformers-0.1.18-py3-none-any.whl --miniconda create_env=fit conda_install='python=3.7 astra-toolbox pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch -c astra-toolbox/label/dev' pip_install='/fourier_image_transformers-0.1.18-py3-none-any.whl' activate=true --entrypoint "/neurodocker/startup.sh python" > v0.1.18.Singularity`

Build singularity container:
`sudo singularity build fit_v0.1.17.simg v0.1.17.Singularity`
`sudo singularity build fit_v0.1.18.simg v0.1.18.Singularity`
172 changes: 172 additions & 0 deletions examples/datamodules/DataModule - Celeb SuperRes.ipynb

Large diffs are not rendered by default.

606 changes: 606 additions & 0 deletions examples/datamodules/DataModule - CelebA Tomo.ipynb

Large diffs are not rendered by default.

605 changes: 605 additions & 0 deletions examples/datamodules/DataModule - Kanji Tomo.ipynb

Large diffs are not rendered by default.

64 changes: 62 additions & 2 deletions fit/datamodules/super_res/SRecDataModule.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from glob import glob
from os.path import join
from typing import Optional, Union, List

import numpy as np
import torch
from imageio import imread
from pytorch_lightning import LightningDataModule
from skimage.transform import resize
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST

Expand All @@ -14,7 +18,7 @@
class MNISTSResFourierTargetDataModule(LightningDataModule):
IMG_SHAPE = 27

def __init__(self, root_dir, batch_size, inner_circle=True):
def __init__(self, root_dir, batch_size):
"""
:param root_dir:
:param batch_size:
Expand All @@ -23,7 +27,6 @@ def __init__(self, root_dir, batch_size, inner_circle=True):
super().__init__()
self.root_dir = root_dir
self.batch_size = batch_size
self.inner_circle = inner_circle
self.gt_ds = None
self.mean = None
self.std = None
Expand Down Expand Up @@ -76,3 +79,60 @@ def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]
SResFourierCoefficientDataset(self.gt_ds, mag_min=self.mag_min, mag_max=self.mag_max, part='test',
img_shape=MNISTSResFourierTargetDataModule.IMG_SHAPE),
batch_size=1)


class CelebASResFourierTargetDataModule(LightningDataModule):
IMG_SHAPE = 127

def __init__(self, root_dir, batch_size):
"""
:param root_dir:
:param batch_size:
:param num_angles:
"""
super().__init__()
self.root_dir = root_dir
self.gt_shape = 63
self.batch_size = batch_size
self.gt_ds = None
self.mean = None
self.std = None
self.mag_min = None
self.mag_max = None

def setup(self, stage: Optional[str] = None):
gt_data = np.load(join(self.root_dir, 'gt_data.npz'))

gt_train = torch.from_numpy(gt_data['gt_train'])
gt_val = torch.from_numpy(gt_data['gt_val'])
gt_test = torch.from_numpy(gt_data['gt_test'])
self.mean = gt_train.mean()
self.std = gt_train.std()

gt_train = normalize(gt_train, self.mean, self.std)
gt_val = normalize(gt_val, self.mean, self.std)
gt_test = normalize(gt_test, self.mean, self.std)
self.gt_ds = GroundTruthDataset(gt_train, gt_val, gt_test)

tmp_fcds = SResFourierCoefficientDataset(self.gt_ds, mag_min=None, mag_max=None, part='train',
img_shape=self.gt_shape)
self.mag_min = tmp_fcds.mag_min
self.mag_max = tmp_fcds.mag_max

def train_dataloader(self, *args, **kwargs) -> DataLoader:
return DataLoader(
SResFourierCoefficientDataset(self.gt_ds, mag_min=self.mag_min, mag_max=self.mag_max, part='train',
img_shape=self.gt_shape),
batch_size=self.batch_size, num_workers=2)

def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
return DataLoader(
SResFourierCoefficientDataset(self.gt_ds, mag_min=self.mag_min, mag_max=self.mag_max, part='validation',
img_shape=self.gt_shape),
batch_size=self.batch_size, num_workers=2)

def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
return DataLoader(
SResFourierCoefficientDataset(self.gt_ds, mag_min=self.mag_min, mag_max=self.mag_max, part='test',
img_shape=self.gt_shape),
batch_size=1)
156 changes: 150 additions & 6 deletions fit/datamodules/tomo_rec/TRecDataModule.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from glob import glob
from os.path import join
from typing import Optional, Union, List

import dival
import numpy as np
import torch
from imageio import imread
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
Expand Down Expand Up @@ -166,21 +169,21 @@ def setup(self, stage: Optional[str] = None):
assert self.gt_shape <= self.IMG_SHAPE, 'GT is larger than original images.'
if self.gt_shape < self.IMG_SHAPE:
gt_train = np.array([resize(lodopab.get_sample(i, part='train', out=(False, True))[1][1:, 1:],
output_shape=(self.gt_shape, self.gt_shape), anti_aliasing=True) for i in
range(4000)])
output_shape=(self.gt_shape, self.gt_shape), anti_aliasing=True) for i in
range(4000)])
gt_val = np.array([resize(lodopab.get_sample(i, part='validation', out=(False, True))[1][1:, 1:],
output_shape=(self.gt_shape, self.gt_shape), anti_aliasing=True) for i in
range(400)])
gt_test = np.array([resize(lodopab.get_sample(i, part='test', out=(False, True))[1][1:, 1:],
output_shape=(self.gt_shape, self.gt_shape), anti_aliasing=True) for i in
range(3553)])
else:
gt_train = np.array([lodopab.get_sample(i, part='train', out=(False, True))[1][1:, 1:] for i in range(4000)])
gt_val = np.array([lodopab.get_sample(i, part='validation', out=(False, True))[1][1:, 1:] for i in range(400)])
gt_train = np.array(
[lodopab.get_sample(i, part='train', out=(False, True))[1][1:, 1:] for i in range(4000)])
gt_val = np.array(
[lodopab.get_sample(i, part='validation', out=(False, True))[1][1:, 1:] for i in range(400)])
gt_test = np.array([lodopab.get_sample(i, part='test', out=(False, True))[1][1:, 1:] for i in range(3553)])



gt_train = torch.from_numpy(gt_train)
gt_val = torch.from_numpy(gt_val)
gt_test = torch.from_numpy(gt_test)
Expand Down Expand Up @@ -229,3 +232,144 @@ def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]
TRecFourierCoefficientDataset(self.gt_ds, mag_min=self.mag_min, mag_max=self.mag_max, part='test',
img_shape=self.gt_shape),
batch_size=1)


class KanjiFourierTargetDataModule(LightningDataModule):
IMG_SHAPE = 63

def __init__(self, root_dir, batch_size, num_angles=33):
"""
:param root_dir:
:param batch_size:
:param num_angles:
"""
super().__init__()
self.root_dir = root_dir
self.batch_size = batch_size
self.num_angles = num_angles
self.inner_circle = True
self.gt_ds = None
self.mean = None
self.std = None

def setup(self, stage: Optional[str] = None):
gt_data = np.load(join(self.root_dir, 'gt_data.npz'))

gt_train = torch.from_numpy(gt_data['gt_train'])
gt_val = torch.from_numpy(gt_data['gt_val'])
gt_test = torch.from_numpy(gt_data['gt_test'])

x, y = torch.meshgrid(torch.arange(-self.IMG_SHAPE // 2 + 1,
self.IMG_SHAPE // 2 + 1),
torch.arange(-self.IMG_SHAPE // 2 + 1,
self.IMG_SHAPE // 2 + 1))
circle = torch.sqrt(x ** 2. + y ** 2.) <= self.IMG_SHAPE // 2
gt_train *= circle
gt_val *= circle
gt_test *= circle

self.mean = gt_train.mean()
self.std = gt_train.std()

gt_train = normalize(gt_train, self.mean, self.std)
gt_val = normalize(gt_val, self.mean, self.std)
gt_test = normalize(gt_test, self.mean, self.std)

self.gt_ds = get_projection_dataset(
GroundTruthDataset(gt_train, gt_val, gt_test),
num_angles=self.num_angles, im_shape=133, impl='astra_cpu', inner_circle=self.inner_circle)

tmp_fcds = TRecFourierCoefficientDataset(self.gt_ds, mag_min=None, mag_max=None, part='train',
img_shape=self.IMG_SHAPE)
self.mag_min = tmp_fcds.mag_min
self.mag_max = tmp_fcds.mag_max

def train_dataloader(self, *args, **kwargs) -> DataLoader:
return DataLoader(
TRecFourierCoefficientDataset(self.gt_ds, mag_min=self.mag_min, mag_max=self.mag_max, part='train',
img_shape=self.IMG_SHAPE),
batch_size=self.batch_size, num_workers=2)

def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
return DataLoader(
TRecFourierCoefficientDataset(self.gt_ds, mag_min=self.mag_min, mag_max=self.mag_max, part='validation',
img_shape=self.IMG_SHAPE),
batch_size=self.batch_size, num_workers=2)

def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
return DataLoader(
TRecFourierCoefficientDataset(self.gt_ds, mag_min=self.mag_min, mag_max=self.mag_max, part='test',
img_shape=self.IMG_SHAPE),
batch_size=1)


class CelebAFourierTargetDataModule(LightningDataModule):
IMG_SHAPE = 127

def __init__(self, root_dir, batch_size, num_angles=33):
"""
:param root_dir:
:param batch_size:
:param num_angles:
"""
super().__init__()
self.root_dir = root_dir
self.batch_size = batch_size
self.gt_shape = 63
self.num_angles = num_angles
self.inner_circle = True
self.gt_ds = None
self.mean = None
self.std = None

def setup(self, stage: Optional[str] = None):
gt_data = np.load(join(self.root_dir, 'gt_data.npz'))

gt_train = torch.from_numpy(gt_data['gt_train'])
gt_val = torch.from_numpy(gt_data['gt_val'])
gt_test = torch.from_numpy(gt_data['gt_test'])

assert gt_train.shape[1] == self.gt_shape
assert gt_train.shape[2] == self.gt_shape
x, y = torch.meshgrid(torch.arange(-self.gt_shape // 2 + 1,
self.gt_shape // 2 + 1),
torch.arange(-self.gt_shape // 2 + 1,
self.gt_shape // 2 + 1))
circle = torch.sqrt(x ** 2. + y ** 2.) <= self.gt_shape // 2
gt_train *= circle
gt_val *= circle
gt_test *= circle

self.mean = gt_train.mean()
self.std = gt_train.std()

gt_train = normalize(gt_train, self.mean, self.std)
gt_val = normalize(gt_val, self.mean, self.std)
gt_test = normalize(gt_test, self.mean, self.std)

self.gt_ds = get_projection_dataset(
GroundTruthDataset(gt_train, gt_val, gt_test),
num_angles=self.num_angles, im_shape=153, impl='astra_cpu', inner_circle=self.inner_circle)

tmp_fcds = TRecFourierCoefficientDataset(self.gt_ds, mag_min=None, mag_max=None, part='train',
img_shape=self.gt_shape)
self.mag_min = tmp_fcds.mag_min
self.mag_max = tmp_fcds.mag_max

def train_dataloader(self, *args, **kwargs) -> DataLoader:
return DataLoader(
TRecFourierCoefficientDataset(self.gt_ds, mag_min=self.mag_min, mag_max=self.mag_max, part='train',
img_shape=self.gt_shape),
batch_size=self.batch_size, num_workers=2)

def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
return DataLoader(
TRecFourierCoefficientDataset(self.gt_ds, mag_min=self.mag_min, mag_max=self.mag_max, part='validation',
img_shape=self.gt_shape),
batch_size=self.batch_size, num_workers=2)

def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
return DataLoader(
TRecFourierCoefficientDataset(self.gt_ds, mag_min=self.mag_min, mag_max=self.mag_max, part='test',
img_shape=self.gt_shape),
batch_size=1)
24 changes: 21 additions & 3 deletions fit/modules/TRecTransformerModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,14 @@
class TRecTransformerModule(LightningModule):
def __init__(self, d_model, y_coords_proj, x_coords_proj, y_coords_img, x_coords_img, src_flatten_coords,
dst_flatten_coords, dst_order, angles, img_shape=27, detector_len=27, init_bin_factor=4,
alpha=1.5, bin_factor_cd=10,
bin_factor_cd=10,
lr=0.0001,
weight_decay=0.01,
attention_type="linear", n_layers=4, n_heads=4, d_query=4, dropout=0.1, attention_dropout=0.1):
super().__init__()

self.save_hyperparameters("d_model",
"img_shape",
"alpha",
"bin_factor_cd",
"init_bin_factor",
"detector_len",
Expand Down Expand Up @@ -56,6 +55,8 @@ def __init__(self, d_model, y_coords_proj, x_coords_proj, y_coords_img, x_coords
self.dft_shape = (img_shape, img_shape // 2 + 1)
self.bin_factor = init_bin_factor
self.bin_count = 0
self.best_mean_val_mse = 9999999
self.bin_factor_patience = 10
self.register_buffer('mask', psfft(self.bin_factor, pixel_res=img_shape))

self.trec = TRecTransformer(d_model=self.hparams.d_model,
Expand Down Expand Up @@ -220,6 +221,9 @@ def log_val_images(self, pred_img, x, y_fc, y_real, mag_min, mag_max):
self.trainer.logger.experiment.add_image('targets/img_{}'.format(i), y_img.unsqueeze(0),
global_step=self.trainer.global_step)

def _is_better(self, mean_val_mse):
return mean_val_mse < self.best_mean_val_mse * (1. - 0.0001)

def validation_epoch_end(self, outputs):
val_loss = [o['val_loss'] for o in outputs]
val_mse = [o['val_mse'] for o in outputs]
Expand All @@ -230,8 +234,19 @@ def validation_epoch_end(self, outputs):
mean_val_mse = torch.mean(torch.stack(val_mse))
mean_val_psnr = torch.mean(torch.stack(val_psnr))
bin_factor_threshold = torch.mean(torch.stack(bin_mse)) * self.bin_factor
if self.bin_count > self.hparams.bin_factor_cd and mean_val_mse < bin_factor_threshold and self.bin_factor > 1:

if self._is_better(mean_val_mse):
self.best_mean_val_mse = mean_val_mse
self.bin_factor_patience = 10
else:
self.bin_factor_patience -= 1

reduce_bin_factor = (self.bin_factor_patience < 1) or (
self.bin_count > self.hparams.bin_factor_cd and mean_val_mse < bin_factor_threshold)
if reduce_bin_factor and self.bin_factor > 1:
self.bin_count = 0
self.bin_factor_patience = 10
self.best_mean_val_mse = mean_val_mse
self.bin_factor = max(1, self.bin_factor // 2)
self.register_buffer('mask', psfft(self.bin_factor, pixel_res=self.hparams.img_shape).to(self.device))
print('Reduced bin_factor to {}.'.format(self.bin_factor))
Expand All @@ -241,6 +256,9 @@ def validation_epoch_end(self, outputs):

self.bin_count += 1

if self.bin_factor > 1:
self.trainer.lr_schedulers[0]['scheduler']._reset()

self.log('Train/avg_val_loss', torch.mean(torch.stack(val_loss)), logger=True, on_epoch=True)
self.log('Train/avg_val_mse', mean_val_mse, logger=True, on_epoch=True)
self.log('Train/avg_val_psnr', mean_val_psnr, logger=True, on_epoch=True)
Expand Down
2 changes: 1 addition & 1 deletion fit/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.1.17'
__version__ = '0.1.18'
Loading

0 comments on commit aa6394e

Please sign in to comment.