Skip to content

Commit

Permalink
Fix remaining pyre errors in infidelity.py (#1414)
Browse files Browse the repository at this point in the history
Summary:

Fix pyre/mypy errors in infidelity.py. Introduce new BaselineTupleType

Differential Revision: D64998803
  • Loading branch information
craymichael authored and facebook-github-bot committed Oct 25, 2024
1 parent 183b820 commit d101194
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 72 deletions.
3 changes: 2 additions & 1 deletion captum/_utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
TupleOrTensorOrBoolGeneric = TypeVar("TupleOrTensorOrBoolGeneric", Tuple, Tensor, bool)
ModuleOrModuleList = TypeVar("ModuleOrModuleList", Module, List[Module])
TargetType = Union[None, int, Tuple[int, ...], Tensor, List[Tuple[int, ...]], List[int]]
BaselineType = Union[None, Tensor, int, float, Tuple[Union[Tensor, int, float], ...]]
BaselineTupleType = Union[None, Tuple[Union[Tensor, int, float], ...]]
BaselineType = Union[None, Tensor, int, float, BaselineTupleType]

TensorLikeList1D = List[float]
TensorLikeList2D = List[TensorLikeList1D]
Expand Down
132 changes: 61 additions & 71 deletions captum/metrics/_core/infidelity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# pyre-strict

from typing import Any, Callable, cast, Optional, Tuple, Union
from typing import Callable, cast, Optional, Tuple, Union

import torch
from captum._utils.common import (
Expand All @@ -15,7 +15,12 @@
ExpansionTypes,
safe_div,
)
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
from captum._utils.typing import (
BaselineTupleType,
BaselineType,
TargetType,
TensorOrTupleOfTensorsGeneric,
)
from captum.log import log_usage
from captum.metrics._utils.batching import _divide_and_aggregate_metrics
from torch import Tensor
Expand All @@ -35,14 +40,14 @@ def infidelity_perturb_func_decorator(
]:
r"""An auxiliary, decorator function that helps with computing
perturbations given perturbed inputs. It can be useful for cases
when `pertub_func` returns only perturbed inputs and we
when `perturb_func` returns only perturbed inputs and we
internally compute the perturbations as
(input - perturbed_input) / (input - baseline) if
multiply_by_inputs is set to True and
(input - perturbed_input) otherwise.
If users decorate their `pertub_func` with
`@infidelity_perturb_func_decorator` function then their `pertub_func`
If users decorate their `perturb_func` with
`@infidelity_perturb_func_decorator` function then their `perturb_func`
needs to only return perturbed inputs.
Args:
Expand All @@ -54,15 +59,15 @@ def infidelity_perturb_func_decorator(
"""

def sub_infidelity_perturb_func_decorator(
pertub_func: Callable[..., TensorOrTupleOfTensorsGeneric]
perturb_func: Callable[..., TensorOrTupleOfTensorsGeneric]
) -> Callable[
[TensorOrTupleOfTensorsGeneric, BaselineType],
Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]],
]:
r"""
Args:
pertub_func(Callable): Input perturbation function that takes inputs
perturb_func(Callable): Input perturbation function that takes inputs
and optionally baselines and returns perturbed inputs
Returns:
Expand All @@ -87,9 +92,9 @@ def default_perturb_func(
) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]:
r""" """
inputs_perturbed: TensorOrTupleOfTensorsGeneric = (
pertub_func(inputs, baselines)
perturb_func(inputs, baselines)
if baselines is not None
else pertub_func(inputs)
else perturb_func(inputs)
)
inputs_perturbed_formatted = _format_tensor_into_tuples(inputs_perturbed)
inputs_formatted = _format_tensor_into_tuples(inputs)
Expand Down Expand Up @@ -135,16 +140,14 @@ def default_perturb_func(

@log_usage()
def infidelity(
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
forward_func: Callable,
forward_func: Callable[..., Tensor],
perturb_func: Callable[
..., Tuple[TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric]
],
inputs: TensorOrTupleOfTensorsGeneric,
attributions: TensorOrTupleOfTensorsGeneric,
baselines: BaselineType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: object = None,
target: TargetType = None,
n_perturb_samples: int = 10,
max_examples_per_batch: Optional[int] = None,
Expand Down Expand Up @@ -417,38 +420,35 @@ def infidelity(
>>> infid = infidelity(net, perturb_fn, input, attribution)
"""
# perform argument formattings
inputs = _format_tensor_into_tuples(inputs) # type: ignore
inputs_formatted = _format_tensor_into_tuples(inputs)
baselines_formatted: BaselineTupleType = None
if baselines is not None:
baselines = _format_baseline(baselines, cast(Tuple[Tensor, ...], inputs))
baselines_formatted = _format_baseline(baselines, inputs_formatted)
additional_forward_args = _format_additional_forward_args(additional_forward_args)
attributions = _format_tensor_into_tuples(attributions) # type: ignore
attributions_formatted = _format_tensor_into_tuples(attributions)

