diff --git a/astrocut/asdf_cutouts.py b/astrocut/asdf_cutouts.py index 8a6ec0ed..8b41b331 100644 --- a/astrocut/asdf_cutouts.py +++ b/astrocut/asdf_cutouts.py @@ -6,11 +6,12 @@ import asdf import astropy import gwcs +import numpy as np from astropy.coordinates import SkyCoord -def get_center_pixel(gwcs: gwcs.wcs.WCS, ra: float, dec: float) -> tuple: +def get_center_pixel(gwcsobj: gwcs.wcs.WCS, ra: float, dec: float) -> tuple: """ Get the center pixel from a roman 2d science image For an input RA, Dec sky coordinate, get the closest pixel location @@ -18,7 +19,7 @@ def get_center_pixel(gwcs: gwcs.wcs.WCS, ra: float, dec: float) -> tuple: Parameters ---------- - gwcs : gwcs.wcs.WCS + gwcsobj : gwcs.wcs.WCS the Roman GWCS object ra : float the input Right Ascension @@ -32,7 +33,7 @@ def get_center_pixel(gwcs: gwcs.wcs.WCS, ra: float, dec: float) -> tuple: """ # Convert the gwcs object to an astropy FITS WCS header - header = gwcs.to_fits_sip() + header = gwcsobj.to_fits_sip() # Update WCS header with some keywords that it's missing. # Otherwise, it won't work with astropy.wcs tools (TODO: Figure out why. What are these keywords for?) @@ -47,14 +48,14 @@ def get_center_pixel(gwcs: gwcs.wcs.WCS, ra: float, dec: float) -> tuple: coordinates = SkyCoord(ra, dec, unit='deg') # Map the coordinates to a pixel's location on the Roman 2d array (row, col) - row, col = astropy.wcs.utils.skycoord_to_pixel(coords=coordinates, wcs=wcs_updated) + row, col = gwcsobj.invert(coordinates) return (row, col), wcs_updated def get_cutout(data: asdf.tags.core.ndarray.NDArrayType, coords: Union[tuple, SkyCoord], wcs: astropy.wcs.wcs.WCS = None, size: int = 20, outfile: str = "example_roman_cutout.fits", - write_file: bool = True) -> astropy.nddata.Cutout2D: + write_file: bool = True, fill_value: Union[int, float] = np.nan) -> astropy.nddata.Cutout2D: """ Get a Roman image cutout Cut out a square section from the input image data array. The ``coords`` can either be a tuple of x, y @@ -75,6 +76,8 @@ def get_cutout(data: asdf.tags.core.ndarray.NDArrayType, coords: Union[tuple, Sk the name of the output cutout file, by default "example_roman_cutout.fits" write_file : bool, by default True Flag to write the cutout to a file or not + fill_value: int | float, by default np.nan + The fill value for pixels outside the original image. Returns ------- @@ -85,6 +88,8 @@ def get_cutout(data: asdf.tags.core.ndarray.NDArrayType, coords: Union[tuple, Sk ------ ValueError: when a wcs is not present when coords is a SkyCoord object + RuntimeError: + when the requested cutout does not overlap with the original image """ # check for correct inputs @@ -92,7 +97,12 @@ def get_cutout(data: asdf.tags.core.ndarray.NDArrayType, coords: Union[tuple, Sk raise ValueError('wcs must be input if coords is a SkyCoord.') # create the cutout - cutout = astropy.nddata.Cutout2D(data, position=coords, wcs=wcs, size=(size, size)) + try: + cutout = astropy.nddata.Cutout2D(data, position=coords, wcs=wcs, size=(size, size), mode='partial', + fill_value=fill_value) + except astropy.nddata.utils.NoOverlapError as e: + raise RuntimeError('Could not create 2d cutout. The requested cutout does not overlap with the ' + 'original image.') from e # check if the data is a quantity and get the array data if isinstance(cutout.data, astropy.units.Quantity): @@ -109,7 +119,7 @@ def get_cutout(data: asdf.tags.core.ndarray.NDArrayType, coords: Union[tuple, Sk def asdf_cut(input_file: str, ra: float, dec: float, cutout_size: int = 20, output_file: str = "example_roman_cutout.fits", - write_file: bool = True) -> astropy.nddata.Cutout2D: + write_file: bool = True, fill_value: Union[int, float] = np.nan) -> astropy.nddata.Cutout2D: """ Preliminary proof-of-concept functionality. Takes a single ASDF input file (``input_file``) and generates a cutout of designated size ``cutout_size`` @@ -129,6 +139,8 @@ def asdf_cut(input_file: str, ra: float, dec: float, cutout_size: int = 20, the name of the output cutout file, by default "example_roman_cutout.fits" write_file : bool, by default True Flag to write the cutout to a file or not + fill_value: int | float, by default np.nan + The fill value for pixels outside the original image. Returns ------- @@ -139,11 +151,11 @@ def asdf_cut(input_file: str, ra: float, dec: float, cutout_size: int = 20, # get the 2d image data with asdf.open(input_file) as f: data = f['roman']['data'] - gwcs = f['roman']['meta']['wcs'] + gwcsobj = f['roman']['meta']['wcs'] # get the center pixel - pixel_coordinates, wcs = get_center_pixel(gwcs, ra, dec) + pixel_coordinates, wcs = get_center_pixel(gwcsobj, ra, dec) # create the 2d image cutout return get_cutout(data, pixel_coordinates, wcs, size=cutout_size, outfile=output_file, - write_file=write_file) + write_file=write_file, fill_value=fill_value) diff --git a/astrocut/tests/test_asdf_cut.py b/astrocut/tests/test_asdf_cut.py index c8dc743d..de54f757 100644 --- a/astrocut/tests/test_asdf_cut.py +++ b/astrocut/tests/test_asdf_cut.py @@ -9,6 +9,7 @@ from astropy import units as u from astropy.io import fits from astropy.wcs import WCS +from astropy.wcs.utils import pixel_to_skycoord from gwcs import wcs from gwcs import coordinate_frames as cf from astrocut.asdf_cutouts import get_center_pixel, get_cutout, asdf_cut @@ -19,38 +20,62 @@ def make_wcs(xsize, ysize, ra=30., dec=45.): # todo - refine this to better reflect roman wcs # create transformations + # - shift coords so array center is at 0, 0 ; reference pixel + # - scale pixels to correct angular scale + # - project coords onto sky with TAN projection + # - transform center pixel to the input celestial coordinate pixelshift = models.Shift(-xsize) & models.Shift(-ysize) pixelscale = models.Scale(0.1 / 3600.) & models.Scale(0.1 / 3600.) # 0.1 arcsec/pixel tangent_projection = models.Pix2Sky_TAN() celestial_rotation = models.RotateNative2Celestial(ra, dec, 180.) - # transform pixels to sky + # net transforms pixels to sky det2sky = pixelshift | pixelscale | tangent_projection | celestial_rotation # define the wcs object detector_frame = cf.Frame2D(name="detector", axes_names=("x", "y"), unit=(u.pix, u.pix)) - sky_frame = cf.CelestialFrame(reference_frame=coord.ICRS(), name='icrs', unit=(u.deg, u.deg)) + sky_frame = cf.CelestialFrame(reference_frame=coord.ICRS(), name='world', unit=(u.deg, u.deg)) return wcs.WCS([(detector_frame, det2sky), (sky_frame, None)]) @pytest.fixture() -def fakedata(): +def makefake(): + """ fixture factory to make a fake gwcs and dataset """ + + def _make_fake(nx, ny, ra, dec, zero=False, asint=False): + # create the wcs + wcsobj = make_wcs(nx/2, ny/2, ra=ra, dec=dec) + wcsobj.bounding_box = ((0, nx), (0, ny)) + + # create the data + if zero: + data = np.zeros([nx, ny]) + else: + size = nx * ny + data = np.arange(size).reshape(nx, ny) + + # make a quantity + data *= (u.electron / u.second) + + # make integer array + if asint: + data = data.astype(int) + + return data, wcsobj + + yield _make_fake + + +@pytest.fixture() +def fakedata(makefake): """ fixture to create fake data and wcs """ # set up initial parameters nx = 100 ny = 100 - size = nx * ny ra = 30. dec = 45. - # create the wcs - wcsobj = make_wcs(nx/2, ny/2, ra=ra, dec=dec) - wcsobj.bounding_box = ((0, nx), (0, ny)) - - # create the data - data = np.arange(size).reshape(nx, ny) * (u.electron / u.second) - - yield data, wcsobj + yield makefake(nx, ny, ra, dec) @pytest.fixture() @@ -108,7 +133,9 @@ def test_get_cutout(output_file, fakedata, quantity): data = data.value # create cutout - get_cutout(data, skycoord, wcs, size=10, outfile=output_file) + cutout = get_cutout(data, skycoord, wcs, size=10, outfile=output_file) + + assert_same_coord(5, 10, cutout, wcs) # test output with fits.open(output_file) as hdulist: @@ -140,3 +167,76 @@ def test_cutout_nofile(make_file, output_file): assert cutout.shape == (10, 10) +def test_cutout_poles(makefake): + """ test we can make cutouts around poles """ + # make fake zero data around the pole + ra, dec = 315.0, 89.995 + data, gwcs = makefake(1000, 1000, ra, dec, zero=True) + + # add some values (5x5 array) + data.value[245:250, 245:250] = 1 + + # check central pixel is correct + ss = gwcs(500, 500) + assert ss == (ra, dec) + + # set input cutout coord + cc = coord.SkyCoord(284.702, 89.986, unit=u.degree) + wcs = WCS(gwcs.to_fits_sip()) + + # get cutout + cutout = get_cutout(data, cc, wcs, size=50, write_file=False) + assert_same_coord(5, 10, cutout, wcs) + + # check cutout contains all data + assert len(np.where(cutout.data.value == 1)[0]) == 25 + + +def test_fail_cutout_outside(fakedata): + """ test we fail when cutout completely outside range """ + data, gwcs = fakedata + wcs = WCS(gwcs.to_fits_sip()) + cc = coord.SkyCoord(200.0, 50.0, unit=u.degree) + + with pytest.raises(RuntimeError, match='Could not create 2d cutout. The requested ' + 'cutout does not overlap with the original image'): + get_cutout(data, cc, wcs, size=50, write_file=False) + + +def assert_same_coord(x, y, cutout, wcs): + """ assert we get the same sky coordinate from cutout and original wcs """ + cutout_coord = pixel_to_skycoord(x, y, cutout.wcs) + ox, oy = cutout.to_original_position((x, y)) + orig_coord = pixel_to_skycoord(ox, oy, wcs) + assert cutout_coord == orig_coord + + +@pytest.mark.parametrize('asint, fill', [(False, None), (True, -9999)], ids=['fillfloat', 'fillint']) +def test_partial_cutout(makefake, asint, fill): + """ test we get a partial cutout with nans or fill value """ + ra, dec = 30.0, 45.0 + data, gwcs = makefake(100, 100, ra, dec, asint=asint) + + wcs = WCS(gwcs.to_fits_sip()) + cc = coord.SkyCoord(29.999, 44.998, unit=u.degree) + cutout = get_cutout(data, cc, wcs, size=50, write_file=False, fill_value=fill) + assert cutout.shape == (50, 50) + if asint: + assert -9999 in cutout.data + else: + assert np.isnan(cutout.data).any() + + +def test_bad_fill(makefake): + """ test error is raised on bad fill value """ + ra, dec = 30.0, 45.0 + data, gwcs = makefake(100, 100, ra, dec, asint=True) + wcs = WCS(gwcs.to_fits_sip()) + cc = coord.SkyCoord(29.999, 44.998, unit=u.degree) + with pytest.raises(ValueError, match='fill_value is inconsistent with the data type of the input array'): + get_cutout(data, cc, wcs, size=50, write_file=False) + + + + +