diff --git a/CHANGELOG.md b/CHANGELOG.md index 7e2bd17b325..8a9a7c49d0e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,7 +30,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- +- Fixed axis names with Precision-Recall curve ([#2462](https://github.com/Lightning-AI/torchmetrics/pull/2462)) ## [1.3.2] - 2024-03-18 diff --git a/src/torchmetrics/utilities/plot.py b/src/torchmetrics/utilities/plot.py index feb906810bc..b1ec17597b7 100644 --- a/src/torchmetrics/utilities/plot.py +++ b/src/torchmetrics/utilities/plot.py @@ -307,9 +307,6 @@ def plot_curve( if isinstance(x, Tensor) and isinstance(y, Tensor) and x.ndim == 1 and y.ndim == 1: label = f"AUC={score.item():0.3f}" if score is not None else None ax.plot(x.detach().cpu(), y.detach().cpu(), linestyle="-", linewidth=2, label=label) - if label_names is not None: - ax.set_xlabel(label_names[0]) - ax.set_ylabel(label_names[1]) if label is not None: ax.legend() elif (isinstance(x, list) and isinstance(y, list)) or ( @@ -318,12 +315,15 @@ def plot_curve( for i, (x_, y_) in enumerate(zip(x, y)): label = f"{legend_name}_{i}" if legend_name is not None else str(i) label += f" AUC={score[i].item():0.3f}" if score is not None else "" - ax.plot(x_.detach().cpu(), y_.detach().cpu(), label=label) + ax.plot(x_.detach().cpu(), y_.detach().cpu(), linestyle="-", linewidth=2, label=label) ax.legend() else: raise ValueError( f"Unknown format for argument `x` and `y`. Expected either list or tensors but got {type(x)} and {type(y)}." ) + if label_names is not None: + ax.set_xlabel(label_names[0]) + ax.set_ylabel(label_names[1]) ax.grid(True) ax.set_title(name)