Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove old bbox support #458

Merged
merged 1 commit into from
Jun 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 27 additions & 44 deletions gwcs/tests/test_wcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
from astropy.wcs import wcsapi
from astropy.time import Time

from gwcs.wcs import new_bbox

from .. import wcs
from ..wcstools import (wcs_from_fiducial, grid_from_bounding_box, wcs_from_points)
from .. import coordinate_frames as cf
Expand Down Expand Up @@ -287,10 +285,7 @@ def test_bounding_box():
pipeline = [('detector', trans2), ('sky', None)]
w = wcs.WCS(pipeline)
w.bounding_box = bb
if new_bbox:
assert w.bounding_box == w.forward_transform.bounding_box
else:
assert w.bounding_box == w.forward_transform.bounding_box[::-1]
assert w.bounding_box == w.forward_transform.bounding_box

pipeline = [("detector", models.Shift(2)), ("sky", None)]
w = wcs.WCS(pipeline)
Expand All @@ -317,29 +312,24 @@ def test_compound_bounding_box():
2: ((-1, 5), (3, 17)),
3: ((-3, 7), (1, 27)),
}
if new_bbox:
# Test attaching a valid bounding box (ignoring input 'x')
w.attach_compound_bounding_box(cbb, [('x',)])
from astropy.modeling.bounding_box import CompoundBoundingBox
cbb = CompoundBoundingBox.validate(trans3, cbb, selector_args=[('x',)], order='F')
assert w.bounding_box == cbb
assert w.bounding_box is trans3.bounding_box

# Test evaluating
assert_allclose(w(13, 2, 1), (np.nan, np.nan, np.nan))
assert_allclose(w(13, 2, 2), (np.nan, np.nan, np.nan))
assert_allclose(w(13, 0, 3), (np.nan, np.nan, np.nan))
# No bounding box for selector
with pytest.raises(RuntimeError):
w(13, 13, 4)

# Test attaching a invalid bounding box (not ignoring input 'x')
with pytest.raises(ValueError):
w.attach_compound_bounding_box(cbb, [('x', False)])
else:
with pytest.raises(NotImplementedError) as err:
w.attach_compound_bounding_box(cbb, [('x',)])
assert str(err.value) == 'Compound bounding box is not supported for your version of astropy'
# Test attaching a valid bounding box (ignoring input 'x')
w.attach_compound_bounding_box(cbb, [('x',)])
from astropy.modeling.bounding_box import CompoundBoundingBox
cbb = CompoundBoundingBox.validate(trans3, cbb, selector_args=[('x',)], order='F')
assert w.bounding_box == cbb
assert w.bounding_box is trans3.bounding_box

# Test evaluating
assert_allclose(w(13, 2, 1), (np.nan, np.nan, np.nan))
assert_allclose(w(13, 2, 2), (np.nan, np.nan, np.nan))
assert_allclose(w(13, 0, 3), (np.nan, np.nan, np.nan))
# No bounding box for selector
with pytest.raises(RuntimeError):
w(13, 13, 4)

# Test attaching a invalid bounding box (not ignoring input 'x')
with pytest.raises(ValueError):
w.attach_compound_bounding_box(cbb, [('x', False)])

# Test that bounding_box with quantities can be assigned and evaluates
trans = models.Shift(10 * u .pix) & models.Shift(2 * u.pix)
Expand All @@ -349,19 +339,15 @@ def test_compound_bounding_box():
1 * u.pix: (1 * u.pix, 5 * u.pix),
2 * u.pix: (2 * u.pix, 6 * u.pix)
}
if new_bbox:
w.attach_compound_bounding_box(cbb, [('x1',)])
w.attach_compound_bounding_box(cbb, [('x1',)])

from astropy.modeling.bounding_box import CompoundBoundingBox
cbb = CompoundBoundingBox.validate(trans, cbb, selector_args=[('x1',)], order='F')
assert w.bounding_box == cbb
assert w.bounding_box is trans.bounding_box
from astropy.modeling.bounding_box import CompoundBoundingBox
cbb = CompoundBoundingBox.validate(trans, cbb, selector_args=[('x1',)], order='F')
assert w.bounding_box == cbb
assert w.bounding_box is trans.bounding_box

assert_allclose(w(-1*u.pix, 1*u.pix), (np.nan, np.nan))
assert_allclose(w(7*u.pix, 2*u.pix), (np.nan, np.nan))
else:
with pytest.raises(NotImplementedError) as err:
w.attach_compound_bounding_box(cbb, [('x1',)])
assert_allclose(w(-1*u.pix, 1*u.pix), (np.nan, np.nan))
assert_allclose(w(7*u.pix, 2*u.pix), (np.nan, np.nan))


def test_grid_from_bounding_box():
Expand Down Expand Up @@ -940,10 +926,7 @@ def test_to_fits_1D_round_trip(gwcs_1d_spectral):

# test points:
np.random.seed(1)
if new_bbox:
(xmin, xmax) = w.bounding_box.bounding_box()
else:
(xmin, xmax) = w.bounding_box
(xmin, xmax) = w.bounding_box.bounding_box()
x = xmin + (xmax - xmin) * np.random.random(100)

# test forward transformation:
Expand Down
101 changes: 29 additions & 72 deletions gwcs/wcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,27 @@
import functools
import itertools
import warnings

