Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix layer integrated gradients pyre fixme issues #1473

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 40 additions & 32 deletions captum/attr/_core/layer/layer_integrated_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
)
from captum.log import log_usage
from torch import Tensor
from torch.nn import Module
from torch.nn.parallel.scatter_gather import scatter


Expand All @@ -58,8 +59,7 @@ class LayerIntegratedGradients(LayerAttribution, GradientAttribution):

def __init__(
self,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
forward_func: Callable,
forward_func: Callable[..., Tensor],
layer: ModuleOrModuleList,
device_ids: Union[None, List[int]] = None,
multiply_by_inputs: bool = True,
Expand Down Expand Up @@ -128,8 +128,7 @@ def _make_gradient_func(
) -> Callable[..., Tuple[Tensor, ...]]:

def _gradient_func(
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
forward_fn: Callable,
forward_fn: Callable[..., Tensor],
inputs: Union[Tensor, Tuple[Tensor, ...]],
target_ind: TargetType = None,
additional_forward_args: Optional[object] = None,
Expand All @@ -146,28 +145,21 @@ def _gradient_func(
target_gpus=self.device_ids,
)

scattered_inputs_dict = {
scattered_inputs_dict: Dict[
torch.device, Union[Tensor, Tuple[Tensor, ...]]
] = {
scattered_input[0].device: scattered_input
for scattered_input in scattered_inputs
}

with torch.autograd.set_grad_enabled(True):

# pyre-fixme[53]: Captured variable `num_outputs_cumsum` is not
# annotated.
# pyre-fixme[53]: Captured variable `scattered_inputs_dict` is not
# annotated.
# pyre-fixme[3]: Return type must be annotated.
def layer_forward_hook(
# pyre-fixme[2]: Parameter must be annotated.
module,
# pyre-fixme[2]: Parameter must be annotated.
hook_inputs,
# pyre-fixme[2]: Parameter must be annotated.
hook_outputs=None,
# pyre-fixme[2]: Parameter must be annotated.
layer_idx=0,
):
module: Module,
hook_inputs: Union[Tensor, Tuple[Tensor, ...]],
hook_outputs: Union[None, Tensor, Tuple[Tensor, ...]] = None,
layer_idx: int = 0,
) -> Union[Tensor, Tuple[Tensor, ...]]:
device = _extract_device(module, hook_inputs, hook_outputs)
is_layer_tuple = (
isinstance(hook_outputs, tuple)
Expand All @@ -177,11 +169,14 @@ def layer_forward_hook(
)

if is_layer_tuple:
return scattered_inputs_dict[device][
num_outputs_cumsum[layer_idx] : num_outputs_cumsum[
layer_idx + 1
]
]
return cast(
Union[Tensor, Tuple[Tensor, ...]],
scattered_inputs_dict[device][
num_outputs_cumsum[layer_idx] : num_outputs_cumsum[
layer_idx + 1
]
],
)

return scattered_inputs_dict[device][num_outputs_cumsum[layer_idx]]

Expand Down Expand Up @@ -502,11 +497,22 @@ def attribute(
additional_forward_args
)

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def flatten_tuple(tup):
def flatten_tuple(tup: List[Tuple[Tensor, ...]]) -> Tuple[Tensor, ...]:
return tuple(
sum((list(x) if isinstance(x, (tuple, list)) else [x] for x in tup), [])
cast(
List[Tensor],
sum(
(
(
list(x)
if isinstance(x, (tuple, list))
else cast(List[Tensor], [x])
)
for x in tup
),
[],
),
)
)

if self.device_ids is None:
Expand All @@ -520,16 +526,18 @@ def flatten_tuple(tup):
additional_forward_args=additional_forward_args,
attribute_to_layer_input=attribute_to_layer_input,
)

input_layer_list: List[Tuple[Tensor, ...]]
# if we have one output
if not isinstance(self.layer, list):
inputs_layer = (inputs_layer,)
input_layer_list = [cast(Tuple[Tensor, ...], inputs_layer)]
else:
input_layer_list = inputs_layer

num_outputs = [1 if isinstance(x, Tensor) else len(x) for x in inputs_layer]
num_outputs = [1 if isinstance(x, Tensor) else len(x) for x in input_layer_list]
num_outputs_cumsum = torch.cumsum(
torch.IntTensor([0] + num_outputs), dim=0 # type: ignore
)
inputs_layer = flatten_tuple(inputs_layer)
inputs_layer = flatten_tuple(input_layer_list)

baselines_layer = _forward_layer_eval(
self.forward_func,
Expand Down
Loading