Skip to content

Commit

Permalink
Fix approximation utils pyre fix me issues
Browse files Browse the repository at this point in the history
Differential Revision: D67706741
  • Loading branch information
Vivek Miglani authored and facebook-github-bot committed Dec 30, 2024
1 parent 5b8d4a3 commit d7a3309
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions captum/attr/_utils/approximation_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# pyre-strict
from enum import Enum
from typing import Callable, List, Tuple
from typing import Callable, cast, List, Tuple

import torch

Expand Down Expand Up @@ -121,19 +121,20 @@ def gauss_legendre_builders() -> (

# allow using riemann even without np
import numpy as np
from numpy.typing import NDArray

def step_sizes(n: int) -> List[float]:
assert n > 0, "The number of steps has to be larger than zero"
# Scaling from 2 to 1
# pyre-fixme[6]: For 1st argument expected `Iterable[Variable[_T]]` but got
# `float`.
return list(0.5 * np.polynomial.legendre.leggauss(n)[1])
return cast(
NDArray[np.float64], 0.5 * np.polynomial.legendre.leggauss(n)[1]
).tolist()

def alphas(n: int) -> List[float]:
assert n > 0, "The number of steps has to be larger than zero"
# Scaling from [-1, 1] to [0, 1]
# pyre-fixme[6]: For 1st argument expected `Iterable[Variable[_T]]` but got
# `float`.
return list(0.5 * (1 + np.polynomial.legendre.leggauss(n)[0]))
return cast(
NDArray[np.float64], 0.5 * (1 + np.polynomial.legendre.leggauss(n)[0])
).tolist()

return step_sizes, alphas

0 comments on commit d7a3309

Please sign in to comment.