diff --git a/examples/example_spine_instance.py b/examples/example_spine_instance.py index 354d5f2..25eac85 100644 --- a/examples/example_spine_instance.py +++ b/examples/example_spine_instance.py @@ -32,7 +32,7 @@ def main(): with cProfile.Profile() as pr: results = evaluator.evaluate(prediction_mask, reference_mask, verbose=False) - for groupname, (result, intermediate_steps_data) in results.items(): + for groupname, result in results.items(): print() print("### Group", groupname) print(result) diff --git a/examples/example_spine_instance_config.py b/examples/example_spine_instance_config.py index 4fbe74e..a038386 100644 --- a/examples/example_spine_instance_config.py +++ b/examples/example_spine_instance_config.py @@ -18,7 +18,7 @@ def main(): with cProfile.Profile() as pr: results = evaluator.evaluate(prediction_mask, reference_mask, verbose=False) - for groupname, (result, intermediate_steps_data) in results.items(): + for groupname, result in results.items(): print() print("### Group", groupname) print(result) diff --git a/examples/example_spine_semantic.py b/examples/example_spine_semantic.py index 0ec88f6..ae427ee 100644 --- a/examples/example_spine_semantic.py +++ b/examples/example_spine_semantic.py @@ -27,13 +27,13 @@ def main(): with cProfile.Profile() as pr: - result, intermediate_steps_data = evaluator.evaluate( - prediction_mask, reference_mask - )["ungrouped"] + result = evaluator.evaluate(prediction_mask, reference_mask)["ungrouped"] # To print the results, just call print print(result) + intermediate_steps_data = result.intermediate_steps_data + assert intermediate_steps_data is not None # To get the different intermediate arrays, just use the second returned object intermediate_steps_data.original_prediction_arr # Input prediction array, untouched intermediate_steps_data.original_reference_arr # Input reference array, untouched diff --git a/panoptica/panoptica_aggregator.py b/panoptica/panoptica_aggregator.py index 0c0c527..e5782c6 100644 --- a/panoptica/panoptica_aggregator.py +++ b/panoptica/panoptica_aggregator.py @@ -192,7 +192,7 @@ def _save_one_subject(self, subject_name, result_grouped): # content = [subject_name] for groupname in self.__class_group_names: - result: PanopticaResult = result_grouped[groupname][0] + result: PanopticaResult = result_grouped[groupname] result_dict = result.to_dict() if result.computation_time is not None: result_dict[COMPUTATION_TIME_KEY] = result.computation_time diff --git a/panoptica/panoptica_evaluator.py b/panoptica/panoptica_evaluator.py index 7b03151..a30a680 100644 --- a/panoptica/panoptica_evaluator.py +++ b/panoptica/panoptica_evaluator.py @@ -3,7 +3,7 @@ from panoptica.instance_approximator import InstanceApproximator from panoptica.instance_evaluator import evaluate_matched_instance from panoptica.instance_matcher import InstanceMatchingAlgorithm -from panoptica.metrics import Metric, _Metric +from panoptica.metrics import Metric from panoptica.panoptica_result import PanopticaResult from panoptica.utils.timing import measure_time from panoptica.utils import EdgeCaseHandler @@ -12,7 +12,6 @@ MatchedInstancePair, SemanticPair, UnmatchedInstancePair, - _ProcessingPair, InputType, EvaluateInstancePair, IntermediateStepsData, @@ -121,7 +120,7 @@ def evaluate( save_group_times: bool | None = None, log_times: bool | None = None, verbose: bool | None = None, - ) -> dict[str, tuple[PanopticaResult, IntermediateStepsData]]: + ) -> dict[str, PanopticaResult]: processing_pair = self.__expected_input(prediction_arr, reference_arr) assert isinstance( processing_pair, self.__expected_input.value @@ -134,21 +133,17 @@ def evaluate( processing_pair.reference_arr, raise_error=True ) - result_grouped: dict[str, tuple[PanopticaResult, IntermediateStepsData]] = {} + result_grouped: dict[str, PanopticaResult] = {} for group_name, label_group in self.__segmentation_class_groups.items(): result_grouped[group_name] = self._evaluate_group( group_name, label_group, processing_pair, result_all, - save_group_times=( - self.__save_group_times - if save_group_times is None - else save_group_times - ), + save_group_times=(self.__save_group_times if save_group_times is None else save_group_times), log_times=log_times, verbose=verbose, - )[1:] + ) return result_grouped @property @@ -170,7 +165,7 @@ def resulting_metric_keys(self) -> list[str]: dummy_input = MatchedInstancePair( np.ones((1, 1, 1), dtype=np.uint8), np.ones((1, 1, 1), dtype=np.uint8) ) - _, res, _ = self._evaluate_group( + res = self._evaluate_group( group_name="", label_group=LabelGroup(1, single_instance=False), processing_pair=dummy_input, @@ -192,7 +187,7 @@ def _evaluate_group( verbose: bool | None = None, log_times: bool | None = None, save_group_times: bool = False, - ): + ) -> PanopticaResult: assert isinstance(label_group, LabelGroup) if self.__save_group_times: start_time = perf_counter() @@ -212,7 +207,7 @@ def _evaluate_group( ) decision_threshold = 0.0 - result, intermediate_steps_data = panoptic_evaluate( + result = panoptic_evaluate( input_pair=processing_pair_grouped, edge_case_handler=self.__edge_case_handler, instance_approximator=self.__instance_approximator, @@ -229,7 +224,7 @@ def _evaluate_group( if save_group_times: duration = perf_counter() - start_time result.computation_time = duration - return group_name, result, intermediate_steps_data + return result def panoptic_evaluate( @@ -246,7 +241,7 @@ def panoptic_evaluate( verbose=False, verbose_calc=False, **kwargs, -) -> tuple[PanopticaResult, IntermediateStepsData]: +) -> PanopticaResult: """ Perform panoptic evaluation on the given processing pair. @@ -368,13 +363,14 @@ def panoptic_evaluate( list_metrics=processing_pair.list_metrics, global_metrics=global_metrics, edge_case_handler=edge_case_handler, + intermediate_steps_data=intermediate_steps_data, ) if isinstance(processing_pair, PanopticaResult): processing_pair._global_metrics = global_metrics if result_all: processing_pair.calculate_all(print_errors=verbose_calc) - return processing_pair, intermediate_steps_data + return processing_pair raise RuntimeError("End of panoptic pipeline reached without results") diff --git a/panoptica/panoptica_result.py b/panoptica/panoptica_result.py index da9c884..901b8d6 100644 --- a/panoptica/panoptica_result.py +++ b/panoptica/panoptica_result.py @@ -13,6 +13,7 @@ MetricType, ) from panoptica.utils import EdgeCaseHandler +from panoptica.utils.processing_pair import IntermediateStepsData class PanopticaResult(object): @@ -27,6 +28,7 @@ def __init__( list_metrics: dict[Metric, list[float]], edge_case_handler: EdgeCaseHandler, global_metrics: list[Metric] = [], + intermediate_steps_data: IntermediateStepsData | None = None, computation_time: float | None = None, ): """Result object for Panoptica, contains all calculatable metrics @@ -45,6 +47,7 @@ def __init__( empty_list_std = self._edge_case_handler.handle_empty_list_std().value self._global_metrics: list[Metric] = global_metrics self.computation_time = computation_time + self.intermediate_steps_data = intermediate_steps_data ###################### # Evaluation Metrics # ###################### diff --git a/unit_tests/test_panoptic_evaluator.py b/unit_tests/test_panoptic_evaluator.py index fce49f4..dea5c13 100644 --- a/unit_tests/test_panoptic_evaluator.py +++ b/unit_tests/test_panoptic_evaluator.py @@ -63,7 +63,7 @@ def test_simple_evaluation(self): instance_matcher=NaiveThresholdMatching(), ) - result, debug_data = evaluator.evaluate(b, a)["ungrouped"] + result = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) @@ -83,7 +83,7 @@ def test_simple_evaluation_instance_multiclass(self): instance_matcher=NaiveThresholdMatching(), ) - result, debug_data = evaluator.evaluate(b, a)["ungrouped"] + result = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertAlmostEqual(result.global_bin_dsc, 0.8571428571428571) self.assertEqual(result.tp, 1) @@ -104,7 +104,7 @@ def test_simple_evaluation_DSC(self): instance_matcher=NaiveThresholdMatching(), ) - result, debug_data = evaluator.evaluate(b, a)["ungrouped"] + result = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) @@ -124,7 +124,7 @@ def test_simple_evaluation_DSC_partial(self): instance_metrics=[Metric.DSC], ) - result, debug_data = evaluator.evaluate(b, a)["ungrouped"] + result = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) @@ -150,7 +150,7 @@ def test_simple_evaluation_ASSD(self): ), ) - result, debug_data = evaluator.evaluate(b, a)["ungrouped"] + result = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) @@ -172,7 +172,7 @@ def test_simple_evaluation_ASSD_negative(self): ), ) - result, debug_data = evaluator.evaluate(b, a)["ungrouped"] + result = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 0) self.assertEqual(result.fp, 1) @@ -192,7 +192,7 @@ def test_pred_empty(self): instance_matcher=NaiveThresholdMatching(), ) - result, debug_data = evaluator.evaluate(b, a)["ungrouped"] + result = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 0) self.assertEqual(result.fp, 0) @@ -214,7 +214,7 @@ def test_no_TP_but_overlap(self): instance_matcher=NaiveThresholdMatching(), ) - result, debug_data = evaluator.evaluate(b, a)["ungrouped"] + result = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 0) self.assertEqual(result.fp, 1) @@ -237,7 +237,7 @@ def test_ref_empty(self): instance_matcher=NaiveThresholdMatching(), ) - result, debug_data = evaluator.evaluate(b, a)["ungrouped"] + result = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 0) self.assertEqual(result.fp, 1) @@ -258,7 +258,7 @@ def test_both_empty(self): instance_matcher=NaiveThresholdMatching(), ) - result, debug_data = evaluator.evaluate(b, a)["ungrouped"] + result = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 0) self.assertEqual(result.fp, 0) @@ -291,7 +291,7 @@ def test_dtype_evaluation(self): instance_matcher=NaiveThresholdMatching(), ) - result, debug_data = evaluator.evaluate(b, a)["ungrouped"] + result = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) @@ -310,7 +310,7 @@ def test_simple_evaluation_maximize_matcher(self): instance_matcher=MaximizeMergeMatching(), ) - result, debug_data = evaluator.evaluate(b, a)["ungrouped"] + result = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) @@ -330,7 +330,7 @@ def test_simple_evaluation_maximize_matcher_overlaptwo(self): instance_matcher=MaximizeMergeMatching(), ) - result, debug_data = evaluator.evaluate(b, a)["ungrouped"] + result = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) @@ -352,7 +352,7 @@ def test_simple_evaluation_maximize_matcher_overlap(self): instance_matcher=MaximizeMergeMatching(), ) - result, debug_data = evaluator.evaluate(b, a)["ungrouped"] + result = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 1) @@ -374,7 +374,7 @@ def test_single_instance_mode(self): segmentation_class_groups=SegmentationClassGroups({"organ": (5, True)}), ) - result, debug_data = evaluator.evaluate(b, a)["organ"] + result = evaluator.evaluate(b, a)["organ"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) @@ -394,7 +394,7 @@ def test_single_instance_mode_nooverlap(self): segmentation_class_groups=SegmentationClassGroups({"organ": (5, True)}), ) - result, debug_data = evaluator.evaluate(b, a)["organ"] + result = evaluator.evaluate(b, a)["organ"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0)