Skip to content

Commit

Permalink
First attempt at keeping a sorted and unsorted list of frame props
Browse files Browse the repository at this point in the history
  • Loading branch information
Cadair committed Oct 16, 2023
1 parent 2e33bb2 commit e65454e
Showing 1 changed file with 86 additions and 67 deletions.
153 changes: 86 additions & 67 deletions gwcs/coordinate_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
from collections import defaultdict
import logging
import numpy as np
from dataclasses import dataclass, InitVar

from astropy.utils.misc import isiterable
from astropy import time
Expand Down Expand Up @@ -176,7 +177,6 @@ def get_ctype_from_ucd(ucd):
return UCD1_TO_CTYPE.get(ucd, "")



class BaseCoordinateFrame(abc.ABC):
"""
API Definition for a Coordinate frame
Expand Down Expand Up @@ -264,6 +264,63 @@ def world_axis_object_components(self):
"""


@dataclass
class FrameProperties:
naxes: InitVar[int]
axes_type: tuple[str]
unit: tuple[u.Unit] = None
axes_names: tuple[str] = None
axis_physical_types: list[str] = None

def __post_init__(self, naxes):
if isinstance(self.axes_type, str):
self.axes_type = (self.axes_type,)
else:
self.axes_type = tuple(self.axes_type)

if len(self.axes_type) != naxes:
raise ValueError("Length of axes_type does not match number of axes.")

if self.unit is not None:
if astutil.isiterable(self.unit):
unit = tuple(self.unit)
else:
unit = (self.unit,)
if len(unit) != naxes:
raise ValueError("Number of units does not match number of axes.")
else:
self.unit = tuple(u.Unit(au) for au in unit)
else:
self.unit = tuple(u.dimensionless_unscaled for na in range(naxes))

if self.axes_names is not None:
if isinstance(self.axes_names, str):
self.axes_names = (self.axes_names,)
else:
self.axes_names = tuple(self.axes_names)
if len(self.axes_names) != naxes:
raise ValueError("Number of axes names does not match number of axes.")
else:
self.axes_names = tuple([""] * naxes)

if self.axis_physical_types is not None:
if isinstance(self.axis_physical_types, str):
self.axis_physical_types = (self.axis_physical_types,)
elif not isiterable(self.axis_physical_types):
raise TypeError("axis_physical_types must be of type string or iterable of strings")
if len(self.axis_physical_types) != naxes:
raise ValueError(f'"axis_physical_types" must be of length {naxes}')
ph_type = []
for axt in self.axis_physical_types:
if axt not in VALID_UCDS and not axt.startswith("custom:"):
ph_type.append("custom:{axt}")
else:
ph_type.append(axt)

validate_physical_types(ph_type)
self.axes_physical_types = tuple(ph_type)


