Skip to content

Commit

Permalink
fixed #161
Browse files Browse the repository at this point in the history
  • Loading branch information
Hendrik-code committed Dec 17, 2024
1 parent 7d72422 commit b2758af
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 133 deletions.
36 changes: 13 additions & 23 deletions panoptica/instance_approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from panoptica.utils.constants import CCABackend
from panoptica._functionals import _connected_components
from panoptica.utils.numpy_utils import _get_smallest_fitting_uint
# from panoptica.utils.numpy_utils import _get_smallest_fitting_uint
from panoptica.utils.processing_pair import (
MatchedInstancePair,
SemanticPair,
Expand Down Expand Up @@ -80,7 +80,7 @@ def approximate_instances(
AssertionError: If there are negative values in the semantic maps, which is not allowed.
"""
# Check validity
pred_labels, ref_labels = semantic_pair._pred_labels, semantic_pair._ref_labels
pred_labels, ref_labels = semantic_pair.pred_labels, semantic_pair.ref_labels
pred_label_range = (
(np.min(pred_labels), np.max(pred_labels))
if len(pred_labels) > 0
Expand All @@ -95,10 +95,10 @@ def approximate_instances(
min_value >= 0
), "There are negative values in the semantic maps. This is not allowed!"
# Set dtype to smalles fitting uint
max_value = max(np.max(pred_label_range[1]), np.max(ref_label_range[1]))
dtype = _get_smallest_fitting_uint(max_value)
semantic_pair.set_dtype(dtype)
print(f"-- Set dtype to {dtype}") if verbose else None
# max_value = max(np.max(pred_label_range[1]), np.max(ref_label_range[1]))
# dtype = _get_smallest_fitting_uint(max_value)
# semantic_pair.set_dtype(dtype)
# print(f"-- Set dtype to {dtype}") if verbose else None

# Call algorithm
instance_pair = self._approximate_instances(semantic_pair, **kwargs)
Expand Down Expand Up @@ -148,31 +148,21 @@ def _approximate_instances(
"""
cca_backend = self.cca_backend
if cca_backend is None:
cca_backend = (
CCABackend.cc3d if semantic_pair.n_dim >= 3 else CCABackend.scipy
)
cca_backend = CCABackend.cc3d if semantic_pair.n_dim >= 3 else CCABackend.scipy
assert cca_backend is not None

empty_prediction = len(semantic_pair._pred_labels) == 0
empty_reference = len(semantic_pair._ref_labels) == 0
empty_prediction = len(semantic_pair.pred_labels) == 0
empty_reference = len(semantic_pair.ref_labels) == 0
prediction_arr, n_prediction_instance = (
_connected_components(semantic_pair._prediction_arr, cca_backend)
if not empty_prediction
else (semantic_pair._prediction_arr, 0)
_connected_components(semantic_pair.prediction_arr, cca_backend) if not empty_prediction else (semantic_pair.prediction_arr, 0)
)
reference_arr, n_reference_instance = (
_connected_components(semantic_pair._reference_arr, cca_backend)
if not empty_reference
else (semantic_pair._reference_arr, 0)
)

dtype = _get_smallest_fitting_uint(
max(prediction_arr.max(), reference_arr.max())
_connected_components(semantic_pair.reference_arr, cca_backend) if not empty_reference else (semantic_pair.reference_arr, 0)
)

return UnmatchedInstancePair(
prediction_arr=prediction_arr.astype(dtype),
reference_arr=reference_arr.astype(dtype),
prediction_arr=prediction_arr,
reference_arr=reference_arr,
n_prediction_instance=n_prediction_instance,
n_reference_instance=n_reference_instance,
)
Expand Down
2 changes: 1 addition & 1 deletion panoptica/instance_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def map_instance_labels(
# Build a MatchedInstancePair out of the newly derived data
matched_instance_pair = MatchedInstancePair(
prediction_arr=prediction_arr_relabeled,
reference_arr=processing_pair._reference_arr,
reference_arr=processing_pair.reference_arr,
)
return matched_instance_pair

Expand Down
155 changes: 63 additions & 92 deletions panoptica/utils/processing_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from panoptica.utils.constants import _Enum_Compare
from dataclasses import dataclass
from panoptica.metrics import Metric
from panoptica.utils.numpy_utils import _get_smallest_fitting_uint

uint_type: type = np.unsignedinteger
int_type: type = np.integer
Expand All @@ -26,90 +27,66 @@ class _ProcessingPair(ABC):
uncropped_shape (tuple[int, ...]): The original shape of the arrays before cropping.
"""

_prediction_arr: np.ndarray
_reference_arr: np.ndarray
# unique labels without zero
_ref_labels: tuple[int, ...]
_pred_labels: tuple[int, ...]
n_dim: int

def __init__(
self, prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None
) -> None:
def __init__(self, prediction_arr: np.ndarray, reference_arr: np.ndarray) -> None:
"""Initializes the processing pair with prediction and reference arrays.
Args:
prediction_arr (np.ndarray): Numpy array of prediction labels.
reference_arr (np.ndarray): Numpy array of reference labels.
dtype (type | None): The expected datatype of arrays. If None, no datatype check is performed.
"""
_check_array_integrity(prediction_arr, reference_arr, dtype=dtype)
self._prediction_arr = prediction_arr
self._reference_arr = reference_arr
self.dtype = dtype
self.n_dim = reference_arr.ndim
self._ref_labels: tuple[int, ...] = tuple(
_unique_without_zeros(reference_arr)
) # type:ignore
self._pred_labels: tuple[int, ...] = tuple(
_unique_without_zeros(prediction_arr)
) # type:ignore
self.crop: tuple[slice, ...] = None
self.is_cropped: bool = False
self.uncropped_shape: tuple[int, ...] = reference_arr.shape
self.__prediction_arr: np.ndarray = prediction_arr
self.__reference_arr: np.ndarray = reference_arr
_check_array_integrity(self.__prediction_arr, self.__reference_arr, dtype=int_type)
max_value = max(prediction_arr.max(), reference_arr.max())
dtype = _get_smallest_fitting_uint(max_value)
self.set_dtype(dtype)
self.__dtype = dtype
self.__n_dim: int = reference_arr.ndim
self.__ref_labels: tuple[int, ...] = tuple(_unique_without_zeros(reference_arr)) # type:ignore
self.__pred_labels: tuple[int, ...] = tuple(_unique_without_zeros(prediction_arr)) # type:ignore
self.__crop: tuple[slice, ...] = None
self.__is_cropped: bool = False
self.__uncropped_shape: tuple[int, ...] = reference_arr.shape

def crop_data(self, verbose: bool = False):
"""Crops prediction and reference arrays to non-zero regions.
Args:
verbose (bool, optional): If True, prints cropping details. Defaults to False.
"""
if self.is_cropped:
if self.__is_cropped:
return
if self.crop is None:
self.uncropped_shape = self._prediction_arr.shape
self.crop = _get_paired_crop(
self._prediction_arr,
self._reference_arr,
if self.__crop is None:
self.__uncropped_shape = self.__prediction_arr.shape
self.__crop = _get_paired_crop(
self.__prediction_arr,
self.__reference_arr,
)

self._prediction_arr = self._prediction_arr[self.crop]
self._reference_arr = self._reference_arr[self.crop]
(
print(
f"-- Cropped from {self.uncropped_shape} to {self._prediction_arr.shape}"
)
if verbose
else None
)
self.is_cropped = True
self.__prediction_arr = self.__prediction_arr[self.__crop]
self.__reference_arr = self.__reference_arr[self.__crop]
(print(f"-- Cropped from {self.__uncropped_shape} to {self.__prediction_arr.shape}") if verbose else None)
self.__is_cropped = True

def uncrop_data(self, verbose: bool = False):
"""Restores the arrays to their original, uncropped shape.
Args:
verbose (bool, optional): If True, prints uncropping details. Defaults to False.
"""
if self.is_cropped == False:
if self.__is_cropped == False:
return
assert (
self.uncropped_shape is not None
), "Calling uncrop_data() without having cropped first"
prediction_arr = np.zeros(self.uncropped_shape)
prediction_arr[self.crop] = self._prediction_arr
self._prediction_arr = prediction_arr

reference_arr = np.zeros(self.uncropped_shape)
reference_arr[self.crop] = self._reference_arr
(
print(
f"-- Uncropped from {self._reference_arr.shape} to {self.uncropped_shape}"
)
if verbose
else None
)
self._reference_arr = reference_arr
self.is_cropped = False
assert self.__uncropped_shape is not None, "Calling uncrop_data() without having cropped first"
prediction_arr = np.zeros(self.__uncropped_shape)
prediction_arr[self.__crop] = self.__prediction_arr
self.__prediction_arr = prediction_arr

reference_arr = np.zeros(self.__uncropped_shape)
reference_arr[self.__crop] = self.__reference_arr
(print(f"-- Uncropped from {self.__reference_arr.shape} to {self.__uncropped_shape}") if verbose else None)
self.__reference_arr = reference_arr
self.__is_cropped = False

def set_dtype(self, type):
"""Sets the data type for both prediction and reference arrays.
Expand All @@ -120,43 +97,38 @@ def set_dtype(self, type):
assert np.issubdtype(
type, int_type
), "set_dtype: tried to set dtype to something other than integers"
self._prediction_arr = self._prediction_arr.astype(type)
self._reference_arr = self._reference_arr.astype(type)
self.__prediction_arr = self.__prediction_arr.astype(type)
self.__reference_arr = self.__reference_arr.astype(type)

@property
def prediction_arr(self):
return self._prediction_arr
return self.__prediction_arr

@property
def reference_arr(self):
return self._reference_arr
return self.__reference_arr

@property
def pred_labels(self):
return self._pred_labels
return self.__pred_labels

@property
def ref_labels(self):
return self._ref_labels
return self.__ref_labels

@property
def n_dim(self):
return self.__n_dim

def copy(self):
"""
Creates an exact copy of this object
"""
return type(self)(
prediction_arr=self._prediction_arr,
reference_arr=self._reference_arr,
prediction_arr=self.__prediction_arr,
reference_arr=self.__reference_arr,
) # type:ignore

# Make all variables read-only!
# def __setattr__(self, attr, value):
# if hasattr(self, attr):
# raise Exception("Attempting to alter read-only value")


#
# self.__dict__[attr] = value


class _ProcessingPairInstanced(_ProcessingPair):
"""Represents a processing pair with labeled instances, including unique label counts.
Expand All @@ -175,7 +147,6 @@ def __init__(
self,
prediction_arr: np.ndarray,
reference_arr: np.ndarray,
dtype: type | None,
n_prediction_instance: int | None = None,
n_reference_instance: int | None = None,
) -> None:
Expand All @@ -188,7 +159,7 @@ def __init__(
n_prediction_instance (int | None, optional): Pre-calculated number of prediction instances.
n_reference_instance (int | None, optional): Pre-calculated number of reference instances.
"""
super().__init__(prediction_arr, reference_arr, dtype)
super().__init__(prediction_arr, reference_arr)
if n_prediction_instance is None:
self.n_prediction_instance = _count_unique_without_zeros(prediction_arr)

Expand All @@ -204,8 +175,8 @@ def copy(self):
Creates an exact copy of this object
"""
return type(self)(
prediction_arr=self._prediction_arr,
reference_arr=self._reference_arr,
prediction_arr=self.prediction_arr,
reference_arr=self.reference_arr,
n_prediction_instance=self.n_prediction_instance,
n_reference_instance=self.n_reference_instance,
) # type:ignore
Expand Down Expand Up @@ -237,6 +208,12 @@ def _check_array_integrity(
assert (
prediction_arr.shape == reference_arr.shape
), f"shape mismatch, got {prediction_arr.shape},{reference_arr.shape}"

min_value = min(prediction_arr.min(), reference_arr.min())
assert min_value >= 0, "There are negative values in the semantic maps. This is not allowed!"

# if prediction_arr.dtype != reference_arr.dtype:
# print(f"Dtype is equal in prediction and reference, got {prediction_arr.dtype},{reference_arr.dtype}. Intended?")
# assert prediction_arr.dtype == reference_arr.dtype, f"dtype mismatch, got {prediction_arr.dtype},{reference_arr.dtype}"
if dtype is not None:
assert (
Expand All @@ -253,7 +230,7 @@ class SemanticPair(_ProcessingPair):
"""

def __init__(self, prediction_arr: np.ndarray, reference_arr: np.ndarray) -> None:
super().__init__(prediction_arr, reference_arr, dtype=int_type)
super().__init__(prediction_arr, reference_arr)


class UnmatchedInstancePair(_ProcessingPairInstanced):
Expand All @@ -272,7 +249,6 @@ def __init__(
super().__init__(
prediction_arr,
reference_arr,
uint_type,
n_prediction_instance,
n_reference_instance,
) # type:ignore
Expand Down Expand Up @@ -320,24 +296,19 @@ def __init__(
super().__init__(
prediction_arr,
reference_arr,
uint_type,
n_prediction_instance,
n_reference_instance,
) # type:ignore
if matched_instances is None:
matched_instances = [i for i in self._pred_labels if i in self._ref_labels]
matched_instances = [i for i in self.pred_labels if i in self.ref_labels]
self.matched_instances = matched_instances

if missed_reference_labels is None:
missed_reference_labels = list(
[i for i in self._ref_labels if i not in self._pred_labels]
)
missed_reference_labels = list([i for i in self.ref_labels if i not in self.pred_labels])
self.missed_reference_labels = missed_reference_labels

if missed_prediction_labels is None:
missed_prediction_labels = list(
[i for i in self._pred_labels if i not in self._ref_labels]
)
missed_prediction_labels = list([i for i in self.pred_labels if i not in self.ref_labels])
self.missed_prediction_labels = missed_prediction_labels

@property
Expand All @@ -349,8 +320,8 @@ def copy(self):
Creates an exact copy of this object
"""
return type(self)(
prediction_arr=self._prediction_arr.copy(),
reference_arr=self._reference_arr.copy(),
prediction_arr=self.prediction_arr.copy(),
reference_arr=self.reference_arr.copy(),
n_prediction_instance=self.n_prediction_instance,
n_reference_instance=self.n_reference_instance,
missed_reference_labels=self.missed_reference_labels,
Expand Down
Loading

0 comments on commit b2758af

Please sign in to comment.