Skip to content

Commit

Permalink
Fix gradcam pyre fixme issues (pytorch#1466)
Browse files Browse the repository at this point in the history
Summary:


Fixing unresolved pyre fixme issues in corresponding file

Reviewed By: cyrjano

Differential Revision: D67705191
  • Loading branch information
Vivek Miglani authored and facebook-github-bot committed Dec 30, 2024
1 parent 945357d commit 1f41895
Showing 1 changed file with 3 additions and 19 deletions.
22 changes: 3 additions & 19 deletions captum/attr/_core/layer/grad_cam.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python3

# pyre-strict
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -54,8 +54,7 @@ class LayerGradCam(LayerAttribution, GradientAttribution):

def __init__(
self,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
forward_func: Callable,
forward_func: Callable[..., Tensor],
layer: Module,
device_ids: Union[None, List[int]] = None,
) -> None:
Expand Down Expand Up @@ -201,7 +200,7 @@ def attribute(
# hidden layer and hidden layer evaluated at each input.
layer_gradients, layer_evals = compute_layer_gradients_and_eval(
self.forward_func,
self.layer,
cast(Module, self.layer),
inputs,
target,
additional_forward_args,
Expand All @@ -213,10 +212,7 @@ def attribute(
summed_grads = tuple(
(
torch.mean(
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
# `Tuple[Tensor, ...]`.
layer_grad,
# pyre-fixme[16]: `tuple` has no attribute `shape`.
dim=tuple(x for x in range(2, len(layer_grad.shape))),
keepdim=True,
)
Expand All @@ -228,27 +224,15 @@ def attribute(

if attr_dim_summation:
scaled_acts = tuple(
# pyre-fixme[58]: `*` is not supported for operand types
# `Union[tuple[torch._tensor.Tensor], torch._tensor.Tensor]` and
# `Tuple[Tensor, ...]`.
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
# `Tuple[Tensor, ...]`.
torch.sum(summed_grad * layer_eval, dim=1, keepdim=True)
for summed_grad, layer_eval in zip(summed_grads, layer_evals)
)
else:
scaled_acts = tuple(
# pyre-fixme[58]: `*` is not supported for operand types
# `Union[tuple[torch._tensor.Tensor], torch._tensor.Tensor]` and
# `Tuple[Tensor, ...]`.
summed_grad * layer_eval
for summed_grad, layer_eval in zip(summed_grads, layer_evals)
)

if relu_attributions:
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
# `Union[tuple[Tensor], Tensor]`.
scaled_acts = tuple(F.relu(scaled_act) for scaled_act in scaled_acts)
# pyre-fixme[6]: For 2nd argument expected `Tuple[Tensor, ...]` but got
# `Tuple[Union[tuple[Tensor], Tensor], ...]`.
return _format_output(len(scaled_acts) > 1, scaled_acts)

0 comments on commit 1f41895

Please sign in to comment.