Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize symmetry computation #168

Merged
merged 19 commits into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ jobs:

- name: mamba setup
if: matrix.install-method == 'mamba'
uses: mamba-org/provision-with-micromamba@v14
uses: mamba-org/setup-micromamba@v1
with:
environment-file: environment.yml
cache-downloads: true

- name: Python setup
if: matrix.install-method == 'pip'
Expand All @@ -58,10 +61,12 @@ jobs:
check-latest: true

- name: Install dependencies
env:
PYTHON_VERSION: ${{ matrix.python-version }}
run: |
python --version
pip install pytest-cov restructuredtext-lint pytest-xdist 'coverage!=6.3.0'
pip install .[all]
pip install pytest-cov pytest-xdist 'coverage!=6.3.0'
pip install -e .[all]
pip freeze

- name: List installed package versions (conda)
Expand Down
3 changes: 3 additions & 0 deletions docs/changes/168.optimization.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
- add keyword for half of the image
- distinguish between tensor and array in get_ifft
- fix micromamba installation
25 changes: 1 addition & 24 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,26 +1,3 @@
[build-system]
requires = ["setuptools>=61.0"]
requires = ["setuptools>=64"]
build-backend = "setuptools.build_meta"

[project]
name = "radionets"
version = "0.2.0"
authors = [
{ name="Kevin Schmidt", email="[email protected]" },
]
description = "Imaging radio interferometric data with Neural Networks."
readme = "README.md"
requires-python = ">=3.7"
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Topic :: Scientific/Engineering :: Astronomy",
"Topic :: Scientific/Engineering :: Physics",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Scientific/Engineering :: Information Analysis",
]

[project.urls]
"Homepage" = "https://github.com/pypa/sampleproject"
"Bug Tracker" = "https://github.com/pypa/sampleproject/issues"
2 changes: 1 addition & 1 deletion radionets/evaluation/train_inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def get_prediction(conf, mode="test"):
images["pred"] = pred
images["indices"] = indices

if images["pred"].shape[-1] == 128:
if images["pred"].shape[-2] < images["pred"].shape[-1]:
images = apply_symmetry(images)

return images
Expand Down
28 changes: 17 additions & 11 deletions radionets/evaluation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,6 @@ def get_images(test_ds, num_images, rand=False, indices=None):

