Skip to content

Commit

Permalink
make tests pass, ecept slicing
Browse files Browse the repository at this point in the history
  • Loading branch information
Nadia Dencheva committed Nov 20, 2023
1 parent e65454e commit 065d4ce
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 75 deletions.
153 changes: 107 additions & 46 deletions gwcs/coordinate_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def __post_init__(self, naxes):
raise ValueError("Number of axes names does not match number of axes.")
else:
self.axes_names = tuple([""] * naxes)

if self.axis_physical_types is not None:
if isinstance(self.axis_physical_types, str):
self.axis_physical_types = (self.axis_physical_types,)
Expand All @@ -313,12 +313,22 @@ def __post_init__(self, naxes):
ph_type = []
for axt in self.axis_physical_types:
if axt not in VALID_UCDS and not axt.startswith("custom:"):
ph_type.append("custom:{axt}")
ph_type.append(f"custom:{axt}")
else:
ph_type.append(axt)

validate_physical_types(ph_type)
self.axes_physical_types = tuple(ph_type)
self.axis_physical_types = tuple(ph_type)
#self.world_axis_physical_types = tuple(ph_type)

@property
def _default_axis_physical_type(self):
"""
The default physical types to use for this frame if none are specified
by the user.
"""
return tuple("custom:{}".format(t) for t in self.axes_type)



