Skip to content

Commit

Permalink
cutout in threadpool
Browse files Browse the repository at this point in the history
  • Loading branch information
falkben committed Jun 13, 2023
1 parent 0baa937 commit 8e19da6
Showing 1 changed file with 23 additions and 4 deletions.
27 changes: 23 additions & 4 deletions astrocut/cube_cut.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

import os
import warnings
from concurrent.futures import ThreadPoolExecutor
from itertools import product
from time import time
from typing import Any, Dict
from typing import Any, Dict, Union

import astropy.units as u
import numpy as np
Expand Down Expand Up @@ -36,11 +37,23 @@ class CutoutFactory():
Future versions will include more generalized cutout functionality.
"""

def __init__(self):
def __init__(self, threads: Union[int, None] = None):
"""
Initialization function.
Parameters
----------
table_data : `~astropy.io.fits.fitsrec.FITS_rec`
The cube image header data table.
threads : int or `None`
Number of threads to use when making cutouts.
<=1 disables the threadpool, > 1 sets the number of threads,
None (default) uses Python's default threads
"""

self.threads = threads

self.cube_wcs = None # WCS information from the image cube
self.cutout_wcs = None # WCS information (linear) for the cutout
self.cutout_wcs_fit = {'WCS_MSEP': [None, "[deg] Max offset between cutout WCS and FFI WCS"],
Expand Down Expand Up @@ -424,8 +437,14 @@ def _get_cutout(self, transposed_cube, verbose=True):
ymax = ymax_cube

# Doing the cutout
cutout = transposed_cube[xmin:xmax, ymin:ymax, :, :]

if self.threads is None or self.threads > 1:
with ThreadPoolExecutor(max_workers=self.threads) as pool:
cutouts = list(pool.map(lambda x: transposed_cube[x, ymin:ymax, :, :], range(xmin, xmax)))
# stack the list of cutouts
cutout = np.stack(cutouts)
else:
cutout = transposed_cube[xmin:xmax, ymin:ymax, :, :]

Check warning on line 446 in astrocut/cube_cut.py

View check run for this annotation

Codecov / codecov/patch

astrocut/cube_cut.py#L446

Added line #L446 was not covered by tests

img_cutout = cutout[:, :, :, 0].transpose((2, 0, 1))
uncert_cutout = cutout[:, :, :, 1].transpose((2, 0, 1))

Expand Down

0 comments on commit 8e19da6

Please sign in to comment.