diff --git a/gwcs/api.py b/gwcs/api.py index b46d679c..ee4fc983 100644 --- a/gwcs/api.py +++ b/gwcs/api.py @@ -7,6 +7,7 @@ from astropy.wcs.wcsapi import BaseHighLevelWCS, BaseLowLevelWCS from astropy.modeling import separable +from astropy.wcs.wcsapi.high_level_api import values_to_high_level_objects, high_level_objects_to_values import astropy.units as u from . import utils @@ -298,10 +299,6 @@ def _sanitize_pixel_inputs(self, *pixel_arrays): return pixels - def _sanitize_world_inputs(self, *world_arrays): - world_coord = [] - - def pixel_to_world(self, *pixel_arrays): """ Convert pixel values to world coordinates. @@ -322,8 +319,9 @@ def world_to_pixel(self, *world_objects): """ Convert world coordinates to pixel values. """ + #args = high_level_objects_to_values(*world_objects, low_level_wcs=self) + #result = self.invert(*args) result = self.invert(*world_objects, with_units=True) - if self.input_frame.naxes > 1: first_res = result[0] if not utils.isnumerical(first_res): diff --git a/gwcs/tests/conftest.py b/gwcs/tests/conftest.py index 3c8d0425..014807de 100644 --- a/gwcs/tests/conftest.py +++ b/gwcs/tests/conftest.py @@ -266,7 +266,8 @@ def gwcs_3d_galactic_spectral(): transform.bounding_box = ((5, 50), (-2, 45), (-1, 35)) sky_frame = cf.CelestialFrame(axes_order=(2, 0), - reference_frame=coord.Galactic(), axes_names=("Longitude", "Latitude")) + reference_frame=coord.Galactic(), + axes_names=("Longitude", "Latitude")) wave_frame = cf.SpectralFrame(axes_order=(1, ), unit=u.Hz, axes_names=("Frequency",)) frame = cf.CompositeFrame([sky_frame, wave_frame]) diff --git a/gwcs/tests/test_bounding_box.py b/gwcs/tests/test_bounding_box.py index 5fc13f79..335dcc4c 100644 --- a/gwcs/tests/test_bounding_box.py +++ b/gwcs/tests/test_bounding_box.py @@ -1,8 +1,6 @@ import numpy as np from numpy.testing import assert_array_equal, assert_allclose -from astropy import units as u - import pytest @@ -13,7 +11,7 @@ @pytest.mark.parametrize((("input", "output")), [((2, 4), (2, 4)), ((100, 200), (np.nan, np.nan)), - ((x, x), (y, y)) + ((x, x),(y, y)) ]) def test_2d_spatial(gwcs_2d_spatial_shift, input, output): w = gwcs_2d_spatial_shift @@ -52,7 +50,7 @@ def test_2d_spatial_coordinate_reordered(gwcs_2d_spatial_reordered, input, outpu @pytest.mark.parametrize((("input", "output")), [(2, 2), ((10, 200), (10, np.nan)), - (x, (np.nan, 2, 4, 13)) + (x, (np.nan, 2, 4, 13)) ]) def test_1d_freq(gwcs_1d_freq, input, output): w = gwcs_1d_freq @@ -76,46 +74,19 @@ def test_3d_spatial_wave(gwcs_3d_spatial_wave, input, output): assert_array_equal(w.world_to_pixel(*w.pixel_to_world(*input)), output) -@pytest.mark.parametrize((("input", "output")), [(2, 2), - ((10, 200), (10, np.nan)), - (x, (np.nan, 2, 4, 13)) - ]) -def test_1d_freq_quantity(gwcs_1d_freq_quantity, input, output): - w = gwcs_1d_freq_quantity - #w.bounding_box = (-.5*u.pix, 21*u.pix) - w.bounding_box = (-.5, 21) - - # assert_array_equal(w.invert(w(input)), output) - # assert_array_equal(w.world_to_pixel_values(w.pixel_to_world_values(*input)), output) - # assert_array_equal(w.world_to_pixel(w.pixel_to_world(input)), output) - - -@pytest.mark.parametrize((("input", "output")), [((2, 4), (2, 4)), - ((100, 200), (np.nan, np.nan)), - ((x, x), (y, y)) - ]) -def test_2d_shift_scale_quantity(gwcs_2d_shift_scale_quantity, input, output): - w = gwcs_2d_shift_scale_quantity - w.bounding_box = ((-.5, 21), (4, 12)) - - assert_array_equal(w.invert(*w(*input)), output) - assert_array_equal(w.world_to_pixel_values(*w.pixel_to_world_values(*input)), output) - assert_array_equal(w.world_to_pixel(w.pixel_to_world(*input)), output) - - -@pytest.mark.parametrize((("input", "output")), [((2, 4, 5), (2, 4, 5)), - ((100, 200, 5), (np.nan, np.nan, np.nan)), - ((x, x, x), (y1, y1, y1)) +@pytest.mark.parametrize((("input", "output")), [((1, 2, 3, 4), (1., 2., 3., 4.)), + ((100, 3, 3, 3), (np.nan, 3, 3, 3)), + ((x, x, x, x), [[np.nan, 2., 4., 13.], + [np.nan, 2., 4., 13.], + [np.nan, 2., 4., 13.], + [np.nan, 2., 4., np.nan]]) ]) -def test_3d_identity_units(gwcs_3d_identity_units, input, output): - w = gwcs_3d_identity_units - w.bounding_box = ((-.5, 21), (4, 12), (1, 21)) +def test_gwcs_spec_cel_time_4d(gwcs_spec_cel_time_4d, input, output): + w = gwcs_spec_cel_time_4d - assert_array_equal(w.invert(*w(*input)), output) - assert_array_equal(w.world_to_pixel_values(*w.pixel_to_world_values(*input)), output) - assert_array_equal(w.world_to_pixel(w.pixel_to_world(*input)), output) + assert_allclose(w.invert(*w(*input, with_bounding_box=False)), output, atol=1e-8) -def test_4d_identity_units(gwcs_4d_identity_units, input, ooutput): - w = gwcs_4d_identity_units - w.bounding_box = ((-.5, 21), (4, 12), (1, 21), (5, 10)) \ No newline at end of file +# @pytest.mark.parametrize((("input", "output")), [((2, 4, 5), (2, 4, 5))] +# def test_gwcs_1d_freq_quantity(gwcs_1d_freq_quantity, input, output): +# w = gwcs_1d_freq_quantity \ No newline at end of file diff --git a/gwcs/wcs.py b/gwcs/wcs.py index 089d6fa2..cb4c5bca 100644 --- a/gwcs/wcs.py +++ b/gwcs/wcs.py @@ -13,7 +13,11 @@ from astropy.modeling.models import (Const1D, Identity, Mapping, Polynomial2D, RotateCelestial2Native, Shift, Sky2Pix_TAN) +from astropy.modeling.parameters import _tofloat from astropy.wcs.utils import celestial_frame_to_wcs, proj_plane_pixel_scales +from astropy.wcs.wcsapi.high_level_api import high_level_objects_to_values + +from astropy import units as u from scipy import linalg, optimize from . import coordinate_frames as cf @@ -466,9 +470,7 @@ def invert(self, *args, **kwargs): btrans = self.backward_transform except NotImplementedError: btrans = None - print(f"args[0], {args[0]}") if not utils.isnumerical(args[0]): - print(f"args1, {args}") # convert astropy objects to numbers and arrays args = self.output_frame.coordinate_to_quantity(*args) if self.output_frame.naxes == 1: @@ -481,12 +483,10 @@ def invert(self, *args, **kwargs): with_bounding_box = kwargs.pop('with_bounding_box', True) fill_value = kwargs.pop('fill_value', np.nan) akwargs = {k: v for k, v in kwargs.items() if k not in _ITER_INV_KWARGS} - print(f"args, {args}") if with_bounding_box and self.bounding_box is not None: result = self.outside_footprint(args) if btrans is not None: - #akwargs = {k: v for k, v in kwargs.items() if k not in _ITER_INV_KWARGS} result = btrans(*args, **akwargs) else: result = self.numerical_inverse(*args, **kwargs, with_units=with_units) @@ -502,38 +502,27 @@ def invert(self, *args, **kwargs): return self.input_frame.coordinates(*result) else: return result - + def outside_footprint(self, world_arrays): - # for axis in world_arrays: - # if np.isscalar(axis): - # world_arrays = np.asarray(list(world_arrays), dtype = np.float64) - # world_arrays = [world_arrays] - #print('axis', axis, world_arrays) - #if np.isscalar(axis) or self.output_frame.naxes == 1: - if self.output_frame.naxes == 1: - #print('axis', axis) - #axis = float(axis) - world_arrays = [world_arrays] - #print('axis', axis) - print('world_arrays1', world_arrays) - #world_arrays = np.asarray(list(world_arrays))#, dtype = np.float64) - print('world_arrays2', world_arrays) + world_arrays = list(world_arrays) axes_types = set(self.output_frame.axes_type) + footprint = self.footprint() + world_arrays = [coo.to(unit) for coo, unit in zip(world_arrays, self.output_frame.unit) + if isinstance(coo, u.Quantity)] + world_arrays = [high_level_objects_to_values(coo, low_level_wcs=self) for + coo in world_arrays if not utils.isnumerical(coo)] + for axtyp in axes_types: - footprint = self.footprint(axis_type=axtyp) - ind = np.asarray((np.asarray(self.output_frame.axes_type) == axtyp)) - #print('ind', ind) - for idim, coord in enumerate(world_arrays[ind]): - #print('footprint', footprint) + for idim, coord in enumerate(world_arrays): + coord = _tofloat(coord) if np.asarray(ind).sum() > 1: axis_range = footprint[:, idim] else: axis_range = footprint range = [axis_range.min(), axis_range.max()] - #print('idim', idim, coord, range) outside = (coord < range[0]) | (coord > range[1]) if np.any(outside): if np.isscalar(coord): @@ -1447,11 +1436,6 @@ def footprint(self, bounding_box=None, center=False, axis_type="all"): """ def _order_clockwise(v): - # if self.input_frame.naxes == 1: - # bb = self.bounding_box.bounding_box() - # if isinstance(bb[0], u.Quantity): - # bb = [v.value for v in bb] * bb[0].unit - # return (bb,) return np.asarray([[v[0][0], v[1][0]], [v[0][0], v[1][1]], [v[0][1], v[1][1]], [v[0][1], v[1][0]]]).T @@ -1482,7 +1466,7 @@ def _order_clockwise(v): axis_type = axis_type.lower() if axis_type == 'spatial' and all_spatial: return result.T - + if axis_type != "all": axtyp_ind = np.array([t.lower() for t in self.output_frame.axes_type]) == axis_type if not axtyp_ind.any():