From 848e805f18cb08c85a8f75b6fe66db92f0fd7b5a Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Tue, 22 Oct 2024 10:08:31 -0700 Subject: [PATCH] Fix pyre errors in Integrated Gradients (#1398) Summary: Initial work on fixing Pyre errors in Integrated Gradients Reviewed By: csauper Differential Revision: D64677345 --- captum/attr/_core/integrated_gradients.py | 45 +++++++---------------- 1 file changed, 14 insertions(+), 31 deletions(-) diff --git a/captum/attr/_core/integrated_gradients.py b/captum/attr/_core/integrated_gradients.py index e80326293..1abbcc69f 100644 --- a/captum/attr/_core/integrated_gradients.py +++ b/captum/attr/_core/integrated_gradients.py @@ -2,7 +2,7 @@ # pyre-strict import typing -from typing import Any, Callable, List, Tuple, Union +from typing import Any, Callable, List, Literal, Tuple, Union import torch from captum._utils.common import ( @@ -12,12 +12,7 @@ _format_output, _is_tuple, ) -from captum._utils.typing import ( - BaselineType, - Literal, - TargetType, - TensorOrTupleOfTensorsGeneric, -) +from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric from captum.attr._utils.approximation_methods import approximation_parameters from captum.attr._utils.attribution import GradientAttribution from captum.attr._utils.batching import _batch_attribution @@ -49,8 +44,7 @@ class IntegratedGradients(GradientAttribution): def __init__( self, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - forward_func: Callable, + forward_func: Callable[..., Tensor], multiply_by_inputs: bool = True, ) -> None: r""" @@ -80,21 +74,16 @@ def __init__( # and when return_convergence_delta is True, the return type is # a tuple with both attributions and deltas. @typing.overload - # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `95`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = None, target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, + additional_forward_args: object = None, n_steps: int = 50, method: str = "gausslegendre", internal_batch_size: Union[None, int] = None, *, - # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[True], ) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ... @@ -111,9 +100,6 @@ def attribute( n_steps: int = 50, method: str = "gausslegendre", internal_batch_size: Union[None, int] = None, - # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`. - # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[False] = False, ) -> TensorOrTupleOfTensorsGeneric: ... @@ -275,37 +261,35 @@ def attribute( # type: ignore """ # Keeps track whether original input is a tuple or not before # converting it into a tuple. - # pyre-fixme[6]: For 1st argument expected `Tensor` but got - # `TensorOrTupleOfTensorsGeneric`. is_inputs_tuple = _is_tuple(inputs) # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as # `Tuple[Tensor, ...]`. - inputs, baselines = _format_input_baseline(inputs, baselines) + formatted_inputs, formatted_baselines = _format_input_baseline( + inputs, baselines + ) # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got # `TensorOrTupleOfTensorsGeneric`. - _validate_input(inputs, baselines, n_steps, method) + _validate_input(formatted_inputs, formatted_baselines, n_steps, method) if internal_batch_size is not None: - num_examples = inputs[0].shape[0] + num_examples = formatted_inputs[0].shape[0] attributions = _batch_attribution( self, num_examples, internal_batch_size, n_steps, - inputs=inputs, - baselines=baselines, + inputs=formatted_inputs, + baselines=formatted_baselines, target=target, additional_forward_args=additional_forward_args, method=method, ) else: attributions = self._attribute( - # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but - # got `TensorOrTupleOfTensorsGeneric`. - inputs=inputs, - baselines=baselines, + inputs=formatted_inputs, + baselines=formatted_baselines, target=target, additional_forward_args=additional_forward_args, n_steps=n_steps, @@ -344,8 +328,7 @@ def _attribute( inputs: Tuple[Tensor, ...], baselines: Tuple[Union[Tensor, int, float], ...], target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, + additional_forward_args: object = None, n_steps: int = 50, method: str = "gausslegendre", step_sizes_and_alphas: Union[None, Tuple[List[float], List[float]]] = None,