diff --git a/karabo/imaging/image.py b/karabo/imaging/image.py index 6235dc4c..09df0743 100644 --- a/karabo/imaging/image.py +++ b/karabo/imaging/image.py @@ -49,8 +49,8 @@ class Image(KaraboResource): def __init__( self, - path: Optional[str] = None, - data: Optional[np.ndarray] = None, + path: Optional[Union[str, FilePathType]] = None, + data: Optional[np.ndarray[np.float_]] = None, # type: ignore header: Optional[fits.header.Header] = None, **kwargs: Any, ) -> None: @@ -174,7 +174,9 @@ def _update_header_after_resize(self) -> None: self.header["CDELT1"] = self.header["CDELT1"] * old_shape[1] / new_shape[1] self.header["CDELT2"] = self.header["CDELT2"] * old_shape[0] / new_shape[0] - def cutout(self, center_xy: Tuple[float, float], size_xy: Tuple[float, float]) -> Image: + def cutout( + self, center_xy: Tuple[float, float], size_xy: Tuple[float, float] + ) -> Image: """ Cutout the image to the given size and center. @@ -196,7 +198,7 @@ def cutout(self, center_xy: Tuple[float, float], size_xy: Tuple[float, float]) - def update_header_from_image_header( new_header: fits.header.Header, old_header: fits.header.Header, - keys_to_copy=HEADER_KEYS_TO_COPY_AFTER_CUTOUT, + keys_to_copy: List[str] = HEADER_KEYS_TO_COPY_AFTER_CUTOUT, ) -> fits.header.Header: for key in keys_to_copy: if key in old_header and key not in new_header: @@ -491,10 +493,9 @@ def get_cellsize(self) -> np.float64: def get_wcs(self) -> WCS: return WCS(self.header) - def get_2d_wcs(self, ra_dec_axis: Tuple[IntLike, IntLike] = [1, 2]) -> WCS: + def get_2d_wcs(self, ra_dec_axis: Tuple[int, int] = tuple(1,2)) -> WCS: wcs = WCS(self.header) - wcs_2d = wcs.sub(ra_dec_axis) # type: ignore - + wcs_2d = wcs.sub(ra_dec_axis) return wcs_2d @@ -532,7 +533,7 @@ class ImageMosaicker: def __init__( self, - reproject_function: Callable = reproject_interp, + reproject_function: Callable[..., Any] = reproject_interp, combine_function: str = "mean", match_background: bool = False, background_reference: Optional[int] = None, @@ -575,11 +576,11 @@ def process( images: List[Union[str, fits.HDUList, fits.PrimaryHDU, NDData]], projection: str = "SIN", weights: Optional[ - List[Union[str, fits.HDUList, fits.PrimaryHDU, np.ndarray]] + List[Union[str, fits.HDUList, fits.PrimaryHDU, np.ndarray[np.float_]]], ] = None, shape_out: Optional[Tuple[int]] = None, image_for_header: Optional[Image] = None, - ) -> Tuple[Image, np.ndarray]: + ) -> Tuple[Image, np.ndarray[np.float_]]: """ Combine the provided images into a single mosaicked image. diff --git a/karabo/sourcedetection/result.py b/karabo/sourcedetection/result.py index cffe3de6..0fdfd3b3 100644 --- a/karabo/sourcedetection/result.py +++ b/karabo/sourcedetection/result.py @@ -3,7 +3,7 @@ import os import shutil import tempfile -from typing import Any, List, Optional, Tuple, Type, TypeVar +from typing import Any, List, Optional, Tuple, Type, TypeVar, Union from warnings import warn import bdsf @@ -52,11 +52,9 @@ def __init__( @classmethod def detect_sources_in_image( cls: Type[T], - image: Image, + image: Union[Image, List[Image]], beam: Optional[Tuple[float, float, float]] = None, quiet: bool = False, - n_splits: int = 0, - overlap: int = 0, use_dask: Optional[bool] = None, client: Optional[Any] = None, **kwargs: Any, @@ -68,8 +66,9 @@ def detect_sources_in_image( ---------- cls : Type[T] The class on which this method is called. - image : Image - Image object for source detection. + image : Image or List[Image] + Image object for source detection. Can be a single image or a list of + images. beam : Optional[Tuple[float, float, float]], optional The Full Width Half Maximum (FWHM) of the restoring beam, given as a tuple (major axis, minor axis, position angle). If None, tries to extract from @@ -131,6 +130,19 @@ def detect_sources_in_image( if use_dask and not client: client = DaskHandler.get_dask_client() + if isinstance(image, List): + if beam is None: + warn( + KaraboWarning( + "Beam was not passed, trying to extract from image metadata." + ) + ) + beam = ( + image[0].header["BMAJ"], + image[0].header["BMIN"], + image[0].header["BPA"], + ) + if beam is None: if image.has_beam_parameters(): beam = (image.header["BMAJ"], image.header["BMIN"], image.header["BPA"]) @@ -142,13 +154,12 @@ def detect_sources_in_image( ) try: - if n_splits > 1: + if isinstance(image, List): # Check if there is a dask client if DaskHandler.dask_client is None: _ = DaskHandler.get_dask_client() - cutouts = image.split_image(n_splits, overlap) results = [] - for cutout in cutouts: + for cutout in image: results.append( delayed(bdsf.process_image)( input=cutout.path, diff --git a/karabo/test/test_source_detection.py b/karabo/test/test_source_detection.py index 810ec90a..0a8e5503 100644 --- a/karabo/test/test_source_detection.py +++ b/karabo/test/test_source_detection.py @@ -201,7 +201,7 @@ def test_automatic_assignment_of_ground_truth_and_prediction(): def test_source_detection(tobject: TFiles): restored = Image.read_from_file(tobject.restored_fits) - detection_results = PyBDSFSourceDetectionResult.detect_sources_in_image( + detection_result = PyBDSFSourceDetectionResult.detect_sources_in_image( restored, thresh_isl=15, thresh_pix=20 ) gtruth = np.array( @@ -215,17 +215,21 @@ def test_source_detection(tobject: TFiles): [1212.06660484, 930.03800074], ] ) - detected = detection_results.get_pixel_position_of_sources() - closest_distances = np.linalg.norm(gtruth - detected, axis=1) - assert np.all(closest_distances < 5), "Source detection is not correct" + detected = detection_result.get_pixel_position_of_sources() + mse = np.linalg.norm(gtruth - detected, axis=1) + assert np.all(mse < 1), "Source detection is not correct" # Now compare it with splitting the image + restored_cuts = restored.split_image(N=2, overlap=100) detection_results = PyBDSFSourceDetectionResult.detect_sources_in_image( - restored, thresh_isl=15, thresh_pix=20, n_splits=4 + restored_cuts, thresh_isl=15, thresh_pix=20 ) - assert len(detection_results) == 4, "Splitting the image did not work" detected = detection_results.get_pixel_position_of_sources() - assert np.all(closest_distances < 5), "Source detection is not correct" + # Sometimes the order of the sources is different, so we need to sort them + detected = detected[np.argsort(detected[:, 0])] + gtruth = gtruth[np.argsort(gtruth[:, 0])] + mse = np.linalg.norm(gtruth - detected, axis=1) + assert np.all(mse < 1), "Source detection is not correct" @pytest.mark.skipif(not RUN_GPU_TESTS, reason="GPU tests are disabled")