From d7a3309b14da3611ecb306cfbce38e1eb45aa13a Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Mon, 30 Dec 2024 07:38:08 -0800 Subject: [PATCH] Fix approximation utils pyre fix me issues Differential Revision: D67706741 --- captum/attr/_utils/approximation_methods.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/captum/attr/_utils/approximation_methods.py b/captum/attr/_utils/approximation_methods.py index 8debc95540..9af3cf9580 100644 --- a/captum/attr/_utils/approximation_methods.py +++ b/captum/attr/_utils/approximation_methods.py @@ -2,7 +2,7 @@ # pyre-strict from enum import Enum -from typing import Callable, List, Tuple +from typing import Callable, cast, List, Tuple import torch @@ -121,19 +121,20 @@ def gauss_legendre_builders() -> ( # allow using riemann even without np import numpy as np + from numpy.typing import NDArray def step_sizes(n: int) -> List[float]: assert n > 0, "The number of steps has to be larger than zero" # Scaling from 2 to 1 - # pyre-fixme[6]: For 1st argument expected `Iterable[Variable[_T]]` but got - # `float`. - return list(0.5 * np.polynomial.legendre.leggauss(n)[1]) + return cast( + NDArray[np.float64], 0.5 * np.polynomial.legendre.leggauss(n)[1] + ).tolist() def alphas(n: int) -> List[float]: assert n > 0, "The number of steps has to be larger than zero" # Scaling from [-1, 1] to [0, 1] - # pyre-fixme[6]: For 1st argument expected `Iterable[Variable[_T]]` but got - # `float`. - return list(0.5 * (1 + np.polynomial.legendre.leggauss(n)[0])) + return cast( + NDArray[np.float64], 0.5 * (1 + np.polynomial.legendre.leggauss(n)[0]) + ).tolist() return step_sizes, alphas