From 7a7dbb23b666897d3f7e686e8ac43d5b834faa93 Mon Sep 17 00:00:00 2001 From: Hanting Zhang Date: Tue, 9 Jan 2024 22:32:26 +0000 Subject: [PATCH 01/11] format --- Cargo.toml | 2 +- cuda/bn254.cu | 9 +++++---- cuda/grumpkin.cu | 9 +++++---- cuda/pallas.cu | 9 +++++---- cuda/vesta.cu | 9 +++++---- 5 files changed, 21 insertions(+), 17 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 629fe2e..25b170e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,7 +33,7 @@ cuda-mobile = [] [dependencies] blst = "~0.3.11" semolina = "~0.1.3" -sppark = "~0.1.2" +sppark = { git = "https://github.com/lurk-lab/sppark", branch = "preallocate-msm" } halo2curves = { version = "0.6.0" } pasta_curves = { git = "https://github.com/lurk-lab/pasta_curves", branch = "dev", version = ">=0.3.1, <=0.5", features = ["repr-c"] } rand = "^0" diff --git a/cuda/bn254.cu b/cuda/bn254.cu index 3c0c1c9..7418990 100644 --- a/cuda/bn254.cu +++ b/cuda/bn254.cu @@ -17,8 +17,9 @@ typedef fr_t scalar_t; #include #ifndef __CUDA_ARCH__ -extern "C" -RustError cuda_pippenger_bn254(point_t *out, const affine_t points[], size_t npoints, - const scalar_t scalars[]) -{ return mult_pippenger(out, points, npoints, scalars); } +extern "C" RustError cuda_pippenger_bn254(point_t *out, const affine_t points[], size_t npoints, + const scalar_t scalars[]) +{ + return mult_pippenger(out, points, npoints, scalars); +} #endif diff --git a/cuda/grumpkin.cu b/cuda/grumpkin.cu index 861610d..b0ac132 100644 --- a/cuda/grumpkin.cu +++ b/cuda/grumpkin.cu @@ -17,8 +17,9 @@ typedef fp_t scalar_t; #include #ifndef __CUDA_ARCH__ -extern "C" -RustError cuda_pippenger_grumpkin(point_t *out, const affine_t points[], size_t npoints, - const scalar_t scalars[]) -{ return mult_pippenger(out, points, npoints, scalars); } +extern "C" RustError cuda_pippenger_grumpkin(point_t *out, const affine_t points[], size_t npoints, + const scalar_t scalars[]) +{ + return mult_pippenger(out, points, npoints, scalars); +} #endif diff --git a/cuda/pallas.cu b/cuda/pallas.cu index f3897bb..6abd6d1 100644 --- a/cuda/pallas.cu +++ b/cuda/pallas.cu @@ -17,8 +17,9 @@ typedef vesta_t scalar_t; #include #ifndef __CUDA_ARCH__ -extern "C" -RustError cuda_pippenger_pallas(point_t *out, const affine_t points[], size_t npoints, - const scalar_t scalars[]) -{ return mult_pippenger(out, points, npoints, scalars); } +extern "C" RustError cuda_pippenger_pallas(point_t *out, const affine_t points[], size_t npoints, + const scalar_t scalars[]) +{ + return mult_pippenger(out, points, npoints, scalars); +} #endif diff --git a/cuda/vesta.cu b/cuda/vesta.cu index a926c6d..65189da 100644 --- a/cuda/vesta.cu +++ b/cuda/vesta.cu @@ -17,8 +17,9 @@ typedef pallas_t scalar_t; #include #ifndef __CUDA_ARCH__ -extern "C" -RustError cuda_pippenger_vesta(point_t *out, const affine_t points[], size_t npoints, - const scalar_t scalars[]) -{ return mult_pippenger(out, points, npoints, scalars); } +extern "C" RustError cuda_pippenger_vesta(point_t *out, const affine_t points[], size_t npoints, + const scalar_t scalars[]) +{ + return mult_pippenger(out, points, npoints, scalars); +} #endif From 3d7fabe6a8bcb6a39492ce815df373c279019f30 Mon Sep 17 00:00:00 2001 From: Hanting Zhang Date: Tue, 9 Jan 2024 23:14:14 +0000 Subject: [PATCH 02/11] add msm_context --- benches/grumpkin_msm.rs | 12 +++ benches/pasta_msm.rs | 13 +++ cuda/bn254.cu | 19 +++- cuda/grumpkin.cu | 21 +++- cuda/pallas.cu | 22 +++- cuda/vesta.cu | 20 +++- src/lib.rs | 219 +++++++++++++++++++++++++++++++++++++++- src/pasta.rs | 214 ++++++++++++++++++++++++++++++++++++--- 8 files changed, 514 insertions(+), 26 deletions(-) diff --git a/benches/grumpkin_msm.rs b/benches/grumpkin_msm.rs index 1575f48..1a23e09 100644 --- a/benches/grumpkin_msm.rs +++ b/benches/grumpkin_msm.rs @@ -58,6 +58,18 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); + let context = grumpkin_msm::bn256_init(&points, npoints); + + group.bench_function( + format!("preallocate 2**{} points", bench_npow), + |b| { + b.iter(|| { + let _ = + grumpkin_msm::bn256_with(&context, npoints, &scalars); + }) + }, + ); + group.finish(); } } diff --git a/benches/pasta_msm.rs b/benches/pasta_msm.rs index f77610f..f961742 100644 --- a/benches/pasta_msm.rs +++ b/benches/pasta_msm.rs @@ -58,6 +58,19 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); + let context = grumpkin_msm::pasta::pallas_init(&points, npoints); + + group.bench_function( + format!("preallocate 2**{} points", bench_npow), + |b| { + b.iter(|| { + let _ = grumpkin_msm::pasta::pallas_with( + &context, npoints, &scalars, + ); + }) + }, + ); + group.finish(); } } diff --git a/cuda/bn254.cu b/cuda/bn254.cu index 7418990..fc97057 100644 --- a/cuda/bn254.cu +++ b/cuda/bn254.cu @@ -17,9 +17,26 @@ typedef fr_t scalar_t; #include #ifndef __CUDA_ARCH__ -extern "C" RustError cuda_pippenger_bn254(point_t *out, const affine_t points[], size_t npoints, + +extern "C" void drop_msm_context_bn254(msm_context_t &ref) { + CUDA_OK(cudaFree(ref.d_points)); +} + +extern "C" RustError +cuda_bn254_init(const affine_t points[], size_t npoints, msm_context_t *msm_context) +{ + return mult_pippenger_init(points, npoints, msm_context); +} + +extern "C" RustError cuda_bn254(point_t *out, const affine_t points[], size_t npoints, const scalar_t scalars[]) { return mult_pippenger(out, points, npoints, scalars); } + +extern "C" RustError cuda_bn254_with(point_t *out, msm_context_t *msm_context, size_t npoints, + const scalar_t scalars[]) +{ + return mult_pippenger_with(out, msm_context, npoints, scalars); +} #endif diff --git a/cuda/grumpkin.cu b/cuda/grumpkin.cu index b0ac132..8403f92 100644 --- a/cuda/grumpkin.cu +++ b/cuda/grumpkin.cu @@ -17,9 +17,26 @@ typedef fp_t scalar_t; #include #ifndef __CUDA_ARCH__ -extern "C" RustError cuda_pippenger_grumpkin(point_t *out, const affine_t points[], size_t npoints, - const scalar_t scalars[]) + +extern "C" void drop_msm_context_grumpkin(msm_context_t &ref) { + CUDA_OK(cudaFree(ref.d_points)); +} + +extern "C" RustError +cuda_grumpkin_init(const affine_t points[], size_t npoints, msm_context_t *msm_context) +{ + return mult_pippenger_init(points, npoints, msm_context); +} + +extern "C" RustError cuda_grumpkin(point_t *out, const affine_t points[], size_t npoints, + const scalar_t scalars[]) { return mult_pippenger(out, points, npoints, scalars); } + +extern "C" RustError cuda_grumpkin_with(point_t *out, msm_context_t *msm_context, size_t npoints, + const scalar_t scalars[]) +{ + return mult_pippenger_with(out, msm_context, npoints, scalars); +} #endif diff --git a/cuda/pallas.cu b/cuda/pallas.cu index 6abd6d1..62a375a 100644 --- a/cuda/pallas.cu +++ b/cuda/pallas.cu @@ -17,9 +17,27 @@ typedef vesta_t scalar_t; #include #ifndef __CUDA_ARCH__ -extern "C" RustError cuda_pippenger_pallas(point_t *out, const affine_t points[], size_t npoints, - const scalar_t scalars[]) + +extern "C" void drop_msm_context_pallas(msm_context_t &ref) { + CUDA_OK(cudaFree(ref.d_points)); +} + +extern "C" RustError +cuda_pallas_init(const affine_t points[], size_t npoints, msm_context_t *msm_context) +{ + return mult_pippenger_init(points, npoints, msm_context); +} + +extern "C" RustError cuda_pallas(point_t *out, const affine_t points[], size_t npoints, + const scalar_t scalars[]) { return mult_pippenger(out, points, npoints, scalars); } + +extern "C" RustError cuda_pallas_with(point_t *out, msm_context_t *msm_context, size_t npoints, + const scalar_t scalars[]) +{ + return mult_pippenger_with(out, msm_context, npoints, scalars); +} + #endif diff --git a/cuda/vesta.cu b/cuda/vesta.cu index 65189da..13ca657 100644 --- a/cuda/vesta.cu +++ b/cuda/vesta.cu @@ -17,9 +17,27 @@ typedef pallas_t scalar_t; #include #ifndef __CUDA_ARCH__ -extern "C" RustError cuda_pippenger_vesta(point_t *out, const affine_t points[], size_t npoints, + +extern "C" void drop_msm_context_grumpkin(msm_context_t &ref) { + CUDA_OK(cudaFree(ref.d_points)); +} + +extern "C" RustError +cuda_grumpkin_init(const affine_t points[], size_t npoints, msm_context_t *msm_context) +{ + return mult_pippenger_init(points, npoints, msm_context); +} + +extern "C" RustError cuda_grumpkin(point_t *out, const affine_t points[], size_t npoints, const scalar_t scalars[]) { return mult_pippenger(out, points, npoints, scalars); } + +extern "C" RustError cuda_grumpkin_with(point_t *out, msm_context_t *msm_context, size_t npoints, + const scalar_t scalars[]) +{ + return mult_pippenger_with(out, msm_context, npoints, scalars); +} + #endif diff --git a/src/lib.rs b/src/lib.rs index 16fb22a..3ba270e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,6 +21,40 @@ pub static mut CUDA_OFF: bool = false; use halo2curves::bn256; use halo2curves::CurveExt; +#[cfg(feature = "cuda")] +#[repr(C)] +#[derive(Debug, Clone)] +pub struct MSMContextBn256 { + context: *const std::ffi::c_void, +} + +#[cfg(feature = "cuda")] +unsafe impl Send for MSMContextBn256 {} + +#[cfg(feature = "cuda")] +unsafe impl Sync for MSMContextBn256 {} + +#[cfg(feature = "cuda")] +impl Default for MSMContextBn256 { + fn default() -> Self { + Self { + context: std::ptr::null(), + } + } +} + +#[cfg(feature = "cuda")] +// TODO: check for device-side memory leaks +impl Drop for MSMContextBn256 { + fn drop(&mut self) { + extern "C" { + fn drop_msm_context_bn254(by_ref: &MSMContextBn256); + } + unsafe { drop_msm_context_bn254(std::mem::transmute::<&_, &_>(self)) }; + self.context = core::ptr::null(); + } +} + extern "C" { fn mult_pippenger_bn254( out: *mut bn256::G1, @@ -38,7 +72,7 @@ pub fn bn256(points: &[bn256::G1Affine], scalars: &[bn256::Fr]) -> bn256::G1 { #[cfg(feature = "cuda")] if npoints >= 1 << 16 && unsafe { !CUDA_OFF && cuda_available() } { extern "C" { - fn cuda_pippenger_bn254( + fn cuda_bn254( out: *mut bn256::G1, points: *const bn256::G1Affine, npoints: usize, @@ -47,9 +81,8 @@ pub fn bn256(points: &[bn256::G1Affine], scalars: &[bn256::Fr]) -> bn256::G1 { } let mut ret = bn256::G1::default(); - let err = unsafe { - cuda_pippenger_bn254(&mut ret, &points[0], npoints, &scalars[0]) - }; + let err = + unsafe { cuda_bn254(&mut ret, &points[0], npoints, &scalars[0]) }; assert!(err.code == 0, "{}", String::from(err)); return bn256::G1::new_jacobian(ret.x, ret.y, ret.z).unwrap(); @@ -59,8 +92,115 @@ pub fn bn256(points: &[bn256::G1Affine], scalars: &[bn256::Fr]) -> bn256::G1 { bn256::G1::new_jacobian(ret.x, ret.y, ret.z).unwrap() } +#[cfg(feature = "cuda")] +pub fn bn256_init( + points: &[bn256::G1Affine], + npoints: usize, +) -> MSMContextBn256 { + unsafe { + assert!( + !CUDA_OFF && cuda_available(), + "feature = \"cuda\" must be enabled" + ) + }; + assert!( + npoints == points.len() && npoints >= 1 << 16, + "length mismatch or less than 10**16" + ); + + extern "C" { + fn cuda_bn254_init( + points: *const bn256::G1Affine, + npoints: usize, + msm_context: &mut MSMContextBn256, + ) -> cuda::Error; + + } + + let mut ret = MSMContextBn256::default(); + let err = unsafe { + cuda_bn254_init(points.as_ptr() as *const _, npoints, &mut ret) + }; + assert!(err.code == 0, "{}", String::from(err)); + + ret +} + +#[cfg(feature = "cuda")] +pub fn bn256_with( + context: &MSMContextBn256, + npoints: usize, + scalars: &[bn256::Fr], +) -> bn256::G1 { + unsafe { + assert!( + !CUDA_OFF && cuda_available(), + "feature = \"cuda\" must be enabled" + ) + }; + assert!( + npoints == scalars.len() && npoints >= 1 << 16, + "length mismatch or less than 10**16" + ); + + extern "C" { + fn cuda_bn254_with( + out: *mut bn256::G1, + context: &MSMContextBn256, + npoints: usize, + scalars: *const bn256::Fr, + is_mont: bool, + ) -> cuda::Error; + + } + + let mut ret = bn256::G1::default(); + let err = unsafe { + cuda_bn254_with(&mut ret, context, npoints, &scalars[0], true) + }; + assert!(err.code == 0, "{}", String::from(err)); + + ret +} + use halo2curves::grumpkin; +#[cfg(feature = "cuda")] +#[repr(C)] +#[derive(Debug, Clone)] +pub struct MSMContextGrumpkin { + context: *const std::ffi::c_void, +} + +#[cfg(feature = "cuda")] +unsafe impl Send for MSMContextGrumpkin {} + +#[cfg(feature = "cuda")] +unsafe impl Sync for MSMContextGrumpkin {} + +#[cfg(feature = "cuda")] +impl Default for MSMContextGrumpkin { + fn default() -> Self { + Self { + context: std::ptr::null(), + } + } +} + +#[cfg(feature = "cuda")] +// TODO: check for device-side memory leaks +impl Drop for MSMContextGrumpkin { + fn drop(&mut self) { + extern "C" { + fn drop_msm_context_grumpkin(by_ref: &MSMContextGrumpkin); + } + unsafe { + drop_msm_context_grumpkin(std::mem::transmute::<&_, &_>(self)) + }; + self.context = core::ptr::null(); + } +} + extern "C" { fn mult_pippenger_grumpkin( out: *mut grumpkin::G1, @@ -104,6 +244,77 @@ pub fn grumpkin( grumpkin::G1::new_jacobian(ret.x, ret.y, ret.z).unwrap() } +#[cfg(feature = "cuda")] +pub fn grumpkin_init( + points: &[grumpkin::G1Affine], + npoints: usize, +) -> MSMContextGrumpkin { + unsafe { + assert!( + !CUDA_OFF && cuda_available(), + "feature = \"cuda\" must be enabled" + ) + }; + assert!( + npoints == points.len() && npoints >= 1 << 16, + "length mismatch or less than 10**16" + ); + + extern "C" { + fn cuda_grumpkin_init( + points: *const grumpkin::G1Affine, + npoints: usize, + msm_context: &mut MSMContextGrumpkin, + ) -> cuda::Error; + + } + + let mut ret = MSMContextGrumpkin::default(); + let err = unsafe { + cuda_grumpkin_init(points.as_ptr() as *const _, npoints, &mut ret) + }; + assert!(err.code == 0, "{}", String::from(err)); + + ret +} + +#[cfg(feature = "cuda")] +pub fn grumpkin_with( + context: &MSMContextGrumpkin, + npoints: usize, + scalars: &[grumpkin::Fr], +) -> grumpkin::G1 { + unsafe { + assert!( + !CUDA_OFF && cuda_available(), + "feature = \"cuda\" must be enabled" + ) + }; + assert!( + npoints == scalars.len() && npoints >= 1 << 16, + "length mismatch or less than 10**16" + ); + + extern "C" { + fn cuda_grumpkin_with( + out: *mut grumpkin::G1, + context: &MSMContextGrumpkin, + npoints: usize, + scalars: *const grumpkin::Fr, + is_mont: bool, + ) -> cuda::Error; + + } + + let mut ret = grumpkin::G1::default(); + let err = unsafe { + cuda_grumpkin_with(&mut ret, context, npoints, &scalars[0], true) + }; + assert!(err.code == 0, "{}", String::from(err)); + + ret +} + #[cfg(test)] mod tests { use halo2curves::group::Curve; diff --git a/src/pasta.rs b/src/pasta.rs index c69f6be..dc12f47 100644 --- a/src/pasta.rs +++ b/src/pasta.rs @@ -11,6 +11,40 @@ use pasta_curves::pallas; #[cfg(feature = "cuda")] use crate::{cuda, cuda_available, CUDA_OFF}; +#[cfg(feature = "cuda")] +#[repr(C)] +#[derive(Debug, Clone)] +pub struct MSMContextPallas { + context: *const std::ffi::c_void, +} + +#[cfg(feature = "cuda")] +unsafe impl Send for MSMContextPallas {} + +#[cfg(feature = "cuda")] +unsafe impl Sync for MSMContextPallas {} + +#[cfg(feature = "cuda")] +impl Default for MSMContextPallas { + fn default() -> Self { + Self { + context: std::ptr::null(), + } + } +} + +#[cfg(feature = "cuda")] +// TODO: check for device-side memory leaks +impl Drop for MSMContextPallas { + fn drop(&mut self) { + extern "C" { + fn drop_msm_context_pallas(by_ref: &MSMContextPallas); + } + unsafe { drop_msm_context_pallas(std::mem::transmute::<&_, &_>(self)) }; + self.context = core::ptr::null(); + } +} + extern "C" { fn mult_pippenger_pallas( out: *mut pallas::Point, @@ -31,7 +65,7 @@ pub fn pallas( #[cfg(feature = "cuda")] if npoints >= 1 << 16 && unsafe { !CUDA_OFF && cuda_available() } { extern "C" { - fn cuda_pippenger_pallas( + fn cuda_pallas( out: *mut pallas::Point, points: *const pallas::Affine, npoints: usize, @@ -42,13 +76,7 @@ pub fn pallas( } let mut ret = pallas::Point::default(); let err = unsafe { - cuda_pippenger_pallas( - &mut ret, - &points[0], - npoints, - &scalars[0], - true, - ) + cuda_pallas(&mut ret, &points[0], npoints, &scalars[0], true) }; assert!(err.code == 0, "{}", String::from(err)); @@ -61,8 +89,101 @@ pub fn pallas( ret } +#[cfg(feature = "cuda")] +pub fn pallas_init( + points: &[pallas::Affine], + npoints: usize, +) -> MSMContextPallas { + unsafe { + assert!( + !CUDA_OFF && cuda_available(), + "feature = \"cuda\" must be enabled" + ) + }; + assert!( + npoints == points.len() && npoints >= 1 << 16, + "length mismatch or less than 10**16" + ); + + extern "C" { + fn cuda_pallas_init( + points: *const pallas::Affine, + npoints: usize, + msm_context: &mut MSMContextPallas, + ) -> cuda::Error; + + } + + let mut ret = MSMContextPallas::default(); + let err = unsafe { + cuda_pallas_init(points.as_ptr() as *const _, npoints, &mut ret) + }; + assert!(err.code == 0, "{}", String::from(err)); + + ret +} + +#[cfg(feature = "cuda")] +pub fn pallas_with( + context: &MSMContextPallas, + npoints: usize, + scalars: &[pallas::Scalar], +) -> pallas::Point { + unsafe { + assert!( + !CUDA_OFF && cuda_available(), + "feature = \"cuda\" must be enabled" + ) + }; + assert!( + npoints == scalars.len() && npoints >= 1 << 16, + "length mismatch or less than 10**16" + ); + + extern "C" { + fn cuda_pallas_with( + out: *mut pallas::Point, + context: &MSMContextPallas, + npoints: usize, + scalars: *const pallas::Scalar, + is_mont: bool, + ) -> cuda::Error; + + } + + let mut ret = pallas::Point::default(); + let err = unsafe { + cuda_pallas_with(&mut ret, context, npoints, &scalars[0], true) + }; + assert!(err.code == 0, "{}", String::from(err)); + + ret +} + use pasta_curves::vesta; +#[cfg(feature = "cuda")] +#[repr(C)] +#[derive(Debug, Clone)] +pub struct MSMContextVesta { + context: *const std::ffi::c_void, +} + +#[cfg(feature = "cuda")] +unsafe impl Send for MSMContextVesta {} + +#[cfg(feature = "cuda")] +unsafe impl Sync for MSMContextVesta {} + +#[cfg(feature = "cuda")] +impl Default for MSMContextVesta { + fn default() -> Self { + Self { + context: std::ptr::null(), + } + } +} + extern "C" { fn mult_pippenger_vesta( out: *mut vesta::Point, @@ -83,7 +204,7 @@ pub fn vesta( #[cfg(feature = "cuda")] if npoints >= 1 << 16 && unsafe { !CUDA_OFF && cuda_available() } { extern "C" { - fn cuda_pippenger_vesta( + fn cuda_vesta( out: *mut vesta::Point, points: *const vesta::Affine, npoints: usize, @@ -94,13 +215,7 @@ pub fn vesta( } let mut ret = vesta::Point::default(); let err = unsafe { - cuda_pippenger_vesta( - &mut ret, - &points[0], - npoints, - &scalars[0], - true, - ) + cuda_vesta(&mut ret, &points[0], npoints, &scalars[0], true) }; assert!(err.code == 0, "{}", String::from(err)); @@ -113,6 +228,73 @@ pub fn vesta( ret } +#[cfg(feature = "cuda")] +pub fn vesta_init(points: &[vesta::Affine], npoints: usize) -> MSMContextVesta { + unsafe { + assert!( + !CUDA_OFF && cuda_available(), + "feature = \"cuda\" must be enabled" + ) + }; + assert!( + npoints == points.len() && npoints >= 1 << 16, + "length mismatch or less than 10**16" + ); + extern "C" { + fn cuda_vesta_init( + points: *const vesta::Affine, + npoints: usize, + msm_context: &mut MSMContextVesta, + ) -> cuda::Error; + + } + + let mut ret = MSMContextVesta::default(); + let err = unsafe { + cuda_vesta_init(points.as_ptr() as *const _, npoints, &mut ret) + }; + assert!(err.code == 0, "{}", String::from(err)); + + ret +} + +#[cfg(feature = "cuda")] +pub fn vesta_with( + context: &MSMContextVesta, + npoints: usize, + scalars: &[vesta::Scalar], +) -> vesta::Point { + unsafe { + assert!( + !CUDA_OFF && cuda_available(), + "feature = \"cuda\" must be enabled" + ) + }; + assert!( + npoints == scalars.len() && npoints >= 1 << 16, + "length mismatch or less than 10**16" + ); + + extern "C" { + fn cuda_vesta_with( + out: *mut vesta::Point, + context: &MSMContextVesta, + npoints: usize, + scalars: *const vesta::Scalar, + is_mont: bool, + ) -> cuda::Error; + + } + + let mut ret = vesta::Point::default(); + let err = unsafe { + cuda_vesta_with(&mut ret, context, npoints, &scalars[0], true) + }; + assert!(err.code == 0, "{}", String::from(err)); + + ret +} + pub mod utils { use std::{ mem::transmute, From 0c2b29d7c92d93b812571fea543844f3cd2df02d Mon Sep 17 00:00:00 2001 From: Hanting Zhang Date: Wed, 10 Jan 2024 21:54:20 +0000 Subject: [PATCH 03/11] reorg code --- benches/grumpkin_msm.rs | 19 +- benches/pasta_msm.rs | 22 +- cuda/bn254.cu | 4 +- cuda/grumpkin.cu | 4 +- cuda/pallas.cu | 4 +- cuda/vesta.cu | 4 +- examples/grumpkin_msm.rs | 2 +- examples/pasta_msm.rs | 6 +- src/lib.rs | 479 +++++++++++++++++---------------------- src/pasta.rs | 434 +++++++++++++++-------------------- 10 files changed, 429 insertions(+), 549 deletions(-) diff --git a/benches/grumpkin_msm.rs b/benches/grumpkin_msm.rs index 1a23e09..0bd2380 100644 --- a/benches/grumpkin_msm.rs +++ b/benches/grumpkin_msm.rs @@ -30,10 +30,21 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function(format!("2**{} points", bench_npow), |b| { b.iter(|| { - let _ = grumpkin_msm::bn256(&points, &scalars); + let _ = grumpkin_msm::bn256::msm(&points, &scalars); }) }); + let context = grumpkin_msm::bn256::init(points.clone()); + + group.bench_function( + format!("\"preallocate\" 2**{} points", bench_npow), + |b| { + b.iter(|| { + let _ = grumpkin_msm::bn256::with(&context, &scalars); + }) + }, + ); + group.finish(); #[cfg(feature = "cuda")] @@ -54,18 +65,18 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function(format!("2**{} points", bench_npow), |b| { b.iter(|| { - let _ = grumpkin_msm::bn256(&points, &scalars); + let _ = grumpkin_msm::bn256::msm(&points, &scalars); }) }); - let context = grumpkin_msm::bn256_init(&points, npoints); + let context = grumpkin_msm::bn256::init(points); group.bench_function( format!("preallocate 2**{} points", bench_npow), |b| { b.iter(|| { let _ = - grumpkin_msm::bn256_with(&context, npoints, &scalars); + grumpkin_msm::bn256::with(&context, &scalars); }) }, ); diff --git a/benches/pasta_msm.rs b/benches/pasta_msm.rs index f961742..ca5970c 100644 --- a/benches/pasta_msm.rs +++ b/benches/pasta_msm.rs @@ -30,10 +30,21 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function(format!("2**{} points", bench_npow), |b| { b.iter(|| { - let _ = grumpkin_msm::pasta::pallas(&points, &scalars); + let _ = grumpkin_msm::pasta::pallas::msm(&points, &scalars); }) }); + let context = grumpkin_msm::pasta::pallas::init(points.clone()); + + group.bench_function( + format!("\"preallocate\" 2**{} points", bench_npow), + |b| { + b.iter(|| { + let _ = grumpkin_msm::pasta::pallas::with(&context, &scalars); + }) + }, + ); + group.finish(); #[cfg(feature = "cuda")] @@ -54,19 +65,18 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function(format!("2**{} points", bench_npow), |b| { b.iter(|| { - let _ = grumpkin_msm::pasta::pallas(&points, &scalars); + let _ = grumpkin_msm::pasta::pallas::msm(&points, &scalars); }) }); - let context = grumpkin_msm::pasta::pallas_init(&points, npoints); + let context = grumpkin_msm::pasta::pallas::init(points); group.bench_function( format!("preallocate 2**{} points", bench_npow), |b| { b.iter(|| { - let _ = grumpkin_msm::pasta::pallas_with( - &context, npoints, &scalars, - ); + let _ = + grumpkin_msm::pasta::pallas::with(&context, &scalars); }) }, ); diff --git a/cuda/bn254.cu b/cuda/bn254.cu index fc97057..e0611e4 100644 --- a/cuda/bn254.cu +++ b/cuda/bn254.cu @@ -34,9 +34,9 @@ extern "C" RustError cuda_bn254(point_t *out, const affine_t points[], size_t np return mult_pippenger(out, points, npoints, scalars); } -extern "C" RustError cuda_bn254_with(point_t *out, msm_context_t *msm_context, size_t npoints, +extern "C" RustError cuda_bn254_with(point_t *out, msm_context_t *msm_context, const scalar_t scalars[]) { - return mult_pippenger_with(out, msm_context, npoints, scalars); + return mult_pippenger_with(out, msm_context, scalars); } #endif diff --git a/cuda/grumpkin.cu b/cuda/grumpkin.cu index 8403f92..106be8f 100644 --- a/cuda/grumpkin.cu +++ b/cuda/grumpkin.cu @@ -34,9 +34,9 @@ extern "C" RustError cuda_grumpkin(point_t *out, const affine_t points[], size_t return mult_pippenger(out, points, npoints, scalars); } -extern "C" RustError cuda_grumpkin_with(point_t *out, msm_context_t *msm_context, size_t npoints, +extern "C" RustError cuda_grumpkin_with(point_t *out, msm_context_t *msm_context, const scalar_t scalars[]) { - return mult_pippenger_with(out, msm_context, npoints, scalars); + return mult_pippenger_with(out, msm_context, scalars); } #endif diff --git a/cuda/pallas.cu b/cuda/pallas.cu index 62a375a..dd76686 100644 --- a/cuda/pallas.cu +++ b/cuda/pallas.cu @@ -34,10 +34,10 @@ extern "C" RustError cuda_pallas(point_t *out, const affine_t points[], size_t n return mult_pippenger(out, points, npoints, scalars); } -extern "C" RustError cuda_pallas_with(point_t *out, msm_context_t *msm_context, size_t npoints, +extern "C" RustError cuda_pallas_with(point_t *out, msm_context_t *msm_context, const scalar_t scalars[]) { - return mult_pippenger_with(out, msm_context, npoints, scalars); + return mult_pippenger_with(out, msm_context, scalars); } #endif diff --git a/cuda/vesta.cu b/cuda/vesta.cu index 13ca657..ff9d594 100644 --- a/cuda/vesta.cu +++ b/cuda/vesta.cu @@ -34,10 +34,10 @@ extern "C" RustError cuda_grumpkin(point_t *out, const affine_t points[], size_t return mult_pippenger(out, points, npoints, scalars); } -extern "C" RustError cuda_grumpkin_with(point_t *out, msm_context_t *msm_context, size_t npoints, +extern "C" RustError cuda_grumpkin_with(point_t *out, msm_context_t *msm_context, const scalar_t scalars[]) { - return mult_pippenger_with(out, msm_context, npoints, scalars); + return mult_pippenger_with(out, msm_context, scalars); } #endif diff --git a/examples/grumpkin_msm.rs b/examples/grumpkin_msm.rs index 3ca237a..56eaed5 100644 --- a/examples/grumpkin_msm.rs +++ b/examples/grumpkin_msm.rs @@ -20,7 +20,7 @@ fn main() { unsafe { grumpkin_msm::CUDA_OFF = false }; } - let res = grumpkin_msm::bn256(&points, &scalars).to_affine(); + let res = grumpkin_msm::bn256::msm(&points, &scalars).to_affine(); let native = naive_multiscalar_mul(&points, &scalars); assert_eq!(res, native); println!("success!") diff --git a/examples/pasta_msm.rs b/examples/pasta_msm.rs index a682c35..33a08de 100644 --- a/examples/pasta_msm.rs +++ b/examples/pasta_msm.rs @@ -8,7 +8,7 @@ use pasta_curves::group::Curve; fn main() { let bench_npow: usize = std::env::var("BENCH_NPOW") - .unwrap_or("17".to_string()) + .unwrap_or("22".to_string()) .parse() .unwrap(); let npoints: usize = 1 << bench_npow; @@ -22,8 +22,10 @@ fn main() { unsafe { grumpkin_msm::CUDA_OFF = false }; } - let res = grumpkin_msm::pasta::pallas(&points, &scalars).to_affine(); let native = naive_multiscalar_mul(&points, &scalars); + let context = grumpkin_msm::pasta::pallas::init(points); + let res = grumpkin_msm::pasta::pallas::with(&context, &scalars).to_affine(); + assert_eq!(res, native); println!("success!") } diff --git a/src/lib.rs b/src/lib.rs index 3ba270e..6a39b39 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,301 +18,228 @@ extern "C" { #[cfg(feature = "cuda")] pub static mut CUDA_OFF: bool = false; -use halo2curves::bn256; -use halo2curves::CurveExt; - -#[cfg(feature = "cuda")] -#[repr(C)] -#[derive(Debug, Clone)] -pub struct MSMContextBn256 { - context: *const std::ffi::c_void, -} - -#[cfg(feature = "cuda")] -unsafe impl Send for MSMContextBn256 {} - -#[cfg(feature = "cuda")] -unsafe impl Sync for MSMContextBn256 {} - -#[cfg(feature = "cuda")] -impl Default for MSMContextBn256 { - fn default() -> Self { - Self { - context: std::ptr::null(), - } - } -} - -#[cfg(feature = "cuda")] -// TODO: check for device-side memory leaks -impl Drop for MSMContextBn256 { - fn drop(&mut self) { - extern "C" { - fn drop_msm_context_bn254(by_ref: &MSMContextBn256); - } - unsafe { drop_msm_context_bn254(std::mem::transmute::<&_, &_>(self)) }; - self.context = core::ptr::null(); - } -} - -extern "C" { - fn mult_pippenger_bn254( - out: *mut bn256::G1, - points: *const bn256::G1Affine, - npoints: usize, - scalars: *const bn256::Fr, - ); - -} - -pub fn bn256(points: &[bn256::G1Affine], scalars: &[bn256::Fr]) -> bn256::G1 { - let npoints = points.len(); - assert!(npoints == scalars.len(), "length mismatch"); - - #[cfg(feature = "cuda")] - if npoints >= 1 << 16 && unsafe { !CUDA_OFF && cuda_available() } { - extern "C" { - fn cuda_bn254( - out: *mut bn256::G1, - points: *const bn256::G1Affine, - npoints: usize, - scalars: *const bn256::Fr, - ) -> cuda::Error; - - } - let mut ret = bn256::G1::default(); - let err = - unsafe { cuda_bn254(&mut ret, &points[0], npoints, &scalars[0]) }; - assert!(err.code == 0, "{}", String::from(err)); - - return bn256::G1::new_jacobian(ret.x, ret.y, ret.z).unwrap(); - } - let mut ret = bn256::G1::default(); - unsafe { mult_pippenger_bn254(&mut ret, &points[0], npoints, &scalars[0]) }; - bn256::G1::new_jacobian(ret.x, ret.y, ret.z).unwrap() -} - -#[cfg(feature = "cuda")] -pub fn bn256_init( - points: &[bn256::G1Affine], - npoints: usize, -) -> MSMContextBn256 { - unsafe { - assert!( - !CUDA_OFF && cuda_available(), - "feature = \"cuda\" must be enabled" - ) +pub mod bn256 { + use halo2curves::{ + bn256::{Fr as Scalar, G1Affine as Affine, G1 as Point}, + CurveExt, }; - assert!( - npoints == points.len() && npoints >= 1 << 16, - "length mismatch or less than 10**16" + + use crate::impl_msm; + + impl_msm!( + cuda_bn254, + cuda_bn254_init, + cuda_bn254_with, + mult_pippenger_bn254, + Point, + Affine, + Scalar ); - - extern "C" { - fn cuda_bn254_init( - points: *const bn256::G1Affine, - npoints: usize, - msm_context: &mut MSMContextBn256, - ) -> cuda::Error; - - } - - let mut ret = MSMContextBn256::default(); - let err = unsafe { - cuda_bn254_init(points.as_ptr() as *const _, npoints, &mut ret) - }; - assert!(err.code == 0, "{}", String::from(err)); - - ret } -#[cfg(feature = "cuda")] -pub fn bn256_with( - context: &MSMContextBn256, - npoints: usize, - scalars: &[bn256::Fr], -) -> bn256::G1 { - unsafe { - assert!( - !CUDA_OFF && cuda_available(), - "feature = \"cuda\" must be enabled" - ) +pub mod grumpkin { + use halo2curves::{ + grumpkin::{Fr as Scalar, G1Affine as Affine, G1 as Point}, + CurveExt, }; - assert!( - npoints == scalars.len() && npoints >= 1 << 16, - "length mismatch or less than 10**16" + + use crate::impl_msm; + + impl_msm!( + cuda_grumpkin, + cuda_grumpkin_init, + cuda_grumpkin_with, + mult_pippenger_grumpkin, + Point, + Affine, + Scalar ); - - extern "C" { - fn cuda_bn254_with( - out: *mut bn256::G1, - context: &MSMContextBn256, - npoints: usize, - scalars: *const bn256::Fr, - is_mont: bool, - ) -> cuda::Error; - - } - - let mut ret = bn256::G1::default(); - let err = unsafe { - cuda_bn254_with(&mut ret, context, npoints, &scalars[0], true) - }; - assert!(err.code == 0, "{}", String::from(err)); - - ret } -use halo2curves::grumpkin; - -#[cfg(feature = "cuda")] -#[repr(C)] -#[derive(Debug, Clone)] -pub struct MSMContextGrumpkin { - context: *const std::ffi::c_void, -} +#[macro_export] +macro_rules! impl_msm { + ( + $name:ident, + $name_init:ident, + $name_with:ident, + $name_cpu:ident, + $point:ident, + $affine:ident, + $scalar:ident + ) => { + #[cfg(feature = "cuda")] + use $crate::{cuda, cuda_available, CUDA_OFF}; + + #[repr(C)] + #[derive(Debug, Clone)] + pub struct CudaMSMContext { + context: *const std::ffi::c_void, + npoints: usize, + } -#[cfg(feature = "cuda")] -unsafe impl Send for MSMContextGrumpkin {} + unsafe impl Send for CudaMSMContext {} -#[cfg(feature = "cuda")] -unsafe impl Sync for MSMContextGrumpkin {} + unsafe impl Sync for CudaMSMContext {} -#[cfg(feature = "cuda")] -impl Default for MSMContextGrumpkin { - fn default() -> Self { - Self { - context: std::ptr::null(), + impl Default for CudaMSMContext { + fn default() -> Self { + Self { + context: std::ptr::null(), + npoints: 0, + } + } } - } -} -#[cfg(feature = "cuda")] -// TODO: check for device-side memory leaks -impl Drop for MSMContextGrumpkin { - fn drop(&mut self) { - extern "C" { - fn drop_msm_context_grumpkin(by_ref: &MSMContextGrumpkin); + #[cfg(feature = "cuda")] + // TODO: check for device-side memory leaks + impl Drop for CudaMSMContext { + fn drop(&mut self) { + extern "C" { + fn drop_msm_context_bn254(by_ref: &CudaMSMContext); + } + unsafe { + drop_msm_context_bn254(std::mem::transmute::<&_, &_>(self)) + }; + self.context = core::ptr::null(); + } } - unsafe { - drop_msm_context_grumpkin(std::mem::transmute::<&_, &_>(self)) - }; - self.context = core::ptr::null(); - } -} -extern "C" { - fn mult_pippenger_grumpkin( - out: *mut grumpkin::G1, - points: *const grumpkin::G1Affine, - npoints: usize, - scalars: *const grumpkin::Fr, - ); - -} - -pub fn grumpkin( - points: &[grumpkin::G1Affine], - scalars: &[grumpkin::Fr], -) -> grumpkin::G1 { - let npoints = points.len(); - assert!(npoints == scalars.len(), "length mismatch"); + #[derive(Debug, Clone)] + pub enum MSMContext { + CUDA(CudaMSMContext), + CPU(Vec<$affine>), + } - #[cfg(feature = "cuda")] - if npoints >= 1 << 16 && unsafe { !CUDA_OFF && cuda_available() } { + unsafe impl Send for MSMContext {} + + unsafe impl Sync for MSMContext {} + + impl MSMContext { + pub fn new_cpu(points: Vec<$affine>) -> Self { + Self::CPU(points) + } + + pub fn new_cuda(cuda_context: CudaMSMContext) -> Self { + Self::CUDA(cuda_context) + } + + pub fn npoints(&self) -> usize { + match self { + Self::CUDA(cuda_context) => cuda_context.npoints, + Self::CPU(points) => points.len(), + } + } + + pub fn cuda(&self) -> &CudaMSMContext { + match self { + Self::CUDA(cuda_context) => cuda_context, + Self::CPU(_) => panic!("not a cuda context"), + } + } + + pub fn points(&self) -> &Vec<$affine> { + match self { + Self::CUDA(_) => { + panic!("cuda context; no host side points") + } + Self::CPU(points) => points, + } + } + } + extern "C" { - fn cuda_pippenger_grumpkin( - out: *mut grumpkin::G1, - points: *const grumpkin::G1Affine, + fn $name_cpu( + out: *mut $point, + points: *const $affine, npoints: usize, - scalars: *const grumpkin::Fr, - ) -> cuda::Error; - + scalars: *const $scalar, + ); + } - let mut ret = grumpkin::G1::default(); - let err = unsafe { - cuda_pippenger_grumpkin(&mut ret, &points[0], npoints, &scalars[0]) - }; - assert!(err.code == 0, "{}", String::from(err)); - - return grumpkin::G1::new_jacobian(ret.x, ret.y, ret.z).unwrap(); - } - let mut ret = grumpkin::G1::default(); - unsafe { - mult_pippenger_grumpkin(&mut ret, &points[0], npoints, &scalars[0]) - }; - grumpkin::G1::new_jacobian(ret.x, ret.y, ret.z).unwrap() -} - -#[cfg(feature = "cuda")] -pub fn grumpkin_init( - points: &[grumpkin::G1Affine], - npoints: usize, -) -> MSMContextGrumpkin { - unsafe { - assert!( - !CUDA_OFF && cuda_available(), - "feature = \"cuda\" must be enabled" - ) - }; - assert!( - npoints == points.len() && npoints >= 1 << 16, - "length mismatch or less than 10**16" - ); - - extern "C" { - fn cuda_grumpkin_init( - points: *const grumpkin::G1Affine, - npoints: usize, - msm_context: &mut MSMContextGrumpkin, - ) -> cuda::Error; - - } - - let mut ret = MSMContextGrumpkin::default(); - let err = unsafe { - cuda_grumpkin_init(points.as_ptr() as *const _, npoints, &mut ret) - }; - assert!(err.code == 0, "{}", String::from(err)); - - ret -} - -#[cfg(feature = "cuda")] -pub fn grumpkin_with( - context: &MSMContextGrumpkin, - npoints: usize, - scalars: &[grumpkin::Fr], -) -> grumpkin::G1 { - unsafe { - assert!( - !CUDA_OFF && cuda_available(), - "feature = \"cuda\" must be enabled" - ) - }; - assert!( - npoints == scalars.len() && npoints >= 1 << 16, - "length mismatch or less than 10**16" - ); - - extern "C" { - fn cuda_grumpkin_with( - out: *mut grumpkin::G1, - context: &MSMContextGrumpkin, - npoints: usize, - scalars: *const grumpkin::Fr, - is_mont: bool, - ) -> cuda::Error; - - } - - let mut ret = grumpkin::G1::default(); - let err = unsafe { - cuda_grumpkin_with(&mut ret, context, npoints, &scalars[0], true) + + pub fn msm(points: &[$affine], scalars: &[$scalar]) -> $point { + let npoints = points.len(); + assert!(npoints == scalars.len(), "length mismatch"); + + #[cfg(feature = "cuda")] + if npoints >= 1 << 16 && unsafe { !CUDA_OFF && cuda_available() } { + extern "C" { + fn $name( + out: *mut $point, + points: *const $affine, + npoints: usize, + scalars: *const $scalar, + ) -> cuda::Error; + + } + let mut ret = $point::default(); + let err = + unsafe { $name(&mut ret, &points[0], npoints, &scalars[0]) }; + assert!(err.code == 0, "{}", String::from(err)); + + return $point::new_jacobian(ret.x, ret.y, ret.z).unwrap(); + } + let mut ret = $point::default(); + unsafe { $name_cpu(&mut ret, &points[0], npoints, &scalars[0]) }; + $point::new_jacobian(ret.x, ret.y, ret.z).unwrap() + } + + pub fn init(points: Vec<$affine>) -> MSMContext { + #[cfg(feature = "cuda")] + if unsafe { !CUDA_OFF && cuda_available() } { + extern "C" { + fn $name_init( + points: *const $affine, + npoints: usize, + msm_context: &mut CudaMSMContext, + ) -> cuda::Error; + } + + let mut ret = CudaMSMContext::default(); + + let npoints = points.len(); + let err = unsafe { + $name_init(points.as_ptr() as *const _, npoints, &mut ret) + }; + assert!(err.code == 0, "{}", String::from(err)); + return MSMContext::new_cuda(ret); + } + + MSMContext::new_cpu(points) + } + + pub fn with(context: &MSMContext, scalars: &[$scalar]) -> $point { + assert!(context.npoints() >= scalars.len(), "not enough points"); + + let mut ret = $point::default(); + + #[cfg(feature = "cuda")] + if unsafe { !CUDA_OFF && cuda_available() } { + extern "C" { + fn $name_with( + out: *mut $point, + context: &CudaMSMContext, + scalars: *const $scalar, + ) -> cuda::Error; + } + + let err = unsafe { $name_with(&mut ret, context.cuda(), &scalars[0]) }; + assert!(err.code == 0, "{}", String::from(err)); + return $point::new_jacobian(ret.x, ret.y, ret.z).unwrap(); + } + + unsafe { + $name_cpu( + &mut ret, + &context.points()[0], + context.npoints(), + &scalars[0], + ) + }; + $point::new_jacobian(ret.x, ret.y, ret.z).unwrap() + } + }; - assert!(err.code == 0, "{}", String::from(err)); - - ret } #[cfg(test)] @@ -334,7 +261,7 @@ mod tests { let naive = naive_multiscalar_mul(&points, &scalars); println!("{:?}", naive); - let ret = crate::bn256(&points, &scalars).to_affine(); + let ret = crate::bn256::msm(&points, &scalars).to_affine(); println!("{:?}", ret); assert_eq!(ret, naive); diff --git a/src/pasta.rs b/src/pasta.rs index dc12f47..7203462 100644 --- a/src/pasta.rs +++ b/src/pasta.rs @@ -6,293 +6,223 @@ extern crate semolina; -use pasta_curves::pallas; - -#[cfg(feature = "cuda")] -use crate::{cuda, cuda_available, CUDA_OFF}; - -#[cfg(feature = "cuda")] -#[repr(C)] -#[derive(Debug, Clone)] -pub struct MSMContextPallas { - context: *const std::ffi::c_void, -} - -#[cfg(feature = "cuda")] -unsafe impl Send for MSMContextPallas {} - -#[cfg(feature = "cuda")] -unsafe impl Sync for MSMContextPallas {} - -#[cfg(feature = "cuda")] -impl Default for MSMContextPallas { - fn default() -> Self { - Self { - context: std::ptr::null(), - } - } -} - -#[cfg(feature = "cuda")] -// TODO: check for device-side memory leaks -impl Drop for MSMContextPallas { - fn drop(&mut self) { - extern "C" { - fn drop_msm_context_pallas(by_ref: &MSMContextPallas); - } - unsafe { drop_msm_context_pallas(std::mem::transmute::<&_, &_>(self)) }; - self.context = core::ptr::null(); - } -} - -extern "C" { - fn mult_pippenger_pallas( - out: *mut pallas::Point, - points: *const pallas::Affine, - npoints: usize, - scalars: *const pallas::Scalar, - is_mont: bool, +pub mod pallas { + use pasta_curves::pallas::{Affine, Point, Scalar}; + + use crate::impl_pasta; + + impl_pasta!( + cuda_pallas, + cuda_pallas_init, + cuda_pallas_with, + mult_pippenger_pallas, + Point, + Affine, + Scalar ); } -pub fn pallas( - points: &[pallas::Affine], - scalars: &[pallas::Scalar], -) -> pallas::Point { - let npoints = points.len(); - assert_eq!(npoints, scalars.len(), "length mismatch"); +pub mod vesta { + use pasta_curves::vesta::{Affine, Point, Scalar}; - #[cfg(feature = "cuda")] - if npoints >= 1 << 16 && unsafe { !CUDA_OFF && cuda_available() } { - extern "C" { - fn cuda_pallas( - out: *mut pallas::Point, - points: *const pallas::Affine, - npoints: usize, - scalars: *const pallas::Scalar, - is_mont: bool, - ) -> cuda::Error; - - } - let mut ret = pallas::Point::default(); - let err = unsafe { - cuda_pallas(&mut ret, &points[0], npoints, &scalars[0], true) - }; - assert!(err.code == 0, "{}", String::from(err)); + use crate::impl_pasta; - return ret; - } - let mut ret = pallas::Point::default(); - unsafe { - mult_pippenger_pallas(&mut ret, &points[0], npoints, &scalars[0], true) - }; - ret -} - -#[cfg(feature = "cuda")] -pub fn pallas_init( - points: &[pallas::Affine], - npoints: usize, -) -> MSMContextPallas { - unsafe { - assert!( - !CUDA_OFF && cuda_available(), - "feature = \"cuda\" must be enabled" - ) - }; - assert!( - npoints == points.len() && npoints >= 1 << 16, - "length mismatch or less than 10**16" + impl_pasta!( + cuda_vesta, + cuda_vesta_init, + cuda_vesta_with, + mult_pippenger_vesta, + Point, + Affine, + Scalar ); +} - extern "C" { - fn cuda_pallas_init( - points: *const pallas::Affine, +#[macro_export] +macro_rules! impl_pasta { + ( + $name:ident, + $name_init:ident, + $name_with:ident, + $name_cpu:ident, + $point:ident, + $affine:ident, + $scalar:ident + ) => { + #[cfg(feature = "cuda")] + use $crate::{cuda, cuda_available, CUDA_OFF}; + + #[repr(C)] + #[derive(Debug, Clone)] + pub struct CudaMSMContext { + context: *const std::ffi::c_void, npoints: usize, - msm_context: &mut MSMContextPallas, - ) -> cuda::Error; - - } + } - let mut ret = MSMContextPallas::default(); - let err = unsafe { - cuda_pallas_init(points.as_ptr() as *const _, npoints, &mut ret) - }; - assert!(err.code == 0, "{}", String::from(err)); + unsafe impl Send for CudaMSMContext {} - ret -} + unsafe impl Sync for CudaMSMContext {} -#[cfg(feature = "cuda")] -pub fn pallas_with( - context: &MSMContextPallas, - npoints: usize, - scalars: &[pallas::Scalar], -) -> pallas::Point { - unsafe { - assert!( - !CUDA_OFF && cuda_available(), - "feature = \"cuda\" must be enabled" - ) - }; - assert!( - npoints == scalars.len() && npoints >= 1 << 16, - "length mismatch or less than 10**16" - ); + impl Default for CudaMSMContext { + fn default() -> Self { + Self { + context: std::ptr::null(), + npoints: 0, + } + } + } - extern "C" { - fn cuda_pallas_with( - out: *mut pallas::Point, - context: &MSMContextPallas, - npoints: usize, - scalars: *const pallas::Scalar, - is_mont: bool, - ) -> cuda::Error; + #[cfg(feature = "cuda")] + // TODO: check for device-side memory leaks + impl Drop for CudaMSMContext { + fn drop(&mut self) { + extern "C" { + fn drop_msm_context_bn254(by_ref: &CudaMSMContext); + } + unsafe { + drop_msm_context_bn254(std::mem::transmute::<&_, &_>(self)) + }; + self.context = core::ptr::null(); + } + } - } + #[derive(Debug, Clone)] + pub enum MSMContext { + CUDA(CudaMSMContext), + CPU(Vec<$affine>), + } - let mut ret = pallas::Point::default(); - let err = unsafe { - cuda_pallas_with(&mut ret, context, npoints, &scalars[0], true) - }; - assert!(err.code == 0, "{}", String::from(err)); + unsafe impl Send for MSMContext {} - ret -} + unsafe impl Sync for MSMContext {} -use pasta_curves::vesta; + impl MSMContext { + pub fn new_cpu(points: Vec<$affine>) -> Self { + Self::CPU(points) + } -#[cfg(feature = "cuda")] -#[repr(C)] -#[derive(Debug, Clone)] -pub struct MSMContextVesta { - context: *const std::ffi::c_void, -} + pub fn new_cuda(cuda_context: CudaMSMContext) -> Self { + Self::CUDA(cuda_context) + } -#[cfg(feature = "cuda")] -unsafe impl Send for MSMContextVesta {} + pub fn npoints(&self) -> usize { + match self { + Self::CUDA(cuda_context) => cuda_context.npoints, + Self::CPU(points) => points.len(), + } + } -#[cfg(feature = "cuda")] -unsafe impl Sync for MSMContextVesta {} + pub fn cuda(&self) -> &CudaMSMContext { + match self { + Self::CUDA(cuda_context) => cuda_context, + Self::CPU(_) => panic!("not a cuda context"), + } + } -#[cfg(feature = "cuda")] -impl Default for MSMContextVesta { - fn default() -> Self { - Self { - context: std::ptr::null(), + pub fn points(&self) -> &Vec<$affine> { + match self { + Self::CUDA(_) => { + panic!("cuda context; no host side points") + } + Self::CPU(points) => points, + } + } } - } -} -extern "C" { - fn mult_pippenger_vesta( - out: *mut vesta::Point, - points: *const vesta::Affine, - npoints: usize, - scalars: *const vesta::Scalar, - is_mont: bool, - ); -} - -pub fn vesta( - points: &[vesta::Affine], - scalars: &[vesta::Scalar], -) -> vesta::Point { - let npoints = points.len(); - assert_eq!(npoints, scalars.len(), "length mismatch"); - - #[cfg(feature = "cuda")] - if npoints >= 1 << 16 && unsafe { !CUDA_OFF && cuda_available() } { extern "C" { - fn cuda_vesta( - out: *mut vesta::Point, - points: *const vesta::Affine, + fn $name_cpu( + out: *mut $point, + points: *const $affine, npoints: usize, - scalars: *const vesta::Scalar, - is_mont: bool, - ) -> cuda::Error; + scalars: *const $scalar, + ); } - let mut ret = vesta::Point::default(); - let err = unsafe { - cuda_vesta(&mut ret, &points[0], npoints, &scalars[0], true) - }; - assert!(err.code == 0, "{}", String::from(err)); - return ret; - } - let mut ret = vesta::Point::default(); - unsafe { - mult_pippenger_vesta(&mut ret, &points[0], npoints, &scalars[0], true) - }; - ret -} + pub fn msm(points: &[$affine], scalars: &[$scalar]) -> $point { + let npoints = points.len(); + assert!(npoints == scalars.len(), "length mismatch"); + + #[cfg(feature = "cuda")] + if npoints >= 1 << 16 && unsafe { !CUDA_OFF && cuda_available() } { + extern "C" { + fn $name( + out: *mut $point, + points: *const $affine, + npoints: usize, + scalars: *const $scalar, + ) -> cuda::Error; + + } + let mut ret = $point::default(); + let err = unsafe { + $name(&mut ret, &points[0], npoints, &scalars[0]) + }; + assert!(err.code == 0, "{}", String::from(err)); + + return ret; + } + let mut ret = $point::default(); + unsafe { $name_cpu(&mut ret, &points[0], npoints, &scalars[0]) }; + ret + } -#[cfg(feature = "cuda")] -pub fn vesta_init(points: &[vesta::Affine], npoints: usize) -> MSMContextVesta { - unsafe { - assert!( - !CUDA_OFF && cuda_available(), - "feature = \"cuda\" must be enabled" - ) - }; - assert!( - npoints == points.len() && npoints >= 1 << 16, - "length mismatch or less than 10**16" - ); - extern "C" { - fn cuda_vesta_init( - points: *const vesta::Affine, - npoints: usize, - msm_context: &mut MSMContextVesta, - ) -> cuda::Error; + pub fn init(points: Vec<$affine>) -> MSMContext { + #[cfg(feature = "cuda")] + if unsafe { !CUDA_OFF && cuda_available() } { + extern "C" { + fn $name_init( + points: *const $affine, + npoints: usize, + msm_context: &mut CudaMSMContext, + ) -> cuda::Error; + } + + let mut ret = CudaMSMContext::default(); + + let npoints = points.len(); + let err = unsafe { + $name_init(points.as_ptr() as *const _, npoints, &mut ret) + }; + assert!(err.code == 0, "{}", String::from(err)); + return MSMContext::new_cuda(ret); + } - } + MSMContext::new_cpu(points) + } - let mut ret = MSMContextVesta::default(); - let err = unsafe { - cuda_vesta_init(points.as_ptr() as *const _, npoints, &mut ret) - }; - assert!(err.code == 0, "{}", String::from(err)); + pub fn with(context: &MSMContext, scalars: &[$scalar]) -> $point { + assert!(context.npoints() >= scalars.len(), "not enough points"); - ret -} + let mut ret = $point::default(); -#[cfg(feature = "cuda")] -pub fn vesta_with( - context: &MSMContextVesta, - npoints: usize, - scalars: &[vesta::Scalar], -) -> vesta::Point { - unsafe { - assert!( - !CUDA_OFF && cuda_available(), - "feature = \"cuda\" must be enabled" - ) - }; - assert!( - npoints == scalars.len() && npoints >= 1 << 16, - "length mismatch or less than 10**16" - ); + #[cfg(feature = "cuda")] + if unsafe { !CUDA_OFF && cuda_available() } { + extern "C" { + fn $name_with( + out: *mut $point, + context: &CudaMSMContext, + scalars: *const $scalar, + ) -> cuda::Error; + } - extern "C" { - fn cuda_vesta_with( - out: *mut vesta::Point, - context: &MSMContextVesta, - npoints: usize, - scalars: *const vesta::Scalar, - is_mont: bool, - ) -> cuda::Error; + let err = unsafe { $name_with(&mut ret, context.cuda(), &scalars[0]) }; + assert!(err.code == 0, "{}", String::from(err)); + return ret; + } - } + unsafe { + $name_cpu( + &mut ret, + &context.points()[0], + context.npoints(), + &scalars[0], + ) + }; - let mut ret = vesta::Point::default(); - let err = unsafe { - cuda_vesta_with(&mut ret, context, npoints, &scalars[0], true) + ret + } }; - assert!(err.code == 0, "{}", String::from(err)); - - ret } pub mod utils { @@ -428,7 +358,7 @@ mod tests { let naive = naive_multiscalar_mul(&points, &scalars); println!("{:?}", naive); - let ret = pallas(&points, &scalars).to_affine(); + let ret = pallas::msm(&points, &scalars).to_affine(); println!("{:?}", ret); assert_eq!(ret, naive); From cf202411870b85c1172b6549e52e2a957d3579e5 Mon Sep 17 00:00:00 2001 From: Hanting Zhang Date: Wed, 10 Jan 2024 22:46:50 +0000 Subject: [PATCH 04/11] use references --- benches/grumpkin_msm.rs | 4 ++-- benches/pasta_msm.rs | 4 ++-- examples/pasta_msm.rs | 2 +- src/lib.rs | 16 ++++++++-------- src/pasta.rs | 16 ++++++++-------- 5 files changed, 21 insertions(+), 21 deletions(-) diff --git a/benches/grumpkin_msm.rs b/benches/grumpkin_msm.rs index 0bd2380..86b9edc 100644 --- a/benches/grumpkin_msm.rs +++ b/benches/grumpkin_msm.rs @@ -34,7 +34,7 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); - let context = grumpkin_msm::bn256::init(points.clone()); + let context = grumpkin_msm::bn256::init(&points); group.bench_function( format!("\"preallocate\" 2**{} points", bench_npow), @@ -69,7 +69,7 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); - let context = grumpkin_msm::bn256::init(points); + let context = grumpkin_msm::bn256::init(&points); group.bench_function( format!("preallocate 2**{} points", bench_npow), diff --git a/benches/pasta_msm.rs b/benches/pasta_msm.rs index ca5970c..2ddac1e 100644 --- a/benches/pasta_msm.rs +++ b/benches/pasta_msm.rs @@ -34,7 +34,7 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); - let context = grumpkin_msm::pasta::pallas::init(points.clone()); + let context = grumpkin_msm::pasta::pallas::init(&points); group.bench_function( format!("\"preallocate\" 2**{} points", bench_npow), @@ -69,7 +69,7 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); - let context = grumpkin_msm::pasta::pallas::init(points); + let context = grumpkin_msm::pasta::pallas::init(&points); group.bench_function( format!("preallocate 2**{} points", bench_npow), diff --git a/examples/pasta_msm.rs b/examples/pasta_msm.rs index 33a08de..f382756 100644 --- a/examples/pasta_msm.rs +++ b/examples/pasta_msm.rs @@ -23,7 +23,7 @@ fn main() { } let native = naive_multiscalar_mul(&points, &scalars); - let context = grumpkin_msm::pasta::pallas::init(points); + let context = grumpkin_msm::pasta::pallas::init(&points); let res = grumpkin_msm::pasta::pallas::with(&context, &scalars).to_affine(); assert_eq!(res, native); diff --git a/src/lib.rs b/src/lib.rs index 6a39b39..8070443 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -105,17 +105,17 @@ macro_rules! impl_msm { } #[derive(Debug, Clone)] - pub enum MSMContext { + pub enum MSMContext<'a> { CUDA(CudaMSMContext), - CPU(Vec<$affine>), + CPU(&'a [$affine]), } - unsafe impl Send for MSMContext {} + unsafe impl<'a> Send for MSMContext<'a> {} - unsafe impl Sync for MSMContext {} + unsafe impl<'a> Sync for MSMContext<'a> {} - impl MSMContext { - pub fn new_cpu(points: Vec<$affine>) -> Self { + impl<'a> MSMContext<'a> { + pub fn new_cpu(points: &'a [$affine]) -> Self { Self::CPU(points) } @@ -137,7 +137,7 @@ macro_rules! impl_msm { } } - pub fn points(&self) -> &Vec<$affine> { + pub fn points(&self) -> &[$affine] { match self { Self::CUDA(_) => { panic!("cuda context; no host side points") @@ -184,7 +184,7 @@ macro_rules! impl_msm { $point::new_jacobian(ret.x, ret.y, ret.z).unwrap() } - pub fn init(points: Vec<$affine>) -> MSMContext { + pub fn init(points: &[$affine]) -> MSMContext { #[cfg(feature = "cuda")] if unsafe { !CUDA_OFF && cuda_available() } { extern "C" { diff --git a/src/pasta.rs b/src/pasta.rs index 7203462..256f58c 100644 --- a/src/pasta.rs +++ b/src/pasta.rs @@ -87,17 +87,17 @@ macro_rules! impl_pasta { } #[derive(Debug, Clone)] - pub enum MSMContext { + pub enum MSMContext<'a> { CUDA(CudaMSMContext), - CPU(Vec<$affine>), + CPU(&'a [$affine]), } - unsafe impl Send for MSMContext {} + unsafe impl<'a> Send for MSMContext<'a> {} - unsafe impl Sync for MSMContext {} + unsafe impl<'a> Sync for MSMContext<'a> {} - impl MSMContext { - pub fn new_cpu(points: Vec<$affine>) -> Self { + impl<'a> MSMContext<'a> { + pub fn new_cpu(points: &'a [$affine]) -> Self { Self::CPU(points) } @@ -119,7 +119,7 @@ macro_rules! impl_pasta { } } - pub fn points(&self) -> &Vec<$affine> { + pub fn points(&self) -> &[$affine] { match self { Self::CUDA(_) => { panic!("cuda context; no host side points") @@ -167,7 +167,7 @@ macro_rules! impl_pasta { ret } - pub fn init(points: Vec<$affine>) -> MSMContext { + pub fn init(points: &[$affine]) -> MSMContext { #[cfg(feature = "cuda")] if unsafe { !CUDA_OFF && cuda_available() } { extern "C" { From 44b29445544bc3a5d0fcb2e6eb6a0dce27c59100 Mon Sep 17 00:00:00 2001 From: Hanting Zhang Date: Thu, 11 Jan 2024 03:13:45 +0000 Subject: [PATCH 05/11] add default --- src/lib.rs | 7 ++++++- src/pasta.rs | 7 ++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 8070443..a1c9bac 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -104,10 +104,12 @@ macro_rules! impl_msm { } } - #[derive(Debug, Clone)] + #[derive(Default, Debug, Clone)] pub enum MSMContext<'a> { CUDA(CudaMSMContext), CPU(&'a [$affine]), + #[default] + Uninit, } unsafe impl<'a> Send for MSMContext<'a> {} @@ -127,6 +129,7 @@ macro_rules! impl_msm { match self { Self::CUDA(cuda_context) => cuda_context.npoints, Self::CPU(points) => points.len(), + Self::Uninit => panic!("not initialized"), } } @@ -134,6 +137,7 @@ macro_rules! impl_msm { match self { Self::CUDA(cuda_context) => cuda_context, Self::CPU(_) => panic!("not a cuda context"), + Self::Uninit => panic!("not initialized"), } } @@ -143,6 +147,7 @@ macro_rules! impl_msm { panic!("cuda context; no host side points") } Self::CPU(points) => points, + Self::Uninit => panic!("not initialized"), } } } diff --git a/src/pasta.rs b/src/pasta.rs index 256f58c..5b072e2 100644 --- a/src/pasta.rs +++ b/src/pasta.rs @@ -86,10 +86,12 @@ macro_rules! impl_pasta { } } - #[derive(Debug, Clone)] + #[derive(Default, Debug, Clone)] pub enum MSMContext<'a> { CUDA(CudaMSMContext), CPU(&'a [$affine]), + #[default] + Uninit, } unsafe impl<'a> Send for MSMContext<'a> {} @@ -109,6 +111,7 @@ macro_rules! impl_pasta { match self { Self::CUDA(cuda_context) => cuda_context.npoints, Self::CPU(points) => points.len(), + Self::Uninit => panic!("not initialized"), } } @@ -116,6 +119,7 @@ macro_rules! impl_pasta { match self { Self::CUDA(cuda_context) => cuda_context, Self::CPU(_) => panic!("not a cuda context"), + Self::Uninit => panic!("not initialized"), } } @@ -125,6 +129,7 @@ macro_rules! impl_pasta { panic!("cuda context; no host side points") } Self::CPU(points) => points, + Self::Uninit => panic!("not initialized"), } } } From c9a2d654dc8e27ee16db20e2b83c6c899bdd691b Mon Sep 17 00:00:00 2001 From: Hanting Zhang Date: Thu, 11 Jan 2024 03:38:00 +0000 Subject: [PATCH 06/11] points() --- cuda/bn254.cu | 4 +- cuda/grumpkin.cu | 4 +- cuda/pallas.cu | 4 +- cuda/vesta.cu | 10 ++-- src/lib.rs | 131 +++++++++++++++++++++++++---------------------- src/pasta.rs | 102 ++++++++++++++++++++---------------- 6 files changed, 141 insertions(+), 114 deletions(-) diff --git a/cuda/bn254.cu b/cuda/bn254.cu index e0611e4..fc97057 100644 --- a/cuda/bn254.cu +++ b/cuda/bn254.cu @@ -34,9 +34,9 @@ extern "C" RustError cuda_bn254(point_t *out, const affine_t points[], size_t np return mult_pippenger(out, points, npoints, scalars); } -extern "C" RustError cuda_bn254_with(point_t *out, msm_context_t *msm_context, +extern "C" RustError cuda_bn254_with(point_t *out, msm_context_t *msm_context, size_t npoints, const scalar_t scalars[]) { - return mult_pippenger_with(out, msm_context, scalars); + return mult_pippenger_with(out, msm_context, npoints, scalars); } #endif diff --git a/cuda/grumpkin.cu b/cuda/grumpkin.cu index 106be8f..8403f92 100644 --- a/cuda/grumpkin.cu +++ b/cuda/grumpkin.cu @@ -34,9 +34,9 @@ extern "C" RustError cuda_grumpkin(point_t *out, const affine_t points[], size_t return mult_pippenger(out, points, npoints, scalars); } -extern "C" RustError cuda_grumpkin_with(point_t *out, msm_context_t *msm_context, +extern "C" RustError cuda_grumpkin_with(point_t *out, msm_context_t *msm_context, size_t npoints, const scalar_t scalars[]) { - return mult_pippenger_with(out, msm_context, scalars); + return mult_pippenger_with(out, msm_context, npoints, scalars); } #endif diff --git a/cuda/pallas.cu b/cuda/pallas.cu index dd76686..62a375a 100644 --- a/cuda/pallas.cu +++ b/cuda/pallas.cu @@ -34,10 +34,10 @@ extern "C" RustError cuda_pallas(point_t *out, const affine_t points[], size_t n return mult_pippenger(out, points, npoints, scalars); } -extern "C" RustError cuda_pallas_with(point_t *out, msm_context_t *msm_context, +extern "C" RustError cuda_pallas_with(point_t *out, msm_context_t *msm_context, size_t npoints, const scalar_t scalars[]) { - return mult_pippenger_with(out, msm_context, scalars); + return mult_pippenger_with(out, msm_context, npoints, scalars); } #endif diff --git a/cuda/vesta.cu b/cuda/vesta.cu index ff9d594..63db638 100644 --- a/cuda/vesta.cu +++ b/cuda/vesta.cu @@ -18,26 +18,26 @@ typedef pallas_t scalar_t; #ifndef __CUDA_ARCH__ -extern "C" void drop_msm_context_grumpkin(msm_context_t &ref) { +extern "C" void drop_msm_context_vesta(msm_context_t &ref) { CUDA_OK(cudaFree(ref.d_points)); } extern "C" RustError -cuda_grumpkin_init(const affine_t points[], size_t npoints, msm_context_t *msm_context) +cuda_vesta_init(const affine_t points[], size_t npoints, msm_context_t *msm_context) { return mult_pippenger_init(points, npoints, msm_context); } -extern "C" RustError cuda_grumpkin(point_t *out, const affine_t points[], size_t npoints, +extern "C" RustError cuda_vesta(point_t *out, const affine_t points[], size_t npoints, const scalar_t scalars[]) { return mult_pippenger(out, points, npoints, scalars); } -extern "C" RustError cuda_grumpkin_with(point_t *out, msm_context_t *msm_context, +extern "C" RustError cuda_vesta_with(point_t *out, msm_context_t *msm_context, size_t npoints, const scalar_t scalars[]) { - return mult_pippenger_with(out, msm_context, scalars); + return mult_pippenger_with(out, msm_context, npoints, scalars); } #endif diff --git a/src/lib.rs b/src/lib.rs index a1c9bac..cd51f3d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,9 +23,9 @@ pub mod bn256 { bn256::{Fr as Scalar, G1Affine as Affine, G1 as Point}, CurveExt, }; - + use crate::impl_msm; - + impl_msm!( cuda_bn254, cuda_bn254_init, @@ -42,9 +42,9 @@ pub mod grumpkin { grumpkin::{Fr as Scalar, G1Affine as Affine, G1 as Point}, CurveExt, }; - + use crate::impl_msm; - + impl_msm!( cuda_grumpkin, cuda_grumpkin_init, @@ -105,11 +105,10 @@ macro_rules! impl_msm { } #[derive(Default, Debug, Clone)] - pub enum MSMContext<'a> { - CUDA(CudaMSMContext), - CPU(&'a [$affine]), - #[default] - Uninit, + pub struct MSMContext<'a> { + cuda_context: CudaMSMContext, + on_gpu: bool, + cpu_context: &'a [$affine], } unsafe impl<'a> Send for MSMContext<'a> {} @@ -117,41 +116,33 @@ macro_rules! impl_msm { unsafe impl<'a> Sync for MSMContext<'a> {} impl<'a> MSMContext<'a> { - pub fn new_cpu(points: &'a [$affine]) -> Self { - Self::CPU(points) - } - - pub fn new_cuda(cuda_context: CudaMSMContext) -> Self { - Self::CUDA(cuda_context) + fn new(points: &'a [$affine]) -> Self { + Self { + cuda_context: CudaMSMContext::default(), + on_gpu: false, + cpu_context: points, + } } - pub fn npoints(&self) -> usize { - match self { - Self::CUDA(cuda_context) => cuda_context.npoints, - Self::CPU(points) => points.len(), - Self::Uninit => panic!("not initialized"), + fn npoints(&self) -> usize { + if self.on_gpu { + assert_eq!( + self.cpu_context.len(), + self.cuda_context.npoints + ); } + self.cpu_context.len() } - pub fn cuda(&self) -> &CudaMSMContext { - match self { - Self::CUDA(cuda_context) => cuda_context, - Self::CPU(_) => panic!("not a cuda context"), - Self::Uninit => panic!("not initialized"), - } + fn cuda(&self) -> &CudaMSMContext { + &self.cuda_context } - pub fn points(&self) -> &[$affine] { - match self { - Self::CUDA(_) => { - panic!("cuda context; no host side points") - } - Self::CPU(points) => points, - Self::Uninit => panic!("not initialized"), - } + fn points(&self) -> &[$affine] { + &self.cpu_context } } - + extern "C" { fn $name_cpu( out: *mut $point, @@ -159,13 +150,13 @@ macro_rules! impl_msm { npoints: usize, scalars: *const $scalar, ); - + } - + pub fn msm(points: &[$affine], scalars: &[$scalar]) -> $point { let npoints = points.len(); assert!(npoints == scalars.len(), "length mismatch"); - + #[cfg(feature = "cuda")] if npoints >= 1 << 16 && unsafe { !CUDA_OFF && cuda_available() } { extern "C" { @@ -175,23 +166,28 @@ macro_rules! impl_msm { npoints: usize, scalars: *const $scalar, ) -> cuda::Error; - + } let mut ret = $point::default(); - let err = - unsafe { $name(&mut ret, &points[0], npoints, &scalars[0]) }; + let err = unsafe { + $name(&mut ret, &points[0], npoints, &scalars[0]) + }; assert!(err.code == 0, "{}", String::from(err)); - + return $point::new_jacobian(ret.x, ret.y, ret.z).unwrap(); } let mut ret = $point::default(); unsafe { $name_cpu(&mut ret, &points[0], npoints, &scalars[0]) }; $point::new_jacobian(ret.x, ret.y, ret.z).unwrap() } - + pub fn init(points: &[$affine]) -> MSMContext { + let npoints = points.len(); + + let mut ret = MSMContext::new(points); + #[cfg(feature = "cuda")] - if unsafe { !CUDA_OFF && cuda_available() } { + if npoints >= 1 << 16 && unsafe { !CUDA_OFF && cuda_available() } { extern "C" { fn $name_init( points: *const $affine, @@ -200,50 +196,60 @@ macro_rules! impl_msm { ) -> cuda::Error; } - let mut ret = CudaMSMContext::default(); - let npoints = points.len(); let err = unsafe { - $name_init(points.as_ptr() as *const _, npoints, &mut ret) + $name_init( + points.as_ptr() as *const _, + npoints, + &mut ret.cuda_context, + ) }; assert!(err.code == 0, "{}", String::from(err)); - return MSMContext::new_cuda(ret); + ret.on_gpu = true; + return ret; } - MSMContext::new_cpu(points) + ret } - + pub fn with(context: &MSMContext, scalars: &[$scalar]) -> $point { - assert!(context.npoints() >= scalars.len(), "not enough points"); - + let npoints = context.npoints(); + let nscalars = scalars.len(); + assert!(npoints >= nscalars, "not enough points"); + let mut ret = $point::default(); - + #[cfg(feature = "cuda")] - if unsafe { !CUDA_OFF && cuda_available() } { + if nscalars >= 1 << 16 + && context.on_gpu + && unsafe { !CUDA_OFF && cuda_available() } + { extern "C" { fn $name_with( out: *mut $point, context: &CudaMSMContext, + npoints: usize, scalars: *const $scalar, ) -> cuda::Error; } - - let err = unsafe { $name_with(&mut ret, context.cuda(), &scalars[0]) }; + + let err = unsafe { + $name_with(&mut ret, &context.cuda_context, nscalars, &scalars[0]) + }; assert!(err.code == 0, "{}", String::from(err)); return $point::new_jacobian(ret.x, ret.y, ret.z).unwrap(); } - + unsafe { $name_cpu( &mut ret, - &context.points()[0], - context.npoints(), + &context.cpu_context[0], + nscalars, &scalars[0], ) }; $point::new_jacobian(ret.x, ret.y, ret.z).unwrap() } - }; } @@ -269,6 +275,11 @@ mod tests { let ret = crate::bn256::msm(&points, &scalars).to_affine(); println!("{:?}", ret); + let context = crate::bn256::init(&points); + let ret_other = crate::bn256::with(&context, &scalars).to_affine(); + println!("{:?}", ret_other); + assert_eq!(ret, naive); + assert_eq!(ret, ret_other); } } diff --git a/src/pasta.rs b/src/pasta.rs index 5b072e2..de0e781 100644 --- a/src/pasta.rs +++ b/src/pasta.rs @@ -87,11 +87,10 @@ macro_rules! impl_pasta { } #[derive(Default, Debug, Clone)] - pub enum MSMContext<'a> { - CUDA(CudaMSMContext), - CPU(&'a [$affine]), - #[default] - Uninit, + pub struct MSMContext<'a> { + cuda_context: CudaMSMContext, + on_gpu: bool, + cpu_context: &'a [$affine], } unsafe impl<'a> Send for MSMContext<'a> {} @@ -99,38 +98,30 @@ macro_rules! impl_pasta { unsafe impl<'a> Sync for MSMContext<'a> {} impl<'a> MSMContext<'a> { - pub fn new_cpu(points: &'a [$affine]) -> Self { - Self::CPU(points) - } - - pub fn new_cuda(cuda_context: CudaMSMContext) -> Self { - Self::CUDA(cuda_context) + fn new(points: &'a [$affine]) -> Self { + Self { + cuda_context: CudaMSMContext::default(), + on_gpu: false, + cpu_context: points, + } } - pub fn npoints(&self) -> usize { - match self { - Self::CUDA(cuda_context) => cuda_context.npoints, - Self::CPU(points) => points.len(), - Self::Uninit => panic!("not initialized"), + fn npoints(&self) -> usize { + if self.on_gpu { + assert_eq!( + self.cpu_context.len(), + self.cuda_context.npoints + ); } + self.cpu_context.len() } - pub fn cuda(&self) -> &CudaMSMContext { - match self { - Self::CUDA(cuda_context) => cuda_context, - Self::CPU(_) => panic!("not a cuda context"), - Self::Uninit => panic!("not initialized"), - } + fn cuda(&self) -> &CudaMSMContext { + &self.cuda_context } - pub fn points(&self) -> &[$affine] { - match self { - Self::CUDA(_) => { - panic!("cuda context; no host side points") - } - Self::CPU(points) => points, - Self::Uninit => panic!("not initialized"), - } + fn points(&self) -> &[$affine] { + &self.cpu_context } } @@ -140,6 +131,7 @@ macro_rules! impl_pasta { points: *const $affine, npoints: usize, scalars: *const $scalar, + is_mont: bool, ); } @@ -156,25 +148,32 @@ macro_rules! impl_pasta { points: *const $affine, npoints: usize, scalars: *const $scalar, + is_mont: bool, ) -> cuda::Error; } let mut ret = $point::default(); let err = unsafe { - $name(&mut ret, &points[0], npoints, &scalars[0]) + $name(&mut ret, &points[0], npoints, &scalars[0], true) }; assert!(err.code == 0, "{}", String::from(err)); return ret; } let mut ret = $point::default(); - unsafe { $name_cpu(&mut ret, &points[0], npoints, &scalars[0]) }; + unsafe { + $name_cpu(&mut ret, &points[0], npoints, &scalars[0], true) + }; ret } pub fn init(points: &[$affine]) -> MSMContext { + let npoints = points.len(); + + let mut ret = MSMContext::new(points); + #[cfg(feature = "cuda")] - if unsafe { !CUDA_OFF && cuda_available() } { + if npoints >= 1 << 16 && unsafe { !CUDA_OFF && cuda_available() } { extern "C" { fn $name_init( points: *const $affine, @@ -183,35 +182,46 @@ macro_rules! impl_pasta { ) -> cuda::Error; } - let mut ret = CudaMSMContext::default(); - let npoints = points.len(); let err = unsafe { - $name_init(points.as_ptr() as *const _, npoints, &mut ret) + $name_init( + points.as_ptr() as *const _, + npoints, + &mut ret.cuda_context, + ) }; assert!(err.code == 0, "{}", String::from(err)); - return MSMContext::new_cuda(ret); + ret.on_gpu = true; + return ret; } - MSMContext::new_cpu(points) + ret } pub fn with(context: &MSMContext, scalars: &[$scalar]) -> $point { - assert!(context.npoints() >= scalars.len(), "not enough points"); + let npoints = context.npoints(); + let nscalars = scalars.len(); + assert!(npoints >= nscalars, "not enough points"); let mut ret = $point::default(); #[cfg(feature = "cuda")] - if unsafe { !CUDA_OFF && cuda_available() } { + if nscalars >= 1 << 16 + && unsafe { !CUDA_OFF && cuda_available() } + { extern "C" { fn $name_with( out: *mut $point, context: &CudaMSMContext, + npoints: usize, scalars: *const $scalar, + is_mont: bool, ) -> cuda::Error; } - let err = unsafe { $name_with(&mut ret, context.cuda(), &scalars[0]) }; + let err = unsafe { + $name_with(&mut ret, context.cuda(), nscalars, &scalars[0], true) + }; assert!(err.code == 0, "{}", String::from(err)); return ret; } @@ -219,9 +229,10 @@ macro_rules! impl_pasta { unsafe { $name_cpu( &mut ret, - &context.points()[0], - context.npoints(), + &context.cpu_context[0], + nscalars, &scalars[0], + true, ) }; @@ -366,6 +377,11 @@ mod tests { let ret = pallas::msm(&points, &scalars).to_affine(); println!("{:?}", ret); + let context = pallas::init(&points); + let ret_other = pallas::with(&context, &scalars).to_affine(); + println!("{:?}", ret_other); + assert_eq!(ret, naive); + assert_eq!(ret, ret_other); } } From 5289c3ae1539ff8ea1fd840f14c4fa25dceb2a0d Mon Sep 17 00:00:00 2001 From: Hanting Zhang Date: Tue, 6 Feb 2024 21:09:02 +0000 Subject: [PATCH 07/11] indices work --- .vscode/launch.json | 140 +++++++++++++++++++++++++++++++++++++++ .vscode/settings.json | 5 ++ Cargo.toml | 2 +- benches/grumpkin_msm.rs | 2 +- cuda/bn254.cu | 9 +-- cuda/grumpkin.cu | 9 +-- cuda/pallas.cu | 9 +-- cuda/vesta.cu | 9 +-- examples/grumpkin_msm.rs | 14 ++-- src/lib.rs | 25 +++++-- src/pasta.rs | 4 +- 11 files changed, 200 insertions(+), 28 deletions(-) create mode 100644 .vscode/launch.json create mode 100644 .vscode/settings.json diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..7850a7c --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,140 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "type": "lldb", + "request": "launch", + "name": "Debug unit tests in library 'grumpkin-msm'", + "cargo": { + "args": [ + "test", + "--no-run", + "--lib", + "--package=grumpkin-msm" + ], + "filter": { + "name": "grumpkin-msm", + "kind": "lib" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug example 'pasta_msm'", + "cargo": { + "args": [ + "build", + "--example=pasta_msm", + "--package=grumpkin-msm" + ], + "filter": { + "name": "pasta_msm", + "kind": "example" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug unit tests in example 'pasta_msm'", + "cargo": { + "args": [ + "test", + "--no-run", + "--example=pasta_msm", + "--package=grumpkin-msm" + ], + "filter": { + "name": "pasta_msm", + "kind": "example" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug example 'grumpkin_msm'", + "cargo": { + "args": [ + "build", + "--example=grumpkin_msm", + "--package=grumpkin-msm", + "--features=cuda" + ], + "filter": { + "name": "grumpkin_msm", + "kind": "example" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug unit tests in example 'grumpkin_msm'", + "cargo": { + "args": [ + "test", + "--no-run", + "--example=grumpkin_msm", + "--package=grumpkin-msm" + ], + "filter": { + "name": "grumpkin_msm", + "kind": "example" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug benchmark 'grumpkin_msm'", + "cargo": { + "args": [ + "test", + "--no-run", + "--bench=grumpkin_msm", + "--package=grumpkin-msm" + ], + "filter": { + "name": "grumpkin_msm", + "kind": "bench" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug benchmark 'pasta_msm'", + "cargo": { + "args": [ + "test", + "--no-run", + "--bench=pasta_msm", + "--package=grumpkin-msm" + ], + "filter": { + "name": "pasta_msm", + "kind": "bench" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + } + ] +} \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..af13e3a --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,5 @@ +{ + "files.associations": { + "cstdint": "cpp" + } +} \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 25b170e..6f02639 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,7 +33,7 @@ cuda-mobile = [] [dependencies] blst = "~0.3.11" semolina = "~0.1.3" -sppark = { git = "https://github.com/lurk-lab/sppark", branch = "preallocate-msm" } +sppark = { git = "https://github.com/lurk-lab/sppark", branch = "indices" } halo2curves = { version = "0.6.0" } pasta_curves = { git = "https://github.com/lurk-lab/pasta_curves", branch = "dev", version = ">=0.3.1, <=0.5", features = ["repr-c"] } rand = "^0" diff --git a/benches/grumpkin_msm.rs b/benches/grumpkin_msm.rs index 86b9edc..c5e0492 100644 --- a/benches/grumpkin_msm.rs +++ b/benches/grumpkin_msm.rs @@ -40,7 +40,7 @@ fn criterion_benchmark(c: &mut Criterion) { format!("\"preallocate\" 2**{} points", bench_npow), |b| { b.iter(|| { - let _ = grumpkin_msm::bn256::with(&context, &scalars); + let _ = grumpkin_msm::bn256::with(&context, &scalars, None); }) }, ); diff --git a/cuda/bn254.cu b/cuda/bn254.cu index fc97057..08371d9 100644 --- a/cuda/bn254.cu +++ b/cuda/bn254.cu @@ -18,7 +18,8 @@ typedef fr_t scalar_t; #ifndef __CUDA_ARCH__ -extern "C" void drop_msm_context_bn254(msm_context_t &ref) { +extern "C" void drop_msm_context_bn254(msm_context_t &ref) +{ CUDA_OK(cudaFree(ref.d_points)); } @@ -29,14 +30,14 @@ cuda_bn254_init(const affine_t points[], size_t npoints, msm_context_t(out, points, npoints, scalars); } extern "C" RustError cuda_bn254_with(point_t *out, msm_context_t *msm_context, size_t npoints, - const scalar_t scalars[]) + const scalar_t scalars[], uint32_t pidx[]) { - return mult_pippenger_with(out, msm_context, npoints, scalars); + return mult_pippenger_with(out, msm_context, npoints, scalars, pidx); } #endif diff --git a/cuda/grumpkin.cu b/cuda/grumpkin.cu index 8403f92..1ebceb4 100644 --- a/cuda/grumpkin.cu +++ b/cuda/grumpkin.cu @@ -18,7 +18,8 @@ typedef fp_t scalar_t; #ifndef __CUDA_ARCH__ -extern "C" void drop_msm_context_grumpkin(msm_context_t &ref) { +extern "C" void drop_msm_context_grumpkin(msm_context_t &ref) +{ CUDA_OK(cudaFree(ref.d_points)); } @@ -29,14 +30,14 @@ cuda_grumpkin_init(const affine_t points[], size_t npoints, msm_context_t(out, points, npoints, scalars); } extern "C" RustError cuda_grumpkin_with(point_t *out, msm_context_t *msm_context, size_t npoints, - const scalar_t scalars[]) + const scalar_t scalars[], uint32_t pidx[]) { - return mult_pippenger_with(out, msm_context, npoints, scalars); + return mult_pippenger_with(out, msm_context, npoints, scalars, pidx); } #endif diff --git a/cuda/pallas.cu b/cuda/pallas.cu index 62a375a..8cbb181 100644 --- a/cuda/pallas.cu +++ b/cuda/pallas.cu @@ -18,7 +18,8 @@ typedef vesta_t scalar_t; #ifndef __CUDA_ARCH__ -extern "C" void drop_msm_context_pallas(msm_context_t &ref) { +extern "C" void drop_msm_context_pallas(msm_context_t &ref) +{ CUDA_OK(cudaFree(ref.d_points)); } @@ -29,15 +30,15 @@ cuda_pallas_init(const affine_t points[], size_t npoints, msm_context_t(out, points, npoints, scalars); } extern "C" RustError cuda_pallas_with(point_t *out, msm_context_t *msm_context, size_t npoints, - const scalar_t scalars[]) + const scalar_t scalars[], uint32_t pidx[]) { - return mult_pippenger_with(out, msm_context, npoints, scalars); + return mult_pippenger_with(out, msm_context, npoints, scalars, pidx); } #endif diff --git a/cuda/vesta.cu b/cuda/vesta.cu index 63db638..6675499 100644 --- a/cuda/vesta.cu +++ b/cuda/vesta.cu @@ -18,7 +18,8 @@ typedef pallas_t scalar_t; #ifndef __CUDA_ARCH__ -extern "C" void drop_msm_context_vesta(msm_context_t &ref) { +extern "C" void drop_msm_context_vesta(msm_context_t &ref) +{ CUDA_OK(cudaFree(ref.d_points)); } @@ -29,15 +30,15 @@ cuda_vesta_init(const affine_t points[], size_t npoints, msm_context_t(out, points, npoints, scalars); } extern "C" RustError cuda_vesta_with(point_t *out, msm_context_t *msm_context, size_t npoints, - const scalar_t scalars[]) + const scalar_t scalars[], uint32_t pidx[]) { - return mult_pippenger_with(out, msm_context, npoints, scalars); + return mult_pippenger_with(out, msm_context, npoints, scalars, pidx); } #endif diff --git a/examples/grumpkin_msm.rs b/examples/grumpkin_msm.rs index 56eaed5..29eca3f 100644 --- a/examples/grumpkin_msm.rs +++ b/examples/grumpkin_msm.rs @@ -6,22 +6,28 @@ use grumpkin_msm::cuda_available; fn main() { let bench_npow: usize = std::env::var("BENCH_NPOW") - .unwrap_or("17".to_string()) + .unwrap_or("23".to_string()) .parse() .unwrap(); let npoints: usize = 1 << bench_npow; println!("generating {} random points, just hang on...", npoints); let points = gen_points(npoints); - let scalars = gen_scalars(npoints); + let mut scalars = gen_scalars(npoints); #[cfg(feature = "cuda")] if unsafe { cuda_available() } { unsafe { grumpkin_msm::CUDA_OFF = false }; } - let res = grumpkin_msm::bn256::msm(&points, &scalars).to_affine(); - let native = naive_multiscalar_mul(&points, &scalars); + let context = grumpkin_msm::bn256::init(&points); + + let indices = (0..(npoints as u32)).rev().collect::>(); + let res = + grumpkin_msm::bn256::with(&context, &scalars, Some(indices.as_slice())) + .to_affine(); + scalars.reverse(); + let native = grumpkin_msm::bn256::msm(&points, &scalars).to_affine(); assert_eq!(res, native); println!("success!") } diff --git a/src/lib.rs b/src/lib.rs index cd51f3d..f2dd27d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -181,7 +181,7 @@ macro_rules! impl_msm { $point::new_jacobian(ret.x, ret.y, ret.z).unwrap() } - pub fn init(points: &[$affine]) -> MSMContext { + pub fn init(points: &[$affine]) -> MSMContext<'_> { let npoints = points.len(); let mut ret = MSMContext::new(points); @@ -212,7 +212,11 @@ macro_rules! impl_msm { ret } - pub fn with(context: &MSMContext, scalars: &[$scalar]) -> $point { + pub fn with( + context: &MSMContext<'_>, + scalars: &[$scalar], + indices: Option<&[u32]>, + ) -> $point { let npoints = context.npoints(); let nscalars = scalars.len(); assert!(npoints >= nscalars, "not enough points"); @@ -230,11 +234,23 @@ macro_rules! impl_msm { context: &CudaMSMContext, npoints: usize, scalars: *const $scalar, + indices: *const u32, ) -> cuda::Error; } + let indices = if let Some(inner) = indices { + inner.as_ptr() + } else { + std::ptr::null() + }; let err = unsafe { - $name_with(&mut ret, &context.cuda_context, nscalars, &scalars[0]) + $name_with( + &mut ret, + &context.cuda_context, + nscalars, + &scalars[0], + indices, + ) }; assert!(err.code == 0, "{}", String::from(err)); return $point::new_jacobian(ret.x, ret.y, ret.z).unwrap(); @@ -276,7 +292,8 @@ mod tests { println!("{:?}", ret); let context = crate::bn256::init(&points); - let ret_other = crate::bn256::with(&context, &scalars).to_affine(); + let ret_other = + crate::bn256::with(&context, &scalars, None).to_affine(); println!("{:?}", ret_other); assert_eq!(ret, naive); diff --git a/src/pasta.rs b/src/pasta.rs index de0e781..bfdeb7a 100644 --- a/src/pasta.rs +++ b/src/pasta.rs @@ -167,7 +167,7 @@ macro_rules! impl_pasta { ret } - pub fn init(points: &[$affine]) -> MSMContext { + pub fn init(points: &[$affine]) -> MSMContext<'_> { let npoints = points.len(); let mut ret = MSMContext::new(points); @@ -198,7 +198,7 @@ macro_rules! impl_pasta { ret } - pub fn with(context: &MSMContext, scalars: &[$scalar]) -> $point { + pub fn with(context: &MSMContext<'_>, scalars: &[$scalar]) -> $point { let npoints = context.npoints(); let nscalars = scalars.len(); assert!(npoints >= nscalars, "not enough points"); From 51850c88a720c1c0fcc4ac5a459ba0e035d08853 Mon Sep 17 00:00:00 2001 From: Hanting Zhang Date: Wed, 7 Feb 2024 20:35:14 +0000 Subject: [PATCH 08/11] why no work??? --- benches/grumpkin_msm.rs | 14 +++++++++++++- cuda/bn254.cu | 10 +++++----- cuda/grumpkin.cu | 8 ++++---- cuda/pallas.cu | 8 ++++---- cuda/vesta.cu | 8 ++++---- 5 files changed, 30 insertions(+), 18 deletions(-) diff --git a/benches/grumpkin_msm.rs b/benches/grumpkin_msm.rs index c5e0492..662538f 100644 --- a/benches/grumpkin_msm.rs +++ b/benches/grumpkin_msm.rs @@ -71,12 +71,24 @@ fn criterion_benchmark(c: &mut Criterion) { let context = grumpkin_msm::bn256::init(&points); + let indices = (0..(npoints as u32)).rev().collect::>(); group.bench_function( format!("preallocate 2**{} points", bench_npow), |b| { b.iter(|| { let _ = - grumpkin_msm::bn256::with(&context, &scalars); + grumpkin_msm::bn256::with(&context, &scalars, Some(indices.as_slice())); + }) + }, + ); + + scalars.reverse(); + group.bench_function( + format!("preallocate 2**{} points rev", bench_npow), + |b| { + b.iter(|| { + let _ = + grumpkin_msm::bn256::with(&context, &scalars, None); }) }, ); diff --git a/cuda/bn254.cu b/cuda/bn254.cu index 08371d9..baebcac 100644 --- a/cuda/bn254.cu +++ b/cuda/bn254.cu @@ -30,14 +30,14 @@ cuda_bn254_init(const affine_t points[], size_t npoints, msm_context_t(out, points, npoints, scalars); + return mult_pippenger(out, points, npoints, scalars, nscalars); } -extern "C" RustError cuda_bn254_with(point_t *out, msm_context_t *msm_context, size_t npoints, - const scalar_t scalars[], uint32_t pidx[]) +extern "C" RustError cuda_bn254_with(point_t *out, msm_context_t *msm_context, + const scalar_t scalars[], size_t nscalars, uint32_t pidx[]) { - return mult_pippenger_with(out, msm_context, npoints, scalars, pidx); + return mult_pippenger_with(out, msm_context, scalars, nscalars, pidx); } #endif diff --git a/cuda/grumpkin.cu b/cuda/grumpkin.cu index 1ebceb4..d4c7d18 100644 --- a/cuda/grumpkin.cu +++ b/cuda/grumpkin.cu @@ -30,14 +30,14 @@ cuda_grumpkin_init(const affine_t points[], size_t npoints, msm_context_t(out, points, npoints, scalars); + return mult_pippenger(out, points, npoints, scalars, nscalars); } extern "C" RustError cuda_grumpkin_with(point_t *out, msm_context_t *msm_context, size_t npoints, - const scalar_t scalars[], uint32_t pidx[]) + const scalar_t scalars[], size_t nscalars, uint32_t pidx[]) { - return mult_pippenger_with(out, msm_context, npoints, scalars, pidx); + return mult_pippenger_with(out, msm_context, npoints, scalars, nscalars, pidx); } #endif diff --git a/cuda/pallas.cu b/cuda/pallas.cu index 8cbb181..a8ed5dd 100644 --- a/cuda/pallas.cu +++ b/cuda/pallas.cu @@ -30,15 +30,15 @@ cuda_pallas_init(const affine_t points[], size_t npoints, msm_context_t(out, points, npoints, scalars); + return mult_pippenger(out, points, npoints, scalars, nscalars); } extern "C" RustError cuda_pallas_with(point_t *out, msm_context_t *msm_context, size_t npoints, - const scalar_t scalars[], uint32_t pidx[]) + const scalar_t scalars[], size_t nscalars, uint32_t pidx[]) { - return mult_pippenger_with(out, msm_context, npoints, scalars, pidx); + return mult_pippenger_with(out, msm_context, npoints, scalars, nscalars, pidx); } #endif diff --git a/cuda/vesta.cu b/cuda/vesta.cu index 6675499..40b8579 100644 --- a/cuda/vesta.cu +++ b/cuda/vesta.cu @@ -30,15 +30,15 @@ cuda_vesta_init(const affine_t points[], size_t npoints, msm_context_t(out, points, npoints, scalars); + return mult_pippenger(out, points, npoints, scalars, nscalars); } extern "C" RustError cuda_vesta_with(point_t *out, msm_context_t *msm_context, size_t npoints, - const scalar_t scalars[], uint32_t pidx[]) + const scalar_t scalars[], size_t nscalars, uint32_t pidx[]) { - return mult_pippenger_with(out, msm_context, npoints, scalars, pidx); + return mult_pippenger_with(out, msm_context, npoints, scalars, nscalars, pidx); } #endif From 047f8e67050b7745b8e641fb843f1baef6b1a327 Mon Sep 17 00:00:00 2001 From: Hanting Zhang Date: Thu, 8 Feb 2024 23:38:03 +0000 Subject: [PATCH 09/11] cleanup API --- .gitignore | 3 + .vscode/launch.json | 140 --------------------------------------- .vscode/settings.json | 5 -- benches/grumpkin_msm.rs | 20 ++++-- benches/pasta_msm.rs | 30 +++++++-- cuda/bn254.cu | 4 +- cuda/grumpkin.cu | 8 +-- cuda/pallas.cu | 8 +-- cuda/vesta.cu | 8 +-- examples/grumpkin_msm.rs | 23 +++++-- examples/pasta_msm.rs | 23 +++++-- src/lib.rs | 85 +++++++++++++++++++++--- src/pasta.rs | 108 ++++++++++++++++++++++++++---- 13 files changed, 261 insertions(+), 204 deletions(-) delete mode 100644 .vscode/launch.json delete mode 100644 .vscode/settings.json diff --git a/.gitignore b/.gitignore index 6985cf1..555be08 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,6 @@ Cargo.lock # MSVC Windows builds of rustc generate these, which store debugging information *.pdb + +# VSCode +.vscode \ No newline at end of file diff --git a/.vscode/launch.json b/.vscode/launch.json deleted file mode 100644 index 7850a7c..0000000 --- a/.vscode/launch.json +++ /dev/null @@ -1,140 +0,0 @@ -{ - // Use IntelliSense to learn about possible attributes. - // Hover to view descriptions of existing attributes. - // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 - "version": "0.2.0", - "configurations": [ - { - "type": "lldb", - "request": "launch", - "name": "Debug unit tests in library 'grumpkin-msm'", - "cargo": { - "args": [ - "test", - "--no-run", - "--lib", - "--package=grumpkin-msm" - ], - "filter": { - "name": "grumpkin-msm", - "kind": "lib" - } - }, - "args": [], - "cwd": "${workspaceFolder}" - }, - { - "type": "lldb", - "request": "launch", - "name": "Debug example 'pasta_msm'", - "cargo": { - "args": [ - "build", - "--example=pasta_msm", - "--package=grumpkin-msm" - ], - "filter": { - "name": "pasta_msm", - "kind": "example" - } - }, - "args": [], - "cwd": "${workspaceFolder}" - }, - { - "type": "lldb", - "request": "launch", - "name": "Debug unit tests in example 'pasta_msm'", - "cargo": { - "args": [ - "test", - "--no-run", - "--example=pasta_msm", - "--package=grumpkin-msm" - ], - "filter": { - "name": "pasta_msm", - "kind": "example" - } - }, - "args": [], - "cwd": "${workspaceFolder}" - }, - { - "type": "lldb", - "request": "launch", - "name": "Debug example 'grumpkin_msm'", - "cargo": { - "args": [ - "build", - "--example=grumpkin_msm", - "--package=grumpkin-msm", - "--features=cuda" - ], - "filter": { - "name": "grumpkin_msm", - "kind": "example" - } - }, - "args": [], - "cwd": "${workspaceFolder}" - }, - { - "type": "lldb", - "request": "launch", - "name": "Debug unit tests in example 'grumpkin_msm'", - "cargo": { - "args": [ - "test", - "--no-run", - "--example=grumpkin_msm", - "--package=grumpkin-msm" - ], - "filter": { - "name": "grumpkin_msm", - "kind": "example" - } - }, - "args": [], - "cwd": "${workspaceFolder}" - }, - { - "type": "lldb", - "request": "launch", - "name": "Debug benchmark 'grumpkin_msm'", - "cargo": { - "args": [ - "test", - "--no-run", - "--bench=grumpkin_msm", - "--package=grumpkin-msm" - ], - "filter": { - "name": "grumpkin_msm", - "kind": "bench" - } - }, - "args": [], - "cwd": "${workspaceFolder}" - }, - { - "type": "lldb", - "request": "launch", - "name": "Debug benchmark 'pasta_msm'", - "cargo": { - "args": [ - "test", - "--no-run", - "--bench=pasta_msm", - "--package=grumpkin-msm" - ], - "filter": { - "name": "pasta_msm", - "kind": "bench" - } - }, - "args": [], - "cwd": "${workspaceFolder}" - } - ] -} \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index af13e3a..0000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "files.associations": { - "cstdint": "cpp" - } -} \ No newline at end of file diff --git a/benches/grumpkin_msm.rs b/benches/grumpkin_msm.rs index 662538f..993d5e4 100644 --- a/benches/grumpkin_msm.rs +++ b/benches/grumpkin_msm.rs @@ -30,7 +30,7 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function(format!("2**{} points", bench_npow), |b| { b.iter(|| { - let _ = grumpkin_msm::bn256::msm(&points, &scalars); + let _ = grumpkin_msm::bn256::msm_aux(&points, &scalars, None); }) }); @@ -40,7 +40,9 @@ fn criterion_benchmark(c: &mut Criterion) { format!("\"preallocate\" 2**{} points", bench_npow), |b| { b.iter(|| { - let _ = grumpkin_msm::bn256::with(&context, &scalars, None); + let _ = grumpkin_msm::bn256::with_context_aux( + &context, &scalars, None, + ); }) }, ); @@ -65,7 +67,7 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function(format!("2**{} points", bench_npow), |b| { b.iter(|| { - let _ = grumpkin_msm::bn256::msm(&points, &scalars); + let _ = grumpkin_msm::bn256::msm_aux(&points, &scalars, None); }) }); @@ -76,8 +78,11 @@ fn criterion_benchmark(c: &mut Criterion) { format!("preallocate 2**{} points", bench_npow), |b| { b.iter(|| { - let _ = - grumpkin_msm::bn256::with(&context, &scalars, Some(indices.as_slice())); + let _ = grumpkin_msm::bn256::with_context_aux( + &context, + &scalars, + Some(indices.as_slice()), + ); }) }, ); @@ -87,8 +92,9 @@ fn criterion_benchmark(c: &mut Criterion) { format!("preallocate 2**{} points rev", bench_npow), |b| { b.iter(|| { - let _ = - grumpkin_msm::bn256::with(&context, &scalars, None); + let _ = grumpkin_msm::bn256::with_context_aux( + &context, &scalars, None, + ); }) }, ); diff --git a/benches/pasta_msm.rs b/benches/pasta_msm.rs index 2ddac1e..1aa542a 100644 --- a/benches/pasta_msm.rs +++ b/benches/pasta_msm.rs @@ -30,7 +30,8 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function(format!("2**{} points", bench_npow), |b| { b.iter(|| { - let _ = grumpkin_msm::pasta::pallas::msm(&points, &scalars); + let _ = + grumpkin_msm::pasta::pallas::msm_aux(&points, &scalars, None); }) }); @@ -40,7 +41,9 @@ fn criterion_benchmark(c: &mut Criterion) { format!("\"preallocate\" 2**{} points", bench_npow), |b| { b.iter(|| { - let _ = grumpkin_msm::pasta::pallas::with(&context, &scalars); + let _ = grumpkin_msm::pasta::pallas::with_context_aux( + &context, &scalars, None, + ); }) }, ); @@ -65,18 +68,35 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function(format!("2**{} points", bench_npow), |b| { b.iter(|| { - let _ = grumpkin_msm::pasta::pallas::msm(&points, &scalars); + let _ = grumpkin_msm::pasta::pallas::msm_aux( + &points, &scalars, None, + ); }) }); let context = grumpkin_msm::pasta::pallas::init(&points); + let indices = (0..(npoints as u32)).rev().collect::>(); group.bench_function( format!("preallocate 2**{} points", bench_npow), |b| { b.iter(|| { - let _ = - grumpkin_msm::pasta::pallas::with(&context, &scalars); + let _ = grumpkin_msm::pasta::pallas::with_context_aux( + &context, + &scalars, + Some(indices.as_slice()), + ); + }) + }, + ); + + group.bench_function( + format!("preallocate 2**{} points rev", bench_npow), + |b| { + b.iter(|| { + let _ = grumpkin_msm::pasta::pallas::with_context_aux( + &context, &scalars, None, + ); }) }, ); diff --git a/cuda/bn254.cu b/cuda/bn254.cu index baebcac..e6682e6 100644 --- a/cuda/bn254.cu +++ b/cuda/bn254.cu @@ -30,9 +30,9 @@ cuda_bn254_init(const affine_t points[], size_t npoints, msm_context_t(out, points, npoints, scalars, nscalars); + return mult_pippenger(out, points, npoints, scalars, nscalars, pidx); } extern "C" RustError cuda_bn254_with(point_t *out, msm_context_t *msm_context, diff --git a/cuda/grumpkin.cu b/cuda/grumpkin.cu index d4c7d18..1c68d6a 100644 --- a/cuda/grumpkin.cu +++ b/cuda/grumpkin.cu @@ -30,14 +30,14 @@ cuda_grumpkin_init(const affine_t points[], size_t npoints, msm_context_t(out, points, npoints, scalars, nscalars); + return mult_pippenger(out, points, npoints, scalars, nscalars, pidx); } -extern "C" RustError cuda_grumpkin_with(point_t *out, msm_context_t *msm_context, size_t npoints, +extern "C" RustError cuda_grumpkin_with(point_t *out, msm_context_t *msm_context, const scalar_t scalars[], size_t nscalars, uint32_t pidx[]) { - return mult_pippenger_with(out, msm_context, npoints, scalars, nscalars, pidx); + return mult_pippenger_with(out, msm_context, scalars, nscalars, pidx); } #endif diff --git a/cuda/pallas.cu b/cuda/pallas.cu index a8ed5dd..ecfc26f 100644 --- a/cuda/pallas.cu +++ b/cuda/pallas.cu @@ -30,15 +30,15 @@ cuda_pallas_init(const affine_t points[], size_t npoints, msm_context_t(out, points, npoints, scalars, nscalars); + return mult_pippenger(out, points, npoints, scalars, nscalars, pidx); } -extern "C" RustError cuda_pallas_with(point_t *out, msm_context_t *msm_context, size_t npoints, +extern "C" RustError cuda_pallas_with(point_t *out, msm_context_t *msm_context, const scalar_t scalars[], size_t nscalars, uint32_t pidx[]) { - return mult_pippenger_with(out, msm_context, npoints, scalars, nscalars, pidx); + return mult_pippenger_with(out, msm_context, scalars, nscalars, pidx); } #endif diff --git a/cuda/vesta.cu b/cuda/vesta.cu index 40b8579..ec24f66 100644 --- a/cuda/vesta.cu +++ b/cuda/vesta.cu @@ -30,15 +30,15 @@ cuda_vesta_init(const affine_t points[], size_t npoints, msm_context_t(out, points, npoints, scalars, nscalars); + return mult_pippenger(out, points, npoints, scalars, nscalars, pidx); } -extern "C" RustError cuda_vesta_with(point_t *out, msm_context_t *msm_context, size_t npoints, +extern "C" RustError cuda_vesta_with(point_t *out, msm_context_t *msm_context, const scalar_t scalars[], size_t nscalars, uint32_t pidx[]) { - return mult_pippenger_with(out, msm_context, npoints, scalars, nscalars, pidx); + return mult_pippenger_with(out, msm_context, scalars, nscalars, pidx); } #endif diff --git a/examples/grumpkin_msm.rs b/examples/grumpkin_msm.rs index 29eca3f..208183f 100644 --- a/examples/grumpkin_msm.rs +++ b/examples/grumpkin_msm.rs @@ -6,7 +6,7 @@ use grumpkin_msm::cuda_available; fn main() { let bench_npow: usize = std::env::var("BENCH_NPOW") - .unwrap_or("23".to_string()) + .unwrap_or("22".to_string()) .parse() .unwrap(); let npoints: usize = 1 << bench_npow; @@ -23,11 +23,22 @@ fn main() { let context = grumpkin_msm::bn256::init(&points); let indices = (0..(npoints as u32)).rev().collect::>(); - let res = - grumpkin_msm::bn256::with(&context, &scalars, Some(indices.as_slice())) - .to_affine(); + let res = grumpkin_msm::bn256::with_context_aux( + &context, + &scalars, + Some(indices.as_slice()), + ) + .to_affine(); + println!("res: {:?}", res); + let res2 = grumpkin_msm::bn256::msm_aux( + &points, + &scalars, + Some(indices.as_slice()), + ) + .to_affine(); + println!("res2: {:?}", res2); scalars.reverse(); - let native = grumpkin_msm::bn256::msm(&points, &scalars).to_affine(); - assert_eq!(res, native); + let native = naive_multiscalar_mul(&points, &scalars); + println!("native: {:?}", native); println!("success!") } diff --git a/examples/pasta_msm.rs b/examples/pasta_msm.rs index f382756..0e917aa 100644 --- a/examples/pasta_msm.rs +++ b/examples/pasta_msm.rs @@ -15,17 +15,32 @@ fn main() { println!("generating {} random points, just hang on...", npoints); let points = gen_points(npoints); - let scalars = gen_scalars(npoints); + let mut scalars = gen_scalars(npoints); #[cfg(feature = "cuda")] if unsafe { cuda_available() } { unsafe { grumpkin_msm::CUDA_OFF = false }; } - let native = naive_multiscalar_mul(&points, &scalars); let context = grumpkin_msm::pasta::pallas::init(&points); - let res = grumpkin_msm::pasta::pallas::with(&context, &scalars).to_affine(); - assert_eq!(res, native); + let indices = (0..(npoints as u32)).rev().collect::>(); + let res = grumpkin_msm::pasta::pallas::with_context_aux( + &context, + &scalars, + Some(indices.as_slice()), + ) + .to_affine(); + println!("res: {:?}", res); + let res2 = grumpkin_msm::pasta::pallas::msm_aux( + &points, + &scalars, + Some(indices.as_slice()), + ) + .to_affine(); + println!("res2: {:?}", res2); + scalars.reverse(); + let native = naive_multiscalar_mul(&points, &scalars); + println!("native: {:?}", native); println!("success!") } diff --git a/src/lib.rs b/src/lib.rs index f2dd27d..2582d55 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -153,9 +153,18 @@ macro_rules! impl_msm { } - pub fn msm(points: &[$affine], scalars: &[$scalar]) -> $point { + pub fn msm_aux( + points: &[$affine], + scalars: &[$scalar], + indices: Option<&[u32]>, + ) -> $point { let npoints = points.len(); - assert!(npoints == scalars.len(), "length mismatch"); + let nscalars = scalars.len(); + if let Some(indices) = indices { + assert!(nscalars == indices.len(), "length mismatch"); + } else { + assert!(npoints == nscalars, "length mismatch"); + } #[cfg(feature = "cuda")] if npoints >= 1 << 16 && unsafe { !CUDA_OFF && cuda_available() } { @@ -165,22 +174,54 @@ macro_rules! impl_msm { points: *const $affine, npoints: usize, scalars: *const $scalar, + nscalars: usize, + indices: *const u32, ) -> cuda::Error; } + + let indices = if let Some(inner) = indices { + inner.as_ptr() + } else { + std::ptr::null() + }; + let mut ret = $point::default(); let err = unsafe { - $name(&mut ret, &points[0], npoints, &scalars[0]) + $name( + &mut ret, + &points[0], + npoints, + &scalars[0], + nscalars, + indices, + ) }; assert!(err.code == 0, "{}", String::from(err)); return $point::new_jacobian(ret.x, ret.y, ret.z).unwrap(); } + + assert!(indices.is_none(), "no cpu support for indexed MSMs"); let mut ret = $point::default(); unsafe { $name_cpu(&mut ret, &points[0], npoints, &scalars[0]) }; $point::new_jacobian(ret.x, ret.y, ret.z).unwrap() } + pub fn msm(points: &[$affine], scalars: &[$scalar]) -> $point { + msm_aux(points, scalars, None) + } + + /// An indexed MSM. We do not check if the indices are valid, i.e. + /// for i in 0..nscalars, 0 <= indices[i] < npoints + pub fn indexed_msm( + points: &[$affine], + scalars: &[$scalar], + indices: &[u32], + ) -> $point { + msm_aux(points, scalars, Some(indices)) + } + pub fn init(points: &[$affine]) -> MSMContext<'_> { let npoints = points.len(); @@ -212,14 +253,18 @@ macro_rules! impl_msm { ret } - pub fn with( + pub fn with_context_aux( context: &MSMContext<'_>, scalars: &[$scalar], indices: Option<&[u32]>, ) -> $point { let npoints = context.npoints(); let nscalars = scalars.len(); - assert!(npoints >= nscalars, "not enough points"); + if let Some(indices) = indices { + assert!(nscalars == indices.len(), "length mismatch"); + } else { + assert!(npoints >= nscalars, "not enough points"); + } let mut ret = $point::default(); @@ -232,8 +277,8 @@ macro_rules! impl_msm { fn $name_with( out: *mut $point, context: &CudaMSMContext, - npoints: usize, scalars: *const $scalar, + nscalars: usize, indices: *const u32, ) -> cuda::Error; } @@ -243,12 +288,13 @@ macro_rules! impl_msm { } else { std::ptr::null() }; + let err = unsafe { $name_with( &mut ret, &context.cuda_context, - nscalars, &scalars[0], + nscalars, indices, ) }; @@ -256,6 +302,7 @@ macro_rules! impl_msm { return $point::new_jacobian(ret.x, ret.y, ret.z).unwrap(); } + assert!(indices.is_none(), "no cpu support for indexed MSMs"); unsafe { $name_cpu( &mut ret, @@ -266,6 +313,23 @@ macro_rules! impl_msm { }; $point::new_jacobian(ret.x, ret.y, ret.z).unwrap() } + + pub fn with_context( + context: &MSMContext<'_>, + scalars: &[$scalar], + ) -> $point { + with_context_aux(context, scalars, None) + } + + /// An indexed MSM. We do not check if the indices are valid, i.e. + /// for i in 0..nscalars, 0 <= indices[i] < npoints + pub fn indexed_with_context( + context: &MSMContext<'_>, + scalars: &[$scalar], + indices: &[u32], + ) -> $point { + with_context_aux(context, scalars, Some(indices)) + } }; } @@ -276,7 +340,7 @@ mod tests { use crate::utils::{gen_points, gen_scalars, naive_multiscalar_mul}; #[test] - fn it_works() { + fn test_simple() { #[cfg(not(debug_assertions))] const NPOINTS: usize = 128 * 1024; #[cfg(debug_assertions)] @@ -288,12 +352,13 @@ mod tests { let naive = naive_multiscalar_mul(&points, &scalars); println!("{:?}", naive); - let ret = crate::bn256::msm(&points, &scalars).to_affine(); + let ret = crate::bn256::msm_aux(&points, &scalars, None).to_affine(); println!("{:?}", ret); let context = crate::bn256::init(&points); let ret_other = - crate::bn256::with(&context, &scalars, None).to_affine(); + crate::bn256::with_context_aux(&context, &scalars, None) + .to_affine(); println!("{:?}", ret_other); assert_eq!(ret, naive); diff --git a/src/pasta.rs b/src/pasta.rs index bfdeb7a..007b9de 100644 --- a/src/pasta.rs +++ b/src/pasta.rs @@ -136,9 +136,18 @@ macro_rules! impl_pasta { } - pub fn msm(points: &[$affine], scalars: &[$scalar]) -> $point { + pub fn msm_aux( + points: &[$affine], + scalars: &[$scalar], + indices: Option<&[u32]>, + ) -> $point { let npoints = points.len(); - assert!(npoints == scalars.len(), "length mismatch"); + let nscalars = scalars.len(); + if let Some(indices) = indices { + assert!(nscalars == indices.len(), "length mismatch"); + } else { + assert!(npoints == nscalars, "length mismatch"); + } #[cfg(feature = "cuda")] if npoints >= 1 << 16 && unsafe { !CUDA_OFF && cuda_available() } { @@ -148,18 +157,37 @@ macro_rules! impl_pasta { points: *const $affine, npoints: usize, scalars: *const $scalar, + nscalars: usize, + indices: *const u32, is_mont: bool, ) -> cuda::Error; } + + let indices = if let Some(inner) = indices { + inner.as_ptr() + } else { + std::ptr::null() + }; + let mut ret = $point::default(); let err = unsafe { - $name(&mut ret, &points[0], npoints, &scalars[0], true) + $name( + &mut ret, + &points[0], + npoints, + &scalars[0], + nscalars, + indices, + true, + ) }; assert!(err.code == 0, "{}", String::from(err)); return ret; } + + assert!(indices.is_none(), "no cpu support for indexed MSMs"); let mut ret = $point::default(); unsafe { $name_cpu(&mut ret, &points[0], npoints, &scalars[0], true) @@ -198,34 +226,70 @@ macro_rules! impl_pasta { ret } - pub fn with(context: &MSMContext<'_>, scalars: &[$scalar]) -> $point { + pub fn msm(points: &[$affine], scalars: &[$scalar]) -> $point { + msm_aux(points, scalars, None) + } + + /// An indexed MSM. We do not check if the indices are valid, i.e. + /// for i in 0..nscalars, 0 <= indices[i] < npoints + pub fn indexed_msm( + points: &[$affine], + scalars: &[$scalar], + indices: &[u32], + ) -> $point { + msm_aux(points, scalars, Some(indices)) + } + + pub fn with_context_aux( + context: &MSMContext<'_>, + scalars: &[$scalar], + indices: Option<&[u32]>, + ) -> $point { let npoints = context.npoints(); let nscalars = scalars.len(); - assert!(npoints >= nscalars, "not enough points"); + if let Some(indices) = indices { + assert!(nscalars == indices.len(), "length mismatch"); + } else { + assert!(npoints >= nscalars, "not enough points"); + } let mut ret = $point::default(); #[cfg(feature = "cuda")] - if nscalars >= 1 << 16 - && unsafe { !CUDA_OFF && cuda_available() } - { + if nscalars >= 1 << 16 && unsafe { !CUDA_OFF && cuda_available() } { extern "C" { fn $name_with( out: *mut $point, context: &CudaMSMContext, - npoints: usize, scalars: *const $scalar, + nscalars: usize, + indices: *const u32, is_mont: bool, ) -> cuda::Error; } + let indices = if let Some(inner) = indices { + inner.as_ptr() + } else { + std::ptr::null() + }; + let err = unsafe { - $name_with(&mut ret, context.cuda(), nscalars, &scalars[0], true) + $name_with( + &mut ret, + context.cuda(), + &scalars[0], + nscalars, + indices, + true, + ) }; assert!(err.code == 0, "{}", String::from(err)); return ret; } + assert!(indices.is_none(), "no cpu support for indexed MSMs"); + unsafe { $name_cpu( &mut ret, @@ -238,6 +302,23 @@ macro_rules! impl_pasta { ret } + + pub fn with_context( + context: &MSMContext<'_>, + scalars: &[$scalar], + ) -> $point { + with_context_aux(context, scalars, None) + } + + /// An indexed MSM. We do not check if the indices are valid, i.e. + /// for i in 0..nscalars, 0 <= indices[i] < npoints + pub fn indexed_with_context( + context: &MSMContext<'_>, + scalars: &[$scalar], + indices: &[u32], + ) -> $point { + with_context_aux(context, scalars, Some(indices)) + } }; } @@ -362,7 +443,7 @@ mod tests { }; #[test] - fn it_works() { + fn test_simple() { #[cfg(not(debug_assertions))] const NPOINTS: usize = 128 * 1024; #[cfg(debug_assertions)] @@ -374,11 +455,12 @@ mod tests { let naive = naive_multiscalar_mul(&points, &scalars); println!("{:?}", naive); - let ret = pallas::msm(&points, &scalars).to_affine(); + let ret = pallas::msm_aux(&points, &scalars, None).to_affine(); println!("{:?}", ret); let context = pallas::init(&points); - let ret_other = pallas::with(&context, &scalars).to_affine(); + let ret_other = + pallas::with_context_aux(&context, &scalars, None).to_affine(); println!("{:?}", ret_other); assert_eq!(ret, naive); From 325c29dc7204f95242c598feec9c382ddc15ebe7 Mon Sep 17 00:00:00 2001 From: Hanting Zhang Date: Fri, 9 Feb 2024 02:00:56 +0000 Subject: [PATCH 10/11] bechmarks --- benches/grumpkin_msm.rs | 15 ++++++++++++--- examples/grumpkin_msm.rs | 2 +- src/lib.rs | 4 ++++ src/pasta.rs | 4 ++++ 4 files changed, 21 insertions(+), 4 deletions(-) diff --git a/benches/grumpkin_msm.rs b/benches/grumpkin_msm.rs index 993d5e4..72c6d12 100644 --- a/benches/grumpkin_msm.rs +++ b/benches/grumpkin_msm.rs @@ -65,15 +65,24 @@ fn criterion_benchmark(c: &mut Criterion) { let mut group = c.benchmark_group("GPU"); group.sample_size(20); + let indices = (0..(npoints as u32)).rev().collect::>(); + group.bench_function(format!("2**{} points", bench_npow), |b| { b.iter(|| { let _ = grumpkin_msm::bn256::msm_aux(&points, &scalars, None); }) }); + scalars.reverse(); + group.bench_function(format!("2**{} points rev", bench_npow), |b| { + b.iter(|| { + let _ = grumpkin_msm::bn256::msm_aux(&points, &scalars, Some(indices.as_slice())); + }) + }); + let context = grumpkin_msm::bn256::init(&points); - let indices = (0..(npoints as u32)).rev().collect::>(); + scalars.reverse(); group.bench_function( format!("preallocate 2**{} points", bench_npow), |b| { @@ -81,7 +90,7 @@ fn criterion_benchmark(c: &mut Criterion) { let _ = grumpkin_msm::bn256::with_context_aux( &context, &scalars, - Some(indices.as_slice()), + None, ); }) }, @@ -93,7 +102,7 @@ fn criterion_benchmark(c: &mut Criterion) { |b| { b.iter(|| { let _ = grumpkin_msm::bn256::with_context_aux( - &context, &scalars, None, + &context, &scalars, Some(indices.as_slice()), ); }) }, diff --git a/examples/grumpkin_msm.rs b/examples/grumpkin_msm.rs index 208183f..1ab89c5 100644 --- a/examples/grumpkin_msm.rs +++ b/examples/grumpkin_msm.rs @@ -6,7 +6,7 @@ use grumpkin_msm::cuda_available; fn main() { let bench_npow: usize = std::env::var("BENCH_NPOW") - .unwrap_or("22".to_string()) + .unwrap_or("23".to_string()) .parse() .unwrap(); let npoints: usize = 1 << bench_npow; diff --git a/src/lib.rs b/src/lib.rs index 2582d55..e7b7a24 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -214,6 +214,10 @@ macro_rules! impl_msm { /// An indexed MSM. We do not check if the indices are valid, i.e. /// for i in 0..nscalars, 0 <= indices[i] < npoints + /// + /// Also, this version carries a performance penalty that all the + /// points must be moved onto the GPU once instead of in batches. + /// If the points are to be reused, please use the [`MSMContext`] API pub fn indexed_msm( points: &[$affine], scalars: &[$scalar], diff --git a/src/pasta.rs b/src/pasta.rs index 007b9de..be8682d 100644 --- a/src/pasta.rs +++ b/src/pasta.rs @@ -232,6 +232,10 @@ macro_rules! impl_pasta { /// An indexed MSM. We do not check if the indices are valid, i.e. /// for i in 0..nscalars, 0 <= indices[i] < npoints + /// + /// Also, this version carries a performance penalty that all the + /// points must be moved onto the GPU once instead of in batches. + /// If the points are to be reused, please use the [`MSMContext`] API pub fn indexed_msm( points: &[$affine], scalars: &[$scalar], From 02d059a70f0e837af39d9da04fa9d980a259aa20 Mon Sep 17 00:00:00 2001 From: Hanting Zhang Date: Fri, 9 Feb 2024 21:43:17 +0000 Subject: [PATCH 11/11] fmt --- benches/grumpkin_msm.rs | 14 +++++++++----- src/lib.rs | 2 +- src/pasta.rs | 2 +- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/benches/grumpkin_msm.rs b/benches/grumpkin_msm.rs index 72c6d12..b4c9f13 100644 --- a/benches/grumpkin_msm.rs +++ b/benches/grumpkin_msm.rs @@ -76,7 +76,11 @@ fn criterion_benchmark(c: &mut Criterion) { scalars.reverse(); group.bench_function(format!("2**{} points rev", bench_npow), |b| { b.iter(|| { - let _ = grumpkin_msm::bn256::msm_aux(&points, &scalars, Some(indices.as_slice())); + let _ = grumpkin_msm::bn256::msm_aux( + &points, + &scalars, + Some(indices.as_slice()), + ); }) }); @@ -88,9 +92,7 @@ fn criterion_benchmark(c: &mut Criterion) { |b| { b.iter(|| { let _ = grumpkin_msm::bn256::with_context_aux( - &context, - &scalars, - None, + &context, &scalars, None, ); }) }, @@ -102,7 +104,9 @@ fn criterion_benchmark(c: &mut Criterion) { |b| { b.iter(|| { let _ = grumpkin_msm::bn256::with_context_aux( - &context, &scalars, Some(indices.as_slice()), + &context, + &scalars, + Some(indices.as_slice()), ); }) }, diff --git a/src/lib.rs b/src/lib.rs index e7b7a24..82577da 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -214,7 +214,7 @@ macro_rules! impl_msm { /// An indexed MSM. We do not check if the indices are valid, i.e. /// for i in 0..nscalars, 0 <= indices[i] < npoints - /// + /// /// Also, this version carries a performance penalty that all the /// points must be moved onto the GPU once instead of in batches. /// If the points are to be reused, please use the [`MSMContext`] API diff --git a/src/pasta.rs b/src/pasta.rs index be8682d..52f71cf 100644 --- a/src/pasta.rs +++ b/src/pasta.rs @@ -232,7 +232,7 @@ macro_rules! impl_pasta { /// An indexed MSM. We do not check if the indices are valid, i.e. /// for i in 0..nscalars, 0 <= indices[i] < npoints - /// + /// /// Also, this version carries a performance penalty that all the /// points must be moved onto the GPU once instead of in batches. /// If the points are to be reused, please use the [`MSMContext`] API