From c92279305d4211f16e0445374c3c7315ee11699b Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Mon, 30 Dec 2024 08:55:00 -0800 Subject: [PATCH] Fix layer gradient x activation pyre fixme issues (#1472) Summary: Pull Request resolved: https://github.com/pytorch/captum/pull/1472 Differential Revision: D67705758 --- .../_core/layer/layer_gradient_x_activation.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/captum/attr/_core/layer/layer_gradient_x_activation.py b/captum/attr/_core/layer/layer_gradient_x_activation.py index c828a262e5..f56265c2e8 100644 --- a/captum/attr/_core/layer/layer_gradient_x_activation.py +++ b/captum/attr/_core/layer/layer_gradient_x_activation.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # pyre-strict -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union from captum._utils.common import ( _format_additional_forward_args, @@ -24,8 +24,7 @@ class LayerGradientXActivation(LayerAttribution, GradientAttribution): def __init__( self, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - forward_func: Callable, + forward_func: Callable[..., Tensor], layer: ModuleOrModuleList, device_ids: Union[None, List[int]] = None, multiply_by_inputs: bool = True, @@ -186,11 +185,10 @@ def attribute( if isinstance(self.layer, Module): return _format_output( len(layer_evals) > 1, - # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but - # got `List[typing.Tuple[Tensor, ...]]`. - # pyre-fixme[6]: For 2nd argument expected `Tuple[Tensor, ...]` but - # got `List[typing.Tuple[Tensor, ...]]`. - self.multiply_gradient_acts(layer_gradients, layer_evals), + self.multiply_gradient_acts( + cast(Tuple[Tensor, ...], layer_gradients), + cast(Tuple[Tensor, ...], layer_evals), + ), ) else: return [