From 22d66cc4d3d5f5a3c63d1d9588170eccaeeb4297 Mon Sep 17 00:00:00 2001 From: Jirka Date: Fri, 15 Mar 2024 21:40:37 +0100 Subject: [PATCH] typing: mark line ignores instead of ignore whole files --- pyproject.toml | 20 --------- .../classification/exact_match.py | 2 +- src/torchmetrics/classification/f_beta.py | 4 +- .../classification/precision_recall.py | 6 +-- src/torchmetrics/classification/roc.py | 8 ++-- .../classification/stat_scores.py | 10 ++--- src/torchmetrics/detection/_mean_ap.py | 42 ++++++++++--------- src/torchmetrics/detection/helpers.py | 6 ++- src/torchmetrics/detection/mean_ap.py | 40 +++++++++--------- src/torchmetrics/functional/image/psnr.py | 8 ++-- src/torchmetrics/functional/image/ssim.py | 8 ++-- 11 files changed, 69 insertions(+), 85 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 70d3725592f..efc34e27a8f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -147,26 +147,6 @@ disable_error_code = "attr-defined" # style choices warn_no_return = "False" -# Ignore mypy errors for these files -# TODO: the goal is for this to be empty -[[tool.mypy.overrides]] -module = [ - "torchmetrics.classification.exact_match", - "torchmetrics.classification.f_beta", - "torchmetrics.classification.precision_recall", - "torchmetrics.classification.ranking", - "torchmetrics.classification.recall_at_fixed_precision", - "torchmetrics.classification.roc", - "torchmetrics.classification.stat_scores", - "torchmetrics.detection._mean_ap", - "torchmetrics.detection.mean_ap", - "torchmetrics.functional.image.psnr", - "torchmetrics.functional.image.ssim", - "torchmetrics.image.psnr", - "torchmetrics.image.ssim", -] -ignore_errors = "True" - [tool.typos.default] extend-ignore-identifiers-re = [ # *sigh* this just isn't worth the cost of fixing diff --git a/src/torchmetrics/classification/exact_match.py b/src/torchmetrics/classification/exact_match.py index b37a0850f4e..10b9aedc2fc 100644 --- a/src/torchmetrics/classification/exact_match.py +++ b/src/torchmetrics/classification/exact_match.py @@ -393,7 +393,7 @@ class ExactMatch(_ClassificationTaskWrapper): """ - def __new__( + def __new__( # type: ignore[misc] cls: Type["ExactMatch"], task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, diff --git a/src/torchmetrics/classification/f_beta.py b/src/torchmetrics/classification/f_beta.py index eec4c33bd8b..93f26441c2a 100644 --- a/src/torchmetrics/classification/f_beta.py +++ b/src/torchmetrics/classification/f_beta.py @@ -1058,7 +1058,7 @@ class FBetaScore(_ClassificationTaskWrapper): """ - def __new__( + def __new__( # type: ignore[misc] cls: Type["FBetaScore"], task: Literal["binary", "multiclass", "multilabel"], beta: float = 1.0, @@ -1122,7 +1122,7 @@ class F1Score(_ClassificationTaskWrapper): """ - def __new__( + def __new__( # type: ignore[misc] cls: Type["F1Score"], task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, diff --git a/src/torchmetrics/classification/precision_recall.py b/src/torchmetrics/classification/precision_recall.py index 44c6304d316..124215f4e03 100644 --- a/src/torchmetrics/classification/precision_recall.py +++ b/src/torchmetrics/classification/precision_recall.py @@ -930,7 +930,7 @@ class Precision(_ClassificationTaskWrapper): """ - def __new__( + def __new__( # type: ignore[misc] cls: Type["Precision"], task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, @@ -995,7 +995,7 @@ class Recall(_ClassificationTaskWrapper): """ - def __new__( + def __new__( # type: ignore[misc] cls: Type["Recall"], task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, @@ -1028,4 +1028,4 @@ def __new__( if not isinstance(num_labels, int): raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") return MultilabelRecall(num_labels, threshold, average, **kwargs) - return None + return None # type: ignore[return-value] diff --git a/src/torchmetrics/classification/roc.py b/src/torchmetrics/classification/roc.py index 40bd8c36327..68edf6e8fcc 100644 --- a/src/torchmetrics/classification/roc.py +++ b/src/torchmetrics/classification/roc.py @@ -120,7 +120,7 @@ class BinaryROC(BinaryPrecisionRecallCurve): def compute(self) -> Tuple[Tensor, Tensor, Tensor]: """Compute metric.""" state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat - return _binary_roc_compute(state, self.thresholds) + return _binary_roc_compute(state, self.thresholds) # type: ignore[arg-type] def plot( self, @@ -290,7 +290,7 @@ class MulticlassROC(MulticlassPrecisionRecallCurve): def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: """Compute metric.""" state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat - return _multiclass_roc_compute(state, self.num_classes, self.thresholds, self.average) + return _multiclass_roc_compute(state, self.num_classes, self.thresholds, self.average) # type: ignore[arg-type] def plot( self, @@ -449,7 +449,7 @@ class MultilabelROC(MultilabelPrecisionRecallCurve): def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: """Compute metric.""" state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat - return _multilabel_roc_compute(state, self.num_labels, self.thresholds, self.ignore_index) + return _multilabel_roc_compute(state, self.num_labels, self.thresholds, self.ignore_index) # type: ignore[arg-type] def plot( self, @@ -564,7 +564,7 @@ class ROC(_ClassificationTaskWrapper): """ - def __new__( + def __new__( # type: ignore[misc] cls: Type["ROC"], task: Literal["binary", "multiclass", "multilabel"], thresholds: Optional[Union[int, List[float], Tensor]] = None, diff --git a/src/torchmetrics/classification/stat_scores.py b/src/torchmetrics/classification/stat_scores.py index b70ee9ddeac..1ae0d4285e6 100644 --- a/src/torchmetrics/classification/stat_scores.py +++ b/src/torchmetrics/classification/stat_scores.py @@ -69,10 +69,10 @@ def _create_state( def _update_state(self, tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor) -> None: """Update states depending on multidim_average argument.""" if self.multidim_average == "samplewise": - self.tp.append(tp) - self.fp.append(fp) - self.tn.append(tn) - self.fn.append(fn) + self.tp.append(tp) # type: ignore[union-attr] + self.fp.append(fp) # type: ignore[union-attr] + self.tn.append(tn) # type: ignore[union-attr] + self.fn.append(fn) # type: ignore[union-attr] else: self.tp += tp self.fp += fp @@ -515,7 +515,7 @@ class StatScores(_ClassificationTaskWrapper): """ - def __new__( + def __new__( # type: ignore[misc] cls: Type["StatScores"], task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, diff --git a/src/torchmetrics/detection/_mean_ap.py b/src/torchmetrics/detection/_mean_ap.py index 1fce9e37599..6d6a5c3dfa4 100644 --- a/src/torchmetrics/detection/_mean_ap.py +++ b/src/torchmetrics/detection/_mean_ap.py @@ -366,18 +366,18 @@ def __init__( def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]) -> None: """Update state with predictions and targets.""" - _input_validator(preds, target, iou_type=self.iou_type) + _input_validator(preds, target, iou_type=self.iou_type) # type: ignore[arg-type] for item in preds: detections = self._get_safe_item_values(item) - self.detections.append(detections) + self.detections.append(detections) # type: ignore[arg-type] self.detection_labels.append(item["labels"]) self.detection_scores.append(item["scores"]) for item in target: groundtruths = self._get_safe_item_values(item) - self.groundtruths.append(groundtruths) + self.groundtruths.append(groundtruths) # type: ignore[arg-type] self.groundtruth_labels.append(item["labels"]) def _move_list_states_to_cpu(self) -> None: @@ -640,13 +640,13 @@ def _find_best_gt_match( Id of current detection. """ - previously_matched = gt_matches[idx_iou] + previously_matched = gt_matches[idx_iou] # type: ignore[index] # Remove previously matched or ignored gts remove_mask = previously_matched | gt_ignore gt_ious = ious[idx_det] * ~remove_mask match_idx = gt_ious.argmax().item() - if gt_ious[match_idx] > thr: - return match_idx + if gt_ious[match_idx] > thr: # type: ignore[index] + return match_idx # type: ignore[return-value] return -1 def _summarize( @@ -713,7 +713,7 @@ def _calculate(self, class_ids: List) -> Tuple[MAPMetricResults, MARMetricResult } eval_imgs = [ - self._evaluate_image(img_id, class_id, area, max_detections, ious) + self._evaluate_image(img_id, class_id, area, max_detections, ious) # type: ignore[arg-type] for class_id in class_ids for area in area_ranges for img_id in img_ids @@ -750,7 +750,7 @@ def _calculate(self, class_ids: List) -> Tuple[MAPMetricResults, MARMetricResult num_bbox_areas=num_bbox_areas, ) - return precision, recall + return precision, recall # type: ignore[return-value] def _summarize_results(self, precisions: Tensor, recalls: Tensor) -> Tuple[MAPMetricResults, MARMetricResults]: """Summarizes the precision and recall values to calculate mAP/mAR. @@ -820,8 +820,8 @@ def __calculate_recall_precision_scores( inds = torch.argsort(det_scores.to(dtype), descending=True) det_scores_sorted = det_scores[inds] - det_matches = torch.cat([e["dtMatches"][:, :max_det] for e in img_eval_cls_bbox], axis=1)[:, inds] - det_ignore = torch.cat([e["dtIgnore"][:, :max_det] for e in img_eval_cls_bbox], axis=1)[:, inds] + det_matches = torch.cat([e["dtMatches"][:, :max_det] for e in img_eval_cls_bbox], axis=1)[:, inds] # type: ignore[call-overload] + det_ignore = torch.cat([e["dtIgnore"][:, :max_det] for e in img_eval_cls_bbox], axis=1)[:, inds] # type: ignore[call-overload] gt_ignore = torch.cat([e["gtIgnore"] for e in img_eval_cls_bbox]) npig = torch.count_nonzero(gt_ignore == False) # noqa: E712 if npig == 0: @@ -849,9 +849,9 @@ def __calculate_recall_precision_scores( inds = torch.searchsorted(rc, rec_thresholds.to(rc.device), right=False) num_inds = inds.argmax() if inds.max() >= tp_len else num_rec_thrs - inds = inds[:num_inds] - prec[:num_inds] = pr[inds] - score[:num_inds] = det_scores_sorted[inds] + inds = inds[:num_inds] # type: ignore[misc] + prec[:num_inds] = pr[inds] # type: ignore[misc] + score[:num_inds] = det_scores_sorted[inds] # type: ignore[misc] precision[idx, :, idx_cls, idx_bbox_area, idx_max_det_thrs] = prec scores[idx, :, idx_cls, idx_bbox_area, idx_max_det_thrs] = score @@ -861,7 +861,7 @@ def compute(self) -> dict: """Compute metric.""" classes = self._get_classes() precisions, recalls = self._calculate(classes) - map_val, mar_val = self._summarize_results(precisions, recalls) + map_val, mar_val = self._summarize_results(precisions, recalls) # type: ignore[arg-type] # if class mode is enabled, evaluate metrics per class map_per_class_values: Tensor = torch.tensor([-1.0]) @@ -888,7 +888,7 @@ def compute(self) -> dict: metrics.classes = torch.tensor(classes, dtype=torch.int) return metrics - def _apply(self, fn: Callable) -> torch.nn.Module: + def _apply(self, fn: Callable) -> torch.nn.Module: # type: ignore[override] """Custom apply function. Excludes the detections and groundtruths from the casting when the iou_type is set to `segm` as the state is @@ -908,14 +908,16 @@ def _sync_dist(self, dist_sync_fn: Optional[Callable] = None, process_group: Opt to gather the list of tuples and then convert it back to a list of tuples. """ - super()._sync_dist(dist_sync_fn=dist_sync_fn, process_group=process_group) + super()._sync_dist(dist_sync_fn=dist_sync_fn, process_group=process_group) # type: ignore[arg-type] if self.iou_type == "segm": - self.detections = self._gather_tuple_list(self.detections, process_group) - self.groundtruths = self._gather_tuple_list(self.groundtruths, process_group) + self.detections = self._gather_tuple_list(self.detections, process_group) # type: ignore[arg-type] + self.groundtruths = self._gather_tuple_list(self.groundtruths, process_group) # type: ignore[arg-type] @staticmethod - def _gather_tuple_list(list_to_gather: List[Tuple], process_group: Optional[Any] = None) -> List[Any]: + def _gather_tuple_list( + list_to_gather: List[Union[tuple, Tensor]], process_group: Optional[Any] = None + ) -> List[Any]: """Gather a list of tuples over multiple devices.""" world_size = dist.get_world_size(group=process_group) dist.barrier(group=process_group) @@ -923,7 +925,7 @@ def _gather_tuple_list(list_to_gather: List[Tuple], process_group: Optional[Any] list_gathered = [None for _ in range(world_size)] dist.all_gather_object(list_gathered, list_to_gather, group=process_group) - return [list_gathered[rank][idx] for idx in range(len(list_gathered[0])) for rank in range(world_size)] + return [list_gathered[rank][idx] for idx in range(len(list_gathered[0])) for rank in range(world_size)] # type: ignore[arg-type,index] def plot( self, val: Optional[Union[Dict[str, Tensor], Sequence[Dict[str, Tensor]]]] = None, ax: Optional[_AX_TYPE] = None diff --git a/src/torchmetrics/detection/helpers.py b/src/torchmetrics/detection/helpers.py index 2c8c35b7ace..dc31a7c7497 100644 --- a/src/torchmetrics/detection/helpers.py +++ b/src/torchmetrics/detection/helpers.py @@ -87,13 +87,15 @@ def _fix_empty_tensors(boxes: Tensor) -> Tensor: return boxes -def _validate_iou_type_arg(iou_type: Union[Literal["bbox", "segm"], Tuple[str]] = "bbox") -> Tuple[str]: +def _validate_iou_type_arg( + iou_type: Union[Literal["bbox", "segm"], Tuple[str]] = "bbox", +) -> Tuple[str]: """Validate that iou type argument is correct.""" allowed_iou_types = ("segm", "bbox") if isinstance(iou_type, str): iou_type = (iou_type,) if any(tp not in allowed_iou_types for tp in iou_type): raise ValueError( - f"Expected argument `iou_type` to be one of {allowed_iou_types} or a list of, but got {iou_type}" + f"Expected argument `iou_type` to be one of {allowed_iou_types} or a tuple of, but got {iou_type}" ) return iou_type diff --git a/src/torchmetrics/detection/mean_ap.py b/src/torchmetrics/detection/mean_ap.py index d3643411b9d..468300386c9 100644 --- a/src/torchmetrics/detection/mean_ap.py +++ b/src/torchmetrics/detection/mean_ap.py @@ -489,14 +489,14 @@ def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]] If any score is not type float and of length 1 """ - _input_validator(preds, target, iou_type=self.iou_type) + _input_validator(preds, target, iou_type=self.iou_type) # type: ignore[arg-type] for item in preds: bbox_detection, mask_detection = self._get_safe_item_values(item, warn=self.warn_on_many_detections) if bbox_detection is not None: self.detection_box.append(bbox_detection) if mask_detection is not None: - self.detection_mask.append(mask_detection) + self.detection_mask.append(mask_detection) # type: ignore[arg-type] self.detection_labels.append(item["labels"]) self.detection_scores.append(item["scores"]) @@ -505,7 +505,7 @@ def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]] if bbox_groundtruth is not None: self.groundtruth_box.append(bbox_groundtruth) if mask_groundtruth is not None: - self.groundtruth_mask.append(mask_groundtruth) + self.groundtruth_mask.append(mask_groundtruth) # type: ignore[arg-type] self.groundtruth_labels.append(item["labels"]) self.groundtruth_crowds.append(item.get("iscrowd", torch.zeros_like(item["labels"]))) self.groundtruth_area.append(item.get("area", torch.zeros_like(item["labels"]))) @@ -524,7 +524,7 @@ def compute(self) -> dict: for anno in coco_preds.dataset["annotations"]: anno["area"] = anno[f"area_{i_type}"] - coco_eval = self.cocoeval(coco_target, coco_preds, iouType=i_type) + coco_eval = self.cocoeval(coco_target, coco_preds, iouType=i_type) # type: ignore[operator] coco_eval.params.iouThrs = np.array(self.iou_thresholds, dtype=np.float64) coco_eval.params.recThrs = np.array(self.rec_thresholds, dtype=np.float64) coco_eval.params.maxDets = self.max_detection_thresholds @@ -553,7 +553,7 @@ def compute(self) -> dict: # since micro averaging have all the data in one class, we need to reinitialize the coco_eval # object in macro mode to get the per class stats coco_preds, coco_target = self._get_coco_datasets(average="macro") - coco_eval = self.cocoeval(coco_target, coco_preds, iouType=i_type) + coco_eval = self.cocoeval(coco_target, coco_preds, iouType=i_type) # type: ignore[operator] coco_eval.params.iouThrs = np.array(self.iou_thresholds, dtype=np.float64) coco_eval.params.recThrs = np.array(self.rec_thresholds, dtype=np.float64) coco_eval.params.maxDets = self.max_detection_thresholds @@ -597,7 +597,7 @@ def _get_coco_datasets(self, average: Literal["macro", "micro"]) -> Tuple[object groundtruth_labels = self.groundtruth_labels detection_labels = self.detection_labels - coco_target, coco_preds = self.coco(), self.coco() + coco_target, coco_preds = self.coco(), self.coco() # type: ignore[operator] coco_target.dataset = self._get_coco_format( labels=groundtruth_labels, @@ -671,17 +671,17 @@ def coco_to_tm( ... ) # doctest: +SKIP """ - iou_type = _validate_iou_type_arg(iou_type) + iou_type = _validate_iou_type_arg(iou_type) # type: ignore[arg-type] coco, _, _ = _load_backend_tools(backend) with contextlib.redirect_stdout(io.StringIO()): - gt = coco(coco_target) + gt = coco(coco_target) # type: ignore[operator] dt = gt.loadRes(coco_preds) gt_dataset = gt.dataset["annotations"] dt_dataset = dt.dataset["annotations"] - target = {} + target: dict = {} for t in gt_dataset: if t["image_id"] not in target: target[t["image_id"]] = { @@ -702,7 +702,7 @@ def coco_to_tm( target[t["image_id"]]["iscrowd"].append(t["iscrowd"]) target[t["image_id"]]["area"].append(t["area"]) - preds = {} + preds: dict = {} for p in dt_dataset: if p["image_id"] not in preds: preds[p["image_id"]] = {"scores": [], "labels": []} @@ -820,18 +820,18 @@ def _get_safe_item_values( boxes = _fix_empty_tensors(item["boxes"]) if boxes.numel() > 0: boxes = box_convert(boxes, in_fmt=self.box_format, out_fmt="xywh") - output[0] = boxes + output[0] = boxes # type: ignore[call-overload] if "segm" in self.iou_type: masks = [] for i in item["masks"].cpu().numpy(): rle = self.mask_utils.encode(np.asfortranarray(i)) masks.append((tuple(rle["size"]), rle["counts"])) - output[1] = tuple(masks) + output[1] = tuple(masks) # type: ignore[call-overload] if (output[0] is not None and len(output[0]) > self.max_detection_thresholds[-1]) or ( output[1] is not None and len(output[1]) > self.max_detection_thresholds[-1] ): _warning_on_too_many_detections(self.max_detection_thresholds[-1]) - return output + return output # type: ignore[return-value] def _get_classes(self) -> List: """Return a list of unique classes found in ground truth and detection data.""" @@ -866,11 +866,11 @@ def _get_coco_format( image_masks = masks[image_id] if len(image_masks) == 0 and boxes is None: continue - image_labels = image_labels.cpu().tolist() + image_labels = image_labels.cpu().tolist() # type: ignore[assignment] images.append({"id": image_id}) if "segm" in self.iou_type and len(image_masks) > 0: - images[-1]["height"], images[-1]["width"] = image_masks[0][0][0], image_masks[0][0][1] + images[-1]["height"], images[-1]["width"] = image_masks[0][0][0], image_masks[0][0][1] # type: ignore[assignment] for k, image_label in enumerate(image_labels): if boxes is not None: @@ -892,7 +892,7 @@ def _get_coco_format( area_stat_box = None area_stat_mask = None - if area is not None and area[image_id][k].cpu().tolist() > 0: + if area is not None and area[image_id][k].cpu().tolist() > 0: # type: ignore[operator] area_stat = area[image_id][k].cpu().tolist() else: area_stat = ( @@ -1011,11 +1011,11 @@ def _sync_dist(self, dist_sync_fn: Optional[Callable] = None, process_group: Opt to gather the list of tuples and then convert it back to a list of tuples. """ - super()._sync_dist(dist_sync_fn=dist_sync_fn, process_group=process_group) + super()._sync_dist(dist_sync_fn=dist_sync_fn, process_group=process_group) # type: ignore[arg-type] if "segm" in self.iou_type: - self.detection_mask = self._gather_tuple_list(self.detection_mask, process_group) - self.groundtruth_mask = self._gather_tuple_list(self.groundtruth_mask, process_group) + self.detection_mask = self._gather_tuple_list(self.detection_mask, process_group) # type: ignore[arg-type] + self.groundtruth_mask = self._gather_tuple_list(self.groundtruth_mask, process_group) # type: ignore[arg-type] @staticmethod def _gather_tuple_list(list_to_gather: List[Tuple], process_group: Optional[Any] = None) -> List[Any]: @@ -1035,7 +1035,7 @@ def _gather_tuple_list(list_to_gather: List[Tuple], process_group: Optional[Any] list_gathered = [None for _ in range(world_size)] dist.all_gather_object(list_gathered, list_to_gather, group=process_group) - return [list_gathered[rank][idx] for idx in range(len(list_gathered[0])) for rank in range(world_size)] + return [list_gathered[rank][idx] for idx in range(len(list_gathered[0])) for rank in range(world_size)] # type: ignore[arg-type,index] def _warning_on_too_many_detections(limit: int) -> None: diff --git a/src/torchmetrics/functional/image/psnr.py b/src/torchmetrics/functional/image/psnr.py index d4b12ff94bd..8f6a3f20dba 100644 --- a/src/torchmetrics/functional/image/psnr.py +++ b/src/torchmetrics/functional/image/psnr.py @@ -142,13 +142,13 @@ def peak_signal_noise_ratio( # `data_range` in the future. raise ValueError("The `data_range` must be given when `dim` is not None.") - data_range = target.max() - target.min() + data_range = target.max() - target.min() # type: ignore[assignment] elif isinstance(data_range, tuple): preds = torch.clamp(preds, min=data_range[0], max=data_range[1]) target = torch.clamp(target, min=data_range[0], max=data_range[1]) - data_range = tensor(data_range[1] - data_range[0]) + data_range = tensor(data_range[1] - data_range[0]) # type: ignore[assignment] else: - data_range = tensor(float(data_range)) + data_range = tensor(float(data_range)) # type: ignore[assignment] sum_squared_error, num_obs = _psnr_update(preds, target, dim=dim) - return _psnr_compute(sum_squared_error, num_obs, data_range, base=base, reduction=reduction) + return _psnr_compute(sum_squared_error, num_obs, data_range, base=base, reduction=reduction) # type: ignore[arg-type] diff --git a/src/torchmetrics/functional/image/ssim.py b/src/torchmetrics/functional/image/ssim.py index 14edac63b79..e33e3943fab 100644 --- a/src/torchmetrics/functional/image/ssim.py +++ b/src/torchmetrics/functional/image/ssim.py @@ -110,14 +110,14 @@ def _ssim_update( raise ValueError(f"Expected `sigma` to have positive number. Got {sigma}.") if data_range is None: - data_range = max(preds.max() - preds.min(), target.max() - target.min()) + data_range = max(preds.max() - preds.min(), target.max() - target.min()) # type: ignore[call-overload] elif isinstance(data_range, tuple): preds = torch.clamp(preds, min=data_range[0], max=data_range[1]) target = torch.clamp(target, min=data_range[0], max=data_range[1]) data_range = data_range[1] - data_range[0] - c1 = pow(k1 * data_range, 2) - c2 = pow(k2 * data_range, 2) + c1 = pow(k1 * data_range, 2) # type: ignore[operator] + c2 = pow(k2 * data_range, 2) # type: ignore[operator] device = preds.device channel = preds.size(1) @@ -421,7 +421,7 @@ def _multiscale_ssim_update( betas = torch.tensor(betas, device=mcs_stack.device).view(-1, 1) mcs_weighted = mcs_stack**betas - return torch.prod(mcs_weighted, axis=0) + return torch.prod(mcs_weighted, axis=0) # type: ignore[call-overload] def _multiscale_ssim_compute(