Skip to content

Commit

Permalink
Generalize symmetry computation (#168)
Browse files Browse the repository at this point in the history
* Add keyword for generalizing symmetry

* Towncrier

* Fix shape and get_ifft

* Test new micromamba installation

* Update towncrier

* Another try

* Try again

* Try

* remove -e

* New micromamba installation and towncrier update

* Update tests

* Delete print

* Test unc keyword

* Add tests

* More generalization

* Add pyproject toml

* Remove setup.py, add setup.cfg
  • Loading branch information
FeGeyer authored Jan 18, 2024
1 parent d13293f commit 7eff0fb
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 118 deletions.
11 changes: 8 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ jobs:
- name: mamba setup
if: matrix.install-method == 'mamba'
uses: mamba-org/provision-with-micromamba@v14
uses: mamba-org/setup-micromamba@v1
with:
environment-file: environment.yml
cache-downloads: true

- name: Python setup
if: matrix.install-method == 'pip'
Expand All @@ -58,10 +61,12 @@ jobs:
check-latest: true

- name: Install dependencies
env:
PYTHON_VERSION: ${{ matrix.python-version }}
run: |
python --version
pip install pytest-cov restructuredtext-lint pytest-xdist 'coverage!=6.3.0'
pip install .[all]
pip install pytest-cov pytest-xdist 'coverage!=6.3.0'
pip install -e .[all]
pip freeze
- name: List installed package versions (conda)
Expand Down
3 changes: 3 additions & 0 deletions docs/changes/168.optimization.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
- add keyword for half of the image
- distinguish between tensor and array in get_ifft
- fix micromamba installation
25 changes: 1 addition & 24 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,26 +1,3 @@
[build-system]
requires = ["setuptools>=61.0"]
requires = ["setuptools>=64"]
build-backend = "setuptools.build_meta"

[project]
name = "radionets"
version = "0.2.0"
authors = [
{ name="Kevin Schmidt", email="[email protected]" },
]
description = "Imaging radio interferometric data with Neural Networks."
readme = "README.md"
requires-python = ">=3.7"
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Topic :: Scientific/Engineering :: Astronomy",
"Topic :: Scientific/Engineering :: Physics",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Scientific/Engineering :: Information Analysis",
]

[project.urls]
"Homepage" = "https://github.com/pypa/sampleproject"
"Bug Tracker" = "https://github.com/pypa/sampleproject/issues"
2 changes: 1 addition & 1 deletion radionets/evaluation/train_inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def get_prediction(conf, mode="test"):
images["pred"] = pred
images["indices"] = indices

if images["pred"].shape[-1] == 128:
if images["pred"].shape[-2] < images["pred"].shape[-1]:
images = apply_symmetry(images)

return images
Expand Down
28 changes: 17 additions & 11 deletions radionets/evaluation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,6 @@ def get_images(test_ds, num_images, rand=False, indices=None):

