Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Respect the bounding_box in inverse transforms #498

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
0.22.0 (unreleased)
-------------------

- Fixed a bug where evaluating the inverse transform did not
respect the bounding box. [#498]

0.21.0 (2024-03-10)
-------------------
Expand Down
8 changes: 7 additions & 1 deletion gwcs/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,13 @@
be returned in the ``(x, y)`` order, where for an image, ``x`` is the
horizontal coordinate and ``y`` is the vertical coordinate.
"""
world_arrays = self._add_units_input(world_arrays, self.backward_transform, self.output_frame)
try:
backward_transform = self.backward_transform
world_arrays = self._add_units_input(world_arrays,
backward_transform,
self.output_frame)
except NotImplementedError:
pass

Check warning on line 141 in gwcs/api.py

View check run for this annotation

Codecov / codecov/patch

gwcs/api.py#L140-L141

Added lines #L140 - L141 were not covered by tests

result = self.invert(*world_arrays, with_units=False)

Expand Down
37 changes: 33 additions & 4 deletions gwcs/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from astropy import coordinates as coord
from astropy.modeling import models

from .. import coordinate_frames as cf
from .. import spectroscopy as sp
from .. import wcs
from .. import geometry
from gwcs import coordinate_frames as cf
from gwcs import spectroscopy as sp
from gwcs import wcs
from gwcs import geometry

# frames
detector_1d = cf.CoordinateFrame(name='detector', axes_order=(0,), naxes=1, axes_type="detector")
Expand Down Expand Up @@ -522,3 +522,32 @@
@pytest.fixture
def cart_to_spher():
return geometry.CartesianToSpherical()


@pytest.fixture
def gwcs_simple_imaging_no_units():
shift_by_crpix = models.Shift(-2048) & models.Shift(-1024)
matrix = np.array([[1.290551569736E-05, 5.9525007864732E-06],

Check warning on line 530 in gwcs/tests/conftest.py

View check run for this annotation

Codecov / codecov/patch

gwcs/tests/conftest.py#L529-L530

Added lines #L529 - L530 were not covered by tests
[5.0226382102765E-06 , -1.2644844123757E-05]])
rotation = models.AffineTransformation2D(matrix,

Check warning on line 532 in gwcs/tests/conftest.py

View check run for this annotation

Codecov / codecov/patch

gwcs/tests/conftest.py#L532

Added line #L532 was not covered by tests
translation=[0, 0])

rotation.inverse = models.AffineTransformation2D(np.linalg.inv(matrix),

Check warning on line 535 in gwcs/tests/conftest.py

View check run for this annotation

Codecov / codecov/patch

gwcs/tests/conftest.py#L535

Added line #L535 was not covered by tests
translation=[0, 0])
tan = models.Pix2Sky_TAN()
celestial_rotation = models.RotateNative2Celestial(5.63056810618,

Check warning on line 538 in gwcs/tests/conftest.py

View check run for this annotation

Codecov / codecov/patch

gwcs/tests/conftest.py#L537-L538

Added lines #L537 - L538 were not covered by tests
-72.05457184279,
180)
det2sky = shift_by_crpix | rotation | tan | celestial_rotation
det2sky.name = "linear_transform"

Check warning on line 542 in gwcs/tests/conftest.py

View check run for this annotation

Codecov / codecov/patch

gwcs/tests/conftest.py#L541-L542

Added lines #L541 - L542 were not covered by tests

detector_frame = cf.Frame2D(name="detector", axes_names=("x", "y"),

Check warning on line 544 in gwcs/tests/conftest.py

View check run for this annotation

Codecov / codecov/patch

gwcs/tests/conftest.py#L544

Added line #L544 was not covered by tests
unit=(u.pix, u.pix))
sky_frame = cf.CelestialFrame(reference_frame=coord.ICRS(), name='icrs',

Check warning on line 546 in gwcs/tests/conftest.py

View check run for this annotation

Codecov / codecov/patch

gwcs/tests/conftest.py#L546

Added line #L546 was not covered by tests
unit=(u.deg, u.deg))
pipeline = [(detector_frame, det2sky),

Check warning on line 548 in gwcs/tests/conftest.py

View check run for this annotation

Codecov / codecov/patch

gwcs/tests/conftest.py#L548

Added line #L548 was not covered by tests
(sky_frame, None)
]
w = wcs.WCS(pipeline)
w.bounding_box = ((2, 100), (5, 500))
return w

Check warning on line 553 in gwcs/tests/conftest.py

View check run for this annotation

Codecov / codecov/patch

gwcs/tests/conftest.py#L551-L553

Added lines #L551 - L553 were not covered by tests
17 changes: 4 additions & 13 deletions gwcs/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def wcs_ndim_types_units(request):

@fixture_all_wcses
def test_lowlevel_types(wcsobj):
pytest.importorskip("typeguard")
try:
# Skip this on older versions of astropy where it dosen't exist.
from astropy.wcs.wcsapi.tests.utils import validate_low_level_wcs_types
Expand Down Expand Up @@ -230,12 +229,12 @@ def test_world_axis_object_classes_4d(gwcs_4d_identity_units):
def _compare_frame_output(wc1, wc2):
if isinstance(wc1, coord.SkyCoord):
assert isinstance(wc1.frame, type(wc2.frame))
assert u.allclose(wc1.spherical.lon, wc2.spherical.lon)
assert u.allclose(wc1.spherical.lat, wc2.spherical.lat)
assert u.allclose(wc1.spherical.distance, wc2.spherical.distance)
assert u.allclose(wc1.spherical.lon, wc2.spherical.lon, equal_nan=True)
assert u.allclose(wc1.spherical.lat, wc2.spherical.lat, equal_nan=True)
assert u.allclose(wc1.spherical.distance, wc2.spherical.distance, equal_nan=True)

elif isinstance(wc1, u.Quantity):
assert u.allclose(wc1, wc2)
assert u.allclose(wc1, wc2, equal_nan=True)

elif isinstance(wc1, time.Time):
assert u.allclose((wc1 - wc2).to(u.s), 0*u.s)
Expand All @@ -252,12 +251,6 @@ def _compare_frame_output(wc1, wc2):

@fixture_all_wcses
def test_high_level_wrapper(wcsobj, request):
if request.node.callspec.params['wcsobj'] in ('gwcs_4d_identity_units', 'gwcs_stokes_lookup'):
pytest.importorskip("astropy", minversion="4.0dev0")

# Remove the bounding box because the type test is a little broken with the
# bounding box.
del wcsobj._pipeline[0].transform.bounding_box

hlvl = HighLevelWCSWrapper(wcsobj)

Expand All @@ -280,8 +273,6 @@ def test_high_level_wrapper(wcsobj, request):


def test_stokes_wrapper(gwcs_stokes_lookup):
pytest.importorskip("astropy", minversion="4.0dev0")

hlvl = HighLevelWCSWrapper(gwcs_stokes_lookup)

pixel_input = [0, 1, 2, 3]
Expand Down
4 changes: 2 additions & 2 deletions gwcs/tests/test_api_slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,8 @@ def test_celestial_slice(gwcs_3d_galactic_spectral):
assert_allclose(wcs.pixel_to_world_values(39, 44), (10.24, 20, 25))
assert_allclose(wcs.array_index_to_world_values(44, 39), (10.24, 20, 25))

assert_allclose(wcs.world_to_pixel_values(12.4, 20, 25), (39., 44.))
assert_equal(wcs.world_to_array_index_values(12.4, 20, 25), (44, 39))
assert_allclose(wcs.world_to_pixel_values(10.24, 20, 25), (39., 44.))
assert_equal(wcs.world_to_array_index_values(10.24, 20, 25), (44, 39))

assert_equal(wcs.pixel_bounds, [(-2, 45), (5, 50)])

Expand Down
7 changes: 4 additions & 3 deletions gwcs/tests/test_coordinate_systems.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def test_bare_baseframe():
output_frame=frame,
input_frame=cf.CoordinateFrame(1, "PIXEL", (0,), unit=(u.pix,), name="detector_frame")
)
#w.bounding_box = (0, 9)
assert u.allclose(w.world_to_pixel(0*u.km), 0)


Expand Down Expand Up @@ -190,7 +191,7 @@ def test_temporal_relative():
assert a[1] == Time("2018-01-01T00:00:00") + 20 * u.s


@pytest.mark.skipif(astropy_version<"4", reason="Requires astropy 4.0 or higher")
#@pytest.mark.skipif(astropy_version<"4", reason="Requires astropy 4.0 or higher")
def test_temporal_absolute():
t = cf.TemporalFrame(reference_frame=Time([], format='isot'))
assert t.coordinates("2018-01-01T00:00:00") == Time("2018-01-01T00:00:00")
Expand Down Expand Up @@ -240,7 +241,7 @@ def test_coordinate_to_quantity_spectral(inp):
(Time("2011-01-01T00:00:10"),),
(10 * u.s,)
])
@pytest.mark.skipif(astropy_version<"4", reason="Requires astropy 4.0 or higher.")
#@pytest.mark.skipif(astropy_version<"4", reason="Requires astropy 4.0 or higher.")
def test_coordinate_to_quantity_temporal(inp):
temp = cf.TemporalFrame(reference_frame=Time("2011-01-01T00:00:00"), unit=u.s)

Expand Down Expand Up @@ -325,7 +326,7 @@ def test_coordinate_to_quantity_frame_2d():
assert_quantity_allclose(output, exp)


@pytest.mark.skipif(astropy_version<"4", reason="Requires astropy 4.0 or higher.")
#@pytest.mark.skipif(astropy_version<"4", reason="Requires astropy 4.0 or higher.")
def test_coordinate_to_quantity_error():
frame = cf.Frame2D(unit=(u.one, u.arcsec))
with pytest.raises(ValueError):
Expand Down
6 changes: 6 additions & 0 deletions gwcs/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from astropy import units as u
from astropy import coordinates as coord
from astropy.modeling import models
from astropy import table

from astropy.tests.helper import assert_quantity_allclose
import pytest
from numpy.testing import assert_allclose
Expand Down Expand Up @@ -104,6 +106,10 @@ def test_isnumerical():
assert gwutils.isnumerical(np.array(0, dtype='>f8'))
assert gwutils.isnumerical(np.array(0, dtype='>i4'))

# check a table column
t = table.Table(data=[[1,2,3], [4,5,6]], names=['x', 'y'])
assert not gwutils.isnumerical(t['x'])


def test_get_values():
args = 2 * u.cm
Expand Down
4 changes: 4 additions & 0 deletions gwcs/tests/test_wcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1125,6 +1125,7 @@ def test_iter_inv():
*w(x, y),
adaptive=True,
detect_divergence=True,
tolerance=1e-4, maxiter=50,
quiet=False
)
assert np.allclose((x, y), (xp, yp))
Expand All @@ -1144,6 +1145,7 @@ def test_iter_inv():
xp, yp = w.numerical_inverse(
*w(x, y),
adaptive=True,
tolerance=1e-5, maxiter=50,
detect_divergence=False,
quiet=False
)
Expand Down Expand Up @@ -1178,6 +1180,7 @@ def test_iter_inv():
xp, yp = w.numerical_inverse(
*w(x, y, with_bounding_box=False),
adaptive=False,
tolerance=1e-5, maxiter=50,
detect_divergence=True,
quiet=False,
with_bounding_box=False
Expand All @@ -1191,6 +1194,7 @@ def test_iter_inv():
xp, yp = w.numerical_inverse(
*w(x, y, with_bounding_box=False),
adaptive=False,
tolerance=1e-5, maxiter=50,
detect_divergence=True,
quiet=False,
with_bounding_box=False
Expand Down
11 changes: 5 additions & 6 deletions gwcs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from astropy import coordinates as coords
from astropy import units as u
from astropy.time import Time, TimeDelta
from astropy import table
from astropy.wcs import Celprm


Expand Down Expand Up @@ -470,14 +471,12 @@ def isnumerical(val):
Determine if a value is numerical (number or np.array of numbers).
"""
isnum = True
if isinstance(val, coords.SkyCoord):
isnum = False
elif isinstance(val, u.Quantity):
isnum = False
elif isinstance(val, (Time, TimeDelta)):
astropy_types=(coords.SkyCoord, u.Quantity, Time, TimeDelta, table.Column, table.Row)
if isinstance(val, astropy_types):
isnum = False
elif (isinstance(val, np.ndarray)
and not np.issubdtype(val.dtype, np.floating)
and not np.issubdtype(val.dtype, np.integer)):
and not np.issubdtype(val.dtype, np.integer)
):
isnum = False
return isnum
87 changes: 69 additions & 18 deletions gwcs/wcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,10 +392,7 @@
and `False` if input is outside the footprint.

