From 7e09289b15d4301338da249464076d91b7b8b136 Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Mon, 21 Oct 2024 09:22:57 -0700 Subject: [PATCH] Fix pyre errors in Lime (#1400) Summary: Initial work on fixing Pyre errors in Lime Differential Revision: D64677340 --- captum/attr/_core/lime.py | 100 ++++++++++++++++---------------------- 1 file changed, 43 insertions(+), 57 deletions(-) diff --git a/captum/attr/_core/lime.py b/captum/attr/_core/lime.py index 21bae8677..6829ae8e6 100644 --- a/captum/attr/_core/lime.py +++ b/captum/attr/_core/lime.py @@ -6,7 +6,7 @@ import typing import warnings from collections.abc import Iterator -from typing import Any, Callable, cast, List, Optional, Tuple, Union +from typing import Any, Callable, cast, Generator, List, Literal, Optional, Tuple, Union import torch from captum._utils.common import ( @@ -23,12 +23,7 @@ from captum._utils.models.linear_model import SkLearnLasso from captum._utils.models.model import Model from captum._utils.progress import progress -from captum._utils.typing import ( - BaselineType, - Literal, - TargetType, - TensorOrTupleOfTensorsGeneric, -) +from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric from captum.attr._utils.attribution import PerturbationAttribution from captum.attr._utils.batching import _batch_example_iterator from captum.attr._utils.common import ( @@ -73,18 +68,18 @@ class LimeBase(PerturbationAttribution): def __init__( self, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - forward_func: Callable, + forward_func: Callable[..., Tensor], interpretable_model: Model, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - similarity_func: Callable, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - perturb_func: Callable, + similarity_func: Callable[ + ..., + Union[float, Tensor], + ], + perturb_func: Callable[..., object], perturb_interpretable_space: bool, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - from_interp_rep_transform: Optional[Callable], - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - to_interp_rep_transform: Optional[Callable], + from_interp_rep_transform: Optional[ + Callable[..., Union[Tensor, Tuple[Tensor, ...]]] + ], + to_interp_rep_transform: Optional[Callable[..., Tensor]], ) -> None: r""" @@ -249,13 +244,11 @@ def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, + additional_forward_args: object = None, n_samples: int = 50, perturbations_per_eval: int = 1, show_progress: bool = False, - # pyre-fixme[2]: Parameter must be annotated. - **kwargs, + **kwargs: object, ) -> Tensor: r""" This method attributes the output of the model with given target index @@ -551,7 +544,7 @@ def generate_perturbation() -> ( curr_sample, inputs, **kwargs ) - return interpretable_inp, curr_model_input + return interpretable_inp, curr_model_input # type: ignore return generate_perturbation @@ -568,8 +561,7 @@ def _evaluate_batch( self, curr_model_inputs: List[TensorOrTupleOfTensorsGeneric], expanded_target: TargetType, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - expanded_additional_args: Any, + expanded_additional_args: object, device: torch.device, ) -> Tensor: model_out = _run_forward( @@ -630,8 +622,7 @@ def default_from_interp_rep_transform(curr_sample, original_inputs, **kwargs): def get_exp_kernel_similarity_function( distance_mode: str = "cosine", kernel_width: float = 1.0, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. -) -> Callable: +) -> Callable[..., float]: r""" This method constructs an appropriate similarity function to compute weights for perturbed sample in LIME. Distance between the original @@ -680,8 +671,9 @@ def default_exp_kernel(original_inp, perturbed_inp, __, **kwargs): return default_exp_kernel -# pyre-fixme[2]: Parameter must be annotated. -def default_perturb_func(original_inp, **kwargs) -> Tensor: +def default_perturb_func( + original_inp: TensorOrTupleOfTensorsGeneric, **kwargs: object +) -> Tensor: assert ( "num_interp_features" in kwargs ), "Must provide num_interp_features to use default interpretable sampling function" @@ -690,7 +682,7 @@ def default_perturb_func(original_inp, **kwargs) -> Tensor: else: device = original_inp[0].device - probs = torch.ones(1, kwargs["num_interp_features"]) * 0.5 + probs = torch.ones(1, cast(int, kwargs["num_interp_features"])) * 0.5 return torch.bernoulli(probs).to(device=device).long() @@ -698,17 +690,17 @@ def construct_feature_mask( feature_mask: Union[None, Tensor, Tuple[Tensor, ...]], formatted_inputs: Tuple[Tensor, ...], ) -> Tuple[Tuple[Tensor, ...], int]: + feature_mask_tuple: Tuple[Tensor, ...] if feature_mask is None: - feature_mask, num_interp_features = _construct_default_feature_mask( + feature_mask_tuple, num_interp_features = _construct_default_feature_mask( formatted_inputs ) else: - feature_mask = _format_tensor_into_tuples(feature_mask) + feature_mask_tuple = _format_tensor_into_tuples(feature_mask) min_interp_features = int( min( torch.min(single_mask).item() - # pyre-fixme[16]: `None` has no attribute `__iter__`. - for single_mask in feature_mask + for single_mask in feature_mask_tuple if single_mask.numel() ) ) @@ -718,14 +710,12 @@ def construct_feature_mask( " start at 0.", stacklevel=2, ) - feature_mask = tuple( - single_mask - min_interp_features for single_mask in feature_mask + feature_mask_tuple = tuple( + single_mask - min_interp_features for single_mask in feature_mask_tuple ) - # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got - # `Optional[typing.Tuple[typing.Any, ...]]`. - num_interp_features = _get_max_feature_index(feature_mask) + 1 - return feature_mask, num_interp_features + num_interp_features = _get_max_feature_index(feature_mask_tuple) + 1 + return feature_mask_tuple, num_interp_features class Lime(LimeBase): @@ -766,8 +756,7 @@ class Lime(LimeBase): def __init__( self, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - forward_func: Callable, + forward_func: Callable[..., Tensor], interpretable_model: Optional[Model] = None, # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. similarity_func: Optional[Callable] = None, @@ -887,8 +876,7 @@ def attribute( # type: ignore inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = None, target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, + additional_forward_args: object = None, feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None, n_samples: int = 25, perturbations_per_eval: int = 1, @@ -1133,18 +1121,14 @@ def _attribute_kwargs( # type: ignore inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = None, target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, + additional_forward_args: object = None, feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None, n_samples: int = 25, perturbations_per_eval: int = 1, return_input_shape: bool = True, show_progress: bool = False, - # pyre-fixme[2]: Parameter must be annotated. - **kwargs, + **kwargs: object, ) -> TensorOrTupleOfTensorsGeneric: - # pyre-fixme[6]: For 1st argument expected `Tensor` but got - # `TensorOrTupleOfTensorsGeneric`. is_inputs_tuple = _is_tuple(inputs) formatted_inputs, baselines = _format_input_baseline(inputs, baselines) bsz = formatted_inputs[0].shape[0] @@ -1263,33 +1247,35 @@ def _attribute_kwargs( # type: ignore return coefs @typing.overload - # pyre-fixme[43]: The implementation of `_convert_output_shape` does not accept - # all possible arguments of overload defined on line `1201`. def _convert_output_shape( self, formatted_inp: Tuple[Tensor, ...], feature_mask: Tuple[Tensor, ...], coefs: Tensor, num_interp_features: int, - # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. is_inputs_tuple: Literal[True], ) -> Tuple[Tensor, ...]: ... @typing.overload - # pyre-fixme[43]: The implementation of `_convert_output_shape` does not accept - # all possible arguments of overload defined on line `1211`. def _convert_output_shape( # type: ignore self, formatted_inp: Tuple[Tensor, ...], feature_mask: Tuple[Tensor, ...], coefs: Tensor, num_interp_features: int, - # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. is_inputs_tuple: Literal[False], ) -> Tensor: ... + @typing.overload + def _convert_output_shape( + self, + formatted_inp: Tuple[Tensor, ...], + feature_mask: Tuple[Tensor, ...], + coefs: Tensor, + num_interp_features: int, + is_inputs_tuple: bool, + ) -> Union[Tensor, Tuple[Tensor, ...]]: ... + def _convert_output_shape( self, formatted_inp: Tuple[Tensor, ...],