Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
nden committed Oct 28, 2024
1 parent 3f55514 commit 926123b
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 80 deletions.
8 changes: 3 additions & 5 deletions gwcs/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from astropy.wcs.wcsapi import BaseHighLevelWCS, BaseLowLevelWCS
from astropy.modeling import separable
from astropy.wcs.wcsapi.high_level_api import values_to_high_level_objects, high_level_objects_to_values
import astropy.units as u

from . import utils
Expand Down Expand Up @@ -298,10 +299,6 @@ def _sanitize_pixel_inputs(self, *pixel_arrays):

return pixels

def _sanitize_world_inputs(self, *world_arrays):
world_coord = []


def pixel_to_world(self, *pixel_arrays):
"""
Convert pixel values to world coordinates.
Expand All @@ -322,8 +319,9 @@ def world_to_pixel(self, *world_objects):
"""
Convert world coordinates to pixel values.
"""
#args = high_level_objects_to_values(*world_objects, low_level_wcs=self)
#result = self.invert(*args)
result = self.invert(*world_objects, with_units=True)

if self.input_frame.naxes > 1:
first_res = result[0]
if not utils.isnumerical(first_res):
Expand Down
3 changes: 2 additions & 1 deletion gwcs/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,8 @@ def gwcs_3d_galactic_spectral():
transform.bounding_box = ((5, 50), (-2, 45), (-1, 35))

sky_frame = cf.CelestialFrame(axes_order=(2, 0),
reference_frame=coord.Galactic(), axes_names=("Longitude", "Latitude"))
reference_frame=coord.Galactic(),
axes_names=("Longitude", "Latitude"))
wave_frame = cf.SpectralFrame(axes_order=(1, ), unit=u.Hz, axes_names=("Frequency",))

frame = cf.CompositeFrame([sky_frame, wave_frame])
Expand Down
57 changes: 14 additions & 43 deletions gwcs/tests/test_bounding_box.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import numpy as np
from numpy.testing import assert_array_equal, assert_allclose

from astropy import units as u

import pytest


Expand All @@ -13,7 +11,7 @@

@pytest.mark.parametrize((("input", "output")), [((2, 4), (2, 4)),
((100, 200), (np.nan, np.nan)),
((x, x), (y, y))
((x, x),(y, y))
])
def test_2d_spatial(gwcs_2d_spatial_shift, input, output):
w = gwcs_2d_spatial_shift
Expand Down Expand Up @@ -52,7 +50,7 @@ def test_2d_spatial_coordinate_reordered(gwcs_2d_spatial_reordered, input, outpu

@pytest.mark.parametrize((("input", "output")), [(2, 2),
((10, 200), (10, np.nan)),
(x, (np.nan, 2, 4, 13))
(x, (np.nan, 2, 4, 13))
])
def test_1d_freq(gwcs_1d_freq, input, output):
w = gwcs_1d_freq
Expand All @@ -76,46 +74,19 @@ def test_3d_spatial_wave(gwcs_3d_spatial_wave, input, output):
assert_array_equal(w.world_to_pixel(*w.pixel_to_world(*input)), output)


@pytest.mark.parametrize((("input", "output")), [(2, 2),
((10, 200), (10, np.nan)),
(x, (np.nan, 2, 4, 13))
])
def test_1d_freq_quantity(gwcs_1d_freq_quantity, input, output):
w = gwcs_1d_freq_quantity
#w.bounding_box = (-.5*u.pix, 21*u.pix)
w.bounding_box = (-.5, 21)

# assert_array_equal(w.invert(w(input)), output)
# assert_array_equal(w.world_to_pixel_values(w.pixel_to_world_values(*input)), output)
# assert_array_equal(w.world_to_pixel(w.pixel_to_world(input)), output)


@pytest.mark.parametrize((("input", "output")), [((2, 4), (2, 4)),
((100, 200), (np.nan, np.nan)),
((x, x), (y, y))
])
def test_2d_shift_scale_quantity(gwcs_2d_shift_scale_quantity, input, output):
w = gwcs_2d_shift_scale_quantity
w.bounding_box = ((-.5, 21), (4, 12))

assert_array_equal(w.invert(*w(*input)), output)
assert_array_equal(w.world_to_pixel_values(*w.pixel_to_world_values(*input)), output)
assert_array_equal(w.world_to_pixel(w.pixel_to_world(*input)), output)


@pytest.mark.parametrize((("input", "output")), [((2, 4, 5), (2, 4, 5)),
((100, 200, 5), (np.nan, np.nan, np.nan)),
((x, x, x), (y1, y1, y1))
@pytest.mark.parametrize((("input", "output")), [((1, 2, 3, 4), (1., 2., 3., 4.)),
((100, 3, 3, 3), (np.nan, 3, 3, 3)),
((x, x, x, x), [[np.nan, 2., 4., 13.],
[np.nan, 2., 4., 13.],
[np.nan, 2., 4., 13.],
[np.nan, 2., 4., np.nan]])
])
def test_3d_identity_units(gwcs_3d_identity_units, input, output):
w = gwcs_3d_identity_units
w.bounding_box = ((-.5, 21), (4, 12), (1, 21))
def test_gwcs_spec_cel_time_4d(gwcs_spec_cel_time_4d, input, output):
w = gwcs_spec_cel_time_4d

assert_array_equal(w.invert(*w(*input)), output)
assert_array_equal(w.world_to_pixel_values(*w.pixel_to_world_values(*input)), output)
assert_array_equal(w.world_to_pixel(w.pixel_to_world(*input)), output)
assert_allclose(w.invert(*w(*input, with_bounding_box=False)), output, atol=1e-8)


def test_4d_identity_units(gwcs_4d_identity_units, input, ooutput):
w = gwcs_4d_identity_units
w.bounding_box = ((-.5, 21), (4, 12), (1, 21), (5, 10))
# @pytest.mark.parametrize((("input", "output")), [((2, 4, 5), (2, 4, 5))]
# def test_gwcs_1d_freq_quantity(gwcs_1d_freq_quantity, input, output):
# w = gwcs_1d_freq_quantity
46 changes: 15 additions & 31 deletions gwcs/wcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
from astropy.modeling.models import (Const1D, Identity, Mapping, Polynomial2D,
RotateCelestial2Native, Shift,
Sky2Pix_TAN)
from astropy.modeling.parameters import _tofloat
from astropy.wcs.utils import celestial_frame_to_wcs, proj_plane_pixel_scales
from astropy.wcs.wcsapi.high_level_api import high_level_objects_to_values

from astropy import units as u
from scipy import linalg, optimize

from . import coordinate_frames as cf
Expand Down Expand Up @@ -466,9 +470,7 @@ def invert(self, *args, **kwargs):
btrans = self.backward_transform
except NotImplementedError:
btrans = None
print(f"args[0], {args[0]}")
if not utils.isnumerical(args[0]):
print(f"args1, {args}")
# convert astropy objects to numbers and arrays
args = self.output_frame.coordinate_to_quantity(*args)
if self.output_frame.naxes == 1:
Expand All @@ -481,12 +483,10 @@ def invert(self, *args, **kwargs):
with_bounding_box = kwargs.pop('with_bounding_box', True)
fill_value = kwargs.pop('fill_value', np.nan)
akwargs = {k: v for k, v in kwargs.items() if k not in _ITER_INV_KWARGS}
print(f"args, {args}")
if with_bounding_box and self.bounding_box is not None:
result = self.outside_footprint(args)

if btrans is not None:
#akwargs = {k: v for k, v in kwargs.items() if k not in _ITER_INV_KWARGS}
result = btrans(*args, **akwargs)
else:
result = self.numerical_inverse(*args, **kwargs, with_units=with_units)
Expand All @@ -502,38 +502,27 @@ def invert(self, *args, **kwargs):
return self.input_frame.coordinates(*result)
else:
return result

def outside_footprint(self, world_arrays):
# for axis in world_arrays:
# if np.isscalar(axis):
# world_arrays = np.asarray(list(world_arrays), dtype = np.float64)
# world_arrays = [world_arrays]
#print('axis', axis, world_arrays)
#if np.isscalar(axis) or self.output_frame.naxes == 1:
if self.output_frame.naxes == 1:
#print('axis', axis)
#axis = float(axis)
world_arrays = [world_arrays]
#print('axis', axis)
print('world_arrays1', world_arrays)
#world_arrays = np.asarray(list(world_arrays))#, dtype = np.float64)
print('world_arrays2', world_arrays)
world_arrays = list(world_arrays)

axes_types = set(self.output_frame.axes_type)
footprint = self.footprint()
world_arrays = [coo.to(unit) for coo, unit in zip(world_arrays, self.output_frame.unit)
if isinstance(coo, u.Quantity)]
world_arrays = [high_level_objects_to_values(coo, low_level_wcs=self) for
coo in world_arrays if not utils.isnumerical(coo)]

for axtyp in axes_types:
footprint = self.footprint(axis_type=axtyp)

ind = np.asarray((np.asarray(self.output_frame.axes_type) == axtyp))
#print('ind', ind)

for idim, coord in enumerate(world_arrays[ind]):
#print('footprint', footprint)
for idim, coord in enumerate(world_arrays):
coord = _tofloat(coord)
if np.asarray(ind).sum() > 1:
axis_range = footprint[:, idim]

Check warning on line 522 in gwcs/wcs.py

View check run for this annotation

Codecov / codecov/patch

gwcs/wcs.py#L522

Added line #L522 was not covered by tests
else:
axis_range = footprint
range = [axis_range.min(), axis_range.max()]
#print('idim', idim, coord, range)
outside = (coord < range[0]) | (coord > range[1])
if np.any(outside):
if np.isscalar(coord):
Expand Down Expand Up @@ -1447,11 +1436,6 @@ def footprint(self, bounding_box=None, center=False, axis_type="all"):
"""
def _order_clockwise(v):
# if self.input_frame.naxes == 1:
# bb = self.bounding_box.bounding_box()
# if isinstance(bb[0], u.Quantity):
# bb = [v.value for v in bb] * bb[0].unit
# return (bb,)
return np.asarray([[v[0][0], v[1][0]], [v[0][0], v[1][1]],
[v[0][1], v[1][1]], [v[0][1], v[1][0]]]).T

Expand Down Expand Up @@ -1482,7 +1466,7 @@ def _order_clockwise(v):
axis_type = axis_type.lower()
if axis_type == 'spatial' and all_spatial:
return result.T

if axis_type != "all":
axtyp_ind = np.array([t.lower() for t in self.output_frame.axes_type]) == axis_type
if not axtyp_ind.any():
Expand Down

0 comments on commit 926123b

Please sign in to comment.