diff --git a/captum/attr/_utils/approximation_methods.py b/captum/attr/_utils/approximation_methods.py index 8debc9554..9af3cf958 100644 --- a/captum/attr/_utils/approximation_methods.py +++ b/captum/attr/_utils/approximation_methods.py @@ -2,7 +2,7 @@ # pyre-strict from enum import Enum -from typing import Callable, List, Tuple +from typing import Callable, cast, List, Tuple import torch @@ -121,19 +121,20 @@ def gauss_legendre_builders() -> ( # allow using riemann even without np import numpy as np + from numpy.typing import NDArray def step_sizes(n: int) -> List[float]: assert n > 0, "The number of steps has to be larger than zero" # Scaling from 2 to 1 - # pyre-fixme[6]: For 1st argument expected `Iterable[Variable[_T]]` but got - # `float`. - return list(0.5 * np.polynomial.legendre.leggauss(n)[1]) + return cast( + NDArray[np.float64], 0.5 * np.polynomial.legendre.leggauss(n)[1] + ).tolist() def alphas(n: int) -> List[float]: assert n > 0, "The number of steps has to be larger than zero" # Scaling from [-1, 1] to [0, 1] - # pyre-fixme[6]: For 1st argument expected `Iterable[Variable[_T]]` but got - # `float`. - return list(0.5 * (1 + np.polynomial.legendre.leggauss(n)[0])) + return cast( + NDArray[np.float64], 0.5 * (1 + np.polynomial.legendre.leggauss(n)[0]) + ).tolist() return step_sizes, alphas