diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index 58a3fd36f3..0cd31b89c2 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -61,6 +61,7 @@ def write_metrics_reports( summary_ops: str | Sequence[str] | None, deli: str = ",", output_type: str = "csv", + class_labels: list[str] | None = None, ) -> None: """ Utility function to write the metrics into files, contains 3 parts: @@ -94,6 +95,8 @@ class mean median max 5percentile 95percentile notnans deli: the delimiter character in the saved file, default to "," as the default output type is `csv`. to be consistent with: https://docs.python.org/3/library/csv.html#csv.Dialect.delimiter. output_type: expected output file type, supported types: ["csv"], default to "csv". + class_labels: list of class names used to name the classes in the output report, if None, + "class0", ..., "classn" are used, default to None. """ if output_type.lower() != "csv": @@ -118,7 +121,12 @@ class mean median max 5percentile 95percentile notnans v = v.reshape((-1, 1)) # add the average value of all classes to v - class_labels = ["class" + str(i) for i in range(v.shape[1])] + ["mean"] + if class_labels is None: + class_labels = ["class" + str(i) for i in range(v.shape[1])] + else: + class_labels = [str(i) for i in class_labels] # ensure to have a list of str + + class_labels += ["mean"] v = np.concatenate([v, np.nanmean(v, axis=1, keepdims=True)], axis=1) with open(os.path.join(save_dir, f"{k}_raw.csv"), "w") as f: