From 03ddcc7f65229a0c09841f024499c089fb3f1b4f Mon Sep 17 00:00:00 2001 From: iback Date: Thu, 12 Dec 2024 12:23:29 +0000 Subject: [PATCH 1/2] increased panoptica statistics utility --- panoptica/panoptica_statistics.py | 55 ++++++------------- panoptica/utils/processing_pair.py | 84 +++++++----------------------- 2 files changed, 36 insertions(+), 103 deletions(-) diff --git a/panoptica/panoptica_statistics.py b/panoptica/panoptica_statistics.py index 7495054..ff66f37 100644 --- a/panoptica/panoptica_statistics.py +++ b/panoptica/panoptica_statistics.py @@ -86,9 +86,7 @@ def from_file(cls, file: str): rows = [row for row in rd] header = rows[0] - assert ( - header[0] == "subject_name" - ), "First column is not subject_names, something wrong with the file?" + assert header[0] == "subject_name", "First column is not subject_names, something wrong with the file?" keys_in_order = list([tuple(c.split("-")) for c in header[1:]]) metric_names = [] @@ -129,19 +127,13 @@ def from_file(cls, file: str): return Panoptica_Statistic(subj_names=subj_names, value_dict=value_dict) def _assertgroup(self, group): - assert ( - group in self.__groupnames - ), f"group {group} not existent, only got groups {self.__groupnames}" + assert group in self.__groupnames, f"group {group} not existent, only got groups {self.__groupnames}" def _assertmetric(self, metric): - assert ( - metric in self.__metricnames - ), f"metric {metric} not existent, only got metrics {self.__metricnames}" + assert metric in self.__metricnames, f"metric {metric} not existent, only got metrics {self.__metricnames}" def _assertsubject(self, subjectname): - assert ( - subjectname in self.__subj_names - ), f"subject {subjectname} not in list of subjects, got {self.__subj_names}" + assert subjectname in self.__subj_names, f"subject {subjectname} not in list of subjects, got {self.__subj_names}" def get(self, group, metric, remove_nones: bool = False) -> list[float]: """Returns the list of values for given group and metric @@ -174,10 +166,7 @@ def get_one_subject(self, subjectname: str): """ self._assertsubject(subjectname) sidx = self.__subj_names.index(subjectname) - return { - g: {m: self.get(g, m)[sidx] for m in self.__metricnames} - for g in self.__groupnames - } + return {g: {m: self.get(g, m)[sidx] for m in self.__metricnames} for g in self.__groupnames} def get_across_groups(self, metric) -> list[float]: """Given metric, gives list of all values (even across groups!) Treat with care! @@ -206,13 +195,8 @@ def get_summary_across_groups(self) -> dict[str, ValueSummary]: summary_dict[m] = ValueSummary(value_list) return summary_dict - def get_summary_dict( - self, include_across_group: bool = True - ) -> dict[str, dict[str, ValueSummary]]: - summary_dict = { - g: {m: self.get_summary(g, m) for m in self.__metricnames} - for g in self.__groupnames - } + def get_summary_dict(self, include_across_group: bool = True) -> dict[str, dict[str, ValueSummary]]: + summary_dict = {g: {m: self.get_summary(g, m) for m in self.__metricnames} for g in self.__groupnames} if include_across_group: summary_dict["across_groups"] = self.get_summary_across_groups() return summary_dict @@ -257,10 +241,7 @@ def get_summary_figure( _type_: _description_ """ orientation = "h" if horizontal else "v" - data_plot = { - g: np.asarray(self.get(g, metric, remove_nones=True)) - for g in self.__groupnames - } + data_plot = {g: np.asarray(self.get(g, metric, remove_nones=True)) for g in self.__groupnames} if manual_metric_range is not None: assert manual_metric_range[0] < manual_metric_range[1], manual_metric_range change = (manual_metric_range[1] - manual_metric_range[0]) / 100 @@ -293,6 +274,7 @@ def make_curve_over_setups( fig: None = None, plot_dotsize: int | None = None, plot_lines: bool = True, + plot_std: bool = False, ): if groups is None: groups = list(statistics_dict.values())[0].groupnames @@ -303,9 +285,7 @@ def make_curve_over_setups( alternate_groupnames = [alternate_groupnames] # for setupname, stat in statistics_dict.items(): - assert ( - metric in stat.metricnames - ), f"metric {metric} not in statistic obj {setupname}" + assert metric in stat.metricnames, f"metric {metric} not in statistic obj {setupname}" setupnames = list(statistics_dict.keys()) convert_x_to_digit = True @@ -330,18 +310,19 @@ def make_curve_over_setups( plt.grid("major") # Y values are average metric values in that group and metric for idx, g in enumerate(groups): - Y = [ - ValueSummary(stat.get(g, metric, remove_nones=True)).avg - for stat in statistics_dict.values() - ] + Y = [ValueSummary(stat.get(g, metric, remove_nones=True)).avg for stat in statistics_dict.values()] + Ystd = [ValueSummary(stat.get(g, metric, remove_nones=True)).std for stat in statistics_dict.values()] if plot_lines: - plt.plot( + p = plt.plot( X, Y, label=g if alternate_groupnames is None else alternate_groupnames[idx], ) + if plot_std: + plt.fill_between(X, np.subtract(Y, Ystd), np.add(Y, Ystd), alpha=0.25, edgecolor=p[-1].get_color()) + if plot_dotsize is not None: plt.scatter(X, Y, s=plot_dotsize) @@ -380,9 +361,7 @@ def plot_box( if sort: df_by_spec_count = df_data.groupby(name_method).mean() df_by_spec_count = dict(df_by_spec_count[name_metric].items()) - df_data["mean"] = df_data[name_method].apply( - lambda x: df_by_spec_count[x] * (1 if orientation == "h" else -1) - ) + df_data["mean"] = df_data[name_method].apply(lambda x: df_by_spec_count[x] * (1 if orientation == "h" else -1)) df_data = df_data.sort_values(by="mean") if orientation == "v": fig = px.strip( diff --git a/panoptica/utils/processing_pair.py b/panoptica/utils/processing_pair.py index f6901b7..09d5668 100644 --- a/panoptica/utils/processing_pair.py +++ b/panoptica/utils/processing_pair.py @@ -33,9 +33,7 @@ class _ProcessingPair(ABC): _pred_labels: tuple[int, ...] n_dim: int - def __init__( - self, prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None - ) -> None: + def __init__(self, prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None) -> None: """Initializes the processing pair with prediction and reference arrays. Args: @@ -48,12 +46,8 @@ def __init__( self._reference_arr = reference_arr self.dtype = dtype self.n_dim = reference_arr.ndim - self._ref_labels: tuple[int, ...] = tuple( - _unique_without_zeros(reference_arr) - ) # type:ignore - self._pred_labels: tuple[int, ...] = tuple( - _unique_without_zeros(prediction_arr) - ) # type:ignore + self._ref_labels: tuple[int, ...] = tuple(_unique_without_zeros(reference_arr)) # type:ignore + self._pred_labels: tuple[int, ...] = tuple(_unique_without_zeros(prediction_arr)) # type:ignore self.crop: tuple[slice, ...] = None self.is_cropped: bool = False self.uncropped_shape: tuple[int, ...] = reference_arr.shape @@ -75,13 +69,7 @@ def crop_data(self, verbose: bool = False): self._prediction_arr = self._prediction_arr[self.crop] self._reference_arr = self._reference_arr[self.crop] - ( - print( - f"-- Cropped from {self.uncropped_shape} to {self._prediction_arr.shape}" - ) - if verbose - else None - ) + (print(f"-- Cropped from {self.uncropped_shape} to {self._prediction_arr.shape}") if verbose else None) self.is_cropped = True def uncrop_data(self, verbose: bool = False): @@ -92,22 +80,14 @@ def uncrop_data(self, verbose: bool = False): """ if self.is_cropped == False: return - assert ( - self.uncropped_shape is not None - ), "Calling uncrop_data() without having cropped first" + assert self.uncropped_shape is not None, "Calling uncrop_data() without having cropped first" prediction_arr = np.zeros(self.uncropped_shape) prediction_arr[self.crop] = self._prediction_arr self._prediction_arr = prediction_arr reference_arr = np.zeros(self.uncropped_shape) reference_arr[self.crop] = self._reference_arr - ( - print( - f"-- Uncropped from {self._reference_arr.shape} to {self.uncropped_shape}" - ) - if verbose - else None - ) + (print(f"-- Uncropped from {self._reference_arr.shape} to {self.uncropped_shape}") if verbose else None) self._reference_arr = reference_arr self.is_cropped = False @@ -117,9 +97,7 @@ def set_dtype(self, type): Args: dtype (type): Expected integer type for the arrays. """ - assert np.issubdtype( - type, int_type - ), "set_dtype: tried to set dtype to something other than integers" + assert np.issubdtype(type, int_type), "set_dtype: tried to set dtype to something other than integers" self._prediction_arr = self._prediction_arr.astype(type) self._reference_arr = self._reference_arr.astype(type) @@ -211,9 +189,7 @@ def copy(self): ) # type:ignore -def _check_array_integrity( - prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None = None -): +def _check_array_integrity(prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None = None): """Validates integrity between two arrays, checking shape, dtype, and consistency with `dtype`. Args: @@ -234,12 +210,8 @@ def _check_array_integrity( assert isinstance(prediction_arr, np.ndarray) and isinstance( reference_arr, np.ndarray ), "prediction and/or reference are not numpy arrays" - assert ( - prediction_arr.shape == reference_arr.shape - ), f"shape mismatch, got {prediction_arr.shape},{reference_arr.shape}" - assert ( - prediction_arr.dtype == reference_arr.dtype - ), f"dtype mismatch, got {prediction_arr.dtype},{reference_arr.dtype}" + assert prediction_arr.shape == reference_arr.shape, f"shape mismatch, got {prediction_arr.shape},{reference_arr.shape}" + # assert prediction_arr.dtype == reference_arr.dtype, f"dtype mismatch, got {prediction_arr.dtype},{reference_arr.dtype}" if dtype is not None: assert ( np.issubdtype(prediction_arr.dtype, dtype) @@ -331,15 +303,11 @@ def __init__( self.matched_instances = matched_instances if missed_reference_labels is None: - missed_reference_labels = list( - [i for i in self._ref_labels if i not in self._pred_labels] - ) + missed_reference_labels = list([i for i in self._ref_labels if i not in self._pred_labels]) self.missed_reference_labels = missed_reference_labels if missed_prediction_labels is None: - missed_prediction_labels = list( - [i for i in self._pred_labels if i not in self._ref_labels] - ) + missed_prediction_labels = list([i for i in self._pred_labels if i not in self._ref_labels]) self.missed_prediction_labels = missed_prediction_labels @property @@ -412,9 +380,7 @@ class InputType(_Enum_Compare): UNMATCHED_INSTANCE = UnmatchedInstancePair MATCHED_INSTANCE = MatchedInstancePair - def __call__( - self, prediction_arr: np.ndarray, reference_arr: np.ndarray - ) -> _ProcessingPair: + def __call__(self, prediction_arr: np.ndarray, reference_arr: np.ndarray) -> _ProcessingPair: return self.value(prediction_arr, reference_arr) @@ -432,9 +398,7 @@ def __init__(self, original_input: _ProcessingPair | None): self._original_input = original_input self._intermediatesteps: dict[str, _ProcessingPair] = {} - def add_intermediate_arr_data( - self, processing_pair: _ProcessingPair, inputtype: InputType - ): + def add_intermediate_arr_data(self, processing_pair: _ProcessingPair, inputtype: InputType): type_name = inputtype.name self.add_intermediate_data(type_name, processing_pair) @@ -444,36 +408,26 @@ def add_intermediate_data(self, key, value): @property def original_prediction_arr(self): - assert ( - self._original_input is not None - ), "Original prediction_arr is None, there are no intermediate steps" + assert self._original_input is not None, "Original prediction_arr is None, there are no intermediate steps" return self._original_input.prediction_arr @property def original_reference_arr(self): - assert ( - self._original_input is not None - ), "Original reference_arr is None, there are no intermediate steps" + assert self._original_input is not None, "Original reference_arr is None, there are no intermediate steps" return self._original_input.reference_arr def prediction_arr(self, inputtype: InputType): type_name = inputtype.name procpair = self[type_name] - assert isinstance( - procpair, _ProcessingPair - ), f"step {type_name} is not a processing pair, error" + assert isinstance(procpair, _ProcessingPair), f"step {type_name} is not a processing pair, error" return procpair.prediction_arr def reference_arr(self, inputtype: InputType): type_name = inputtype.name procpair = self[type_name] - assert isinstance( - procpair, _ProcessingPair - ), f"step {type_name} is not a processing pair, error" + assert isinstance(procpair, _ProcessingPair), f"step {type_name} is not a processing pair, error" return procpair.reference_arr def __getitem__(self, key): - assert ( - key in self._intermediatesteps - ), f"key {key} not in intermediate steps, maybe the step was skipped?" + assert key in self._intermediatesteps, f"key {key} not in intermediate steps, maybe the step was skipped?" return self._intermediatesteps[key] From 309574a7181b383b8ba21104bd3b56b0c1b46618 Mon Sep 17 00:00:00 2001 From: "brainless-bot[bot]" <153751247+brainless-bot[bot]@users.noreply.github.com> Date: Thu, 12 Dec 2024 12:38:58 +0000 Subject: [PATCH 2/2] Autoformat with black --- panoptica/panoptica_statistics.py | 61 ++++++++++++++++++----- panoptica/utils/processing_pair.py | 80 +++++++++++++++++++++++------- 2 files changed, 110 insertions(+), 31 deletions(-) diff --git a/panoptica/panoptica_statistics.py b/panoptica/panoptica_statistics.py index ff66f37..9a35a49 100644 --- a/panoptica/panoptica_statistics.py +++ b/panoptica/panoptica_statistics.py @@ -86,7 +86,9 @@ def from_file(cls, file: str): rows = [row for row in rd] header = rows[0] - assert header[0] == "subject_name", "First column is not subject_names, something wrong with the file?" + assert ( + header[0] == "subject_name" + ), "First column is not subject_names, something wrong with the file?" keys_in_order = list([tuple(c.split("-")) for c in header[1:]]) metric_names = [] @@ -127,13 +129,19 @@ def from_file(cls, file: str): return Panoptica_Statistic(subj_names=subj_names, value_dict=value_dict) def _assertgroup(self, group): - assert group in self.__groupnames, f"group {group} not existent, only got groups {self.__groupnames}" + assert ( + group in self.__groupnames + ), f"group {group} not existent, only got groups {self.__groupnames}" def _assertmetric(self, metric): - assert metric in self.__metricnames, f"metric {metric} not existent, only got metrics {self.__metricnames}" + assert ( + metric in self.__metricnames + ), f"metric {metric} not existent, only got metrics {self.__metricnames}" def _assertsubject(self, subjectname): - assert subjectname in self.__subj_names, f"subject {subjectname} not in list of subjects, got {self.__subj_names}" + assert ( + subjectname in self.__subj_names + ), f"subject {subjectname} not in list of subjects, got {self.__subj_names}" def get(self, group, metric, remove_nones: bool = False) -> list[float]: """Returns the list of values for given group and metric @@ -166,7 +174,10 @@ def get_one_subject(self, subjectname: str): """ self._assertsubject(subjectname) sidx = self.__subj_names.index(subjectname) - return {g: {m: self.get(g, m)[sidx] for m in self.__metricnames} for g in self.__groupnames} + return { + g: {m: self.get(g, m)[sidx] for m in self.__metricnames} + for g in self.__groupnames + } def get_across_groups(self, metric) -> list[float]: """Given metric, gives list of all values (even across groups!) Treat with care! @@ -195,8 +206,13 @@ def get_summary_across_groups(self) -> dict[str, ValueSummary]: summary_dict[m] = ValueSummary(value_list) return summary_dict - def get_summary_dict(self, include_across_group: bool = True) -> dict[str, dict[str, ValueSummary]]: - summary_dict = {g: {m: self.get_summary(g, m) for m in self.__metricnames} for g in self.__groupnames} + def get_summary_dict( + self, include_across_group: bool = True + ) -> dict[str, dict[str, ValueSummary]]: + summary_dict = { + g: {m: self.get_summary(g, m) for m in self.__metricnames} + for g in self.__groupnames + } if include_across_group: summary_dict["across_groups"] = self.get_summary_across_groups() return summary_dict @@ -241,7 +257,10 @@ def get_summary_figure( _type_: _description_ """ orientation = "h" if horizontal else "v" - data_plot = {g: np.asarray(self.get(g, metric, remove_nones=True)) for g in self.__groupnames} + data_plot = { + g: np.asarray(self.get(g, metric, remove_nones=True)) + for g in self.__groupnames + } if manual_metric_range is not None: assert manual_metric_range[0] < manual_metric_range[1], manual_metric_range change = (manual_metric_range[1] - manual_metric_range[0]) / 100 @@ -285,7 +304,9 @@ def make_curve_over_setups( alternate_groupnames = [alternate_groupnames] # for setupname, stat in statistics_dict.items(): - assert metric in stat.metricnames, f"metric {metric} not in statistic obj {setupname}" + assert ( + metric in stat.metricnames + ), f"metric {metric} not in statistic obj {setupname}" setupnames = list(statistics_dict.keys()) convert_x_to_digit = True @@ -310,8 +331,14 @@ def make_curve_over_setups( plt.grid("major") # Y values are average metric values in that group and metric for idx, g in enumerate(groups): - Y = [ValueSummary(stat.get(g, metric, remove_nones=True)).avg for stat in statistics_dict.values()] - Ystd = [ValueSummary(stat.get(g, metric, remove_nones=True)).std for stat in statistics_dict.values()] + Y = [ + ValueSummary(stat.get(g, metric, remove_nones=True)).avg + for stat in statistics_dict.values() + ] + Ystd = [ + ValueSummary(stat.get(g, metric, remove_nones=True)).std + for stat in statistics_dict.values() + ] if plot_lines: p = plt.plot( @@ -321,7 +348,13 @@ def make_curve_over_setups( ) if plot_std: - plt.fill_between(X, np.subtract(Y, Ystd), np.add(Y, Ystd), alpha=0.25, edgecolor=p[-1].get_color()) + plt.fill_between( + X, + np.subtract(Y, Ystd), + np.add(Y, Ystd), + alpha=0.25, + edgecolor=p[-1].get_color(), + ) if plot_dotsize is not None: plt.scatter(X, Y, s=plot_dotsize) @@ -361,7 +394,9 @@ def plot_box( if sort: df_by_spec_count = df_data.groupby(name_method).mean() df_by_spec_count = dict(df_by_spec_count[name_metric].items()) - df_data["mean"] = df_data[name_method].apply(lambda x: df_by_spec_count[x] * (1 if orientation == "h" else -1)) + df_data["mean"] = df_data[name_method].apply( + lambda x: df_by_spec_count[x] * (1 if orientation == "h" else -1) + ) df_data = df_data.sort_values(by="mean") if orientation == "v": fig = px.strip( diff --git a/panoptica/utils/processing_pair.py b/panoptica/utils/processing_pair.py index 09d5668..1e24062 100644 --- a/panoptica/utils/processing_pair.py +++ b/panoptica/utils/processing_pair.py @@ -33,7 +33,9 @@ class _ProcessingPair(ABC): _pred_labels: tuple[int, ...] n_dim: int - def __init__(self, prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None) -> None: + def __init__( + self, prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None + ) -> None: """Initializes the processing pair with prediction and reference arrays. Args: @@ -46,8 +48,12 @@ def __init__(self, prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: self._reference_arr = reference_arr self.dtype = dtype self.n_dim = reference_arr.ndim - self._ref_labels: tuple[int, ...] = tuple(_unique_without_zeros(reference_arr)) # type:ignore - self._pred_labels: tuple[int, ...] = tuple(_unique_without_zeros(prediction_arr)) # type:ignore + self._ref_labels: tuple[int, ...] = tuple( + _unique_without_zeros(reference_arr) + ) # type:ignore + self._pred_labels: tuple[int, ...] = tuple( + _unique_without_zeros(prediction_arr) + ) # type:ignore self.crop: tuple[slice, ...] = None self.is_cropped: bool = False self.uncropped_shape: tuple[int, ...] = reference_arr.shape @@ -69,7 +75,13 @@ def crop_data(self, verbose: bool = False): self._prediction_arr = self._prediction_arr[self.crop] self._reference_arr = self._reference_arr[self.crop] - (print(f"-- Cropped from {self.uncropped_shape} to {self._prediction_arr.shape}") if verbose else None) + ( + print( + f"-- Cropped from {self.uncropped_shape} to {self._prediction_arr.shape}" + ) + if verbose + else None + ) self.is_cropped = True def uncrop_data(self, verbose: bool = False): @@ -80,14 +92,22 @@ def uncrop_data(self, verbose: bool = False): """ if self.is_cropped == False: return - assert self.uncropped_shape is not None, "Calling uncrop_data() without having cropped first" + assert ( + self.uncropped_shape is not None + ), "Calling uncrop_data() without having cropped first" prediction_arr = np.zeros(self.uncropped_shape) prediction_arr[self.crop] = self._prediction_arr self._prediction_arr = prediction_arr reference_arr = np.zeros(self.uncropped_shape) reference_arr[self.crop] = self._reference_arr - (print(f"-- Uncropped from {self._reference_arr.shape} to {self.uncropped_shape}") if verbose else None) + ( + print( + f"-- Uncropped from {self._reference_arr.shape} to {self.uncropped_shape}" + ) + if verbose + else None + ) self._reference_arr = reference_arr self.is_cropped = False @@ -97,7 +117,9 @@ def set_dtype(self, type): Args: dtype (type): Expected integer type for the arrays. """ - assert np.issubdtype(type, int_type), "set_dtype: tried to set dtype to something other than integers" + assert np.issubdtype( + type, int_type + ), "set_dtype: tried to set dtype to something other than integers" self._prediction_arr = self._prediction_arr.astype(type) self._reference_arr = self._reference_arr.astype(type) @@ -189,7 +211,9 @@ def copy(self): ) # type:ignore -def _check_array_integrity(prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None = None): +def _check_array_integrity( + prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None = None +): """Validates integrity between two arrays, checking shape, dtype, and consistency with `dtype`. Args: @@ -210,7 +234,9 @@ def _check_array_integrity(prediction_arr: np.ndarray, reference_arr: np.ndarray assert isinstance(prediction_arr, np.ndarray) and isinstance( reference_arr, np.ndarray ), "prediction and/or reference are not numpy arrays" - assert prediction_arr.shape == reference_arr.shape, f"shape mismatch, got {prediction_arr.shape},{reference_arr.shape}" + assert ( + prediction_arr.shape == reference_arr.shape + ), f"shape mismatch, got {prediction_arr.shape},{reference_arr.shape}" # assert prediction_arr.dtype == reference_arr.dtype, f"dtype mismatch, got {prediction_arr.dtype},{reference_arr.dtype}" if dtype is not None: assert ( @@ -303,11 +329,15 @@ def __init__( self.matched_instances = matched_instances if missed_reference_labels is None: - missed_reference_labels = list([i for i in self._ref_labels if i not in self._pred_labels]) + missed_reference_labels = list( + [i for i in self._ref_labels if i not in self._pred_labels] + ) self.missed_reference_labels = missed_reference_labels if missed_prediction_labels is None: - missed_prediction_labels = list([i for i in self._pred_labels if i not in self._ref_labels]) + missed_prediction_labels = list( + [i for i in self._pred_labels if i not in self._ref_labels] + ) self.missed_prediction_labels = missed_prediction_labels @property @@ -380,7 +410,9 @@ class InputType(_Enum_Compare): UNMATCHED_INSTANCE = UnmatchedInstancePair MATCHED_INSTANCE = MatchedInstancePair - def __call__(self, prediction_arr: np.ndarray, reference_arr: np.ndarray) -> _ProcessingPair: + def __call__( + self, prediction_arr: np.ndarray, reference_arr: np.ndarray + ) -> _ProcessingPair: return self.value(prediction_arr, reference_arr) @@ -398,7 +430,9 @@ def __init__(self, original_input: _ProcessingPair | None): self._original_input = original_input self._intermediatesteps: dict[str, _ProcessingPair] = {} - def add_intermediate_arr_data(self, processing_pair: _ProcessingPair, inputtype: InputType): + def add_intermediate_arr_data( + self, processing_pair: _ProcessingPair, inputtype: InputType + ): type_name = inputtype.name self.add_intermediate_data(type_name, processing_pair) @@ -408,26 +442,36 @@ def add_intermediate_data(self, key, value): @property def original_prediction_arr(self): - assert self._original_input is not None, "Original prediction_arr is None, there are no intermediate steps" + assert ( + self._original_input is not None + ), "Original prediction_arr is None, there are no intermediate steps" return self._original_input.prediction_arr @property def original_reference_arr(self): - assert self._original_input is not None, "Original reference_arr is None, there are no intermediate steps" + assert ( + self._original_input is not None + ), "Original reference_arr is None, there are no intermediate steps" return self._original_input.reference_arr def prediction_arr(self, inputtype: InputType): type_name = inputtype.name procpair = self[type_name] - assert isinstance(procpair, _ProcessingPair), f"step {type_name} is not a processing pair, error" + assert isinstance( + procpair, _ProcessingPair + ), f"step {type_name} is not a processing pair, error" return procpair.prediction_arr def reference_arr(self, inputtype: InputType): type_name = inputtype.name procpair = self[type_name] - assert isinstance(procpair, _ProcessingPair), f"step {type_name} is not a processing pair, error" + assert isinstance( + procpair, _ProcessingPair + ), f"step {type_name} is not a processing pair, error" return procpair.reference_arr def __getitem__(self, key): - assert key in self._intermediatesteps, f"key {key} not in intermediate steps, maybe the step was skipped?" + assert ( + key in self._intermediatesteps + ), f"key {key} not in intermediate steps, maybe the step was skipped?" return self._intermediatesteps[key]