Skip to content

Commit

Permalink
Merge pull request #140 from radionets-project/symmetry_training
Browse files Browse the repository at this point in the history
Utilize symmetry
  • Loading branch information
Kevin2 authored Feb 1, 2023
2 parents 99d346f + 888d43b commit df40587
Show file tree
Hide file tree
Showing 23 changed files with 252 additions and 138 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,4 @@ __pycache__/
dist/
radionets/dl_framework/values.yaml
*.csv
*.swp
35 changes: 35 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,38 @@
Radionets 0.2.0 (2023-01-31)
============================


API Changes
-----------

- train on half-sized iamges and applying symmetry afterward is a backward incompatible change
models trained with early versions of `radionets` are not supported anymore [`#140 <https://github.com/radionets-project/radionets/pull/140>`__]


Bug Fixes
---------

- fixed sampling of test data set
fixed same indices for plots [`#140 <https://github.com/radionets-project/radionets/pull/140>`__]


New Features
------------

- enabled training and evaluation of half sized images (for 128 pixel images) [`#140 <https://github.com/radionets-project/radionets/pull/140>`__]


Maintenance
-----------

- Deleted unusable functions for new source types
Deleted unused hardcoded scaling [`#140 <https://github.com/radionets-project/radionets/pull/140>`__]


Refactoring and Optimization
----------------------------


Radionets 0.1.18 (2023-01-30)
=============================

Expand Down
7 changes: 0 additions & 7 deletions docs/changes/129.feature.rst

This file was deleted.

1 change: 0 additions & 1 deletion docs/changes/130.maintenance.rst

This file was deleted.

1 change: 0 additions & 1 deletion docs/changes/134.maintenance.rst

This file was deleted.

1 change: 0 additions & 1 deletion docs/changes/136.maintenance.rst

This file was deleted.

2 changes: 2 additions & 0 deletions docs/changes/140.bugfixes.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
fixed sampling of test data set
fixed same indices for plots
1 change: 1 addition & 0 deletions docs/changes/140.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
enabled training and evaluation of half sized images (for 128 pixel images)
2 changes: 2 additions & 0 deletions docs/changes/140.maintenance.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Deleted unusable functions for new source types
Deleted unused hardcoded scaling
1 change: 0 additions & 1 deletion examples/default_eval_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ visualize_prediction = true
visualize_source_reconstruction = true
visualize_contour = true
visualize_dynamic_range = false
visualize_blobs = true
visualize_ms_ssim = false
random = false
num_images = 5
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "radionets"
version = "0.1.16"
version = "0.2.0"
authors = [
{ name="Kevin Schmidt", email="[email protected]" },
]
Expand Down
8 changes: 3 additions & 5 deletions radionets/dl_framework/architectures/res_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,7 @@ def __init__(self):
nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False), nn.BatchNorm2d(64)
)

self.final = nn.Sequential(
nn.Conv2d(64, 2, 9, stride=1, padding=4, groups=2),
)
self.final = nn.Sequential(nn.Conv2d(64, 2, 9, stride=1, padding=4, groups=2))

self.hardtanh = nn.Hardtanh(-pi, pi)
self.relu = nn.ReLU()
Expand All @@ -167,9 +165,9 @@ def forward(self, x):

x = self.final(x)

x0 = x[:, 0].reshape(-1, 1, s, s)
x0 = x[:, 0].reshape(-1, 1, s // 2 + 1, s)
x0 = self.relu(x0)
x1 = self.hardtanh(x[:, 1]).reshape(-1, 1, s, s)
x1 = self.hardtanh(x[:, 1]).reshape(-1, 1, s // 2 + 1, s)

return torch.cat([x0, x1], dim=1)

Expand Down
20 changes: 10 additions & 10 deletions radionets/dl_framework/architectures/unc_archs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,36 @@
GeneralELU,
LocallyConnected2d,
)
from radionets.dl_framework.architectures.res_exp import SRResNet
from radionets.dl_framework.architectures.res_exp import SRResNet_16


class Uncertainty(nn.Module):
def __init__(self, img_size):
super().__init__()

self.conv1 = nn.Sequential(
nn.Conv2d(4, 8, 9, stride=1, padding=4, groups=2),
nn.BatchNorm2d(8),
nn.Conv2d(4, 16, 9, stride=1, padding=4, groups=2),
nn.BatchNorm2d(16),
nn.PReLU(),
)

self.conv2 = nn.Sequential(
nn.Conv2d(8, 16, 3, stride=1, padding=1),
nn.BatchNorm2d(16),
nn.Conv2d(16, 32, 3, stride=1, padding=1),
nn.BatchNorm2d(32),
nn.PReLU(),
)

self.conv3 = nn.Sequential(
nn.Conv2d(16, 32, 9, stride=1, padding=4, groups=2),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, 9, stride=1, padding=4, groups=2),
nn.BatchNorm2d(64),
nn.PReLU(),
)

