Skip to content

Commit

Permalink
GPU accelerated Grumpkin MSM (#187)
Browse files Browse the repository at this point in the history
* add grumpkin msm

* no sort feature

* add test

---------

Co-authored-by: Hanting Zhang <[email protected]>
  • Loading branch information
winston-h-zhang and Hanting Zhang authored Dec 22, 2023
1 parent 90a1269 commit e2b2008
Show file tree
Hide file tree
Showing 3 changed files with 253 additions and 2 deletions.
5 changes: 4 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ rand = "0.8.5"
ref-cast = "1.0.20"

[target.'cfg(any(target_arch = "x86_64", target_arch = "aarch64"))'.dependencies]
pasta-msm = { git="https://github.com/lurk-lab/pasta-msm", branch="dev", version = "0.1.4" }
pasta-msm = { git = "https://github.com/lurk-lab/pasta-msm", branch = "dev", version = "0.1.4" }
# pasta-msm also calls into sppark, which defines the same common sorting code this crate would:
# the `dont-implement-sort` feature avoids creating conflicting symbols.
grumpkin-msm = { git = "https://github.com/lurk-lab/grumpkin-msm", branch = "dev", features = ["dont-implement-sort"] }

[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
halo2curves = { version = "0.5.0", features = ["bits", "derive_serde"] }
Expand Down
209 changes: 208 additions & 1 deletion src/provider/bn256_grumpkin.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
//! This module implements the Nova traits for `bn256::Point`, `bn256::Scalar`, `grumpkin::Point`, `grumpkin::Scalar`.
use crate::{
impl_traits,
provider::{
traits::{CompressedGroup, DlogGroup},
util::msm::cpu_best_msm,
Expand Down Expand Up @@ -35,6 +34,170 @@ pub mod grumpkin {
pub use halo2curves::grumpkin::{Fq as Base, Fr as Scalar, G1Affine as Affine, G1 as Point};
}

macro_rules! impl_traits {
(
$name:ident,
$name_compressed:ident,
$name_curve:ident,
$name_curve_affine:ident,
$order_str:literal,
$base_str:literal
) => {
impl Group for $name::Point {
type Base = $name::Base;
type Scalar = $name::Scalar;

fn group_params() -> (Self::Base, Self::Base, BigInt, BigInt) {
let A = $name::Point::a();
let B = $name::Point::b();
let order = BigInt::from_str_radix($order_str, 16).unwrap();
let base = BigInt::from_str_radix($base_str, 16).unwrap();

(A, B, order, base)
}
}

impl DlogGroup for $name::Point {
type CompressedGroupElement = $name_compressed;
type PreprocessedGroupElement = $name::Affine;

#[tracing::instrument(
skip_all,
level = "trace",
name = "<_ as Group>::vartime_multiscalar_mul"
)]
fn vartime_multiscalar_mul(
scalars: &[Self::Scalar],
bases: &[Self::PreprocessedGroupElement],
) -> Self {
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
if scalars.len() >= 128 {
grumpkin_msm::$name(bases, scalars)
} else {
cpu_best_msm(scalars, bases)
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
cpu_best_msm(scalars, bases)
}
fn preprocessed(&self) -> Self::PreprocessedGroupElement {
self.to_affine()
}

fn compress(&self) -> Self::CompressedGroupElement {
self.to_bytes()
}

fn from_label(label: &'static [u8], n: usize) -> Vec<Self::PreprocessedGroupElement> {
let mut shake = Shake256::default();
shake.update(label);
let mut reader = shake.finalize_xof();
let mut uniform_bytes_vec = Vec::new();
for _ in 0..n {
let mut uniform_bytes = [0u8; 32];
reader.read_exact(&mut uniform_bytes).unwrap();
uniform_bytes_vec.push(uniform_bytes);
}
let gens_proj: Vec<$name_curve> = (0..n)
.into_par_iter()
.map(|i| {
let hash = $name_curve::hash_to_curve("from_uniform_bytes");
hash(&uniform_bytes_vec[i])
})
.collect();

let num_threads = rayon::current_num_threads();
if gens_proj.len() > num_threads {
let chunk = (gens_proj.len() as f64 / num_threads as f64).ceil() as usize;
(0..num_threads)
.into_par_iter()
.flat_map(|i| {
let start = i * chunk;
let end = if i == num_threads - 1 {
gens_proj.len()
} else {
core::cmp::min((i + 1) * chunk, gens_proj.len())
};
if end > start {
let mut gens = vec![$name_curve_affine::identity(); end - start];
<Self as Curve>::batch_normalize(&gens_proj[start..end], &mut gens);
gens
} else {
vec![]
}
})
.collect()
} else {
let mut gens = vec![$name_curve_affine::identity(); n];
<Self as Curve>::batch_normalize(&gens_proj, &mut gens);
gens
}
}

fn zero() -> Self {
$name::Point::identity()
}

fn to_coordinates(&self) -> (Self::Base, Self::Base, bool) {
let coordinates = self.to_affine().coordinates();
if coordinates.is_some().unwrap_u8() == 1
&& ($name_curve_affine::identity() != self.to_affine())
{
(*coordinates.unwrap().x(), *coordinates.unwrap().y(), false)
} else {
(Self::Base::zero(), Self::Base::zero(), true)
}
}
}

impl PrimeFieldExt for $name::Scalar {
fn from_uniform(bytes: &[u8]) -> Self {
let bytes_arr: [u8; 64] = bytes.try_into().unwrap();
$name::Scalar::from_uniform_bytes(&bytes_arr)
}
}

impl<G: DlogGroup> TranscriptReprTrait<G> for $name_compressed {
fn to_transcript_bytes(&self) -> Vec<u8> {
self.as_ref().to_vec()
}
}

impl CompressedGroup for $name_compressed {
type GroupElement = $name::Point;

fn decompress(&self) -> Option<$name::Point> {
Some($name_curve::from_bytes(&self).unwrap())
}
}

impl<G: Group> TranscriptReprTrait<G> for $name::Scalar {
fn to_transcript_bytes(&self) -> Vec<u8> {
self.to_repr().to_vec()
}
}

impl<G: DlogGroup> TranscriptReprTrait<G> for $name::Affine {
fn to_transcript_bytes(&self) -> Vec<u8> {
let (x, y, is_infinity_byte) = {
let coordinates = self.coordinates();
if coordinates.is_some().unwrap_u8() == 1 && ($name_curve_affine::identity() != *self) {
let c = coordinates.unwrap();
(*c.x(), *c.y(), u8::from(false))
} else {
($name::Base::zero(), $name::Base::zero(), u8::from(false))
}
};

x.to_repr()
.into_iter()
.chain(y.to_repr().into_iter())
.chain(std::iter::once(is_infinity_byte))
.collect()
}
}
};
}

impl_traits!(
bn256,
Bn256Compressed,
Expand All @@ -52,3 +215,47 @@ impl_traits!(
"30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47",
"30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001"
);

#[cfg(test)]
mod tests {
use ff::Field;
use rand::thread_rng;

use crate::provider::{
bn256_grumpkin::{bn256, grumpkin},
traits::DlogGroup,
util::msm::cpu_best_msm,
};

#[test]
fn test_bn256_msm_correctness() {
let npoints = 1usize << 16;
let points = bn256::Point::from_label(b"test", npoints);

let mut rng = thread_rng();
let scalars = (0..npoints)
.map(|_| bn256::Scalar::random(&mut rng))
.collect::<Vec<_>>();

let cpu_msm = cpu_best_msm(&scalars, &points);
let gpu_msm = bn256::Point::vartime_multiscalar_mul(&scalars, &points);

assert_eq!(cpu_msm, gpu_msm);
}

#[test]
fn test_grumpkin_msm_correctness() {
let npoints = 1usize << 16;
let points = grumpkin::Point::from_label(b"test", npoints);

let mut rng = thread_rng();
let scalars = (0..npoints)
.map(|_| grumpkin::Scalar::random(&mut rng))
.collect::<Vec<_>>();

let cpu_msm = cpu_best_msm(&scalars, &points);
let gpu_msm = grumpkin::Point::vartime_multiscalar_mul(&scalars, &points);

assert_eq!(cpu_msm, gpu_msm);
}
}
41 changes: 41 additions & 0 deletions src/provider/pasta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,3 +230,44 @@ impl_traits!(
"40000000000000000000000000000000224698fc094cf91b992d30ed00000001",
"40000000000000000000000000000000224698fc0994a8dd8c46eb2100000001"
);

#[cfg(test)]
mod tests {
use ff::Field;
use pasta_curves::{pallas, vesta};
use rand::thread_rng;

use crate::provider::{traits::DlogGroup, util::msm::cpu_best_msm};

#[test]
fn test_pallas_msm_correctness() {
let npoints = 1usize << 16;
let points = pallas::Point::from_label(b"test", npoints);

let mut rng = thread_rng();
let scalars = (0..npoints)
.map(|_| pallas::Scalar::random(&mut rng))
.collect::<Vec<_>>();

let cpu_msm = cpu_best_msm(&scalars, &points);
let gpu_msm = pallas::Point::vartime_multiscalar_mul(&scalars, &points);

assert_eq!(cpu_msm, gpu_msm);
}

#[test]
fn test_vesta_msm_correctness() {
let npoints = 1usize << 16;
let points = vesta::Point::from_label(b"test", npoints);

let mut rng = thread_rng();
let scalars = (0..npoints)
.map(|_| vesta::Scalar::random(&mut rng))
.collect::<Vec<_>>();

let cpu_msm = cpu_best_msm(&scalars, &points);
let gpu_msm = vesta::Point::vartime_multiscalar_mul(&scalars, &points);

assert_eq!(cpu_msm, gpu_msm);
}
}

0 comments on commit e2b2008

Please sign in to comment.