Skip to content

Commit

Permalink
Fix pyre errors in Occlusion (pytorch#1403)
Browse files Browse the repository at this point in the history
Summary:

Initial work on fixing Pyre errors in Occlusion

Reviewed By: craymichael

Differential Revision: D64677342
  • Loading branch information
Vivek Miglani authored and facebook-github-bot committed Oct 22, 2024
1 parent 5575bdd commit 81e0218
Showing 1 changed file with 4 additions and 9 deletions.
13 changes: 4 additions & 9 deletions captum/attr/_core/occlusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ class Occlusion(FeatureAblation):
/tensorflow/methods.py#L401
"""

# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
def __init__(self, forward_func: Callable) -> None:
def __init__(self, forward_func: Callable[..., Tensor]) -> None:
r"""
Args:
Expand All @@ -58,8 +57,7 @@ def attribute( # type: ignore
] = None,
baselines: BaselineType = None,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: object = None,
perturbations_per_eval: int = 1,
show_progress: bool = False,
) -> TensorOrTupleOfTensorsGeneric:
Expand Down Expand Up @@ -377,9 +375,7 @@ def _occlusion_mask(
padded_tensor = torch.nn.functional.pad(
sliding_window_tsr, tuple(pad_values) # type: ignore
)
# pyre-fixme[58]: `+` is not supported for operand types `Tuple[int]` and
# `Size`.
return padded_tensor.reshape((1,) + padded_tensor.shape)
return padded_tensor.reshape((1,) + tuple(padded_tensor.shape))

def _get_feature_range_and_mask(
self, input: Tensor, input_mask: Optional[Tensor], **kwargs: Any
Expand All @@ -389,8 +385,7 @@ def _get_feature_range_and_mask(

def _get_feature_counts(
self,
# pyre-fixme[2]: Parameter must be annotated.
inputs,
inputs: TensorOrTupleOfTensorsGeneric,
feature_mask: Tuple[Tensor, ...],
**kwargs: Any,
) -> Tuple[int, ...]:
Expand Down

0 comments on commit 81e0218

Please sign in to comment.