From 069cbf3fb2a1c54a9d62d19c7a7aba7595936b21 Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Tue, 22 Oct 2024 07:28:41 -0700 Subject: [PATCH] Fix pyre errors in Kernel Shap (#1399) Summary: Initial work on fixing Pyre errors in KernelShap Reviewed By: csauper Differential Revision: D64677350 --- captum/attr/_core/kernel_shap.py | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/captum/attr/_core/kernel_shap.py b/captum/attr/_core/kernel_shap.py index 8b6fb44cb..89d22990d 100644 --- a/captum/attr/_core/kernel_shap.py +++ b/captum/attr/_core/kernel_shap.py @@ -2,7 +2,7 @@ # pyre-strict -from typing import Any, Callable, Generator, Tuple, Union +from typing import Callable, cast, Generator, Tuple, Union import torch from captum._utils.models.linear_model import SkLearnLinearRegression @@ -27,8 +27,7 @@ class KernelShap(Lime): https://arxiv.org/abs/1705.07874 """ - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - def __init__(self, forward_func: Callable) -> None: + def __init__(self, forward_func: Callable[..., Tensor]) -> None: r""" Args: @@ -50,8 +49,7 @@ def attribute( # type: ignore inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = None, target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, + additional_forward_args: object = None, feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None, n_samples: int = 25, perturbations_per_eval: int = 1, @@ -279,10 +277,7 @@ def attribute( # type: ignore ) num_features_list = torch.arange(num_interp_features, dtype=torch.float) denom = num_features_list * (num_interp_features - num_features_list) - # pyre-fixme[58]: `/` is not supported for operand types - # `int` and `torch._tensor.Tensor`. - probs = (num_interp_features - 1) / denom - # pyre-fixme[16]: `float` has no attribute `__setitem__`. + probs = torch.tensor((num_interp_features - 1)) / denom probs[0] = 0.0 return self._attribute_kwargs( inputs=inputs, @@ -309,8 +304,7 @@ def kernel_shap_similarity_kernel( _, __, interpretable_sample: Tensor, - # pyre-fixme[2]: Parameter must be annotated. - **kwargs, + **kwargs: object, ) -> Tensor: assert ( "num_interp_features" in kwargs @@ -332,8 +326,7 @@ def kernel_shap_similarity_kernel( def kernel_shap_perturb_generator( self, original_inp: Union[Tensor, Tuple[Tensor, ...]], - # pyre-fixme[2]: Parameter must be annotated. - **kwargs, + **kwargs: object, ) -> Generator[Tensor, None, None]: r""" Perturbations are sampled by the following process: @@ -361,11 +354,13 @@ def kernel_shap_perturb_generator( device = original_inp.device else: device = original_inp[0].device - num_features = kwargs["num_interp_features"] + num_features = cast(int, kwargs["num_interp_features"]) yield torch.ones(1, num_features, device=device, dtype=torch.long) yield torch.zeros(1, num_features, device=device, dtype=torch.long) while True: - num_selected_features = kwargs["num_select_distribution"].sample() + num_selected_features = cast( + Categorical, kwargs["num_select_distribution"] + ).sample() rand_vals = torch.randn(1, num_features) threshold = torch.kthvalue( rand_vals, num_features - num_selected_features