Skip to content

Commit

Permalink
Start work on porting tests to ape 14 infra
Browse files Browse the repository at this point in the history
  • Loading branch information
Cadair committed Jun 20, 2023
1 parent adc2d41 commit a9bbf19
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 79 deletions.
4 changes: 2 additions & 2 deletions gwcs/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,11 +256,11 @@ def serialized_classes(self):

@property
def world_axis_object_classes(self):
return self.output_frame._world_axis_object_classes
return self.output_frame.world_axis_object_classes

@property
def world_axis_object_components(self):
return self.output_frame._world_axis_object_components
return self.output_frame.world_axis_object_components

@property
def pixel_axis_names(self):
Expand Down
56 changes: 28 additions & 28 deletions gwcs/coordinate_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def axis_physical_types(self):

@property
@abc.abstractmethod
def _world_axis_object_classes(self):
def world_axis_object_classes(self):
"""
The APE 14 object classes for this frame.
Expand All @@ -260,7 +260,7 @@ def _world_axis_object_classes(self):

@property
@abc.abstractmethod
def _world_axis_object_components(self):
def world_axis_object_components(self):
"""
The APE 14 object components for this frame.
Expand Down Expand Up @@ -444,14 +444,14 @@ def axis_physical_types(self):
return self._axis_physical_types or self._default_axis_physical_types

@property
def _world_axis_object_classes(self):
def world_axis_object_classes(self):
return {f"{at}{i}" if i != 0 else at: (u.Quantity,
(),
{'unit': unit})
for i, (at, unit) in enumerate(zip(self._axes_type, self.unit))}

@property
def _world_axis_object_components(self):
def world_axis_object_components(self):
return [(f"{at}{i}" if i != 0 else at, 0, 'value') for i, at in enumerate(self._axes_type)]


Expand Down Expand Up @@ -543,15 +543,15 @@ def _default_axis_physical_types(self):
return tuple("custom:{}".format(t) for t in self.axes_names)

@property
def _world_axis_object_classes(self):
def world_axis_object_classes(self):
return {'celestial': (
coord.SkyCoord,
(),
{'frame': self.reference_frame,
'unit': self.unit})}

@property
def _world_axis_object_components(self):
def world_axis_object_components(self):
return [('celestial', 0, 'spherical.lon'),
('celestial', 1, 'spherical.lat')]

Expand Down Expand Up @@ -605,14 +605,14 @@ def _default_axis_physical_types(self):
return ("custom:{}".format(self.unit[0].physical_type),)

@property
def _world_axis_object_classes(self):
def world_axis_object_classes(self):
return {'spectral': (
coord.SpectralCoord,
(),
{'unit': self.unit[0]})}

@property
def _world_axis_object_components(self):
def world_axis_object_components(self):
return [('spectral', 0, 'value')]


Expand Down Expand Up @@ -657,8 +657,19 @@ def __init__(self, reference_frame, unit=None, axes_order=(0,),
def _default_axis_physical_types(self):
return ("time",)

def _convert_to_time(self, dt, *, unit, **kwargs):
if (not isinstance(dt, time.TimeDelta) and
isinstance(dt, time.Time) or
isinstance(self.reference_frame.value, np.ndarray)):
return time.Time(dt, **kwargs)

if not hasattr(dt, 'unit'):
dt = dt * unit

return self.reference_frame + dt

@property
def _world_axis_object_classes(self):
def world_axis_object_classes(self):
comp = (
time.Time,
(),
Expand All @@ -668,25 +679,14 @@ def _world_axis_object_classes(self):
return {'temporal': comp}

@property
def _world_axis_object_components(self):
def world_axis_object_components(self):
if isinstance(self.reference_frame.value, np.ndarray):
return [('temporal', 0, 'value')]

def offset_from_time_and_reference(time):
return (time - self.reference_frame).sec
return [('temporal', 0, offset_from_time_and_reference)]

def _convert_to_time(self, dt, *, unit, **kwargs):
if (not isinstance(dt, time.TimeDelta) and
isinstance(dt, time.Time) or
isinstance(self.reference_frame.value, np.ndarray)):
return time.Time(dt, **kwargs)

if not hasattr(dt, 'unit'):
dt = dt * unit

return self.reference_frame + dt


class CompositeFrame(CoordinateFrame):
"""
Expand Down Expand Up @@ -743,7 +743,7 @@ def _wao_classes_rename_map(self):
for frame in self.frames:
# ensure the frame is in the mapper
mapper[frame]
for key in frame._world_axis_object_classes.keys():
for key in frame.world_axis_object_classes.keys():
if key in seen_names:
new_key = f"{key}{seen_names.count(key)}"
mapper[frame][key] = new_key
Expand All @@ -755,7 +755,7 @@ def _wao_renamed_components_iter(self):
mapper = self._wao_classes_rename_map
for frame in self.frames:
renamed_components = []
for comp in frame._world_axis_object_components:
for comp in frame.world_axis_object_components:
comp = list(comp)
rename = mapper[frame].get(comp[0])
if rename:
Expand All @@ -767,14 +767,14 @@ def _wao_renamed_components_iter(self):
def _wao_renamed_classes_iter(self):
mapper = self._wao_classes_rename_map
for frame in self.frames:
for key, value in frame._world_axis_object_classes.items():
for key, value in frame.world_axis_object_classes.items():
rename = mapper[frame].get(key)
if rename:
key = rename
yield key, value

@property
def _world_axis_object_components(self):
def world_axis_object_components(self):
"""
We need to generate the components respecting the axes_order.
"""
Expand All @@ -788,7 +788,7 @@ def _world_axis_object_components(self):
return out

@property
def _world_axis_object_classes(self):
def world_axis_object_classes(self):
return dict(self._wao_renamed_classes_iter)


Expand All @@ -814,15 +814,15 @@ def _default_axis_physical_types(self):
return ("phys.polarization.stokes",)

@property
def _world_axis_object_classes(self):
def world_axis_object_classes(self):
return {'stokes': (
StokesCoord,
(),
{},
)}

@property
def _world_axis_object_components(self):
def world_axis_object_components(self):
return [('stokes', 0, 'value')]


Expand Down
4 changes: 2 additions & 2 deletions gwcs/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,12 +491,12 @@ def test_composite_many_base_frame():
q_frame_2 = cf.CoordinateFrame(name='distance', axes_order=(1,), naxes=1, axes_type="SPATIAL", unit=(u.m,))
frame = cf.CompositeFrame([q_frame_1, q_frame_2])

wao_classes = frame._world_axis_object_classes
wao_classes = frame.world_axis_object_classes

assert len(wao_classes) == 2
assert not set(wao_classes.keys()).difference({"SPATIAL", "SPATIAL1"})

wao_components = frame._world_axis_object_components
wao_components = frame.world_axis_object_components

assert len(wao_components) == 2
assert not {c[0] for c in wao_components}.difference({"SPATIAL", "SPATIAL1"})
Expand Down
Loading

0 comments on commit a9bbf19

Please sign in to comment.