Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug fix for asdf cutouts with data as astropy quantity #114

Merged
merged 2 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 30 additions & 4 deletions astrocut/asdf_cutouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def get_center_pixel(gwcs: gwcs.wcs.WCS, ra: float, dec: float) -> tuple:


def get_cutout(data: asdf.tags.core.ndarray.NDArrayType, coords: Union[tuple, SkyCoord],
wcs: astropy.wcs.wcs.WCS = None, size: int = 20, outfile: str = "example_roman_cutout.fits"):
wcs: astropy.wcs.wcs.WCS = None, size: int = 20, outfile: str = "example_roman_cutout.fits",
write_file: bool = True) -> astropy.nddata.Cutout2D:
falkben marked this conversation as resolved.
Show resolved Hide resolved
""" Get a Roman image cutout

Cut out a square section from the input image data array. The ``coords`` can either be a tuple of x, y
Expand All @@ -72,6 +73,13 @@ def get_cutout(data: asdf.tags.core.ndarray.NDArrayType, coords: Union[tuple, Sk
the image cutout pizel size, by default 20
outfile : str, optional
the name of the output cutout file, by default "example_roman_cutout.fits"
write_file : bool, by default True
Flag to write the cutout to a file or not

Returns
-------
astropy.nddata.Cutout2D:
an image cutout object

Raises
------
Expand All @@ -86,12 +94,22 @@ def get_cutout(data: asdf.tags.core.ndarray.NDArrayType, coords: Union[tuple, Sk
# create the cutout
cutout = astropy.nddata.Cutout2D(data, position=coords, wcs=wcs, size=(size, size))

# check if the data is a quantity and get the array data
if isinstance(cutout.data, astropy.units.Quantity):
data = cutout.data.value
else:
data = cutout.data

# write the cutout to the output file
astropy.io.fits.writeto(outfile, data=cutout.data, header=cutout.wcs.to_header(), overwrite=True)
if write_file:
astropy.io.fits.writeto(outfile, data=data, header=cutout.wcs.to_header(), overwrite=True)

return cutout


def asdf_cut(input_file: str, ra: float, dec: float, cutout_size: int = 20,
output_file: str = "example_roman_cutout.fits"):
output_file: str = "example_roman_cutout.fits",
write_file: bool = True) -> astropy.nddata.Cutout2D:
""" Preliminary proof-of-concept functionality.

Takes a single ASDF input file (``input_file``) and generates a cutout of designated size ``cutout_size``
Expand All @@ -109,6 +127,13 @@ def asdf_cut(input_file: str, ra: float, dec: float, cutout_size: int = 20,
the image cutout pixel size, by default 20
output_file : str, optional
the name of the output cutout file, by default "example_roman_cutout.fits"
write_file : bool, by default True
Flag to write the cutout to a file or not

Returns
-------
astropy.nddata.Cutout2D:
an image cutout object
"""

# get the 2d image data
Expand All @@ -120,4 +145,5 @@ def asdf_cut(input_file: str, ra: float, dec: float, cutout_size: int = 20,
pixel_coordinates, wcs = get_center_pixel(gwcs, ra, dec)

# create the 2d image cutout
get_cutout(data, pixel_coordinates, wcs, size=cutout_size, outfile=output_file)
return get_cutout(data, pixel_coordinates, wcs, size=cutout_size, outfile=output_file,
write_file=write_file)
21 changes: 19 additions & 2 deletions astrocut/tests/test_asdf_cut.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@

import pathlib
import numpy as np
import pytest

Expand Down Expand Up @@ -47,7 +48,7 @@ def fakedata():
wcsobj.bounding_box = ((0, nx), (0, ny))

# create the data
data = np.arange(size).reshape(nx, ny)
data = np.arange(size).reshape(nx, ny) * (u.electron / u.second)

yield data, wcsobj

Expand Down Expand Up @@ -93,14 +94,19 @@ def output_file(tmp_path):
yield output_file


def test_get_cutout(output_file, fakedata):
@pytest.mark.parametrize('quantity', [True, False], ids=['quantity', 'array'])
def test_get_cutout(output_file, fakedata, quantity):
""" test we can create a cutout """

# get the input wcs
data, gwcs = fakedata
skycoord = gwcs(25, 25, with_units=True)
wcs = WCS(gwcs.to_fits_sip())

# convert quanity data back to array
if not quantity:
data = data.value

# create cutout
get_cutout(data, skycoord, wcs, size=10, outfile=output_file)

Expand All @@ -123,3 +129,14 @@ def test_asdf_cutout(make_file, output_file):
assert data.shape == (10, 10)
assert data[5, 5] == 2526


def test_cutout_nofile(make_file, output_file):
""" test we can make a cutout with no file output """
# make cutout
ra, dec = (29.99901792, 44.99930555)
cutout = asdf_cut(make_file, ra, dec, cutout_size=10, output_file=output_file, write_file=False)

assert not pathlib.Path(output_file).exists()
assert cutout.shape == (10, 10)


Loading