From 8ff04209aebaaf35da4f85aa072ea9fef36ce22f Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Tue, 22 Oct 2024 10:07:55 -0700 Subject: [PATCH] Fix pyre errors in NoiseTunnel (#1402) Summary: Initial work on fixing Pyre errors in Noise Tunnel Reviewed By: craymichael Differential Revision: D64677341 --- captum/attr/_core/noise_tunnel.py | 49 ++++++++++++------------------- 1 file changed, 19 insertions(+), 30 deletions(-) diff --git a/captum/attr/_core/noise_tunnel.py b/captum/attr/_core/noise_tunnel.py index 7247ccc00..5d9eb1962 100644 --- a/captum/attr/_core/noise_tunnel.py +++ b/captum/attr/_core/noise_tunnel.py @@ -2,7 +2,7 @@ # pyre-strict from enum import Enum -from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Union import torch from captum._utils.common import ( @@ -27,8 +27,7 @@ class NoiseTunnelType(Enum): vargrad = 3 -# pyre-fixme[5]: Global expression must be annotated. -SUPPORTED_NOISE_TUNNEL_TYPES = list(NoiseTunnelType.__members__.keys()) +SUPPORTED_NOISE_TUNNEL_TYPES: List[str] = list(NoiseTunnelType.__members__.keys()) class NoiseTunnel(Attribution): @@ -58,6 +57,10 @@ class NoiseTunnel(Attribution): It is assumed that the batch size is the first dimension of input tensors. """ + is_delta_supported: bool + _multiply_by_inputs: bool + is_gradient_method: bool + def __init__(self, attribution_method: Attribution) -> None: r""" Args: @@ -66,19 +69,15 @@ def __init__(self, attribution_method: Attribution) -> None: Conductance or Saliency. """ self.attribution_method = attribution_method - # pyre-fixme[4]: Attribute must be annotated. self.is_delta_supported = self.attribution_method.has_convergence_delta() - # pyre-fixme[4]: Attribute must be annotated. self._multiply_by_inputs = self.attribution_method.multiplies_by_inputs - # pyre-fixme[4]: Attribute must be annotated. self.is_gradient_method = isinstance( self.attribution_method, GradientAttribution ) Attribution.__init__(self, self.attribution_method.forward_func) @property - # pyre-fixme[3]: Return type must be annotated. - def multiplies_by_inputs(self): + def multiplies_by_inputs(self) -> bool: return self._multiply_by_inputs @log_usage() @@ -205,9 +204,10 @@ def attribute( nt_samples_batch_size, kwargs_copy, inputs, draw_baseline_from_distrib ) - sum_attributions: List[Union[None, Tensor]] = [] - sum_attributions_sq: List[Union[None, Tensor]] = [] + sum_attributions: Sequence[Union[None, Tensor]] = [] + sum_attributions_sq: Sequence[Union[None, Tensor]] = [] delta_partial_list: List[Tensor] = [] + is_attrib_tuple = is_inputs_tuple for _ in range(nt_samples_partition): inputs_with_noise = self._add_noise_to_inputs( @@ -225,11 +225,7 @@ def attribute( ) if len(sum_attributions) == 0: - # pyre-fixme[9]: sum_attributions has type - # `List[Optional[Tensor]]`; used as `List[None]`. sum_attributions = [None] * len(attributions_partial) - # pyre-fixme[9]: sum_attributions_sq has type - # `List[Optional[Tensor]]`; used as `List[None]`. sum_attributions_sq = [None] * len(attributions_partial) self._update_partial_attribution_and_delta( @@ -297,7 +293,6 @@ def attribute( return self._apply_checks_and_return_attributions( attributions, - # pyre-fixme[61]: `is_attrib_tuple` is undefined, or not always defined. is_attrib_tuple, return_convergence_delta, delta, @@ -348,9 +343,7 @@ def _add_noise_to_input( bsz = input.shape[0] # expand input size by the number of drawn samples - # pyre-fixme[58]: `+` is not supported for operand types `Tuple[int]` - # and `Size`. - input_expanded_size = (bsz * nt_samples_partition,) + input.shape[1:] + input_expanded_size = (bsz * nt_samples_partition,) + tuple(input.shape[1:]) # expand stdev for the shape of the input and number of drawn samples stdev_expanded = torch.tensor(stdev, device=input.device).repeat( @@ -375,14 +368,13 @@ def _update_sum_attribution_and_sq( bsz = attribution.shape[0] // nt_samples_batch_size_inter attribution_shape = cast(Tuple[int, ...], (bsz, nt_samples_batch_size_inter)) if len(attribution.shape) > 1: - # pyre-fixme[22]: The cast is redundant. - attribution_shape += cast(Tuple[int, ...], tuple(attribution.shape[1:])) + attribution_shape += tuple(attribution.shape[1:]) attribution = attribution.view(attribution_shape) current_attribution_sum = attribution.sum(dim=1, keepdim=False) - # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and - # `int`. - current_attribution_sq = torch.sum(attribution**2, dim=1, keepdim=False) + current_attribution_sq = torch.sum( + torch.pow(attribution, 2), dim=1, keepdim=False + ) sum_attribution[i] = ( current_attribution_sum @@ -398,8 +390,7 @@ def _update_sum_attribution_and_sq( def _compute_partial_attribution( self, inputs_with_noise_partition: Tuple[Tensor, ...], - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - kwargs_partition: Any, + kwargs_partition: object, is_inputs_tuple: bool, return_convergence_delta: bool, ) -> Tuple[Tuple[Tensor, ...], bool, Union[None, Tensor]]: @@ -505,14 +496,12 @@ def _apply_checks_and_return_attributions( ) -> Union[ TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor] ]: - # pyre-fixme[9]: Unable to unpack `Union[Tensor, typing.Tuple[Tensor, - # ...]]`, expected a tuple. - attributions = _format_output(is_attrib_tuple, attributions) + attributions_tuple = _format_output(is_attrib_tuple, attributions) ret = ( - (attributions, cast(Tensor, delta)) + (attributions_tuple, cast(Tensor, delta)) if self.is_delta_supported and return_convergence_delta - else attributions + else attributions_tuple ) ret = cast( # pyre-fixme[34]: `Variable[TensorOrTupleOfTensorsGeneric <: