diff --git a/pySDC/core/Lagrange.py b/pySDC/core/Lagrange.py index 6ef2d09653..3733eee1a0 100644 --- a/pySDC/core/Lagrange.py +++ b/pySDC/core/Lagrange.py @@ -88,7 +88,7 @@ class LagrangeApproximation(object): The associated barycentric weights """ - def __init__(self, points): + def __init__(self, points, fValues=None): points = np.asarray(points).ravel() diffs = points[:, None] - points[None, :] @@ -110,6 +110,20 @@ def analytic(diffs): self.points = points self.weights = weights + # Store function values if provided + if fValues is not None: + fValues = np.asarray(fValues) + if fValues.shape != points.shape: + raise ValueError(f'fValues {fValues.shape} has not the correct shape: {points.shape}') + self.fValues = fValues + + def __call__(self, t): + assert self.fValues is not None, "cannot evaluate polynomial without fValues" + t = np.asarray(t) + values = self.getInterpolationMatrix(t.ravel()).dot(self.fValues) + values.shape = t.shape + return values + @property def n(self): return self.points.size diff --git a/pySDC/projects/PinTSimE/switch_estimator.py b/pySDC/projects/PinTSimE/switch_estimator.py index 3c438c134a..8da86a9ee6 100644 --- a/pySDC/projects/PinTSimE/switch_estimator.py +++ b/pySDC/projects/PinTSimE/switch_estimator.py @@ -4,6 +4,7 @@ from pySDC.core.Collocation import CollBase from pySDC.core.ConvergenceController import ConvergenceController, Status from pySDC.implementations.convergence_controller_classes.check_convergence import CheckConvergence +from pySDC.core.Lagrange import LagrangeApproximation class SwitchEstimator(ConvergenceController): @@ -274,23 +275,8 @@ def get_switch(t_interp, state_function, m_guess): Time point of found event. """ - LagrangeInterpolator = LagrangeInterpolation(t_interp, state_function) - - def p(t): - """ - Simplifies the call of the interpolant. - - Parameters - ---------- - t : float - Time t at which the interpolant is called. - - Returns - ------- - p(t) : float - The value of the interpolated function at time t. - """ - return LagrangeInterpolator.eval(t) + LagrangeInterpolation = LagrangeApproximation(points=t_interp, fValues=state_function) + p = lambda t: LagrangeInterpolation.__call__(t) def fprime(t): """ @@ -385,47 +371,3 @@ def newton(x0, p, fprime, newton_tol, newton_maxiter): root = x0 return root - - -class LagrangeInterpolation(object): - def __init__(self, ti, yi): - """Initialization routine""" - self.ti = np.asarray(ti) - self.yi = np.asarray(yi) - self.n = len(ti) - - def get_Lagrange_polynomial(self, t, i): - """ - Computes the basis of the i-th Lagrange polynomial. - - Parameters - ---------- - t : float - Time where the polynomial is computed at. - i : int - Index of the Lagrange polynomial - - Returns - ------- - product : float - The product of the bases. - """ - product = np.prod([(t - self.ti[k]) / (self.ti[i] - self.ti[k]) for k in range(self.n) if k != i]) - return product - - def eval(self, t): - """ - Evaluates the Lagrange interpolation at time t. - - Parameters - ---------- - t : float - Time where interpolation is computed. - - Returns - ------- - p : float - Value of interpolant at time t. - """ - p = np.sum([self.yi[i] * self.get_Lagrange_polynomial(t, i) for i in range(self.n)]) - return p