From 8343177bf7d91306e029af2462e4893bf6383796 Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Mon, 30 Dec 2024 16:01:03 -0800 Subject: [PATCH] Fix layer deeplift pyre fixme issues (#1470) Summary: Fixing unresolved pyre fixme issues in corresponding file Reviewed By: cyrjano Differential Revision: D67705583 --- captum/attr/_core/layer/layer_deep_lift.py | 33 ++++++++++++---------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/captum/attr/_core/layer/layer_deep_lift.py b/captum/attr/_core/layer/layer_deep_lift.py index a126971cf2..da24e7cb48 100644 --- a/captum/attr/_core/layer/layer_deep_lift.py +++ b/captum/attr/_core/layer/layer_deep_lift.py @@ -321,8 +321,9 @@ def attribute( additional_forward_args, ) - # pyre-fixme[24]: Generic type `Sequence` expects 1 type parameter. - def chunk_output_fn(out: TensorOrTupleOfTensorsGeneric) -> Sequence: + def chunk_output_fn( + out: TensorOrTupleOfTensorsGeneric, + ) -> Sequence[Union[Tensor, Sequence[Tensor]]]: if isinstance(out, Tensor): return out.chunk(2) return tuple(out_sub.chunk(2) for out_sub in out) @@ -434,8 +435,6 @@ def __init__( # Ignoring mypy error for inconsistent signature with DeepLiftShap @typing.overload # type: ignore - # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `453`. def attribute( self, inputs: Union[Tensor, Tuple[Tensor, ...]], @@ -450,9 +449,7 @@ def attribute( custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None, ) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ... - @typing.overload - # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `439`. + @typing.overload # type: ignore def attribute( self, inputs: Union[Tensor, Tuple[Tensor, ...]], @@ -654,7 +651,7 @@ def attribute( ) = DeepLiftShap._expand_inputs_baselines_targets( self, baselines, inputs, target, additional_forward_args ) - attributions = LayerDeepLift.attribute.__wrapped__( # type: ignore + attribs_layer_deeplift = LayerDeepLift.attribute.__wrapped__( # type: ignore self, exp_inp, exp_base, @@ -667,8 +664,12 @@ def attribute( attribute_to_layer_input=attribute_to_layer_input, custom_attribution_func=custom_attribution_func, ) + delta: Tensor + attributions: Union[Tensor, Tuple[Tensor, ...]] if return_convergence_delta: - attributions, delta = attributions + attributions, delta = attribs_layer_deeplift + else: + attributions = attribs_layer_deeplift if isinstance(attributions, tuple): attributions = tuple( DeepLiftShap._compute_mean_across_baselines( @@ -681,15 +682,17 @@ def attribute( self, inp_bsz, base_bsz, attributions ) if return_convergence_delta: - # pyre-fixme[61]: `delta` is undefined, or not always defined. return attributions, delta else: - # pyre-fixme[7]: Expected `Union[Tuple[Union[Tensor, - # typing.Tuple[Tensor, ...]], Tensor], Tensor, typing.Tuple[Tensor, ...]]` - # but got `Union[tuple[Tensor], Tensor]`. - return attributions + return cast( + Union[ + Tensor, + Tuple[Tensor, ...], + Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor], + ], + attributions, + ) @property - # pyre-fixme[3]: Return type must be annotated. def multiplies_by_inputs(self) -> bool: return self._multiply_by_inputs