Skip to content

Commit

Permalink
Handle an s3 URI in asdf_cut() (#117)
Browse files Browse the repository at this point in the history
Handle an s3 URI in asdf_cut()
  • Loading branch information
snbianco authored May 21, 2024
1 parent 5d0d3e0 commit 56c32a2
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 23 deletions.
24 changes: 23 additions & 1 deletion astrocut/asdf_cutouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,28 @@
import astropy
import gwcs
import numpy as np
import s3fs

from astropy.coordinates import SkyCoord
from astropy.modeling import models


def _get_cloud_http(s3_uri: str) -> str:
""" Get the HTTP URI of a cloud resource from an S3 URI
Parameters
----------
s3_uri : string
the S3 URI of the cloud resource
"""
# create file system
fs = s3fs.S3FileSystem(anon=True)

# open resource and get URL
with fs.open(s3_uri, 'rb') as f:
return f.url()


def get_center_pixel(gwcsobj: gwcs.wcs.WCS, ra: float, dec: float) -> tuple:
""" Get the center pixel from a roman 2d science image
Expand Down Expand Up @@ -247,8 +264,13 @@ def asdf_cut(input_file: str, ra: float, dec: float, cutout_size: int = 20,
an image cutout object
"""

# if file comes from AWS cloud bucket, get HTTP URL to open with asdf
file = input_file
if isinstance(input_file, str) and input_file.startswith('s3://'):
file = _get_cloud_http(input_file)

# get the 2d image data
with asdf.open(input_file) as f:
with asdf.open(file) as f:
data = f['roman']['data']
gwcsobj = f['roman']['meta']['wcs']

Expand Down
42 changes: 31 additions & 11 deletions astrocut/tests/test_asdf_cut.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

import pathlib
from unittest.mock import MagicMock, patch
import numpy as np
import pytest

Expand All @@ -12,7 +13,7 @@
from astropy.wcs.utils import pixel_to_skycoord
from gwcs import wcs
from gwcs import coordinate_frames as cf
from astrocut.asdf_cutouts import get_center_pixel, get_cutout, asdf_cut, _slice_gwcs
from astrocut.asdf_cutouts import get_center_pixel, get_cutout, asdf_cut, _slice_gwcs, _get_cloud_http


def make_wcs(xsize, ysize, ra=30., dec=45.):
Expand Down Expand Up @@ -99,16 +100,6 @@ def make_file(tmp_path, fakedata):
yield filename


def test_get_center_pixel(fakedata):
""" test we can get the correct center pixel """
# get the fake data
__, gwcs = fakedata

pixel_coordinates, wcs = get_center_pixel(gwcs, 30., 45.)
assert np.allclose(pixel_coordinates, (np.array(500.), np.array(500.)))
assert np.allclose(wcs.celestial.wcs.crval, np.array([30., 45.]))


@pytest.fixture()
def output(tmp_path):
""" fixture to create the output path """
Expand All @@ -121,6 +112,16 @@ def _output_file(ext='fits'):
yield _output_file


def test_get_center_pixel(fakedata):
""" test we can get the correct center pixel """
# get the fake data
__, gwcs = fakedata

pixel_coordinates, wcs = get_center_pixel(gwcs, 30., 45.)
assert np.allclose(pixel_coordinates, (np.array(500.), np.array(500.)))
assert np.allclose(wcs.celestial.wcs.crval, np.array([30., 45.]))


@pytest.mark.parametrize('quantity', [True, False], ids=['quantity', 'array'])
def test_get_cutout(output, fakedata, quantity):
""" test we can create a cutout """
Expand Down Expand Up @@ -312,3 +313,22 @@ def test_slice_gwcs(fakedata):
# gwcs footprint/bounding_box expects ((x0, x1), (y0, y1)) but cutout.bbox is in ((y0, y1), (x0, x1))
assert (gwcsobj.footprint(bounding_box=tuple(reversed(cutout.bbox_original))) == sliced.footprint()).all()


@patch('s3fs.S3FileSystem')
def test_get_cloud_http(mock_s3fs):
""" test we can get HTTP URI of cloud resource """
# mock s3 file system operations
HTTP_URI = "http_test"
mock_file = MagicMock()
mock_fs = MagicMock()
mock_file.url.return_value = HTTP_URI
mock_fs.open.return_value.__enter__.return_value = mock_file
mock_s3fs.return_value = mock_fs

s3_uri = "s3://test_bucket/test_file.asdf"
http_uri = _get_cloud_http(s3_uri)

assert http_uri == HTTP_URI
mock_s3fs.assert_called_once_with(anon=True)
mock_fs.open.assert_called_once_with(s3_uri, 'rb')
mock_file.url.assert_called_once()
22 changes: 11 additions & 11 deletions astrocut/tests/test_make_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@ def test_make_cube(tmpdir):
ecube[:, :, i, 0] = -plane
ecube[:, :, i, 1] = plane
plane += img_sz*img_sz
assert np.alltrue(cube == ecube), "Cube values do not match expected values"
assert np.all(cube == ecube), "Cube values do not match expected values"

tab = Table(hdu[2].data)
assert np.alltrue(tab['TSTART'] == np.arange(num_im)), "TSTART mismatch in table"
assert np.alltrue(tab['TSTOP'] == np.arange(num_im)+1), "TSTOP mismatch in table"
assert np.all(tab['TSTART'] == np.arange(num_im)), "TSTART mismatch in table"
assert np.all(tab['TSTOP'] == np.arange(num_im)+1), "TSTOP mismatch in table"

filenames = np.array([path.split(x)[1] for x in ffi_files])
assert np.alltrue(tab['FFI_FILE'] == np.array(filenames)), "FFI_FILE mismatch in table"
assert np.all(tab['FFI_FILE'] == np.array(filenames)), "FFI_FILE mismatch in table"

hdu.close()

Expand Down Expand Up @@ -86,7 +86,7 @@ def test_make_and_update_cube(tmpdir):
# ecube[:, :, i, 1] = plane
plane += img_sz*img_sz

assert np.alltrue(cube == ecube), "Cube values do not match expected values"
assert np.all(cube == ecube), "Cube values do not match expected values"

hdu.close()

Expand All @@ -110,14 +110,14 @@ def test_make_and_update_cube(tmpdir):
# ecube[:, :, i, 1] = plane
plane += img_sz*img_sz

assert np.alltrue(cube == ecube), "Cube values do not match expected values"
assert np.all(cube == ecube), "Cube values do not match expected values"

tab = Table(hdu[2].data)
assert np.alltrue(tab['STARTTJD'] == np.arange(num_im)), "STARTTJD mismatch in table"
assert np.alltrue(tab['ENDTJD'] == np.arange(num_im)+1), "ENDTJD mismatch in table"
assert np.all(tab['STARTTJD'] == np.arange(num_im)), "STARTTJD mismatch in table"
assert np.all(tab['ENDTJD'] == np.arange(num_im)+1), "ENDTJD mismatch in table"

filenames = np.array([path.split(x)[1] for x in ffi_files])
assert np.alltrue(tab['FFI_FILE'] == np.array(filenames)), "FFI_FILE mismatch in table"
assert np.all(tab['FFI_FILE'] == np.array(filenames)), "FFI_FILE mismatch in table"

hdu.close()

Expand Down Expand Up @@ -156,7 +156,7 @@ def test_iteration(tmpdir, capsys):
cube_2 = hdu_2[1].data

assert cube_1.shape == cube_2.shape, "Mismatch between cube shape for 1 vs 2 iterations"
assert np.alltrue(cube_1 == cube_2), "Cubes made in 1 vs 2 iterations do not match"
assert np.all(cube_1 == cube_2), "Cubes made in 1 vs 2 iterations do not match"

# expected values for cube
ecube = np.zeros((img_sz, img_sz, num_im, 2))
Expand All @@ -168,7 +168,7 @@ def test_iteration(tmpdir, capsys):
ecube[:, :, i, 1] = plane
plane += img_sz*img_sz

assert np.alltrue(cube_1 == ecube), "Cube values do not match expected values"
assert np.all(cube_1 == ecube), "Cube values do not match expected values"


@pytest.mark.parametrize("ffi_type", ["TICA", "SPOC"])
Expand Down

0 comments on commit 56c32a2

Please sign in to comment.