Skip to content

Commit

Permalink
pre-commits and symmetry layer (#126)
Browse files Browse the repository at this point in the history
* Add pre-commits to repository

* Add warning to symmetry layer

* Add test for symmetry layer

* Delete unused import

* Add another symmetry test

* Delete old symmetry function and rename

* Update url

* Update README.md
  • Loading branch information
FeGeyer authored Dec 1, 2022
1 parent 5a97566 commit eae8dac
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 41 deletions.
10 changes: 10 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
repos:
- repo: https://github.com/psf/black
rev: 22.6.0
hooks:
- id: black-jupyter
- repo: https://github.com/PyCQA/flake8
rev: 4.0.1
hooks:
- id: flake8
args: [--max-line-length=88, "--extend-ignore=E203"]
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ You can create one by running the following command in this repository:
$ conda env create -f environment.yml
```
Depending on your `cuda` version you have to specify the `cudatoolkit` version used by `pytorch`. If you are working on machines
with `cuda` versions < 10.2, please change the version number in the environment.yml file.
with `cuda` versions < 10.2, please change the version number in the environment.yml file. Since the package `pre-commit` is used, you need to execute
```
$ pre-commit install
```
after the installation.

## Usage

Expand Down
5 changes: 0 additions & 5 deletions radionets/dl_framework/architectures/unc_archs.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import torch
from torch import nn
from radionets.dl_framework.model import (
Lambda,
symmetry,
GeneralELU,
LocallyConnected2d,
)
from radionets.dl_framework.architectures.res_exp import SRResNet
from functools import partial


class Uncertainty(nn.Module):
Expand Down Expand Up @@ -43,8 +40,6 @@ def __init__(self, img_size):
)
)

self.symmetry = Lambda(partial(symmetry, mode="real"))

self.elu = GeneralELU(add=+(1 + 1e-7))

def forward(self, x):
Expand Down
36 changes: 7 additions & 29 deletions radionets/dl_framework/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch.nn.functional as F
from torch.nn.modules.utils import _pair
from pathlib import Path
from math import sqrt, pi
from math import pi


class Lambda(nn.Module):
Expand All @@ -15,30 +15,6 @@ def forward(self, x):
return self.func(x)


def symmetry(x, mode="real"):
center = (x.shape[1]) // 2
u = torch.arange(center)
v = torch.arange(center)

diag1 = torch.arange(center, x.shape[1])
diag2 = torch.arange(center, x.shape[1])
diag_indices = torch.stack((diag1, diag2))
grid = torch.tril_indices(x.shape[1], x.shape[1], -1)

x_sym = torch.cat((grid[0].reshape(-1, 1), diag_indices[0].reshape(-1, 1)))
y_sym = torch.cat((grid[1].reshape(-1, 1), diag_indices[1].reshape(-1, 1)))
x = torch.rot90(x, 1, dims=(1, 2))
i = center + (center - x_sym)
j = center + (center - y_sym)
u = center - (center - x_sym)
v = center - (center - y_sym)
if mode == "real":
x[:, i, j] = x[:, u, v]
if mode == "imag":
x[:, i, j] = -x[:, u, v]
return torch.rot90(x, 3, dims=(1, 2))


class GeneralRelu(nn.Module):
def __init__(self, leak=None, sub=None, maxv=None):
super().__init__()
Expand Down Expand Up @@ -75,8 +51,8 @@ def init_cnn_(m, f):
f(m.weight, a=0.1)
if getattr(m, "bias", None) is not None:
m.bias.data.zero_()
for l in m.children():
init_cnn_(l, f)
for c in m.children():
init_cnn_(c, f)


def init_cnn(m, uniform=False):
Expand Down Expand Up @@ -223,7 +199,7 @@ def __init__(
in_channels,
output_size[0],
output_size[1],
kernel_size ** 2,
kernel_size**2,
)
)
if bias:
Expand Down Expand Up @@ -313,7 +289,9 @@ def _conv_block(self, ni, nf, stride):
)


def even_better_symmetry(x):
def symmetry(x):
if x.shape[-1] % 2 != 0:
raise ValueError("The symmetry function only works for even image sizes.")
upper_half = x[:, :, 0 : x.shape[2] // 2, :].clone()
upper_left = upper_half[:, :, :, 0 : upper_half.shape[3] // 2].clone()
upper_right = upper_half[:, :, :, upper_half.shape[3] // 2 :].clone()
Expand Down
8 changes: 5 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name="radionets",
version="0.1.14",
description="Imaging radio interferometric data with neural networks",
url="https://github.com/Kevin2/radionets",
url="https://github.com/radionets-project/radionets",
author="Kevin Schmidt, Felix Geyer",
author_email="[email protected], [email protected]",
license="MIT",
Expand All @@ -25,14 +25,16 @@
"pytest",
"pytest-cov",
"pytest-order",
"comet_ml"
"comet_ml",
"pre-commit",
],
setup_requires=["pytest-runner"],
tests_require=["pytest"],
zip_safe=False,
entry_points={
"console_scripts": [
"radionets_simulations = radionets.simulations.scripts.simulate_images:main",
"radionets_simulations = radionets.simulations.scripts.simulate_images\
:main",
"radionets_training = radionets.dl_training.scripts.start_training:main",
"radionets_evaluation = radionets.evaluation.scripts.start_evaluation:main",
],
Expand Down
29 changes: 26 additions & 3 deletions tests/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ def test_im_to_array_value(self):

x_coords, y_coords, value = im_to_array_value(image)

assert x_coords.shape == (2, 64 ** 2)
assert y_coords.shape == (2, 64 ** 2)
assert value.shape == (2, 64 ** 2)
assert x_coords.shape == (2, 64**2)
assert y_coords.shape == (2, 64**2)
assert value.shape == (2, 64**2)

def test_bmul(self):
import torch
Expand Down Expand Up @@ -323,6 +323,29 @@ def test_gan_sources(self):

assert evaluate_gan_sources(conf) is None

def test_symmetry(self):
import torch
from radionets.dl_framework.model import symmetry

x = torch.randint(0, 9, size=(1, 2, 4, 4))
x_symm = symmetry(x.clone())
for i in range(x.shape[-1]):
for j in range(x.shape[-1]):
assert (
x_symm[0, 0, i, j]
== x_symm[0, 0, x.shape[-1] - 1 - i, x.shape[-1] - 1 - j]
)
assert (
x_symm[0, 1, i, j]
== -x_symm[0, 1, x.shape[-1] - 1 - i, x.shape[-1] - 1 - j]
)

rot_amp = torch.rot90(x_symm[0, 0], 2)
rot_phase = torch.rot90(x_symm[0, 1], 2)

assert torch.isclose(rot_amp - x_symm[0, 0], torch.tensor(0)).all()
assert torch.isclose(rot_phase + x_symm[0, 1], torch.tensor(0)).all()

def test_evaluation(self):
import shutil
import os
Expand Down

0 comments on commit eae8dac

Please sign in to comment.