import astropy.io.fits as fits
import numpy as np
import numpy.linalg as npla
from scipy import optimize, linalg
from astropy import units as u
from astropy.modeling import fix_inputs, projections
from astropy.modeling.bounding_box import CompoundBoundingBox
from astropy.modeling.bounding_box import ModelBoundingBox as Bbox
from astropy.modeling.core import Model
from astropy.modeling.models import (
Identity, Mapping, Const1D, Shift, Polynomial2D,
Sky2Pix_TAN, RotateCelestial2Native
)
from astropy.modeling import projections, fix_inputs
import astropy.io.fits as fits
from astropy.modeling.models import (Const1D, Identity, Mapping, Polynomial2D,
RotateCelestial2Native, Shift,
Sky2Pix_TAN)
from astropy.wcs.utils import celestial_frame_to_wcs, proj_plane_pixel_scales
from scipy import linalg, optimize

from .api import GWCSAPIMixin
from . import coordinate_frames as cf
from .utils import CoordinateFrameError
from . import utils
from .api import GWCSAPIMixin
from .utils import CoordinateFrameError
from .wcstools import grid_from_bounding_box

try:
from astropy.modeling.bounding_box import ModelBoundingBox as Bbox
from astropy.modeling.bounding_box import CompoundBoundingBox
new_bbox = True
except ImportError:
from astropy.modeling.utils import _BoundingBox as Bbox
new_bbox = False


__all__ = ['WCS', 'Step', 'NoConvergence']

_ITER_INV_KWARGS = ['tolerance', 'maxiter', 'adaptive', 'detect_divergence', 'quiet']
Expand Down Expand Up @@ -366,14 +359,7 @@ def __call__(self, *args, **kwargs):
# Currently compound models do not attempt to combine individual model
# bounding boxes. Get the forward transform and assign the bounding_box to it
# before evaluating it. The order Model.bounding_box is reversed.
if new_bbox:
transform.bounding_box = self.bounding_box
else:
axes_ind = self._get_axes_indices()
if transform.n_inputs > 1:
transform.bounding_box = [self.bounding_box[ind] for ind in axes_ind][::-1]
else:
transform.bounding_box = self.bounding_box
transform.bounding_box = self.bounding_box

result = transform(*args, **kwargs)

Expand Down Expand Up @@ -425,10 +411,7 @@ def in_image(self, *args, **kwargs):
return result

if self.input_frame.naxes == 1:
if new_bbox:
x1, x2 = self.bounding_box.bounding_box()
else:
x1, x2 = self.bounding_box
x1, x2 = self.bounding_box.bounding_box()

if len(np.shape(args[0])) > 0:
result[result] = (coords[result] >= x1) & (coords[result] <= x2)
Expand Down Expand Up @@ -1318,17 +1301,7 @@ def bounding_box(self):
except NotImplementedError:
return None

if new_bbox:
return bb
else:
if transform_0.n_inputs == 1:
return bb
try:
axes_order = self.input_frame.axes_order
except AttributeError:
axes_order = np.arange(transform_0.n_inputs)
# Model.bounding_box is in python order, need to reverse it first.
return tuple(bb[::-1][i] for i in axes_order)
return bb

@bounding_box.setter
def bounding_box(self, value):
Expand All @@ -1350,39 +1323,23 @@ def bounding_box(self, value):
else:
try:
# Make sure the dimensions of the new bbox are correct.
if new_bbox:
if isinstance(value, CompoundBoundingBox):
bbox = CompoundBoundingBox.validate(transform_0, value, order='F')
else:
bbox = Bbox.validate(transform_0, value, order='F')
if isinstance(value, CompoundBoundingBox):
bbox = CompoundBoundingBox.validate(transform_0, value, order='F')
else:
Bbox.validate(transform_0, value)
bbox = Bbox.validate(transform_0, value, order='F')
except Exception:
raise

if new_bbox:
transform_0.bounding_box = bbox
else:
# get the sorted order of axes' indices
axes_ind = self._get_axes_indices()
if transform_0.n_inputs == 1:
transform_0.bounding_box = value
else:
# The axes in bounding_box in modeling follow python order
#transform_0.bounding_box = np.array(value)[axes_ind][::-1]
transform_0.bounding_box = [value[ind] for ind in axes_ind][::-1]
transform_0.bounding_box = bbox

self.set_transform(frames[0], frames[1], transform_0)

def attach_compound_bounding_box(self, cbbox, selector_args):
if new_bbox:
frames = self.available_frames
transform_0 = self.get_transform(frames[0], frames[1])
frames = self.available_frames
transform_0 = self.get_transform(frames[0], frames[1])

self.bounding_box = CompoundBoundingBox.validate(transform_0, cbbox, selector_args=selector_args,
order='F')
else:
raise NotImplementedError('Compound bounding box is not supported for your version of astropy')
self.bounding_box = CompoundBoundingBox.validate(transform_0, cbbox, selector_args=selector_args,
order='F')

def _get_axes_indices(self):
try:
Expand All @@ -1394,6 +1351,7 @@ def _get_axes_indices(self):

def __str__(self):
from astropy.table import Table

#col1 = [item[0] for item in self._pipeline]
col1 = [step.frame for step in self._pipeline]
col2 = []
Expand Down Expand Up @@ -2554,13 +2512,12 @@ def _to_fits_tab(self, hdr, world_axes_group, use_cd, bounding_box,
if isinstance(bin_ext, str):
bin_ext = (bin_ext, 1)

if new_bbox:
if isinstance(bounding_box, Bbox):
bounding_box = bounding_box.bounding_box(order='F')
if isinstance(bounding_box, list):
for index, bbox in enumerate(bounding_box):
if isinstance(bbox, Bbox):
bounding_box[index] = bbox.bounding_box(order='F')
if isinstance(bounding_box, Bbox):
bounding_box = bounding_box.bounding_box(order='F')
if isinstance(bounding_box, list):
for index, bbox in enumerate(bounding_box):
if isinstance(bbox, Bbox):
bounding_box[index] = bbox.bounding_box(order='F')

# identify input axes:
input_axes = []
Expand Down