self.final = nn.Sequential(
LocallyConnected2d(
32,
64,
2,
img_size,
[img_size // 2 + 1, img_size],
1,
stride=1,
bias=False,
Expand All @@ -55,7 +55,7 @@ def forward(self, x):
class UncertaintyWrapper(nn.Module):
def __init__(self, img_size):
super().__init__()
self.pred = SRResNet()
self.pred = SRResNet_16()

self.uncertainty = Uncertainty(img_size)

Expand Down
2 changes: 1 addition & 1 deletion radionets/dl_framework/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ class DataAug(Callback):
def before_batch(self):
x = self.xb[0].clone()
y = self.yb[0].clone()
randint = np.random.randint(0, 4, x.shape[0])
randint = np.random.randint(0, 1, x.shape[0]) * 2
last_axis = len(x.shape) - 1
for i in range(x.shape[0]):
x[i] = torch.rot90(x[i], int(randint[i]), [last_axis - 2, last_axis - 1])
Expand Down
2 changes: 2 additions & 0 deletions radionets/dl_framework/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def open_image(self, var, i):
if data.shape[0] == 1:
data = data.squeeze(0)

data = data[:, :65, :]

return data.float()


Expand Down
9 changes: 6 additions & 3 deletions radionets/dl_framework/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@ def __init__(
self, in_channels, out_channels, output_size, kernel_size, stride, bias=False
):
super(LocallyConnected2d, self).__init__()
output_size = _pair(output_size)
self.weight = nn.Parameter(
torch.randn(
1,
Expand Down Expand Up @@ -281,10 +280,14 @@ def forward(self, x):

def _conv_block(self, ni, nf, stride):
return nn.Sequential(
nn.Conv2d(ni, nf, 3, stride=stride, padding=1, bias=False),
nn.Conv2d(
ni, nf, 3, stride=stride, padding=1, bias=False, padding_mode="reflect"
),
nn.BatchNorm2d(nf),
nn.PReLU(),
nn.Conv2d(nf, nf, 3, stride=1, padding=1, bias=False),
nn.Conv2d(
nf, nf, 3, stride=1, padding=1, bias=False, padding_mode="reflect"
),
nn.BatchNorm2d(nf),
)

Expand Down
69 changes: 10 additions & 59 deletions radionets/evaluation/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@
import torch
import matplotlib as mpl
from matplotlib.colors import ListedColormap, LogNorm
from matplotlib.lines import Line2D
from matplotlib.patches import Arc, Rectangle
from matplotlib.patches import Rectangle
from mpl_toolkits.axes_grid1 import make_axes_locatable
from pytorch_msssim import ms_ssim
from radionets.evaluation.blob_detection import calc_blobs

from radionets.evaluation.contour import compute_area_ratio
from radionets.evaluation.dynamic_range import calc_dr, get_boxsize
from radionets.evaluation.jet_angle import calc_jet_angle

from radionets.evaluation.utils import (
check_vmin_vmax,
make_axes_nice,
Expand Down Expand Up @@ -209,11 +208,6 @@ def visualize_with_fourier(
real_pred, imag_pred = img_pred[0], img_pred[1]
real_truth, imag_truth = img_truth[0], img_truth[1]

if amp_phase:
inp_real = 10 ** (10 * inp_real - 10) - 1e-10
real_pred = 10 ** (10 * real_pred - 10) - 1e-10
real_truth = 10 ** (10 * real_truth - 10) - 1e-10

# plotting
fig, ((ax1, ax2, ax3), (ax4, ax5, ax6)) = plt.subplots(
2, 3, figsize=(16, 10), sharex=True, sharey=True
Expand Down Expand Up @@ -291,21 +285,17 @@ def visualize_with_fourier_diff(
real_pred, imag_pred = img_pred[0], img_pred[1]
real_truth, imag_truth = img_truth[0], img_truth[1]

if amp_phase:
real_pred = 10 ** (10 * real_pred - 10) - 1e-10
real_truth = 10 ** (10 * real_truth - 10) - 1e-10

# plotting
# plt.style.use('./paper_large_3_2.rc')
fig, ((ax1, ax2, ax3), (ax4, ax5, ax6)) = plt.subplots(
2, 3, figsize=(16, 10), sharex=True, sharey=True
)

if amp_phase:
im1 = ax1.imshow(real_pred, cmap="inferno", norm=LogNorm())
im1 = ax1.imshow(real_pred, cmap="inferno")
make_axes_nice(fig, ax1, im1, r"Amplitude Prediction")

im2 = ax2.imshow(real_truth, cmap="inferno", norm=LogNorm())
im2 = ax2.imshow(real_truth, cmap="inferno")
make_axes_nice(fig, ax2, im2, r"Amplitude Truth")

a = check_vmin_vmax(real_pred - real_truth)
Expand Down Expand Up @@ -344,37 +334,16 @@ def visualize_source_reconstruction(
out_path,
i,
dr=False,
blobs=False,
msssim=False,
plot_format="png",
):
m_truth, n_truth, alpha_truth = calc_jet_angle(ifft_truth)
m_pred, n_pred, alpha_pred = calc_jet_angle(ifft_pred)
x_space = torch.arange(0, 63, 1)

# plt.style.use("./paper_large_3.rc")
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(16, 10), sharey=True)

# Plot prediction
ax1.plot(x_space, m_pred * x_space + n_pred, "w--", alpha=0.5)
ax1.axvline(32, 0, 1, linestyle="--", color="white", alpha=0.5)

# create angle visualization
theta1 = min(0, -alpha_pred.numpy()[0])
theta2 = max(0, -alpha_pred.numpy()[0])
ax1.add_patch(Arc([32, 32], 50, 50, 90, theta1, theta2, color="white"))

im1 = ax1.imshow(ifft_pred, vmax=ifft_truth.max(), cmap="inferno")

# Plot truth
ax2.plot(x_space, m_truth * x_space + n_truth, "w--", alpha=0.5)
ax2.axvline(32, 0, 1, linestyle="--", color="white", alpha=0.5)

# create angle visualization
theta1 = min(0, -alpha_truth.numpy()[0])
theta2 = max(0, -alpha_truth.numpy()[0])
ax2.add_patch(Arc([32, 32], 50, 50, 90, theta1, theta2, color="white"))

im2 = ax2.imshow(ifft_truth, cmap="inferno")

a = check_vmin_vmax(ifft_pred - ifft_truth)
Expand All @@ -389,15 +358,6 @@ def visualize_source_reconstruction(
ax2.set_xlabel(r"Pixels")
ax3.set_xlabel(r"Pixels")

# ax1.tick_params(axis="both", labelsize=20)
# ax2.tick_params(axis="both", labelsize=20)
# ax3.tick_params(axis="both", labelsize=20)

if blobs:
blobs_pred, blobs_truth = calc_blobs(ifft_pred, ifft_truth)
plot_blobs(blobs_pred, ax1)
plot_blobs(blobs_truth, ax2)

if dr:
dr_truth, dr_pred, num_boxes, corners = calc_dr(
ifft_truth[None, ...], ifft_pred[None, ...]
Expand All @@ -417,15 +377,6 @@ def visualize_source_reconstruction(

outpath = str(out_path) + f"/fft_pred_{i}.{plot_format}"

line = Line2D(
[], [], linestyle="-", color="w", label=rf"$\alpha = {alpha_pred[0]:.2f}\,$deg"
)
line_truth = Line2D(
[], [], linestyle="-", color="w", label=rf"$\alpha = {alpha_truth[0]:.2f}\,$deg"
)

ax1.legend(loc="best", handles=[line])
ax2.legend(loc="best", handles=[line_truth])
fig.tight_layout(pad=1)
plt.savefig(outpath, bbox_inches="tight", pad_inches=0.05)
plt.close("all")
Expand Down Expand Up @@ -472,19 +423,19 @@ def visualize_uncertainty(
2, 2, sharey=True, sharex=True, figsize=(12, 10)
)

im1 = ax1.imshow(true_phase)
im1 = ax1.imshow(true_phase, cmap=OrBu, vmin=-np.pi, vmax=np.pi)

im2 = ax2.imshow(pred_phase)
im2 = ax2.imshow(pred_phase, cmap=OrBu, vmin=-np.pi, vmax=np.pi)

im3 = ax3.imshow(unc_phase)

a = check_vmin_vmax(true_phase - pred_phase)
im4 = ax4.imshow(true_phase - pred_phase, cmap=OrBu, vmin=-a, vmax=a)

make_axes_nice(fig, ax1, im1, r"Simulation")
make_axes_nice(fig, ax2, im2, r"Predicted $\mu$")
make_axes_nice(fig, ax1, im1, r"Simulation", phase=True)
make_axes_nice(fig, ax2, im2, r"Predicted $\mu$", phase=True)
make_axes_nice(fig, ax3, im3, r"Predicted $\sigma^2$", unc=True)
make_axes_nice(fig, ax4, im4, r"Difference")
make_axes_nice(fig, ax4, im4, r"Difference", phase_diff=True)

ax1.set_ylabel(r"pixels")
ax3.set_ylabel(r"pixels")
Expand Down
4 changes: 1 addition & 3 deletions radionets/evaluation/scripts/start_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def main(configuration_path):
print(eval_conf, "\n")

if eval_conf["sample_unc"]:
click.echo("Sampling test data set.\n")
save_sampled(eval_conf)

for entry in conf["inspection"]:
Expand Down Expand Up @@ -66,9 +67,6 @@ def main(configuration_path):

click.echo(f"\nCreated {eval_conf['num_images']} test predictions.\n")

if eval_conf["vis_blobs"]:
click.echo("\nBlob visualization is enabled for source plots.\n")

if eval_conf["vis_ms_ssim"]:
click.echo("\nVisualization of ms ssim is enabled for source plots.\n")

Expand Down
Loading

0 comments on commit df40587

Please sign in to comment.