From 4ef567d29644ce30af33231b5ed95909e960e04c Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Mon, 21 Oct 2024 17:24:58 -0700 Subject: [PATCH] Fix pyre errors in DeepLift (#1391) Summary: Initial work on fixing Pyre errors in DeepLift Differential Revision: D64677338 --- captum/attr/_core/deep_lift.py | 209 ++++++++++----------------------- 1 file changed, 61 insertions(+), 148 deletions(-) diff --git a/captum/attr/_core/deep_lift.py b/captum/attr/_core/deep_lift.py index 6a71605d3..7166595dd 100644 --- a/captum/attr/_core/deep_lift.py +++ b/captum/attr/_core/deep_lift.py @@ -3,7 +3,7 @@ # pyre-strict import typing import warnings -from typing import Any, Callable, cast, List, Tuple, Union +from typing import Any, Callable, cast, Dict, List, Literal, Tuple, Type, Union import torch import torch.nn as nn @@ -25,12 +25,7 @@ apply_gradient_requirements, undo_gradient_requirements, ) -from captum._utils.typing import ( - BaselineType, - Literal, - TargetType, - TensorOrTupleOfTensorsGeneric, -) +from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric from captum.attr._utils.attribution import GradientAttribution from captum.attr._utils.common import ( _call_custom_attribution_func, @@ -117,35 +112,24 @@ def __init__( self._multiply_by_inputs = multiply_by_inputs @typing.overload - # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `131`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = None, target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, + additional_forward_args: object = None, *, - # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[True], custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None, ) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ... @typing.overload - # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `120`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = None, target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, - # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`. - # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. + additional_forward_args: object = None, return_convergence_delta: Literal[False] = False, custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None, ) -> TensorOrTupleOfTensorsGeneric: ... @@ -156,7 +140,7 @@ def attribute( # type: ignore inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = None, target: TargetType = None, - additional_forward_args: Any = None, + additional_forward_args: object = None, return_convergence_delta: bool = False, custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None, ) -> Union[ @@ -302,24 +286,14 @@ def attribute( # type: ignore # Keeps track whether original input is a tuple or not before # converting it into a tuple. - # pyre-fixme[6]: For 1st argument expected `Tensor` but got - # `TensorOrTupleOfTensorsGeneric`. is_inputs_tuple = _is_tuple(inputs) - # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as - # `Tuple[Tensor, ...]`. - inputs = _format_tensor_into_tuples(inputs) - # pyre-fixme[6]: For 2nd argument expected `Tuple[Tensor, ...]` but got - # `TensorOrTupleOfTensorsGeneric`. - baselines = _format_baseline(baselines, inputs) + inputs_tuple = _format_tensor_into_tuples(inputs) + baselines = _format_baseline(baselines, inputs_tuple) - # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got - # `TensorOrTupleOfTensorsGeneric`. - gradient_mask = apply_gradient_requirements(inputs) + gradient_mask = apply_gradient_requirements(inputs_tuple) - # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got - # `TensorOrTupleOfTensorsGeneric`. - _validate_input(inputs, baselines) + _validate_input(inputs_tuple, baselines) # set hooks for baselines warnings.warn( @@ -328,9 +302,7 @@ def attribute( # type: ignore after the attribution is finished""", stacklevel=2, ) - # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got - # `TensorOrTupleOfTensorsGeneric`. - baselines = _tensorize_baseline(inputs, baselines) + baselines = _tensorize_baseline(inputs_tuple, baselines) main_model_hooks = [] try: main_model_hooks = self._hook_main_model() @@ -347,17 +319,17 @@ def attribute( # type: ignore wrapped_forward_func = self._construct_forward_func( self.model, - (inputs, baselines), + (inputs_tuple, baselines), expanded_target, additional_forward_args, ) - gradients = self.gradient_func(wrapped_forward_func, inputs) + gradients = self.gradient_func(wrapped_forward_func, inputs_tuple) if custom_attribution_func is None: if self.multiplies_by_inputs: attributions = tuple( (input - baseline) * gradient for input, baseline, gradient in zip( - inputs, baselines, gradients + inputs_tuple, baselines, gradients ) ) else: @@ -366,25 +338,21 @@ def attribute( # type: ignore attributions = _call_custom_attribution_func( custom_attribution_func, gradients, - # pyre-fixme[6]: For 3rd argument expected `Tuple[Tensor, ...]` - # but got `TensorOrTupleOfTensorsGeneric`. - inputs, + inputs_tuple, baselines, ) finally: # Even if any error is raised, remove all hooks before raising self._remove_hooks(main_model_hooks) - # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got - # `TensorOrTupleOfTensorsGeneric`. - undo_gradient_requirements(inputs, gradient_mask) + undo_gradient_requirements(inputs_tuple, gradient_mask) # pyre-fixme[7]: Expected `Union[Tuple[Variable[TensorOrTupleOfTensorsGeneric... return _compute_conv_delta_and_format_attrs( self, return_convergence_delta, attributions, baselines, - inputs, + inputs_tuple, additional_forward_args, target, is_inputs_tuple, @@ -399,24 +367,17 @@ def attribute_future(self) -> Callable: def _construct_forward_func( self, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - forward_func: Callable, - # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. - inputs: Tuple, + forward_func: Callable[..., Tensor], + inputs: Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]], target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - ) -> Callable: - # pyre-fixme[3]: Return type must be annotated. - def forward_fn(): - model_out = _run_forward( - forward_func, inputs, None, additional_forward_args + additional_forward_args: object = None, + ) -> Callable[[], Tensor]: + def forward_fn() -> Tensor: + model_out = cast( + Tensor, + _run_forward(forward_func, inputs, None, additional_forward_args), ) return _select_targets( - # pyre-fixme[16]: Item `Future` of - # `Union[Future[torch._tensor.Tensor], Tensor]` has no attribute - # `__getitem__`. torch.cat((model_out[:, 0], model_out[:, 1])), target, ) @@ -539,8 +500,10 @@ def _remove_hooks(self, extra_hooks_to_remove: List[RemovableHandle]) -> None: backward_handle.remove() def _hook_main_model(self) -> List[RemovableHandle]: - # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. - def pre_hook(module: Module, baseline_inputs_add_args: Tuple) -> Tuple: + def pre_hook( + module: Module, + baseline_inputs_add_args: Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]], + ) -> Tuple[object, ...]: inputs = baseline_inputs_add_args[0] baselines = baseline_inputs_add_args[1] additional_args = None @@ -553,9 +516,7 @@ def pre_hook(module: Module, baseline_inputs_add_args: Tuple) -> Tuple: ) if additional_args is not None: expanded_additional_args = cast( - # pyre-fixme[24]: Generic type `tuple` expects at least 1 type - # parameter. - Tuple, + Tuple[object], _expand_additional_forward_args( additional_args, 2, ExpansionTypes.repeat ), @@ -565,9 +526,9 @@ def pre_hook(module: Module, baseline_inputs_add_args: Tuple) -> Tuple: return (*baseline_input_tsr, *expanded_additional_args) return baseline_input_tsr - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. - def forward_hook(module: Module, inputs: Tuple, outputs: Tensor): + def forward_hook( + module: Module, inputs: Tuple[Tensor, ...], outputs: Tensor + ) -> Tensor: return torch.stack(torch.chunk(outputs, 2), dim=1) if isinstance( @@ -636,8 +597,6 @@ def __init__(self, model: Module, multiply_by_inputs: bool = True) -> None: # There's a mismatch between the signatures of DeepLift.attribute and # DeepLiftShap.attribute, so we ignore typing here @typing.overload # type: ignore - # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `597`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, @@ -645,18 +604,13 @@ def attribute( TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric] ], target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, + additional_forward_args: object = None, *, - # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[True], custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None, ) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ... @typing.overload - # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `584`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, @@ -664,11 +618,7 @@ def attribute( TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric] ], target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, - # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`. - # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. + additional_forward_args: object = None, return_convergence_delta: Literal[False] = False, custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None, ) -> TensorOrTupleOfTensorsGeneric: ... @@ -681,7 +631,7 @@ def attribute( # type: ignore TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric] ], target: TargetType = None, - additional_forward_args: Any = None, + additional_forward_args: object = None, return_convergence_delta: bool = False, custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None, ) -> Union[ @@ -821,42 +771,28 @@ def attribute( # type: ignore >>> # Computes shap values using deeplift for class 3. >>> attribution = dl.attribute(input, target=3) """ - # pyre-fixme[9]: baselines has type `Union[typing.Callable[..., - # Variable[TensorOrTupleOfTensorsGeneric <: [Tensor, typing.Tuple[Tensor, - # ...]]]], Variable[TensorOrTupleOfTensorsGeneric <: [Tensor, - # typing.Tuple[Tensor, ...]]]]`; used as `Tuple[Tensor, ...]`. - baselines = _format_callable_baseline(baselines, inputs) - - # pyre-fixme[16]: Item `Callable` of `Union[(...) -> - # TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric]` has no - # attribute `__getitem__`. - assert isinstance(baselines[0], torch.Tensor) and baselines[0].shape[0] > 1, ( + formatted_baselines = _format_callable_baseline(baselines, inputs) + + assert ( + isinstance(formatted_baselines[0], torch.Tensor) + and formatted_baselines[0].shape[0] > 1 + ), ( "Baselines distribution has to be provided in form of a torch.Tensor" " with more than one example but found: {}." " If baselines are provided in shape of scalars or with a single" " baseline example, `DeepLift`" - # pyre-fixme[16]: Item `Callable` of `Union[(...) -> - # TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric]` has no - # attribute `__getitem__`. - " approach can be used instead.".format(baselines[0]) + " approach can be used instead.".format(formatted_baselines[0]) ) # Keeps track whether original input is a tuple or not before # converting it into a tuple. - # pyre-fixme[6]: For 1st argument expected `Tensor` but got - # `TensorOrTupleOfTensorsGeneric`. is_inputs_tuple = _is_tuple(inputs) - # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as - # `Tuple[Tensor, ...]`. - inputs = _format_tensor_into_tuples(inputs) + inputs_tuple = _format_tensor_into_tuples(inputs) # batch sizes - inp_bsz = inputs[0].shape[0] - # pyre-fixme[16]: Item `Callable` of `Union[(...) -> - # TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric]` has no - # attribute `__getitem__`. - base_bsz = baselines[0].shape[0] + inp_bsz = inputs_tuple[0].shape[0] + base_bsz = formatted_baselines[0].shape[0] ( exp_inp, @@ -864,13 +800,8 @@ def attribute( # type: ignore exp_tgt, exp_addit_args, ) = self._expand_inputs_baselines_targets( - # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got `... - # pyre-fixme[6]: For 2nd argument expected `Tuple[Tensor, ...]` but got - # `TensorOrTupleOfTensorsGeneric`. - baselines, - # pyre-fixme[6]: For 2nd argument expected `Tuple[Tensor, ...]` but got - # `TensorOrTupleOfTensorsGeneric`. - inputs, + formatted_baselines, + inputs_tuple, target, additional_forward_args, ) @@ -881,15 +812,12 @@ def attribute( # type: ignore target=exp_tgt, additional_forward_args=exp_addit_args, return_convergence_delta=cast( - # pyre-fixme[31]: Expression `Literal[(True, False)]` is not a valid - # type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take - # parameters. Literal[True, False], return_convergence_delta, ), custom_attribution_func=custom_attribution_func, ) + delta: Tensor = torch.tensor(0) if return_convergence_delta: attributions, delta = cast(Tuple[Tuple[Tensor, ...], Tensor], attributions) @@ -902,21 +830,18 @@ def attribute( # type: ignore if return_convergence_delta: # pyre-fixme[7]: Expected `Union[Tuple[Variable[TensorOrTupleOfTensorsGen... - # pyre-fixme[61]: `delta` is undefined, or not always defined. return _format_output(is_inputs_tuple, attributions), delta else: # pyre-fixme[7]: Expected `Union[Tuple[Variable[TensorOrTupleOfTensorsGen... return _format_output(is_inputs_tuple, attributions) - # pyre-fixme[3]: Return annotation cannot contain `Any`. def _expand_inputs_baselines_targets( self, baselines: Tuple[Tensor, ...], inputs: Tuple[Tensor, ...], target: TargetType, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any, - ) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...], TargetType, Any]: + additional_forward_args: object, + ) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...], TargetType, object]: inp_bsz = inputs[0].shape[0] base_bsz = baselines[0].shape[0] @@ -957,12 +882,9 @@ def _compute_mean_across_baselines( self, inp_bsz: int, base_bsz: int, attribution: Tensor ) -> Tensor: # Average for multiple references - # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. - attr_shape: Tuple = (inp_bsz, base_bsz) + attr_shape: Tuple[int, ...] = (inp_bsz, base_bsz) if len(attribution.shape) > 1: - # pyre-fixme[58]: `+` is not supported for operand types `Tuple[int, - # int]` and `Size`. - attr_shape += attribution.shape[1:] + attr_shape += tuple(attribution.shape[1:]) return torch.mean(attribution.view(attr_shape), dim=1, keepdim=False) @@ -1051,14 +973,10 @@ def maxpool2d( def maxpool3d( module: Module, - # pyre-fixme[2]: Parameter must be annotated. - inputs, - # pyre-fixme[2]: Parameter must be annotated. - outputs, - # pyre-fixme[2]: Parameter must be annotated. - grad_input, - # pyre-fixme[2]: Parameter must be annotated. - grad_output, + inputs: Tensor, + outputs: Tensor, + grad_input: Tensor, + grad_output: Tensor, eps: float = 1e-10, ) -> Tensor: return maxpool( @@ -1079,14 +997,10 @@ def maxpool( pool_func: Callable, # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. unpool_func: Callable, - # pyre-fixme[2]: Parameter must be annotated. - inputs, - # pyre-fixme[2]: Parameter must be annotated. - outputs, - # pyre-fixme[2]: Parameter must be annotated. - grad_input, - # pyre-fixme[2]: Parameter must be annotated. - grad_output, + inputs: Tensor, + outputs: Tensor, + grad_input: Tensor, + grad_output: Tensor, eps: float = 1e-10, ) -> Tensor: with torch.no_grad(): @@ -1158,8 +1072,7 @@ def _compute_diffs(inputs: Tensor, outputs: Tensor) -> Tuple[Tensor, Tensor]: return torch.cat(2 * [delta_in]), torch.cat(2 * [delta_out]) -# pyre-fixme[5]: Global expression must be annotated. -SUPPORTED_NON_LINEAR = { +SUPPORTED_NON_LINEAR: Dict[Type[Module], Callable[..., Tensor]] = { nn.ReLU: nonlinear, nn.ELU: nonlinear, nn.LeakyReLU: nonlinear,