Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

typing: mark line ignores instead of ignore whole files #2452

Merged
merged 1 commit into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 0 additions & 20 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -147,26 +147,6 @@ disable_error_code = "attr-defined"
# style choices
warn_no_return = "False"

# Ignore mypy errors for these files
# TODO: the goal is for this to be empty
[[tool.mypy.overrides]]
module = [
"torchmetrics.classification.exact_match",
"torchmetrics.classification.f_beta",
"torchmetrics.classification.precision_recall",
"torchmetrics.classification.ranking",
"torchmetrics.classification.recall_at_fixed_precision",
"torchmetrics.classification.roc",
"torchmetrics.classification.stat_scores",
"torchmetrics.detection._mean_ap",
"torchmetrics.detection.mean_ap",
"torchmetrics.functional.image.psnr",
"torchmetrics.functional.image.ssim",
"torchmetrics.image.psnr",
"torchmetrics.image.ssim",
]
ignore_errors = "True"

[tool.typos.default]
extend-ignore-identifiers-re = [
# *sigh* this just isn't worth the cost of fixing
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/classification/exact_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ class ExactMatch(_ClassificationTaskWrapper):

"""

def __new__(
def __new__( # type: ignore[misc]
cls: Type["ExactMatch"],
task: Literal["binary", "multiclass", "multilabel"],
threshold: float = 0.5,
Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,7 +1058,7 @@ class FBetaScore(_ClassificationTaskWrapper):

"""

def __new__(
def __new__( # type: ignore[misc]
cls: Type["FBetaScore"],
task: Literal["binary", "multiclass", "multilabel"],
beta: float = 1.0,
Expand Down Expand Up @@ -1122,7 +1122,7 @@ class F1Score(_ClassificationTaskWrapper):

"""

def __new__(
def __new__( # type: ignore[misc]
cls: Type["F1Score"],
task: Literal["binary", "multiclass", "multilabel"],
threshold: float = 0.5,
Expand Down
6 changes: 3 additions & 3 deletions src/torchmetrics/classification/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,7 +930,7 @@ class Precision(_ClassificationTaskWrapper):

"""

def __new__(
def __new__( # type: ignore[misc]
cls: Type["Precision"],
task: Literal["binary", "multiclass", "multilabel"],
threshold: float = 0.5,
Expand Down Expand Up @@ -995,7 +995,7 @@ class Recall(_ClassificationTaskWrapper):

"""

def __new__(
def __new__( # type: ignore[misc]
cls: Type["Recall"],
task: Literal["binary", "multiclass", "multilabel"],
threshold: float = 0.5,
Expand Down Expand Up @@ -1028,4 +1028,4 @@ def __new__(
if not isinstance(num_labels, int):
raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`")
return MultilabelRecall(num_labels, threshold, average, **kwargs)
return None
return None # type: ignore[return-value]
8 changes: 4 additions & 4 deletions src/torchmetrics/classification/roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class BinaryROC(BinaryPrecisionRecallCurve):
def compute(self) -> Tuple[Tensor, Tensor, Tensor]:
"""Compute metric."""
state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat
return _binary_roc_compute(state, self.thresholds)
return _binary_roc_compute(state, self.thresholds) # type: ignore[arg-type]

def plot(
self,
Expand Down Expand Up @@ -290,7 +290,7 @@ class MulticlassROC(MulticlassPrecisionRecallCurve):
def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
"""Compute metric."""
state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat
return _multiclass_roc_compute(state, self.num_classes, self.thresholds, self.average)
return _multiclass_roc_compute(state, self.num_classes, self.thresholds, self.average) # type: ignore[arg-type]

def plot(
self,
Expand Down Expand Up @@ -449,7 +449,7 @@ class MultilabelROC(MultilabelPrecisionRecallCurve):
def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
"""Compute metric."""
state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat
return _multilabel_roc_compute(state, self.num_labels, self.thresholds, self.ignore_index)
return _multilabel_roc_compute(state, self.num_labels, self.thresholds, self.ignore_index) # type: ignore[arg-type]

def plot(
self,
Expand Down Expand Up @@ -564,7 +564,7 @@ class ROC(_ClassificationTaskWrapper):

"""

def __new__(
def __new__( # type: ignore[misc]
cls: Type["ROC"],
task: Literal["binary", "multiclass", "multilabel"],
thresholds: Optional[Union[int, List[float], Tensor]] = None,
Expand Down
10 changes: 5 additions & 5 deletions src/torchmetrics/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,10 @@ def _create_state(
def _update_state(self, tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor) -> None:
"""Update states depending on multidim_average argument."""
if self.multidim_average == "samplewise":
self.tp.append(tp)
self.fp.append(fp)
self.tn.append(tn)
self.fn.append(fn)
self.tp.append(tp) # type: ignore[union-attr]
self.fp.append(fp) # type: ignore[union-attr]
self.tn.append(tn) # type: ignore[union-attr]
self.fn.append(fn) # type: ignore[union-attr]
else:
self.tp += tp
self.fp += fp
Expand Down Expand Up @@ -515,7 +515,7 @@ class StatScores(_ClassificationTaskWrapper):

"""

def __new__(
def __new__( # type: ignore[misc]
cls: Type["StatScores"],
task: Literal["binary", "multiclass", "multilabel"],
threshold: float = 0.5,
Expand Down
42 changes: 22 additions & 20 deletions src/torchmetrics/detection/_mean_ap.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,18 +366,18 @@ def __init__(

def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]) -> None:
"""Update state with predictions and targets."""
_input_validator(preds, target, iou_type=self.iou_type)
_input_validator(preds, target, iou_type=self.iou_type) # type: ignore[arg-type]

for item in preds:
detections = self._get_safe_item_values(item)

self.detections.append(detections)
self.detections.append(detections) # type: ignore[arg-type]
self.detection_labels.append(item["labels"])
self.detection_scores.append(item["scores"])

for item in target:
groundtruths = self._get_safe_item_values(item)
self.groundtruths.append(groundtruths)
self.groundtruths.append(groundtruths) # type: ignore[arg-type]
self.groundtruth_labels.append(item["labels"])

def _move_list_states_to_cpu(self) -> None:
Expand Down Expand Up @@ -640,13 +640,13 @@ def _find_best_gt_match(
Id of current detection.

"""
previously_matched = gt_matches[idx_iou]
previously_matched = gt_matches[idx_iou] # type: ignore[index]
# Remove previously matched or ignored gts
remove_mask = previously_matched | gt_ignore
gt_ious = ious[idx_det] * ~remove_mask
match_idx = gt_ious.argmax().item()
if gt_ious[match_idx] > thr:
return match_idx
if gt_ious[match_idx] > thr: # type: ignore[index]
return match_idx # type: ignore[return-value]
return -1

def _summarize(
Expand Down Expand Up @@ -713,7 +713,7 @@ def _calculate(self, class_ids: List) -> Tuple[MAPMetricResults, MARMetricResult
}

eval_imgs = [
self._evaluate_image(img_id, class_id, area, max_detections, ious)
self._evaluate_image(img_id, class_id, area, max_detections, ious) # type: ignore[arg-type]
for class_id in class_ids
for area in area_ranges
for img_id in img_ids
Expand Down Expand Up @@ -750,7 +750,7 @@ def _calculate(self, class_ids: List) -> Tuple[MAPMetricResults, MARMetricResult
num_bbox_areas=num_bbox_areas,
)

return precision, recall
return precision, recall # type: ignore[return-value]

def _summarize_results(self, precisions: Tensor, recalls: Tensor) -> Tuple[MAPMetricResults, MARMetricResults]:
"""Summarizes the precision and recall values to calculate mAP/mAR.
Expand Down Expand Up @@ -820,8 +820,8 @@ def __calculate_recall_precision_scores(
inds = torch.argsort(det_scores.to(dtype), descending=True)
det_scores_sorted = det_scores[inds]

det_matches = torch.cat([e["dtMatches"][:, :max_det] for e in img_eval_cls_bbox], axis=1)[:, inds]
det_ignore = torch.cat([e["dtIgnore"][:, :max_det] for e in img_eval_cls_bbox], axis=1)[:, inds]
det_matches = torch.cat([e["dtMatches"][:, :max_det] for e in img_eval_cls_bbox], axis=1)[:, inds] # type: ignore[call-overload]
det_ignore = torch.cat([e["dtIgnore"][:, :max_det] for e in img_eval_cls_bbox], axis=1)[:, inds] # type: ignore[call-overload]
gt_ignore = torch.cat([e["gtIgnore"] for e in img_eval_cls_bbox])
npig = torch.count_nonzero(gt_ignore == False) # noqa: E712
if npig == 0:
Expand Down Expand Up @@ -849,9 +849,9 @@ def __calculate_recall_precision_scores(

inds = torch.searchsorted(rc, rec_thresholds.to(rc.device), right=False)
num_inds = inds.argmax() if inds.max() >= tp_len else num_rec_thrs
inds = inds[:num_inds]
prec[:num_inds] = pr[inds]
score[:num_inds] = det_scores_sorted[inds]
inds = inds[:num_inds] # type: ignore[misc]
prec[:num_inds] = pr[inds] # type: ignore[misc]
score[:num_inds] = det_scores_sorted[inds] # type: ignore[misc]
precision[idx, :, idx_cls, idx_bbox_area, idx_max_det_thrs] = prec
scores[idx, :, idx_cls, idx_bbox_area, idx_max_det_thrs] = score

Expand All @@ -861,7 +861,7 @@ def compute(self) -> dict:
"""Compute metric."""
classes = self._get_classes()
precisions, recalls = self._calculate(classes)
map_val, mar_val = self._summarize_results(precisions, recalls)
map_val, mar_val = self._summarize_results(precisions, recalls) # type: ignore[arg-type]

# if class mode is enabled, evaluate metrics per class
map_per_class_values: Tensor = torch.tensor([-1.0])
Expand All @@ -888,7 +888,7 @@ def compute(self) -> dict:
metrics.classes = torch.tensor(classes, dtype=torch.int)
return metrics

def _apply(self, fn: Callable) -> torch.nn.Module:
def _apply(self, fn: Callable) -> torch.nn.Module: # type: ignore[override]
"""Custom apply function.

Excludes the detections and groundtruths from the casting when the iou_type is set to `segm` as the state is
Expand All @@ -908,22 +908,24 @@ def _sync_dist(self, dist_sync_fn: Optional[Callable] = None, process_group: Opt
to gather the list of tuples and then convert it back to a list of tuples.

"""
super()._sync_dist(dist_sync_fn=dist_sync_fn, process_group=process_group)
super()._sync_dist(dist_sync_fn=dist_sync_fn, process_group=process_group) # type: ignore[arg-type]

if self.iou_type == "segm":
self.detections = self._gather_tuple_list(self.detections, process_group)
self.groundtruths = self._gather_tuple_list(self.groundtruths, process_group)
self.detections = self._gather_tuple_list(self.detections, process_group) # type: ignore[arg-type]
self.groundtruths = self._gather_tuple_list(self.groundtruths, process_group) # type: ignore[arg-type]

@staticmethod
def _gather_tuple_list(list_to_gather: List[Tuple], process_group: Optional[Any] = None) -> List[Any]:
def _gather_tuple_list(
list_to_gather: List[Union[tuple, Tensor]], process_group: Optional[Any] = None
) -> List[Any]:
"""Gather a list of tuples over multiple devices."""
world_size = dist.get_world_size(group=process_group)
dist.barrier(group=process_group)

list_gathered = [None for _ in range(world_size)]
dist.all_gather_object(list_gathered, list_to_gather, group=process_group)

return [list_gathered[rank][idx] for idx in range(len(list_gathered[0])) for rank in range(world_size)]
return [list_gathered[rank][idx] for idx in range(len(list_gathered[0])) for rank in range(world_size)] # type: ignore[arg-type,index]

def plot(
self, val: Optional[Union[Dict[str, Tensor], Sequence[Dict[str, Tensor]]]] = None, ax: Optional[_AX_TYPE] = None
Expand Down
6 changes: 4 additions & 2 deletions src/torchmetrics/detection/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,15 @@ def _fix_empty_tensors(boxes: Tensor) -> Tensor:
return boxes


def _validate_iou_type_arg(iou_type: Union[Literal["bbox", "segm"], Tuple[str]] = "bbox") -> Tuple[str]:
def _validate_iou_type_arg(
iou_type: Union[Literal["bbox", "segm"], Tuple[str]] = "bbox",
) -> Tuple[str]:
"""Validate that iou type argument is correct."""
allowed_iou_types = ("segm", "bbox")
if isinstance(iou_type, str):
iou_type = (iou_type,)
if any(tp not in allowed_iou_types for tp in iou_type):
raise ValueError(
f"Expected argument `iou_type` to be one of {allowed_iou_types} or a list of, but got {iou_type}"
f"Expected argument `iou_type` to be one of {allowed_iou_types} or a tuple of, but got {iou_type}"
)
return iou_type
Loading
Loading