Skip to content

Commit

Permalink
Fix pyre errors in Feature Ablation (pytorch#1392)
Browse files Browse the repository at this point in the history
Summary:

Initial work on fixing Pyre errors in Feature Ablation

Differential Revision: D64677337
  • Loading branch information
Vivek Miglani authored and facebook-github-bot committed Oct 21, 2024
1 parent 78d9bdc commit 2f2f3f5
Showing 1 changed file with 13 additions and 31 deletions.
44 changes: 13 additions & 31 deletions captum/attr/_core/feature_ablation.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ class FeatureAblation(PerturbationAttribution):
first dimension (i.e. a feature mask requires to be applied to all inputs).
"""

# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
def __init__(self, forward_func: Callable) -> None:
def __init__(
self, forward_func: Callable[..., Union[int, float, Tensor, Future[Tensor]]]
) -> None:
r"""
Args:
Expand All @@ -74,8 +75,7 @@ def attribute(
inputs: TensorOrTupleOfTensorsGeneric,
baselines: BaselineType = None,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: object = None,
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
perturbations_per_eval: int = 1,
show_progress: bool = False,
Expand Down Expand Up @@ -261,17 +261,13 @@ def attribute(
"""
# Keeps track whether original input is a tuple or not before
# converting it into a tuple.
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
# `TensorOrTupleOfTensorsGeneric`.
is_inputs_tuple = _is_tuple(inputs)

formatted_inputs, baselines = _format_input_baseline(inputs, baselines)
formatted_additional_forward_args = _format_additional_forward_args(
additional_forward_args
)
num_examples = formatted_inputs[0].shape[0]
# pyre-fixme[6]: For 2nd argument expected `Tuple[Tensor, ...]` but got
# `TensorOrTupleOfTensorsGeneric`.
formatted_feature_mask = _format_feature_mask(feature_mask, formatted_inputs)

assert (
Expand Down Expand Up @@ -384,8 +380,6 @@ def attribute(
# pyre-fixme[7]: Expected `Variable[TensorOrTupleOfTensorsGeneric <:
# [Tensor, typing.Tuple[Tensor, ...]]]`
# but got `Union[Tensor, typing.Tuple[Tensor, ...]]`.
# pyre-fixme[6]: In call `FeatureAblation._generate_result`,
# for 3rd positional argument, expected `bool` but got `Literal[]`.
return self._generate_result(total_attrib, weights, is_inputs_tuple) # type: ignore # noqa: E501 line too long

def _initial_eval_to_processed_initial_eval_fut(
Expand Down Expand Up @@ -414,8 +408,7 @@ def attribute_future(
inputs: TensorOrTupleOfTensorsGeneric,
baselines: BaselineType = None,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: object = None,
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
perturbations_per_eval: int = 1,
show_progress: bool = False,
Expand All @@ -428,8 +421,6 @@ def attribute_future(

# Keeps track whether original input is a tuple or not before
# converting it into a tuple.
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
# `TensorOrTupleOfTensorsGeneric`.
is_inputs_tuple = _is_tuple(inputs)
formatted_inputs, baselines = _format_input_baseline(inputs, baselines)
formatted_additional_forward_args = _format_additional_forward_args(
Expand Down Expand Up @@ -660,13 +651,11 @@ def _eval_fut_to_ablated_out_fut(
) from e
return result

# pyre-fixme[3]: Return type must be specified as type that does not contain `Any`
def _ith_input_ablation_generator(
self,
i: int,
inputs: TensorOrTupleOfTensorsGeneric,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_args: Any,
additional_args: object,
target: TargetType,
baselines: BaselineType,
input_mask: Union[None, Tensor, Tuple[Tensor, ...]],
Expand All @@ -675,7 +664,7 @@ def _ith_input_ablation_generator(
) -> Generator[
Tuple[
Tuple[Tensor, ...],
Any,
object,
TargetType,
Tensor,
],
Expand Down Expand Up @@ -705,10 +694,9 @@ def _ith_input_ablation_generator(
perturbations_per_eval = min(perturbations_per_eval, num_features)
baseline = baselines[i] if isinstance(baselines, tuple) else baselines
if isinstance(baseline, torch.Tensor):
# pyre-fixme[58]: `+` is not supported for operand types `Tuple[int]`
# and `Size`.
baseline = baseline.reshape((1,) + baseline.shape)
baseline = baseline.reshape((1,) + tuple(baseline.shape))

additional_args_repeated: object
if perturbations_per_eval > 1:
# Repeat features and additional args for batch size.
all_features_repeated = [
Expand All @@ -727,6 +715,7 @@ def _ith_input_ablation_generator(
target_repeated = target

num_features_processed = min_feature
current_additional_args: object
while num_features_processed < num_features:
current_num_ablated_features = min(
perturbations_per_eval, num_features - num_features_processed
Expand Down Expand Up @@ -762,9 +751,7 @@ def _ith_input_ablation_generator(
# dimension of this tensor.
current_reshaped = current_features[i].reshape(
(current_num_ablated_features, -1)
# pyre-fixme[58]: `+` is not supported for operand types
# `Tuple[int, int]` and `Size`.
+ current_features[i].shape[1:]
+ tuple(current_features[i].shape[1:])
)

ablated_features, current_mask = self._construct_ablated_input(
Expand All @@ -780,10 +767,7 @@ def _ith_input_ablation_generator(
# (current_num_ablated_features * num_examples, inputs[i].shape[1:]),
# which can be provided to the model as input.
current_features[i] = ablated_features.reshape(
(-1,)
# pyre-fixme[58]: `+` is not supported for operand types
# `Tuple[int]` and `Size`.
+ ablated_features.shape[2:]
(-1,) + tuple(ablated_features.shape[2:])
)
yield tuple(
current_features
Expand Down Expand Up @@ -818,9 +802,7 @@ def _construct_ablated_input(
thus counted towards ablations for that feature) and 0s otherwise.
"""
current_mask = torch.stack(
# pyre-fixme[6]: For 1st argument expected `Union[List[Tensor],
# Tuple[Tensor, ...]]` but got `List[Union[bool, Tensor]]`.
[input_mask == j for j in range(start_feature, end_feature)], # type: ignore # noqa: E501 line too long
cast(List[Tensor], [input_mask == j for j in range(start_feature, end_feature)]), # type: ignore # noqa: E501 line too long
dim=0,
).long()
current_mask = current_mask.to(expanded_input.device)
Expand Down

0 comments on commit 2f2f3f5

Please sign in to comment.