# Make sure that inputs and corresponding attributions have matching sizes.
assert len(inputs) == len(attributions), (
"""The number of tensors in the inputs and
attributions must match. Found number of tensors in the inputs is: {} and in the
attributions: {}"""
).format(len(inputs), len(attributions))
for inp, attr in zip(inputs, attributions):
assert len(inputs_formatted) == len(attributions_formatted), (
"The number of tensors in the inputs and attributions must match. "
f"Found number of tensors in the inputs is: {len(inputs_formatted)} and in "
f"the attributions: {len(attributions_formatted)}"
)
for inp, attr in zip(inputs_formatted, attributions_formatted):
assert inp.shape == attr.shape, (
"""Inputs and attributions must have
matching shapes. One of the input tensor's shape is {} and the
attribution tensor's shape is: {}"""
# pyre-fixme[16]: Module `attr` has no attribute `shape`.
).format(inp.shape, attr.shape)
"Inputs and attributions must have matching shapes. "
f"One of the input tensor's shape is {inp.shape} and the "
f"attribution tensor's shape is: {attr.shape}"
)

bsz = inputs[0].size(0)
bsz = inputs_formatted[0].size(0)

_next_infidelity_tensors = _make_next_infidelity_tensors_func(
forward_func,
bsz,
# error: Argument 3 to "_make_next_infidelity_tensors_func" has incompatible
# type "Callable[..., tuple[Tensor, Tensor]]"; expected
# "Callable[..., tuple[tuple[Tensor, ...], tuple[Tensor, ...]]]" [arg-type]
perturb_func, # type: ignore
inputs,
baselines,
attributions,
perturb_func,
inputs_formatted,
baselines_formatted,
attributions_formatted,
additional_forward_args,
target,
normalize,
Expand All @@ -458,7 +458,7 @@ def infidelity(
# if not normalize, directly return aggrgated MSE ((a-b)^2,)
# else return aggregated MSE's polynomial expansion tensors (a^2, ab, b^2)
agg_tensors = _divide_and_aggregate_metrics(
cast(Tuple[Tensor, ...], inputs),
inputs_formatted,
n_perturb_samples,
_next_infidelity_tensors,
agg_func=_sum_infidelity_tensors,
Expand All @@ -472,11 +472,7 @@ def infidelity(
beta = safe_div(beta_num, beta_denorm)

infidelity_values = (
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
beta**2 * agg_tensors[0]
- 2 * beta * agg_tensors[1]
+ agg_tensors[2]
beta * beta * agg_tensors[0] - 2 * beta * agg_tensors[1] + agg_tensors[2]
)
else:
infidelity_values = agg_tensors[0]
Expand All @@ -491,8 +487,8 @@ def _generate_perturbations(
perturb_func: Callable[
..., Tuple[TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric]
],
inputs: TensorOrTupleOfTensorsGeneric,
baselines: BaselineType,
inputs: Tuple[Tensor, ...],
baselines: BaselineTupleType,
) -> Tuple[TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric]:
r"""
The perturbations are generated for each example
Expand All @@ -507,14 +503,12 @@ def call_perturb_func() -> (
Tuple[TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric]
):
r""" """
baselines_pert = None
baselines_pert: BaselineType = None
inputs_pert: Union[Tensor, Tuple[Tensor, ...]]
if len(inputs_expanded) == 1:
inputs_pert = inputs_expanded[0]
if baselines_expanded is not None:
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type
# parameter.
baselines_pert = cast(Tuple, baselines_expanded)[0]
baselines_pert = baselines_expanded[0]
else:
inputs_pert = inputs_expanded
baselines_pert = baselines_expanded
Expand All @@ -539,9 +533,7 @@ def call_perturb_func() -> (
and baseline.shape[0] > 1
else baseline
)
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type
# parameter.
for input, baseline in zip(inputs, cast(Tuple, baselines))
for input, baseline in zip(inputs, baselines)
)

return call_perturb_func()
Expand All @@ -554,34 +546,32 @@ def _validate_inputs_and_perturbations(
) -> None:
# asserts the sizes of the perturbations and inputs
assert len(perturbations) == len(inputs), (
"""The number of perturbed
inputs and corresponding perturbations must have the same number of
elements. Found number of inputs is: {} and perturbations:
{}"""
).format(len(perturbations), len(inputs))
"The number of perturbed "
"inputs and corresponding perturbations must have the same number of "
f"elements. Found number of inputs is: {len(perturbations)} and "
f"perturbations: {len(inputs)}"
)

# asserts the shapes of the perturbations and perturbed inputs
for perturb, input_perturbed in zip(perturbations, inputs_perturbed):
assert perturb[0].shape == input_perturbed[0].shape, (
"""Perturbed input
and corresponding perturbation must have the same shape and
dimensionality. Found perturbation shape is: {} and the input shape
is: {}"""
).format(perturb[0].shape, input_perturbed[0].shape)
"Perturbed input "
"and corresponding perturbation must have the same shape and "
f"dimensionality. Found perturbation shape is: {perturb[0].shape} "
f"and the input shape is: {input_perturbed[0].shape}"
)


def _make_next_infidelity_tensors_func(
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
forward_func: Callable,
forward_func: Callable[..., Tensor],
bsz: int,
perturb_func: Callable[
..., Tuple[TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric]
],
inputs: TensorOrTupleOfTensorsGeneric,
baselines: BaselineType,
attributions: TensorOrTupleOfTensorsGeneric,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
inputs: Tuple[Tensor, ...],
baselines: BaselineTupleType,
attributions: Tuple[Tensor, ...],
additional_forward_args: object = None,
target: TargetType = None,
normalize: bool = False,
) -> Callable[[int], Union[Tuple[Tensor], Tuple[Tensor, Tensor, Tensor]]]:
Expand All @@ -597,7 +587,7 @@ def _next_infidelity_tensors(
inputs_perturbed_formatted = _format_tensor_into_tuples(inputs_perturbed)

_validate_inputs_and_perturbations(
cast(Tuple[Tensor, ...], inputs),
inputs,
inputs_perturbed_formatted,
perturbations_formatted,
)
Expand Down Expand Up @@ -666,7 +656,7 @@ def _next_infidelity_tensors(
return _next_infidelity_tensors


# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def _sum_infidelity_tensors(agg_tensors, tensors):
def _sum_infidelity_tensors(
agg_tensors: Tuple[Tensor, ...], tensors: Tuple[Tensor, ...]
) -> Tuple[Tensor, ...]:
return tuple(agg_t + t for agg_t, t in zip(agg_tensors, tensors))

0 comments on commit d101194

Please sign in to comment.