"""
kwargs['with_bounding_box'] = True
kwargs['fill_value'] = np.nan

coords = self.invert(*args, **kwargs)
coords = self.invert(*args, with_bounding_box=False, **kwargs)

result = np.isfinite(coords)
if self.input_frame.naxes > 1:
Expand Down Expand Up @@ -466,38 +463,86 @@
"""
with_units = kwargs.pop('with_units', False)

try:
btrans = self.backward_transform
except NotImplementedError:
btrans = None

if not utils.isnumerical(args[0]):
# convert astropy objects to numbers and arrays
args = self.output_frame.coordinate_to_quantity(*args)
if self.output_frame.naxes == 1:
args = [args]
try:
if not self.backward_transform.uses_quantity:
args = utils.get_values(self.output_frame.unit, *args)
except (NotImplementedError, KeyError):

# if the transform does not use units, getthe numerical values
if btrans is not None and not btrans.uses_quantity:
args = utils.get_values(self.output_frame.unit, *args)

if 'with_bounding_box' not in kwargs:
kwargs['with_bounding_box'] = True
with_bounding_box = kwargs.pop('with_bounding_box', True)
fill_value = kwargs.pop('fill_value', np.nan)
akwargs = {k: v for k, v in kwargs.items() if k not in _ITER_INV_KWARGS}

