Skip to content

Commit

Permalink
Fix internal influence pyre fixme issues
Browse files Browse the repository at this point in the history
Differential Revision: D67705214
  • Loading branch information
Vivek Miglani authored and facebook-github-bot committed Dec 30, 2024
1 parent 44a1de9 commit 5f6173d
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions captum/attr/_core/layer/internal_influence.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python3

# pyre-strict
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union

import torch
from captum._utils.common import (
Expand Down Expand Up @@ -41,8 +41,7 @@ class InternalInfluence(LayerAttribution, GradientAttribution):

def __init__(
self,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
forward_func: Callable,
forward_func: Callable[..., Tensor],
layer: Module,
device_ids: Union[None, List[int]] = None,
) -> None:
Expand Down Expand Up @@ -293,7 +292,7 @@ def _attribute(
# Returns gradient of output with respect to hidden layer.
layer_gradients, _ = compute_layer_gradients_and_eval(
forward_fn=self.forward_func,
layer=self.layer,
layer=cast(Module, self.layer),
inputs=scaled_features_tpl,
target_ind=expanded_target,
additional_forward_args=input_additional_args,
Expand All @@ -304,9 +303,7 @@ def _attribute(
# flattening grads so that we can multiply it with step-size
# calling contiguous to avoid `memory whole` problems
scaled_grads = tuple(
# pyre-fixme[16]: `tuple` has no attribute `contiguous`.
layer_grad.contiguous().view(n_steps, -1)
# pyre-fixme[16]: `tuple` has no attribute `device`.
* torch.tensor(step_sizes).view(n_steps, 1).to(layer_grad.device)
for layer_grad in layer_gradients
)
Expand All @@ -317,8 +314,7 @@ def _attribute(
scaled_grad,
n_steps,
inputs[0].shape[0],
# pyre-fixme[16]: `tuple` has no attribute `shape`.
layer_grad.shape[1:],
tuple(layer_grad.shape[1:]),
)
for scaled_grad, layer_grad in zip(scaled_grads, layer_gradients)
)
Expand Down

0 comments on commit 5f6173d

Please sign in to comment.