From aed0ef8294cf17822f4154f0195bd40fed725898 Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Mon, 21 Oct 2024 17:24:58 -0700 Subject: [PATCH] Fix pyre errors in Feature Ablation (#1392) Summary: Initial work on fixing Pyre errors in Feature Ablation Reviewed By: jjuncho Differential Revision: D64677337 --- captum/attr/_core/feature_ablation.py | 44 ++++++++------------------- 1 file changed, 13 insertions(+), 31 deletions(-) diff --git a/captum/attr/_core/feature_ablation.py b/captum/attr/_core/feature_ablation.py index 26e93846b..4eb46650b 100644 --- a/captum/attr/_core/feature_ablation.py +++ b/captum/attr/_core/feature_ablation.py @@ -47,8 +47,9 @@ class FeatureAblation(PerturbationAttribution): first dimension (i.e. a feature mask requires to be applied to all inputs). """ - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - def __init__(self, forward_func: Callable) -> None: + def __init__( + self, forward_func: Callable[..., Union[int, float, Tensor, Future[Tensor]]] + ) -> None: r""" Args: @@ -74,8 +75,7 @@ def attribute( inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = None, target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, + additional_forward_args: object = None, feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None, perturbations_per_eval: int = 1, show_progress: bool = False, @@ -261,8 +261,6 @@ def attribute( """ # Keeps track whether original input is a tuple or not before # converting it into a tuple. - # pyre-fixme[6]: For 1st argument expected `Tensor` but got - # `TensorOrTupleOfTensorsGeneric`. is_inputs_tuple = _is_tuple(inputs) formatted_inputs, baselines = _format_input_baseline(inputs, baselines) @@ -270,8 +268,6 @@ def attribute( additional_forward_args ) num_examples = formatted_inputs[0].shape[0] - # pyre-fixme[6]: For 2nd argument expected `Tuple[Tensor, ...]` but got - # `TensorOrTupleOfTensorsGeneric`. formatted_feature_mask = _format_feature_mask(feature_mask, formatted_inputs) assert ( @@ -384,8 +380,6 @@ def attribute( # pyre-fixme[7]: Expected `Variable[TensorOrTupleOfTensorsGeneric <: # [Tensor, typing.Tuple[Tensor, ...]]]` # but got `Union[Tensor, typing.Tuple[Tensor, ...]]`. - # pyre-fixme[6]: In call `FeatureAblation._generate_result`, - # for 3rd positional argument, expected `bool` but got `Literal[]`. return self._generate_result(total_attrib, weights, is_inputs_tuple) # type: ignore # noqa: E501 line too long def _initial_eval_to_processed_initial_eval_fut( @@ -414,8 +408,7 @@ def attribute_future( inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = None, target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, + additional_forward_args: object = None, feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None, perturbations_per_eval: int = 1, show_progress: bool = False, @@ -428,8 +421,6 @@ def attribute_future( # Keeps track whether original input is a tuple or not before # converting it into a tuple. - # pyre-fixme[6]: For 1st argument expected `Tensor` but got - # `TensorOrTupleOfTensorsGeneric`. is_inputs_tuple = _is_tuple(inputs) formatted_inputs, baselines = _format_input_baseline(inputs, baselines) formatted_additional_forward_args = _format_additional_forward_args( @@ -660,13 +651,11 @@ def _eval_fut_to_ablated_out_fut( ) from e return result - # pyre-fixme[3]: Return type must be specified as type that does not contain `Any` def _ith_input_ablation_generator( self, i: int, inputs: TensorOrTupleOfTensorsGeneric, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_args: Any, + additional_args: object, target: TargetType, baselines: BaselineType, input_mask: Union[None, Tensor, Tuple[Tensor, ...]], @@ -675,7 +664,7 @@ def _ith_input_ablation_generator( ) -> Generator[ Tuple[ Tuple[Tensor, ...], - Any, + object, TargetType, Tensor, ], @@ -705,10 +694,9 @@ def _ith_input_ablation_generator( perturbations_per_eval = min(perturbations_per_eval, num_features) baseline = baselines[i] if isinstance(baselines, tuple) else baselines if isinstance(baseline, torch.Tensor): - # pyre-fixme[58]: `+` is not supported for operand types `Tuple[int]` - # and `Size`. - baseline = baseline.reshape((1,) + baseline.shape) + baseline = baseline.reshape((1,) + tuple(baseline.shape)) + additional_args_repeated: object if perturbations_per_eval > 1: # Repeat features and additional args for batch size. all_features_repeated = [ @@ -727,6 +715,7 @@ def _ith_input_ablation_generator( target_repeated = target num_features_processed = min_feature + current_additional_args: object while num_features_processed < num_features: current_num_ablated_features = min( perturbations_per_eval, num_features - num_features_processed @@ -762,9 +751,7 @@ def _ith_input_ablation_generator( # dimension of this tensor. current_reshaped = current_features[i].reshape( (current_num_ablated_features, -1) - # pyre-fixme[58]: `+` is not supported for operand types - # `Tuple[int, int]` and `Size`. - + current_features[i].shape[1:] + + tuple(current_features[i].shape[1:]) ) ablated_features, current_mask = self._construct_ablated_input( @@ -780,10 +767,7 @@ def _ith_input_ablation_generator( # (current_num_ablated_features * num_examples, inputs[i].shape[1:]), # which can be provided to the model as input. current_features[i] = ablated_features.reshape( - (-1,) - # pyre-fixme[58]: `+` is not supported for operand types - # `Tuple[int]` and `Size`. - + ablated_features.shape[2:] + (-1,) + tuple(ablated_features.shape[2:]) ) yield tuple( current_features @@ -818,9 +802,7 @@ def _construct_ablated_input( thus counted towards ablations for that feature) and 0s otherwise. """ current_mask = torch.stack( - # pyre-fixme[6]: For 1st argument expected `Union[List[Tensor], - # Tuple[Tensor, ...]]` but got `List[Union[bool, Tensor]]`. - [input_mask == j for j in range(start_feature, end_feature)], # type: ignore # noqa: E501 line too long + cast(List[Tensor], [input_mask == j for j in range(start_feature, end_feature)]), # type: ignore # noqa: E501 line too long dim=0, ).long() current_mask = current_mask.to(expanded_input.device)