class CoordinateFrame(BaseCoordinateFrame):
"""
Base class for Coordinate Frames.
Expand Down Expand Up @@ -291,68 +348,28 @@ def __init__(self, naxes, axes_type, axes_order, reference_frame=None,
name=None, axis_physical_types=None):
self._naxes = naxes
self._axes_order = tuple(axes_order)
if isinstance(axes_type, str):
self._axes_type = (axes_type,)
else:
self._axes_type = tuple(axes_type)

self._reference_frame = reference_frame
if unit is not None:
if astutil.isiterable(unit):
unit = tuple(unit)
else:
unit = (unit,)
if len(unit) != naxes:
raise ValueError("Number of units does not match number of axes.")
else:
self._unit = tuple([u.Unit(au) for au in unit])
else:
self._unit = tuple(u.Unit("") for na in range(naxes))
if axes_names is not None:
if isinstance(axes_names, str):
axes_names = (axes_names,)
else:
axes_names = tuple(axes_names)
if len(axes_names) != naxes:
raise ValueError("Number of axes names does not match number of axes.")
else:
axes_names = tuple([""] * naxes)
self._axes_names = axes_names

if name is None:
self._name = self.__class__.__name__
else:
self._name = name

if len(self._axes_type) != naxes:
raise ValueError("Length of axes_type does not match number of axes.")
if len(self._axes_order) != naxes:
raise ValueError("Length of axes_order does not match number of axes.")

super(CoordinateFrame, self).__init__()
# _axis_physical_types holds any user supplied physical types
self._axis_physical_types = self._set_axis_physical_types(axis_physical_types)

def _set_axis_physical_types(self, pht):
"""
Set the physical type of the coordinate axes using VO UCD1+ v1.23 definitions.
"""
if pht is not None:
if isinstance(pht, str):
pht = (pht,)
elif not isiterable(pht):
raise TypeError("axis_physical_types must be of type string or iterable of strings")
if len(pht) != self.naxes:
raise ValueError('"axis_physical_types" must be of length {}'.format(self.naxes))
ph_type = []
for axt in pht:
if axt not in VALID_UCDS and not axt.startswith("custom:"):
ph_type.append("custom:{}".format(axt))
else:
ph_type.append(axt)
if isinstance(axes_type, str):
axes_type = (axes_type,)
default_apt = tuple([f"custom:{t}" for t in axes_type])
self._prop = FrameProperties(
naxes,
axes_type,
unit,
axes_names,
axis_physical_types or default_apt,
)

validate_physical_types(ph_type)
return tuple(ph_type)
super().__init__()

def __repr__(self):
fmt = '<{0}(name="{1}", unit={2}, axes_names={3}, axes_order={4}'.format(
Expand All @@ -368,6 +385,10 @@ def __str__(self):
return self._name
return self.__class__.__name__

def _sort_property(self, property):
return tuple(dict(sorted(zip(property, self.axes_order),
key=lambda x: x[1])).keys())

@property
def name(self):
""" A custom name of this frame."""
Expand All @@ -386,12 +407,12 @@ def naxes(self):
@property
def unit(self):
"""The unit of this frame."""
return self._unit
return self._sort_property(self._prop.unit)

@property
def axes_names(self):
""" Names of axes in the frame."""
return self._axes_names
return self._sort_property(self._prop.axes_names)

@property
def axes_order(self):
Expand All @@ -406,15 +427,7 @@ def reference_frame(self):
@property
def axes_type(self):
""" Type of this frame : 'SPATIAL', 'SPECTRAL', 'TIME'. """
return self._axes_type

@property
def _default_axis_physical_types(self):
"""
The default physical types to use for this frame if none are specified
by the user.
"""
return tuple("custom:{}".format(t) for t in self.axes_type)
return self._sort_property(self._prop.axes_type)

@property
def axis_physical_types(self):
Expand All @@ -423,7 +436,7 @@ def axis_physical_types(self):
These physical types are the types in frame order, not transform order.
"""
return self._axis_physical_types or self._default_axis_physical_types
return self._sort_property(self._prop.axis_physical_types)

@property
def world_axis_object_classes(self):
Expand All @@ -434,7 +447,7 @@ def world_axis_object_classes(self):

@property
def world_axis_object_components(self):
return [(f"{at}{i}" if i != 0 else at, 0, 'value') for i, at in enumerate(self._axes_type)]
return [(f"{at}{i}" if i != 0 else at, 0, 'value') for i, at in enumerate(self._prop.axes_type)]

@property
def _native_world_axis_object_components(self):
Expand Down Expand Up @@ -681,8 +694,14 @@ def __init__(self, frames, name=None):
for frame in frames:
axes_order.extend(frame.axes_order)
for frame in frames:
for ind, axtype, un, n, pht in zip(frame.axes_order, frame.axes_type,
frame.unit, frame.axes_names, frame.axis_physical_types):
unsorted_prop = zip(
frame.axes_order,
frame._prop.axes_type,
frame._prop.unit,
frame._prop.axes_names,
frame._prop.axis_physical_types
)
for ind, axtype, un, n, pht in unsorted_prop:
axes_type[ind] = axtype
axes_names[ind] = n
unit[ind] = un
Expand Down

0 comments on commit e65454e

Please sign in to comment.