From d0f6543f6778a78fb2389776cd862d2bd5613e11 Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Mon, 30 Dec 2024 16:00:37 -0800 Subject: [PATCH] Fix layer gradient shap pyre fixme issues (#1471) Summary: Fixing unresolved pyre fixme issues in corresponding file Reviewed By: cyrjano Differential Revision: D67705670 --- .../attr/_core/layer/layer_gradient_shap.py | 58 +++++++------------ 1 file changed, 22 insertions(+), 36 deletions(-) diff --git a/captum/attr/_core/layer/layer_gradient_shap.py b/captum/attr/_core/layer/layer_gradient_shap.py index c9987eb00..e0e213997 100644 --- a/captum/attr/_core/layer/layer_gradient_shap.py +++ b/captum/attr/_core/layer/layer_gradient_shap.py @@ -61,8 +61,7 @@ class LayerGradientShap(LayerAttribution, GradientAttribution): def __init__( self, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - forward_func: Callable, + forward_func: Callable[..., Tensor], layer: Module, device_ids: Union[None, List[int]] = None, multiply_by_inputs: bool = True, @@ -104,13 +103,12 @@ def __init__( self._multiply_by_inputs = multiply_by_inputs @typing.overload - # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `106`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - baselines: Union[TensorOrTupleOfTensorsGeneric, Callable], + baselines: Union[ + TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric] + ], n_samples: int = 5, stdevs: Union[float, Tuple[float, ...]] = 0.0, target: TargetType = None, @@ -121,13 +119,12 @@ def attribute( ) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ... @typing.overload - # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `120`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - baselines: Union[TensorOrTupleOfTensorsGeneric, Callable], + baselines: Union[ + TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric] + ], n_samples: int = 5, stdevs: Union[float, Tuple[float, ...]] = 0.0, target: TargetType = None, @@ -137,11 +134,14 @@ def attribute( ) -> Union[Tensor, Tuple[Tensor, ...]]: ... @log_usage() + # pyre-fixme[43]: This definition does not have the same decorators as the + # preceding overload(s). def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - baselines: Union[TensorOrTupleOfTensorsGeneric, Callable], + baselines: Union[ + TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric] + ], n_samples: int = 5, stdevs: Union[float, Tuple[float, ...]] = 0.0, target: TargetType = None, @@ -294,17 +294,10 @@ def attribute( """ # since `baselines` is a distribution, we can generate it using a function # rather than passing it as an input argument - # pyre-fixme[9]: baselines has type `Union[typing.Callable[..., typing.Any], - # Variable[TensorOrTupleOfTensorsGeneric <: [Tensor, typing.Tuple[Tensor, - # ...]]]]`; used as `Tuple[Tensor, ...]`. - baselines = _format_callable_baseline(baselines, inputs) - # pyre-fixme[16]: Item `Callable` of `Union[(...) -> Any, - # TensorOrTupleOfTensorsGeneric]` has no attribute `__getitem__`. - assert isinstance(baselines[0], torch.Tensor), ( + formatted_baselines = _format_callable_baseline(baselines, inputs) + assert isinstance(formatted_baselines[0], torch.Tensor), ( "Baselines distribution has to be provided in a form " - # pyre-fixme[16]: Item `Callable` of `Union[(...) -> Any, - # TensorOrTupleOfTensorsGeneric]` has no attribute `__getitem__`. - "of a torch.Tensor {}.".format(baselines[0]) + "of a torch.Tensor {}.".format(formatted_baselines[0]) ) input_min_baseline_x_grad = LayerInputBaselineXGradient( @@ -323,7 +316,7 @@ def attribute( nt_samples=n_samples, stdevs=stdevs, draw_baseline_from_distrib=True, - baselines=baselines, + baselines=formatted_baselines, target=target, additional_forward_args=additional_forward_args, return_convergence_delta=return_convergence_delta, @@ -343,8 +336,7 @@ def multiplies_by_inputs(self) -> bool: class LayerInputBaselineXGradient(LayerAttribution, GradientAttribution): def __init__( self, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - forward_func: Callable, + forward_func: Callable[..., Tensor], layer: Module, device_ids: Union[None, List[int]] = None, multiply_by_inputs: bool = True, @@ -436,7 +428,7 @@ def attribute( # type: ignore ) grads, _ = compute_layer_gradients_and_eval( self.forward_func, - self.layer, + cast(Module, self.layer), input_baseline_scaled, target, additional_forward_args, @@ -448,7 +440,7 @@ def attribute( # type: ignore attr_baselines = _forward_layer_eval( self.forward_func, baselines, - self.layer, + cast(Module, self.layer), additional_forward_args=additional_forward_args, device_ids=self.device_ids, attribute_to_layer_input=attribute_to_layer_input, @@ -457,19 +449,15 @@ def attribute( # type: ignore attr_inputs = _forward_layer_eval( self.forward_func, inputs, - self.layer, + cast(Module, self.layer), additional_forward_args=additional_forward_args, device_ids=self.device_ids, attribute_to_layer_input=attribute_to_layer_input, ) - + attributions: Tuple[Tensor, ...] if self.multiplies_by_inputs: input_baseline_diffs = tuple( - # pyre-fixme[58]: `-` is not supported for operand types - # `typing.Tuple[torch._tensor.Tensor, ...]` and - # `typing.Tuple[torch._tensor.Tensor, ...]`. - input - baseline - for input, baseline in zip(attr_inputs, attr_baselines) + input - baseline for input, baseline in zip(attr_inputs, attr_baselines) ) attributions = tuple( input_baseline_diff * grad @@ -481,8 +469,6 @@ def attribute( # type: ignore return _compute_conv_delta_and_format_attrs( self, return_convergence_delta, - # pyre-fixme[6]: For 3rd argument expected `Tuple[Tensor, ...]` but got - # `Union[List[typing.Tuple[Tensor, ...]], tuple[Tensor]]`. attributions, baselines, inputs,