From f0b75629fc161ab610f79da37c0ce9c68464cd4a Mon Sep 17 00:00:00 2001 From: Sam Bianco <70121323+snbianco@users.noreply.github.com> Date: Mon, 17 Jun 2024 22:43:18 -0400 Subject: [PATCH] Support for pathlib.Path and s3path.S3Path objects in asdf_cut() (#119) Support for pathlib.Path and s3path.S3Path objects in `asdf_cut` --- CHANGES.rst | 7 ++++++- astrocut/asdf_cutouts.py | 15 ++++++++------- astrocut/tests/test_asdf_cut.py | 9 ++++++++- setup.cfg | 1 + 4 files changed, 23 insertions(+), 9 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index a996c3c5..54da5bd9 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,4 +1,9 @@ -0.11.0 (Unreleased) +0.12.0 (Unreleased) +-------------------- + +- asdf_cut() function now accepts pathlib.Path and s3path.S3Path objects as an input file [#119] + +0.11.0 (2024-05-28) -------------------- - Add functionality for creating cutouts from the ASDF file format [#105] diff --git a/astrocut/asdf_cutouts.py b/astrocut/asdf_cutouts.py index 3f039a44..d0e9612c 100644 --- a/astrocut/asdf_cutouts.py +++ b/astrocut/asdf_cutouts.py @@ -10,18 +10,19 @@ import gwcs import numpy as np import s3fs +from s3path import S3Path from astropy.coordinates import SkyCoord from astropy.modeling import models -def _get_cloud_http(s3_uri: str) -> str: +def _get_cloud_http(s3_uri: Union[str, S3Path]) -> str: """ Get the HTTP URI of a cloud resource from an S3 URI. Parameters ---------- - s3_uri : string + s3_uri : string | S3Path the S3 URI of the cloud resource """ # create file system @@ -239,8 +240,8 @@ def _write_asdf(cutout: astropy.nddata.Cutout2D, gwcsobj: gwcs.wcs.WCS, outfile: af.write_to(outfile) -def asdf_cut(input_file: str, ra: float, dec: float, cutout_size: int = 20, - output_file: str = "example_roman_cutout.fits", +def asdf_cut(input_file: Union[str, pathlib.Path, S3Path], ra: float, dec: float, cutout_size: int = 20, + output_file: Union[str, pathlib.Path] = "example_roman_cutout.fits", write_file: bool = True, fill_value: Union[int, float] = np.nan) -> astropy.nddata.Cutout2D: """ Takes a single ASDF input file (`input_file`) and generates a cutout of designated size `cutout_size` @@ -250,7 +251,7 @@ def asdf_cut(input_file: str, ra: float, dec: float, cutout_size: int = 20, Parameters ---------- - input_file : str + input_file : str | Path | S3Path The input ASDF file. ra : float The right ascension of the central cutout. @@ -258,7 +259,7 @@ def asdf_cut(input_file: str, ra: float, dec: float, cutout_size: int = 20, The declination of the central cutout. cutout_size : int Optional, default 20. The image cutout pixel size. - output_file : str + output_file : str | Path Optional, default "example_roman_cutout.fits". The name of the output cutout file. write_file : bool Optional, default True. Flag to write the cutout to a file or not. @@ -273,7 +274,7 @@ def asdf_cut(input_file: str, ra: float, dec: float, cutout_size: int = 20, # if file comes from AWS cloud bucket, get HTTP URL to open with asdf file = input_file - if isinstance(input_file, str) and input_file.startswith('s3://'): + if (isinstance(input_file, str) and input_file.startswith('s3://')) or isinstance(input_file, S3Path): file = _get_cloud_http(input_file) # get the 2d image data diff --git a/astrocut/tests/test_asdf_cut.py b/astrocut/tests/test_asdf_cut.py index 70c0c968..ae39c43f 100644 --- a/astrocut/tests/test_asdf_cut.py +++ b/astrocut/tests/test_asdf_cut.py @@ -13,6 +13,7 @@ from astropy.wcs.utils import pixel_to_skycoord from gwcs import wcs from gwcs import coordinate_frames as cf +from s3path import S3Path from astrocut.asdf_cutouts import get_center_pixel, asdf_cut, _get_cutout, _slice_gwcs, _get_cloud_http @@ -325,10 +326,16 @@ def test_get_cloud_http(mock_s3fs): mock_fs.open.return_value.__enter__.return_value = mock_file mock_s3fs.return_value = mock_fs + # test function with string input s3_uri = "s3://test_bucket/test_file.asdf" http_uri = _get_cloud_http(s3_uri) - assert http_uri == HTTP_URI mock_s3fs.assert_called_once_with() mock_fs.open.assert_called_once_with(s3_uri, 'rb') mock_file.url.assert_called_once() + + # test function with S3Path input + s3_uri_path = S3Path("test_bucket/test_file_2.asdf") + http_uri_path = _get_cloud_http(s3_uri_path) + assert http_uri_path == HTTP_URI + mock_fs.open.assert_called_with(s3_uri_path, 'rb') diff --git a/setup.cfg b/setup.cfg index d8bcbe06..723df454 100644 --- a/setup.cfg +++ b/setup.cfg @@ -20,6 +20,7 @@ install_requires = astropy>=5.2 # astropy with s3fs support fsspec[http]>=2022.8.2 # for remote cutouts s3fs>=2022.8.2 # for remote cutouts + s3path>=0.5.7 # for remote file paths roman_datamodels>=0.17.0 # for roman file support scipy Pillow