diff --git a/src/poly.rs b/src/poly.rs index a541002a..7c2c76d1 100644 --- a/src/poly.rs +++ b/src/poly.rs @@ -202,13 +202,48 @@ fn verify_point(p: Public, i: u32, m: u32, x: Scalar) -> bool { lhs == rhs } +// Polynomial sum +fn poly_sum(x: Vec, y: Vec) -> Vec { + let mut res = vec![Scalar::zero(); std::cmp::max(x.len(), y.len())]; + for (i, xi) in x.iter().enumerate() { + res[i] += xi + } + for (i, yi) in y.iter().enumerate() { + res[i] += yi + } + res +} + // Polynomial product fn poly_prod(x: Vec, y: Vec) -> Vec { let mut res = vec![Scalar::zero(); x.len() + y.len() - 1]; for (i, xi) in x.iter().enumerate() { for (j, yj) in y.iter().enumerate() { - res[i+j] += xi * yj + res[i + j] += xi * yj } } res } + +// lagrange basis polynomial L_n_j(x) +fn lagrange_basis(j: usize, xs: &Vec) -> Vec { + let mut num = vec![Scalar::one()]; // numerator + let mut den = Scalar::one(); + for (k, xk) in xs.iter().enumerate() { + if k != j { + num = poly_prod(num, vec![-xk, Scalar::one()]); + den *= xs[j] - xk + } + } + den = den.invert().unwrap(); + num.iter().map(|v| v * den).collect() +} + +fn lagrange_interpolate(points: Vec<(Scalar, Scalar)>) -> Vec { + let (xs, ys): (Vec, Vec) = points.iter().cloned().unzip(); + let mut res = Vec::new(); + for (j, yj) in ys.iter().enumerate() { + res = poly_sum(res, lagrange_basis(j, &xs).iter().map(|v| v * yj).collect()) + } + res +}