From f86f87cb98b110b7cb65052c008909f6fedd367e Mon Sep 17 00:00:00 2001 From: iback Date: Fri, 31 Jan 2025 10:44:00 +0000 Subject: [PATCH 1/2] added some features while doing #157 --- panoptica/panoptica_aggregator.py | 26 ++++---- panoptica/panoptica_evaluator.py | 66 +++++-------------- panoptica/panoptica_statistics.py | 106 +++++++++++++++--------------- 3 files changed, 79 insertions(+), 119 deletions(-) diff --git a/panoptica/panoptica_aggregator.py b/panoptica/panoptica_aggregator.py index e5782c6..14d4132 100644 --- a/panoptica/panoptica_aggregator.py +++ b/panoptica/panoptica_aggregator.py @@ -59,16 +59,15 @@ def __init__( self.__panoptica_evaluator = panoptica_evaluator self.__class_group_names = panoptica_evaluator.segmentation_class_groups_names self.__evaluation_metrics = panoptica_evaluator.resulting_metric_keys + self.__log_times = log_times - if log_times: + if log_times and COMPUTATION_TIME_KEY not in self.__evaluation_metrics: self.__evaluation_metrics.append(COMPUTATION_TIME_KEY) if isinstance(output_file, str): output_file = Path(output_file) # uses tsv - assert ( - output_file.parent.exists() - ), f"Directory {str(output_file.parent)} does not exist" + assert output_file.parent.exists(), f"Directory {str(output_file.parent)} does not exist" out_file_path = str(output_file) @@ -82,19 +81,13 @@ def __init__( else: out_file_path += ".tsv" # add extension - out_buffer_file: Path = Path(out_file_path).parent.joinpath( - "panoptica_aggregator_tmp.tsv" - ) + out_buffer_file: Path = Path(out_file_path).parent.joinpath("panoptica_aggregator_tmp.tsv") self.__output_buffer_file = out_buffer_file Path(out_file_path).parent.mkdir(parents=True, exist_ok=True) self.__output_file = out_file_path - header = ["subject_name"] + [ - f"{g}-{m}" - for g in self.__class_group_names - for m in self.__evaluation_metrics - ] + header = ["subject_name"] + [f"{g}-{m}" for g in self.__class_group_names for m in self.__evaluation_metrics] header_hash = hash("+".join(header)) if not output_file.exists(): @@ -108,9 +101,7 @@ def __init__( continue_file = True else: # TODO should also hash panoptica_evaluator just to make sure! and then save into header of file - assert header_hash == hash( - "+".join(header_list) - ), "Hash of header not the same! You are using a different setup!" + assert header_hash == hash("+".join(header_list)), "Hash of header not the same! You are using a different setup!" if out_buffer_file.exists(): os.remove(out_buffer_file) @@ -176,6 +167,7 @@ def evaluate( result_all=True, verbose=False, log_times=False, + save_group_times=self.__log_times, ) # Add to file @@ -208,6 +200,10 @@ def _save_one_subject(self, subject_name, result_grouped): def panoptica_evaluator(self): return self.__panoptica_evaluator + @property + def evaluation_metrics(self): + return self.__evaluation_metrics + def _read_first_row(file: str | Path): """Reads the first row of a TSV file. diff --git a/panoptica/panoptica_evaluator.py b/panoptica/panoptica_evaluator.py index 6849f82..8b630aa 100644 --- a/panoptica/panoptica_evaluator.py +++ b/panoptica/panoptica_evaluator.py @@ -82,13 +82,9 @@ def __init__( segmentation_class_groups = _NoSegmentationClassGroups() self.__segmentation_class_groups = segmentation_class_groups - self.__edge_case_handler = ( - edge_case_handler if edge_case_handler is not None else EdgeCaseHandler() - ) + self.__edge_case_handler = edge_case_handler if edge_case_handler is not None else EdgeCaseHandler() if self.__decision_metric is not None: - assert ( - self.__decision_threshold is not None - ), "decision metric set but no decision threshold for it" + assert self.__decision_threshold is not None, "decision metric set but no decision threshold for it" # self.__log_times = log_times self.__verbose = verbose @@ -122,16 +118,10 @@ def evaluate( verbose: bool | None = None, ) -> dict[str, PanopticaResult]: processing_pair = self.__expected_input(prediction_arr, reference_arr) - assert isinstance( - processing_pair, self.__expected_input.value - ), f"input not of expected type {self.__expected_input}" + assert isinstance(processing_pair, self.__expected_input.value), f"input not of expected type {self.__expected_input}" - self.__segmentation_class_groups.has_defined_labels_for( - processing_pair.prediction_arr, raise_error=True - ) - self.__segmentation_class_groups.has_defined_labels_for( - processing_pair.reference_arr, raise_error=True - ) + self.__segmentation_class_groups.has_defined_labels_for(processing_pair.prediction_arr, raise_error=True) + self.__segmentation_class_groups.has_defined_labels_for(processing_pair.reference_arr, raise_error=True) result_grouped: dict[str, PanopticaResult] = {} for group_name, label_group in self.__segmentation_class_groups.items(): @@ -140,11 +130,7 @@ def evaluate( label_group, processing_pair, result_all, - save_group_times=( - self.__save_group_times - if save_group_times is None - else save_group_times - ), + save_group_times=(self.__save_group_times if save_group_times is None else save_group_times), log_times=log_times, verbose=verbose, ) @@ -166,9 +152,7 @@ def _set_instance_matcher(self, matcher: InstanceMatchingAlgorithm): @property def resulting_metric_keys(self) -> list[str]: if self.__resulting_metric_keys is None: - dummy_input = MatchedInstancePair( - np.ones((1, 1, 1), dtype=np.uint8), np.ones((1, 1, 1), dtype=np.uint8) - ) + dummy_input = MatchedInstancePair(np.ones((1, 1, 1), dtype=np.uint8), np.ones((1, 1, 1), dtype=np.uint8)) res = self._evaluate_group( group_name="", label_group=LabelGroup(1, single_instance=False), @@ -193,7 +177,7 @@ def _evaluate_group( save_group_times: bool = False, ) -> PanopticaResult: assert isinstance(label_group, LabelGroup) - if self.__save_group_times: + if self.__save_group_times or save_group_times: start_time = perf_counter() prediction_arr_grouped = label_group(processing_pair.prediction_arr) @@ -202,9 +186,7 @@ def _evaluate_group( single_instance_mode = label_group.single_instance processing_pair_grouped = processing_pair.__class__(prediction_arr=prediction_arr_grouped, reference_arr=reference_arr_grouped) # type: ignore decision_threshold = self.__decision_threshold - if single_instance_mode and not isinstance( - processing_pair, MatchedInstancePair - ): + if single_instance_mode and not isinstance(processing_pair, MatchedInstancePair): processing_pair_grouped = MatchedInstancePair( prediction_arr=processing_pair_grouped.prediction_arr, reference_arr=processing_pair_grouped.reference_arr, @@ -225,7 +207,7 @@ def _evaluate_group( verbose=True if verbose is None else verbose, verbose_calc=self.__verbose if verbose is None else verbose, ) - if save_group_times: + if self.__save_group_times or save_group_times: duration = perf_counter() - start_time result.computation_time = duration return result @@ -284,22 +266,12 @@ def panoptic_evaluate( # Crops away unecessary space of zeroes input_pair.crop_data() - processing_pair: ( - SemanticPair - | UnmatchedInstancePair - | MatchedInstancePair - | EvaluateInstancePair - | PanopticaResult - ) = input_pair.copy() + processing_pair: SemanticPair | UnmatchedInstancePair | MatchedInstancePair | EvaluateInstancePair | PanopticaResult = input_pair.copy() # First Phase: Instance Approximation if isinstance(processing_pair, SemanticPair): - intermediate_steps_data.add_intermediate_arr_data( - processing_pair.copy(), InputType.SEMANTIC - ) - assert ( - instance_approximator is not None - ), "Got SemanticPair but not InstanceApproximator" + intermediate_steps_data.add_intermediate_arr_data(processing_pair.copy(), InputType.SEMANTIC) + assert instance_approximator is not None, "Got SemanticPair but not InstanceApproximator" if verbose: print("-- Got SemanticPair, will approximate instances") start = perf_counter() @@ -309,9 +281,7 @@ def panoptic_evaluate( # Second Phase: Instance Matching if isinstance(processing_pair, UnmatchedInstancePair): - intermediate_steps_data.add_intermediate_arr_data( - processing_pair.copy(), InputType.UNMATCHED_INSTANCE - ) + intermediate_steps_data.add_intermediate_arr_data(processing_pair.copy(), InputType.UNMATCHED_INSTANCE) processing_pair = _handle_zero_instances_cases( processing_pair, eval_metrics=instance_metrics, @@ -322,9 +292,7 @@ def panoptic_evaluate( if isinstance(processing_pair, UnmatchedInstancePair): if verbose: print("-- Got UnmatchedInstancePair, will match instances") - assert ( - instance_matcher is not None - ), "Got UnmatchedInstancePair but not InstanceMatchingAlgorithm" + assert instance_matcher is not None, "Got UnmatchedInstancePair but not InstanceMatchingAlgorithm" start = perf_counter() processing_pair = instance_matcher.match_instances( processing_pair, @@ -334,9 +302,7 @@ def panoptic_evaluate( # Third Phase: Instance Evaluation if isinstance(processing_pair, MatchedInstancePair): - intermediate_steps_data.add_intermediate_arr_data( - processing_pair.copy(), InputType.MATCHED_INSTANCE - ) + intermediate_steps_data.add_intermediate_arr_data(processing_pair.copy(), InputType.MATCHED_INSTANCE) processing_pair = _handle_zero_instances_cases( processing_pair, eval_metrics=instance_metrics, diff --git a/panoptica/panoptica_statistics.py b/panoptica/panoptica_statistics.py index c191cbb..fe8e9b8 100644 --- a/panoptica/panoptica_statistics.py +++ b/panoptica/panoptica_statistics.py @@ -15,10 +15,16 @@ class ValueSummary: def __init__(self, value_list: list[float]) -> None: self.__value_list = value_list - self.__avg = float(np.average(value_list)) - self.__std = float(np.std(value_list)) - self.__min = min(value_list) - self.__max = max(value_list) + if len(value_list) == 0: + self.__avg = np.nan + self.__std = np.nan + self.__min = np.nan + self.__max = np.nan + else: + self.__avg = float(np.average(value_list)) + self.__std = float(np.std(value_list)) + self.__min = min(value_list) + self.__max = max(value_list) @property def values(self) -> list[float]: @@ -40,6 +46,9 @@ def min(self) -> float: def max(self) -> float: return self.__max + def __repr__(self): + return str(self) + def __str__(self): return f"[{round(self.min, 3)}, {round(self.max, 3)}], avg = {round(self.avg, 3)} +- {round(self.std, 3)}" @@ -88,9 +97,7 @@ def from_file(cls, file: str): rows = [row for row in rd] header = rows[0] - assert ( - header[0] == "subject_name" - ), "First column is not subject_names, something wrong with the file?" + assert header[0] == "subject_name", "First column is not subject_names, something wrong with the file?" keys_in_order = list([tuple(c.split("-")) for c in header[1:]]) metric_names = [] @@ -131,19 +138,13 @@ def from_file(cls, file: str): return Panoptica_Statistic(subj_names=subj_names, value_dict=value_dict) def _assertgroup(self, group): - assert ( - group in self.__groupnames - ), f"group {group} not existent, only got groups {self.__groupnames}" + assert group in self.__groupnames, f"group {group} not existent, only got groups {self.__groupnames}" def _assertmetric(self, metric): - assert ( - metric in self.__metricnames - ), f"metric {metric} not existent, only got metrics {self.__metricnames}" + assert metric in self.__metricnames, f"metric {metric} not existent, only got metrics {self.__metricnames}" def _assertsubject(self, subjectname): - assert ( - subjectname in self.__subj_names - ), f"subject {subjectname} not in list of subjects, got {self.__subj_names}" + assert subjectname in self.__subj_names, f"subject {subjectname} not in list of subjects, got {self.__subj_names}" def get(self, group, metric, remove_nones: bool = False) -> list[float]: """Returns the list of values for given group and metric @@ -176,10 +177,31 @@ def get_one_subject(self, subjectname: str): """ self._assertsubject(subjectname) sidx = self.__subj_names.index(subjectname) - return { - g: {m: self.get(g, m)[sidx] for m in self.__metricnames} - for g in self.__groupnames - } + return {g: {m: self.get(g, m)[sidx] for m in self.__metricnames} for g in self.__groupnames} + + def get_one_metric(self, metricname: str): + """Gets the dictionary mapping the group to the metrics specified + + Args: + metricname (str): _description_ + + Returns: + _type_: _description_ + """ + self._assertmetric(metricname) + return {g: self.get(g, metricname) for g in self.__groupnames} + + def get_one_group(self, groupname: str): + """Gets the dictionary mapping metric to values for ONE group + + Args: + groupname (str): _description_ + + Returns: + _type_: _description_ + """ + self._assertgroup(groupname) + return {m: self.get(groupname, m) for m in self.__metricnames} def get_across_groups(self, metric) -> list[float]: """Given metric, gives list of all values (even across groups!) Treat with care! @@ -208,13 +230,8 @@ def get_summary_across_groups(self) -> dict[str, ValueSummary]: summary_dict[m] = ValueSummary(value_list) return summary_dict - def get_summary_dict( - self, include_across_group: bool = True - ) -> dict[str, dict[str, ValueSummary]]: - summary_dict = { - g: {m: self.get_summary(g, m) for m in self.__metricnames} - for g in self.__groupnames - } + def get_summary_dict(self, include_across_group: bool = True) -> dict[str, dict[str, ValueSummary]]: + summary_dict = {g: {m: self.get_summary(g, m) for m in self.__metricnames} for g in self.__groupnames} if include_across_group: summary_dict["across_groups"] = self.get_summary_across_groups() return summary_dict @@ -258,10 +275,7 @@ def get_summary_figure( Returns: _type_: _description_ """ - data_plot = { - g: np.asarray(self.get(g, metric, remove_nones=True)) - for g in self.__groupnames - } + data_plot = {g: np.asarray(self.get(g, metric, remove_nones=True)) for g in self.__groupnames} if manual_metric_range is not None: assert manual_metric_range[0] < manual_metric_range[1], manual_metric_range change = (manual_metric_range[1] - manual_metric_range[0]) / 100 @@ -313,14 +327,10 @@ def make_curve_over_setups( if isinstance(alternate_groupnames, str): alternate_groupnames = [alternate_groupnames] - assert ( - plot_as_barchart or len(groups) == 1 - ), "When plotting without barcharts, you cannot plot more than one group at the same time" + assert plot_as_barchart or len(groups) == 1, "When plotting without barcharts, you cannot plot more than one group at the same time" # for setupname, stat in statistics_dict.items(): - assert ( - metric in stat.metricnames - ), f"metric {metric} not in statistic obj {setupname}" + assert metric in stat.metricnames, f"metric {metric} not in statistic obj {setupname}" setupnames = list(statistics_dict.keys()) convert_x_to_digit = True @@ -340,25 +350,17 @@ def make_curve_over_setups( # Y values are average metric values in that group and metric for idx, g in enumerate(groups): - Y = [ - ValueSummary(stat.get(g, metric, remove_nones=True)).avg - for stat in statistics_dict.values() - ] + Y = [ValueSummary(stat.get(g, metric, remove_nones=True)).avg for stat in statistics_dict.values()] name = g if alternate_groupnames is None else alternate_groupnames[idx] if plot_std: - Ystd = [ - ValueSummary(stat.get(g, metric, remove_nones=True)).std - for stat in statistics_dict.values() - ] + Ystd = [ValueSummary(stat.get(g, metric, remove_nones=True)).std for stat in statistics_dict.values()] else: Ystd = None if plot_as_barchart: - fig.add_trace( - go.Bar(name=name, x=X, y=Y, error_y=dict(type="data", array=Ystd)) - ) + fig.add_trace(go.Bar(name=name, x=X, y=Y, error_y=dict(type="data", array=Ystd))) else: # lineplot fig.add_trace( @@ -378,9 +380,7 @@ def make_curve_over_setups( height=height, showlegend=True, yaxis_title=metric if yaxis_title is None else yaxis_title, - xaxis_title=( - "Different setups and groups" if xaxis_title is None else xaxis_title - ), + xaxis_title=("Different setups and groups" if xaxis_title is None else xaxis_title), font={"family": "Arial"}, title=figure_title, ) @@ -422,9 +422,7 @@ def plot_box( if sort: df_by_spec_count = df_data.groupby(name_method).mean() df_by_spec_count = dict(df_by_spec_count[name_metric].items()) - df_data["mean"] = df_data[name_method].apply( - lambda x: df_by_spec_count[x] * (1 if orientation_horizontal else -1) - ) + df_data["mean"] = df_data[name_method].apply(lambda x: df_by_spec_count[x] * (1 if orientation_horizontal else -1)) df_data = df_data.sort_values(by="mean") if not orientation_horizontal: fig = px.strip( From eb284e069418d98b4701110f5d6aa928cbbf0ece Mon Sep 17 00:00:00 2001 From: "brainless-bot[bot]" <153751247+brainless-bot[bot]@users.noreply.github.com> Date: Fri, 31 Jan 2025 10:45:39 +0000 Subject: [PATCH 2/2] Autoformat with black --- panoptica/panoptica_aggregator.py | 18 +++++++-- panoptica/panoptica_evaluator.py | 62 ++++++++++++++++++++++------- panoptica/panoptica_statistics.py | 65 ++++++++++++++++++++++++------- 3 files changed, 112 insertions(+), 33 deletions(-) diff --git a/panoptica/panoptica_aggregator.py b/panoptica/panoptica_aggregator.py index 14d4132..9e08b3f 100644 --- a/panoptica/panoptica_aggregator.py +++ b/panoptica/panoptica_aggregator.py @@ -67,7 +67,9 @@ def __init__( if isinstance(output_file, str): output_file = Path(output_file) # uses tsv - assert output_file.parent.exists(), f"Directory {str(output_file.parent)} does not exist" + assert ( + output_file.parent.exists() + ), f"Directory {str(output_file.parent)} does not exist" out_file_path = str(output_file) @@ -81,13 +83,19 @@ def __init__( else: out_file_path += ".tsv" # add extension - out_buffer_file: Path = Path(out_file_path).parent.joinpath("panoptica_aggregator_tmp.tsv") + out_buffer_file: Path = Path(out_file_path).parent.joinpath( + "panoptica_aggregator_tmp.tsv" + ) self.__output_buffer_file = out_buffer_file Path(out_file_path).parent.mkdir(parents=True, exist_ok=True) self.__output_file = out_file_path - header = ["subject_name"] + [f"{g}-{m}" for g in self.__class_group_names for m in self.__evaluation_metrics] + header = ["subject_name"] + [ + f"{g}-{m}" + for g in self.__class_group_names + for m in self.__evaluation_metrics + ] header_hash = hash("+".join(header)) if not output_file.exists(): @@ -101,7 +109,9 @@ def __init__( continue_file = True else: # TODO should also hash panoptica_evaluator just to make sure! and then save into header of file - assert header_hash == hash("+".join(header_list)), "Hash of header not the same! You are using a different setup!" + assert header_hash == hash( + "+".join(header_list) + ), "Hash of header not the same! You are using a different setup!" if out_buffer_file.exists(): os.remove(out_buffer_file) diff --git a/panoptica/panoptica_evaluator.py b/panoptica/panoptica_evaluator.py index 8b630aa..f37a8d1 100644 --- a/panoptica/panoptica_evaluator.py +++ b/panoptica/panoptica_evaluator.py @@ -82,9 +82,13 @@ def __init__( segmentation_class_groups = _NoSegmentationClassGroups() self.__segmentation_class_groups = segmentation_class_groups - self.__edge_case_handler = edge_case_handler if edge_case_handler is not None else EdgeCaseHandler() + self.__edge_case_handler = ( + edge_case_handler if edge_case_handler is not None else EdgeCaseHandler() + ) if self.__decision_metric is not None: - assert self.__decision_threshold is not None, "decision metric set but no decision threshold for it" + assert ( + self.__decision_threshold is not None + ), "decision metric set but no decision threshold for it" # self.__log_times = log_times self.__verbose = verbose @@ -118,10 +122,16 @@ def evaluate( verbose: bool | None = None, ) -> dict[str, PanopticaResult]: processing_pair = self.__expected_input(prediction_arr, reference_arr) - assert isinstance(processing_pair, self.__expected_input.value), f"input not of expected type {self.__expected_input}" + assert isinstance( + processing_pair, self.__expected_input.value + ), f"input not of expected type {self.__expected_input}" - self.__segmentation_class_groups.has_defined_labels_for(processing_pair.prediction_arr, raise_error=True) - self.__segmentation_class_groups.has_defined_labels_for(processing_pair.reference_arr, raise_error=True) + self.__segmentation_class_groups.has_defined_labels_for( + processing_pair.prediction_arr, raise_error=True + ) + self.__segmentation_class_groups.has_defined_labels_for( + processing_pair.reference_arr, raise_error=True + ) result_grouped: dict[str, PanopticaResult] = {} for group_name, label_group in self.__segmentation_class_groups.items(): @@ -130,7 +140,11 @@ def evaluate( label_group, processing_pair, result_all, - save_group_times=(self.__save_group_times if save_group_times is None else save_group_times), + save_group_times=( + self.__save_group_times + if save_group_times is None + else save_group_times + ), log_times=log_times, verbose=verbose, ) @@ -152,7 +166,9 @@ def _set_instance_matcher(self, matcher: InstanceMatchingAlgorithm): @property def resulting_metric_keys(self) -> list[str]: if self.__resulting_metric_keys is None: - dummy_input = MatchedInstancePair(np.ones((1, 1, 1), dtype=np.uint8), np.ones((1, 1, 1), dtype=np.uint8)) + dummy_input = MatchedInstancePair( + np.ones((1, 1, 1), dtype=np.uint8), np.ones((1, 1, 1), dtype=np.uint8) + ) res = self._evaluate_group( group_name="", label_group=LabelGroup(1, single_instance=False), @@ -186,7 +202,9 @@ def _evaluate_group( single_instance_mode = label_group.single_instance processing_pair_grouped = processing_pair.__class__(prediction_arr=prediction_arr_grouped, reference_arr=reference_arr_grouped) # type: ignore decision_threshold = self.__decision_threshold - if single_instance_mode and not isinstance(processing_pair, MatchedInstancePair): + if single_instance_mode and not isinstance( + processing_pair, MatchedInstancePair + ): processing_pair_grouped = MatchedInstancePair( prediction_arr=processing_pair_grouped.prediction_arr, reference_arr=processing_pair_grouped.reference_arr, @@ -266,12 +284,22 @@ def panoptic_evaluate( # Crops away unecessary space of zeroes input_pair.crop_data() - processing_pair: SemanticPair | UnmatchedInstancePair | MatchedInstancePair | EvaluateInstancePair | PanopticaResult = input_pair.copy() + processing_pair: ( + SemanticPair + | UnmatchedInstancePair + | MatchedInstancePair + | EvaluateInstancePair + | PanopticaResult + ) = input_pair.copy() # First Phase: Instance Approximation if isinstance(processing_pair, SemanticPair): - intermediate_steps_data.add_intermediate_arr_data(processing_pair.copy(), InputType.SEMANTIC) - assert instance_approximator is not None, "Got SemanticPair but not InstanceApproximator" + intermediate_steps_data.add_intermediate_arr_data( + processing_pair.copy(), InputType.SEMANTIC + ) + assert ( + instance_approximator is not None + ), "Got SemanticPair but not InstanceApproximator" if verbose: print("-- Got SemanticPair, will approximate instances") start = perf_counter() @@ -281,7 +309,9 @@ def panoptic_evaluate( # Second Phase: Instance Matching if isinstance(processing_pair, UnmatchedInstancePair): - intermediate_steps_data.add_intermediate_arr_data(processing_pair.copy(), InputType.UNMATCHED_INSTANCE) + intermediate_steps_data.add_intermediate_arr_data( + processing_pair.copy(), InputType.UNMATCHED_INSTANCE + ) processing_pair = _handle_zero_instances_cases( processing_pair, eval_metrics=instance_metrics, @@ -292,7 +322,9 @@ def panoptic_evaluate( if isinstance(processing_pair, UnmatchedInstancePair): if verbose: print("-- Got UnmatchedInstancePair, will match instances") - assert instance_matcher is not None, "Got UnmatchedInstancePair but not InstanceMatchingAlgorithm" + assert ( + instance_matcher is not None + ), "Got UnmatchedInstancePair but not InstanceMatchingAlgorithm" start = perf_counter() processing_pair = instance_matcher.match_instances( processing_pair, @@ -302,7 +334,9 @@ def panoptic_evaluate( # Third Phase: Instance Evaluation if isinstance(processing_pair, MatchedInstancePair): - intermediate_steps_data.add_intermediate_arr_data(processing_pair.copy(), InputType.MATCHED_INSTANCE) + intermediate_steps_data.add_intermediate_arr_data( + processing_pair.copy(), InputType.MATCHED_INSTANCE + ) processing_pair = _handle_zero_instances_cases( processing_pair, eval_metrics=instance_metrics, diff --git a/panoptica/panoptica_statistics.py b/panoptica/panoptica_statistics.py index fe8e9b8..1ed9315 100644 --- a/panoptica/panoptica_statistics.py +++ b/panoptica/panoptica_statistics.py @@ -97,7 +97,9 @@ def from_file(cls, file: str): rows = [row for row in rd] header = rows[0] - assert header[0] == "subject_name", "First column is not subject_names, something wrong with the file?" + assert ( + header[0] == "subject_name" + ), "First column is not subject_names, something wrong with the file?" keys_in_order = list([tuple(c.split("-")) for c in header[1:]]) metric_names = [] @@ -138,13 +140,19 @@ def from_file(cls, file: str): return Panoptica_Statistic(subj_names=subj_names, value_dict=value_dict) def _assertgroup(self, group): - assert group in self.__groupnames, f"group {group} not existent, only got groups {self.__groupnames}" + assert ( + group in self.__groupnames + ), f"group {group} not existent, only got groups {self.__groupnames}" def _assertmetric(self, metric): - assert metric in self.__metricnames, f"metric {metric} not existent, only got metrics {self.__metricnames}" + assert ( + metric in self.__metricnames + ), f"metric {metric} not existent, only got metrics {self.__metricnames}" def _assertsubject(self, subjectname): - assert subjectname in self.__subj_names, f"subject {subjectname} not in list of subjects, got {self.__subj_names}" + assert ( + subjectname in self.__subj_names + ), f"subject {subjectname} not in list of subjects, got {self.__subj_names}" def get(self, group, metric, remove_nones: bool = False) -> list[float]: """Returns the list of values for given group and metric @@ -177,7 +185,10 @@ def get_one_subject(self, subjectname: str): """ self._assertsubject(subjectname) sidx = self.__subj_names.index(subjectname) - return {g: {m: self.get(g, m)[sidx] for m in self.__metricnames} for g in self.__groupnames} + return { + g: {m: self.get(g, m)[sidx] for m in self.__metricnames} + for g in self.__groupnames + } def get_one_metric(self, metricname: str): """Gets the dictionary mapping the group to the metrics specified @@ -230,8 +241,13 @@ def get_summary_across_groups(self) -> dict[str, ValueSummary]: summary_dict[m] = ValueSummary(value_list) return summary_dict - def get_summary_dict(self, include_across_group: bool = True) -> dict[str, dict[str, ValueSummary]]: - summary_dict = {g: {m: self.get_summary(g, m) for m in self.__metricnames} for g in self.__groupnames} + def get_summary_dict( + self, include_across_group: bool = True + ) -> dict[str, dict[str, ValueSummary]]: + summary_dict = { + g: {m: self.get_summary(g, m) for m in self.__metricnames} + for g in self.__groupnames + } if include_across_group: summary_dict["across_groups"] = self.get_summary_across_groups() return summary_dict @@ -275,7 +291,10 @@ def get_summary_figure( Returns: _type_: _description_ """ - data_plot = {g: np.asarray(self.get(g, metric, remove_nones=True)) for g in self.__groupnames} + data_plot = { + g: np.asarray(self.get(g, metric, remove_nones=True)) + for g in self.__groupnames + } if manual_metric_range is not None: assert manual_metric_range[0] < manual_metric_range[1], manual_metric_range change = (manual_metric_range[1] - manual_metric_range[0]) / 100 @@ -327,10 +346,14 @@ def make_curve_over_setups( if isinstance(alternate_groupnames, str): alternate_groupnames = [alternate_groupnames] - assert plot_as_barchart or len(groups) == 1, "When plotting without barcharts, you cannot plot more than one group at the same time" + assert ( + plot_as_barchart or len(groups) == 1 + ), "When plotting without barcharts, you cannot plot more than one group at the same time" # for setupname, stat in statistics_dict.items(): - assert metric in stat.metricnames, f"metric {metric} not in statistic obj {setupname}" + assert ( + metric in stat.metricnames + ), f"metric {metric} not in statistic obj {setupname}" setupnames = list(statistics_dict.keys()) convert_x_to_digit = True @@ -350,17 +373,25 @@ def make_curve_over_setups( # Y values are average metric values in that group and metric for idx, g in enumerate(groups): - Y = [ValueSummary(stat.get(g, metric, remove_nones=True)).avg for stat in statistics_dict.values()] + Y = [ + ValueSummary(stat.get(g, metric, remove_nones=True)).avg + for stat in statistics_dict.values() + ] name = g if alternate_groupnames is None else alternate_groupnames[idx] if plot_std: - Ystd = [ValueSummary(stat.get(g, metric, remove_nones=True)).std for stat in statistics_dict.values()] + Ystd = [ + ValueSummary(stat.get(g, metric, remove_nones=True)).std + for stat in statistics_dict.values() + ] else: Ystd = None if plot_as_barchart: - fig.add_trace(go.Bar(name=name, x=X, y=Y, error_y=dict(type="data", array=Ystd))) + fig.add_trace( + go.Bar(name=name, x=X, y=Y, error_y=dict(type="data", array=Ystd)) + ) else: # lineplot fig.add_trace( @@ -380,7 +411,9 @@ def make_curve_over_setups( height=height, showlegend=True, yaxis_title=metric if yaxis_title is None else yaxis_title, - xaxis_title=("Different setups and groups" if xaxis_title is None else xaxis_title), + xaxis_title=( + "Different setups and groups" if xaxis_title is None else xaxis_title + ), font={"family": "Arial"}, title=figure_title, ) @@ -422,7 +455,9 @@ def plot_box( if sort: df_by_spec_count = df_data.groupby(name_method).mean() df_by_spec_count = dict(df_by_spec_count[name_metric].items()) - df_data["mean"] = df_data[name_method].apply(lambda x: df_by_spec_count[x] * (1 if orientation_horizontal else -1)) + df_data["mean"] = df_data[name_method].apply( + lambda x: df_by_spec_count[x] * (1 if orientation_horizontal else -1) + ) df_data = df_data.sort_values(by="mean") if not orientation_horizontal: fig = px.strip(