Skip to content

Commit

Permalink
Fix layer gradient shap pyre fixme issues (#1471)
Browse files Browse the repository at this point in the history
Summary:

Fixing unresolved pyre fixme issues in corresponding file

Reviewed By: cyrjano

Differential Revision: D67705670
  • Loading branch information
Vivek Miglani authored and facebook-github-bot committed Dec 30, 2024
1 parent 6df7e0a commit 477d737
Showing 1 changed file with 22 additions and 36 deletions.
58 changes: 22 additions & 36 deletions captum/attr/_core/layer/layer_gradient_shap.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ class LayerGradientShap(LayerAttribution, GradientAttribution):

def __init__(
self,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
forward_func: Callable,
forward_func: Callable[..., Tensor],
layer: Module,
device_ids: Union[None, List[int]] = None,
multiply_by_inputs: bool = True,
Expand Down Expand Up @@ -104,13 +103,12 @@ def __init__(
self._multiply_by_inputs = multiply_by_inputs

@typing.overload
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
# arguments of overload defined on line `106`.
def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
baselines: Union[TensorOrTupleOfTensorsGeneric, Callable],
baselines: Union[
TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric]
],
n_samples: int = 5,
stdevs: Union[float, Tuple[float, ...]] = 0.0,
target: TargetType = None,
Expand All @@ -121,13 +119,12 @@ def attribute(
) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ...

@typing.overload
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
# arguments of overload defined on line `120`.
def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
baselines: Union[TensorOrTupleOfTensorsGeneric, Callable],
baselines: Union[
TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric]
],
n_samples: int = 5,
stdevs: Union[float, Tuple[float, ...]] = 0.0,
target: TargetType = None,
Expand All @@ -137,11 +134,14 @@ def attribute(
) -> Union[Tensor, Tuple[Tensor, ...]]: ...

@log_usage()
# pyre-fixme[43]: This definition does not have the same decorators as the
# preceding overload(s).
def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
baselines: Union[TensorOrTupleOfTensorsGeneric, Callable],
baselines: Union[
TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric]
],
n_samples: int = 5,
stdevs: Union[float, Tuple[float, ...]] = 0.0,
target: TargetType = None,
Expand Down Expand Up @@ -294,17 +294,10 @@ def attribute(
"""
# since `baselines` is a distribution, we can generate it using a function
# rather than passing it as an input argument
# pyre-fixme[9]: baselines has type `Union[typing.Callable[..., typing.Any],
# Variable[TensorOrTupleOfTensorsGeneric <: [Tensor, typing.Tuple[Tensor,
# ...]]]]`; used as `Tuple[Tensor, ...]`.
baselines = _format_callable_baseline(baselines, inputs)
# pyre-fixme[16]: Item `Callable` of `Union[(...) -> Any,
# TensorOrTupleOfTensorsGeneric]` has no attribute `__getitem__`.
assert isinstance(baselines[0], torch.Tensor), (
formatted_baselines = _format_callable_baseline(baselines, inputs)
assert isinstance(formatted_baselines[0], torch.Tensor), (
"Baselines distribution has to be provided in a form "
# pyre-fixme[16]: Item `Callable` of `Union[(...) -> Any,
# TensorOrTupleOfTensorsGeneric]` has no attribute `__getitem__`.
"of a torch.Tensor {}.".format(baselines[0])
"of a torch.Tensor {}.".format(formatted_baselines[0])
)

input_min_baseline_x_grad = LayerInputBaselineXGradient(
Expand All @@ -323,7 +316,7 @@ def attribute(
nt_samples=n_samples,
stdevs=stdevs,
draw_baseline_from_distrib=True,
baselines=baselines,
baselines=formatted_baselines,
target=target,
additional_forward_args=additional_forward_args,
return_convergence_delta=return_convergence_delta,
Expand All @@ -343,8 +336,7 @@ def multiplies_by_inputs(self) -> bool:
class LayerInputBaselineXGradient(LayerAttribution, GradientAttribution):
def __init__(
self,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
forward_func: Callable,
forward_func: Callable[..., Tensor],
layer: Module,
device_ids: Union[None, List[int]] = None,
multiply_by_inputs: bool = True,
Expand Down Expand Up @@ -436,7 +428,7 @@ def attribute( # type: ignore
)
grads, _ = compute_layer_gradients_and_eval(
self.forward_func,
self.layer,
cast(Module, self.layer),
input_baseline_scaled,
target,
additional_forward_args,
Expand All @@ -448,7 +440,7 @@ def attribute( # type: ignore
attr_baselines = _forward_layer_eval(
self.forward_func,
baselines,
self.layer,
cast(Module, self.layer),
additional_forward_args=additional_forward_args,
device_ids=self.device_ids,
attribute_to_layer_input=attribute_to_layer_input,
Expand All @@ -457,19 +449,15 @@ def attribute( # type: ignore
attr_inputs = _forward_layer_eval(
self.forward_func,
inputs,
self.layer,
cast(Module, self.layer),
additional_forward_args=additional_forward_args,
device_ids=self.device_ids,
attribute_to_layer_input=attribute_to_layer_input,
)

attributions: Tuple[Tensor, ...]
if self.multiplies_by_inputs:
input_baseline_diffs = tuple(
# pyre-fixme[58]: `-` is not supported for operand types
# `typing.Tuple[torch._tensor.Tensor, ...]` and
# `typing.Tuple[torch._tensor.Tensor, ...]`.
input - baseline
for input, baseline in zip(attr_inputs, attr_baselines)
input - baseline for input, baseline in zip(attr_inputs, attr_baselines)
)
attributions = tuple(
input_baseline_diff * grad
Expand All @@ -481,8 +469,6 @@ def attribute( # type: ignore
return _compute_conv_delta_and_format_attrs(
self,
return_convergence_delta,
# pyre-fixme[6]: For 3rd argument expected `Tuple[Tensor, ...]` but got
# `Union[List[typing.Tuple[Tensor, ...]], tuple[Tensor]]`.
attributions,
baselines,
inputs,
Expand Down

0 comments on commit 477d737

Please sign in to comment.