img_test = test_ds[indices][0]
img_true = test_ds[indices][1]
img_test = img_test[:, :, :65, :]
img_true = img_true[:, :, :65, :]
return img_test, img_true, indices
else:
mean = test_ds[indices][0]
Expand Down Expand Up @@ -356,7 +354,10 @@ def get_ifft(array, amp_phase=False, scale=False):
image(s) in image space
"""
if len(array.shape) == 3:
array = array.unsqueeze(0)
if hasattr(array, "numpy"):
array = array.unsqueeze(0)
else:
array = array[np.newaxis, :]
if amp_phase:
if scale:
amp = 10 ** (10 * array[:, 0] - 10) - 1e-10
Expand Down Expand Up @@ -439,18 +440,19 @@ def symmetry(image, key):
image = torch.tensor(image)
if len(image.shape) == 3:
image = image.view(1, image.shape[0], image.shape[1], image.shape[2])
upper_half = image[:, :, :64, :].clone()
half_image = image.shape[-1] // 2
upper_half = image[:, :, :half_image, :].clone()
a = torch.rot90(upper_half, 2, dims=[-2, -1])

image[:, 0, 65:, 1:] = a[:, 0, :-1, :-1]
image[:, 0, 65:, 0] = a[:, 0, :-1, -1]
image[:, 0, half_image + 1 :, 1:] = a[:, 0, :-1, :-1]
image[:, 0, half_image + 1 :, 0] = a[:, 0, :-1, -1]

if key == "unc":
image[:, 1, 65:, 1:] = a[:, 1, :-1, :-1]
image[:, 1, 65:, 0] = a[:, 1, :-1, -1]
image[:, 1, half_image + 1 :, 1:] = a[:, 1, :-1, :-1]
image[:, 1, half_image + 1 :, 0] = a[:, 1, :-1, -1]
else:
image[:, 1, 65:, 1:] = -a[:, 1, :-1, :-1]
image[:, 1, 65:, 0] = -a[:, 1, :-1, -1]
image[:, 1, half_image + 1 :, 1:] = -a[:, 1, :-1, :-1]
image[:, 1, half_image + 1 :, 0] = -a[:, 1, :-1, -1]

return image

Expand All @@ -473,8 +475,12 @@ def apply_symmetry(img_dict):
if key != "indices":
if isinstance(img_dict[key], np.ndarray):
img_dict[key] = torch.tensor(img_dict[key])
half_image = img_dict[key].shape[-1] // 2
output = F.pad(
input=img_dict[key], pad=(0, 0, 0, 63), mode="constant", value=0
input=img_dict[key],
pad=(0, 0, 0, half_image - 1),
mode="constant",
value=0,
)
output = symmetry(output, key)
img_dict[key] = output
Expand Down
57 changes: 56 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,60 @@
[metadata]
name = radionets
version = 0.3.1
author = Kevin Schmidt, Felix Geyer
author_email = [email protected], [email protected]
license = MIT
description = Imaging radio interferometric data with neural networks
url = https://github.com/radionets-project/radionets
classifiers =
Development Status :: 4 - Beta
Intended Audience :: Science/Research
License :: OSI Approved :: MIT License
Natural Language :: English
Operating System :: OS Independent
Programming Language :: Python
Programming Language :: Python :: 3.6
Programming Language :: Python :: 3.7
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3 :: Only
Topic :: Scientific/Engineering :: Astronomy
Topic :: Scientific/Engineering :: Physics
Topic :: Scientific/Engineering :: Artificial Intelligence
Topic :: Scientific/Engineering :: Information Analysis

[aliases]
test = pytest

[options]
packages = find:
zip_safe = False
setup_requires = pytest-runner
install_requires =
fastai
kornia
pytorch-msssim
numpy
astropy
tqdm
click
numba
jupyter
h5py
scikit-image
pandas
toml
pytest
pytest-cov
pytest-order
comet_ml
pre-commit
tests_require = pytest

[tool:pytest]
addopts = --verbose
addopts = --verbose

[options.entry_points]
console_scripts =
radionets_simulations = radionets.simulations.scripts.simulate_images:main
radionets_training = radionets.dl_training.scripts.start_training:main
radionets_evaluation = radionets.evaluation.scripts.start_evaluation:main
58 changes: 0 additions & 58 deletions setup.py

This file was deleted.

40 changes: 20 additions & 20 deletions tests/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,21 @@ def test_get_prediction(self):
out_path.mkdir(parents=True, exist_ok=True)
save_pred(str(out_path) + "/predictions_model_eval.h5", img)

def test_get_ifft(self):
import numpy as np
import torch

from radionets.evaluation.utils import get_ifft

a = torch.zeros([10, 2, 64, 64])
test_torch = get_ifft(a, amp_phase=True)
b = np.zeros([2, 64, 64])
test_numpy = get_ifft(b, amp_phase=True)
print(test_numpy.shape)
assert ~np.isnan([test_torch]).any()
assert ~np.isnan([test_numpy]).any()
assert len(test_torch.shape) == len(test_numpy.shape) + 1

def test_contour(self):
import numpy as np
import toml
Expand Down Expand Up @@ -329,26 +344,11 @@ def test_gan_sources(self):
def test_symmetry(self):
import torch

from radionets.dl_framework.model import symmetry

x = torch.randint(0, 9, size=(1, 2, 4, 4))
x_symm = symmetry(x.clone())
for i in range(x.shape[-1]):
for j in range(x.shape[-1]):
assert (
x_symm[0, 0, i, j]
== x_symm[0, 0, x.shape[-1] - 1 - i, x.shape[-1] - 1 - j]
)
assert (
x_symm[0, 1, i, j]
== -x_symm[0, 1, x.shape[-1] - 1 - i, x.shape[-1] - 1 - j]
)

rot_amp = torch.rot90(x_symm[0, 0], 2)
rot_phase = torch.rot90(x_symm[0, 1], 2)

assert torch.isclose(rot_amp - x_symm[0, 0], torch.tensor(0)).all()
assert torch.isclose(rot_phase + x_symm[0, 1], torch.tensor(0)).all()
from radionets.evaluation.utils import symmetry

x = torch.randint(0, 9, size=(1, 2, 64, 64))
x_symm = symmetry(x.clone(), key="unc")
assert x_symm.shape == x.shape

def test_sample_images(self):
import numpy as np
Expand Down