Skip to content

Commit

Permalink
added unit tests for validating cutouts and proc time
Browse files Browse the repository at this point in the history
  • Loading branch information
jaymedina committed Jun 29, 2023
1 parent 8e19da6 commit 83573da
Showing 1 changed file with 43 additions and 0 deletions.
43 changes: 43 additions & 0 deletions astrocut/tests/test_cube_cut.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import astropy.units as u
import numpy as np
import pytest
import time
from astropy import wcs
from astropy.coordinates import SkyCoord
from astropy.io import fits
Expand Down Expand Up @@ -527,3 +528,45 @@ def test_s3_cube_cut(tmp_path: Path):
assert np.isclose(hdulist[1].data["FLUX_ERR"][200][1, 2], 1.1239403)
assert hdulist[0].header["CAMERA"] == 2
hdulist.close()


@pytest.mark.parametrize("cutout_size", [[5, 10], 20, 31])
def test_multiprocessing(cutout_size, tmp_path):

tmpdir = str(tmp_path)

coord = SkyCoord(217.42893801, -62.67949189, unit="deg", frame="icrs")
cube_file = "s3://stpubdata/tess/public/mast/tess-s0038-2-2-cube.fits"

cutf_0threads = CutoutFactory(threads=0)
start_0time = time.time()
cutout_0threads = cutf_0threads.cube_cut(cube_file, coordinates=coord,
output_path=tmpdir, verbose=False,
cutout_size=cutout_size)
time_0threads = time.time() - start_0time

cutf_4threads = CutoutFactory(threads=4)
start_4time = time.time()
cutout_4threads = cutf_4threads.cube_cut(cube_file, coordinates=coord,
output_path=tmpdir, verbose=False,
cutout_size=cutout_size)
time_4threads = time.time() - start_4time

cutf_8threads = CutoutFactory(threads=8)
start_8time = time.time()
cutout_8threads = cutf_8threads.cube_cut(cube_file, coordinates=coord,
output_path=tmpdir, verbose=False,
cutout_size=cutout_size)
time_8threads = time.time() - start_8time

y, x = 1, 2
index = np.random.randint(0, len(fits.getdata(cutout_0threads)["FLUX"]) - 1)
pixels, means = [], []
for cutout in [cutout_0threads, cutout_4threads, cutout_8threads]:

pixels.append(fits.getdata(cutout)["FLUX"][index][y, x])

This comment has been minimized.

Copy link
@falkben

falkben Jun 29, 2023

Member

I think we could probably do an np.array_equal comparison instead.

means.append(np.mean(fits.getdata(cutout)["FLUX"][index]))

assert len(set(pixels)) == 1, f"pixel values are different for coord (1, 2): {pixels}"
assert len(set(means)) == 1
assert time_0threads > time_4threads > time_8threads

This comment has been minimized.

Copy link
@falkben

falkben Jun 29, 2023

Member

I don't think this specific assertion (test0 > test4 > test8) is necessary and probably can cause CI failures which can be frustrating to deal with.

1 comment on commit 83573da

@falkben
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to undo this commit, I think I have an idea of what we want here, and can write the test.

Please sign in to comment.