From 422c30d9c5930c12b53b1e3799e715f76b0d2f74 Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Mon, 21 Oct 2024 00:07:52 -0700 Subject: [PATCH] Fix pyre errors in Guided Backprop Summary: Initial work on fixing Pyre errors in Guided Backprop Differential Revision: D64677346 --- .../attr/_core/guided_backprop_deconvnet.py | 25 ++++++------------- 1 file changed, 7 insertions(+), 18 deletions(-) diff --git a/captum/attr/_core/guided_backprop_deconvnet.py b/captum/attr/_core/guided_backprop_deconvnet.py index 60359bc0d..aca369d3c 100644 --- a/captum/attr/_core/guided_backprop_deconvnet.py +++ b/captum/attr/_core/guided_backprop_deconvnet.py @@ -45,8 +45,7 @@ def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, + additional_forward_args: object = None, ) -> TensorOrTupleOfTensorsGeneric: r""" Computes attribution by overriding relu gradients. Based on constructor @@ -58,16 +57,10 @@ def attribute( # Keeps track whether original input is a tuple or not before # converting it into a tuple. - # pyre-fixme[6]: For 1st argument expected `Tensor` but got - # `TensorOrTupleOfTensorsGeneric`. is_inputs_tuple = _is_tuple(inputs) - # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as - # `Tuple[Tensor, ...]`. - inputs = _format_tensor_into_tuples(inputs) - # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got - # `TensorOrTupleOfTensorsGeneric`. - gradient_mask = apply_gradient_requirements(inputs) + inputs_tuple = _format_tensor_into_tuples(inputs) + gradient_mask = apply_gradient_requirements(inputs_tuple) # set hooks for overriding ReLU gradients warnings.warn( @@ -79,14 +72,12 @@ def attribute( self.model.apply(self._register_hooks) gradients = self.gradient_func( - self.forward_func, inputs, target, additional_forward_args + self.forward_func, inputs_tuple, target, additional_forward_args ) finally: self._remove_hooks() - # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got - # `TensorOrTupleOfTensorsGeneric`. - undo_gradient_requirements(inputs, gradient_mask) + undo_gradient_requirements(inputs_tuple, gradient_mask) # pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got # `Tuple[Tensor, ...]`. return _format_output(is_inputs_tuple, gradients) @@ -155,8 +146,7 @@ def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, + additional_forward_args: object = None, ) -> TensorOrTupleOfTensorsGeneric: r""" Args: @@ -265,8 +255,7 @@ def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, + additional_forward_args: object = None, ) -> TensorOrTupleOfTensorsGeneric: r""" Args: