Skip to content

Commit

Permalink
Remove hardcoded numbers in sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
FeGeyer committed Apr 19, 2024
1 parent 608cc86 commit 3223403
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
7 changes: 6 additions & 1 deletion radionets/evaluation/train_inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,12 @@ def save_sampled(conf):
result = sample_images(img["pred"], img["unc"], 100, conf)

# pad true image
output = F.pad(input=img["true"], pad=(0, 0, 0, 63), mode="constant", value=0)
output = F.pad(
input=img["true"],
pad=(0, 0, 0, img_size // 2 - 1),
mode="constant",
value=0,
)
img["true"] = symmetry(output, None)
ifft_truth = get_ifft(img["true"], amp_phase=conf["amp_phase"])

Expand Down
14 changes: 9 additions & 5 deletions radionets/evaluation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ def tn_numba_vec_parallel(mu, sig, a, b):
return rv


def trunc_rvs(mu, sig, num_samples, mode, target="cpu", nthreads=1):
def trunc_rvs(mu, sig, num_samples, mode, target="parallel", nthreads=20):
if mode == "amp":
a = 0
b = np.inf
Expand Down Expand Up @@ -580,15 +580,15 @@ def sample_images(mean, std, num_samples, conf):
sig=std_amp,
mode=mode[0],
num_samples=num_samples,
).reshape(num_img * num_samples, 65, 128)
).reshape(num_img * num_samples, mean_amp.shape[-2], mean_amp.shape[-1])

# phase
sampled_gauss_phase = trunc_rvs(
mu=mean_phase,
sig=std_phase,
mode=mode[1],
num_samples=num_samples,
).reshape(num_img * num_samples, 65, 128)
).reshape(num_img * num_samples, mean_phase.shape[-2], mean_phase.shape[-1])

# masks
if conf["amp_phase"]:
Expand All @@ -604,13 +604,17 @@ def sample_images(mean, std, num_samples, conf):

# pad resulting images and utilize symmetry
sampled_gauss = F.pad(
input=torch.tensor(sampled_gauss), pad=(0, 0, 0, 63), mode="constant", value=0
input=torch.tensor(sampled_gauss),
pad=(0, 0, 0, mean_amp.shape[-2] - 2),
mode="constant",
value=0,
)
sampled_gauss_symmetry = symmetry(sampled_gauss, None)
print(sampled_gauss_symmetry.shape)

fft_sampled_symmetry = get_ifft(
sampled_gauss_symmetry, amp_phase=conf["amp_phase"], scale=False
).reshape(num_img, num_samples, 128, 128)
).reshape(num_img, num_samples, mean_amp.shape[-1], mean_amp.shape[-1])

results = {
"mean": fft_sampled_symmetry.mean(axis=1),
Expand Down

0 comments on commit 3223403

Please sign in to comment.