Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize symmetry computation #168

Merged
merged 19 commits into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ jobs:

- name: mamba setup
if: matrix.install-method == 'mamba'
uses: mamba-org/provision-with-micromamba@v14
uses: mamba-org/setup-micromamba@v1
with:
environment-file: environment.yml
cache-downloads: true

- name: Python setup
if: matrix.install-method == 'pip'
Expand All @@ -58,10 +61,12 @@ jobs:
check-latest: true

- name: Install dependencies
env:
PYTHON_VERSION: ${{ matrix.python-version }}
run: |
python --version
pip install pytest-cov restructuredtext-lint pytest-xdist 'coverage!=6.3.0'
pip install .[all]
pip install pytest-cov pytest-xdist 'coverage!=6.3.0'
pip install -e .[all]
pip freeze

- name: List installed package versions (conda)
Expand Down
3 changes: 3 additions & 0 deletions docs/changes/168.optimization.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
- add keyword for half of the image
- distinguish between tensor and array in get_ifft
- fix micromamba installation
25 changes: 1 addition & 24 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,26 +1,3 @@
[build-system]
requires = ["setuptools>=61.0"]
requires = ["setuptools>=64"]
build-backend = "setuptools.build_meta"

[project]
name = "radionets"
version = "0.2.0"
authors = [
{ name="Kevin Schmidt", email="[email protected]" },
]
description = "Imaging radio interferometric data with Neural Networks."
readme = "README.md"
requires-python = ">=3.7"
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Topic :: Scientific/Engineering :: Astronomy",
"Topic :: Scientific/Engineering :: Physics",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Scientific/Engineering :: Information Analysis",
]

[project.urls]
"Homepage" = "https://github.com/pypa/sampleproject"
"Bug Tracker" = "https://github.com/pypa/sampleproject/issues"
2 changes: 1 addition & 1 deletion radionets/evaluation/train_inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def get_prediction(conf, mode="test"):
images["pred"] = pred
images["indices"] = indices

if images["pred"].shape[-1] == 128:
if images["pred"].shape[-2] < images["pred"].shape[-1]:
images = apply_symmetry(images)

return images
Expand Down
28 changes: 17 additions & 11 deletions radionets/evaluation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,6 @@ def get_images(test_ds, num_images, rand=False, indices=None):

img_test = test_ds[indices][0]
img_true = test_ds[indices][1]
img_test = img_test[:, :, :65, :]
img_true = img_true[:, :, :65, :]
return img_test, img_true, indices
else:
mean = test_ds[indices][0]
Expand Down Expand Up @@ -356,7 +354,10 @@ def get_ifft(array, amp_phase=False, scale=False):
image(s) in image space
"""
if len(array.shape) == 3:
array = array.unsqueeze(0)
if hasattr(array, "numpy"):
array = array.unsqueeze(0)
else:
array = array[np.newaxis, :]
if amp_phase:
if scale:
amp = 10 ** (10 * array[:, 0] - 10) - 1e-10
Expand Down Expand Up @@ -439,18 +440,19 @@ def symmetry(image, key):
image = torch.tensor(image)
if len(image.shape) == 3:
image = image.view(1, image.shape[0], image.shape[1], image.shape[2])
upper_half = image[:, :, :64, :].clone()
half_image = image.shape[-1] // 2
upper_half = image[:, :, :half_image, :].clone()
a = torch.rot90(upper_half, 2, dims=[-2, -1])

image[:, 0, 65:, 1:] = a[:, 0, :-1, :-1]
image[:, 0, 65:, 0] = a[:, 0, :-1, -1]
image[:, 0, half_image + 1 :, 1:] = a[:, 0, :-1, :-1]
image[:, 0, half_image + 1 :, 0] = a[:, 0, :-1, -1]

if key == "unc":
image[:, 1, 65:, 1:] = a[:, 1, :-1, :-1]
image[:, 1, 65:, 0] = a[:, 1, :-1, -1]
image[:, 1, half_image + 1 :, 1:] = a[:, 1, :-1, :-1]
image[:, 1, half_image + 1 :, 0] = a[:, 1, :-1, -1]
else:
image[:, 1, 65:, 1:] = -a[:, 1, :-1, :-1]
image[:, 1, 65:, 0] = -a[:, 1, :-1, -1]
image[:, 1, half_image + 1 :, 1:] = -a[:, 1, :-1, :-1]
image[:, 1, half_image + 1 :, 0] = -a[:, 1, :-1, -1]

return image

Expand All @@ -473,8 +475,12 @@ def apply_symmetry(img_dict):
if key != "indices":
if isinstance(img_dict[key], np.ndarray):
img_dict[key] = torch.tensor(img_dict[key])
half_image = img_dict[key].shape[-1] // 2
output = F.pad(
input=img_dict[key], pad=(0, 0, 0, 63), mode="constant", value=0
input=img_dict[key],
pad=(0, 0, 0, half_image - 1),
mode="constant",
value=0,
)
output = symmetry(output, key)
img_dict[key] = output
Expand Down
40 changes: 20 additions & 20 deletions tests/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,21 @@ def test_get_prediction(self):
out_path.mkdir(parents=True, exist_ok=True)
save_pred(str(out_path) + "/predictions_model_eval.h5", img)

def test_get_ifft(self):
import numpy as np
import torch

from radionets.evaluation.utils import get_ifft

a = torch.zeros([10, 2, 64, 64])
test_torch = get_ifft(a, amp_phase=True)
b = np.zeros([2, 64, 64])
test_numpy = get_ifft(b, amp_phase=True)
print(test_numpy.shape)
assert ~np.isnan([test_torch]).any()
assert ~np.isnan([test_numpy]).any()
assert len(test_torch.shape) == len(test_numpy.shape) + 1

def test_contour(self):
import numpy as np
import toml
Expand Down Expand Up @@ -329,26 +344,11 @@ def test_gan_sources(self):
def test_symmetry(self):
import torch

from radionets.dl_framework.model import symmetry

x = torch.randint(0, 9, size=(1, 2, 4, 4))
x_symm = symmetry(x.clone())
for i in range(x.shape[-1]):
for j in range(x.shape[-1]):
assert (
x_symm[0, 0, i, j]
== x_symm[0, 0, x.shape[-1] - 1 - i, x.shape[-1] - 1 - j]
)
assert (
x_symm[0, 1, i, j]
== -x_symm[0, 1, x.shape[-1] - 1 - i, x.shape[-1] - 1 - j]
)

rot_amp = torch.rot90(x_symm[0, 0], 2)
rot_phase = torch.rot90(x_symm[0, 1], 2)

assert torch.isclose(rot_amp - x_symm[0, 0], torch.tensor(0)).all()
assert torch.isclose(rot_phase + x_symm[0, 1], torch.tensor(0)).all()
from radionets.evaluation.utils import symmetry

x = torch.randint(0, 9, size=(1, 2, 64, 64))
x_symm = symmetry(x.clone(), key="unc")
assert x_symm.shape == x.shape

def test_sample_images(self):
import numpy as np
Expand Down