Skip to content

Commit

Permalink
mypy issues
Browse files Browse the repository at this point in the history
Signed-off-by: thibaultdvx <[email protected]>
  • Loading branch information
thibaultdvx committed Sep 18, 2024
1 parent b01b38b commit c9ee60c
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions monai/metrics/r2_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(self, multi_output: MultiOutput | str = MultiOutput.UNIFORM, p: int
self.multi_output = multi_output
self.p = p

def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: # type: ignore[override]
_check_dim(y_pred, y)
return y_pred, y

Expand Down Expand Up @@ -100,7 +100,7 @@ def _check_dim(y_pred: torch.Tensor, y: torch.Tensor) -> None:
)


def _check_r2_params(multi_output, p) -> tuple[MultiOutput, int]:
def _check_r2_params(multi_output: MultiOutput | str, p: int) -> tuple[MultiOutput | str, int]:
multi_output = look_up_option(multi_output, MultiOutput)
if not isinstance(p, int) or p < 0:
raise ValueError(f"`p` must be an integer larger or equal to 0, got {p}.")
Expand All @@ -115,7 +115,7 @@ def _calculate(y_pred: np.ndarray, y: np.ndarray, p: int) -> float:
r2 = 1 - (rss / tss)
r2_adjusted = 1 - (1 - r2) * (num_obs - 1) / (num_obs - p - 1)

return r2_adjusted
return r2_adjusted # type: ignore[no-any-return]


def compute_r2_score(
Expand Down Expand Up @@ -154,28 +154,29 @@ def compute_r2_score(
_check_dim(y_pred, y)
dim = y.ndimension()
n = y.shape[0]
y = y.cpu().numpy()
y_pred = y_pred.cpu().numpy()
y = y.cpu().numpy() # type: ignore[assignment]
y_pred = y_pred.cpu().numpy() # type: ignore[assignment]

if n < 2:
raise ValueError("There is no enough data for computing. Needs at least two samples to calculate r2 score.")
if p >= n - 1:
raise ValueError("`p` must be smaller than n_samples - 1, " f"got p={p}, n_samples={n}.")

if dim == 2 and y_pred.shape[1] == 1:
y_pred = np.squeeze(y_pred, axis=-1)
y = np.squeeze(y, axis=-1)
y_pred = np.squeeze(y_pred, axis=-1) # type: ignore[assignment]
y = np.squeeze(y, axis=-1) # type: ignore[assignment]
dim = 1

if dim == 1:
return _calculate(y_pred, y, p)
return _calculate(y_pred, y, p) # type: ignore[arg-type]

y, y_pred = np.transpose(y, axes=(1, 0)), np.transpose(y_pred, axes=(1, 0))
y, y_pred = np.transpose(y, axes=(1, 0)), np.transpose(y_pred, axes=(1, 0)) # type: ignore[assignment]
r2_values = [_calculate(y_pred_, y_, p) for y_pred_, y_ in zip(y_pred, y)]
if multi_output == MultiOutput.RAW:
return r2_values
if multi_output == MultiOutput.UNIFORM:
return np.mean(r2_values)
if multi_output == multi_output.VARIANCE:
if multi_output == MultiOutput.VARIANCE:
weights = np.var(y, axis=1)
return np.average(r2_values, weights=weights)
return np.average(r2_values, weights=weights) # type: ignore[no-any-return]
raise ValueError(f'Unsupported multi_output: {multi_output}, available options are ["raw_values", "uniform_average", "variance_weighted"].')

0 comments on commit c9ee60c

Please sign in to comment.