Skip to content

Commit

Permalink
Fix pyre errors in Shapley Values
Browse files Browse the repository at this point in the history
Summary: Initial work on fixing Pyre errors in Shapley Values

Differential Revision: D64677339
  • Loading branch information
Vivek Miglani authored and facebook-github-bot committed Oct 21, 2024
1 parent 18c525e commit aa8ec5c
Showing 1 changed file with 28 additions and 66 deletions.
94 changes: 28 additions & 66 deletions captum/attr/_core/shapley_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import itertools
import math
import warnings
from typing import Any, Callable, cast, Iterable, Sequence, Tuple, Union
from typing import Callable, cast, Iterable, Sequence, Tuple, Union

import torch
from captum._utils.common import (
Expand Down Expand Up @@ -56,9 +56,7 @@ def _shape_feature_mask(
f"input shape: {inp.shape}, feature mask shape {mask.shape}"
)
if mask.dim() < inp.dim():
# pyre-fixme[58]: `+` is not supported for operand types `Tuple[int,
# ...]` and `Size`.
mask = mask.reshape((1,) * (inp.dim() - mask.dim()) + mask.shape)
mask = mask.reshape((1,) * (inp.dim() - mask.dim()) + tuple(mask.shape))

mask_list.append(mask)

Expand Down Expand Up @@ -89,8 +87,7 @@ class ShapleyValueSampling(PerturbationAttribution):
https://pdfs.semanticscholar.org/7715/bb1070691455d1fcfc6346ff458dbca77b2c.pdf
"""

# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
def __init__(self, forward_func: Callable) -> None:
def __init__(self, forward_func: Callable[..., Union[int, float, Tensor]]) -> None:
r"""
Args:
Expand All @@ -111,8 +108,7 @@ def attribute(
inputs: TensorOrTupleOfTensorsGeneric,
baselines: BaselineType = None,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: object = None,
feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None,
n_samples: int = 25,
perturbations_per_eval: int = 1,
Expand Down Expand Up @@ -301,45 +297,25 @@ def attribute(
"""
# Keeps track whether original input is a tuple or not before
# converting it into a tuple.
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
# `TensorOrTupleOfTensorsGeneric`.
is_inputs_tuple = _is_tuple(inputs)
# pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as
# `Tuple[Tensor, ...]`.
inputs, baselines = _format_input_baseline(inputs, baselines)
inputs_tuple, baselines = _format_input_baseline(inputs, baselines)
additional_forward_args = _format_additional_forward_args(
additional_forward_args
)
# pyre-fixme[9]: feature_mask has type
# `Optional[Variable[TensorOrTupleOfTensorsGeneric <: [Tensor,
# typing.Tuple[Tensor, ...]]]]`; used as `Tuple[Tensor, ...]`.
# pyre-fixme[6]: For 2nd argument expected `Tuple[Tensor, ...]` but got
# `TensorOrTupleOfTensorsGeneric`.
feature_mask = _format_feature_mask(feature_mask, inputs)
# pyre-fixme[9]: feature_mask has type
# `Optional[Variable[TensorOrTupleOfTensorsGeneric <: [Tensor,
# typing.Tuple[Tensor, ...]]]]`; used as `Tuple[Tensor, ...]`.
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
# `Optional[Variable[TensorOrTupleOfTensorsGeneric <: [Tensor,
# typing.Tuple[Tensor, ...]]]]`.
# pyre-fixme[6]: For 2nd argument expected `Tuple[Tensor, ...]` but got
# `TensorOrTupleOfTensorsGeneric`.
feature_mask = _shape_feature_mask(feature_mask, inputs)
formatted_feature_mask = _format_feature_mask(feature_mask, inputs_tuple)
reshaped_feature_mask = _shape_feature_mask(
formatted_feature_mask, inputs_tuple
)

assert (
isinstance(perturbations_per_eval, int) and perturbations_per_eval >= 1
), "Ablations per evaluation must be at least 1."

with torch.no_grad():
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
# `TensorOrTupleOfTensorsGeneric`.
baselines = _tensorize_baseline(inputs, baselines)
num_examples = inputs[0].shape[0]
baselines = _tensorize_baseline(inputs_tuple, baselines)
num_examples = inputs_tuple[0].shape[0]

# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
# `Optional[Variable[TensorOrTupleOfTensorsGeneric <: [Tensor,
# typing.Tuple[Tensor, ...]]]]`.
total_features = _get_max_feature_index(feature_mask) + 1
total_features = _get_max_feature_index(reshaped_feature_mask) + 1

if show_progress:
attr_progress = progress(
Expand All @@ -362,7 +338,7 @@ def attribute(
initial_eval,
num_examples,
perturbations_per_eval,
feature_mask,
reshaped_feature_mask,
allow_multi_outputs=True,
)

Expand All @@ -372,11 +348,11 @@ def attribute(
# attr shape (*output_shape, *input_feature_shape)
total_attrib = [
torch.zeros(
output_shape + input.shape[1:],
tuple(output_shape) + tuple(input.shape[1:]),
dtype=torch.float,
device=inputs[0].device,
device=inputs_tuple[0].device,
)
for input in inputs
for input in inputs_tuple
]

iter_count = 0
Expand All @@ -393,17 +369,11 @@ def attribute(
current_target,
current_masks,
) in self._perturbation_generator(
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]`
# but got `TensorOrTupleOfTensorsGeneric`.
inputs,
inputs_tuple,
additional_forward_args,
target,
baselines,
# pyre-fixme[6]: For 5th argument expected
# `TensorOrTupleOfTensorsGeneric` but got
# `Optional[Variable[TensorOrTupleOfTensorsGeneric <: [Tensor,
# typing.Tuple[Tensor, ...]]]]`.
feature_mask,
reshaped_feature_mask,
feature_permutation,
perturbations_per_eval,
):
Expand Down Expand Up @@ -445,9 +415,7 @@ def attribute(
# have the same dim as the mask tensor.
formatted_eval_diff = eval_diff.reshape(
(-1,)
# pyre-fixme[58]: `+` is not supported for operand types
# `Tuple[int]` and `Size`.
+ output_shape
+ tuple(output_shape)
+ (len(inputs[j].shape) - 1) * (1,)
)

Expand All @@ -460,11 +428,9 @@ def attribute(
# )
cur_mask = current_masks[j]
cur_mask = cur_mask.reshape(
cur_mask.shape[:2]
tuple(cur_mask.shape[:2])
+ (len(output_shape) - 1) * (1,)
# pyre-fixme[58]: `+` is not supported for operand types
# `Tuple[int, ...]` and `Size`.
+ cur_mask.shape[2:]
+ tuple(cur_mask.shape[2:])
)

# aggregate n_perturb
Expand Down Expand Up @@ -495,18 +461,16 @@ def attribute_future(self) -> Callable:
"attribute_future is not implemented for ShapleyValueSampling"
)

# pyre-fixme[3]: Return annotation cannot contain `Any`.
def _perturbation_generator(
self,
inputs: Tuple[Tensor, ...],
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_args: Any,
additional_args: object,
target: TargetType,
baselines: Tuple[Tensor, ...],
input_masks: TensorOrTupleOfTensorsGeneric,
feature_permutation: Sequence[int],
perturbations_per_eval: int,
) -> Iterable[Tuple[Tuple[Tensor, ...], Any, TargetType, Tuple[Tensor, ...]]]:
) -> Iterable[Tuple[Tuple[Tensor, ...], object, TargetType, Tuple[Tensor, ...]]]:
"""
This method is a generator which yields each perturbation to be evaluated
including inputs, additional_forward_args, targets, and mask.
Expand Down Expand Up @@ -578,9 +542,9 @@ def _perturbation_generator(
combined_masks,
)

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def _get_n_evaluations(self, total_features, n_samples, perturbations_per_eval):
def _get_n_evaluations(
self, total_features: int, n_samples: int, perturbations_per_eval: int
) -> int:
"""return the total number of forward evaluations needed"""
return math.ceil(total_features / perturbations_per_eval) * n_samples

Expand Down Expand Up @@ -642,8 +606,7 @@ class ShapleyValues(ShapleyValueSampling):
evaluations, and we plan to add this approach in the future.
"""

# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
def __init__(self, forward_func: Callable) -> None:
def __init__(self, forward_func: Callable[..., Union[int, float, Tensor]]) -> None:
r"""
Args:
Expand All @@ -664,8 +627,7 @@ def attribute(
inputs: TensorOrTupleOfTensorsGeneric,
baselines: BaselineType = None,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: object = None,
feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None,
perturbations_per_eval: int = 1,
show_progress: bool = False,
Expand Down

0 comments on commit aa8ec5c

Please sign in to comment.