diff --git a/panoptica/panoptica_statistics.py b/panoptica/panoptica_statistics.py index 7495054..9a35a49 100644 --- a/panoptica/panoptica_statistics.py +++ b/panoptica/panoptica_statistics.py @@ -293,6 +293,7 @@ def make_curve_over_setups( fig: None = None, plot_dotsize: int | None = None, plot_lines: bool = True, + plot_std: bool = False, ): if groups is None: groups = list(statistics_dict.values())[0].groupnames @@ -334,14 +335,27 @@ def make_curve_over_setups( ValueSummary(stat.get(g, metric, remove_nones=True)).avg for stat in statistics_dict.values() ] + Ystd = [ + ValueSummary(stat.get(g, metric, remove_nones=True)).std + for stat in statistics_dict.values() + ] if plot_lines: - plt.plot( + p = plt.plot( X, Y, label=g if alternate_groupnames is None else alternate_groupnames[idx], ) + if plot_std: + plt.fill_between( + X, + np.subtract(Y, Ystd), + np.add(Y, Ystd), + alpha=0.25, + edgecolor=p[-1].get_color(), + ) + if plot_dotsize is not None: plt.scatter(X, Y, s=plot_dotsize) diff --git a/panoptica/utils/processing_pair.py b/panoptica/utils/processing_pair.py index f6901b7..1e24062 100644 --- a/panoptica/utils/processing_pair.py +++ b/panoptica/utils/processing_pair.py @@ -237,9 +237,7 @@ def _check_array_integrity( assert ( prediction_arr.shape == reference_arr.shape ), f"shape mismatch, got {prediction_arr.shape},{reference_arr.shape}" - assert ( - prediction_arr.dtype == reference_arr.dtype - ), f"dtype mismatch, got {prediction_arr.dtype},{reference_arr.dtype}" + # assert prediction_arr.dtype == reference_arr.dtype, f"dtype mismatch, got {prediction_arr.dtype},{reference_arr.dtype}" if dtype is not None: assert ( np.issubdtype(prediction_arr.dtype, dtype)