Skip to content

Commit

Permalink
Merge pull request #115 from spacetelescope/asdfpoles
Browse files Browse the repository at this point in the history
Updates to asdf cutout for edge cases
  • Loading branch information
havok2063 authored Jan 24, 2024
2 parents 84b30ed + 8a20161 commit d31a3cf
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 23 deletions.
32 changes: 22 additions & 10 deletions astrocut/asdf_cutouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,20 @@
import asdf
import astropy
import gwcs
import numpy as np

from astropy.coordinates import SkyCoord


def get_center_pixel(gwcs: gwcs.wcs.WCS, ra: float, dec: float) -> tuple:
def get_center_pixel(gwcsobj: gwcs.wcs.WCS, ra: float, dec: float) -> tuple:
""" Get the center pixel from a roman 2d science image
For an input RA, Dec sky coordinate, get the closest pixel location
on the input Roman image.
Parameters
----------
gwcs : gwcs.wcs.WCS
gwcsobj : gwcs.wcs.WCS
the Roman GWCS object
ra : float
the input Right Ascension
Expand All @@ -32,7 +33,7 @@ def get_center_pixel(gwcs: gwcs.wcs.WCS, ra: float, dec: float) -> tuple:
"""

# Convert the gwcs object to an astropy FITS WCS header
header = gwcs.to_fits_sip()
header = gwcsobj.to_fits_sip()

# Update WCS header with some keywords that it's missing.
# Otherwise, it won't work with astropy.wcs tools (TODO: Figure out why. What are these keywords for?)
Expand All @@ -47,14 +48,14 @@ def get_center_pixel(gwcs: gwcs.wcs.WCS, ra: float, dec: float) -> tuple:
coordinates = SkyCoord(ra, dec, unit='deg')

# Map the coordinates to a pixel's location on the Roman 2d array (row, col)
row, col = astropy.wcs.utils.skycoord_to_pixel(coords=coordinates, wcs=wcs_updated)
row, col = gwcsobj.invert(coordinates)

return (row, col), wcs_updated


def get_cutout(data: asdf.tags.core.ndarray.NDArrayType, coords: Union[tuple, SkyCoord],
wcs: astropy.wcs.wcs.WCS = None, size: int = 20, outfile: str = "example_roman_cutout.fits",
write_file: bool = True) -> astropy.nddata.Cutout2D:
write_file: bool = True, fill_value: Union[int, float] = np.nan) -> astropy.nddata.Cutout2D:
""" Get a Roman image cutout
Cut out a square section from the input image data array. The ``coords`` can either be a tuple of x, y
Expand All @@ -75,6 +76,8 @@ def get_cutout(data: asdf.tags.core.ndarray.NDArrayType, coords: Union[tuple, Sk
the name of the output cutout file, by default "example_roman_cutout.fits"
write_file : bool, by default True
Flag to write the cutout to a file or not
fill_value: int | float, by default np.nan
The fill value for pixels outside the original image.
Returns
-------
Expand All @@ -85,14 +88,21 @@ def get_cutout(data: asdf.tags.core.ndarray.NDArrayType, coords: Union[tuple, Sk
------
ValueError:
when a wcs is not present when coords is a SkyCoord object
RuntimeError:
when the requested cutout does not overlap with the original image
"""

# check for correct inputs
if isinstance(coords, SkyCoord) and not wcs:
raise ValueError('wcs must be input if coords is a SkyCoord.')

# create the cutout
cutout = astropy.nddata.Cutout2D(data, position=coords, wcs=wcs, size=(size, size))
try:
cutout = astropy.nddata.Cutout2D(data, position=coords, wcs=wcs, size=(size, size), mode='partial',
fill_value=fill_value)
except astropy.nddata.utils.NoOverlapError as e:
raise RuntimeError('Could not create 2d cutout. The requested cutout does not overlap with the '
'original image.') from e

