Skip to content

Commit

Permalink
Fix pyre errors in Kernel Shap (pytorch#1399)
Browse files Browse the repository at this point in the history
Summary:

Initial work on fixing Pyre errors in KernelShap

Reviewed By: csauper

Differential Revision: D64677350
  • Loading branch information
Vivek Miglani authored and facebook-github-bot committed Oct 22, 2024
1 parent 748f133 commit 665b430
Showing 1 changed file with 10 additions and 15 deletions.
25 changes: 10 additions & 15 deletions captum/attr/_core/kernel_shap.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# pyre-strict

from typing import Any, Callable, Generator, Tuple, Union
from typing import Callable, cast, Generator, Tuple, Union

import torch
from captum._utils.models.linear_model import SkLearnLinearRegression
Expand All @@ -27,8 +27,7 @@ class KernelShap(Lime):
https://arxiv.org/abs/1705.07874
"""

# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
def __init__(self, forward_func: Callable) -> None:
def __init__(self, forward_func: Callable[..., Tensor]) -> None:
r"""
Args:
Expand All @@ -50,8 +49,7 @@ def attribute( # type: ignore
inputs: TensorOrTupleOfTensorsGeneric,
baselines: BaselineType = None,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: object = None,
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
n_samples: int = 25,
perturbations_per_eval: int = 1,
Expand Down Expand Up @@ -279,10 +277,7 @@ def attribute( # type: ignore
)
num_features_list = torch.arange(num_interp_features, dtype=torch.float)
denom = num_features_list * (num_interp_features - num_features_list)
# pyre-fixme[58]: `/` is not supported for operand types
# `int` and `torch._tensor.Tensor`.
probs = (num_interp_features - 1) / denom
# pyre-fixme[16]: `float` has no attribute `__setitem__`.
probs = torch.tensor((num_interp_features - 1)) / denom
probs[0] = 0.0
return self._attribute_kwargs(
inputs=inputs,
Expand All @@ -309,8 +304,7 @@ def kernel_shap_similarity_kernel(
_,
__,
interpretable_sample: Tensor,
# pyre-fixme[2]: Parameter must be annotated.
**kwargs,
**kwargs: object,
) -> Tensor:
assert (
"num_interp_features" in kwargs
Expand All @@ -332,8 +326,7 @@ def kernel_shap_similarity_kernel(
def kernel_shap_perturb_generator(
self,
original_inp: Union[Tensor, Tuple[Tensor, ...]],
# pyre-fixme[2]: Parameter must be annotated.
**kwargs,
**kwargs: object,
) -> Generator[Tensor, None, None]:
r"""
Perturbations are sampled by the following process:
Expand Down Expand Up @@ -361,11 +354,13 @@ def kernel_shap_perturb_generator(
device = original_inp.device
else:
device = original_inp[0].device
num_features = kwargs["num_interp_features"]
num_features = cast(int, kwargs["num_interp_features"])
yield torch.ones(1, num_features, device=device, dtype=torch.long)
yield torch.zeros(1, num_features, device=device, dtype=torch.long)
while True:
num_selected_features = kwargs["num_select_distribution"].sample()
num_selected_features = cast(
Categorical, kwargs["num_select_distribution"]
).sample()
rand_vals = torch.randn(1, num_features)
threshold = torch.kthvalue(
rand_vals, num_features - num_selected_features
Expand Down

0 comments on commit 665b430

Please sign in to comment.