diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5e7956f5..eb9b64f8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -3,7 +3,7 @@ name: 'CI' on: push: branches: - - 'master' + - '*' tags: - '*' pull_request: diff --git a/CHANGES.rst b/CHANGES.rst index 9cd94fb4..4fa91620 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -5,6 +5,8 @@ - Fix an issue with units in ``wcs_from_points``. [#507] +- Fix incorrect units being returned in the low level WCS API. [#512] + 0.21.0 (2024-03-10) ------------------- diff --git a/docs/index.rst b/docs/index.rst index 481f2f76..d90f59a7 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -100,7 +100,7 @@ To install the latest release:: pip install gwcs -The latest release of GWCS is also available as part of `astroconda `__. +The latest release of GWCS is also available as a conda package via `conda-forge `__. .. _getting-started: @@ -240,13 +240,7 @@ To convert a pixel (x, y) = (1, 2) to sky coordinates, call the WCS object as a The :meth:`~gwcs.wcs.WCS.invert` method evaluates the :meth:`~gwcs.wcs.WCS.backward_transform` if available, otherwise applies an iterative method to calculate the reverse coordinates. -.. doctest-skip:: - - >>> wcsobj.invert(*sky) - (0.9999999996185807, 1.999999999186798) - -GWCS supports the common WCS interface which defines several methods -to work with high level Astropy objects: +GWCS supports the :ref:`wcsapi` which defines several methods to work with high level Astropy objects: .. doctest-skip:: diff --git a/gwcs/api.py b/gwcs/api.py index 4f2ce9fc..bcff2c1b 100644 --- a/gwcs/api.py +++ b/gwcs/api.py @@ -5,17 +5,14 @@ """ -from astropy.wcs.wcsapi import BaseHighLevelWCS, BaseLowLevelWCS +from astropy.wcs.wcsapi import BaseLowLevelWCS, HighLevelWCSMixin from astropy.modeling import separable import astropy.units as u -from . import utils -from . import coordinate_frames as cf - __all__ = ["GWCSAPIMixin"] -class GWCSAPIMixin(BaseHighLevelWCS, BaseLowLevelWCS): +class GWCSAPIMixin(BaseLowLevelWCS, HighLevelWCSMixin): """ A mix-in class that is intended to be inherited by the :class:`~gwcs.wcs.WCS` class and provides the low- and high-level @@ -52,14 +49,7 @@ def world_axis_physical_types(self): arbitrary string. Alternatively, if the physical type is unknown/undefined, an element can be `None`. """ - # A CompositeFrame orders the output correctly based on axes_order. - if isinstance(self.output_frame, cf.CompositeFrame): - return self.output_frame.axis_physical_types - - # If we don't have a CompositeFrame, where this is taken care of for us, - # we need to make sure we re-order the output to match the transform. - # The underlying frames don't reorder themselves because axes_order is global. - return tuple(self.output_frame.axis_physical_types[i] for i in self.output_frame.axes_order) + return self.output_frame.axis_physical_types @property def world_axis_units(self): @@ -78,19 +68,14 @@ def _remove_quantity_output(self, result, frame): if self.output_frame.naxes == 1: result = [result] - result = tuple(r.to_value(unit) for r, unit in zip(result, frame.unit)) + result = tuple(r.to_value(unit) if isinstance(r, u.Quantity) else r + for r, unit in zip(result, frame.unit)) # If we only have one output axes, we shouldn't return a tuple. if self.output_frame.naxes == 1 and isinstance(result, tuple): return result[0] return result - def _add_units_input(self, arrays, transform, frame): - if transform.uses_quantity: - return tuple(u.Quantity(array, unit) for array, unit in zip(arrays, frame.unit)) - - return arrays - def pixel_to_world_values(self, *pixel_arrays): """ Convert pixel coordinates to world coordinates. @@ -104,8 +89,7 @@ def pixel_to_world_values(self, *pixel_arrays): order, where for an image, ``x`` is the horizontal coordinate and ``y`` is the vertical coordinate. """ - pixel_arrays = self._add_units_input(pixel_arrays, self.forward_transform, self.input_frame) - result = self(*pixel_arrays, with_units=False) + result = self._call_forward(*pixel_arrays) return self._remove_quantity_output(result, self.output_frame) @@ -132,9 +116,7 @@ def world_to_pixel_values(self, *world_arrays): be returned in the ``(x, y)`` order, where for an image, ``x`` is the horizontal coordinate and ``y`` is the vertical coordinate. """ - world_arrays = self._add_units_input(world_arrays, self.backward_transform, self.output_frame) - - result = self.invert(*world_arrays, with_units=False) + result = self._call_backward(*world_arrays) return self._remove_quantity_output(result, self.input_frame) @@ -259,78 +241,11 @@ def serialized_classes(self): @property def world_axis_object_classes(self): - return self.output_frame._world_axis_object_classes + return self.output_frame.world_axis_object_classes @property def world_axis_object_components(self): - return self.output_frame._world_axis_object_components - - # High level APE 14 API - - @property - def low_level_wcs(self): - """ - Returns a reference to the underlying low-level WCS object. - """ - return self - - def _sanitize_pixel_inputs(self, *pixel_arrays): - pixels = [] - if self.forward_transform.uses_quantity: - for i, pixel in enumerate(pixel_arrays): - if not isinstance(pixel, u.Quantity): - pixel = u.Quantity(value=pixel, unit=self.input_frame.unit[i]) - pixels.append(pixel) - else: - for i, pixel in enumerate(pixel_arrays): - if isinstance(pixel, u.Quantity): - if pixel.unit != self.input_frame.unit[i]: - raise ValueError('Quantity input does not match the ' - 'input_frame unit.') - pixel = pixel.value - pixels.append(pixel) - - return pixels - - def pixel_to_world(self, *pixel_arrays): - """ - Convert pixel values to world coordinates. - """ - pixels = self._sanitize_pixel_inputs(*pixel_arrays) - return self(*pixels, with_units=True) - - def array_index_to_world(self, *index_arrays): - """ - Convert array indices to world coordinates (represented by Astropy - objects). - """ - pixel_arrays = index_arrays[::-1] - pixels = self._sanitize_pixel_inputs(*pixel_arrays) - return self(*pixels, with_units=True) - - def world_to_pixel(self, *world_objects): - """ - Convert world coordinates to pixel values. - """ - result = self.invert(*world_objects, with_units=True) - - if self.input_frame.naxes > 1: - first_res = result[0] - if not utils.isnumerical(first_res): - result = [i.value for i in result] - else: - if not utils.isnumerical(result): - result = result.value - - return result - - def world_to_array_index(self, *world_objects): - """ - Convert world coordinates (represented by Astropy objects) to array - indices. - """ - result = self.invert(*world_objects, with_units=True)[::-1] - return tuple([utils._toindex(r) for r in result]) + return self.output_frame.world_axis_object_components @property def pixel_axis_names(self): diff --git a/gwcs/converters/wcs.py b/gwcs/converters/wcs.py index a2f40b3f..274a5973 100644 --- a/gwcs/converters/wcs.py +++ b/gwcs/converters/wcs.py @@ -138,19 +138,8 @@ def from_yaml_tree(self, node, tag, ctx): from ..coordinate_frames import SpectralFrame node = self._from_yaml_tree(node, tag, ctx) - if 'reference_position' in node: - node['reference_position'] = node['reference_position'].upper() - return SpectralFrame(**node) - def to_yaml_tree(self, frame, tag, ctx): - node = self._to_yaml_tree(frame, tag, ctx) - - if frame.reference_position is not None: - node['reference_position'] = frame.reference_position.lower() - - return node - class CompositeFrameConverter(FrameConverter): tags = ["tag:stsci.edu:gwcs/composite_frame-*"] diff --git a/gwcs/coordinate_frames.py b/gwcs/coordinate_frames.py index 5131a4d4..e549c9e4 100644 --- a/gwcs/coordinate_frames.py +++ b/gwcs/coordinate_frames.py @@ -1,10 +1,123 @@ # Licensed under a 3-clause BSD style license - see LICENSE.rst """ -Defines coordinate frames and ties them to data axes. +This module defines coordinate frames for describing the inputs and/or outputs +of a transform. + +In the block diagram, the WCS pipeline has a two stage transformation (two +astropy Model instances), with an input frame, an output frame and an +intermediate frame. + +.. code-block:: + + ┌───────────────┐ + │ │ + │ Input │ + │ Frame │ + │ │ + └───────┬───────┘ + │ + ┌─────▼─────┐ + │ Transform │ + └─────┬─────┘ + │ + ┌───────▼───────┐ + │ │ + │ Intermediate │ + │ Frame │ + │ │ + └───────┬───────┘ + │ + ┌─────▼─────┐ + │ Transform │ + └─────┬─────┘ + │ + ┌───────▼───────┐ + │ │ + │ Output │ + │ Frame │ + │ │ + └───────────────┘ + + +Each frame instance is both metadata for the inputs/outputs of a transform and +also a converter between those inputs/outputs and richer coordinate +representations of those inputs/ouputs. + +For example, an output frame of type `~gwcs.coordinate_frames.SpectralFrame` +provides metadata to the `.WCS` object such as the ``axes_type`` being +``"SPECTRAL"`` and the unit of the output etc. The output frame also provides a +converter of the numeric output of the transform to a +`~astropy.coordinates.SpectralCoord` object, by combining this metadata with the +numerical values. + +``axes_order`` and conversion between objects and arguments +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +One of the key concepts regarding coordinate frames is the ``axes_order`` argument. +This argument is used to map from the components of the frame to the inputs/outputs of the transform. +To illustrate this consider this situation where you have a forward transform +which outputs three coordinates ``[lat, lambda, lon]``. These would be +represented as a `.SpectralFrame` and a `.CelestialFrame`, however, the axes of +a `.CelestialFrame` are always ``[lon, lat]``, so by specifying two frames as + +.. code-block:: python + + [SpectralFrame(axes_order=(1,)), CelestialFrame(axes_order=(2, 0))] + +we would map the outputs of this transform into the correct positions in the frames. + As shown below, this is also used when constructing the inputs to the inverse transform. + + +When taking the output from the forward transform the following transformation is performed by the coordinate frames: + +.. code-block:: + + lat, lambda, lon + │ │ │ + └──────┼─────┼────────┐ + ┌───────────┘ └──┐ │ + │ │ │ + ┌─────────▼────────┐ ┌──────▼─────▼─────┐ + │ │ │ │ + │ SpectralFrame │ │ CelestialFrame │ + │ │ │ │ + │ (1,) │ │ (2, 0) │ + │ │ │ │ + └─────────┬────────┘ └──────────┬────┬──┘ + │ │ │ + │ │ │ + ▼ ▼ ▼ + SpectralCoord(lambda) SkyCoord((lon, lat)) + + +When considering the backward transform the following transformations take place in the coordinate frames before the transform is called: + +.. code-block:: + + SpectralCoord(lambda) SkyCoord((lon, lat)) + │ │ │ + └─────┐ ┌────────────┘ │ + │ │ ┌────────────┘ + ▼ ▼ ▼ + [lambda, lon, lat] + │ │ │ + │ │ │ + ┌──────▼─────▼────▼────┐ + │ │ + │ Sort by axes_order │ + │ │ + └────┬──────┬─────┬────┘ + │ │ │ + ▼ ▼ ▼ + lat, lambda, lon + """ + +import abc from collections import defaultdict import logging import numpy as np +from dataclasses import dataclass, InitVar from astropy.utils.misc import isiterable from astropy import time @@ -16,7 +129,7 @@ from astropy.wcs.wcsapi.fitswcs import CTYPE_TO_UCD1 from astropy.coordinates import StokesCoord -__all__ = ['Frame2D', 'CelestialFrame', 'SpectralFrame', 'CompositeFrame', +__all__ = ['BaseCoordinateFrame', 'Frame2D', 'CelestialFrame', 'SpectralFrame', 'CompositeFrame', 'CoordinateFrame', 'TemporalFrame', 'StokesFrame'] @@ -58,10 +171,6 @@ def _ucd1_to_ctype_name_mapping(ctype_to_ucd, allowed_ucd_duplicates): STANDARD_REFERENCE_FRAMES = [frame.upper() for frame in coord.builtin_frames.__all__] -STANDARD_REFERENCE_POSITION = ["GEOCENTER", "BARYCENTER", "HELIOCENTER", - "TOPOCENTER", "LSR", "LSRK", "LSRD", - "GALACTIC_CENTER", "LOCAL_GROUP_CENTER"] - def get_ctype_from_ucd(ucd): """ @@ -80,7 +189,188 @@ def get_ctype_from_ucd(ucd): return UCD1_TO_CTYPE.get(ucd, "") -class CoordinateFrame: +@dataclass +class FrameProperties: + naxes: InitVar[int] + axes_type: tuple[str] + unit: tuple[u.Unit] = None + axes_names: tuple[str] = None + axis_physical_types: list[str] = None + + def __post_init__(self, naxes): + if isinstance(self.axes_type, str): + self.axes_type = (self.axes_type,) + else: + self.axes_type = tuple(self.axes_type) + + if len(self.axes_type) != naxes: + raise ValueError("Length of axes_type does not match number of axes.") + + if self.unit is not None: + if astutil.isiterable(self.unit): + unit = tuple(self.unit) + else: + unit = (self.unit,) + if len(unit) != naxes: + raise ValueError("Number of units does not match number of axes.") + else: + self.unit = tuple(u.Unit(au) for au in unit) + else: + self.unit = tuple(u.dimensionless_unscaled for na in range(naxes)) + + if self.axes_names is not None: + if isinstance(self.axes_names, str): + self.axes_names = (self.axes_names,) + else: + self.axes_names = tuple(self.axes_names) + if len(self.axes_names) != naxes: + raise ValueError("Number of axes names does not match number of axes.") + else: + self.axes_names = tuple([""] * naxes) + + if self.axis_physical_types is not None: + if isinstance(self.axis_physical_types, str): + self.axis_physical_types = (self.axis_physical_types,) + elif not isiterable(self.axis_physical_types): + raise TypeError("axis_physical_types must be of type string or iterable of strings") + if len(self.axis_physical_types) != naxes: + raise ValueError(f'"axis_physical_types" must be of length {naxes}') + ph_type = [] + for axt in self.axis_physical_types: + if axt not in VALID_UCDS and not axt.startswith("custom:"): + ph_type.append(f"custom:{axt}") + else: + ph_type.append(axt) + + validate_physical_types(ph_type) + self.axis_physical_types = tuple(ph_type) + + @property + def _default_axis_physical_type(self): + """ + The default physical types to use for this frame if none are specified + by the user. + """ + return tuple("custom:{}".format(t) for t in self.axes_type) + + +class BaseCoordinateFrame(abc.ABC): + """ + API Definition for a Coordinate frame + """ + + _prop: FrameProperties + """ + The FrameProperties object holding properties in native frame order. + """ + + @property + @abc.abstractmethod + def naxes(self) -> int: + """ + The number of axes described by this frame. + """ + + @property + @abc.abstractmethod + def name(self) -> str: + """ + The name of the coordinate frame. + """ + + @property + @abc.abstractmethod + def unit(self) -> tuple[u.Unit, ...]: + """ + The units of the axes in this frame. + """ + + @property + @abc.abstractmethod + def axes_names(self) -> tuple[str, ...]: + """ + Names describing the axes of the frame. + """ + + @property + @abc.abstractmethod + def axes_order(self) -> tuple[int, ...]: + """ + The position of the axes in the frame in the transform. + """ + + @property + @abc.abstractmethod + def reference_frame(self): + """ + The reference frame of the coordinates described by this frame. + + This is usually an Astropy object such as ``SkyCoord`` or ``Time``. + """ + + @property + @abc.abstractmethod + def axes_type(self): + """ + An upcase string describing the type of the axis. + + Known values are ``"SPATIAL", "TEMPORAL", "STOKES", "SPECTRAL", "PIXEL"``. + """ + + @property + @abc.abstractmethod + def axis_physical_types(self): + """ + The UCD 1+ physical types for the axes, in frame order. + """ + + @property + @abc.abstractmethod + def world_axis_object_classes(self): + """ + The APE 14 object classes for this frame. + + See Also + -------- + astropy.wcs.wcsapi.BaseLowLevelWCS.world_axis_object_classes + """ + + @property + def world_axis_object_components(self): + """ + The APE 14 object components for this frame. + + See Also + -------- + astropy.wcs.wcsapi.BaseLowLevelWCS.world_axis_object_components + """ + if self.naxes == 1: + return self._native_world_axis_object_components + + # If we have more than one axis then we should sort the native + # components by the axes_order. + ordered = np.array(self._native_world_axis_object_components, + dtype=object)[np.argsort(self.axes_order)] + return list(map(tuple, ordered)) + + @property + @abc.abstractmethod + def _native_world_axis_object_components(self): + """ + This property holds the "native" frame order of the components. + + The native order of the componets is the order the frame assumes the + axes are in when creating the high level objects, for example + ``CelestialFrame`` creates ``SkyCoord`` objects which are in lon, lat + order (in their positional args). + + This property is used both to construct the ordered + ``world_axis_object_components`` property as well as by `CompositeFrame` + to be able to get the components in their native order. + """ + + +class CoordinateFrame(BaseCoordinateFrame): """ Base class for Coordinate Frames. @@ -94,8 +384,6 @@ class CoordinateFrame: A dimension in the input data that corresponds to this axis. reference_frame : astropy.coordinates.builtin_frames Reference frame (usually used with output_frame to convert to world coordinate objects). - reference_position : str - Reference position - one of ``STANDARD_REFERENCE_POSITION`` unit : list of astropy.units.Unit Unit for each axis. axes_names : list @@ -105,81 +393,44 @@ class CoordinateFrame: """ def __init__(self, naxes, axes_type, axes_order, reference_frame=None, - reference_position=None, unit=None, axes_names=None, + unit=None, axes_names=None, name=None, axis_physical_types=None): self._naxes = naxes self._axes_order = tuple(axes_order) - if isinstance(axes_type, str): - self._axes_type = (axes_type,) - else: - self._axes_type = tuple(axes_type) - self._reference_frame = reference_frame - if unit is not None: - if astutil.isiterable(unit): - unit = tuple(unit) - else: - unit = (unit,) - if len(unit) != naxes: - raise ValueError("Number of units does not match number of axes.") - else: - self._unit = tuple([u.Unit(au) for au in unit]) - else: - self._unit = tuple(u.Unit("") for na in range(naxes)) - if axes_names is not None: - if isinstance(axes_names, str): - axes_names = (axes_names,) - else: - axes_names = tuple(axes_names) - if len(axes_names) != naxes: - raise ValueError("Number of axes names does not match number of axes.") - else: - axes_names = tuple([""] * naxes) - self._axes_names = axes_names if name is None: self._name = self.__class__.__name__ else: self._name = name - self._reference_position = reference_position - - if len(self._axes_type) != naxes: - raise ValueError("Length of axes_type does not match number of axes.") if len(self._axes_order) != naxes: raise ValueError("Length of axes_order does not match number of axes.") - super(CoordinateFrame, self).__init__() - # _axis_physical_types holds any user supplied physical types - self._axis_physical_types = self._set_axis_physical_types(axis_physical_types) + if isinstance(axes_type, str): + axes_type = (axes_type,) + + self._prop = FrameProperties( + naxes, + axes_type, + unit, + axes_names, + axis_physical_types or self._default_axis_physical_type(axes_type) + ) + + super().__init__() - def _set_axis_physical_types(self, pht): + def _default_axis_physical_type(self, axes_type): """ - Set the physical type of the coordinate axes using VO UCD1+ v1.23 definitions. + The default physical types to use for this frame if none are specified + by the user. """ - if pht is not None: - if isinstance(pht, str): - pht = (pht,) - elif not isiterable(pht): - raise TypeError("axis_physical_types must be of type string or iterable of strings") - if len(pht) != self.naxes: - raise ValueError('"axis_physical_types" must be of length {}'.format(self.naxes)) - ph_type = [] - for axt in pht: - if axt not in VALID_UCDS and not axt.startswith("custom:"): - ph_type.append("custom:{}".format(axt)) - else: - ph_type.append(axt) - - validate_physical_types(ph_type) - return tuple(ph_type) + return tuple("custom:{}".format(t) for t in axes_type) def __repr__(self): fmt = '<{0}(name="{1}", unit={2}, axes_names={3}, axes_order={4}'.format( self.__class__.__name__, self.name, self.unit, self.axes_names, self.axes_order) - if self.reference_position is not None: - fmt += ', reference_position="{0}"'.format(self.reference_position) if self.reference_frame is not None: fmt += ", reference_frame={0}".format(self.reference_frame) fmt += ")>" @@ -190,6 +441,11 @@ def __str__(self): return self._name return self.__class__.__name__ + def _sort_property(self, property): + sorted_prop = sorted(zip(property, self.axes_order), + key=lambda x: x[1]) + return tuple([t[0] for t in sorted_prop]) + @property def name(self): """ A custom name of this frame.""" @@ -208,12 +464,12 @@ def naxes(self): @property def unit(self): """The unit of this frame.""" - return self._unit + return self._sort_property(self._prop.unit) @property def axes_names(self): """ Names of axes in the frame.""" - return self._axes_names + return self._sort_property(self._prop.axes_names) @property def axes_order(self): @@ -225,41 +481,10 @@ def reference_frame(self): """ Reference frame, used to convert to world coordinate objects. """ return self._reference_frame - @property - def reference_position(self): - """ Reference Position. """ - return getattr(self, "_reference_position", None) - @property def axes_type(self): """ Type of this frame : 'SPATIAL', 'SPECTRAL', 'TIME'. """ - return self._axes_type - - def coordinates(self, *args): - """ Create world coordinates object""" - coo = tuple([arg * un if not hasattr(arg, "to") else arg.to(un) for arg, un in zip(args, self.unit)]) - if self.naxes == 1: - return coo[0] - return coo - - def coordinate_to_quantity(self, *coords): - """ - Given a rich coordinate object return an astropy quantity object. - """ - # NoOp leaves it to the model to handle - # If coords is a 1-tuple of quantity then return the element of the tuple - # This aligns the behavior with the other implementations - if not hasattr(coords, 'unit') and len(coords) == 1: - return coords[0] - return coords - - @property - def _default_axis_physical_types(self): - """ - The default physical types to use for this frame if none are specified - by the user. - """ - return tuple("custom:{}".format(t) for t in self.axes_type) + return self._sort_property(self._prop.axes_type) @property def axis_physical_types(self): @@ -268,23 +493,28 @@ def axis_physical_types(self): These physical types are the types in frame order, not transform order. """ - return self._axis_physical_types or self._default_axis_physical_types + apt = self._prop.axis_physical_types or self._default_axis_physical_types + return self._sort_property(apt) @property - def _world_axis_object_classes(self): + def world_axis_object_classes(self): return {f"{at}{i}" if i != 0 else at: (u.Quantity, (), {'unit': unit}) - for i, (at, unit) in enumerate(zip(self._axes_type, self.unit))} + for i, (at, unit) in enumerate(zip(self.axes_type, self.unit))} @property - def _world_axis_object_components(self): - return [(f"{at}{i}" if i != 0 else at, 0, 'value') for i, at in enumerate(self._axes_type)] + def _native_world_axis_object_components(self): + return [(f"{at}{i}" if i != 0 else at, 0, 'value') for i, at in enumerate(self._prop.axes_type)] class CelestialFrame(CoordinateFrame): """ - Celestial Frame Representation + Representation of a Celesital coordinate system. + + This class has a native order of longitude then latitude, meaning + ``axes_names``, ``unit`` and ``axis_physical_types`` should be lon, lat ordered. If your transform is + in a different order this should be specified with ``axes_order``. Parameters ---------- @@ -298,6 +528,8 @@ class CelestialFrame(CoordinateFrame): Names of the axes in this frame. name : str Name of this frame. + axis_physical_types : list + The UCD 1+ physical types for the axes, in frame order (lon, lat). """ def __init__(self, axes_order=None, reference_frame=None, @@ -313,91 +545,50 @@ def __init__(self, axes_order=None, reference_frame=None, if axes_names is None: axes_names = _axes_names naxes = len(_axes_names) - _unit = list(reference_frame.representation_component_units.values()) - if unit is None and _unit: - unit = _unit + self.native_axes_order = tuple(range(naxes)) if axes_order is None: - axes_order = tuple(range(naxes)) + axes_order = self.native_axes_order if unit is None: unit = tuple([u.degree] * naxes) axes_type = ['SPATIAL'] * naxes - super(CelestialFrame, self).__init__(naxes=naxes, axes_type=axes_type, - axes_order=axes_order, - reference_frame=reference_frame, - unit=unit, - axes_names=axes_names, - name=name, axis_physical_types=axis_physical_types) - - @property - def _default_axis_physical_types(self): - if isinstance(self.reference_frame, coord.Galactic): + pht = axis_physical_types or self._default_axis_physical_types(reference_frame, axes_names) + super().__init__(naxes=naxes, + axes_type=axes_type, + axes_order=axes_order, + reference_frame=reference_frame, + unit=unit, + axes_names=axes_names, + name=name, + axis_physical_types=pht) + + def _default_axis_physical_types(self, reference_frame, axes_names): + if isinstance(reference_frame, coord.Galactic): return "pos.galactic.lon", "pos.galactic.lat" - elif isinstance(self.reference_frame, (coord.GeocentricTrueEcliptic, - coord.GCRS, - coord.PrecessedGeocentric)): + elif isinstance(reference_frame, (coord.GeocentricTrueEcliptic, + coord.GCRS, + coord.PrecessedGeocentric)): return "pos.bodyrc.lon", "pos.bodyrc.lat" - elif isinstance(self.reference_frame, coord.builtin_frames.BaseRADecFrame): + elif isinstance(reference_frame, coord.builtin_frames.BaseRADecFrame): return "pos.eq.ra", "pos.eq.dec" - elif isinstance(self.reference_frame, coord.builtin_frames.BaseEclipticFrame): + elif isinstance(reference_frame, coord.builtin_frames.BaseEclipticFrame): return "pos.ecliptic.lon", "pos.ecliptic.lat" else: - return tuple("custom:{}".format(t) for t in self.axes_names) + return tuple("custom:{}".format(t) for t in axes_names) @property - def _world_axis_object_classes(self): + def world_axis_object_classes(self): return {'celestial': ( coord.SkyCoord, (), {'frame': self.reference_frame, - 'unit': self.unit})} + 'unit': self._prop.unit})} @property - def _world_axis_object_components(self): - return [('celestial', 0, 'spherical.lon'), - ('celestial', 1, 'spherical.lat')] - - def coordinates(self, *args): - """ - Create a SkyCoord object. - - Parameters - ---------- - args : float - inputs to wcs.input_frame - """ - if isinstance(args[0], coord.SkyCoord): - return args[0].transform_to(self.reference_frame) - return coord.SkyCoord(*args, unit=self.unit, frame=self.reference_frame) - - def coordinate_to_quantity(self, *coords): - """ Convert a ``SkyCoord`` object to quantities.""" - if len(coords) == 2: - arg = coords - elif len(coords) == 1: - arg = coords[0] - else: - raise ValueError("Unexpected number of coordinates in " - "input to frame {} : " - "expected 2, got {}".format(self.name, len(coords))) - - if isinstance(arg, coord.SkyCoord): - arg = arg.transform_to(self._reference_frame) - try: - lon = arg.data.lon - lat = arg.data.lat - except AttributeError: - lon = arg.spherical.lon - lat = arg.spherical.lat - - return lon, lat - - elif all(isinstance(a, u.Quantity) for a in arg): - return tuple(arg) - - else: - raise ValueError("Could not convert input {} to lon and lat quantities.".format(arg)) + def _native_world_axis_object_components(self): + return [('celestial', 0, lambda sc: sc.spherical.lon.to_value(self._prop.unit[0])), + ('celestial', 1, lambda sc: sc.spherical.lat.to_value(self._prop.unit[1]))] class SpectralFrame(CoordinateFrame): @@ -416,63 +607,48 @@ class SpectralFrame(CoordinateFrame): Spectral axis name. name : str Name for this frame. - reference_position : str - Reference position - one of ``STANDARD_REFERENCE_POSITION`` """ def __init__(self, axes_order=(0,), reference_frame=None, unit=None, - axes_names=None, name=None, axis_physical_types=None, - reference_position=None): + axes_names=None, name=None, axis_physical_types=None): - super(SpectralFrame, self).__init__(naxes=1, axes_type="SPECTRAL", axes_order=axes_order, - axes_names=axes_names, reference_frame=reference_frame, - unit=unit, name=name, - reference_position=reference_position, - axis_physical_types=axis_physical_types) + if not isiterable(unit): + unit = (unit,) + unit = [u.Unit(un) for un in unit] + pht = axis_physical_types or self._default_axis_physical_types(unit) - @property - def _default_axis_physical_types(self): - if self.unit[0].physical_type == "frequency": + super().__init__(naxes=1, axes_type="SPECTRAL", axes_order=axes_order, + axes_names=axes_names, reference_frame=reference_frame, + unit=unit, name=name, + axis_physical_types=pht) + + def _default_axis_physical_types(self, unit): + if unit[0].physical_type == "frequency": return ("em.freq",) - elif self.unit[0].physical_type == "length": + elif unit[0].physical_type == "length": return ("em.wl",) - elif self.unit[0].physical_type == "energy": + elif unit[0].physical_type == "energy": return ("em.energy",) - elif self.unit[0].physical_type == "speed": + elif unit[0].physical_type == "speed": return ("spect.dopplerVeloc",) logging.warning("Physical type may be ambiguous. Consider " "setting the physical type explicitly as " "either 'spect.dopplerVeloc.optical' or " "'spect.dopplerVeloc.radio'.") else: - return ("custom:{}".format(self.unit[0].physical_type),) + return ("custom:{}".format(unit[0].physical_type),) @property - def _world_axis_object_classes(self): + def world_axis_object_classes(self): return {'spectral': ( coord.SpectralCoord, (), {'unit': self.unit[0]})} @property - def _world_axis_object_components(self): - return [('spectral', 0, 'value')] - - def coordinates(self, *args): - # using SpectralCoord - if isinstance(args[0], coord.SpectralCoord): - return args[0].to(self.unit[0]) - else: - if hasattr(args[0], 'unit'): - return coord.SpectralCoord(*args).to(self.unit[0]) - else: - return coord.SpectralCoord(*args, self.unit[0]) - - def coordinate_to_quantity(self, *coords): - if hasattr(coords[0], 'unit'): - return coords[0] - return coords[0] * self.unit[0] + def _native_world_axis_object_components(self): + return [('spectral', 0, lambda sc: sc.to_value(self.unit[0]))] class TemporalFrame(CoordinateFrame): @@ -502,9 +678,11 @@ def __init__(self, reference_frame, unit=None, axes_order=(0,), reference_frame.scale, reference_frame.location) + pht = axis_physical_types or self._default_axis_physical_types() + super().__init__(naxes=1, axes_type="TIME", axes_order=axes_order, axes_names=axes_names, reference_frame=reference_frame, - unit=unit, name=name, axis_physical_types=axis_physical_types) + unit=unit, name=name, axis_physical_types=pht) self._attrs = {} for a in self.reference_frame.info._represent_as_dict_extra_attrs: try: @@ -512,12 +690,22 @@ def __init__(self, reference_frame, unit=None, axes_order=(0,), except AttributeError: pass - @property def _default_axis_physical_types(self): return ("time",) + def _convert_to_time(self, dt, *, unit, **kwargs): + if (not isinstance(dt, time.TimeDelta) and + isinstance(dt, time.Time) or + isinstance(self.reference_frame.value, np.ndarray)): + return time.Time(dt, **kwargs) + + if not hasattr(dt, 'unit'): + dt = dt * unit + + return self.reference_frame + dt + @property - def _world_axis_object_classes(self): + def world_axis_object_classes(self): comp = ( time.Time, (), @@ -527,7 +715,7 @@ def _world_axis_object_classes(self): return {'temporal': comp} @property - def _world_axis_object_components(self): + def _native_world_axis_object_components(self): if isinstance(self.reference_frame.value, np.ndarray): return [('temporal', 0, 'value')] @@ -535,42 +723,6 @@ def offset_from_time_and_reference(time): return (time - self.reference_frame).sec return [('temporal', 0, offset_from_time_and_reference)] - def coordinates(self, *args): - if np.isscalar(args): - dt = args - else: - dt = args[0] - - return self._convert_to_time(dt, unit=self.unit[0], **self._attrs) - - def _convert_to_time(self, dt, *, unit, **kwargs): - if (not isinstance(dt, time.TimeDelta) and - isinstance(dt, time.Time) or - isinstance(self.reference_frame.value, np.ndarray)): - return time.Time(dt, **kwargs) - - if not hasattr(dt, 'unit'): - dt = dt * unit - - return self.reference_frame + dt - - def coordinate_to_quantity(self, *coords): - if isinstance(coords[0], time.Time): - ref_value = self.reference_frame.value - if not isinstance(ref_value, np.ndarray): - return (coords[0] - self.reference_frame).to(self.unit[0]) - else: - # If we can't convert to a quantity just drop the object out - # and hope the transform can cope. - return coords[0] - # Is already a quantity - elif hasattr(coords[0], 'unit'): - return coords[0] - if isinstance(coords[0], np.ndarray): - return coords[0] * self.unit[0] - else: - raise ValueError("Can not convert {} to Quantity".format(coords[0])) - class CompositeFrame(CoordinateFrame): """ @@ -579,81 +731,53 @@ class CompositeFrame(CoordinateFrame): Parameters ---------- frames : list - List of frames (TemporalFrame, CelestialFrame, SpectralFrame, CoordinateFrame). + List of constituient frames. name : str Name for this frame. - """ def __init__(self, frames, name=None): self._frames = frames[:] naxes = sum([frame._naxes for frame in self._frames]) - axes_type = list(range(naxes)) - unit = list(range(naxes)) - axes_names = list(range(naxes)) + axes_order = [] - ph_type = list(range(naxes)) + axes_type = [] + axes_names = [] + unit = [] + ph_type = [] + for frame in frames: axes_order.extend(frame.axes_order) + + # Stack the raw (not-native) ordered properties for frame in frames: - for ind, axtype, un, n, pht in zip(frame.axes_order, frame.axes_type, - frame.unit, frame.axes_names, frame.axis_physical_types): - axes_type[ind] = axtype - axes_names[ind] = n - unit[ind] = un - ph_type[ind] = pht + axes_type += list(frame._prop.axes_type) + axes_names += list(frame._prop.axes_names) + unit += list(frame._prop.unit) + ph_type += list(frame._prop.axis_physical_types) + if len(np.unique(axes_order)) != len(axes_order): raise ValueError("Incorrect numbering of axes, " "axes_order should contain unique numbers, " - "got {}.".format(axes_order)) + f"got {axes_order}.") - super(CompositeFrame, self).__init__(naxes, axes_type=axes_type, - axes_order=axes_order, - unit=unit, axes_names=axes_names, - name=name) + super().__init__(naxes, axes_type=axes_type, + axes_order=axes_order, + unit=unit, axes_names=axes_names, + axis_physical_types=tuple(ph_type), + name=name) self._axis_physical_types = tuple(ph_type) @property def frames(self): + """ + The constituient frames that comprise this `CompositeFrame`. + """ return self._frames def __repr__(self): return repr(self.frames) - def coordinates(self, *args): - coo = [] - if len(args) == len(self.frames): - for frame, arg in zip(self.frames, args): - coo.append(frame.coordinates(arg)) - else: - for frame in self.frames: - fargs = [args[i] for i in frame.axes_order] - coo.append(frame.coordinates(*fargs)) - return coo - - def coordinate_to_quantity(self, *coords): - if len(coords) == len(self.frames): - args = coords - elif len(coords) == self.naxes: - args = [] - for _frame in self.frames: - if _frame.naxes > 1: - # Collect the arguments for this frame based on axes_order - args.append([coords[i] for i in _frame.axes_order]) - else: - args.append(coords[_frame.axes_order[0]]) - else: - raise ValueError("Incorrect number of arguments") - - qs = [] - for _frame, arg in zip(self.frames, args): - ret = _frame.coordinate_to_quantity(arg) - if isinstance(ret, tuple): - qs += list(ret) - else: - qs.append(ret) - return qs - @property def _wao_classes_rename_map(self): mapper = defaultdict(dict) @@ -661,7 +785,7 @@ def _wao_classes_rename_map(self): for frame in self.frames: # ensure the frame is in the mapper mapper[frame] - for key in frame._world_axis_object_classes.keys(): + for key in frame.world_axis_object_classes.keys(): if key in seen_names: new_key = f"{key}{seen_names.count(key)}" mapper[frame][key] = new_key @@ -673,7 +797,7 @@ def _wao_renamed_components_iter(self): mapper = self._wao_classes_rename_map for frame in self.frames: renamed_components = [] - for comp in frame._world_axis_object_components: + for comp in frame._native_world_axis_object_components: comp = list(comp) rename = mapper[frame].get(comp[0]) if rename: @@ -685,28 +809,27 @@ def _wao_renamed_components_iter(self): def _wao_renamed_classes_iter(self): mapper = self._wao_classes_rename_map for frame in self.frames: - for key, value in frame._world_axis_object_classes.items(): + for key, value in frame.world_axis_object_classes.items(): rename = mapper[frame].get(key) if rename: key = rename yield key, value @property - def _world_axis_object_components(self): - """ - We need to generate the components respecting the axes_order. - """ + def world_axis_object_components(self): out = [None] * self.naxes + for frame, components in self._wao_renamed_components_iter: for i, ao in enumerate(frame.axes_order): out[ao] = components[i] if any([o is None for o in out]): raise ValueError("axes_order leads to incomplete world_axis_object_components") + return out @property - def _world_axis_object_classes(self): + def world_axis_object_classes(self): return dict(self._wao_renamed_classes_iter) @@ -723,16 +846,18 @@ class StokesFrame(CoordinateFrame): """ def __init__(self, axes_order=(0,), axes_names=("stokes",), name=None, axis_physical_types=None): - super(StokesFrame, self).__init__(1, ["STOKES"], axes_order, name=name, - axes_names=axes_names, unit=u.one, - axis_physical_types=axis_physical_types) - @property + pht = axis_physical_types or self._default_axis_physical_types() + + super().__init__(1, ["STOKES"], axes_order, name=name, + axes_names=axes_names, unit=u.one, + axis_physical_types=pht) + def _default_axis_physical_types(self): return ("phys.polarization.stokes",) @property - def _world_axis_object_classes(self): + def world_axis_object_classes(self): return {'stokes': ( StokesCoord, (), @@ -740,22 +865,9 @@ def _world_axis_object_classes(self): )} @property - def _world_axis_object_components(self): + def _native_world_axis_object_components(self): return [('stokes', 0, 'value')] - def coordinates(self, *args): - if isinstance(args[0], u.Quantity): - arg = args[0].value - else: - arg = args[0] - - return StokesCoord(arg) - - def coordinate_to_quantity(self, *coords): - if isinstance(coords[0], StokesCoord): - return coords[0].value << u.one - return coords[0] - class Frame2D(CoordinateFrame): """ @@ -774,38 +886,19 @@ class Frame2D(CoordinateFrame): """ def __init__(self, axes_order=(0, 1), unit=(u.pix, u.pix), axes_names=('x', 'y'), - name=None, axis_physical_types=None): + name=None, axes_type=["SPATIAL", "SPATIAL"], axis_physical_types=None): - super(Frame2D, self).__init__(naxes=2, axes_type=["SPATIAL", "SPATIAL"], - axes_order=axes_order, name=name, - axes_names=axes_names, unit=unit, - axis_physical_types=axis_physical_types) + pht = axis_physical_types or self._default_axis_physical_types(axes_names, axes_type) - @property - def _default_axis_physical_types(self): - if all(self.axes_names): - ph_type = self.axes_names - else: - ph_type = self.axes_type - return tuple("custom:{}".format(t) for t in ph_type) + super().__init__(naxes=2, axes_type=axes_type, + axes_order=axes_order, name=name, + axes_names=axes_names, unit=unit, + axis_physical_types=pht) - def coordinates(self, *args): - args = [args[i] for i in self.axes_order] - coo = tuple([arg * un for arg, un in zip(args, self.unit)]) - return coo - - def coordinate_to_quantity(self, *coords): - # list or tuple - if len(coords) == 1 and astutil.isiterable(coords[0]): - coords = list(coords[0]) - elif len(coords) == 2: - coords = list(coords) + def _default_axis_physical_types(self, axes_names, axes_type): + if axes_names is not None and all(axes_names): + ph_type = axes_names else: - raise ValueError("Unexpected number of coordinates in " - "input to frame {} : " - "expected 2, got {}".format(self.name, len(coords))) - - for i in range(2): - if not hasattr(coords[i], 'unit'): - coords[i] = coords[i] * self.unit[i] - return tuple(coords) + ph_type = axes_type + + return tuple("custom:{}".format(t) for t in ph_type) diff --git a/gwcs/selector.py b/gwcs/selector.py index 12ac7914..8b4d88b4 100644 --- a/gwcs/selector.py +++ b/gwcs/selector.py @@ -531,7 +531,20 @@ def __init__(self, inputs, outputs, selector, label_mapper, undefined_transform_ raise ValueError('"0" and " " are not allowed as keys.') self._input_units_strict = {key: False for key in self._inputs} self._input_units_allow_dimensionless = {key: False for key in self._inputs} - super(RegionsSelector, self).__init__(n_models=1, name=name, **kwargs) + super().__init__(n_models=1, name=name, **kwargs) + # Validate uses_quantity at init time for nicer error message + self.uses_quantity # noqa + + @property + def uses_quantity(self): + all_uses_quantity = [t.uses_quantity for t in self._selector.values()] + not_all_uses_quantity = [not uq for uq in all_uses_quantity] + if all(all_uses_quantity): + return True + elif not_all_uses_quantity: + return False + else: + raise ValueError("You can not mix models which use quantity and do not use quantity inside a RegionSelector") def set_input(self, rid): """ diff --git a/gwcs/tests/conftest.py b/gwcs/tests/conftest.py index 0b0a8878..a7f19182 100644 --- a/gwcs/tests/conftest.py +++ b/gwcs/tests/conftest.py @@ -10,10 +10,10 @@ from astropy import coordinates as coord from astropy.modeling import models -from .. import coordinate_frames as cf -from .. import spectroscopy as sp -from .. import wcs -from .. import geometry +from gwcs import coordinate_frames as cf +from gwcs import spectroscopy as sp +from gwcs import wcs +from gwcs import geometry # frames detector_1d = cf.CoordinateFrame(name='detector', axes_order=(0,), naxes=1, axes_type="detector") @@ -54,7 +54,7 @@ def gwcs_2d_spatial_reordered(): A simple one step spatial WCS, in ICRS with a 1 and 2 px shift. """ out_frame = cf.CelestialFrame(reference_frame=coord.ICRS(), - axes_order=(1, 0)) + axes_order=(1, 0)) return wcs.WCS(model_2d_shift | models.Mapping((1, 0)), input_frame=detector_2d, output_frame=out_frame) @@ -125,7 +125,8 @@ def gwcs_3d_identity_units(): models.Multiply(1 * u.nm / u.pixel)) sky_frame = cf.CelestialFrame(axes_order=(0, 1), name='icrs', reference_frame=coord.ICRS(), - axes_names=("longitude", "latitude")) + axes_names=("longitude", "latitude"), + unit=(u.arcsec, u.arcsec)) wave_frame = cf.SpectralFrame(axes_order=(2, ), unit=u.nm, axes_names=("wavelength",)) frame = cf.CompositeFrame([sky_frame, wave_frame]) @@ -255,7 +256,7 @@ def gwcs_3d_galactic_spectral(): shift = models.Shift(-crpix3) & models.Shift(-crpix1) scale = models.Multiply(cdelt3) & models.Multiply(cdelt1) - proj = models.Pix2Sky_CAR() + proj = models.Pix2Sky_TAN() skyrot = models.RotateNative2Celestial(crval3, 90 + crval1, 180) celestial = shift | scale | proj | skyrot diff --git a/gwcs/tests/test_api.py b/gwcs/tests/test_api.py index fd6f916c..a3d22bc9 100644 --- a/gwcs/tests/test_api.py +++ b/gwcs/tests/test_api.py @@ -79,6 +79,11 @@ def test_names(wcsobj): assert wcsobj.pixel_axis_names == wcsobj.input_frame.axes_names +def test_names_split(gwcs_3d_galactic_spectral): + wcs = gwcs_3d_galactic_spectral + assert wcs.world_axis_names == wcs.output_frame.axes_names == ("Latitude", "Frequency", "Longitude") + + @fixture_wcs_ndim_types_units def test_pixel_n_dim(wcs_ndim_types_units): wcsobj, ndims, *_ = wcs_ndim_types_units @@ -106,7 +111,7 @@ def test_world_axis_units(wcs_ndim_types_units): @pytest.mark.parametrize(("x", "y"), zip((x, xarr), (y, yarr))) def test_pixel_to_world_values(gwcs_2d_spatial_shift, x, y): wcsobj = gwcs_2d_spatial_shift - assert_allclose(wcsobj.pixel_to_world_values(x, y), wcsobj(x, y, with_units=False)) + assert_allclose(wcsobj.pixel_to_world_values(x, y), wcsobj(x, y)) @pytest.mark.parametrize(("x", "y"), zip((x, xarr), (y, yarr))) @@ -116,7 +121,7 @@ def test_pixel_to_world_values_units_2d(gwcs_2d_shift_scale_quantity, x, y): call_pixel = x*u.pix, y*u.pix api_pixel = x, y - call_world = wcsobj(*call_pixel, with_units=False) + call_world = wcsobj(*call_pixel) api_world = wcsobj.pixel_to_world_values(*api_pixel) # Check that call returns quantities and api dosen't @@ -126,7 +131,7 @@ def test_pixel_to_world_values_units_2d(gwcs_2d_shift_scale_quantity, x, y): # Check that they are the same (and implicitly in the same units) assert_allclose(u.Quantity(call_world).value, api_world) - new_call_pixel = wcsobj.invert(*call_world, with_units=False) + new_call_pixel = wcsobj.invert(*call_world) [assert_allclose(n, p) for n, p in zip(new_call_pixel, call_pixel)] new_api_pixel = wcsobj.world_to_pixel_values(*api_world) @@ -140,7 +145,7 @@ def test_pixel_to_world_values_units_1d(gwcs_1d_freq_quantity, x): call_pixel = x * u.pix api_pixel = x - call_world = wcsobj(call_pixel, with_units=False) + call_world = wcsobj(call_pixel) api_world = wcsobj.pixel_to_world_values(api_pixel) # Check that call returns quantities and api dosen't @@ -150,7 +155,7 @@ def test_pixel_to_world_values_units_1d(gwcs_1d_freq_quantity, x): # Check that they are the same (and implicitly in the same units) assert_allclose(u.Quantity(call_world).value, api_world) - new_call_pixel = wcsobj.invert(call_world, with_units=False) + new_call_pixel = wcsobj.invert(call_world) assert_allclose(new_call_pixel, call_pixel) new_api_pixel = wcsobj.world_to_pixel_values(api_world) @@ -160,13 +165,15 @@ def test_pixel_to_world_values_units_1d(gwcs_1d_freq_quantity, x): @pytest.mark.parametrize(("x", "y"), zip((x, xarr), (y, yarr))) def test_array_index_to_world_values(gwcs_2d_spatial_shift, x, y): wcsobj = gwcs_2d_spatial_shift - assert_allclose(wcsobj.array_index_to_world_values(x, y), wcsobj(y, x, with_units=False)) + assert_allclose(wcsobj.array_index_to_world_values(x, y), wcsobj(y, x)) def test_world_axis_object_components_2d(gwcs_2d_spatial_shift): waoc = gwcs_2d_spatial_shift.world_axis_object_components - assert waoc == [('celestial', 0, 'spherical.lon'), - ('celestial', 1, 'spherical.lat')] + assert waoc[0][:2] == ('celestial', 0) + assert callable(waoc[0][2]) + assert waoc[1][:2] == ('celestial', 1) + assert callable(waoc[1][2]) def test_world_axis_object_components_2d_generic(gwcs_2d_quantity_shift): @@ -177,15 +184,19 @@ def test_world_axis_object_components_2d_generic(gwcs_2d_quantity_shift): def test_world_axis_object_components_1d(gwcs_1d_freq): waoc = gwcs_1d_freq.world_axis_object_components - assert waoc == [('spectral', 0, 'value')] + assert [c[:2] for c in waoc] == [('spectral', 0)] + assert callable(waoc[0][2]) def test_world_axis_object_components_4d(gwcs_4d_identity_units): waoc = gwcs_4d_identity_units.world_axis_object_components - assert waoc[0:3] == [('celestial', 0, 'spherical.lon'), - ('celestial', 1, 'spherical.lat'), - ('spectral', 0, 'value')] - assert waoc[3][0:2] == ('temporal', 0) + first_two = [c[:2] for c in waoc] + last_one = [c[2] for c in waoc] + assert first_two == [('celestial', 0), + ('celestial', 1), + ('spectral', 0), + ('temporal', 0)] + assert all([callable(l) for l in last_one]) def test_world_axis_object_classes_2d(gwcs_2d_spatial_shift): @@ -195,7 +206,7 @@ def test_world_axis_object_classes_2d(gwcs_2d_spatial_shift): assert 'frame' in waoc['celestial'][2] assert 'unit' in waoc['celestial'][2] assert isinstance(waoc['celestial'][2]['frame'], coord.ICRS) - assert waoc['celestial'][2]['unit'] == (u.deg, u.deg) + assert tuple(waoc['celestial'][2]['unit']) == (u.deg, u.deg) def test_world_axis_object_classes_2d_generic(gwcs_2d_quantity_shift): @@ -217,7 +228,7 @@ def test_world_axis_object_classes_4d(gwcs_4d_identity_units): assert 'frame' in waoc['celestial'][2] assert 'unit' in waoc['celestial'][2] assert isinstance(waoc['celestial'][2]['frame'], coord.ICRS) - assert waoc['celestial'][2]['unit'] == (u.deg, u.deg) + assert tuple(waoc['celestial'][2]['unit']) == (u.deg, u.deg) temporal = waoc['temporal'] assert temporal[0] is time.Time @@ -263,10 +274,8 @@ def test_high_level_wrapper(wcsobj, request): pixel_input = [3] * wcsobj.pixel_n_dim - # If the model expects units we have to pass in units - if wcsobj.forward_transform.uses_quantity: - pixel_input *= u.pix - + # Assert that both APE 14 API and GWCS give the same answer The APE 14 API + # uses the mixin class and __call__ calls values_to_high_level_objects wc1 = hlvl.pixel_to_world(*pixel_input) wc2 = wcsobj(*pixel_input, with_units=True) @@ -278,6 +287,22 @@ def test_high_level_wrapper(wcsobj, request): else: _compare_frame_output(wc1, wc2) + # we have just asserted that wc1 and wc2 are equal + if not isinstance(wc1, (list, tuple)): + wc1 = (wc1,) + + pix_out1 = hlvl.world_to_pixel(*wc1) + pix_out2 = wcsobj.invert(*wc1) + + if not isinstance(pix_out2, (list, tuple)): + pix_out2 = (pix_out2,) + + if wcsobj.forward_transform.uses_quantity: + pix_out2 = tuple(p.to_value(unit) for p, unit in zip(pix_out2, wcsobj.input_frame.unit)) + + np.testing.assert_allclose(pix_out1, pixel_input) + np.testing.assert_allclose(pix_out2, pixel_input) + def test_stokes_wrapper(gwcs_stokes_lookup): pytest.importorskip("astropy", minversion="4.0dev0") @@ -362,24 +387,20 @@ def test_low_level_wcs(wcsobj): @wcs_objs def test_pixel_to_world(wcsobj): - comp = wcsobj(x, y, with_units=True) - comp = wcsobj.output_frame.coordinates(comp) + values = wcsobj(x, y) result = wcsobj.pixel_to_world(x, y) - assert isinstance(comp, coord.SkyCoord) assert isinstance(result, coord.SkyCoord) - assert_allclose(comp.data.lon, result.data.lon) - assert_allclose(comp.data.lat, result.data.lat) + assert_allclose(values[0] * u.deg, result.data.lon) + assert_allclose(values[1] * u.deg, result.data.lat) @wcs_objs def test_array_index_to_world(wcsobj): - comp = wcsobj(x, y, with_units=True) - comp = wcsobj.output_frame.coordinates(comp) + values = wcsobj(x, y) result = wcsobj.array_index_to_world(y, x) - assert isinstance(comp, coord.SkyCoord) assert isinstance(result, coord.SkyCoord) - assert_allclose(comp.data.lon, result.data.lon) - assert_allclose(comp.data.lat, result.data.lat) + assert_allclose(values[0] * u.deg, result.data.lon) + assert_allclose(values[1] * u.deg, result.data.lat) def test_pixel_to_world_quantity(gwcs_2d_shift_scale, gwcs_2d_shift_scale_quantity): @@ -460,28 +481,28 @@ def sky_ra_dec(request, gwcs_2d_spatial_shift): def test_world_to_pixel(gwcs_2d_spatial_shift, sky_ra_dec): wcsobj = gwcs_2d_spatial_shift sky, ra, dec = sky_ra_dec - assert_allclose(wcsobj.world_to_pixel(sky), wcsobj.invert(ra, dec, with_units=False)) + assert_allclose(wcsobj.world_to_pixel(sky), wcsobj.invert(ra, dec)) def test_world_to_array_index(gwcs_2d_spatial_shift, sky_ra_dec): wcsobj = gwcs_2d_spatial_shift sky, ra, dec = sky_ra_dec - assert_allclose(wcsobj.world_to_array_index(sky), wcsobj.invert(ra, dec, with_units=False)[::-1]) + assert_allclose(wcsobj.world_to_array_index(sky), wcsobj.invert(ra, dec)[::-1]) def test_world_to_pixel_values(gwcs_2d_spatial_shift, sky_ra_dec): wcsobj = gwcs_2d_spatial_shift sky, ra, dec = sky_ra_dec - assert_allclose(wcsobj.world_to_pixel_values(sky), wcsobj.invert(ra, dec, with_units=False)) + assert_allclose(wcsobj.world_to_pixel_values(ra, dec), wcsobj.invert(ra, dec)) def test_world_to_array_index_values(gwcs_2d_spatial_shift, sky_ra_dec): wcsobj = gwcs_2d_spatial_shift sky, ra, dec = sky_ra_dec - assert_allclose(wcsobj.world_to_array_index_values(sky), - wcsobj.invert(ra, dec, with_units=False)[::-1]) + assert_allclose(wcsobj.world_to_array_index_values(ra, dec), + wcsobj.invert(ra, dec)[::-1]) def test_ndim_str_frames(gwcs_with_frames_strings): @@ -494,12 +515,12 @@ def test_composite_many_base_frame(): q_frame_2 = cf.CoordinateFrame(name='distance', axes_order=(1,), naxes=1, axes_type="SPATIAL", unit=(u.m,)) frame = cf.CompositeFrame([q_frame_1, q_frame_2]) - wao_classes = frame._world_axis_object_classes + wao_classes = frame.world_axis_object_classes assert len(wao_classes) == 2 assert not set(wao_classes.keys()).difference({"SPATIAL", "SPATIAL1"}) - wao_components = frame._world_axis_object_components + wao_components = frame.world_axis_object_components assert len(wao_components) == 2 assert not {c[0] for c in wao_components}.difference({"SPATIAL", "SPATIAL1"}) @@ -521,3 +542,19 @@ def test_coordinate_frame_api(): pixel2 = wcs.invert(world) assert u.allclose(pixel2, 0*u.pix) + + +def test_world_axis_object_components_units(gwcs_3d_identity_units): + from astropy.wcs.wcsapi.high_level_api import high_level_objects_to_values + + wcs = gwcs_3d_identity_units + world = wcs.pixel_to_world(1, 1, 1) + + values = high_level_objects_to_values(*world, low_level_wcs=wcs) + + expected_values = [world[0].spherical.lon.to_value(wcs.output_frame.unit[0]), + world[0].spherical.lon.to_value(wcs.output_frame.unit[1]), + world[1].to_value(wcs.output_frame.unit[2])] + + assert not any([isinstance(o, u.Quantity) for o in values]) + np.testing.assert_allclose(values, expected_values) diff --git a/gwcs/tests/test_api_slicing.py b/gwcs/tests/test_api_slicing.py index d4866242..3d38a9e9 100644 --- a/gwcs/tests/test_api_slicing.py +++ b/gwcs/tests/test_api_slicing.py @@ -1,6 +1,7 @@ import astropy.units as u from astropy.coordinates import Galactic, SkyCoord, SpectralCoord +from astropy.wcs.wcsapi import wcs_info_str from astropy.wcs.wcsapi.wrappers import SlicedLowLevelWCS from numpy.testing import assert_allclose, assert_equal @@ -31,6 +32,11 @@ """ +def test_no_ellipsis(gwcs_3d_galactic_spectral): + expected_repr = EXPECTED_ELLIPSIS_REPR.replace("SlicedLowLevel", "") + assert wcs_info_str(gwcs_3d_galactic_spectral) == expected_repr.strip() + + def test_ellipsis(gwcs_3d_galactic_spectral): wcs = SlicedLowLevelWCS(gwcs_3d_galactic_spectral, Ellipsis) @@ -44,24 +50,28 @@ def test_ellipsis(gwcs_3d_galactic_spectral): assert_equal(wcs.axis_correlation_matrix, [[True, False, True], [False, True, False], [True, False, True]]) - assert wcs.world_axis_object_components == [('celestial', 1, 'spherical.lat'), - ('spectral', 0, 'value'), - ('celestial', 0, 'spherical.lon')] + first_two = [c[:2] for c in wcs.world_axis_object_components] + last_one = [c[2] for c in wcs.world_axis_object_components] + assert first_two == [('celestial', 1), + ('spectral', 0), + ('celestial', 0)] + + assert all([callable(l) for l in last_one]) assert wcs.world_axis_object_classes['celestial'][0] is SkyCoord assert wcs.world_axis_object_classes['celestial'][1] == () assert isinstance(wcs.world_axis_object_classes['celestial'][2]['frame'], Galactic) - assert wcs.world_axis_object_classes['celestial'][2]['unit'] == (u.deg, u.deg) + assert tuple(wcs.world_axis_object_classes['celestial'][2]['unit']) == (u.deg, u.deg) assert wcs.world_axis_object_classes['spectral'][0] is SpectralCoord assert wcs.world_axis_object_classes['spectral'][1] == () assert wcs.world_axis_object_classes['spectral'][2] == {'unit': 'Hz'} - assert_allclose(wcs.pixel_to_world_values(29, 39, 44), (10, 20, 25)) - assert_allclose(wcs.array_index_to_world_values(44, 39, 29), (10, 20, 25)) + assert_allclose(wcs.pixel_to_world_values(29, 39, 44), (80, 20, 205)) + assert_allclose(wcs.array_index_to_world_values(44, 39, 29), (80, 20, 205)) - assert_allclose(wcs.world_to_pixel_values(10, 20, 25), (29., 39., 44.)) - assert_equal(wcs.world_to_array_index_values(10, 20, 25), (44, 39, 29)) + assert_allclose(wcs.world_to_pixel_values(80, 20, 205), (29., 39., 44.)) + assert_equal(wcs.world_to_array_index_values(80, 20, 205), (44, 39, 29)) assert str(wcs) == EXPECTED_ELLIPSIS_REPR.strip() assert EXPECTED_ELLIPSIS_REPR.strip() in repr(wcs) @@ -106,19 +116,23 @@ def test_spectral_slice(gwcs_3d_galactic_spectral): assert_equal(wcs.axis_correlation_matrix, [[True, True], [True, True]]) - assert wcs.world_axis_object_components == [('celestial', 1, 'spherical.lat'), - ('celestial', 0, 'spherical.lon')] + first_two = [c[:2] for c in wcs.world_axis_object_components] + last_one = [c[2] for c in wcs.world_axis_object_components] + assert first_two == [('celestial', 1), + ('celestial', 0)] + + assert all([callable(l) for l in last_one]) assert wcs.world_axis_object_classes['celestial'][0] is SkyCoord assert wcs.world_axis_object_classes['celestial'][1] == () assert isinstance(wcs.world_axis_object_classes['celestial'][2]['frame'], Galactic) - assert wcs.world_axis_object_classes['celestial'][2]['unit'] == (u.deg, u.deg) + assert tuple(wcs.world_axis_object_classes['celestial'][2]['unit']) == (u.deg, u.deg) - assert_allclose(wcs.pixel_to_world_values(29, 44), (10, 25)) - assert_allclose(wcs.array_index_to_world_values(44, 29), (10, 25)) + assert_allclose(wcs.pixel_to_world_values(29, 44), (80, 205)) + assert_allclose(wcs.array_index_to_world_values(44, 29), (80, 205)) - assert_allclose(wcs.world_to_pixel_values(10, 25), (29., 44.)) - assert_equal(wcs.world_to_array_index_values(10, 25), (44, 29)) + assert_allclose(wcs.world_to_pixel_values(80, 205), (29., 44.)) + assert_equal(wcs.world_to_array_index_values(80, 205), (44, 29)) assert_equal(wcs.pixel_bounds, [(-1, 35), (5, 50)]) @@ -166,24 +180,28 @@ def test_spectral_range(gwcs_3d_galactic_spectral): assert_equal(wcs.axis_correlation_matrix, [[True, False, True], [False, True, False], [True, False, True]]) - assert wcs.world_axis_object_components == [('celestial', 1, 'spherical.lat'), - ('spectral', 0, 'value'), - ('celestial', 0, 'spherical.lon')] + first_two = [c[:2] for c in wcs.world_axis_object_components] + last_one = [c[2] for c in wcs.world_axis_object_components] + assert first_two == [('celestial', 1), + ('spectral', 0), + ('celestial', 0)] + + assert all([callable(l) for l in last_one]) assert wcs.world_axis_object_classes['celestial'][0] is SkyCoord assert wcs.world_axis_object_classes['celestial'][1] == () assert isinstance(wcs.world_axis_object_classes['celestial'][2]['frame'], Galactic) - assert wcs.world_axis_object_classes['celestial'][2]['unit'] == (u.deg, u.deg) + assert tuple(wcs.world_axis_object_classes['celestial'][2]['unit']) == (u.deg, u.deg) assert wcs.world_axis_object_classes['spectral'][0] is SpectralCoord assert wcs.world_axis_object_classes['spectral'][1] == () assert wcs.world_axis_object_classes['spectral'][2] == {'unit': 'Hz'} - assert_allclose(wcs.pixel_to_world_values(29, 35, 44), (10, 20, 25)) - assert_allclose(wcs.array_index_to_world_values(44, 35, 29), (10, 20, 25)) + assert_allclose(wcs.pixel_to_world_values(29, 35, 44), (80, 20, 205)) + assert_allclose(wcs.array_index_to_world_values(44, 35, 29), (80, 20, 205)) - assert_allclose(wcs.world_to_pixel_values(10, 20, 25), (29., 35., 44.)) - assert_equal(wcs.world_to_array_index_values(10, 20, 25), (44, 35, 29)) + assert_allclose(wcs.world_to_pixel_values(80, 20, 205), (29., 35., 44.)) + assert_equal(wcs.world_to_array_index_values(80, 20, 205), (44, 35, 29)) assert_equal(wcs.pixel_bounds, [(-1, 35), (-6, 41), (5, 50)]) @@ -230,24 +248,28 @@ def test_celestial_slice(gwcs_3d_galactic_spectral): assert_equal(wcs.axis_correlation_matrix, [[False, True], [True, False], [False, True]]) - assert wcs.world_axis_object_components == [('celestial', 1, 'spherical.lat'), - ('spectral', 0, 'value'), - ('celestial', 0, 'spherical.lon')] + first_two = [c[:2] for c in wcs.world_axis_object_components] + last_one = [c[2] for c in wcs.world_axis_object_components] + assert first_two == [('celestial', 1), + ('spectral', 0), + ('celestial', 0)] + + assert all([callable(l) for l in last_one]) assert wcs.world_axis_object_classes['celestial'][0] is SkyCoord assert wcs.world_axis_object_classes['celestial'][1] == () assert isinstance(wcs.world_axis_object_classes['celestial'][2]['frame'], Galactic) - assert wcs.world_axis_object_classes['celestial'][2]['unit'] == (u.deg, u.deg) + assert tuple(wcs.world_axis_object_classes['celestial'][2]['unit']) == (u.deg, u.deg) assert wcs.world_axis_object_classes['spectral'][0] is SpectralCoord assert wcs.world_axis_object_classes['spectral'][1] == () assert wcs.world_axis_object_classes['spectral'][2] == {'unit': 'Hz'} - assert_allclose(wcs.pixel_to_world_values(39, 44), (10.24, 20, 25)) - assert_allclose(wcs.array_index_to_world_values(44, 39), (10.24, 20, 25)) + assert_allclose(wcs.pixel_to_world_values(39, 44), (79.76, 20, 205)) + assert_allclose(wcs.array_index_to_world_values(44, 39), (79.76, 20, 205)) - assert_allclose(wcs.world_to_pixel_values(12.4, 20, 25), (39., 44.)) - assert_equal(wcs.world_to_array_index_values(12.4, 20, 25), (44, 39)) + assert_allclose(wcs.world_to_pixel_values(79.76, 20, 205), (39., 44.)) + assert_equal(wcs.world_to_array_index_values(79.76, 20, 205), (44, 39)) assert_equal(wcs.pixel_bounds, [(-2, 45), (5, 50)]) @@ -295,24 +317,28 @@ def test_celestial_range(gwcs_3d_galactic_spectral): assert_equal(wcs.axis_correlation_matrix, [[True, False, True], [False, True, False], [True, False, True]]) - assert wcs.world_axis_object_components == [('celestial', 1, 'spherical.lat'), - ('spectral', 0, 'value'), - ('celestial', 0, 'spherical.lon')] + first_two = [c[:2] for c in wcs.world_axis_object_components] + last_one = [c[2] for c in wcs.world_axis_object_components] + assert first_two == [('celestial', 1), + ('spectral', 0), + ('celestial', 0)] + + assert all([callable(l) for l in last_one]) assert wcs.world_axis_object_classes['celestial'][0] is SkyCoord assert wcs.world_axis_object_classes['celestial'][1] == () assert isinstance(wcs.world_axis_object_classes['celestial'][2]['frame'], Galactic) - assert wcs.world_axis_object_classes['celestial'][2]['unit'] == (u.deg, u.deg) + assert tuple(wcs.world_axis_object_classes['celestial'][2]['unit']) == (u.deg, u.deg) assert wcs.world_axis_object_classes['spectral'][0] is SpectralCoord assert wcs.world_axis_object_classes['spectral'][1] == () assert wcs.world_axis_object_classes['spectral'][2] == {'unit': 'Hz'} - assert_allclose(wcs.pixel_to_world_values(24, 39, 44), (10, 20, 25)) - assert_allclose(wcs.array_index_to_world_values(44, 39, 24), (10, 20, 25)) + assert_allclose(wcs.pixel_to_world_values(24, 39, 44), (80, 20, 205)) + assert_allclose(wcs.array_index_to_world_values(44, 39, 24), (80, 20, 205)) - assert_allclose(wcs.world_to_pixel_values(10, 20, 25), (24., 39., 44.)) - assert_equal(wcs.world_to_array_index_values(10, 20, 25), (44, 39, 24)) + assert_allclose(wcs.world_to_pixel_values(80, 20, 205), (24., 39., 44.)) + assert_equal(wcs.world_to_array_index_values(80, 20, 205), (44, 39, 24)) assert_equal(wcs.pixel_bounds, [(-6, 30), (-2, 45), (5, 50)]) @@ -363,24 +389,28 @@ def test_no_array_shape(gwcs_3d_galactic_spectral): assert_equal(wcs.axis_correlation_matrix, [[True, False, True], [False, True, False], [True, False, True]]) - assert wcs.world_axis_object_components == [('celestial', 1, 'spherical.lat'), - ('spectral', 0, 'value'), - ('celestial', 0, 'spherical.lon')] + first_two = [c[:2] for c in wcs.world_axis_object_components] + last_one = [c[2] for c in wcs.world_axis_object_components] + assert first_two == [('celestial', 1), + ('spectral', 0), + ('celestial', 0)] + + assert all([callable(l) for l in last_one]) assert wcs.world_axis_object_classes['celestial'][0] is SkyCoord assert wcs.world_axis_object_classes['celestial'][1] == () assert isinstance(wcs.world_axis_object_classes['celestial'][2]['frame'], Galactic) - assert wcs.world_axis_object_classes['celestial'][2]['unit'] == (u.deg, u.deg) + assert tuple(wcs.world_axis_object_classes['celestial'][2]['unit']) == (u.deg, u.deg) assert wcs.world_axis_object_classes['spectral'][0] is SpectralCoord assert wcs.world_axis_object_classes['spectral'][1] == () assert wcs.world_axis_object_classes['spectral'][2] == {'unit': 'Hz'} - assert_allclose(wcs.pixel_to_world_values(29, 39, 44), (10, 20, 25)) - assert_allclose(wcs.array_index_to_world_values(44, 39, 29), (10, 20, 25)) + assert_allclose(wcs.pixel_to_world_values(29, 39, 44), (80, 20, 205)) + assert_allclose(wcs.array_index_to_world_values(44, 39, 29), (80, 20, 205)) - assert_allclose(wcs.world_to_pixel_values(10, 20, 25), (29., 39., 44.)) - assert_equal(wcs.world_to_array_index_values(10, 20, 25), (44, 39, 29)) + assert_allclose(wcs.world_to_pixel_values(80, 20, 205), (29., 39., 44.)) + assert_equal(wcs.world_to_array_index_values(80, 20, 205), (44, 39, 29)) assert str(wcs) == EXPECTED_NO_SHAPE_REPR.strip() assert EXPECTED_NO_SHAPE_REPR.strip() in repr(wcs) @@ -416,8 +446,9 @@ def test_no_array_shape(gwcs_3d_galactic_spectral): def test_ellipsis_none_types(gwcs_3d_galactic_spectral): pht = list(gwcs_3d_galactic_spectral.output_frame._axis_physical_types) - pht[1] = None - gwcs_3d_galactic_spectral.output_frame._axis_physical_types = tuple(pht) + # This index is in "axes_order" ordering + pht[2] = None + gwcs_3d_galactic_spectral.output_frame._prop.axis_physical_types = tuple(pht) wcs = SlicedLowLevelWCS(gwcs_3d_galactic_spectral, Ellipsis) @@ -431,20 +462,24 @@ def test_ellipsis_none_types(gwcs_3d_galactic_spectral): assert_equal(wcs.axis_correlation_matrix, [[True, False, True], [False, True, False], [True, False, True]]) - assert wcs.world_axis_object_components == [('celestial', 1, 'spherical.lat'), - ('spectral', 0, 'value'), - ('celestial', 0, 'spherical.lon')] + first_two = [c[:2] for c in wcs.world_axis_object_components] + last_one = [c[2] for c in wcs.world_axis_object_components] + assert first_two == [('celestial', 1), + ('spectral', 0), + ('celestial', 0)] + + assert all([callable(last) for last in last_one]) assert wcs.world_axis_object_classes['celestial'][0] is SkyCoord assert wcs.world_axis_object_classes['celestial'][1] == () assert isinstance(wcs.world_axis_object_classes['celestial'][2]['frame'], Galactic) - assert wcs.world_axis_object_classes['celestial'][2]['unit'] == (u.deg, u.deg) + assert tuple(wcs.world_axis_object_classes['celestial'][2]['unit']) == (u.deg, u.deg) - assert_allclose(wcs.pixel_to_world_values(29, 39, 44), (10, 20, 25)) - assert_allclose(wcs.array_index_to_world_values(44, 39, 29), (10, 20, 25)) + assert_allclose(wcs.pixel_to_world_values(29, 39, 44), (80, 20, 205)) + assert_allclose(wcs.array_index_to_world_values(44, 39, 29), (80, 20, 205)) - assert_allclose(wcs.world_to_pixel_values(10, 20, 25), (29., 39., 44.)) - assert_equal(wcs.world_to_array_index_values(10, 20, 25), (44, 39, 29)) + assert_allclose(wcs.world_to_pixel_values(80, 20, 205), (29., 39., 44.)) + assert_equal(wcs.world_to_array_index_values(80, 20, 205), (44, 39, 29)) assert_equal(wcs.pixel_bounds, [(-1, 35), (-2, 45), (5, 50)]) diff --git a/gwcs/tests/test_coordinate_systems.py b/gwcs/tests/test_coordinate_systems.py index 967657f8..75f2895f 100644 --- a/gwcs/tests/test_coordinate_systems.py +++ b/gwcs/tests/test_coordinate_systems.py @@ -10,11 +10,12 @@ from astropy.tests.helper import assert_quantity_allclose from astropy.modeling import models as m from astropy.wcs.wcsapi.fitswcs import CTYPE_TO_UCD1 -from astropy.coordinates import StokesCoord +from astropy.coordinates import StokesCoord, SpectralCoord from .. import WCS from .. import coordinate_frames as cf +from astropy.wcs.wcsapi.high_level_api import values_to_high_level_objects, high_level_objects_to_values import astropy astropy_version = astropy.__version__ @@ -33,7 +34,7 @@ focal = cf.Frame2D(name='focal', axes_order=(0, 1), unit=(u.m, u.m)) spec1 = cf.SpectralFrame(name='freq', unit=[u.Hz, ], axes_order=(2, )) -spec2 = cf.SpectralFrame(name='wave', unit=[u.m, ], axes_order=(2, ), axes_names=('lambda', )) +spec2 = cf.SpectralFrame(name='wave', unit=[u.m, ], axes_order=(2, ), axes_names=('lambda',)) spec3 = cf.SpectralFrame(name='energy', unit=[u.J, ], axes_order=(2, )) spec4 = cf.SpectralFrame(name='pixel', unit=[u.pix, ], axes_order=(2, )) spec5 = cf.SpectralFrame(name='speed', unit=[u.m/u.s, ], axes_order=(2, )) @@ -55,6 +56,19 @@ inputs3 = [(xscalar, yscalar, xscalar), (xarr, yarr, xarr)] +@pytest.fixture(autouse=True, scope="module") +def serialized_classes(): + """ + In the rest of this test file we are passing the CoordinateFrame object to + astropy helper functions as if they were a low level WCS object. + + This little patch means that this works. + """ + cf.CoordinateFrame.serialized_classes = False + yield + del cf.CoordinateFrame.serialized_classes + + def test_units(): assert(comp1.unit == (u.deg, u.deg, u.Hz)) assert(comp2.unit == (u.m, u.m, u.m)) @@ -64,19 +78,34 @@ def test_units(): assert(comp.unit == (u.deg, u.deg, u.Hz, u.m)) +# These two functions fake the old methods on CoordinateFrame to reduce the +# amount of refactoring that needed doing in these tests. +def coordinates(*inputs, frame): + results = values_to_high_level_objects(*inputs, low_level_wcs=frame) + if isinstance(results, list) and len(results) == 1: + return results[0] + return results + + +def coordinate_to_quantity(*inputs, frame): + results = high_level_objects_to_values(*inputs, low_level_wcs=frame) + results = [r << unit for r, unit in zip(results, frame.unit)] + return results + + @pytest.mark.parametrize('inputs', inputs2) def test_coordinates_spatial(inputs): - sky_coo = icrs.coordinates(*inputs) + sky_coo = coordinates(*inputs, frame=icrs) assert isinstance(sky_coo, coord.SkyCoord) assert_allclose((sky_coo.ra.value, sky_coo.dec.value), inputs) - focal_coo = focal.coordinates(*inputs) + focal_coo = coordinates(*inputs, frame=focal) assert_allclose([coo.value for coo in focal_coo], inputs) assert [coo.unit for coo in focal_coo] == [u.m, u.m] @pytest.mark.parametrize('inputs', inputs1) def test_coordinates_spectral(inputs): - wave = spec2.coordinates(inputs) + wave = coordinates(inputs, frame=spec2) assert_allclose(wave.value, inputs) assert wave.unit == 'meter' assert isinstance(wave, u.Quantity) @@ -85,7 +114,7 @@ def test_coordinates_spectral(inputs): @pytest.mark.parametrize('inputs', inputs3) def test_coordinates_composite(inputs): frame = cf.CompositeFrame([icrs, spec2]) - result = frame.coordinates(*inputs) + result = coordinates(*inputs, frame=frame) assert isinstance(result[0], coord.SkyCoord) assert_allclose((result[0].ra.value, result[0].dec.value), inputs[:2]) assert_allclose(result[1].value, inputs[2]) @@ -96,7 +125,7 @@ def test_coordinates_composite_order(): dist = cf.CoordinateFrame(name='distance', naxes=1, axes_type=["SPATIAL"], unit=[u.m, ], axes_order=(1, )) frame = cf.CompositeFrame([time, dist]) - result = frame.coordinates(0, 0) + result = coordinates(0, 0, frame=frame) assert result[0] == Time("2011-01-01T00:00:00") assert u.allclose(result[1], 0*u.m) @@ -104,7 +133,8 @@ def test_coordinates_composite_order(): def test_bare_baseframe(): # This is a regression test for the following call: frame = cf.CoordinateFrame(1, "SPATIAL", (0,), unit=(u.km,)) - assert u.allclose(frame.coordinate_to_quantity((1*u.m,)), 1*u.m) + quantity = coordinate_to_quantity(1*u.m, frame=frame) + assert u.allclose(quantity, 1*u.m) # Now also setup the same situation through the whole call stack to be safe. w = WCS(forward_transform=m.Tabular1D(points=np.arange(10)*u.pix, @@ -158,55 +188,52 @@ def test_base_coordinate(): assert frame.name == 'CoordinateFrame' frame = cf.CoordinateFrame(name="CustomFrame", naxes=2, axes_type=("SPATIAL", "SPATIAL"), - axes_order=(0, 1)) + axes_order=(0, 1), + unit=(u.deg, u.arcsec)) assert frame.name == 'CustomFrame' frame.name = "DeLorean" assert frame.name == 'DeLorean' - q1, q2 = frame.coordinate_to_quantity(12 * u.deg, 3 * u.arcsec) + q1, q2 = coordinate_to_quantity(12 * u.deg, 3 * u.arcsec, frame=frame) assert_quantity_allclose(q1, 12 * u.deg) assert_quantity_allclose(q2, 3 * u.arcsec) - q1, q2 = frame.coordinate_to_quantity((12 * u.deg, 3 * u.arcsec)) + q1, q2 = coordinate_to_quantity(*(12 * u.deg, 3 * u.arcsec), frame=frame) assert_quantity_allclose(q1, 12 * u.deg) assert_quantity_allclose(q2, 3 * u.arcsec) def test_temporal_relative(): t = cf.TemporalFrame(reference_frame=Time("2018-01-01T00:00:00"), unit=u.s) - assert t.coordinates(10) == Time("2018-01-01T00:00:00") + 10 * u.s - assert t.coordinates(10 * u.s) == Time("2018-01-01T00:00:00") + 10 * u.s + assert coordinates(10, frame=t) == Time("2018-01-01T00:00:00") + 10 * u.s + assert coordinates(10 * u.s, frame=t) == Time("2018-01-01T00:00:00") + 10 * u.s - a = t.coordinates((10, 20)) + a = coordinates((10, 20), frame=t) assert a[0] == Time("2018-01-01T00:00:00") + 10 * u.s assert a[1] == Time("2018-01-01T00:00:00") + 20 * u.s t = cf.TemporalFrame(reference_frame=Time("2018-01-01T00:00:00")) - assert t.coordinates(10 * u.s) == Time("2018-01-01T00:00:00") + 10 * u.s - assert t.coordinates(TimeDelta(10, format='sec')) == Time("2018-01-01T00:00:00") + 10 * u.s + assert coordinates(10 * u.s, frame=t) == Time("2018-01-01T00:00:00") + 10 * u.s + assert coordinates(TimeDelta(10, format='sec'), frame=t) == Time("2018-01-01T00:00:00") + 10 * u.s - a = t.coordinates((10, 20) * u.s) + a = coordinates((10, 20) * u.s, frame=t) assert a[0] == Time("2018-01-01T00:00:00") + 10 * u.s assert a[1] == Time("2018-01-01T00:00:00") + 20 * u.s -@pytest.mark.skipif(astropy_version<"4", reason="Requires astropy 4.0 or higher") def test_temporal_absolute(): t = cf.TemporalFrame(reference_frame=Time([], format='isot')) - assert t.coordinates("2018-01-01T00:00:00") == Time("2018-01-01T00:00:00") + assert coordinates("2018-01-01T00:00:00", frame=t) == Time("2018-01-01T00:00:00") - a = t.coordinates(("2018-01-01T00:00:00", "2018-01-01T00:10:00")) + a = coordinates(("2018-01-01T00:00:00", "2018-01-01T00:10:00"), frame=t) assert a[0] == Time("2018-01-01T00:00:00") assert a[1] == Time("2018-01-01T00:10:00") t = cf.TemporalFrame(reference_frame=Time([], scale='tai', format='isot')) - assert t.coordinates("2018-01-01T00:00:00") == Time("2018-01-01T00:00:00", scale='tai') + assert coordinates("2018-01-01T00:00:00", frame=t) == Time("2018-01-01T00:00:00", scale='tai') @pytest.mark.parametrize('inp', [ - (10 * u.deg, 20 * u.deg), - ((10 * u.deg, 20 * u.deg),), - (u.Quantity([10, 20], u.deg),), (coord.SkyCoord(10 * u.deg, 20 * u.deg, frame=coord.ICRS),), # This is the same as 10,20 in ICRS (coord.SkyCoord(119.26936774, -42.79039286, unit=u.deg, frame='galactic'),) @@ -214,54 +241,40 @@ def test_temporal_absolute(): def test_coordinate_to_quantity_celestial(inp): cel = cf.CelestialFrame(reference_frame=coord.ICRS(), axes_order=(0, 1)) - lon, lat = cel.coordinate_to_quantity(*inp) + lon, lat = coordinate_to_quantity(*inp, frame=cel) assert_quantity_allclose(lon, 10 * u.deg) assert_quantity_allclose(lat, 20 * u.deg) with pytest.raises(ValueError): - cel.coordinate_to_quantity(10 * u.deg, 2 * u.deg, 3 * u.deg) + coordinate_to_quantity(10 * u.deg, 2 * u.deg, 3 * u.deg, frame=cel) with pytest.raises(ValueError): - cel.coordinate_to_quantity((1, 2)) + coordinate_to_quantity((1, 2), frame=cel) @pytest.mark.parametrize('inp', [ - (100,), - (100 * u.nm,), - (0.1 * u.um,), + (SpectralCoord(100 * u.nm),), + (SpectralCoord(0.1 * u.um),), ]) def test_coordinate_to_quantity_spectral(inp): spec = cf.SpectralFrame(unit=u.nm, axes_order=(1, )) - wav = spec.coordinate_to_quantity(*inp) + wav = coordinate_to_quantity(*inp, frame=spec) assert_quantity_allclose(wav, 100 * u.nm) @pytest.mark.parametrize('inp', [ (Time("2011-01-01T00:00:10"),), - (10 * u.s,) ]) -@pytest.mark.skipif(astropy_version<"4", reason="Requires astropy 4.0 or higher.") def test_coordinate_to_quantity_temporal(inp): temp = cf.TemporalFrame(reference_frame=Time("2011-01-01T00:00:00"), unit=u.s) - t = temp.coordinate_to_quantity(*inp) + t = coordinate_to_quantity(*inp, frame=temp) assert_quantity_allclose(t, 10 * u.s) - temp2 = cf.TemporalFrame(reference_frame=Time([], format='isot'), unit=u.s) - - tt = Time("2011-01-01T00:00:00") - t = temp2.coordinate_to_quantity(tt) - - assert t is tt - @pytest.mark.parametrize('inp', [ - (211 * u.AA, 0 * u.s, 0 * u.arcsec, 0 * u.arcsec), - (211 * u.AA, 0 * u.s, (0 * u.arcsec, 0 * u.arcsec)), - (211 * u.AA, 0 * u.s, (0, 0) * u.arcsec), - (211 * u.AA, Time("2011-01-01T00:00:00"), (0, 0) * u.arcsec), - (211 * u.AA, Time("2011-01-01T00:00:00"), coord.SkyCoord(0, 0, unit=u.arcsec)), + (SpectralCoord(211 * u.AA), Time("2011-01-01T00:00:00"), coord.SkyCoord(0, 0, unit=u.arcsec)), ]) def test_coordinate_to_quantity_composite(inp): # Composite @@ -272,29 +285,46 @@ def test_coordinate_to_quantity_composite(inp): comp = cf.CompositeFrame([wave_frame, time_frame, sky_frame]) - coords = comp.coordinate_to_quantity(*inp) + coords = coordinate_to_quantity(*inp, frame=comp) expected = (211 * u.AA, 0 * u.s, 0 * u.arcsec, 0 * u.arcsec) for output, exp in zip(coords, expected): assert_quantity_allclose(output, exp) +def test_coordinate_to_quantity_composite_split(): + inp = ( + SpectralCoord(211 * u.AA), + coord.SkyCoord(0, 0, unit=u.arcsec), + Time("2011-01-01T00:00:00"), + ) + + # Composite + wave_frame = cf.SpectralFrame(axes_order=(1, ), unit=u.AA) + sky_frame = cf.CelestialFrame(axes_order=(2, 0), reference_frame=coord.ICRS()) + time_frame = cf.TemporalFrame( + axes_order=(3,), unit=u.s, reference_frame=Time("2011-01-01T00:00:00")) + + comp = cf.CompositeFrame([wave_frame, sky_frame, time_frame]) + + coords = coordinate_to_quantity(*inp, frame=comp) + + expected = (0 * u.arcsec, 211 * u.AA, 0 * u.arcsec, 0 * u.s) + for output, exp in zip(coords, expected): + assert_quantity_allclose(output, exp) + + def test_stokes_frame(): sf = cf.StokesFrame() - assert sf.coordinates(1) == 'I' - assert sf.coordinates(1 * u.pix) == 'I' - assert sf.coordinate_to_quantity(StokesCoord('I')) == 1 * u.one - assert sf.coordinate_to_quantity(1) == 1 + assert coordinates(1, frame=sf) == 'I' + assert coordinates(1 * u.one, frame=sf) == 'I' + assert coordinate_to_quantity(StokesCoord('I'), frame=sf) == 1 * u.one + assert coordinate_to_quantity(StokesCoord(1), frame=sf) == 1 * u.one -@pytest.mark.parametrize('inp', [ - (211 * u.AA, 0 * u.s, 0 * u.one, 0 * u.one), - (211 * u.AA, 0 * u.s, (0 * u.one, 0 * u.one)), - (211 * u.AA, 0 * u.s, (0, 0) * u.one), - (211 * u.AA, Time("2011-01-01T00:00:00"), (0, 0) * u.one) -]) -def test_coordinate_to_quantity_frame2d_composite(inp): +def test_coordinate_to_quantity_frame2d_composite(): + inp = (SpectralCoord(211 * u.AA), Time("2011-01-01T00:00:00"), 0 * u.one, 0 * u.one) wave_frame = cf.SpectralFrame(axes_order=(0, ), unit=u.AA) time_frame = cf.TemporalFrame( axes_order=(1, ), unit=u.s, reference_frame=Time("2011-01-01T00:00:00")) @@ -303,7 +333,7 @@ def test_coordinate_to_quantity_frame2d_composite(inp): comp = cf.CompositeFrame([wave_frame, time_frame, frame2d]) - coords = comp.coordinate_to_quantity(*inp) + coords = coordinate_to_quantity(*inp, frame=comp) expected = (211 * u.AA, 0 * u.s, 0 * u.one, 0 * u.one) for output, exp in zip(coords, expected): @@ -312,31 +342,24 @@ def test_coordinate_to_quantity_frame2d_composite(inp): def test_coordinate_to_quantity_frame_2d(): frame = cf.Frame2D(unit=(u.one, u.arcsec)) - inp = (1, 2) + inp = (1 * u.one, 2 * u.arcsec) expected = (1 * u.one, 2 * u.arcsec) - result = frame.coordinate_to_quantity(*inp) + result = coordinate_to_quantity(*inp, frame=frame) for output, exp in zip(result, expected): assert_quantity_allclose(output, exp) - inp = (1 * u.one, 2) - expected = (1 * u.one, 2 * u.arcsec) - result = frame.coordinate_to_quantity(*inp) - for output, exp in zip(result, expected): - assert_quantity_allclose(output, exp) - -@pytest.mark.skipif(astropy_version<"4", reason="Requires astropy 4.0 or higher.") def test_coordinate_to_quantity_error(): frame = cf.Frame2D(unit=(u.one, u.arcsec)) with pytest.raises(ValueError): - frame.coordinate_to_quantity(1) + coordinate_to_quantity(1, frame=frame) with pytest.raises(ValueError): - comp1.coordinate_to_quantity((1, 1), 2) + coordinate_to_quantity((1, 1), 2, frame=frame) frame = cf.TemporalFrame(reference_frame=Time([], format='isot'), unit=u.s) with pytest.raises(ValueError): - frame.coordinate_to_quantity(1) + coordinate_to_quantity(1, frame=frame) def test_axis_physical_type(): @@ -352,7 +375,7 @@ def test_axis_physical_type(): assert comp.axis_physical_types == ('pos.eq.ra', 'pos.eq.dec', 'em.freq', 'em.wl') spec6 = cf.SpectralFrame(name='waven', axes_order=(1,), - axis_physical_types='em.wavenumber') + axis_physical_types='em.wavenumber', unit=u.Unit(1)) assert spec6.axis_physical_types == ('em.wavenumber',) t = cf.TemporalFrame(reference_frame=Time("2018-01-01T00:00:00"), unit=u.s) @@ -406,7 +429,7 @@ def test_base_frame(): assert frame.naxes == 1 assert frame.axes_names == ("x",) - frame.coordinate_to_quantity(1, 2) + coordinate_to_quantity(1*u.one, frame=frame) def test_ucd1_to_ctype_not_out_of_sync(caplog): @@ -452,3 +475,50 @@ def test_ucd1_to_ctype(caplog): assert ctype_to_ucd[v] == k assert inv_map['new.repeated.type'] in new_ctype_to_ucd + + +def test_celestial_ordering(): + c1 = cf.CelestialFrame( + reference_frame=coord.ICRS(), + axes_order=(0, 1), + axes_names=("lon", "lat"), + unit=(u.deg, u.arcsec), + axis_physical_types=("custom:lon", "custom:lat"), + ) + c2 = cf.CelestialFrame( + reference_frame=coord.ICRS(), + axes_order=(1, 0), + axes_names=("lon", "lat"), + unit=(u.deg, u.arcsec), + axis_physical_types=("custom:lon", "custom:lat"), + ) + + assert c1.axes_names == ("lon", "lat") + assert c2.axes_names == ("lat", "lon") + + assert c1.unit == (u.deg, u.arcsec) + assert c2.unit == (u.arcsec, u.deg) + + assert c1.axis_physical_types == ("custom:lon", "custom:lat") + assert c2.axis_physical_types == ("custom:lat", "custom:lon") + + +def test_composite_ordering(): + print("boo") + c1 = cf.CelestialFrame( + reference_frame=coord.ICRS(), + axes_order=(1, 0), + axes_names=("lon", "lat"), + unit=(u.deg, u.arcsec), + axis_physical_types=("custom:lon", "custom:lat"), + ) + spec = cf.SpectralFrame( + axes_order=(2,), + axes_names=("spectral",), + unit=u.AA, + ) + comp = cf.CompositeFrame([c1, spec]) + assert comp.axes_names == ("lat", "lon", "spectral") + assert comp.axis_physical_types == ("custom:lat", "custom:lon", "em.wl") + assert comp.unit == (u.arcsec, u.deg, u.AA) + assert comp.axes_order == (1, 0, 2) diff --git a/gwcs/tests/test_region.py b/gwcs/tests/test_region.py index 5fe69f90..6304b786 100644 --- a/gwcs/tests/test_region.py +++ b/gwcs/tests/test_region.py @@ -8,8 +8,9 @@ from numpy.testing import assert_equal, assert_allclose from astropy.modeling import models import pytest -from .. import region, selector -from .. import utils as gwutils +from gwcs import region, selector, WCS +from gwcs import utils as gwutils +from gwcs import coordinate_frames as cf def test_LabelMapperArray_from_vertices_int(): @@ -237,6 +238,10 @@ def test_RegionsSelector(): reg_selector.undefined_transform_value = -100 assert_equal(reg_selector(0, 0), [-100, -100]) + wcs = WCS(forward_transform=reg_selector, output_frame=cf.Frame2D()) + out = wcs(1, 1) + assert out == (-100, -100) + def test_overalpping_ranges(): """ diff --git a/gwcs/tests/test_utils.py b/gwcs/tests/test_utils.py index e69ec536..7f880e65 100644 --- a/gwcs/tests/test_utils.py +++ b/gwcs/tests/test_utils.py @@ -90,21 +90,6 @@ def test_get_axes(): assert not other -def test_isnumerical(): - sky = coord.SkyCoord(1 * u.deg, 2 * u.deg) - assert not gwutils.isnumerical(sky) - - assert not gwutils.isnumerical(2 * u.m) - - assert gwutils.isnumerical(float(0)) - assert gwutils.isnumerical(np.array(0)) - - assert not gwutils.isnumerical(np.array(['s200', '234'])) - - assert gwutils.isnumerical(np.array(0, dtype='>f8')) - assert gwutils.isnumerical(np.array(0, dtype='>i4')) - - def test_get_values(): args = 2 * u.cm units=(u.m, ) diff --git a/gwcs/tests/test_wcs.py b/gwcs/tests/test_wcs.py index c285a7b1..34a1d567 100644 --- a/gwcs/tests/test_wcs.py +++ b/gwcs/tests/test_wcs.py @@ -17,13 +17,12 @@ from astropy.utils.introspection import minversion import asdf -from .. import wcs -from ..wcstools import (wcs_from_fiducial, grid_from_bounding_box, wcs_from_points) -from .. import coordinate_frames as cf -from .. import utils -from ..utils import CoordinateFrameError -from .utils import _gwcs_from_hst_fits_wcs -from . import data +from gwcs import wcs +from gwcs.wcstools import (wcs_from_fiducial, grid_from_bounding_box, wcs_from_points) +from gwcs import coordinate_frames as cf +from gwcs.utils import CoordinateFrameError +from . utils import _gwcs_from_hst_fits_wcs +from gwcs.tests import data data_path = os.path.split(os.path.abspath(data.__file__))[0] @@ -33,7 +32,7 @@ m2 = models.Scale(2) & models.Scale(-2) m = m1 | m2 -icrs = cf.CelestialFrame(reference_frame=coord.ICRS(), name='icrs') +icrs = cf.CelestialFrame(reference_frame=coord.ICRS(), name='icrs', unit=(u.deg, u.deg)) detector = cf.Frame2D(name='detector', axes_order=(0, 1)) focal = cf.Frame2D(name='focal', axes_order=(0, 1), unit=(u.m, u.m)) spec = cf.SpectralFrame(name='wave', unit=[u.m, ], axes_order=(2, ), axes_names=('lambda', )) @@ -206,48 +205,6 @@ def test_backward_transform_has_inverse(): assert_allclose(w.backward_transform.inverse(1, 2), w(1, 2)) -def test_return_coordinates(): - """Test converting to coordinate objects or quantities.""" - w = wcs.WCS(pipe[:]) - x = 1 - y = 2.3 - numerical_result = (26.8, -0.6) - # Celestial frame - num_plus_output = w(x, y, with_units=True) - output_quant = w.output_frame.coordinate_to_quantity(num_plus_output) - assert_allclose(w(x, y), numerical_result) - assert_allclose(utils.get_values(w.unit, *output_quant), numerical_result) - assert_allclose(w.invert(num_plus_output), (x, y)) - assert isinstance(num_plus_output, coord.SkyCoord) - - # Spectral frame - poly = models.Polynomial1D(1, c0=1, c1=2) - w = wcs.WCS(forward_transform=poly, output_frame=spec) - numerical_result = poly(y) - num_plus_output = w(y, with_units=True) - output_quant = w.output_frame.coordinate_to_quantity(num_plus_output) - assert_allclose(utils.get_values(w.unit, output_quant), numerical_result) - assert isinstance(num_plus_output, u.Quantity) - - # CompositeFrame - [celestial, spectral] - output_frame = cf.CompositeFrame(frames=[icrs, spec]) - transform = m1 & poly - w = wcs.WCS(forward_transform=transform, output_frame=output_frame) - numerical_result = transform(x, y, y) - num_plus_output = w(x, y, y, with_units=True) - output_quant = w.output_frame.coordinate_to_quantity(*num_plus_output) - assert_allclose(utils.get_values(w.unit, *output_quant), numerical_result) - - # CompositeFrame - [celestial, Stokes] - output_frame = cf.CompositeFrame(frames=[icrs, stokes]) - transform = m1 & models.Identity(1) - w = wcs.WCS(forward_transform=transform, output_frame=output_frame) - numerical_result = transform(x, y, y) - num_plus_output = w(x, y, y, with_units=True) - output_quant = w.output_frame.coordinate_to_quantity(*num_plus_output) - assert_allclose(utils.get_values(w.unit, *output_quant), numerical_result) - - def test_from_fiducial_sky(): sky = coord.SkyCoord(1.63 * u.radian, -72.4 * u.deg, frame='fk5') tan = models.Pix2Sky_TAN() @@ -267,7 +224,7 @@ def test_from_fiducial_composite(): assert isinstance(w.cube_frame.frames[1].reference_frame, coord.FK5) assert_allclose(w(1, 1, 1), (1.5, 96.52373368309931, -71.37420187296995)) # test returning coordinate objects with composite output_frame - res = w(1, 2, 2, with_units=True) + res = w.pixel_to_world(1, 2, 2) assert_allclose(res[0], u.Quantity(1.5 * u.micron)) assert isinstance(res[1], coord.SkyCoord) assert_allclose(res[1].ra.value, 99.329496642319) @@ -279,7 +236,7 @@ def test_from_fiducial_composite(): assert_allclose(w(1, 1, 1), (11.5, 99.97738475762152, -72.29039139739766)) # test coordinate object output - coord_result = w(1, 1, 1, with_units=True) + coord_result = w.pixel_to_world(1, 1, 1) assert_allclose(coord_result[0], u.Quantity(11.5 * u.micron)) @@ -310,13 +267,16 @@ def test_bounding_box(): with pytest.raises(ValueError): w.bounding_box = ((1, 5), (2, 6)) + +def test_bounding_box_units(): # Test that bounding_box with quantities can be assigned and evaluates bb = ((1 * u.pix, 5 * u.pix), (2 * u.pix, 6 * u.pix)) trans = models.Shift(10 * u .pix) & models.Shift(2 * u.pix) pipeline = [('detector', trans), ('sky', None)] w = wcs.WCS(pipeline) w.bounding_box = bb - assert_allclose(w(-1*u.pix, -1*u.pix), (np.nan, np.nan)) + world = w(-1*u.pix, -1*u.pix) + assert_allclose(world, (np.nan, np.nan)) def test_compound_bounding_box(): @@ -641,7 +601,7 @@ def test_inverse(self): def test_back_coordinates(self): sky_coord = self.wcs(1, 2, with_units=True) - res = self.wcs.transform('sky', 'focal', sky_coord) + res = self.wcs.transform('sky', 'focal', sky_coord, with_units=True) assert_allclose(res, self.wcs.get_transform('detector', 'focal')(1, 2)) def test_units(self): @@ -761,7 +721,7 @@ def test_to_fits_sip_composite_frame(gwcs_cube_with_separable_spectral): assert fw_hdr['NAXIS2'] == 64 fw = astwcs.WCS(fw_hdr) - gskyval = w(1, 60, 55, with_units=True)[0] + gskyval = w.pixel_to_world(1, 60, 55)[1] fskyval = fw.all_pix2world(1, 60, 0) fskyval = [float(fskyval[ra_axis - 1]), float(fskyval[dec_axis - 1])] assert np.allclose([gskyval.ra.value, gskyval.dec.value], fskyval) @@ -774,7 +734,7 @@ def test_to_fits_sip_composite_frame_galactic(gwcs_3d_galactic_spectral): assert fw_hdr['CTYPE1'] == 'GLAT-TAN' fw = astwcs.WCS(fw_hdr) - gskyval = w(7, 8, 9, with_units=True)[0] + gskyval = w.pixel_to_world(7, 8, 9)[0] assert np.allclose([gskyval.b.value, gskyval.l.value], fw.all_pix2world(7, 9, 0), atol=1e-3) @@ -1338,3 +1298,84 @@ def test_spatial_spectral_stokes(): def test_wcs_str(): w = wcs.WCS(output_frame="icrs") assert 'icrs' in str(w) + + +def test_split_frame_wcs(): + # Setup a WCS where the pixel & world axes are (lat, wave, lon) + + # We setup a model which is pretending to be a celestial transform. Note + # that we are pretending that this model is ordered lon, lat because that's + # what the projections require in astropy. + + # Input is (lat, wave, lon) + # lat: multuply by 20 arcsec, lon: multiply by 15 deg + # result should be 20 arcsec, 10nm, 45 deg + spatial = models.Multiply(20*u.arcsec/u.pix) & models.Multiply(15*u.deg/u.pix) + compound = models.Linear1D(intercept=0*u.nm, slope=10*u.nm/u.pix) & spatial + # This forward transforms uses mappings to be (lat, wave, lon) + forward = models.Mapping((1, 0, 2)) | compound | models.Mapping((1, 0, 2)) + + # Setup the output frame + celestial_frame = cf.CelestialFrame(axes_order=(2, 0), unit=(u.deg, u.arcsec), + reference_frame=coord.ICRS(), axes_names=('lon', 'lat')) + #celestial_frame = cf.CelestialFrame(axes_order=(2, 0), unit=(u.arcsec, u.deg), + # reference_frame=coord.ICRS()) + spectral_frame = cf.SpectralFrame(axes_order=(1,), unit=u.nm, axes_names='wave') + output_frame = cf.CompositeFrame([spectral_frame, celestial_frame]) + #output_frame = cf.CompositeFrame([celestial_frame, spectral_frame]) + + input_frame = cf.CoordinateFrame(3, ["PIXEL"]*3, + axes_order=list(range(3)), unit=[u.pix]*3) + + iwcs = wcs.WCS(forward, input_frame, output_frame) + input_pixel = [1*u.pix, 1*u.pix, 3*u.pix] + output_world = iwcs.pixel_to_world_values(*input_pixel) + output_pixel = iwcs.world_to_pixel_values(*output_world) + assert_allclose(output_pixel, u.Quantity(input_pixel).to_value(u.pix)) + + expected_world = [20*u.arcsec, 10*u.nm, 45*u.deg] + #expected_world = [15*u.deg, 20*u.nm, 60*u.arcsec] + for expected, output in zip(expected_world, output_world): + assert_allclose(output, expected.value) + + world_obj = iwcs.pixel_to_world(*input_pixel) + assert isinstance(world_obj[0], coord.SkyCoord) + assert isinstance(world_obj[1], coord.SpectralCoord) + + assert u.allclose(world_obj[0].spherical.lat, expected_world[0]) + assert u.allclose(world_obj[0].spherical.lon, expected_world[2]) + assert u.allclose(world_obj[1], expected_world[1]) + + obj_pixel = iwcs.world_to_pixel(*world_obj) + assert_allclose(obj_pixel, u.Quantity(input_pixel).to_value(u.pix)) + + +def test_reordered_celestial(): + # This is a spatial model which is ordered lat, lon for the purposes of this test. + # Expected lat=45 deg, lon=20 arcsec + spatial = models.Multiply(20*u.arcsec/u.pix) & models.Multiply(15*u.deg/u.pix) | models.Mapping((1,0)) + + celestial_frame = cf.CelestialFrame(axes_order=(1, 0), unit=(u.arcsec, u.deg), + reference_frame=coord.ICRS()) + + input_frame = cf.CoordinateFrame(2, ["PIXEL"]*2, + axes_order=list(range(2)), unit=[u.pix]*2) + + iwcs = wcs.WCS(spatial, input_frame, celestial_frame) + + input_pixel = [1*u.pix, 3*u.pix] + output_world = iwcs.pixel_to_world_values(*input_pixel) + output_pixel = iwcs.world_to_pixel_values(*output_world) + assert_allclose(output_pixel, u.Quantity(input_pixel).to_value(u.pix)) + + expected_world = [45*u.deg, 20*u.arcsec]#, 45*u.deg] + assert_allclose(output_world, [e.value for e in expected_world]) + + world_obj = iwcs.pixel_to_world(*input_pixel) + assert isinstance(world_obj, coord.SkyCoord) + + assert u.allclose(world_obj.spherical.lat, expected_world[0]) + assert u.allclose(world_obj.spherical.lon, expected_world[1]) + + obj_pixel = iwcs.world_to_pixel(world_obj) + assert_allclose(obj_pixel, u.Quantity(input_pixel).to_value(u.pix)) diff --git a/gwcs/utils.py b/gwcs/utils.py index 104558cf..c04c105d 100644 --- a/gwcs/utils.py +++ b/gwcs/utils.py @@ -11,7 +11,6 @@ from astropy.io import fits from astropy import coordinates as coords from astropy import units as u -from astropy.time import Time, TimeDelta from astropy.wcs import Celprm @@ -465,19 +464,13 @@ def create_projection_transform(projcode): return projklass(**projparams) -def isnumerical(val): +def is_high_level(*args, low_level_wcs): """ - Determine if a value is numerical (number or np.array of numbers). + Determine if args matches the high level classes as defined by + ``low_level_wcs``. """ - isnum = True - if isinstance(val, coords.SkyCoord): - isnum = False - elif isinstance(val, u.Quantity): - isnum = False - elif isinstance(val, (Time, TimeDelta)): - isnum = False - elif (isinstance(val, np.ndarray) - and not np.issubdtype(val.dtype, np.floating) - and not np.issubdtype(val.dtype, np.integer)): - isnum = False - return isnum + if len(args) != len(low_level_wcs.world_axis_object_classes): + return False + + return all([type(arg) is waoc[0] + for arg, waoc in zip(args, low_level_wcs.world_axis_object_classes.values())]) diff --git a/gwcs/wcs.py b/gwcs/wcs.py index 6dede33e..96bcad1b 100644 --- a/gwcs/wcs.py +++ b/gwcs/wcs.py @@ -3,9 +3,11 @@ import itertools import warnings +import astropy.units as u import astropy.io.fits as fits import numpy as np import numpy.linalg as npla +from astropy import utils as astutil from astropy.modeling import fix_inputs, projections from astropy.modeling.bounding_box import CompoundBoundingBox from astropy.modeling.bounding_box import ModelBoundingBox as Bbox @@ -15,6 +17,7 @@ Sky2Pix_TAN) from astropy.wcs.utils import celestial_frame_to_wcs, proj_plane_pixel_scales from scipy import linalg, optimize +from astropy.wcs.wcsapi.high_level_api import high_level_objects_to_values, values_to_high_level_objects from . import coordinate_frames as cf from . import utils @@ -128,7 +131,6 @@ class WCS(GWCSAPIMixin): def __init__(self, forward_transform=None, input_frame='detector', output_frame=None, name=""): - #self.low_level_wcs = self self._approx_inverse = None self._available_frames = [] self._pipeline = [] @@ -254,9 +256,7 @@ def forward_transform(self): Return the total forward transform - from input to output coordinate frame. """ - if self._pipeline: - #return functools.reduce(lambda x, y: x | y, [step[1] for step in self._pipeline[: -1]]) return functools.reduce(lambda x, y: x | y, [step.transform for step in self._pipeline[:-1]]) else: return None @@ -318,6 +318,19 @@ def _get_frame_name(self, frame): frame_obj = frame return name, frame_obj + def _add_units_input(self, arrays, frame): + if frame is not None: + return tuple(u.Quantity(array, unit) for array, unit in zip(arrays, frame.unit)) + + return arrays + + def _remove_units_input(self, arrays, frame): + if frame is not None: + return tuple(array.to_value(unit) if isinstance(array, u.Quantity) else array + for array, unit in zip(arrays, frame.unit)) + + return arrays + def __call__(self, *args, **kwargs): """ Executes the forward transform. @@ -325,11 +338,6 @@ def __call__(self, *args, **kwargs): args : float or array-like Inputs in the input coordinate system, separate inputs for each dimension. - with_units : bool - If ``True`` returns a `~astropy.coordinates.SkyCoord` or - `~astropy.coordinates.SpectralCoord` object, by using the units of - the output cooridnate frame. - Optional, default=False. with_bounding_box : bool, optional If True(default) values in the result which correspond to any of the inputs being outside the bounding_box are set @@ -337,16 +345,42 @@ def __call__(self, *args, **kwargs): fill_value : float, optional Output value for inputs outside the bounding_box (default is np.nan). + with_units : bool, optional + If ``True`` then high level Astropy objects will be returned. + Optional, default=False. """ - transform = self.forward_transform + with_units = kwargs.pop("with_units", False) + + results = self._call_forward(*args, **kwargs) + + if with_units: + if not astutil.isiterable(results): + results = (results,) + high_level = values_to_high_level_objects(*results, low_level_wcs=self) + if len(high_level) == 1: + high_level = high_level[0] + return high_level + return results + + def _call_forward(self, *args, from_frame=None, to_frame=None, + with_bounding_box=True, fill_value=np.nan, **kwargs): + """ + Executes the forward transform, but values only. + """ + if from_frame is None and to_frame is None: + transform = self.forward_transform + else: + transform = self.get_transform(from_frame, to_frame) + if transform is None: raise NotImplementedError("WCS.forward_transform is not implemented.") - with_units = kwargs.pop("with_units", False) - if 'with_bounding_box' not in kwargs: - kwargs['with_bounding_box'] = True - if 'fill_value' not in kwargs: - kwargs['fill_value'] = np.nan + # Validate that the input type matches what the transform expects + input_is_quantity = any((isinstance(a, u.Quantity) for a in args)) + if not input_is_quantity and transform.uses_quantity: + args = self._add_units_input(args, self.input_frame) + if not transform.uses_quantity and input_is_quantity: + args = self._remove_units_input(args, self.input_frame) if self.bounding_box is not None: # Currently compound models do not attempt to combine individual model @@ -354,15 +388,10 @@ def __call__(self, *args, **kwargs): # before evaluating it. The order Model.bounding_box is reversed. transform.bounding_box = self.bounding_box - result = transform(*args, **kwargs) - - if with_units: - if self.output_frame.naxes == 1: - result = self.output_frame.coordinates(result) - else: - result = self.output_frame.coordinates(*result) - - return result + return transform(*args, + with_bounding_box=with_bounding_box, + fill_value=fill_value, + **kwargs) def in_image(self, *args, **kwargs): """ @@ -446,9 +475,8 @@ def invert(self, *args, **kwargs): Output value for inputs outside the bounding_box (default is ``np.nan``). with_units : bool, optional - If ``True`` returns a `~astropy.coordinates.SkyCoord` or - `~astropy.coordinates.SpectralCoord` object, by using the units of - the output cooridnate frame. Default is `False`. + If ``True`` then high level astropy object (i.e. ``Quantity``) will be returned. + Optional, default=False. Other Parameters ---------------- @@ -461,44 +489,47 @@ def invert(self, *args, **kwargs): result : tuple or value Returns a tuple of scalar or array values for each axis. Unless ``input_frame.naxes == 1`` when it shall return the value. + The return type will be `~astropy.units.Quantity` objects if the + transform returns ``Quantity`` objects, else values. """ - with_units = kwargs.pop('with_units', False) + if utils.is_high_level(*args, low_level_wcs=self): + args = high_level_objects_to_values(*args, low_level_wcs=self) - if not utils.isnumerical(args[0]): - args = self.output_frame.coordinate_to_quantity(*args) - if self.output_frame.naxes == 1: - args = [args] - try: - if not self.backward_transform.uses_quantity: - args = utils.get_values(self.output_frame.unit, *args) - except (NotImplementedError, KeyError): - args = utils.get_values(self.output_frame.unit, *args) + results = self._call_backward(*args, **kwargs) - if 'with_bounding_box' not in kwargs: - kwargs['with_bounding_box'] = True + with_units = kwargs.pop('with_units', False) + if with_units: + high_level = values_to_high_level_objects(*results, low_level_wcs=self) + if len(high_level) == 1: + high_level = high_level[0] + return high_level - if 'fill_value' not in kwargs: - kwargs['fill_value'] = np.nan + return results + def _call_backward(self, *args, with_bounding_box=True, fill_value=np.nan, **kwargs): try: + transform = self.backward_transform + # Validate that the input type matches what the transform expects + input_is_quantity = any((isinstance(a, u.Quantity) for a in args)) + if not input_is_quantity and transform.uses_quantity: + args = self._add_units_input(args, self.output_frame) + if not transform.uses_quantity and input_is_quantity: + args = self._remove_units_input(args, self.output_frame) + # remove iterative inverse-specific keyword arguments: akwargs = {k: v for k, v in kwargs.items() if k not in _ITER_INV_KWARGS} - result = self.backward_transform(*args, **akwargs) + result = transform(*args, with_bounding_box=with_bounding_box, fill_value=fill_value, **akwargs) except (NotImplementedError, KeyError): - result = self.numerical_inverse(*args, **kwargs, with_units=with_units) + # Always strip units for numerical inverse + args = self._remove_units_input(args, self.output_frame) + result = self.numerical_inverse(*args, with_bounding_box=with_bounding_box, fill_value=fill_value, **kwargs) - if with_units and self.input_frame: - if self.input_frame.naxes == 1: - return self.input_frame.coordinates(result) - else: - return self.input_frame.coordinates(*result) - else: - return result + return result def numerical_inverse(self, *args, tolerance=1e-5, maxiter=50, adaptive=True, detect_divergence=True, quiet=True, with_bounding_box=True, - fill_value=np.nan, with_units=False, **kwargs): + fill_value=np.nan, **kwargs): """ Invert coordinates from output frame to input frame using numerical inverse. @@ -525,11 +556,6 @@ def numerical_inverse(self, *args, tolerance=1e-5, maxiter=50, adaptive=True, fill_value : float, optional Output value for inputs outside the bounding_box (default is ``np.nan``). - with_units : bool, optional - If ``True`` returns a `~astropy.coordinates.SkyCoord` or - `~astropy.coordinates.SpectralCoord` object, by using the units of - the output cooridnate frame. Default is `False`. - tolerance : float, optional *Absolute tolerance* of solution. Iteration terminates when the iterative solver estimates that the "true solution" is @@ -734,11 +760,8 @@ def numerical_inverse(self, *args, tolerance=1e-5, maxiter=50, adaptive=True, [2.76552923e-05 1.14789013e-05]] """ - if not utils.isnumerical(args[0]): - args = self.output_frame.coordinate_to_quantity(*args) - if self.output_frame.naxes == 1: - args = [args] - args = utils.get_values(self.output_frame.unit, *args) + if kwargs.pop("with_units", False): + raise ValueError("Support for with_units in numerical_inverse has been removed, use inverse") args_shape = np.shape(args) nargs = args_shape[0] @@ -808,13 +831,7 @@ def numerical_inverse(self, *args, tolerance=1e-5, maxiter=50, adaptive=True, result = tuple(np.reshape(result, args_shape)) - if with_units and self.input_frame: - if self.input_frame.naxes == 1: - return self.input_frame.coordinates(result) - else: - return self.input_frame.coordinates(*result) - else: - return result + return result def _vectorized_fixed_point(self, pix0, world, tolerance, maxiter, adaptive, detect_divergence, quiet, @@ -1098,33 +1115,20 @@ def transform(self, from_frame, to_frame, *args, **kwargs): fill_value : float, optional Output value for inputs outside the bounding_box (default is np.nan). """ - transform = self.get_transform(from_frame, to_frame) - if not utils.isnumerical(args[0]): - inp_frame = getattr(self, from_frame) - args = inp_frame.coordinate_to_quantity(*args) - if not transform.uses_quantity: - args = utils.get_values(inp_frame.unit, *args) + # Determine if the transform is actually an inverse + from_ind = self._get_frame_index(from_frame) + to_ind = self._get_frame_index(to_frame) + backward = to_ind < from_ind with_units = kwargs.pop("with_units", False) - if 'with_bounding_box' not in kwargs: - kwargs['with_bounding_box'] = True - if 'fill_value' not in kwargs: - kwargs['fill_value'] = np.nan + if with_units and backward: + args = high_level_objects_to_values(*args, low_level_wcs=self) - result = transform(*args, **kwargs) + results = self._call_forward(*args, from_frame=from_frame, to_frame=to_frame, **kwargs) - if with_units: - to_frame_name, to_frame_obj = self._get_frame_name(to_frame) - if to_frame_obj is not None: - if to_frame_obj.naxes == 1: - result = to_frame_obj.coordinates(result) - else: - result = to_frame_obj.coordinates(*result) - else: - raise TypeError("Coordinate objects could not be created because" - "frame {0} is not defined.".format(to_frame_name)) - - return result + if with_units and not backward: + return values_to_high_level_objects(*results, low_level_wcs=self) + return results @property def available_frames(self): diff --git a/gwcs/wcstools.py b/gwcs/wcstools.py index 20987e75..3340eef8 100644 --- a/gwcs/wcstools.py +++ b/gwcs/wcstools.py @@ -303,7 +303,7 @@ def wcs_from_points(xy, world_coords, proj_point='center', "Only one of {} is supported.".format(polynomial_type, supported_poly_types.keys())) - skyrot = models.RotateCelestial2Native(crval[0].deg, crval[1].deg, 180) + skyrot = models.RotateCelestial2Native(crval[0].to_value(u.deg), crval[1].to_value(u.deg), 180) trans = (skyrot | projection) projection_x, projection_y = trans(lon, lat) poly = supported_poly_types[polynomial_type](poly_degree) diff --git a/tox.ini b/tox.ini index c1b112eb..d72d5e39 100644 --- a/tox.ini +++ b/tox.ini @@ -50,16 +50,19 @@ description = warnings: treating warnings as errors cov: with coverage xdist: using parallel processing -passenv = +pass_env = HOME GITHUB_* TOXENV CI CODECOV_* DISPLAY + CC + LOCALE_ARCHIVE + LC_ALL + jwst,romancal: CRDS_* set_env = dev: PIP_EXTRA_INDEX_URL = https://pypi.anaconda.org/astropy/simple https://pypi.anaconda.org/liberfa/simple https://pypi.anaconda.org/scientific-python-nightly-wheels/simple - args_are_paths = false change_dir = pyargs: {env:HOME} extras = @@ -72,8 +75,6 @@ deps = romancal: romancal[test] @ git+https://github.com/spacetelescope/romancal.git numpy123: numpy==1.23.* numpy125: numpy==1.25.* -pass_env = - jwst,romancal: CRDS_* commands_pre = dev: pip install -r requirements-dev.txt -U --upgrade-strategy eager pip freeze