Skip to content

Commit

Permalink
Started fixing mypy.
Browse files Browse the repository at this point in the history
  • Loading branch information
kenfus committed Nov 8, 2023
1 parent e49dc43 commit 3c5e59d
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 26 deletions.
21 changes: 11 additions & 10 deletions karabo/imaging/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@
class Image(KaraboResource):
def __init__(
self,
path: Optional[str] = None,
data: Optional[np.ndarray] = None,
path: Optional[Union[str, FilePathType]] = None,
data: Optional[np.ndarray[np.float_]] = None, # type: ignore
header: Optional[fits.header.Header] = None,
**kwargs: Any,
) -> None:
Expand Down Expand Up @@ -174,7 +174,9 @@ def _update_header_after_resize(self) -> None:
self.header["CDELT1"] = self.header["CDELT1"] * old_shape[1] / new_shape[1]
self.header["CDELT2"] = self.header["CDELT2"] * old_shape[0] / new_shape[0]

def cutout(self, center_xy: Tuple[float, float], size_xy: Tuple[float, float]) -> Image:
def cutout(
self, center_xy: Tuple[float, float], size_xy: Tuple[float, float]
) -> Image:
"""
Cutout the image to the given size and center.
Expand All @@ -196,7 +198,7 @@ def cutout(self, center_xy: Tuple[float, float], size_xy: Tuple[float, float]) -
def update_header_from_image_header(
new_header: fits.header.Header,
old_header: fits.header.Header,
keys_to_copy=HEADER_KEYS_TO_COPY_AFTER_CUTOUT,
keys_to_copy: List[str] = HEADER_KEYS_TO_COPY_AFTER_CUTOUT,
) -> fits.header.Header:
for key in keys_to_copy:
if key in old_header and key not in new_header:
Expand Down Expand Up @@ -491,10 +493,9 @@ def get_cellsize(self) -> np.float64:
def get_wcs(self) -> WCS:
return WCS(self.header)

def get_2d_wcs(self, ra_dec_axis: Tuple[IntLike, IntLike] = [1, 2]) -> WCS:
def get_2d_wcs(self, ra_dec_axis: Tuple[int, int] = tuple(1,2)) -> WCS:
wcs = WCS(self.header)
wcs_2d = wcs.sub(ra_dec_axis) # type: ignore

wcs_2d = wcs.sub(ra_dec_axis)
return wcs_2d


Expand Down Expand Up @@ -532,7 +533,7 @@ class ImageMosaicker:

def __init__(
self,
reproject_function: Callable = reproject_interp,
reproject_function: Callable[..., Any] = reproject_interp,
combine_function: str = "mean",
match_background: bool = False,
background_reference: Optional[int] = None,
Expand Down Expand Up @@ -575,11 +576,11 @@ def process(
images: List[Union[str, fits.HDUList, fits.PrimaryHDU, NDData]],
projection: str = "SIN",
weights: Optional[
List[Union[str, fits.HDUList, fits.PrimaryHDU, np.ndarray]]
List[Union[str, fits.HDUList, fits.PrimaryHDU, np.ndarray[np.float_]]],
] = None,
shape_out: Optional[Tuple[int]] = None,
image_for_header: Optional[Image] = None,
) -> Tuple[Image, np.ndarray]:
) -> Tuple[Image, np.ndarray[np.float_]]:
"""
Combine the provided images into a single mosaicked image.
Expand Down
29 changes: 20 additions & 9 deletions karabo/sourcedetection/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import shutil
import tempfile
from typing import Any, List, Optional, Tuple, Type, TypeVar
from typing import Any, List, Optional, Tuple, Type, TypeVar, Union
from warnings import warn

import bdsf
Expand Down Expand Up @@ -52,11 +52,9 @@ def __init__(
@classmethod
def detect_sources_in_image(
cls: Type[T],
image: Image,
image: Union[Image, List[Image]],
beam: Optional[Tuple[float, float, float]] = None,
quiet: bool = False,
n_splits: int = 0,
overlap: int = 0,
use_dask: Optional[bool] = None,
client: Optional[Any] = None,
**kwargs: Any,
Expand All @@ -68,8 +66,9 @@ def detect_sources_in_image(
----------
cls : Type[T]
The class on which this method is called.
image : Image
Image object for source detection.
image : Image or List[Image]
Image object for source detection. Can be a single image or a list of
images.
beam : Optional[Tuple[float, float, float]], optional
The Full Width Half Maximum (FWHM) of the restoring beam, given as a tuple
(major axis, minor axis, position angle). If None, tries to extract from
Expand Down Expand Up @@ -131,6 +130,19 @@ def detect_sources_in_image(
if use_dask and not client:
client = DaskHandler.get_dask_client()

if isinstance(image, List):
if beam is None:
warn(
KaraboWarning(
"Beam was not passed, trying to extract from image metadata."
)
)
beam = (
image[0].header["BMAJ"],
image[0].header["BMIN"],
image[0].header["BPA"],
)

if beam is None:
if image.has_beam_parameters():
beam = (image.header["BMAJ"], image.header["BMIN"], image.header["BPA"])
Expand All @@ -142,13 +154,12 @@ def detect_sources_in_image(
)

try:
if n_splits > 1:
if isinstance(image, List):
# Check if there is a dask client
if DaskHandler.dask_client is None:
_ = DaskHandler.get_dask_client()
cutouts = image.split_image(n_splits, overlap)
results = []
for cutout in cutouts:
for cutout in image:
results.append(
delayed(bdsf.process_image)(
input=cutout.path,
Expand Down
18 changes: 11 additions & 7 deletions karabo/test/test_source_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def test_automatic_assignment_of_ground_truth_and_prediction():

def test_source_detection(tobject: TFiles):
restored = Image.read_from_file(tobject.restored_fits)
detection_results = PyBDSFSourceDetectionResult.detect_sources_in_image(
detection_result = PyBDSFSourceDetectionResult.detect_sources_in_image(
restored, thresh_isl=15, thresh_pix=20
)
gtruth = np.array(
Expand All @@ -215,17 +215,21 @@ def test_source_detection(tobject: TFiles):
[1212.06660484, 930.03800074],
]
)
detected = detection_results.get_pixel_position_of_sources()
closest_distances = np.linalg.norm(gtruth - detected, axis=1)
assert np.all(closest_distances < 5), "Source detection is not correct"
detected = detection_result.get_pixel_position_of_sources()
mse = np.linalg.norm(gtruth - detected, axis=1)
assert np.all(mse < 1), "Source detection is not correct"

# Now compare it with splitting the image
restored_cuts = restored.split_image(N=2, overlap=100)
detection_results = PyBDSFSourceDetectionResult.detect_sources_in_image(
restored, thresh_isl=15, thresh_pix=20, n_splits=4
restored_cuts, thresh_isl=15, thresh_pix=20
)
assert len(detection_results) == 4, "Splitting the image did not work"
detected = detection_results.get_pixel_position_of_sources()
assert np.all(closest_distances < 5), "Source detection is not correct"
# Sometimes the order of the sources is different, so we need to sort them
detected = detected[np.argsort(detected[:, 0])]
gtruth = gtruth[np.argsort(gtruth[:, 0])]
mse = np.linalg.norm(gtruth - detected, axis=1)
assert np.all(mse < 1), "Source detection is not correct"


@pytest.mark.skipif(not RUN_GPU_TESTS, reason="GPU tests are disabled")
Expand Down

0 comments on commit 3c5e59d

Please sign in to comment.