class CoordinateFrame(BaseCoordinateFrame):
Expand Down Expand Up @@ -360,17 +370,25 @@ def __init__(self, naxes, axes_type, axes_order, reference_frame=None,

if isinstance(axes_type, str):
axes_type = (axes_type,)
default_apt = tuple([f"custom:{t}" for t in axes_type])

self._prop = FrameProperties(
naxes,
axes_type,
unit,
axes_names,
axis_physical_types or default_apt,
axis_physical_types or self._default_axis_physical_type(axes_type)
#axis_physical_types# or default_apt,
)

super().__init__()

def _default_axis_physical_type(self, axes_type):
"""
The default physical types to use for this frame if none are specified
by the user.
"""
return tuple("custom:{}".format(t) for t in axes_type)

def __repr__(self):
fmt = '<{0}(name="{1}", unit={2}, axes_names={3}, axes_order={4}'.format(
self.__class__.__name__, self.name,
Expand All @@ -386,8 +404,11 @@ def __str__(self):
return self.__class__.__name__

def _sort_property(self, property):
return tuple(dict(sorted(zip(property, self.axes_order),
key=lambda x: x[1])).keys())
#return tuple(dict(sorted(zip(property, self.axes_order),
# key=lambda x: x[1])).keys())
sorted_prop = sorted(zip(property, self.axes_order),
key=lambda x: x[1])
return tuple([t[0] for t in sorted_prop])

@property
def name(self):
Expand All @@ -408,6 +429,7 @@ def naxes(self):
def unit(self):
"""The unit of this frame."""
return self._sort_property(self._prop.unit)
#return self._prop.unit

@property
def axes_names(self):
Expand Down Expand Up @@ -436,21 +458,21 @@ def axis_physical_types(self):
These physical types are the types in frame order, not transform order.
"""
return self._sort_property(self._prop.axis_physical_types)
return self._prop.axis_physical_types or self._default_axis_physical_types

@property
def world_axis_object_classes(self):
return {f"{at}{i}" if i != 0 else at: (u.Quantity,
(),
{'unit': unit})
for i, (at, unit) in enumerate(zip(self._axes_type, self.unit))}
for i, (at, unit) in enumerate(zip(self.axes_type, self.unit))}

@property
def world_axis_object_components(self):
return [(f"{at}{i}" if i != 0 else at, 0, 'value') for i, at in enumerate(self._prop.axes_type)]

@property
def _native_world_axis_object_components(self):
def _native_world_axis_object_components(self):
"""Defines the target component ordering (i.e. not taking into account axes_order)"""
return self.world_axis_object_components

Expand Down Expand Up @@ -486,47 +508,50 @@ def __init__(self, axes_order=None, reference_frame=None,
if axes_names is None:
axes_names = _axes_names
naxes = len(_axes_names)
_unit = list(reference_frame.representation_component_units.values())
if unit is None and _unit:
unit = _unit
#_unit = list(reference_frame.representation_component_units.values())
#if unit is None and _unit:
# unit = _unit


self.native_axes_order = tuple(range(naxes))
if axes_order is None:
axes_order = tuple(range(naxes))
axes_order = self.native_axes_order
if unit is None:
unit = tuple([u.degree] * naxes)
axes_type = ['SPATIAL'] * naxes

pht = axis_physical_types or self._default_axis_physical_types(reference_frame, axes_names)
super().__init__(naxes=naxes,
axes_type=axes_type,
axes_order=axes_order,
reference_frame=reference_frame,
unit=unit,
axes_names=axes_names,
name=name,
axis_physical_types=axis_physical_types)
axis_physical_types=pht)

@property
def _default_axis_physical_types(self):
if isinstance(self.reference_frame, coord.Galactic):
def _default_axis_physical_types(self, reference_frame, axes_names):
if isinstance(reference_frame, coord.Galactic):
return "pos.galactic.lon", "pos.galactic.lat"
elif isinstance(self.reference_frame, (coord.GeocentricTrueEcliptic,
coord.GCRS,
coord.PrecessedGeocentric)):
elif isinstance(reference_frame, (coord.GeocentricTrueEcliptic,
coord.GCRS,
coord.PrecessedGeocentric)):
return "pos.bodyrc.lon", "pos.bodyrc.lat"
elif isinstance(self.reference_frame, coord.builtin_frames.BaseRADecFrame):
elif isinstance(reference_frame, coord.builtin_frames.BaseRADecFrame):
return "pos.eq.ra", "pos.eq.dec"
elif isinstance(self.reference_frame, coord.builtin_frames.BaseEclipticFrame):
elif isinstance(reference_frame, coord.builtin_frames.BaseEclipticFrame):
return "pos.ecliptic.lon", "pos.ecliptic.lat"
else:
return tuple("custom:{}".format(t) for t in self.axes_names)

@property
def world_axis_object_classes(self):
unit = np.array(self.unit)[np.argsort(self.axes_order)]
return {'celestial': (
coord.SkyCoord,
(),
{'frame': self.reference_frame,
'unit': self.unit})}
'unit': unit})}

@property
def _native_world_axis_object_components(self):
Expand Down Expand Up @@ -563,28 +588,34 @@ class SpectralFrame(CoordinateFrame):

def __init__(self, axes_order=(0,), reference_frame=None, unit=None,
axes_names=None, name=None, axis_physical_types=None):


if not isiterable(unit):
unit = (unit,)

pht = axis_physical_types or self._default_axis_physical_types(unit)

super().__init__(naxes=1, axes_type="SPECTRAL", axes_order=axes_order,
axes_names=axes_names, reference_frame=reference_frame,
unit=unit, name=name,
axis_physical_types=axis_physical_types)
#axis_physical_types="em.wl")
axis_physical_types=pht)

@property
def _default_axis_physical_types(self):
if self.unit[0].physical_type == "frequency":

def _default_axis_physical_types(self, unit):
if unit[0].physical_type == "frequency":
return ("em.freq",)
elif self.unit[0].physical_type == "length":
elif unit[0].physical_type == "length":
return ("em.wl",)
elif self.unit[0].physical_type == "energy":
elif unit[0].physical_type == "energy":
return ("em.energy",)
elif self.unit[0].physical_type == "speed":
elif unit[0].physical_type == "speed":
return ("spect.dopplerVeloc",)
logging.warning("Physical type may be ambiguous. Consider "
"setting the physical type explicitly as "
"either 'spect.dopplerVeloc.optical' or "
"'spect.dopplerVeloc.radio'.")
else:
return ("custom:{}".format(self.unit[0].physical_type),)
return ("custom:{}".format(unit[0].physical_type),)

@property
def world_axis_object_classes(self):
Expand Down Expand Up @@ -625,17 +656,19 @@ def __init__(self, reference_frame, unit=None, axes_order=(0,),
reference_frame.scale,
reference_frame.location)

pht = axis_physical_types or self._default_axis_physical_types()

super().__init__(naxes=1, axes_type="TIME", axes_order=axes_order,
axes_names=axes_names, reference_frame=reference_frame,
unit=unit, name=name, axis_physical_types=axis_physical_types)
unit=unit, name=name, axis_physical_types=pht)
self._attrs = {}
for a in self.reference_frame.info._represent_as_dict_extra_attrs:
try:
self._attrs[a] = getattr(self.reference_frame, a)
except AttributeError:
pass

@property
#@property
def _default_axis_physical_types(self):
return ("time",)

Expand Down Expand Up @@ -686,13 +719,16 @@ class CompositeFrame(CoordinateFrame):
def __init__(self, frames, name=None):
self._frames = frames[:]
naxes = sum([frame._naxes for frame in self._frames])

axes_type = list(range(naxes))
unit = list(range(naxes))
axes_names = list(range(naxes))
axes_order = []
ph_type = list(range(naxes))
axes_order = []

for frame in frames:
axes_order.extend(frame.axes_order)

for frame in frames:
unsorted_prop = zip(
frame.axes_order,
Expand All @@ -706,6 +742,8 @@ def __init__(self, frames, name=None):
axes_names[ind] = n
unit[ind] = un
ph_type[ind] = pht


if len(np.unique(axes_order)) != len(axes_order):
raise ValueError("Incorrect numbering of axes, "
"axes_order should contain unique numbers, "
Expand All @@ -714,13 +752,30 @@ def __init__(self, frames, name=None):
super().__init__(naxes, axes_type=axes_type,
axes_order=axes_order,
unit=unit, axes_names=axes_names,
axis_physical_types=tuple(ph_type),
name=name)
self._axis_physical_types = tuple(ph_type)

@property
def frames(self):
return self._frames

@property
def unit(self):
return self._prop.unit

@property
def axes_(self):
return self._prop.axes_names

@property
def axes_type(self):
return self._prop.axes_type

@property
def axis_physical_types(self):
return self._prop.axis_physical_types

def __repr__(self):
return repr(self.frames)

Expand Down Expand Up @@ -767,12 +822,14 @@ def world_axis_object_components(self):
We need to generate the components respecting the axes_order.
"""
out = [None] * self.naxes

for frame, components in self._wao_renamed_components_iter:
for i, ao in enumerate(frame.axes_order):
out[ao] = components[i]

if any([o is None for o in out]):
raise ValueError("axes_order leads to incomplete world_axis_object_components")

return out

@property
Expand All @@ -793,11 +850,13 @@ class StokesFrame(CoordinateFrame):
"""

def __init__(self, axes_order=(0,), axes_names=("stokes",), name=None, axis_physical_types=None):

pht = axis_physical_types or self._default_axis_physical_types()

super().__init__(1, ["STOKES"], axes_order, name=name,
axes_names=axes_names, unit=u.one,
axis_physical_types=axis_physical_types)
axis_physical_types=pht)

@property
def _default_axis_physical_types(self):
return ("phys.polarization.stokes",)

Expand Down Expand Up @@ -831,17 +890,19 @@ class Frame2D(CoordinateFrame):
"""

def __init__(self, axes_order=(0, 1), unit=(u.pix, u.pix), axes_names=('x', 'y'),
name=None, axis_physical_types=None):
name=None, axes_type=["SPATIAL", "SPATIAL"], axis_physical_types=None):

pht = axis_physical_types or self._default_axis_physical_types(axes_names, axes_type)

super().__init__(naxes=2, axes_type=["SPATIAL", "SPATIAL"],
super().__init__(naxes=2, axes_type=axes_type,
axes_order=axes_order, name=name,
axes_names=axes_names, unit=unit,
axis_physical_types=axis_physical_types)
axis_physical_types=pht)

@property
def _default_axis_physical_types(self):
if all(self.axes_names):
ph_type = self.axes_names
def _default_axis_physical_types(self, axes_names, axes_type):
if axes_names is not None and all(axes_names):
ph_type = axes_names
else:
ph_type = self.axes_type
ph_type = axes_type

return tuple("custom:{}".format(t) for t in ph_type)
8 changes: 4 additions & 4 deletions gwcs/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from astropy import coordinates as coord
from astropy.modeling import models

from .. import coordinate_frames as cf
from .. import spectroscopy as sp
from .. import wcs
from .. import geometry
from gwcs import coordinate_frames as cf
from gwcs import spectroscopy as sp
from gwcs import wcs
from gwcs import geometry

# frames
detector_1d = cf.CoordinateFrame(name='detector', axes_order=(0,), naxes=1, axes_type="detector")
Expand Down
4 changes: 2 additions & 2 deletions gwcs/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def test_world_axis_object_classes_2d(gwcs_2d_spatial_shift):
assert 'frame' in waoc['celestial'][2]
assert 'unit' in waoc['celestial'][2]
assert isinstance(waoc['celestial'][2]['frame'], coord.ICRS)
assert waoc['celestial'][2]['unit'] == (u.deg, u.deg)
assert tuple(waoc['celestial'][2]['unit']) == (u.deg, u.deg)


def test_world_axis_object_classes_2d_generic(gwcs_2d_quantity_shift):
Expand All @@ -217,7 +217,7 @@ def test_world_axis_object_classes_4d(gwcs_4d_identity_units):
assert 'frame' in waoc['celestial'][2]
assert 'unit' in waoc['celestial'][2]
assert isinstance(waoc['celestial'][2]['frame'], coord.ICRS)
assert waoc['celestial'][2]['unit'] == (u.deg, u.deg)
assert tuple(waoc['celestial'][2]['unit']) == (u.deg, u.deg)

temporal = waoc['temporal']
assert temporal[0] is time.Time
Expand Down
Loading

0 comments on commit 065d4ce

Please sign in to comment.