Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
FeGeyer committed Feb 1, 2023
1 parent 5d53314 commit 888d43b
Showing 1 changed file with 54 additions and 18 deletions.
72 changes: 54 additions & 18 deletions tests/test_evaluation.py
Original file line number Diff line number Diff line change
@@ -335,34 +335,58 @@ def test_symmetry(self):

def test_sample_images(self):
import numpy as np
import torch
import torch.nn.functional as F
from radionets.evaluation.utils import (
read_pred,
trunc_rvs,
even_better_symmetry,
get_ifft,
sym_new,
)

num_samples = 100
num_img = 2
img = read_pred("./tests/model/predictions_unc.h5")
mean = img["pred"][0]
std = img["unc"][0]
mean_amp, mean_phase = mean[0], mean[1]
std_amp, std_phase = std[0], std[1]
mean_amp, mean_phase = (
img["pred"][:num_img, 0, :65, :],
img["pred"][:num_img, 1, :65, :],
)
std_amp, std_phase = (
img["unc"][:num_img, 0, :65, :],
img["unc"][:num_img, 1, :65, :],
)
img_size = mean_amp.shape[-1]

# amplitude
sampled_gauss_amp = trunc_rvs(mean_amp, std_amp, "amp", num_samples, num_img=1)
sampled_gauss_amp = trunc_rvs(
mean_amp, std_amp, "amp", num_samples, num_img=num_img
)

# phase
sampled_gauss_phase = trunc_rvs(
mean_phase, std_phase, "phase", num_samples, num_img=1
mean_phase, std_phase, "phase", num_samples, num_img=num_img
)

assert sampled_gauss_amp.shape == (1, num_samples, img_size, img_size)
assert sampled_gauss_phase.shape == (1, num_samples, img_size, img_size)
assert sampled_gauss_amp.shape == (
num_img,
num_samples,
img_size // 2 + 1,
img_size,
)
assert sampled_gauss_phase.shape == (
num_img,
num_samples,
img_size // 2 + 1,
img_size,
)

sampled_gauss_amp = sampled_gauss_amp.reshape(num_img * num_samples, 65, 128)
sampled_gauss_phase = sampled_gauss_phase.reshape(
num_img * num_samples, 65, 128
)

with pytest.raises(ValueError):
trunc_rvs(mean_phase, std_phase, "pase", num_samples, num_img=1)
trunc_rvs(mean_phase, std_phase, "pase", num_samples, num_img=num_img)

# masks
mask_invalid_amp = sampled_gauss_amp < 0
@@ -373,24 +397,36 @@ def test_sample_images(self):
assert mask_invalid_phase.sum() == 0

sampled_gauss = np.stack([sampled_gauss_amp, sampled_gauss_phase], axis=1)
sampled_gauss_symmetry = even_better_symmetry(sampled_gauss)

# pad resulting images and utilize symmetry
sampled_gauss = F.pad(
input=torch.tensor(sampled_gauss),
pad=(0, 0, 0, 63),
mode="constant",
value=0,
)
sampled_gauss_symmetry = sym_new(sampled_gauss, None)

fft_sampled_symmetry = get_ifft(
sampled_gauss_symmetry, amp_phase=True, scale=False
)
).reshape(num_img, num_samples, 128, 128)

results = {
"mean": fft_sampled_symmetry.mean(axis=0),
"std": fft_sampled_symmetry.std(axis=0),
"mean": fft_sampled_symmetry.mean(axis=1),
"std": fft_sampled_symmetry.std(axis=1),
}
assert results["mean"].shape == (img_size, img_size)
assert results["std"].shape == (img_size, img_size)
assert results["mean"].shape == (num_img, img_size, img_size)
assert results["std"].shape == (num_img, img_size, img_size)

def test_uncertainty_plots(self):
from radionets.evaluation.utils import read_pred, sample_images

img = read_pred("./tests/model/predictions_unc.h5")
results = sample_images(img["pred"][0:2], img["unc"][0:2], 100)
assert results is not None
results = sample_images(
img["pred"][:2, :, :65, :], img["unc"][:2, :, :65, :], 100
)
assert results["mean"].shape == (2, 128, 128)
assert results["std"].shape == (2, 128, 128)

def test_evaluation(self):
import shutil

0 comments on commit 888d43b

Please sign in to comment.