From 3856b515f410e579e584a61dcc26cb5885a89c67 Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Mon, 30 Dec 2024 07:38:08 -0800 Subject: [PATCH] Fix gradcam pyre fixme issues (#1466) Summary: Pull Request resolved: https://github.com/pytorch/captum/pull/1466 Differential Revision: D67705191 --- captum/attr/_core/layer/grad_cam.py | 22 +++------------------- 1 file changed, 3 insertions(+), 19 deletions(-) diff --git a/captum/attr/_core/layer/grad_cam.py b/captum/attr/_core/layer/grad_cam.py index eed6397609..d57049ad8e 100644 --- a/captum/attr/_core/layer/grad_cam.py +++ b/captum/attr/_core/layer/grad_cam.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # pyre-strict -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -54,8 +54,7 @@ class LayerGradCam(LayerAttribution, GradientAttribution): def __init__( self, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - forward_func: Callable, + forward_func: Callable[..., Tensor], layer: Module, device_ids: Union[None, List[int]] = None, ) -> None: @@ -201,7 +200,7 @@ def attribute( # hidden layer and hidden layer evaluated at each input. layer_gradients, layer_evals = compute_layer_gradients_and_eval( self.forward_func, - self.layer, + cast(Module, self.layer), inputs, target, additional_forward_args, @@ -213,10 +212,7 @@ def attribute( summed_grads = tuple( ( torch.mean( - # pyre-fixme[6]: For 1st argument expected `Tensor` but got - # `Tuple[Tensor, ...]`. layer_grad, - # pyre-fixme[16]: `tuple` has no attribute `shape`. dim=tuple(x for x in range(2, len(layer_grad.shape))), keepdim=True, ) @@ -228,27 +224,15 @@ def attribute( if attr_dim_summation: scaled_acts = tuple( - # pyre-fixme[58]: `*` is not supported for operand types - # `Union[tuple[torch._tensor.Tensor], torch._tensor.Tensor]` and - # `Tuple[Tensor, ...]`. - # pyre-fixme[6]: For 1st argument expected `Tensor` but got - # `Tuple[Tensor, ...]`. torch.sum(summed_grad * layer_eval, dim=1, keepdim=True) for summed_grad, layer_eval in zip(summed_grads, layer_evals) ) else: scaled_acts = tuple( - # pyre-fixme[58]: `*` is not supported for operand types - # `Union[tuple[torch._tensor.Tensor], torch._tensor.Tensor]` and - # `Tuple[Tensor, ...]`. summed_grad * layer_eval for summed_grad, layer_eval in zip(summed_grads, layer_evals) ) if relu_attributions: - # pyre-fixme[6]: For 1st argument expected `Tensor` but got - # `Union[tuple[Tensor], Tensor]`. scaled_acts = tuple(F.relu(scaled_act) for scaled_act in scaled_acts) - # pyre-fixme[6]: For 2nd argument expected `Tuple[Tensor, ...]` but got - # `Tuple[Union[tuple[Tensor], Tensor], ...]`. return _format_output(len(scaled_acts) > 1, scaled_acts)