# check if the data is a quantity and get the array data
if isinstance(cutout.data, astropy.units.Quantity):
Expand All @@ -109,7 +119,7 @@ def get_cutout(data: asdf.tags.core.ndarray.NDArrayType, coords: Union[tuple, Sk

def asdf_cut(input_file: str, ra: float, dec: float, cutout_size: int = 20,
output_file: str = "example_roman_cutout.fits",
write_file: bool = True) -> astropy.nddata.Cutout2D:
write_file: bool = True, fill_value: Union[int, float] = np.nan) -> astropy.nddata.Cutout2D:
""" Preliminary proof-of-concept functionality.
Takes a single ASDF input file (``input_file``) and generates a cutout of designated size ``cutout_size``
Expand All @@ -129,6 +139,8 @@ def asdf_cut(input_file: str, ra: float, dec: float, cutout_size: int = 20,
the name of the output cutout file, by default "example_roman_cutout.fits"
write_file : bool, by default True
Flag to write the cutout to a file or not
fill_value: int | float, by default np.nan
The fill value for pixels outside the original image.
Returns
-------
Expand All @@ -139,11 +151,11 @@ def asdf_cut(input_file: str, ra: float, dec: float, cutout_size: int = 20,
# get the 2d image data
with asdf.open(input_file) as f:
data = f['roman']['data']
gwcs = f['roman']['meta']['wcs']
gwcsobj = f['roman']['meta']['wcs']

# get the center pixel
pixel_coordinates, wcs = get_center_pixel(gwcs, ra, dec)
pixel_coordinates, wcs = get_center_pixel(gwcsobj, ra, dec)

# create the 2d image cutout
return get_cutout(data, pixel_coordinates, wcs, size=cutout_size, outfile=output_file,
write_file=write_file)
write_file=write_file, fill_value=fill_value)
126 changes: 113 additions & 13 deletions astrocut/tests/test_asdf_cut.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from astropy import units as u
from astropy.io import fits
from astropy.wcs import WCS
from astropy.wcs.utils import pixel_to_skycoord
from gwcs import wcs
from gwcs import coordinate_frames as cf
from astrocut.asdf_cutouts import get_center_pixel, get_cutout, asdf_cut
Expand All @@ -19,38 +20,62 @@ def make_wcs(xsize, ysize, ra=30., dec=45.):
# todo - refine this to better reflect roman wcs

# create transformations
# - shift coords so array center is at 0, 0 ; reference pixel
# - scale pixels to correct angular scale
# - project coords onto sky with TAN projection
# - transform center pixel to the input celestial coordinate
pixelshift = models.Shift(-xsize) & models.Shift(-ysize)
pixelscale = models.Scale(0.1 / 3600.) & models.Scale(0.1 / 3600.) # 0.1 arcsec/pixel
tangent_projection = models.Pix2Sky_TAN()
celestial_rotation = models.RotateNative2Celestial(ra, dec, 180.)

# transform pixels to sky
# net transforms pixels to sky
det2sky = pixelshift | pixelscale | tangent_projection | celestial_rotation

# define the wcs object
detector_frame = cf.Frame2D(name="detector", axes_names=("x", "y"), unit=(u.pix, u.pix))
sky_frame = cf.CelestialFrame(reference_frame=coord.ICRS(), name='icrs', unit=(u.deg, u.deg))
sky_frame = cf.CelestialFrame(reference_frame=coord.ICRS(), name='world', unit=(u.deg, u.deg))
return wcs.WCS([(detector_frame, det2sky), (sky_frame, None)])


@pytest.fixture()
def fakedata():
def makefake():
""" fixture factory to make a fake gwcs and dataset """

def _make_fake(nx, ny, ra, dec, zero=False, asint=False):
# create the wcs
wcsobj = make_wcs(nx/2, ny/2, ra=ra, dec=dec)
wcsobj.bounding_box = ((0, nx), (0, ny))

# create the data
if zero:
data = np.zeros([nx, ny])
else:
size = nx * ny
data = np.arange(size).reshape(nx, ny)

# make a quantity
data *= (u.electron / u.second)

# make integer array
if asint:
data = data.astype(int)

return data, wcsobj

yield _make_fake


@pytest.fixture()
def fakedata(makefake):
""" fixture to create fake data and wcs """
# set up initial parameters
nx = 100
ny = 100
size = nx * ny
ra = 30.
dec = 45.

# create the wcs
wcsobj = make_wcs(nx/2, ny/2, ra=ra, dec=dec)
wcsobj.bounding_box = ((0, nx), (0, ny))

# create the data
data = np.arange(size).reshape(nx, ny) * (u.electron / u.second)

yield data, wcsobj
yield makefake(nx, ny, ra, dec)


@pytest.fixture()
Expand Down Expand Up @@ -108,7 +133,9 @@ def test_get_cutout(output_file, fakedata, quantity):
data = data.value

# create cutout
get_cutout(data, skycoord, wcs, size=10, outfile=output_file)
cutout = get_cutout(data, skycoord, wcs, size=10, outfile=output_file)

assert_same_coord(5, 10, cutout, wcs)

# test output
with fits.open(output_file) as hdulist:
Expand Down Expand Up @@ -140,3 +167,76 @@ def test_cutout_nofile(make_file, output_file):
assert cutout.shape == (10, 10)


def test_cutout_poles(makefake):
""" test we can make cutouts around poles """
# make fake zero data around the pole
ra, dec = 315.0, 89.995
data, gwcs = makefake(1000, 1000, ra, dec, zero=True)

# add some values (5x5 array)
data.value[245:250, 245:250] = 1

# check central pixel is correct
ss = gwcs(500, 500)
assert ss == (ra, dec)

# set input cutout coord
cc = coord.SkyCoord(284.702, 89.986, unit=u.degree)
wcs = WCS(gwcs.to_fits_sip())

# get cutout
cutout = get_cutout(data, cc, wcs, size=50, write_file=False)
assert_same_coord(5, 10, cutout, wcs)

# check cutout contains all data
assert len(np.where(cutout.data.value == 1)[0]) == 25


def test_fail_cutout_outside(fakedata):
""" test we fail when cutout completely outside range """
data, gwcs = fakedata
wcs = WCS(gwcs.to_fits_sip())
cc = coord.SkyCoord(200.0, 50.0, unit=u.degree)

with pytest.raises(RuntimeError, match='Could not create 2d cutout. The requested '
'cutout does not overlap with the original image'):
get_cutout(data, cc, wcs, size=50, write_file=False)


def assert_same_coord(x, y, cutout, wcs):
""" assert we get the same sky coordinate from cutout and original wcs """
cutout_coord = pixel_to_skycoord(x, y, cutout.wcs)
ox, oy = cutout.to_original_position((x, y))
orig_coord = pixel_to_skycoord(ox, oy, wcs)
assert cutout_coord == orig_coord


@pytest.mark.parametrize('asint, fill', [(False, None), (True, -9999)], ids=['fillfloat', 'fillint'])
def test_partial_cutout(makefake, asint, fill):
""" test we get a partial cutout with nans or fill value """
ra, dec = 30.0, 45.0
data, gwcs = makefake(100, 100, ra, dec, asint=asint)

wcs = WCS(gwcs.to_fits_sip())
cc = coord.SkyCoord(29.999, 44.998, unit=u.degree)
cutout = get_cutout(data, cc, wcs, size=50, write_file=False, fill_value=fill)
assert cutout.shape == (50, 50)
if asint:
assert -9999 in cutout.data
else:
assert np.isnan(cutout.data).any()


def test_bad_fill(makefake):
""" test error is raised on bad fill value """
ra, dec = 30.0, 45.0
data, gwcs = makefake(100, 100, ra, dec, asint=True)
wcs = WCS(gwcs.to_fits_sip())
cc = coord.SkyCoord(29.999, 44.998, unit=u.degree)
with pytest.raises(ValueError, match='fill_value is inconsistent with the data type of the input array'):
get_cutout(data, cc, wcs, size=50, write_file=False)





0 comments on commit d31a3cf

Please sign in to comment.