if 'fill_value' not in kwargs:
kwargs['fill_value'] = np.nan
if with_bounding_box and self.bounding_box is not None:
result = self.outside_footprint(args)

try:
# remove iterative inverse-specific keyword arguments:
akwargs = {k: v for k, v in kwargs.items() if k not in _ITER_INV_KWARGS}
result = self.backward_transform(*args, **akwargs)
except (NotImplementedError, KeyError):
if btrans is not None:
#akwargs = {k: v for k, v in kwargs.items() if k not in _ITER_INV_KWARGS}
result = btrans(*args, **akwargs)
else:
result = self.numerical_inverse(*args, **kwargs, with_units=with_units)

# deal with values outside the bounding box
if with_bounding_box and self.bounding_box is not None:
result = self.out_of_bounds(result, fill_value=fill_value)

if with_units and self.input_frame:
if self.input_frame.naxes == 1:
return self.input_frame.coordinates(result)
else:
return self.input_frame.coordinates(*result)
else:
return result

def outside_footprint(self, world_arrays):
for axis in world_arrays:
if np.isscalar(world_arrays) or self.output_frame.naxes == 1:
world_arrays = [world_arrays]
world_arrays = list(world_arrays)
footprint = self.footprint()
for idim, coord in enumerate(world_arrays):
axis_range = footprint[:, idim]
range = [axis_range.min(), axis_range.max()]
outside = (coord < range[0]) | (coord > range[1])
if np.any(outside):
if np.isscalar(coord):
coord = np.nan

