Skip to content

Commit

Permalink
Fix pyre errors in Lime (pytorch#1400)
Browse files Browse the repository at this point in the history
Summary:

Initial work on fixing Pyre errors in Lime

Differential Revision: D64677340
  • Loading branch information
Vivek Miglani authored and facebook-github-bot committed Oct 21, 2024
1 parent 9f76ebd commit 7e09289
Showing 1 changed file with 43 additions and 57 deletions.
100 changes: 43 additions & 57 deletions captum/attr/_core/lime.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import typing
import warnings
from collections.abc import Iterator
from typing import Any, Callable, cast, List, Optional, Tuple, Union
from typing import Any, Callable, cast, Generator, List, Literal, Optional, Tuple, Union

import torch
from captum._utils.common import (
Expand All @@ -23,12 +23,7 @@
from captum._utils.models.linear_model import SkLearnLasso
from captum._utils.models.model import Model
from captum._utils.progress import progress
from captum._utils.typing import (
BaselineType,
Literal,
TargetType,
TensorOrTupleOfTensorsGeneric,
)
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
from captum.attr._utils.attribution import PerturbationAttribution
from captum.attr._utils.batching import _batch_example_iterator
from captum.attr._utils.common import (
Expand Down Expand Up @@ -73,18 +68,18 @@ class LimeBase(PerturbationAttribution):

def __init__(
self,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
forward_func: Callable,
forward_func: Callable[..., Tensor],
interpretable_model: Model,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
similarity_func: Callable,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
perturb_func: Callable,
similarity_func: Callable[
...,
Union[float, Tensor],
],
perturb_func: Callable[..., object],
perturb_interpretable_space: bool,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
from_interp_rep_transform: Optional[Callable],
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
to_interp_rep_transform: Optional[Callable],
from_interp_rep_transform: Optional[
Callable[..., Union[Tensor, Tuple[Tensor, ...]]]
],
to_interp_rep_transform: Optional[Callable[..., Tensor]],
) -> None:
r"""
Expand Down Expand Up @@ -249,13 +244,11 @@ def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: object = None,
n_samples: int = 50,
perturbations_per_eval: int = 1,
show_progress: bool = False,
# pyre-fixme[2]: Parameter must be annotated.
**kwargs,
**kwargs: object,
) -> Tensor:
r"""
This method attributes the output of the model with given target index
Expand Down Expand Up @@ -551,7 +544,7 @@ def generate_perturbation() -> (
curr_sample, inputs, **kwargs
)

return interpretable_inp, curr_model_input
return interpretable_inp, curr_model_input # type: ignore

return generate_perturbation

Expand All @@ -568,8 +561,7 @@ def _evaluate_batch(
self,
curr_model_inputs: List[TensorOrTupleOfTensorsGeneric],
expanded_target: TargetType,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
expanded_additional_args: Any,
expanded_additional_args: object,
device: torch.device,
) -> Tensor:
model_out = _run_forward(
Expand Down Expand Up @@ -630,8 +622,7 @@ def default_from_interp_rep_transform(curr_sample, original_inputs, **kwargs):
def get_exp_kernel_similarity_function(
distance_mode: str = "cosine",
kernel_width: float = 1.0,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
) -> Callable:
) -> Callable[..., float]:
r"""
This method constructs an appropriate similarity function to compute
weights for perturbed sample in LIME. Distance between the original
Expand Down Expand Up @@ -680,8 +671,9 @@ def default_exp_kernel(original_inp, perturbed_inp, __, **kwargs):
return default_exp_kernel


# pyre-fixme[2]: Parameter must be annotated.
def default_perturb_func(original_inp, **kwargs) -> Tensor:
def default_perturb_func(
original_inp: TensorOrTupleOfTensorsGeneric, **kwargs: object
) -> Tensor:
assert (
"num_interp_features" in kwargs
), "Must provide num_interp_features to use default interpretable sampling function"
Expand All @@ -690,25 +682,25 @@ def default_perturb_func(original_inp, **kwargs) -> Tensor:
else:
device = original_inp[0].device

probs = torch.ones(1, kwargs["num_interp_features"]) * 0.5
probs = torch.ones(1, cast(int, kwargs["num_interp_features"])) * 0.5
return torch.bernoulli(probs).to(device=device).long()


def construct_feature_mask(
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]],
formatted_inputs: Tuple[Tensor, ...],
) -> Tuple[Tuple[Tensor, ...], int]:
feature_mask_tuple: Tuple[Tensor, ...]
if feature_mask is None:
feature_mask, num_interp_features = _construct_default_feature_mask(
feature_mask_tuple, num_interp_features = _construct_default_feature_mask(
formatted_inputs
)
else:
feature_mask = _format_tensor_into_tuples(feature_mask)
feature_mask_tuple = _format_tensor_into_tuples(feature_mask)
min_interp_features = int(
min(
torch.min(single_mask).item()
# pyre-fixme[16]: `None` has no attribute `__iter__`.
for single_mask in feature_mask
for single_mask in feature_mask_tuple
if single_mask.numel()
)
)
Expand All @@ -718,14 +710,12 @@ def construct_feature_mask(
" start at 0.",
stacklevel=2,
)
feature_mask = tuple(
single_mask - min_interp_features for single_mask in feature_mask
feature_mask_tuple = tuple(
single_mask - min_interp_features for single_mask in feature_mask_tuple
)

# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
# `Optional[typing.Tuple[typing.Any, ...]]`.
num_interp_features = _get_max_feature_index(feature_mask) + 1
return feature_mask, num_interp_features
num_interp_features = _get_max_feature_index(feature_mask_tuple) + 1
return feature_mask_tuple, num_interp_features


class Lime(LimeBase):
Expand Down Expand Up @@ -766,8 +756,7 @@ class Lime(LimeBase):

def __init__(
self,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
forward_func: Callable,
forward_func: Callable[..., Tensor],
interpretable_model: Optional[Model] = None,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
similarity_func: Optional[Callable] = None,
Expand Down Expand Up @@ -887,8 +876,7 @@ def attribute( # type: ignore
inputs: TensorOrTupleOfTensorsGeneric,
baselines: BaselineType = None,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: object = None,
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
n_samples: int = 25,
perturbations_per_eval: int = 1,
Expand Down Expand Up @@ -1133,18 +1121,14 @@ def _attribute_kwargs( # type: ignore
inputs: TensorOrTupleOfTensorsGeneric,
baselines: BaselineType = None,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: object = None,
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
n_samples: int = 25,
perturbations_per_eval: int = 1,
return_input_shape: bool = True,
show_progress: bool = False,
# pyre-fixme[2]: Parameter must be annotated.
**kwargs,
**kwargs: object,
) -> TensorOrTupleOfTensorsGeneric:
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
# `TensorOrTupleOfTensorsGeneric`.
is_inputs_tuple = _is_tuple(inputs)
formatted_inputs, baselines = _format_input_baseline(inputs, baselines)
bsz = formatted_inputs[0].shape[0]
Expand Down Expand Up @@ -1263,33 +1247,35 @@ def _attribute_kwargs( # type: ignore
return coefs

@typing.overload
# pyre-fixme[43]: The implementation of `_convert_output_shape` does not accept
# all possible arguments of overload defined on line `1201`.
def _convert_output_shape(
self,
formatted_inp: Tuple[Tensor, ...],
feature_mask: Tuple[Tensor, ...],
coefs: Tensor,
num_interp_features: int,
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
is_inputs_tuple: Literal[True],
) -> Tuple[Tensor, ...]: ...

@typing.overload
# pyre-fixme[43]: The implementation of `_convert_output_shape` does not accept
# all possible arguments of overload defined on line `1211`.
def _convert_output_shape( # type: ignore
self,
formatted_inp: Tuple[Tensor, ...],
feature_mask: Tuple[Tensor, ...],
coefs: Tensor,
num_interp_features: int,
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
is_inputs_tuple: Literal[False],
) -> Tensor: ...

@typing.overload
def _convert_output_shape(
self,
formatted_inp: Tuple[Tensor, ...],
feature_mask: Tuple[Tensor, ...],
coefs: Tensor,
num_interp_features: int,
is_inputs_tuple: bool,
) -> Union[Tensor, Tuple[Tensor, ...]]: ...

def _convert_output_shape(
self,
formatted_inp: Tuple[Tensor, ...],
Expand Down

0 comments on commit 7e09289

Please sign in to comment.