Skip to content

Commit

Permalink
Address Pyre FixMe's in layer_feature_ablation.py (#1411)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1411

This diff helps address the number of pyre-fixme's in the layer_feature_ablation.py file.

Reviewed By: jjuncho

Differential Revision: D64796530

fbshipit-source-id: 804cbe43a1b110f9b1377c499d9c30ee12aa0898
  • Loading branch information
Ayush-Warikoo authored and facebook-github-bot committed Oct 24, 2024
1 parent 568434a commit cbe45aa
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 20 deletions.
40 changes: 23 additions & 17 deletions captum/attr/_core/layer/layer_feature_ablation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python3

# pyre-strict
from typing import Any, Callable, List, Tuple, Type, Union
from typing import Any, Callable, cast, Dict, List, Tuple, Type, Union

import torch
from captum._utils.common import (
Expand Down Expand Up @@ -37,8 +37,7 @@ class LayerFeatureAblation(LayerAttribution, PerturbationAttribution):

def __init__(
self,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
forward_func: Callable,
forward_func: Callable[..., Tensor],
layer: Module,
device_ids: Union[None, List[int]] = None,
) -> None:
Expand Down Expand Up @@ -70,8 +69,7 @@ def attribute(
inputs: Union[Tensor, Tuple[Tensor, ...]],
layer_baselines: BaselineType = None,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: object = None,
layer_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
attribute_to_layer_input: bool = False,
perturbations_per_eval: int = 1,
Expand Down Expand Up @@ -225,29 +223,33 @@ def attribute(
>>> layer_mask=layer_mask)
"""

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def layer_forward_func(*args):
layer_length = args[-1]
layer_input = args[:layer_length]
original_inputs = args[layer_length:-1]
def layer_forward_func(*args: Any) -> Union[Tensor]:
r"""
Args:
args (Any): Tensors comprising the layer input and the original
inputs, and an int representing the length of the layer input
"""
layer_length: int = args[-1]
layer_input: Tuple[Tensor, ...] = args[:layer_length]
original_inputs: Tuple[Tensor, ...] = args[layer_length:-1]

device_ids = self.device_ids
if device_ids is None:
device_ids = getattr(self.forward_func, "device_ids", None)

all_layer_inputs = {}
all_layer_inputs: Dict[torch.device, Tuple[Tensor, ...]] = {}
if device_ids is not None:
scattered_layer_input = scatter(layer_input, target_gpus=device_ids)
for device_tensors in scattered_layer_input:
all_layer_inputs[device_tensors[0].device] = device_tensors
else:
all_layer_inputs[layer_input[0].device] = layer_input

# pyre-fixme[53]: Captured variable `all_layer_inputs` is not annotated.
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def forward_hook(module, inp, out=None):
def forward_hook(
module: Module,
inp: Union[None, Tensor, Tuple[Tensor, ...]],
out: Union[None, Tensor, Tuple[Tensor, ...]] = None,
) -> Union[Tensor, Tuple[Tensor, ...]]:
device = _extract_device(module, inp, out)
is_layer_tuple = (
isinstance(out, tuple)
Expand Down Expand Up @@ -275,7 +277,11 @@ def forward_hook(module, inp, out=None):
finally:
if hook is not None:
hook.remove()
return eval

# _run_forward may return future of Tensor,
# but we don't support it here now
# And it will fail before here.
return cast(Tensor, eval)

with torch.no_grad():
inputs = _format_tensor_into_tuples(inputs)
Expand Down
5 changes: 2 additions & 3 deletions captum/attr/_core/layer/layer_feature_permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@
)

from captum._utils.gradient import _forward_layer_eval

from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric
from captum.attr._core.feature_permutation import FeaturePermutation
from captum.attr._utils.attribution import LayerAttribution
from captum.log import log_usage
from torch import device, Tensor
from torch import Tensor
from torch.nn import Module
from torch.nn.parallel.scatter_gather import scatter

Expand Down Expand Up @@ -171,7 +170,7 @@ def layer_forward_func(*args: Any) -> Tensor:
if device_ids is None:
device_ids = getattr(self.forward_func, "device_ids", None)

all_layer_inputs: Dict[device, Tuple[Tensor, ...]] = {}
all_layer_inputs: Dict[torch.device, Tuple[Tensor, ...]] = {}
if device_ids is not None:
scattered_layer_input = scatter(layer_input, target_gpus=device_ids)
for device_tensors in scattered_layer_input:
Expand Down

0 comments on commit cbe45aa

Please sign in to comment.