diff --git a/captum/attr/_core/layer/layer_integrated_gradients.py b/captum/attr/_core/layer/layer_integrated_gradients.py index 146c5c552..dba6488a3 100644 --- a/captum/attr/_core/layer/layer_integrated_gradients.py +++ b/captum/attr/_core/layer/layer_integrated_gradients.py @@ -109,6 +109,124 @@ def __init__( stacklevel=2, ) + def _make_gradient_func( + self, + # pyre-fixme[2]: Parameter needs type annotation. + num_outputs_cumsum, + attribute_to_layer_input: bool, + ) -> Callable[..., Tuple[Tensor, ...]]: + + def _gradient_func( + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. + forward_fn: Callable, + inputs: Union[Tensor, Tuple[Tensor, ...]], + target_ind: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. + additional_forward_args: Any = None, + ) -> Tuple[Tensor, ...]: + if self.device_ids is None or len(self.device_ids) == 0: + scattered_inputs = (inputs,) + else: + # scatter method does not have a precise enough return type in its + # stub, so suppress the type warning. + scattered_inputs = scatter( # type:ignore + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `Union[Tensor, typing.Tuple[Tensor, ...]]`. + inputs, + target_gpus=self.device_ids, + ) + + scattered_inputs_dict = { + scattered_input[0].device: scattered_input + for scattered_input in scattered_inputs + } + + with torch.autograd.set_grad_enabled(True): + + # pyre-fixme[53]: Captured variable `num_outputs_cumsum` is not + # annotated. + # pyre-fixme[53]: Captured variable `scattered_inputs_dict` is not + # annotated. + # pyre-fixme[3]: Return type must be annotated. + def layer_forward_hook( + # pyre-fixme[2]: Parameter must be annotated. + module, + # pyre-fixme[2]: Parameter must be annotated. + hook_inputs, + # pyre-fixme[2]: Parameter must be annotated. + hook_outputs=None, + # pyre-fixme[2]: Parameter must be annotated. + layer_idx=0, + ): + device = _extract_device(module, hook_inputs, hook_outputs) + is_layer_tuple = ( + isinstance(hook_outputs, tuple) + # hook_outputs is None if attribute_to_layer_input == True + if hook_outputs is not None + else isinstance(hook_inputs, tuple) + ) + + if is_layer_tuple: + return scattered_inputs_dict[device][ + num_outputs_cumsum[layer_idx] : num_outputs_cumsum[ + layer_idx + 1 + ] + ] + + return scattered_inputs_dict[device][num_outputs_cumsum[layer_idx]] + + hooks = [] + try: + + layers = self.layer + if not isinstance(layers, list): + layers = [self.layer] + + for layer_idx, layer in enumerate(layers): + hook = None + # TODO: + # Allow multiple attribute_to_layer_input flags for + # each layer, i.e. attribute_to_layer_input[layer_idx] + if attribute_to_layer_input: + hook = layer.register_forward_pre_hook( + functools.partial( + layer_forward_hook, layer_idx=layer_idx + ) + ) + else: + hook = layer.register_forward_hook( + functools.partial( + layer_forward_hook, layer_idx=layer_idx + ) + ) + + hooks.append(hook) + + # the inputs is an empty tuple + # coz it is prepended into additional_forward_args + output = _run_forward( + self.forward_func, (), target_ind, additional_forward_args + ) + finally: + for hook in hooks: + if hook is not None: + hook.remove() + + # _run_forward may return future of Tensor, + # but we don't support it here now + # And it will fail before here. + output = cast(Tensor, output) + assert output[0].numel() == 1, ( + "Target not provided when necessary, cannot" + " take gradient with respect to multiple outputs." + ) + # torch.unbind(forward_out) is a list of scalar tensor tuples and + # contains batch_size * #steps elements + grads = torch.autograd.grad(torch.unbind(output), inputs) + return grads + + return _gradient_func + @overload # pyre-fixme[43]: The implementation of `attribute` does not accept all possible # arguments of overload defined on line `112`. @@ -415,116 +533,10 @@ def flatten_tuple(tup): baselines_layer = flatten_tuple(baselines_layer) # inputs -> these inputs are scaled - def gradient_func( - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - forward_fn: Callable, - inputs: Union[Tensor, Tuple[Tensor, ...]], - target_ind: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, - ) -> Tuple[Tensor, ...]: - if self.device_ids is None or len(self.device_ids) == 0: - scattered_inputs = (inputs,) - else: - # scatter method does not have a precise enough return type in its - # stub, so suppress the type warning. - scattered_inputs = scatter( # type:ignore - # pyre-fixme[6]: For 1st argument expected `Tensor` but got - # `Union[Tensor, typing.Tuple[Tensor, ...]]`. - inputs, - target_gpus=self.device_ids, - ) - - scattered_inputs_dict = { - scattered_input[0].device: scattered_input - for scattered_input in scattered_inputs - } - - with torch.autograd.set_grad_enabled(True): - - # pyre-fixme[53]: Captured variable `num_outputs_cumsum` is not - # annotated. - # pyre-fixme[53]: Captured variable `scattered_inputs_dict` is not - # annotated. - # pyre-fixme[3]: Return type must be annotated. - def layer_forward_hook( - # pyre-fixme[2]: Parameter must be annotated. - module, - # pyre-fixme[2]: Parameter must be annotated. - hook_inputs, - # pyre-fixme[2]: Parameter must be annotated. - hook_outputs=None, - # pyre-fixme[2]: Parameter must be annotated. - layer_idx=0, - ): - device = _extract_device(module, hook_inputs, hook_outputs) - is_layer_tuple = ( - isinstance(hook_outputs, tuple) - # hook_outputs is None if attribute_to_layer_input == True - if hook_outputs is not None - else isinstance(hook_inputs, tuple) - ) - - if is_layer_tuple: - return scattered_inputs_dict[device][ - num_outputs_cumsum[layer_idx] : num_outputs_cumsum[ - layer_idx + 1 - ] - ] - - return scattered_inputs_dict[device][num_outputs_cumsum[layer_idx]] - - hooks = [] - try: - layers = self.layer - if not isinstance(layers, list): - layers = [self.layer] - - for layer_idx, layer in enumerate(layers): - hook = None - # TODO: - # Allow multiple attribute_to_layer_input flags for - # each layer, i.e. attribute_to_layer_input[layer_idx] - if attribute_to_layer_input: - hook = layer.register_forward_pre_hook( - functools.partial( - layer_forward_hook, layer_idx=layer_idx - ) - ) - else: - hook = layer.register_forward_hook( - functools.partial( - layer_forward_hook, layer_idx=layer_idx - ) - ) - - hooks.append(hook) - - # the inputs is an empty tuple - # coz it is prepended into additional_forward_args - output = _run_forward( - self.forward_func, (), target_ind, additional_forward_args - ) - finally: - for hook in hooks: - if hook is not None: - hook.remove() - - # _run_forward may return future of Tensor, - # but we don't support it here now - # And it will fail before here. - output = cast(Tensor, output) - assert output[0].numel() == 1, ( - "Target not provided when necessary, cannot" - " take gradient with respect to multiple outputs." - ) - # torch.unbind(forward_out) is a list of scalar tensor tuples and - # contains batch_size * #steps elements - grads = torch.autograd.grad(torch.unbind(output), inputs) - return grads - - self.ig.gradient_func = gradient_func + self.ig.gradient_func = self._make_gradient_func( + num_outputs_cumsum, attribute_to_layer_input + ) all_inputs = ( (inps + additional_forward_args) if additional_forward_args is not None