Check warning on line 518 in gwcs/wcs.py

View check run for this annotation

Codecov / codecov/patch

gwcs/wcs.py#L517-L518

Added lines #L517 - L518 were not covered by tests
else:
coord[outside] = np.nan
world_arrays[idim] = coord

Check warning on line 521 in gwcs/wcs.py

View check run for this annotation

Codecov / codecov/patch

gwcs/wcs.py#L520-L521

Added lines #L520 - L521 were not covered by tests

return world_arrays


def out_of_bounds(self, pixel_arrays, fill_value=np.nan):
if np.isscalar(pixel_arrays) or self.input_frame.naxes == 1:
pixel_arrays = [pixel_arrays]

Check warning on line 528 in gwcs/wcs.py

View check run for this annotation

Codecov / codecov/patch

gwcs/wcs.py#L528

Added line #L528 was not covered by tests

pixel_arrays = list(pixel_arrays)
bbox = self.bounding_box
for idim, pix in enumerate(pixel_arrays):
outside = (pix < bbox[idim][0]) | (pix > bbox[idim][1])
if np.any(outside):
if np.isscalar(pix):
pixel_arrays[idim] = np.nan

Check warning on line 536 in gwcs/wcs.py

View check run for this annotation

Codecov / codecov/patch

gwcs/wcs.py#L535-L536

Added lines #L535 - L536 were not covered by tests
else:
pix = pixel_arrays[idim].astype(float, copy=True)
pix[outside] = np.nan
pixel_arrays[idim] = pix

Check warning on line 540 in gwcs/wcs.py

View check run for this annotation

Codecov / codecov/patch

gwcs/wcs.py#L538-L540

Added lines #L538 - L540 were not covered by tests
if self.input_frame.naxes == 1:
pixel_arrays = pixel_arrays[0]

Check warning on line 542 in gwcs/wcs.py

View check run for this annotation

Codecov / codecov/patch

gwcs/wcs.py#L542

Added line #L542 was not covered by tests
return pixel_arrays

def numerical_inverse(self, *args, tolerance=1e-5, maxiter=50, adaptive=True,
def numerical_inverse(self, *args, tolerance=1e-5, maxiter=30, adaptive=True,
detect_divergence=True, quiet=True, with_bounding_box=True,
fill_value=np.nan, with_units=False, **kwargs):
"""
Expand Down Expand Up @@ -1380,6 +1425,11 @@

"""
def _order_clockwise(v):
# if self.input_frame.naxes == 1:
# bb = self.bounding_box.bounding_box()
# if isinstance(bb[0], u.Quantity):
# bb = [v.value for v in bb] * bb[0].unit
# return (bb,)
return np.asarray([[v[0][0], v[1][0]], [v[0][0], v[1][1]],
[v[0][1], v[1][1]], [v[0][1], v[1][0]]]).T

Expand All @@ -1397,6 +1447,7 @@
else:
vertices = np.array(list(itertools.product(*bb))).T

# workaround an issue with bbox with quantity, interval needs to be a cquantity, not a list of quantities
if center:
vertices = utils._toindex(vertices)

Expand Down
Loading