diff --git a/captum/attr/_core/layer/layer_activation.py b/captum/attr/_core/layer/layer_activation.py index 076323a27..d9aea9b27 100644 --- a/captum/attr/_core/layer/layer_activation.py +++ b/captum/attr/_core/layer/layer_activation.py @@ -20,8 +20,7 @@ class LayerActivation(LayerAttribution): def __init__( self, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - forward_func: Callable, + forward_func: Callable[..., Union[int, float, Tensor]], layer: ModuleOrModuleList, device_ids: Union[None, List[int]] = None, ) -> None: @@ -132,8 +131,6 @@ def attribute( ) else: return [ - # pyre-fixme[6]: For 2nd argument expected `Tuple[Tensor, ...]` but - # got `Tensor`. _format_output(len(single_layer_eval) > 1, single_layer_eval) for single_layer_eval in layer_eval ]