From dfc9e33cc7fd06b535d29aae2886da10aaf054ba Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Thu, 31 Oct 2024 22:01:08 +0100 Subject: [PATCH] tests: cleaning classif. (#2815) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/torchmetrics/utilities/compute.py | 26 +++++++++++++- tests/unittests/_helpers/testers.py | 11 +++++- .../unittests/classification/test_accuracy.py | 16 ++++----- tests/unittests/classification/test_auroc.py | 4 +-- .../classification/test_average_precision.py | 4 +-- .../classification/test_calibration_error.py | 4 +-- .../classification/test_cohen_kappa.py | 4 +-- .../classification/test_confusion_matrix.py | 6 ++-- tests/unittests/classification/test_f_beta.py | 16 ++++----- .../classification/test_hamming_distance.py | 16 ++++----- tests/unittests/classification/test_hinge.py | 4 +-- .../unittests/classification/test_jaccard.py | 6 ++-- .../classification/test_matthews_corrcoef.py | 6 ++-- .../test_precision_fixed_recall.py | 4 +-- .../classification/test_precision_recall.py | 34 ++++++++++--------- .../test_precision_recall_curve.py | 4 +-- .../test_recall_fixed_precision.py | 4 +-- tests/unittests/classification/test_roc.py | 4 +-- .../test_sensitivity_specificity.py | 4 +-- .../test_specificity_sensitivity.py | 4 +-- .../classification/test_stat_scores.py | 12 +++---- 21 files changed, 114 insertions(+), 79 deletions(-) diff --git a/src/torchmetrics/utilities/compute.py b/src/torchmetrics/utilities/compute.py index ee11a36136f..e526ecc8456 100644 --- a/src/torchmetrics/utilities/compute.py +++ b/src/torchmetrics/utilities/compute.py @@ -53,6 +53,13 @@ def _safe_divide(num: Tensor, denom: Tensor, zero_division: float = 0.0) -> Tens denom: denominator tensor, which may contain zeros zero_division: value to replace elements divided by zero + Example: + >>> import torch + >>> num = torch.tensor([1.0, 2.0, 3.0]) + >>> denom = torch.tensor([0.0, 1.0, 2.0]) + >>> _safe_divide(num, denom) + tensor([0.0000, 2.0000, 1.5000]) + """ num = num if num.is_floating_point() else num.float() denom = denom if denom.is_floating_point() else denom.float() @@ -102,6 +109,16 @@ def _auc_compute_without_check(x: Tensor, y: Tensor, direction: float, axis: int def _auc_compute(x: Tensor, y: Tensor, reorder: bool = False) -> Tensor: + """Compute area under the curve using the trapezoidal rule. + + Example: + >>> import torch + >>> x = torch.tensor([1, 2, 3, 4]) + >>> y = torch.tensor([1, 2, 3, 4]) + >>> _auc_compute(x, y) + tensor(7.5000) + + """ with torch.no_grad(): if reorder: x, x_idx = torch.sort(x, stable=True) @@ -139,7 +156,7 @@ def auc(x: Tensor, y: Tensor, reorder: bool = False) -> Tensor: def interp(x: Tensor, xp: Tensor, fp: Tensor) -> Tensor: """One-dimensional linear interpolation for monotonically increasing sample points. - Returns the one-dimensional piecewise linear interpolant to a function with + Returns the one-dimensional piecewise linear interpolation to a function with given discrete data points :math:`(xp, fp)`, evaluated at :math:`x`. Adjusted version of this https://github.com/pytorch/pytorch/issues/50334#issuecomment-1000917964 @@ -152,6 +169,13 @@ def interp(x: Tensor, xp: Tensor, fp: Tensor) -> Tensor: Returns: the interpolated values, same size as `x`. + Example: + >>> x = torch.tensor([0.5, 1.5, 2.5]) + >>> xp = torch.tensor([1, 2, 3]) + >>> fp = torch.tensor([1, 2, 3]) + >>> interp(x, xp, fp) + tensor([0.5000, 1.5000, 2.5000]) + """ m = _safe_divide(fp[1:] - fp[:-1], xp[1:] - xp[:-1]) b = fp[:-1] - (m * xp[:-1]) diff --git a/tests/unittests/_helpers/testers.py b/tests/unittests/_helpers/testers.py index 1b46d6f237f..98cc110a3ff 100644 --- a/tests/unittests/_helpers/testers.py +++ b/tests/unittests/_helpers/testers.py @@ -685,7 +685,16 @@ def inject_ignore_index(x: Tensor, ignore_index: int) -> Tensor: def remove_ignore_index(target: Tensor, preds: Tensor, ignore_index: Optional[int]) -> Tuple[Tensor, Tensor]: - """Remove samples that are equal to the ignore_index in comparison functions.""" + """Remove samples that are equal to the ignore_index in comparison functions. + + Example: + >>> target = torch.tensor([0, 1, 2, 3, 4]) + >>> preds = torch.tensor([0, 1, 2, 3, 4]) + >>> ignore_index = 2 + >>> remove_ignore_index(target, preds, ignore_index) + (tensor([0, 1, 3, 4]), tensor([0, 1, 3, 4])) + + """ if ignore_index is not None: idx = target == ignore_index target, preds = deepcopy(target[~idx]), deepcopy(preds[~idx]) diff --git a/tests/unittests/classification/test_accuracy.py b/tests/unittests/classification/test_accuracy.py index 30d4a473a84..65e42c00b07 100644 --- a/tests/unittests/classification/test_accuracy.py +++ b/tests/unittests/classification/test_accuracy.py @@ -55,14 +55,14 @@ def _reference_sklearn_accuracy_binary(preds, target, ignore_index, multidim_ave preds = (preds >= THRESHOLD).astype(np.uint8) if multidim_average == "global": - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return _reference_sklearn_accuracy(target, preds) res = [] for pred, true in zip(preds, target): pred = pred.flatten() true = true.flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) res.append(_reference_sklearn_accuracy(true, pred)) return np.stack(res) @@ -185,7 +185,7 @@ def _reference_sklearn_accuracy_multiclass(preds, target, ignore_index, multidim if multidim_average == "global": preds = preds.numpy().flatten() target = target.numpy().flatten() - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) if average == "micro": return _reference_sklearn_accuracy(target, preds) confmat = sk_confusion_matrix(target, preds, labels=list(range(NUM_CLASSES))) @@ -207,7 +207,7 @@ def _reference_sklearn_accuracy_multiclass(preds, target, ignore_index, multidim for pred, true in zip(preds, target): pred = pred.flatten() true = true.flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) if average == "micro": res.append(_reference_sklearn_accuracy(true, pred)) else: @@ -445,13 +445,13 @@ def _reference_sklearn_accuracy_multilabel(preds, target, ignore_index, multidim if average == "micro": preds = preds.flatten() target = target.flatten() - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return _reference_sklearn_accuracy(target, preds) accuracy, weights = [], [] for i in range(preds.shape[1]): pred, true = preds[:, i].flatten(), target[:, i].flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) accuracy.append(_reference_sklearn_accuracy(true, pred)) weights.append(confmat[1, 1] + confmat[1, 0]) @@ -472,7 +472,7 @@ def _reference_sklearn_accuracy_multilabel(preds, target, ignore_index, multidim for i in range(preds.shape[0]): if average == "micro": pred, true = preds[i].flatten(), target[i].flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) accuracy.append(_reference_sklearn_accuracy(true, pred)) confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) weights.append(confmat[1, 1] + confmat[1, 0]) @@ -480,7 +480,7 @@ def _reference_sklearn_accuracy_multilabel(preds, target, ignore_index, multidim scores, w = [], [] for j in range(preds.shape[1]): pred, true = preds[i, j], target[i, j] - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) scores.append(_reference_sklearn_accuracy(true, pred)) confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) w.append(confmat[1, 1] + confmat[1, 0]) diff --git a/tests/unittests/classification/test_auroc.py b/tests/unittests/classification/test_auroc.py index 3691c7305b7..30d4acb470c 100644 --- a/tests/unittests/classification/test_auroc.py +++ b/tests/unittests/classification/test_auroc.py @@ -38,7 +38,7 @@ def _reference_sklearn_auroc_binary(preds, target, max_fpr=None, ignore_index=No target = target.flatten().numpy() if not ((preds > 0) & (preds < 1)).all(): preds = sigmoid(preds) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return sk_roc_auc_score(target, preds, max_fpr=max_fpr) @@ -144,7 +144,7 @@ def _reference_sklearn_auroc_multiclass(preds, target, average="macro", ignore_i target = target.numpy().flatten() if not ((preds > 0) & (preds < 1)).all(): preds = softmax(preds, 1) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return sk_roc_auc_score(target, preds, average=average, multi_class="ovr", labels=list(range(NUM_CLASSES))) diff --git a/tests/unittests/classification/test_average_precision.py b/tests/unittests/classification/test_average_precision.py index 51d40839642..da0dc2f56b6 100644 --- a/tests/unittests/classification/test_average_precision.py +++ b/tests/unittests/classification/test_average_precision.py @@ -47,7 +47,7 @@ def _reference_sklearn_avg_precision_binary(preds, target, ignore_index=None): target = target.flatten().numpy() if np.issubdtype(preds.dtype, np.floating) and not ((preds > 0) & (preds < 1)).all(): preds = sigmoid(preds) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return sk_average_precision_score(target, preds) @@ -156,7 +156,7 @@ def _reference_sklearn_avg_precision_multiclass(preds, target, average="macro", target = target.numpy().flatten() if not ((preds > 0) & (preds < 1)).all(): preds = softmax(preds, 1) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) res = [] for i in range(NUM_CLASSES): diff --git a/tests/unittests/classification/test_calibration_error.py b/tests/unittests/classification/test_calibration_error.py index 86cfded8246..b8b6bfc1646 100644 --- a/tests/unittests/classification/test_calibration_error.py +++ b/tests/unittests/classification/test_calibration_error.py @@ -43,7 +43,7 @@ def _reference_netcal_binary_calibration_error(preds, target, n_bins, norm, igno target = target.numpy().flatten() if not ((preds > 0) & (preds < 1)).all(): preds = sigmoid(preds) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) metric = ECE if norm == "l1" else MCE return metric(n_bins).measure(preds, target) @@ -149,7 +149,7 @@ def _reference_netcal_multiclass_calibration_error(preds, target, n_bins, norm, if not ((preds > 0) & (preds < 1)).all(): preds = softmax(preds, 1) preds = np.moveaxis(preds, 1, -1).reshape((-1, preds.shape[1])) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) metric = ECE if norm == "l1" else MCE return metric(n_bins).measure(preds, target) diff --git a/tests/unittests/classification/test_cohen_kappa.py b/tests/unittests/classification/test_cohen_kappa.py index 40ebbc028bd..1f2585372bd 100644 --- a/tests/unittests/classification/test_cohen_kappa.py +++ b/tests/unittests/classification/test_cohen_kappa.py @@ -37,7 +37,7 @@ def _reference_sklearn_cohen_kappa_binary(preds, target, weights=None, ignore_in if not ((preds > 0) & (preds < 1)).all(): preds = sigmoid(preds) preds = (preds >= THRESHOLD).astype(np.uint8) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return sk_cohen_kappa(y1=target, y2=preds, weights=weights) @@ -136,7 +136,7 @@ def _reference_sklearn_cohen_kappa_multiclass(preds, target, weights, ignore_ind preds = np.argmax(preds, axis=1) preds = preds.flatten() target = target.flatten() - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return sk_cohen_kappa(y1=target, y2=preds, weights=weights) diff --git a/tests/unittests/classification/test_confusion_matrix.py b/tests/unittests/classification/test_confusion_matrix.py index 12f21451949..4d27dfc2069 100644 --- a/tests/unittests/classification/test_confusion_matrix.py +++ b/tests/unittests/classification/test_confusion_matrix.py @@ -46,7 +46,7 @@ def _reference_sklearn_confusion_matrix_binary(preds, target, normalize=None, ig if not ((preds > 0) & (preds < 1)).all(): preds = sigmoid(preds) preds = (preds >= THRESHOLD).astype(np.uint8) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return sk_confusion_matrix(y_true=target, y_pred=preds, labels=[0, 1], normalize=normalize) @@ -147,7 +147,7 @@ def _reference_sklearn_confusion_matrix_multiclass(preds, target, normalize=None preds = np.argmax(preds, axis=1) preds = preds.flatten() target = target.flatten() - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return sk_confusion_matrix(y_true=target, y_pred=preds, normalize=normalize, labels=list(range(NUM_CLASSES))) @@ -298,7 +298,7 @@ def _reference_sklearn_confusion_matrix_multilabel(preds, target, normalize=None confmat = [] for i in range(preds.shape[1]): pred, true = preds[:, i], target[:, i] - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) confmat.append(sk_confusion_matrix(true, pred, normalize=normalize, labels=[0, 1])) return np.stack(confmat, axis=0) diff --git a/tests/unittests/classification/test_f_beta.py b/tests/unittests/classification/test_f_beta.py index 03c39d336fe..075e37cc699 100644 --- a/tests/unittests/classification/test_f_beta.py +++ b/tests/unittests/classification/test_f_beta.py @@ -63,14 +63,14 @@ def _reference_sklearn_fbeta_score_binary(preds, target, sk_fn, ignore_index, mu preds = (preds >= THRESHOLD).astype(np.uint8) if multidim_average == "global": - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return sk_fn(target, preds, zero_division=zero_division) res = [] for pred, true in zip(preds, target): pred = pred.flatten() true = true.flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) res.append(sk_fn(true, pred, zero_division=zero_division)) return np.stack(res) @@ -205,7 +205,7 @@ def _reference_sklearn_fbeta_score_multiclass( if multidim_average == "global": preds = preds.numpy().flatten() target = target.numpy().flatten() - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return sk_fn( target, preds, @@ -220,7 +220,7 @@ def _reference_sklearn_fbeta_score_multiclass( for pred, true in zip(preds, target): pred = pred.flatten() true = true.flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) if len(pred) == 0 and average == "weighted": # The result of sk_fn([], [], labels=None, average="weighted", zero_division=zero_division) @@ -417,13 +417,13 @@ def _reference_sklearn_fbeta_score_multilabel_global(preds, target, sk_fn, ignor if average == "micro": preds = preds.flatten() target = target.flatten() - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return sk_fn(target, preds, zero_division=zero_division) fbeta_score, weights = [], [] for i in range(preds.shape[1]): pred, true = preds[:, i].flatten(), target[:, i].flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) fbeta_score.append(sk_fn(true, pred, zero_division=zero_division)) confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) weights.append(confmat[1, 1] + confmat[1, 0]) @@ -446,7 +446,7 @@ def _reference_sklearn_fbeta_score_multilabel_local(preds, target, sk_fn, ignore for i in range(preds.shape[0]): if average == "micro": pred, true = preds[i].flatten(), target[i].flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) fbeta_score.append(sk_fn(true, pred, zero_division=zero_division)) confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) weights.append(confmat[1, 1] + confmat[1, 0]) @@ -454,7 +454,7 @@ def _reference_sklearn_fbeta_score_multilabel_local(preds, target, sk_fn, ignore scores, w = [], [] for j in range(preds.shape[1]): pred, true = preds[i, j], target[i, j] - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) scores.append(sk_fn(true, pred, zero_division=zero_division)) confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) w.append(confmat[1, 1] + confmat[1, 0]) diff --git a/tests/unittests/classification/test_hamming_distance.py b/tests/unittests/classification/test_hamming_distance.py index f7f3686c73b..a7a42db61b0 100644 --- a/tests/unittests/classification/test_hamming_distance.py +++ b/tests/unittests/classification/test_hamming_distance.py @@ -59,14 +59,14 @@ def _reference_sklearn_hamming_distance_binary(preds, target, ignore_index, mult preds = (preds >= THRESHOLD).astype(np.uint8) if multidim_average == "global": - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return _reference_sklearn_hamming_loss(target, preds) res = [] for pred, true in zip(preds, target): pred = pred.flatten() true = true.flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) res.append(_reference_sklearn_hamming_loss(true, pred)) return np.stack(res) @@ -167,7 +167,7 @@ def test_binary_hamming_distance_dtype_gpu(self, inputs, dtype): def _reference_sklearn_hamming_distance_multiclass_global(preds, target, ignore_index, average): preds = preds.numpy().flatten() target = target.numpy().flatten() - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) if average == "micro": return _reference_sklearn_hamming_loss(target, preds) confmat = sk_confusion_matrix(y_true=target, y_pred=preds, labels=list(range(NUM_CLASSES))) @@ -191,7 +191,7 @@ def _reference_sklearn_hamming_distance_multiclass_local(preds, target, ignore_i for pred, true in zip(preds, target): pred = pred.flatten() true = true.flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) if average == "micro": res.append(_reference_sklearn_hamming_loss(true, pred)) else: @@ -331,13 +331,13 @@ def _reference_sklearn_hamming_distance_multilabel_global(preds, target, ignore_ if average == "micro": preds = preds.flatten() target = target.flatten() - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return _reference_sklearn_hamming_loss(target, preds) hamming, weights = [], [] for i in range(preds.shape[1]): pred, true = preds[:, i].flatten(), target[:, i].flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) hamming.append(_reference_sklearn_hamming_loss(true, pred)) weights.append(confmat[1, 1] + confmat[1, 0]) @@ -360,13 +360,13 @@ def _reference_sklearn_hamming_distance_multilabel_local(preds, target, ignore_i for i in range(preds.shape[0]): if average == "micro": pred, true = preds[i].flatten(), target[i].flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) hamming.append(_reference_sklearn_hamming_loss(true, pred)) else: scores, w = [], [] for j in range(preds.shape[1]): pred, true = preds[i, j], target[i, j] - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) scores.append(_reference_sklearn_hamming_loss(true, pred)) confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) w.append(confmat[1, 1] + confmat[1, 0]) diff --git a/tests/unittests/classification/test_hinge.py b/tests/unittests/classification/test_hinge.py index fb0c62838cc..8963177d7a0 100644 --- a/tests/unittests/classification/test_hinge.py +++ b/tests/unittests/classification/test_hinge.py @@ -38,7 +38,7 @@ def _reference_sklearn_binary_hinge_loss(preds, target, ignore_index): if not ((preds > 0) & (preds < 1)).all(): preds = sigmoid(preds) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) target = 2 * target - 1 return sk_hinge(target, preds) @@ -125,7 +125,7 @@ def _reference_sklearn_multiclass_hinge_loss(preds, target, multiclass_mode, ign if not ((preds > 0) & (preds < 1)).all(): preds = softmax(preds, 1) preds = np.moveaxis(preds, 1, -1).reshape((-1, preds.shape[1])) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) if multiclass_mode == "one-vs-all": enc = OneHotEncoder() diff --git a/tests/unittests/classification/test_jaccard.py b/tests/unittests/classification/test_jaccard.py index e7afdb557a6..0a20a2e458a 100644 --- a/tests/unittests/classification/test_jaccard.py +++ b/tests/unittests/classification/test_jaccard.py @@ -45,7 +45,7 @@ def _reference_sklearn_jaccard_index_binary(preds, target, ignore_index=None, ze if not ((preds > 0) & (preds < 1)).all(): preds = sigmoid(preds) preds = (preds >= THRESHOLD).astype(np.uint8) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return sk_jaccard_index(y_true=target, y_pred=preds, zero_division=zero_division) @@ -141,7 +141,7 @@ def _reference_sklearn_jaccard_index_multiclass(preds, target, ignore_index=None preds = np.argmax(preds, axis=1) preds = preds.flatten() target = target.flatten() - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) if average is None: return sk_jaccard_index( y_true=target, y_pred=preds, average=average, labels=list(range(NUM_CLASSES)), zero_division=zero_division @@ -269,7 +269,7 @@ def _reference_sklearn_jaccard_index_multilabel(preds, target, ignore_index=None scores, weights = [], [] for i in range(preds.shape[1]): pred, true = preds[:, i], target[:, i] - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) scores.append(sk_jaccard_index(true, pred, zero_division=zero_division)) weights.append(confmat[1, 0] + confmat[1, 1]) diff --git a/tests/unittests/classification/test_matthews_corrcoef.py b/tests/unittests/classification/test_matthews_corrcoef.py index 2f881604d09..b340db8d713 100644 --- a/tests/unittests/classification/test_matthews_corrcoef.py +++ b/tests/unittests/classification/test_matthews_corrcoef.py @@ -46,7 +46,7 @@ def _reference_sklearn_matthews_corrcoef_binary(preds, target, ignore_index=None if not ((preds > 0) & (preds < 1)).all(): preds = sigmoid(preds) preds = (preds >= THRESHOLD).astype(np.uint8) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return sk_matthews_corrcoef(y_true=target, y_pred=preds) @@ -138,7 +138,7 @@ def _reference_sklearn_matthews_corrcoef_multiclass(preds, target, ignore_index= preds = np.argmax(preds, axis=1) preds = preds.flatten() target = target.flatten() - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return sk_matthews_corrcoef(y_true=target, y_pred=preds) @@ -228,7 +228,7 @@ def _reference_sklearn_matthews_corrcoef_multilabel(preds, target, ignore_index= if not ((preds > 0) & (preds < 1)).all(): preds = sigmoid(preds) preds = (preds >= THRESHOLD).astype(np.uint8) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return sk_matthews_corrcoef(y_true=target, y_pred=preds) diff --git a/tests/unittests/classification/test_precision_fixed_recall.py b/tests/unittests/classification/test_precision_fixed_recall.py index f320d2cf1e9..b6649ad869d 100644 --- a/tests/unittests/classification/test_precision_fixed_recall.py +++ b/tests/unittests/classification/test_precision_fixed_recall.py @@ -58,7 +58,7 @@ def _reference_sklearn_precision_at_fixed_recall_binary(preds, target, min_recal target = target.flatten().numpy() if np.issubdtype(preds.dtype, np.floating) and not ((preds > 0) & (preds < 1)).all(): preds = sigmoid(preds) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return _precision_at_recall_x_multilabel(preds, target, min_recall) @@ -169,7 +169,7 @@ def _reference_sklearn_precision_at_fixed_recall_multiclass(preds, target, min_r target = target.numpy().flatten() if not ((preds > 0) & (preds < 1)).all(): preds = softmax(preds, 1) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) precision, thresholds = [], [] for i in range(NUM_CLASSES): diff --git a/tests/unittests/classification/test_precision_recall.py b/tests/unittests/classification/test_precision_recall.py index 00eee202cc0..7717ffa5b0d 100644 --- a/tests/unittests/classification/test_precision_recall.py +++ b/tests/unittests/classification/test_precision_recall.py @@ -49,7 +49,9 @@ seed_all(42) -def _reference_sklearn_precision_recall_binary(preds, target, sk_fn, ignore_index, multidim_average, zero_division=0): +def _reference_sklearn_precision_recall_binary( + preds, target, sk_fn, ignore_index, multidim_average, zero_division=0, prob_threshold: float = THRESHOLD +): if multidim_average == "global": preds = preds.view(-1).numpy() target = target.view(-1).numpy() @@ -60,17 +62,17 @@ def _reference_sklearn_precision_recall_binary(preds, target, sk_fn, ignore_inde if np.issubdtype(preds.dtype, np.floating): if not ((preds > 0) & (preds < 1)).all(): preds = sigmoid(preds) - preds = (preds >= THRESHOLD).astype(np.uint8) + preds = (preds >= prob_threshold).astype(np.uint8) if multidim_average == "global": - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return sk_fn(target, preds, zero_division=zero_division) res = [] for pred, true in zip(preds, target): pred = pred.flatten() true = true.flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) res.append(sk_fn(true, pred, zero_division=zero_division)) return np.stack(res) @@ -197,7 +199,7 @@ def test_binary_precision_recall_half_gpu(self, inputs, module, functional, comp def _reference_sklearn_precision_recall_multiclass( - preds, target, sk_fn, ignore_index, multidim_average, average, zero_division=0 + preds, target, sk_fn, ignore_index, multidim_average, average, zero_division=0, num_classes: int = NUM_CLASSES ): if preds.ndim == target.ndim + 1: preds = torch.argmax(preds, 1) @@ -205,12 +207,12 @@ def _reference_sklearn_precision_recall_multiclass( if multidim_average == "global": preds = preds.numpy().flatten() target = target.numpy().flatten() - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return sk_fn( target, preds, average=average, - labels=list(range(NUM_CLASSES)) if average is None else None, + labels=list(range(num_classes)) if average is None else None, zero_division=zero_division, ) @@ -220,7 +222,7 @@ def _reference_sklearn_precision_recall_multiclass( for pred, true in zip(preds, target): pred = pred.flatten() true = true.flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) if len(pred) == 0 and average == "weighted": # The result of sk_fn([], [], labels=None, average="weighted", zero_division=zero_division) # varies depending on the sklearn version: @@ -235,7 +237,7 @@ def _reference_sklearn_precision_recall_multiclass( true, pred, average=average, - labels=list(range(NUM_CLASSES)) if average is None else None, + labels=list(range(num_classes)) if average is None else None, zero_division=zero_division, ) res.append(0.0 if np.isnan(r).any() else r) @@ -422,13 +424,13 @@ def _reference_sklearn_precision_recall_multilabel_global(preds, target, sk_fn, if average == "micro": preds = preds.flatten() target = target.flatten() - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return sk_fn(target, preds, zero_division=zero_division) precision_recall, weights = [], [] for i in range(preds.shape[1]): pred, true = preds[:, i].flatten(), target[:, i].flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) precision_recall.append(sk_fn(true, pred, zero_division=zero_division)) confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) weights.append(confmat[1, 1] + confmat[1, 0]) @@ -451,7 +453,7 @@ def _reference_sklearn_precision_recall_multilabel_local(preds, target, sk_fn, i for i in range(preds.shape[0]): if average == "micro": pred, true = preds[i].flatten(), target[i].flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) precision_recall.append(sk_fn(true, pred, zero_division=zero_division)) confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) weights.append(confmat[1, 1] + confmat[1, 0]) @@ -459,7 +461,7 @@ def _reference_sklearn_precision_recall_multilabel_local(preds, target, sk_fn, i scores, w = [], [] for j in range(preds.shape[1]): pred, true = preds[i, j], target[i, j] - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) scores.append(sk_fn(true, pred, zero_division=zero_division)) confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) w.append(confmat[1, 1] + confmat[1, 0]) @@ -481,7 +483,7 @@ def _reference_sklearn_precision_recall_multilabel_local(preds, target, sk_fn, i def _reference_sklearn_precision_recall_multilabel( - preds, target, sk_fn, ignore_index, multidim_average, average, zero_division=0 + preds, target, sk_fn, ignore_index, multidim_average, average, zero_division=0, num_classes: int = NUM_CLASSES ): preds = preds.numpy() target = target.numpy() @@ -493,8 +495,8 @@ def _reference_sklearn_precision_recall_multilabel( target = target.reshape(*target.shape[:2], -1) if ignore_index is None and multidim_average == "global": return sk_fn( - target.transpose(0, 2, 1).reshape(-1, NUM_CLASSES), - preds.transpose(0, 2, 1).reshape(-1, NUM_CLASSES), + target.transpose(0, 2, 1).reshape(-1, num_classes), + preds.transpose(0, 2, 1).reshape(-1, num_classes), average=average, zero_division=zero_division, ) diff --git a/tests/unittests/classification/test_precision_recall_curve.py b/tests/unittests/classification/test_precision_recall_curve.py index 6f78438007e..7c034c528e6 100644 --- a/tests/unittests/classification/test_precision_recall_curve.py +++ b/tests/unittests/classification/test_precision_recall_curve.py @@ -46,7 +46,7 @@ def _reference_sklearn_precision_recall_curve_binary(preds, target, ignore_index target = target.flatten().numpy() if np.issubdtype(preds.dtype, np.floating) and not ((preds > 0) & (preds < 1)).all(): preds = sigmoid(preds) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return sk_precision_recall_curve(target, preds) @@ -159,7 +159,7 @@ def _reference_sklearn_precision_recall_curve_multiclass(preds, target, ignore_i target = target.numpy().flatten() if not ((preds > 0) & (preds < 1)).all(): preds = softmax(preds, 1) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) precision, recall, thresholds = [], [], [] for i in range(NUM_CLASSES): diff --git a/tests/unittests/classification/test_recall_fixed_precision.py b/tests/unittests/classification/test_recall_fixed_precision.py index 9bdca8950bd..2d73d64f264 100644 --- a/tests/unittests/classification/test_recall_fixed_precision.py +++ b/tests/unittests/classification/test_recall_fixed_precision.py @@ -58,7 +58,7 @@ def _reference_sklearn_recall_at_fixed_precision_binary(preds, target, min_preci target = target.flatten().numpy() if np.issubdtype(preds.dtype, np.floating) and not ((preds > 0) & (preds < 1)).all(): preds = sigmoid(preds) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return _recall_at_precision_x_multilabel(preds, target, min_precision) @@ -173,7 +173,7 @@ def _reference_sklearn_recall_at_fixed_precision_multiclass(preds, target, min_p target = target.numpy().flatten() if not ((preds > 0) & (preds < 1)).all(): preds = softmax(preds, 1) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) recall, thresholds = [], [] for i in range(NUM_CLASSES): diff --git a/tests/unittests/classification/test_roc.py b/tests/unittests/classification/test_roc.py index 167ad4876f0..f6cbd173128 100644 --- a/tests/unittests/classification/test_roc.py +++ b/tests/unittests/classification/test_roc.py @@ -37,7 +37,7 @@ def _reference_sklearn_roc_binary(preds, target, ignore_index=None): target = target.flatten().numpy() if np.issubdtype(preds.dtype, np.floating) and not ((preds > 0) & (preds < 1)).all(): preds = sigmoid(preds) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) fpr, tpr, thresholds = sk_roc_curve(target, preds, drop_intermediate=False) thresholds[0] = 1.0 return [np.nan_to_num(x, nan=0.0) for x in [fpr, tpr, thresholds]] @@ -140,7 +140,7 @@ def _reference_sklearn_roc_multiclass(preds, target, ignore_index=None): target = target.numpy().flatten() if not ((preds > 0) & (preds < 1)).all(): preds = softmax(preds, 1) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) fpr, tpr, thresholds = [], [], [] for i in range(NUM_CLASSES): diff --git a/tests/unittests/classification/test_sensitivity_specificity.py b/tests/unittests/classification/test_sensitivity_specificity.py index e0689859dd4..47884bae2a3 100644 --- a/tests/unittests/classification/test_sensitivity_specificity.py +++ b/tests/unittests/classification/test_sensitivity_specificity.py @@ -79,7 +79,7 @@ def _reference_sklearn_sensitivity_at_specificity_binary(preds, target, min_spec target = target.flatten().numpy() if np.issubdtype(preds.dtype, np.floating) and not ((preds > 0) & (preds < 1)).all(): preds = sigmoid(preds) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return _sensitivity_at_specificity_x_multilabel(preds, target, min_specificity) @@ -197,7 +197,7 @@ def _reference_sklearn_sensitivity_at_specificity_multiclass(preds, target, min_ target = target.numpy().flatten() if not ((preds > 0) & (preds < 1)).all(): preds = softmax(preds, 1) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) sensitivity, thresholds = [], [] for i in range(NUM_CLASSES): diff --git a/tests/unittests/classification/test_specificity_sensitivity.py b/tests/unittests/classification/test_specificity_sensitivity.py index 0bafdfe55ea..934d669678a 100644 --- a/tests/unittests/classification/test_specificity_sensitivity.py +++ b/tests/unittests/classification/test_specificity_sensitivity.py @@ -77,7 +77,7 @@ def _reference_sklearn_specificity_at_sensitivity_binary(preds, target, min_sens target = target.flatten().numpy() if np.issubdtype(preds.dtype, np.floating) and not ((preds > 0) & (preds < 1)).all(): preds = sigmoid(preds) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return _specificity_at_sensitivity_x_multilabel(preds, target, min_sensitivity) @@ -192,7 +192,7 @@ def _reference_sklearn_specificity_at_sensitivity_multiclass(preds, target, min_ target = target.numpy().flatten() if not ((preds > 0) & (preds < 1)).all(): preds = softmax(preds, 1) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) specificity, thresholds = [], [] for i in range(NUM_CLASSES): diff --git a/tests/unittests/classification/test_stat_scores.py b/tests/unittests/classification/test_stat_scores.py index 5ea4c206bc0..2a5a53bb8aa 100644 --- a/tests/unittests/classification/test_stat_scores.py +++ b/tests/unittests/classification/test_stat_scores.py @@ -54,7 +54,7 @@ def _reference_sklearn_stat_scores_binary(preds, target, ignore_index, multidim_ preds = (preds >= THRESHOLD).astype(np.uint8) if multidim_average == "global": - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) tn, fp, fn, tp = sk_confusion_matrix(y_true=target, y_pred=preds, labels=[0, 1]).ravel() return np.array([tp, fp, tn, fn, tp + fn]) @@ -62,7 +62,7 @@ def _reference_sklearn_stat_scores_binary(preds, target, ignore_index, multidim_ for pred, true in zip(preds, target): pred = pred.flatten() true = true.flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) tn, fp, fn, tp = sk_confusion_matrix(y_true=true, y_pred=pred, labels=[0, 1]).ravel() res.append(np.array([tp, fp, tn, fn, tp + fn])) return np.stack(res) @@ -164,7 +164,7 @@ def test_binary_stat_scores_dtype_gpu(self, inputs, dtype): def _reference_sklearn_stat_scores_multiclass_global(preds, target, ignore_index, average): preds = preds.numpy().flatten() target = target.numpy().flatten() - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) confmat = sk_confusion_matrix(y_true=target, y_pred=preds, labels=list(range(NUM_CLASSES))) tp = np.diag(confmat) fp = confmat.sum(0) - tp @@ -192,7 +192,7 @@ def _reference_sklearn_stat_scores_multiclass_local(preds, target, ignore_index, for pred, true in zip(preds, target): pred = pred.flatten() true = true.flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) confmat = sk_confusion_matrix(y_true=true, y_pred=pred, labels=list(range(NUM_CLASSES))) tp = np.diag(confmat) fp = confmat.sum(0) - tp @@ -431,7 +431,7 @@ def _reference_sklearn_stat_scores_multilabel(preds, target, ignore_index, multi stat_scores = [] for i in range(preds.shape[1]): pred, true = preds[:, i].flatten(), target[:, i].flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) tn, fp, fn, tp = sk_confusion_matrix(true, pred, labels=[0, 1]).ravel() stat_scores.append(np.array([tp, fp, tn, fn, tp + fn])) res = np.stack(stat_scores, axis=0) @@ -452,7 +452,7 @@ def _reference_sklearn_stat_scores_multilabel(preds, target, ignore_index, multi scores = [] for j in range(preds.shape[1]): pred, true = preds[i, j], target[i, j] - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) tn, fp, fn, tp = sk_confusion_matrix(true, pred, labels=[0, 1]).ravel() scores.append(np.array([tp, fp, tn, fn, tp + fn])) stat_scores.append(np.stack(scores, 1))