Skip to content

Commit

Permalink
layer_integrated_gradients is too complex (#1407)
Browse files Browse the repository at this point in the history
Summary:

This diff addresses the C901 in visualization.py by breaking down the method

Reviewed By: vivekmig

Differential Revision: D64565179
  • Loading branch information
jjuncho authored and facebook-github-bot committed Oct 23, 2024
1 parent b80e488 commit 77e2bc4
Showing 1 changed file with 121 additions and 109 deletions.
230 changes: 121 additions & 109 deletions captum/attr/_core/layer/layer_integrated_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,124 @@ def __init__(
stacklevel=2,
)

def _make_gradient_func(
self,
# pyre-fixme[2]: Parameter needs type annotation.
num_outputs_cumsum,
attribute_to_layer_input: bool,
) -> Callable[..., Tuple[Tensor, ...]]:

def _gradient_func(
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
forward_fn: Callable,
inputs: Union[Tensor, Tuple[Tensor, ...]],
target_ind: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
) -> Tuple[Tensor, ...]:
if self.device_ids is None or len(self.device_ids) == 0:
scattered_inputs = (inputs,)
else:
# scatter method does not have a precise enough return type in its
# stub, so suppress the type warning.
scattered_inputs = scatter( # type:ignore
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
# `Union[Tensor, typing.Tuple[Tensor, ...]]`.
inputs,
target_gpus=self.device_ids,
)

scattered_inputs_dict = {
scattered_input[0].device: scattered_input
for scattered_input in scattered_inputs
}

with torch.autograd.set_grad_enabled(True):

# pyre-fixme[53]: Captured variable `num_outputs_cumsum` is not
# annotated.
# pyre-fixme[53]: Captured variable `scattered_inputs_dict` is not
# annotated.
# pyre-fixme[3]: Return type must be annotated.
def layer_forward_hook(
# pyre-fixme[2]: Parameter must be annotated.
module,
# pyre-fixme[2]: Parameter must be annotated.
hook_inputs,
# pyre-fixme[2]: Parameter must be annotated.
hook_outputs=None,
# pyre-fixme[2]: Parameter must be annotated.
layer_idx=0,
):
device = _extract_device(module, hook_inputs, hook_outputs)
is_layer_tuple = (
isinstance(hook_outputs, tuple)
# hook_outputs is None if attribute_to_layer_input == True
if hook_outputs is not None
else isinstance(hook_inputs, tuple)
)

if is_layer_tuple:
return scattered_inputs_dict[device][
num_outputs_cumsum[layer_idx] : num_outputs_cumsum[
layer_idx + 1
]
]

return scattered_inputs_dict[device][num_outputs_cumsum[layer_idx]]

hooks = []
try:

layers = self.layer
if not isinstance(layers, list):
layers = [self.layer]

for layer_idx, layer in enumerate(layers):
hook = None
# TODO:
# Allow multiple attribute_to_layer_input flags for
# each layer, i.e. attribute_to_layer_input[layer_idx]
if attribute_to_layer_input:
hook = layer.register_forward_pre_hook(
functools.partial(
layer_forward_hook, layer_idx=layer_idx
)
)
else:
hook = layer.register_forward_hook(
functools.partial(
layer_forward_hook, layer_idx=layer_idx
)
)

hooks.append(hook)

# the inputs is an empty tuple
# coz it is prepended into additional_forward_args
output = _run_forward(
self.forward_func, (), target_ind, additional_forward_args
)
finally:
for hook in hooks:
if hook is not None:
hook.remove()

# _run_forward may return future of Tensor,
# but we don't support it here now
# And it will fail before here.
output = cast(Tensor, output)
assert output[0].numel() == 1, (
"Target not provided when necessary, cannot"
" take gradient with respect to multiple outputs."
)
# torch.unbind(forward_out) is a list of scalar tensor tuples and
# contains batch_size * #steps elements
grads = torch.autograd.grad(torch.unbind(output), inputs)
return grads

return _gradient_func

@overload
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
# arguments of overload defined on line `112`.
Expand Down Expand Up @@ -415,116 +533,10 @@ def flatten_tuple(tup):
baselines_layer = flatten_tuple(baselines_layer)

# inputs -> these inputs are scaled
def gradient_func(
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
forward_fn: Callable,
inputs: Union[Tensor, Tuple[Tensor, ...]],
target_ind: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
) -> Tuple[Tensor, ...]:
if self.device_ids is None or len(self.device_ids) == 0:
scattered_inputs = (inputs,)
else:
# scatter method does not have a precise enough return type in its
# stub, so suppress the type warning.
scattered_inputs = scatter( # type:ignore
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
# `Union[Tensor, typing.Tuple[Tensor, ...]]`.
inputs,
target_gpus=self.device_ids,
)

scattered_inputs_dict = {
scattered_input[0].device: scattered_input
for scattered_input in scattered_inputs
}

with torch.autograd.set_grad_enabled(True):

# pyre-fixme[53]: Captured variable `num_outputs_cumsum` is not
# annotated.
# pyre-fixme[53]: Captured variable `scattered_inputs_dict` is not
# annotated.
# pyre-fixme[3]: Return type must be annotated.
def layer_forward_hook(
# pyre-fixme[2]: Parameter must be annotated.
module,
# pyre-fixme[2]: Parameter must be annotated.
hook_inputs,
# pyre-fixme[2]: Parameter must be annotated.
hook_outputs=None,
# pyre-fixme[2]: Parameter must be annotated.
layer_idx=0,
):
device = _extract_device(module, hook_inputs, hook_outputs)
is_layer_tuple = (
isinstance(hook_outputs, tuple)
# hook_outputs is None if attribute_to_layer_input == True
if hook_outputs is not None
else isinstance(hook_inputs, tuple)
)

if is_layer_tuple:
return scattered_inputs_dict[device][
num_outputs_cumsum[layer_idx] : num_outputs_cumsum[
layer_idx + 1
]
]

return scattered_inputs_dict[device][num_outputs_cumsum[layer_idx]]

hooks = []
try:

layers = self.layer
if not isinstance(layers, list):
layers = [self.layer]

for layer_idx, layer in enumerate(layers):
hook = None
# TODO:
# Allow multiple attribute_to_layer_input flags for
# each layer, i.e. attribute_to_layer_input[layer_idx]
if attribute_to_layer_input:
hook = layer.register_forward_pre_hook(
functools.partial(
layer_forward_hook, layer_idx=layer_idx
)
)
else:
hook = layer.register_forward_hook(
functools.partial(
layer_forward_hook, layer_idx=layer_idx
)
)

hooks.append(hook)

# the inputs is an empty tuple
# coz it is prepended into additional_forward_args
output = _run_forward(
self.forward_func, (), target_ind, additional_forward_args
)
finally:
for hook in hooks:
if hook is not None:
hook.remove()

# _run_forward may return future of Tensor,
# but we don't support it here now
# And it will fail before here.
output = cast(Tensor, output)
assert output[0].numel() == 1, (
"Target not provided when necessary, cannot"
" take gradient with respect to multiple outputs."
)
# torch.unbind(forward_out) is a list of scalar tensor tuples and
# contains batch_size * #steps elements
grads = torch.autograd.grad(torch.unbind(output), inputs)
return grads

self.ig.gradient_func = gradient_func
self.ig.gradient_func = self._make_gradient_func(
num_outputs_cumsum, attribute_to_layer_input
)
all_inputs = (
(inps + additional_forward_args)
if additional_forward_args is not None
Expand Down

0 comments on commit 77e2bc4

Please sign in to comment.