Skip to content

Commit

Permalink
Fix layer gradient x activation pyre fixme issues
Browse files Browse the repository at this point in the history
Differential Revision: D67705758
  • Loading branch information
Vivek Miglani authored and facebook-github-bot committed Dec 30, 2024
1 parent 313746d commit 1cec2a3
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions captum/attr/_core/layer/layer_gradient_x_activation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python3

# pyre-strict
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union

from captum._utils.common import (
_format_additional_forward_args,
Expand All @@ -24,8 +24,7 @@ class LayerGradientXActivation(LayerAttribution, GradientAttribution):

def __init__(
self,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
forward_func: Callable,
forward_func: Callable[..., Tensor],
layer: ModuleOrModuleList,
device_ids: Union[None, List[int]] = None,
multiply_by_inputs: bool = True,
Expand Down Expand Up @@ -186,11 +185,10 @@ def attribute(
if isinstance(self.layer, Module):
return _format_output(
len(layer_evals) > 1,
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but
# got `List[typing.Tuple[Tensor, ...]]`.
# pyre-fixme[6]: For 2nd argument expected `Tuple[Tensor, ...]` but
# got `List[typing.Tuple[Tensor, ...]]`.
self.multiply_gradient_acts(layer_gradients, layer_evals),
self.multiply_gradient_acts(
cast(Tuple[Tensor, ...], layer_gradients),
cast(Tuple[Tensor, ...], layer_evals),
),
)
else:
return [
Expand Down

0 comments on commit 1cec2a3

Please sign in to comment.