Skip to content

Commit

Permalink
Fix axis names (#2462)
Browse files Browse the repository at this point in the history
* Always add axis label names if argument is specified
* Consistent plotting style
* chlog

---------

Co-authored-by: Jirka <[email protected]>
  • Loading branch information
baskrahmer and Borda authored Mar 19, 2024
1 parent 3c5ceeb commit e0f239b
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

-
- Fixed axis names with Precision-Recall curve ([#2462](https://github.com/Lightning-AI/torchmetrics/pull/2462))


## [1.3.2] - 2024-03-18
Expand Down
8 changes: 4 additions & 4 deletions src/torchmetrics/utilities/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,9 +307,6 @@ def plot_curve(
if isinstance(x, Tensor) and isinstance(y, Tensor) and x.ndim == 1 and y.ndim == 1:
label = f"AUC={score.item():0.3f}" if score is not None else None
ax.plot(x.detach().cpu(), y.detach().cpu(), linestyle="-", linewidth=2, label=label)
if label_names is not None:
ax.set_xlabel(label_names[0])
ax.set_ylabel(label_names[1])
if label is not None:
ax.legend()
elif (isinstance(x, list) and isinstance(y, list)) or (
Expand All @@ -318,12 +315,15 @@ def plot_curve(
for i, (x_, y_) in enumerate(zip(x, y)):
label = f"{legend_name}_{i}" if legend_name is not None else str(i)
label += f" AUC={score[i].item():0.3f}" if score is not None else ""
ax.plot(x_.detach().cpu(), y_.detach().cpu(), label=label)
ax.plot(x_.detach().cpu(), y_.detach().cpu(), linestyle="-", linewidth=2, label=label)
ax.legend()
else:
raise ValueError(
f"Unknown format for argument `x` and `y`. Expected either list or tensors but got {type(x)} and {type(y)}."
)
if label_names is not None:
ax.set_xlabel(label_names[0])
ax.set_ylabel(label_names[1])
ax.grid(True)
ax.set_title(name)

Expand Down

0 comments on commit e0f239b

Please sign in to comment.