Skip to content

Commit

Permalink
Add fix for uncertainty training
Browse files Browse the repository at this point in the history
  • Loading branch information
FeGeyer committed Apr 19, 2024
1 parent 33460e0 commit 608cc86
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions radionets/dl_framework/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def plot_test_pred(self):
pred = rescale_normalization(pred, norm_dict)
if pred.shape[1] == 4:
self.uncertainty = True
pred = torch.stack((pred[:, 0, :], pred[:, 2, :]), dim=1)
images = {"pred": pred, "truth": img_true}
images = apply_symmetry(images)
pred = images["pred"]
Expand All @@ -74,10 +75,7 @@ def plot_test_pred(self):
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 10))
lim_phase = check_vmin_vmax(img_true[0, 1])
im1 = ax1.imshow(pred[0, 0], cmap="inferno")
if self.uncertainty:
im2 = ax2.imshow(pred[0, 2], cmap=OrBu, vmin=-lim_phase, vmax=lim_phase)
else:
im2 = ax2.imshow(pred[0, 1], cmap=OrBu, vmin=-lim_phase, vmax=lim_phase)
im2 = ax2.imshow(pred[0, 1], cmap=OrBu, vmin=-lim_phase, vmax=lim_phase)
im3 = ax3.imshow(img_true[0, 0], cmap="inferno")
im4 = ax4.imshow(img_true[0, 1], cmap=OrBu, vmin=-lim_phase, vmax=lim_phase)
make_axes_nice(fig, ax1, im1, "Real")
Expand All @@ -103,6 +101,8 @@ def plot_test_fft(self):
with torch.no_grad():
pred = eval_model(img_test, model)
pred = rescale_normalization(pred, norm_dict)
if self.uncertainty:
pred = torch.stack((pred[:, 0, :], pred[:, 2, :]), dim=1)
images = {"pred": pred, "truth": img_true}
images = apply_symmetry(images)
pred = images["pred"]
Expand Down

0 comments on commit 608cc86

Please sign in to comment.