From 4362daaaf1c108344eb8cca8040ce0fd36104302 Mon Sep 17 00:00:00 2001 From: Kevin Cortacero Date: Sat, 4 Jan 2025 22:57:10 +0100 Subject: [PATCH] refactoring --- src/kartezio/core/components.py | 124 ++++---- src/kartezio/core/endpoints.py | 489 ++++++++++++++++++++----------- src/kartezio/preprocessing.py | 2 +- src/kartezio/vision/watershed.py | 391 ++++++++++++++++++++++++ 4 files changed, 781 insertions(+), 225 deletions(-) diff --git a/src/kartezio/core/components.py b/src/kartezio/core/components.py index 86df228..f034abe 100644 --- a/src/kartezio/core/components.py +++ b/src/kartezio/core/components.py @@ -138,58 +138,59 @@ def name_of(component_class: type) -> str: def display(): pprint(Components._registry) + def add_as(self, fundamental: type, replace: type = None): + """ + Register a component to the Components registry. -def register(fundamental: type, replace: type = None): - """ - Register a component to the Components registry. + Args: + fundamental (type): The fundamental type of the component. + replace (type): If not None, replace an existing component with the type. - Args: - fundamental (type): The fundamental type of the component. - replace (type): If not None, replace an existing component with the type. + Returns: + Callable: A decorator for registering the component. + """ + fundamental_name = fundamental.__name__ + + def inner(item_cls): + name = item_cls.__name__ + if Components._contains(fundamental_name, name): + if not replace: + raise KeyError( + f"""Error registering {fundamental_name} called '{name}'. + Here is the list of all registered {fundamental_name} components: + \n{Components._registry[fundamental_name].keys()}. + \n > Replace it using 'replace=True' in @register, or use another name. + """ + ) + if replace: + replace_name = replace.__name__ + if Components._contains(fundamental_name, replace_name): + print( + f"Component '{fundamental_name}/{replace_name}' will be replaced by '{name}'" + ) + Components.add(fundamental_name, replace_name, item_cls) + else: + Components.add(fundamental_name, name, item_cls) + return item_cls - Returns: - Callable: A decorator for registering the component. - """ - fundamental_name = fundamental.__name__ - - def inner(item_cls): - name = item_cls.__name__ - if Components._contains(fundamental_name, name): - if not replace: - raise KeyError( - f"""Error registering {fundamental_name} called '{name}'. - Here is the list of all registered {fundamental_name} components: - \n{Components._registry[fundamental_name].keys()}. - \n > Replace it using 'replace=True' in @register, or use another name. - """ - ) - if replace: - replace_name = replace.__name__ - if Components._contains(fundamental_name, replace_name): - print( - f"Component '{fundamental_name}/{replace_name}' will be replaced by '{name}'" - ) - Components.add(fundamental_name, replace_name, item_cls) - else: - Components.add(fundamental_name, name, item_cls) - return item_cls + return inner - return inner + def declare(self): + """ + Register a fundamental component to the Components registry. + Returns: + Callable: A decorator for registering the fundamental component. + """ -def component(): - """ - Register a fundamental component to the Components registry. + def inner(item_cls): + Components.add_component(item_cls.__name__) + return item_cls - Returns: - Callable: A decorator for registering the fundamental component. - """ + return inner - def inner(item_cls): - Components.add_component(item_cls.__name__) - return item_cls - return inner +registry = Components() def load_component( @@ -223,7 +224,7 @@ def dump_component(component: KartezioComponent) -> Dict: return base_dict -@component() +@registry.declare() class Node(KartezioComponent, ABC): """ Abstract base class for a Node in the CGP framework. @@ -232,7 +233,7 @@ class Node(KartezioComponent, ABC): pass -@component() +@registry.declare() class Preprocessing(Node, ABC): """ Preprocessing node, called before training loop. @@ -265,7 +266,7 @@ def then(self, preprocessing: "Preprocessing"): return self -@component() +@registry.declare() class Primitive(Node, ABC): """ Primitive function called inside the CGP Graph. @@ -286,7 +287,7 @@ def __to_dict__(self) -> Dict: return {"name": self.name} -@component() +@registry.declare() class Genotype(KartezioComponent): """ Represents the genotype for Cartesian Genetic Programming (CGP). @@ -392,7 +393,7 @@ def clone(self) -> "Genotype": return copy.deepcopy(self) -@component() +@registry.declare() class Reducer(Node, ABC): def batch(self, x: List): y = [] @@ -405,7 +406,7 @@ def reduce(self, x): pass -@component() +@registry.declare() class Endpoint(Node, ABC): """ Represents the final node in a CGP graph, responsible for producing the final outputs. @@ -425,7 +426,6 @@ def __init__(self, inputs: List[KType]): @classmethod def __from_dict__(cls, dict_infos: Dict) -> "Endpoint": - from kartezio.core.endpoints import Endpoint """ Create an Endpoint instance from a dictionary representation. @@ -441,8 +441,14 @@ def __from_dict__(cls, dict_infos: Dict) -> "Endpoint": **dict_infos["args"], ) + @classmethod + def from_config(cls, config): + return registry.instantiate( + cls.__name__, config["name"], **config["args"] + ) + -@component() +@registry.declare() class Fitness(KartezioComponent, ABC): def __init__(self, reduction="mean"): super().__init__() @@ -482,6 +488,7 @@ def evaluate(self, y_true, y_pred): @classmethod def __from_dict__(cls, dict_infos: Dict) -> "Fitness": from kartezio.core.fitness import Fitness + return Components.instantiate( "Fitness", dict_infos["name"], @@ -489,7 +496,7 @@ def __from_dict__(cls, dict_infos: Dict) -> "Fitness": ) -@component() +@registry.declare() class Library(KartezioComponent): def __init__(self, rtype): super().__init__() @@ -596,7 +603,7 @@ def size(self): return len(self._primitives) -@component() +@registry.declare() class Mutation(KartezioComponent, ABC): def __init__(self, adapter): super().__init__() @@ -685,7 +692,7 @@ def __to_dict__(self) -> Dict: return {} -@component() +@registry.declare() class Initialization(KartezioComponent, ABC): """ """ @@ -693,7 +700,7 @@ def __init__(self): super().__init__() -@register(Initialization) +@registry.add_as(Initialization) class CopyGenotype(Initialization): @classmethod def __from_dict__(cls, dict_infos: Dict) -> "CopyGenotype": @@ -707,7 +714,7 @@ def mutate(self, genotype): return self.genotype.clone() -@register(Initialization) +@registry.add_as(Initialization) class RandomInit(Initialization, Mutation): """ Can be used to initialize genome (genome) randomly @@ -736,3 +743,8 @@ def mutate(self, genotype: Genotype): def random(self): genotype = self.adapter.new_genotype() return self.mutate(genotype) + + +if __name__ == "__main__": + registry.display() + print("Done!") diff --git a/src/kartezio/core/endpoints.py b/src/kartezio/core/endpoints.py index 88bc1e8..61f538d 100644 --- a/src/kartezio/core/endpoints.py +++ b/src/kartezio/core/endpoints.py @@ -1,41 +1,332 @@ +from abc import ABC from typing import Dict import cv2 import numpy as np -from scipy import ndimage -from skimage.feature import peak_local_max -from skimage.segmentation import watershed from skimage.transform import hough_ellipse -from kartezio.core.components import Endpoint, register -from kartezio.preprocessing import Resize -from kartezio.primitives.array import Sobel +from kartezio.core.components import Endpoint, registry + +# from kartezio.preprocessing import Resize +# from kartezio.primitives.array import Sobel from kartezio.types import TypeArray, TypeLabels from kartezio.vision.common import ( - WatershedSkimage, contours_fill, contours_find, image_new, threshold_tozero, ) +from kartezio.vision.watershed import ( + _connected_components, + distance_watershed, + double_threshold_watershed, + local_max_watershed, + marker_controlled_watershed, + threshold_local_max_watershed, + threshold_watershed, +) -@register(Endpoint) -class ToLabels(Endpoint): +class EndpointWatershed(Endpoint, ABC): + def __init__(self, arity, watershed_line=True): + super().__init__([TypeArray] * arity) + self.watershed_line = watershed_line + + +class PeakedMarkersWatershed(EndpointWatershed, ABC): + def __init__(self, watershed_line=True, min_distance=1, downsample=0): + super().__init__(1, watershed_line=watershed_line) + self.min_distance = min_distance + self.downsample = downsample + + +@registry.add_as(Endpoint) +class MarkerControlledWatershed(EndpointWatershed): + """ + MarkerControlledWatershed + + An endpoint class for a Cartesian Genetic Programming pipeline (or other plugin system) + that applies a marker-controlled watershed algorithm to segment an image. + + The input is expected to be a list/tuple of two NumPy arrays: + 1) `x[0]`: The primary image (e.g., grayscale or distance-transformed) to be segmented. + 2) `x[1]`: The binary or labeled marker array specifying the regions or points + from which the watershed should grow. + + Parameters + ---------- + watershed_line : bool, optional + If True, the watershed algorithm computes a "line" (pixel value = 0) + that separates adjacent segmented regions. Defaults to True. + + Examples + -------- + >>> # Suppose 'img' is a 2D NumPy array (grayscale) + >>> # and 'markers' is a 2D array of the same shape, with nonzero regions. + >>> segmenter = MarkerControlledWatershed(watershed_line=False) + >>> segmented = segmenter.call([img, markers])[0] + >>> segmented.shape + (height, width) + """ + + def __init__(self, watershed_line=True): + super().__init__(2, watershed_line=watershed_line) + + def call(self, x): + """ + Apply the marker-controlled watershed transform to the given input and marker arrays. + + Parameters + ---------- + x : list or tuple of np.ndarray + A container with two arrays: + - x[0]: The primary image (2D ndarray) to be segmented. + - x[1]: The marker array (2D ndarray) identifying regions/seeds. + + Returns + ------- + list of np.ndarray + A single-element list containing the segmented (labeled) image as a 2D ndarray. + Each connected region is assigned a unique integer label. + """ + return [marker_controlled_watershed(x[0], x[1], self.watershed_line)] + + +@registry.add_as(Endpoint) +class LocalMaxWatershed(PeakedMarkersWatershed): + """ + LocalMaxWatershed + + An endpoint that finds local maxima within a single input image (which may + be a grayscale or distance-transformed image) to generate markers, then + applies the watershed transform. + + Parameters + ---------- + min_distance : int, optional + The minimum distance separating local maxima. Defaults to 10. + watershed_line : bool, optional + If True, produces watershed lines (pixel = 0) where regions meet. + Defaults to True. + downsample : int, optional + If > 0, downsample the input by 2^downsample before detecting maxima. + Defaults to 0 (no downsampling). + """ + + def __init__( + self, watershed_line: bool, min_distance: int, downsample: int = 0 + ): + super().__init__( + watershed_line, min_distance, downsample + ) # Single input image + + def call(self, x): + """ + Perform local-maxima-based watershed on the input image. + + Parameters + ---------- + x : list of np.ndarray + - x[0] : The image on which to perform watershed (2D ndarray). + + Returns + ------- + list of np.ndarray + A single-element list containing the labeled segmentation result. + """ + return [ + local_max_watershed( + image=x[0], + min_distance=self.min_distance, + watershed_line=self.watershed_line, + downsample=self.downsample, + ) + ] + + +@registry.add_as(Endpoint) +class DistanceWatershed(PeakedMarkersWatershed): + """ + DistanceWatershed + + An endpoint that computes a distance transform internally, finds local maxima, + and applies a watershed transform. Typically used for segmenting binary masks + (foreground vs. background). + + Parameters + ---------- + min_distance : int, optional + Minimum distance to separate local maxima in the distance map. Defaults to 10. + watershed_line : bool, optional + Whether to produce watershed lines. Defaults to True. + downsample : int, optional + If > 0, downsample the distance-transformed image before local maxima detection. + Defaults to 0. + """ + + def __init__( + self, watershed_line: bool, min_distance: int, downsample: int = 0 + ): + super().__init__( + watershed_line, min_distance, downsample + ) # Single input (binary mask recommended) + + def call(self, x): + """ + Perform distance-based watershed on a binary mask or grayscale image. + + Parameters + ---------- + x : list of np.ndarray + - x[0] : The input image (2D ndarray), typically a binary mask. + + Returns + ------- + list of np.ndarray + A single-element list containing the labeled segmentation. + """ + return [ + distance_watershed( + image=x[0], + min_distance=self.min_distance, + watershed_line=self.watershed_line, + downsample=self.downsample, + ) + ] + + +@registry.add_as(Endpoint) +class ThresholdLocalMaxWatershed(PeakedMarkersWatershed): + """ + ThresholdLocalMaxWatershed + + An endpoint that first thresholds the input image (zeroing out pixels below + the threshold), then detects local maxima in the thresholded image and applies + a watershed transform. + + Parameters + ---------- + threshold : float, optional + Pixel intensity threshold. Pixels below are zeroed out. Defaults to 128.0. + min_distance : int, optional + Minimum distance separating local maxima. Defaults to 10. + watershed_line : bool, optional + If True, produce watershed lines. Defaults to True. + downsample : int, optional + If > 0, downsample before local maxima detection. Defaults to 0. + """ + + def __init__( + self, + watershed_line: bool = True, + min_distance: int = 10, + downsample: int = 0, + threshold: int = 128, + ): + super().__init__( + watershed_line, min_distance, downsample + ) # Single input image + self.threshold = threshold + self.min_distance = min_distance + self.watershed_line = watershed_line + self.downsample = downsample + + def call(self, x): + """ + Apply threshold + local maxima + watershed to the input image. + + Parameters + ---------- + x : list of np.ndarray + - x[0] : The input image (2D ndarray) to be thresholded & segmented. + + Returns + ------- + list of np.ndarray + A single-element list containing the labeled segmentation result. + """ + return [ + threshold_local_max_watershed( + image=x[0], + threshold=self.threshold, + min_distance=self.min_distance, + watershed_line=self.watershed_line, + downsample=self.downsample, + ) + ] + + +@registry.add_as(Endpoint) +class ThresholdWatershed(EndpointWatershed): + """ + ThresholdWatershed + + An endpoint that applies a threshold to the input image to generate a marker + array, then directly performs a marker-controlled watershed. + + Parameters + ---------- + threshold : float, optional + Pixel intensity threshold below which pixels become 0 in the marker array. + Defaults to 128.0. + watershed_line : bool, optional + If True, produce watershed lines. Defaults to True. + """ + + def __init__( + self, watershed_line: bool, threshold: int = 128, threshold_2=None + ): + super().__init__(1, watershed_line) # Single input image + self.threshold = threshold + self.threshold_2 = threshold_2 + if threshold_2 is not None: + if not (self.threshold < self.threshold_2): + raise ValueError( + f"threshold1 ({self.threshold}) must be < threshold2 ({self.threshold_2})" + ) + def call(self, x): + """ + Apply threshold-based watershed to the input image. + + Parameters + ---------- + x : list of np.ndarray + - x[0] : The input image (2D ndarray). + + Returns + ------- + list of np.ndarray + A single-element list containing the labeled segmentation result. + """ + if self.threshold_2 is not None: + return [ + double_threshold_watershed( + image=x[0], + threshold1=self.threshold1, + threshold2=self.threshold2, + watershed_line=self.watershed_line, + ) + ] return [ - x[0], - cv2.connectedComponents( - x[0], connectivity=self.connectivity, ltype=cv2.CV_16U - )[1], + threshold_watershed( + image=x[0], + threshold=self.threshold, + watershed_line=self.watershed_line, + ) ] - def __init__(self, connectivity=4): + +@registry.add_as(Endpoint) +class ToLabels(Endpoint): + def __init__(self): super().__init__([TypeArray]) - self.connectivity = connectivity + def call(self, x): + return [_connected_components(x[0])] -@register(Endpoint) + +@registry.add_as(Endpoint) class EndpointSubtract(Endpoint): def __init__(self): super().__init__([TypeArray, TypeArray]) @@ -43,43 +334,35 @@ def __init__(self): def call(self, x): return [cv2.subtract(x[0], x[1])] - def __to_dict__(self) -> Dict: - return { - "name": "subtract", - "args": {}, - } - -@register(Endpoint) +@registry.add_as(Endpoint) class EndpointThreshold(Endpoint): - def __init__(self, threshold, normalize=False, mode="binary"): + def __init__(self, threshold): super().__init__([TypeArray]) self.threshold = threshold - self.normalize = normalize - self.mode = mode def call(self, x): - image = x[0] - if self.normalize: - image = cv2.normalize(image, None, 0, 255, cv2.NORM_MINMAX) - if self.mode == "binary": - return [ - cv2.threshold(image, self.threshold, 255, cv2.THRESH_BINARY)[1] - ] - return [ - cv2.threshold(image, self.threshold, 255, cv2.THRESH_TOZERO)[1] - ] + return [threshold_tozero(x[0], self.threshold)] def __to_dict__(self) -> Dict: return { "args": { "threshold": self.threshold, - "normalize": self.normalize, - "mode": self.mode, }, } +print(registry.display()) +test_endpoint = Endpoint.from_config( + { + "name": "ThresholdWatershed", + "args": {"watershed_line": True, "threshold": 128}, + } +) +print(test_endpoint) +print(test_endpoint.__to_dict__()) + + @register(Endpoint) class EndpointHoughCircle(Endpoint): def __init__( @@ -209,136 +492,6 @@ def __to_dict__(self) -> Dict: } -@register(Endpoint) -class EndpointWatershed(Endpoint): - def __init__(self): - super().__init__([TypeArray, TypeArray]) - - def call(self, x): - markers = cv2.connectedComponents( - x[1], connectivity=8, ltype=cv2.CV_16U - )[1] - scaled = cv2.exp((x[0] / 255.0).astype(np.float32)) - labels = watershed( - -scaled, - markers=markers, - mask=x[0] > 0, - watershed_line=True, - ) - return [labels] - - -@register(Endpoint) -class LocalMaxWatershed(Endpoint): - """Watershed based KartezioEndpoint, but only based on one single mask. - Markers are computed as the local max of the distance transform of the mask - - """ - - def __init__(self, min_distance=21): - super().__init__([TypeArray]) - self.min_distance = min_distance - - def call(self, x): - distance = cv2.distanceTransform(x[0], cv2.DIST_L2, 3) - distance = cv2.normalize(distance, None, 0.0, 1.0, cv2.NORM_MINMAX) - # distance[distance < 0.5] = 0 - # distance = (distance * 255).astype(np.uint8) - # markers = cv2.connectedComponents(distance, connectivity=8, ltype=cv2.CV_16U)[1] - scaled = cv2.exp((x[0] / 255.0).astype(np.float32)) - markers = _fast_local_max(distance, min_distance=self.min_distance) - labels = watershed( - -scaled, - markers=markers, - mask=x[0] > 0, - watershed_line=True, - ) - return [labels] - - def __to_dict__(self) -> Dict: - return { - "name": "local-max_watershed", - "args": { - "min_distance": self.min_distance, - }, - } - - -def _peak_local_max(image, min_distance=21): - peak_idx = peak_local_max( - image, - min_distance=min_distance, - exclude_border=0, - ) - peak_mask = np.zeros_like(image, dtype=np.uint8) - labels = list(range(1, peak_idx.shape[0] + 1)) - peak_mask[tuple(peak_idx.T)] = labels - return peak_mask - - -def _fast_local_max(image, min_distance=21): - image_down = cv2.pyrDown(image) - peak_idx = peak_local_max( - image_down, - min_distance=min_distance // 2, - exclude_border=0, - ) - peak_mask = np.zeros_like(image, dtype=np.uint8) - labels = list(range(1, peak_idx.shape[0] + 1)) - remaped_peaks = (peak_idx * 2).astype(np.int32) - peak_mask[tuple(remaped_peaks.T)] = labels - return peak_mask - - -@register(Endpoint) -class RawWatershed(Endpoint): - def __init__(self, min_distance=21): - super().__init__([TypeArray]) - self.min_distance = min_distance - - def call(self, x): - marker_labels = _fast_local_max(x[0], min_distance=self.min_distance) - labels = watershed( - -x[0], - markers=marker_labels, - mask=x[0] > 0, - watershed_line=True, - ) - return [labels] - - -@register(Endpoint) -class RawLocalMaxWatershed(Endpoint): - """Watershed based KartezioEndpoint, but only based on one single mask. - Markers are computed as the local max of the mask - - """ - - def __init__(self, threshold=1, markers_distance=21): - super().__init__([TypeArray]) - self.wt = WatershedSkimage(markers_distance=markers_distance) - self.threshold = threshold - - def call(self, x): - mask = threshold_tozero(x[0], self.threshold) - mask, markers, labels = self.wt.apply( - mask, markers=None, mask=mask > 0 - ) - return { - "mask_raw": x[0], - "mask": mask, - "markers": markers, - "count": len(np.unique(labels)) - 1, - "labels": labels, - } - - def _to_json_kwargs(self) -> dict: - return { - "threshold": self.threshold, - "markers_distance": self.wt.markers_distance, - } - - @register(Endpoint) class EndpointHoughCircleSmall(Endpoint): def __init__(self, min_dist=4, p1=256, p2=8, min_radius=2, max_radius=12): diff --git a/src/kartezio/preprocessing.py b/src/kartezio/preprocessing.py index ebbd2c5..67ee97b 100755 --- a/src/kartezio/preprocessing.py +++ b/src/kartezio/preprocessing.py @@ -3,7 +3,7 @@ import cv2 import numpy as np -from kartezio.core.components import Preprocessing, register +from kartezio.core.components import Preprocessing, registry from kartezio.vision.common import image_split diff --git a/src/kartezio/vision/watershed.py b/src/kartezio/vision/watershed.py index e69de29..e8beefb 100644 --- a/src/kartezio/vision/watershed.py +++ b/src/kartezio/vision/watershed.py @@ -0,0 +1,391 @@ +""" +watershed.py + +This module provides a collection of image segmentation routines based on the +watershed transform. It relies on OpenCV (cv2) for image processing operations +(e.g., distance transforms, connected components) and on scikit-image for +finding local maxima (peak_local_max) and performing the watershed algorithm. + +Functions Overview +------------------ +- _pyrdown(image, n=1): + Downsamples an image by repeatedly applying cv2.pyrDown n times. + +- watershed_transform(image, markers, watershed_line): + Applies the watershed transform using negative-exponential-scaled intensity. + +- _connected_components(image): + Labels connected components of a binary image using cv2.connectedComponents. + +- _distance_transform(image, normalize=False): + Computes the distance transform of a binary image; optionally normalizes it. + +- coordinates_to_mask(coordinates, shape, scale): + Converts peak coordinates into a binary mask, optionally rescaling them. + +- _local_max_markers(image, min_distance): + Creates markers for watershed by identifying local maxima in an image. + +- _fast_local_max_markers(image, min_distance, n): + Performs local maxima detection on a downsampled image, then scales results. + +- marker_controlled_watershed(image, markers, watershed_line): + Runs a watershed transform using the given marker image. + +- local_max_watershed(image, min_distance, watershed_line, downsample): + Combines a distance transform or intensity-based approach with local maxima + to produce markers and apply watershed. + +- distance_watershed(image, min_distance, watershed_line, downsample=0): + Shortcut to compute a distance transform, detect local maxima, and run watershed. + +- threshold_local_max_watershed(image, threshold, min_distance, watershed_line, downsample=0): + Applies a threshold, finds local maxima, and runs watershed with those markers. + +- threshold_watershed(image, threshold, watershed_line): + Applies a threshold directly as markers for watershed. + +- double_threshold_watershed(image, threshold1, threshold2, watershed_line): + Uses two thresholds to create marker regions for watershed. + +Dependencies +------------ +- cv2 (OpenCV) +- numpy +- skimage.feature.peak_local_max +- skimage.segmentation.watershed +""" + +import cv2 +import numpy as np +from skimage.feature import peak_local_max +from skimage.segmentation import watershed + +from kartezio.vision.common import threshold_tozero + + +def _pyrdown(image, n=1): + """ + Downsample an image n times using cv2.pyrDown. + + Parameters + ---------- + image : np.ndarray + Input image. + n : int, optional + Number of times to downsample (default is 1). + + Returns + ------- + np.ndarray + The downsampled image. + """ + for _ in range(n): + image = cv2.pyrDown(image) + return image + + +def watershed_transform(image, markers, watershed_line): + """ + Apply a watershed transform using negative exponential scaling of intensities. + + Parameters + ---------- + image : np.ndarray + Input grayscale image. Expected range [0, 255] for scaling. + markers : np.ndarray + Marker image, typically an integer-labeled array where non-zero regions + represent different labels. + watershed_line : bool + If True, the function computes a lines-producing watershed. The lines + separate the regions, setting them to 0. + + Returns + ------- + np.ndarray + Labeled image (same shape as input). Each connected region has its own label. + Watershed lines are zero if watershed_line=True. + """ + scaled = cv2.exp((image / 255.0).astype(np.float32)) + return watershed( + -scaled, + markers=markers, + mask=image > 0, + watershed_line=watershed_line, + ) + + +def _connected_components(image): + """ + Label the connected components in a binary image. + + Parameters + ---------- + image : np.ndarray + Binary or label image from which connected components are computed. + + Returns + ------- + np.ndarray + An integer-labeled array of the same shape, where each component has a unique ID. + """ + # cv2.connectedComponents returns (num_labels, labels_img). We only return labels_img. + return cv2.connectedComponents(image, connectivity=8, ltype=cv2.CV_16U)[1] + + +def _distance_transform(image, normalize=False): + """ + Compute the distance transform of a binary image. + + Parameters + ---------- + image : np.ndarray + Binary image (non-zero pixels considered foreground). + normalize : bool, optional + If True, the resulting distance transform is normalized to the range [0,1]. + + Returns + ------- + np.ndarray + A distance transform of the same shape as `image`. + """ + distance = cv2.distanceTransform(image, cv2.DIST_L2, 3) + if normalize: + return cv2.normalize(distance, None, 0.0, 1.0, cv2.NORM_MINMAX) + return distance + + +def coordinates_to_mask(coordinates, shape, scale): + """ + Convert coordinates of local maxima into a binary mask. + + Parameters + ---------- + coordinates : np.ndarray + An (N, 2) array of (row, col) coordinates for local maxima. + shape : tuple + Shape of the desired mask (e.g., image.shape). + scale : int + Upsampling factor to rescale coordinates if they were found in a downsampled image. + + Returns + ------- + np.ndarray + Binary mask of the same shape with 1's at local maxima positions, 0 otherwise. + """ + mask = np.zeros(shape, dtype=np.uint8) + coordinates = (coordinates * scale).astype(np.int32) + mask[coordinates[:, 0], coordinates[:, 1]] = 1 + return mask + + +def _local_max_markers(image, min_distance): + """ + Identify local maxima in the image and convert them to a binary marker mask. + + Parameters + ---------- + image : np.ndarray + Grayscale or distance-transformed image. + min_distance : int + Minimum distance separating local maxima. + + Returns + ------- + np.ndarray + Binary mask of local maxima. + """ + peak_idx = peak_local_max( + image, + min_distance=min_distance, + exclude_border=1, + ) + return coordinates_to_mask(peak_idx, image.shape, 1) + + +def _fast_local_max_markers(image, min_distance, n): + """ + Identify local maxima on a downsampled image, then scale them back up. + + Parameters + ---------- + image : np.ndarray + Input image to find local maxima. + min_distance : int + Minimum distance separating local maxima (adjusted for downsampling). + n : int + Number of times the image is downsampled. + + Returns + ------- + np.ndarray + Binary mask of local maxima upsampled to match the original image size. + """ + scale = 2**n + image_down = _pyrdown(image, n) + peak_coordinates = peak_local_max( + image_down, + min_distance=min_distance // scale, + exclude_border=1, + ) + return coordinates_to_mask(peak_coordinates, image.shape, scale) + + +def marker_controlled_watershed(image, markers, watershed_line): + """ + Run a watershed transform given an initial marker image. + + Parameters + ---------- + image : np.ndarray + Input image to be segmented (e.g., intensity or distance transform). + markers : np.ndarray + Binary or labeled image to serve as markers. + watershed_line : bool + If True, produce watershed lines (separating boundaries). + + Returns + ------- + np.ndarray + Integer-labeled segmented image. + """ + markers = _connected_components(markers) + return watershed_transform(image, markers, watershed_line) + + +def local_max_watershed(image, min_distance, watershed_line, downsample): + """ + Segment an image by detecting local maxima as markers and running watershed. + + Parameters + ---------- + image : np.ndarray + Image to segment (could be a distance transform or raw intensity). + min_distance : int + Minimum distance separating local maxima. + watershed_line : bool + Whether to produce watershed lines. + downsample : int or None + If > 0, downsample the image by 'downsample' times before detecting maxima. + Otherwise, detect maxima at the full resolution. + + Returns + ------- + np.ndarray + Integer-labeled segmented image. + """ + if downsample: + markers = _fast_local_max_markers(image, min_distance, downsample) + else: + markers = _local_max_markers(image, min_distance) + return marker_controlled_watershed(image, markers, watershed_line) + + +def distance_watershed( + image, min_distance, watershed_line, normalize, downsample=0 +): + """ + Shortcut for running watershed using a distance transform + local maxima approach. + + Parameters + ---------- + image : np.ndarray + Binary image on which distance transform is computed. + min_distance : int + Minimum distance for local maxima detection. + watershed_line : bool + Whether to produce watershed lines. + downsample : int, optional + If > 0, downsample the distance transform before local maxima detection. + + Returns + ------- + np.ndarray + Integer-labeled segmentation of the input image. + """ + distance = _distance_transform(image, normalize=normalize) + return local_max_watershed( + distance, min_distance, watershed_line, downsample + ) + + +def threshold_local_max_watershed( + image, threshold, min_distance, watershed_line, downsample=0 +): + """ + Watershed segmentation where markers are found from thresholded local maxima. + + Parameters + ---------- + image : np.ndarray + Grayscale or distance-transformed image. + threshold : float + Pixel intensity threshold; below this, pixels become 0. + min_distance : int + Minimum distance for local maxima detection. + watershed_line : bool + Whether to produce watershed lines. + downsample : int, optional + If > 0, downsample before local maxima detection. + + Returns + ------- + np.ndarray + Segmented, labeled image. + """ + markers = threshold_tozero(image, threshold) + if downsample: + markers = _fast_local_max_markers(markers, min_distance, downsample) + else: + markers = _local_max_markers(markers, min_distance) + return marker_controlled_watershed(image, markers, watershed_line) + + +def threshold_watershed(image, threshold, watershed_line): + """ + Basic threshold-based watershed segmentation. + + Parameters + ---------- + image : np.ndarray + Grayscale or distance-transformed image to segment. + threshold : float + Pixel intensity threshold; below this, pixels become 0 (markers). + watershed_line : bool + Whether to produce watershed lines. + + Returns + ------- + np.ndarray + Labeled watershed segmentation. + """ + markers = threshold_tozero(image, threshold) + return marker_controlled_watershed(image, markers, watershed_line) + + +def double_threshold_watershed(image, threshold1, threshold2, watershed_line): + """ + Watershed segmentation using two thresholds. + + Pixels below threshold1 are set to zero, then below threshold2 are further + refined. This can help isolate more confident markers from less confident ones. + + Parameters + ---------- + image : np.ndarray + Grayscale or distance-transformed image. + threshold1 : float + First threshold; pixels below are zeroed out. + threshold2 : float + Second threshold; pixels below are zeroed out (refined). + watershed_line : bool + Whether to produce watershed lines. + + Returns + ------- + np.ndarray + Labeled segmentation from watershed. + """ + image = threshold_tozero(image, threshold1) + markers = threshold_tozero(image, threshold2) + return marker_controlled_watershed(image, markers, watershed_line)