Skip to content

Commit

Permalink
Fix layer integrated gradients pyre fixme issues
Browse files Browse the repository at this point in the history
Differential Revision: D67706224
  • Loading branch information
Vivek Miglani authored and facebook-github-bot committed Dec 30, 2024
1 parent 81ea94c commit 5f482b0
Showing 1 changed file with 42 additions and 36 deletions.
78 changes: 42 additions & 36 deletions captum/attr/_core/layer/layer_integrated_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
)
from captum.log import log_usage
from torch import Tensor
from torch.nn import Module
from torch.nn.parallel.scatter_gather import scatter


Expand All @@ -58,8 +59,7 @@ class LayerIntegratedGradients(LayerAttribution, GradientAttribution):

def __init__(
self,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
forward_func: Callable,
forward_func: Callable[..., Tensor],
layer: ModuleOrModuleList,
device_ids: Union[None, List[int]] = None,
multiply_by_inputs: bool = True,
Expand Down Expand Up @@ -128,8 +128,7 @@ def _make_gradient_func(
) -> Callable[..., Tuple[Tensor, ...]]:

def _gradient_func(
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
forward_fn: Callable,
forward_fn: Callable[..., Tensor],
inputs: Union[Tensor, Tuple[Tensor, ...]],
target_ind: TargetType = None,
additional_forward_args: Optional[object] = None,
Expand All @@ -146,28 +145,21 @@ def _gradient_func(
target_gpus=self.device_ids,
)

scattered_inputs_dict = {
scattered_inputs_dict: Dict[
torch.device, Union[Tensor, Tuple[Tensor, ...]]
] = {
scattered_input[0].device: scattered_input
for scattered_input in scattered_inputs
}

with torch.autograd.set_grad_enabled(True):

# pyre-fixme[53]: Captured variable `num_outputs_cumsum` is not
# annotated.
# pyre-fixme[53]: Captured variable `scattered_inputs_dict` is not
# annotated.
# pyre-fixme[3]: Return type must be annotated.
def layer_forward_hook(
# pyre-fixme[2]: Parameter must be annotated.
module,
# pyre-fixme[2]: Parameter must be annotated.
hook_inputs,
# pyre-fixme[2]: Parameter must be annotated.
hook_outputs=None,
# pyre-fixme[2]: Parameter must be annotated.
layer_idx=0,
):
module: Module,
hook_inputs: Union[Tensor, Tuple[Tensor, ...]],
hook_outputs: Union[None, Tensor, Tuple[Tensor, ...]] = None,
layer_idx: int = 0,
) -> Union[Tensor, Tuple[Tensor, ...]]:
device = _extract_device(module, hook_inputs, hook_outputs)
is_layer_tuple = (
isinstance(hook_outputs, tuple)
Expand All @@ -177,11 +169,14 @@ def layer_forward_hook(
)

if is_layer_tuple:
return scattered_inputs_dict[device][
num_outputs_cumsum[layer_idx] : num_outputs_cumsum[
layer_idx + 1
]
]
return cast(
Union[Tensor, Tuple[Tensor, ...]],
scattered_inputs_dict[device][
num_outputs_cumsum[layer_idx] : num_outputs_cumsum[
layer_idx + 1
]
],
)

return scattered_inputs_dict[device][num_outputs_cumsum[layer_idx]]

Expand Down Expand Up @@ -255,6 +250,7 @@ def attribute(
) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]: ...

@overload
@log_usage()
def attribute( # type: ignore
self,
inputs: Union[Tensor, Tuple[Tensor, ...]],
Expand All @@ -273,8 +269,7 @@ def attribute( # type: ignore
]: ...

@overload
# pyre-fixme[43]: This definition does not have the same decorators as the
# preceding overload(s).
@log_usage()
def attribute(
self,
inputs: Union[Tensor, Tuple[Tensor, ...]],
Expand All @@ -296,8 +291,6 @@ def attribute(
]: ...

@log_usage()
# pyre-fixme[43]: This definition does not have the same decorators as the
# preceding overload(s).
def attribute(
self,
inputs: Union[Tensor, Tuple[Tensor, ...]],
Expand Down Expand Up @@ -502,11 +495,22 @@ def attribute(
additional_forward_args
)

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def flatten_tuple(tup):
def flatten_tuple(tup: List[Tuple[Tensor, ...]]) -> Tuple[Tensor, ...]:
return tuple(
sum((list(x) if isinstance(x, (tuple, list)) else [x] for x in tup), [])
cast(
List[Tensor],
sum(
(
(
list(x)
if isinstance(x, (tuple, list))
else cast(List[Tensor], [x])
)
for x in tup
),
[],
),
)
)

if self.device_ids is None:
Expand All @@ -520,16 +524,18 @@ def flatten_tuple(tup):
additional_forward_args=additional_forward_args,
attribute_to_layer_input=attribute_to_layer_input,
)

input_layer_list: List[Tuple[Tensor, ...]]
# if we have one output
if not isinstance(self.layer, list):
inputs_layer = (inputs_layer,)
input_layer_list = [cast(Tuple[Tensor, ...], inputs_layer)]
else:
input_layer_list = inputs_layer

num_outputs = [1 if isinstance(x, Tensor) else len(x) for x in inputs_layer]
num_outputs = [1 if isinstance(x, Tensor) else len(x) for x in input_layer_list]
num_outputs_cumsum = torch.cumsum(
torch.IntTensor([0] + num_outputs), dim=0 # type: ignore
)
inputs_layer = flatten_tuple(inputs_layer)
inputs_layer = flatten_tuple(input_layer_list)

baselines_layer = _forward_layer_eval(
self.forward_func,
Expand Down

0 comments on commit 5f482b0

Please sign in to comment.