img_test = test_ds[indices][0]
img_true = test_ds[indices][1]
img_test = img_test[:, :, :65, :]
img_true = img_true[:, :, :65, :]
return img_test, img_true, indices
else:
mean = test_ds[indices][0]
Expand Down Expand Up @@ -356,7 +354,10 @@ def get_ifft(array, amp_phase=False, scale=False):
image(s) in image space
"""
if len(array.shape) == 3:
array = array.unsqueeze(0)
if hasattr(array, "numpy"):
array = array.unsqueeze(0)
else:
array = array[np.newaxis, :]
if amp_phase:
if scale:
amp = 10 ** (10 * array[:, 0] - 10) - 1e-10
Expand Down Expand Up @@ -439,18 +440,19 @@ def symmetry(image, key):
image = torch.tensor(image)
if len(image.shape) == 3:
image = image.view(1, image.shape[0], image.shape[1], image.shape[2])
upper_half = image[:, :, :64, :].clone()
half_image = image.shape[-1] // 2
upper_half = image[:, :, :half_image, :].clone()
a = torch.rot90(upper_half, 2, dims=[-2, -1])

image[:, 0, 65:, 1:] = a[:, 0, :-1, :-1]
image[:, 0, 65:, 0] = a[:, 0, :-1, -1]
image[:, 0, half_image + 1 :, 1:] = a[:, 0, :-1, :-1]
image[:, 0, half_image + 1 :, 0] = a[:, 0, :-1, -1]

if key == "unc":
image[:, 1, 65:, 1:] = a[:, 1, :-1, :-1]
image[:, 1, 65:, 0] = a[:, 1, :-1, -1]
image[:, 1, half_image + 1 :, 1:] = a[:, 1, :-1, :-1]
image[:, 1, half_image + 1 :, 0] = a[:, 1, :-1, -1]
else:
image[:, 1, 65:, 1:] = -a[:, 1, :-1, :-1]
image[:, 1, 65:, 0] = -a[:, 1, :-1, -1]
image[:, 1, half_image + 1 :, 1:] = -a[:, 1, :-1, :-1]
image[:, 1, half_image + 1 :, 0] = -a[:, 1, :-1, -1]

return image

Expand All @@ -473,8 +475,12 @@ def apply_symmetry(img_dict):
if key != "indices":
if isinstance(img_dict[key], np.ndarray):
img_dict[key] = torch.tensor(img_dict[key])
half_image = img_dict[key].shape[-1] // 2
output = F.pad(
input=img_dict[key], pad=(0, 0, 0, 63), mode="constant", value=0
input=img_dict[key],
pad=(0, 0, 0, half_image - 1),
mode="constant",
value=0,
)
output = symmetry(output, key)
img_dict[key] = output
Expand Down
57 changes: 56 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,60 @@
[metadata]
name = radionets
version = 0.3.1
author = Kevin Schmidt, Felix Geyer
author_email = [email protected], [email protected]
license = MIT
description = Imaging radio interferometric data with neural networks
url = https://github.com/radionets-project/radionets
classifiers =
Development Status :: 4 - Beta
Intended Audience :: Science/Research
License :: OSI Approved :: MIT License
Natural Language :: English
Operating System :: OS Independent
Programming Language :: Python
Programming Language :: Python :: 3.6
Programming Language :: Python :: 3.7
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3 :: Only
Topic :: Scientific/Engineering :: Astronomy
Topic :: Scientific/Engineering :: Physics
Topic :: Scientific/Engineering :: Artificial Intelligence
Topic :: Scientific/Engineering :: Information Analysis

[aliases]
test = pytest

[options]
packages = find:
zip_safe = False
setup_requires = pytest-runner
install_requires =
fastai
kornia
pytorch-msssim
numpy
astropy
tqdm
click
numba
jupyter
h5py
scikit-image
pandas
toml
pytest
pytest-cov
pytest-order
comet_ml
pre-commit
tests_require = pytest

[tool:pytest]
addopts = --verbose
addopts = --verbose

[options.entry_points]
console_scripts =
radionets_simulations = radionets.simulations.scripts.simulate_images:main
radionets_training = radionets.dl_training.scripts.start_training:main
radionets_evaluation = radionets.evaluation.scripts.start_evaluation:main
58 changes: 0 additions & 58 deletions setup.py

This file was deleted.

40 changes: 20 additions & 20 deletions tests/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,21 @@ def test_get_prediction(self):
out_path.mkdir(parents=True, exist_ok=True)
save_pred(str(out_path) + "/predictions_model_eval.h5", img)

def test_get_ifft(self):
import numpy as np
import torch

from radionets.evaluation.utils import get_ifft

a = torch.zeros([10, 2, 64, 64])
test_torch = get_ifft(a, amp_phase=True)
b = np.zeros([2, 64, 64])
test_numpy = get_ifft(b, amp_phase=True)
print(test_numpy.shape)
assert ~np.isnan([test_torch]).any()
assert ~np.isnan([test_numpy]).any()
assert len(test_torch.shape) == len(test_numpy.shape) + 1

def test_contour(self):
import numpy as np
import toml
Expand Down Expand Up @@ -329,26 +344,11 @@ def test_gan_sources(self):
def test_symmetry(self):
import torch

from radionets.dl_framework.model import symmetry

x = torch.randint(0, 9, size=(1, 2, 4, 4))
x_symm = symmetry(x.clone())
for i in range(x.shape[-1]):
for j in range(x.shape[-1]):
assert (
x_symm[0, 0, i, j]
== x_symm[0, 0, x.shape[-1] - 1 - i, x.shape[-1] - 1 - j]
)
assert (
x_symm[0, 1, i, j]
== -x_symm[0, 1, x.shape[-1] - 1 - i, x.shape[-1] - 1 - j]
)

rot_amp = torch.rot90(x_symm[0, 0], 2)
rot_phase = torch.rot90(x_symm[0, 1], 2)

assert torch.isclose(rot_amp - x_symm[0, 0], torch.tensor(0)).all()
assert torch.isclose(rot_phase + x_symm[0, 1], torch.tensor(0)).all()
from radionets.evaluation.utils import symmetry

x = torch.randint(0, 9, size=(1, 2, 64, 64))
x_symm = symmetry(x.clone(), key="unc")
assert x_symm.shape == x.shape

def test_sample_images(self):
import numpy as np
Expand Down

0 comments on commit 7eff0fb

Please sign in to comment.