diff --git a/.gitignore b/.gitignore index 6985cf1..555be08 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,6 @@ Cargo.lock # MSVC Windows builds of rustc generate these, which store debugging information *.pdb + +# VSCode +.vscode \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 629fe2e..6f02639 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,7 +33,7 @@ cuda-mobile = [] [dependencies] blst = "~0.3.11" semolina = "~0.1.3" -sppark = "~0.1.2" +sppark = { git = "https://github.com/lurk-lab/sppark", branch = "indices" } halo2curves = { version = "0.6.0" } pasta_curves = { git = "https://github.com/lurk-lab/pasta_curves", branch = "dev", version = ">=0.3.1, <=0.5", features = ["repr-c"] } rand = "^0" diff --git a/benches/grumpkin_msm.rs b/benches/grumpkin_msm.rs index 1575f48..b4c9f13 100644 --- a/benches/grumpkin_msm.rs +++ b/benches/grumpkin_msm.rs @@ -30,10 +30,23 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function(format!("2**{} points", bench_npow), |b| { b.iter(|| { - let _ = grumpkin_msm::bn256(&points, &scalars); + let _ = grumpkin_msm::bn256::msm_aux(&points, &scalars, None); }) }); + let context = grumpkin_msm::bn256::init(&points); + + group.bench_function( + format!("\"preallocate\" 2**{} points", bench_npow), + |b| { + b.iter(|| { + let _ = grumpkin_msm::bn256::with_context_aux( + &context, &scalars, None, + ); + }) + }, + ); + group.finish(); #[cfg(feature = "cuda")] @@ -52,12 +65,53 @@ fn criterion_benchmark(c: &mut Criterion) { let mut group = c.benchmark_group("GPU"); group.sample_size(20); + let indices = (0..(npoints as u32)).rev().collect::>(); + group.bench_function(format!("2**{} points", bench_npow), |b| { b.iter(|| { - let _ = grumpkin_msm::bn256(&points, &scalars); + let _ = grumpkin_msm::bn256::msm_aux(&points, &scalars, None); + }) + }); + + scalars.reverse(); + group.bench_function(format!("2**{} points rev", bench_npow), |b| { + b.iter(|| { + let _ = grumpkin_msm::bn256::msm_aux( + &points, + &scalars, + Some(indices.as_slice()), + ); }) }); + let context = grumpkin_msm::bn256::init(&points); + + scalars.reverse(); + group.bench_function( + format!("preallocate 2**{} points", bench_npow), + |b| { + b.iter(|| { + let _ = grumpkin_msm::bn256::with_context_aux( + &context, &scalars, None, + ); + }) + }, + ); + + scalars.reverse(); + group.bench_function( + format!("preallocate 2**{} points rev", bench_npow), + |b| { + b.iter(|| { + let _ = grumpkin_msm::bn256::with_context_aux( + &context, + &scalars, + Some(indices.as_slice()), + ); + }) + }, + ); + group.finish(); } } diff --git a/benches/pasta_msm.rs b/benches/pasta_msm.rs index f77610f..1aa542a 100644 --- a/benches/pasta_msm.rs +++ b/benches/pasta_msm.rs @@ -30,10 +30,24 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function(format!("2**{} points", bench_npow), |b| { b.iter(|| { - let _ = grumpkin_msm::pasta::pallas(&points, &scalars); + let _ = + grumpkin_msm::pasta::pallas::msm_aux(&points, &scalars, None); }) }); + let context = grumpkin_msm::pasta::pallas::init(&points); + + group.bench_function( + format!("\"preallocate\" 2**{} points", bench_npow), + |b| { + b.iter(|| { + let _ = grumpkin_msm::pasta::pallas::with_context_aux( + &context, &scalars, None, + ); + }) + }, + ); + group.finish(); #[cfg(feature = "cuda")] @@ -54,10 +68,39 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function(format!("2**{} points", bench_npow), |b| { b.iter(|| { - let _ = grumpkin_msm::pasta::pallas(&points, &scalars); + let _ = grumpkin_msm::pasta::pallas::msm_aux( + &points, &scalars, None, + ); }) }); + let context = grumpkin_msm::pasta::pallas::init(&points); + + let indices = (0..(npoints as u32)).rev().collect::>(); + group.bench_function( + format!("preallocate 2**{} points", bench_npow), + |b| { + b.iter(|| { + let _ = grumpkin_msm::pasta::pallas::with_context_aux( + &context, + &scalars, + Some(indices.as_slice()), + ); + }) + }, + ); + + group.bench_function( + format!("preallocate 2**{} points rev", bench_npow), + |b| { + b.iter(|| { + let _ = grumpkin_msm::pasta::pallas::with_context_aux( + &context, &scalars, None, + ); + }) + }, + ); + group.finish(); } } diff --git a/cuda/bn254.cu b/cuda/bn254.cu index 3c0c1c9..e6682e6 100644 --- a/cuda/bn254.cu +++ b/cuda/bn254.cu @@ -17,8 +17,27 @@ typedef fr_t scalar_t; #include #ifndef __CUDA_ARCH__ -extern "C" -RustError cuda_pippenger_bn254(point_t *out, const affine_t points[], size_t npoints, - const scalar_t scalars[]) -{ return mult_pippenger(out, points, npoints, scalars); } + +extern "C" void drop_msm_context_bn254(msm_context_t &ref) +{ + CUDA_OK(cudaFree(ref.d_points)); +} + +extern "C" RustError +cuda_bn254_init(const affine_t points[], size_t npoints, msm_context_t *msm_context) +{ + return mult_pippenger_init(points, npoints, msm_context); +} + +extern "C" RustError cuda_bn254(point_t *out, const affine_t points[], size_t npoints, + const scalar_t scalars[], size_t nscalars, uint32_t pidx[]) +{ + return mult_pippenger(out, points, npoints, scalars, nscalars, pidx); +} + +extern "C" RustError cuda_bn254_with(point_t *out, msm_context_t *msm_context, + const scalar_t scalars[], size_t nscalars, uint32_t pidx[]) +{ + return mult_pippenger_with(out, msm_context, scalars, nscalars, pidx); +} #endif diff --git a/cuda/grumpkin.cu b/cuda/grumpkin.cu index 861610d..1c68d6a 100644 --- a/cuda/grumpkin.cu +++ b/cuda/grumpkin.cu @@ -17,8 +17,27 @@ typedef fp_t scalar_t; #include #ifndef __CUDA_ARCH__ -extern "C" -RustError cuda_pippenger_grumpkin(point_t *out, const affine_t points[], size_t npoints, - const scalar_t scalars[]) -{ return mult_pippenger(out, points, npoints, scalars); } + +extern "C" void drop_msm_context_grumpkin(msm_context_t &ref) +{ + CUDA_OK(cudaFree(ref.d_points)); +} + +extern "C" RustError +cuda_grumpkin_init(const affine_t points[], size_t npoints, msm_context_t *msm_context) +{ + return mult_pippenger_init(points, npoints, msm_context); +} + +extern "C" RustError cuda_grumpkin(point_t *out, const affine_t points[], size_t npoints, + const scalar_t scalars[], size_t nscalars, uint32_t pidx[]) +{ + return mult_pippenger(out, points, npoints, scalars, nscalars, pidx); +} + +extern "C" RustError cuda_grumpkin_with(point_t *out, msm_context_t *msm_context, + const scalar_t scalars[], size_t nscalars, uint32_t pidx[]) +{ + return mult_pippenger_with(out, msm_context, scalars, nscalars, pidx); +} #endif diff --git a/cuda/pallas.cu b/cuda/pallas.cu index f3897bb..ecfc26f 100644 --- a/cuda/pallas.cu +++ b/cuda/pallas.cu @@ -17,8 +17,28 @@ typedef vesta_t scalar_t; #include #ifndef __CUDA_ARCH__ -extern "C" -RustError cuda_pippenger_pallas(point_t *out, const affine_t points[], size_t npoints, - const scalar_t scalars[]) -{ return mult_pippenger(out, points, npoints, scalars); } + +extern "C" void drop_msm_context_pallas(msm_context_t &ref) +{ + CUDA_OK(cudaFree(ref.d_points)); +} + +extern "C" RustError +cuda_pallas_init(const affine_t points[], size_t npoints, msm_context_t *msm_context) +{ + return mult_pippenger_init(points, npoints, msm_context); +} + +extern "C" RustError cuda_pallas(point_t *out, const affine_t points[], size_t npoints, + const scalar_t scalars[], size_t nscalars, uint32_t pidx[]) +{ + return mult_pippenger(out, points, npoints, scalars, nscalars, pidx); +} + +extern "C" RustError cuda_pallas_with(point_t *out, msm_context_t *msm_context, + const scalar_t scalars[], size_t nscalars, uint32_t pidx[]) +{ + return mult_pippenger_with(out, msm_context, scalars, nscalars, pidx); +} + #endif diff --git a/cuda/vesta.cu b/cuda/vesta.cu index a926c6d..ec24f66 100644 --- a/cuda/vesta.cu +++ b/cuda/vesta.cu @@ -17,8 +17,28 @@ typedef pallas_t scalar_t; #include #ifndef __CUDA_ARCH__ -extern "C" -RustError cuda_pippenger_vesta(point_t *out, const affine_t points[], size_t npoints, - const scalar_t scalars[]) -{ return mult_pippenger(out, points, npoints, scalars); } + +extern "C" void drop_msm_context_vesta(msm_context_t &ref) +{ + CUDA_OK(cudaFree(ref.d_points)); +} + +extern "C" RustError +cuda_vesta_init(const affine_t points[], size_t npoints, msm_context_t *msm_context) +{ + return mult_pippenger_init(points, npoints, msm_context); +} + +extern "C" RustError cuda_vesta(point_t *out, const affine_t points[], size_t npoints, + const scalar_t scalars[], size_t nscalars, uint32_t pidx[]) +{ + return mult_pippenger(out, points, npoints, scalars, nscalars, pidx); +} + +extern "C" RustError cuda_vesta_with(point_t *out, msm_context_t *msm_context, + const scalar_t scalars[], size_t nscalars, uint32_t pidx[]) +{ + return mult_pippenger_with(out, msm_context, scalars, nscalars, pidx); +} + #endif diff --git a/examples/grumpkin_msm.rs b/examples/grumpkin_msm.rs index 3ca237a..1ab89c5 100644 --- a/examples/grumpkin_msm.rs +++ b/examples/grumpkin_msm.rs @@ -6,22 +6,39 @@ use grumpkin_msm::cuda_available; fn main() { let bench_npow: usize = std::env::var("BENCH_NPOW") - .unwrap_or("17".to_string()) + .unwrap_or("23".to_string()) .parse() .unwrap(); let npoints: usize = 1 << bench_npow; println!("generating {} random points, just hang on...", npoints); let points = gen_points(npoints); - let scalars = gen_scalars(npoints); + let mut scalars = gen_scalars(npoints); #[cfg(feature = "cuda")] if unsafe { cuda_available() } { unsafe { grumpkin_msm::CUDA_OFF = false }; } - let res = grumpkin_msm::bn256(&points, &scalars).to_affine(); + let context = grumpkin_msm::bn256::init(&points); + + let indices = (0..(npoints as u32)).rev().collect::>(); + let res = grumpkin_msm::bn256::with_context_aux( + &context, + &scalars, + Some(indices.as_slice()), + ) + .to_affine(); + println!("res: {:?}", res); + let res2 = grumpkin_msm::bn256::msm_aux( + &points, + &scalars, + Some(indices.as_slice()), + ) + .to_affine(); + println!("res2: {:?}", res2); + scalars.reverse(); let native = naive_multiscalar_mul(&points, &scalars); - assert_eq!(res, native); + println!("native: {:?}", native); println!("success!") } diff --git a/examples/pasta_msm.rs b/examples/pasta_msm.rs index a682c35..0e917aa 100644 --- a/examples/pasta_msm.rs +++ b/examples/pasta_msm.rs @@ -8,22 +8,39 @@ use pasta_curves::group::Curve; fn main() { let bench_npow: usize = std::env::var("BENCH_NPOW") - .unwrap_or("17".to_string()) + .unwrap_or("22".to_string()) .parse() .unwrap(); let npoints: usize = 1 << bench_npow; println!("generating {} random points, just hang on...", npoints); let points = gen_points(npoints); - let scalars = gen_scalars(npoints); + let mut scalars = gen_scalars(npoints); #[cfg(feature = "cuda")] if unsafe { cuda_available() } { unsafe { grumpkin_msm::CUDA_OFF = false }; } - let res = grumpkin_msm::pasta::pallas(&points, &scalars).to_affine(); + let context = grumpkin_msm::pasta::pallas::init(&points); + + let indices = (0..(npoints as u32)).rev().collect::>(); + let res = grumpkin_msm::pasta::pallas::with_context_aux( + &context, + &scalars, + Some(indices.as_slice()), + ) + .to_affine(); + println!("res: {:?}", res); + let res2 = grumpkin_msm::pasta::pallas::msm_aux( + &points, + &scalars, + Some(indices.as_slice()), + ) + .to_affine(); + println!("res2: {:?}", res2); + scalars.reverse(); let native = naive_multiscalar_mul(&points, &scalars); - assert_eq!(res, native); + println!("native: {:?}", native); println!("success!") } diff --git a/src/lib.rs b/src/lib.rs index 16fb22a..82577da 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,90 +18,323 @@ extern "C" { #[cfg(feature = "cuda")] pub static mut CUDA_OFF: bool = false; -use halo2curves::bn256; -use halo2curves::CurveExt; +pub mod bn256 { + use halo2curves::{ + bn256::{Fr as Scalar, G1Affine as Affine, G1 as Point}, + CurveExt, + }; -extern "C" { - fn mult_pippenger_bn254( - out: *mut bn256::G1, - points: *const bn256::G1Affine, - npoints: usize, - scalars: *const bn256::Fr, + use crate::impl_msm; + + impl_msm!( + cuda_bn254, + cuda_bn254_init, + cuda_bn254_with, + mult_pippenger_bn254, + Point, + Affine, + Scalar ); +} + +pub mod grumpkin { + use halo2curves::{ + grumpkin::{Fr as Scalar, G1Affine as Affine, G1 as Point}, + CurveExt, + }; + + use crate::impl_msm; + impl_msm!( + cuda_grumpkin, + cuda_grumpkin_init, + cuda_grumpkin_with, + mult_pippenger_grumpkin, + Point, + Affine, + Scalar + ); } -pub fn bn256(points: &[bn256::G1Affine], scalars: &[bn256::Fr]) -> bn256::G1 { - let npoints = points.len(); - assert!(npoints == scalars.len(), "length mismatch"); +#[macro_export] +macro_rules! impl_msm { + ( + $name:ident, + $name_init:ident, + $name_with:ident, + $name_cpu:ident, + $point:ident, + $affine:ident, + $scalar:ident + ) => { + #[cfg(feature = "cuda")] + use $crate::{cuda, cuda_available, CUDA_OFF}; - #[cfg(feature = "cuda")] - if npoints >= 1 << 16 && unsafe { !CUDA_OFF && cuda_available() } { - extern "C" { - fn cuda_pippenger_bn254( - out: *mut bn256::G1, - points: *const bn256::G1Affine, - npoints: usize, - scalars: *const bn256::Fr, - ) -> cuda::Error; + #[repr(C)] + #[derive(Debug, Clone)] + pub struct CudaMSMContext { + context: *const std::ffi::c_void, + npoints: usize, + } + unsafe impl Send for CudaMSMContext {} + + unsafe impl Sync for CudaMSMContext {} + + impl Default for CudaMSMContext { + fn default() -> Self { + Self { + context: std::ptr::null(), + npoints: 0, + } + } } - let mut ret = bn256::G1::default(); - let err = unsafe { - cuda_pippenger_bn254(&mut ret, &points[0], npoints, &scalars[0]) - }; - assert!(err.code == 0, "{}", String::from(err)); - return bn256::G1::new_jacobian(ret.x, ret.y, ret.z).unwrap(); - } - let mut ret = bn256::G1::default(); - unsafe { mult_pippenger_bn254(&mut ret, &points[0], npoints, &scalars[0]) }; - bn256::G1::new_jacobian(ret.x, ret.y, ret.z).unwrap() -} + #[cfg(feature = "cuda")] + // TODO: check for device-side memory leaks + impl Drop for CudaMSMContext { + fn drop(&mut self) { + extern "C" { + fn drop_msm_context_bn254(by_ref: &CudaMSMContext); + } + unsafe { + drop_msm_context_bn254(std::mem::transmute::<&_, &_>(self)) + }; + self.context = core::ptr::null(); + } + } -use halo2curves::grumpkin; + #[derive(Default, Debug, Clone)] + pub struct MSMContext<'a> { + cuda_context: CudaMSMContext, + on_gpu: bool, + cpu_context: &'a [$affine], + } -extern "C" { - fn mult_pippenger_grumpkin( - out: *mut grumpkin::G1, - points: *const grumpkin::G1Affine, - npoints: usize, - scalars: *const grumpkin::Fr, - ); + unsafe impl<'a> Send for MSMContext<'a> {} -} + unsafe impl<'a> Sync for MSMContext<'a> {} + + impl<'a> MSMContext<'a> { + fn new(points: &'a [$affine]) -> Self { + Self { + cuda_context: CudaMSMContext::default(), + on_gpu: false, + cpu_context: points, + } + } -pub fn grumpkin( - points: &[grumpkin::G1Affine], - scalars: &[grumpkin::Fr], -) -> grumpkin::G1 { - let npoints = points.len(); - assert!(npoints == scalars.len(), "length mismatch"); + fn npoints(&self) -> usize { + if self.on_gpu { + assert_eq!( + self.cpu_context.len(), + self.cuda_context.npoints + ); + } + self.cpu_context.len() + } + + fn cuda(&self) -> &CudaMSMContext { + &self.cuda_context + } + + fn points(&self) -> &[$affine] { + &self.cpu_context + } + } - #[cfg(feature = "cuda")] - if npoints >= 1 << 16 && unsafe { !CUDA_OFF && cuda_available() } { extern "C" { - fn cuda_pippenger_grumpkin( - out: *mut grumpkin::G1, - points: *const grumpkin::G1Affine, + fn $name_cpu( + out: *mut $point, + points: *const $affine, npoints: usize, - scalars: *const grumpkin::Fr, - ) -> cuda::Error; + scalars: *const $scalar, + ); } - let mut ret = grumpkin::G1::default(); - let err = unsafe { - cuda_pippenger_grumpkin(&mut ret, &points[0], npoints, &scalars[0]) - }; - assert!(err.code == 0, "{}", String::from(err)); - return grumpkin::G1::new_jacobian(ret.x, ret.y, ret.z).unwrap(); - } - let mut ret = grumpkin::G1::default(); - unsafe { - mult_pippenger_grumpkin(&mut ret, &points[0], npoints, &scalars[0]) + pub fn msm_aux( + points: &[$affine], + scalars: &[$scalar], + indices: Option<&[u32]>, + ) -> $point { + let npoints = points.len(); + let nscalars = scalars.len(); + if let Some(indices) = indices { + assert!(nscalars == indices.len(), "length mismatch"); + } else { + assert!(npoints == nscalars, "length mismatch"); + } + + #[cfg(feature = "cuda")] + if npoints >= 1 << 16 && unsafe { !CUDA_OFF && cuda_available() } { + extern "C" { + fn $name( + out: *mut $point, + points: *const $affine, + npoints: usize, + scalars: *const $scalar, + nscalars: usize, + indices: *const u32, + ) -> cuda::Error; + + } + + let indices = if let Some(inner) = indices { + inner.as_ptr() + } else { + std::ptr::null() + }; + + let mut ret = $point::default(); + let err = unsafe { + $name( + &mut ret, + &points[0], + npoints, + &scalars[0], + nscalars, + indices, + ) + }; + assert!(err.code == 0, "{}", String::from(err)); + + return $point::new_jacobian(ret.x, ret.y, ret.z).unwrap(); + } + + assert!(indices.is_none(), "no cpu support for indexed MSMs"); + let mut ret = $point::default(); + unsafe { $name_cpu(&mut ret, &points[0], npoints, &scalars[0]) }; + $point::new_jacobian(ret.x, ret.y, ret.z).unwrap() + } + + pub fn msm(points: &[$affine], scalars: &[$scalar]) -> $point { + msm_aux(points, scalars, None) + } + + /// An indexed MSM. We do not check if the indices are valid, i.e. + /// for i in 0..nscalars, 0 <= indices[i] < npoints + /// + /// Also, this version carries a performance penalty that all the + /// points must be moved onto the GPU once instead of in batches. + /// If the points are to be reused, please use the [`MSMContext`] API + pub fn indexed_msm( + points: &[$affine], + scalars: &[$scalar], + indices: &[u32], + ) -> $point { + msm_aux(points, scalars, Some(indices)) + } + + pub fn init(points: &[$affine]) -> MSMContext<'_> { + let npoints = points.len(); + + let mut ret = MSMContext::new(points); + + #[cfg(feature = "cuda")] + if npoints >= 1 << 16 && unsafe { !CUDA_OFF && cuda_available() } { + extern "C" { + fn $name_init( + points: *const $affine, + npoints: usize, + msm_context: &mut CudaMSMContext, + ) -> cuda::Error; + } + + let npoints = points.len(); + let err = unsafe { + $name_init( + points.as_ptr() as *const _, + npoints, + &mut ret.cuda_context, + ) + }; + assert!(err.code == 0, "{}", String::from(err)); + ret.on_gpu = true; + return ret; + } + + ret + } + + pub fn with_context_aux( + context: &MSMContext<'_>, + scalars: &[$scalar], + indices: Option<&[u32]>, + ) -> $point { + let npoints = context.npoints(); + let nscalars = scalars.len(); + if let Some(indices) = indices { + assert!(nscalars == indices.len(), "length mismatch"); + } else { + assert!(npoints >= nscalars, "not enough points"); + } + + let mut ret = $point::default(); + + #[cfg(feature = "cuda")] + if nscalars >= 1 << 16 + && context.on_gpu + && unsafe { !CUDA_OFF && cuda_available() } + { + extern "C" { + fn $name_with( + out: *mut $point, + context: &CudaMSMContext, + scalars: *const $scalar, + nscalars: usize, + indices: *const u32, + ) -> cuda::Error; + } + + let indices = if let Some(inner) = indices { + inner.as_ptr() + } else { + std::ptr::null() + }; + + let err = unsafe { + $name_with( + &mut ret, + &context.cuda_context, + &scalars[0], + nscalars, + indices, + ) + }; + assert!(err.code == 0, "{}", String::from(err)); + return $point::new_jacobian(ret.x, ret.y, ret.z).unwrap(); + } + + assert!(indices.is_none(), "no cpu support for indexed MSMs"); + unsafe { + $name_cpu( + &mut ret, + &context.cpu_context[0], + nscalars, + &scalars[0], + ) + }; + $point::new_jacobian(ret.x, ret.y, ret.z).unwrap() + } + + pub fn with_context( + context: &MSMContext<'_>, + scalars: &[$scalar], + ) -> $point { + with_context_aux(context, scalars, None) + } + + /// An indexed MSM. We do not check if the indices are valid, i.e. + /// for i in 0..nscalars, 0 <= indices[i] < npoints + pub fn indexed_with_context( + context: &MSMContext<'_>, + scalars: &[$scalar], + indices: &[u32], + ) -> $point { + with_context_aux(context, scalars, Some(indices)) + } }; - grumpkin::G1::new_jacobian(ret.x, ret.y, ret.z).unwrap() } #[cfg(test)] @@ -111,7 +344,7 @@ mod tests { use crate::utils::{gen_points, gen_scalars, naive_multiscalar_mul}; #[test] - fn it_works() { + fn test_simple() { #[cfg(not(debug_assertions))] const NPOINTS: usize = 128 * 1024; #[cfg(debug_assertions)] @@ -123,9 +356,16 @@ mod tests { let naive = naive_multiscalar_mul(&points, &scalars); println!("{:?}", naive); - let ret = crate::bn256(&points, &scalars).to_affine(); + let ret = crate::bn256::msm_aux(&points, &scalars, None).to_affine(); println!("{:?}", ret); + let context = crate::bn256::init(&points); + let ret_other = + crate::bn256::with_context_aux(&context, &scalars, None) + .to_affine(); + println!("{:?}", ret_other); + assert_eq!(ret, naive); + assert_eq!(ret, ret_other); } } diff --git a/src/pasta.rs b/src/pasta.rs index c69f6be..52f71cf 100644 --- a/src/pasta.rs +++ b/src/pasta.rs @@ -6,111 +6,324 @@ extern crate semolina; -use pasta_curves::pallas; - -#[cfg(feature = "cuda")] -use crate::{cuda, cuda_available, CUDA_OFF}; - -extern "C" { - fn mult_pippenger_pallas( - out: *mut pallas::Point, - points: *const pallas::Affine, - npoints: usize, - scalars: *const pallas::Scalar, - is_mont: bool, +pub mod pallas { + use pasta_curves::pallas::{Affine, Point, Scalar}; + + use crate::impl_pasta; + + impl_pasta!( + cuda_pallas, + cuda_pallas_init, + cuda_pallas_with, + mult_pippenger_pallas, + Point, + Affine, + Scalar ); } -pub fn pallas( - points: &[pallas::Affine], - scalars: &[pallas::Scalar], -) -> pallas::Point { - let npoints = points.len(); - assert_eq!(npoints, scalars.len(), "length mismatch"); +pub mod vesta { + use pasta_curves::vesta::{Affine, Point, Scalar}; - #[cfg(feature = "cuda")] - if npoints >= 1 << 16 && unsafe { !CUDA_OFF && cuda_available() } { - extern "C" { - fn cuda_pippenger_pallas( - out: *mut pallas::Point, - points: *const pallas::Affine, - npoints: usize, - scalars: *const pallas::Scalar, - is_mont: bool, - ) -> cuda::Error; + use crate::impl_pasta; - } - let mut ret = pallas::Point::default(); - let err = unsafe { - cuda_pippenger_pallas( - &mut ret, - &points[0], - npoints, - &scalars[0], - true, - ) - }; - assert!(err.code == 0, "{}", String::from(err)); - - return ret; - } - let mut ret = pallas::Point::default(); - unsafe { - mult_pippenger_pallas(&mut ret, &points[0], npoints, &scalars[0], true) - }; - ret + impl_pasta!( + cuda_vesta, + cuda_vesta_init, + cuda_vesta_with, + mult_pippenger_vesta, + Point, + Affine, + Scalar + ); } -use pasta_curves::vesta; +#[macro_export] +macro_rules! impl_pasta { + ( + $name:ident, + $name_init:ident, + $name_with:ident, + $name_cpu:ident, + $point:ident, + $affine:ident, + $scalar:ident + ) => { + #[cfg(feature = "cuda")] + use $crate::{cuda, cuda_available, CUDA_OFF}; + + #[repr(C)] + #[derive(Debug, Clone)] + pub struct CudaMSMContext { + context: *const std::ffi::c_void, + npoints: usize, + } -extern "C" { - fn mult_pippenger_vesta( - out: *mut vesta::Point, - points: *const vesta::Affine, - npoints: usize, - scalars: *const vesta::Scalar, - is_mont: bool, - ); -} + unsafe impl Send for CudaMSMContext {} -pub fn vesta( - points: &[vesta::Affine], - scalars: &[vesta::Scalar], -) -> vesta::Point { - let npoints = points.len(); - assert_eq!(npoints, scalars.len(), "length mismatch"); + unsafe impl Sync for CudaMSMContext {} + + impl Default for CudaMSMContext { + fn default() -> Self { + Self { + context: std::ptr::null(), + npoints: 0, + } + } + } + + #[cfg(feature = "cuda")] + // TODO: check for device-side memory leaks + impl Drop for CudaMSMContext { + fn drop(&mut self) { + extern "C" { + fn drop_msm_context_bn254(by_ref: &CudaMSMContext); + } + unsafe { + drop_msm_context_bn254(std::mem::transmute::<&_, &_>(self)) + }; + self.context = core::ptr::null(); + } + } + + #[derive(Default, Debug, Clone)] + pub struct MSMContext<'a> { + cuda_context: CudaMSMContext, + on_gpu: bool, + cpu_context: &'a [$affine], + } + + unsafe impl<'a> Send for MSMContext<'a> {} + + unsafe impl<'a> Sync for MSMContext<'a> {} + + impl<'a> MSMContext<'a> { + fn new(points: &'a [$affine]) -> Self { + Self { + cuda_context: CudaMSMContext::default(), + on_gpu: false, + cpu_context: points, + } + } + + fn npoints(&self) -> usize { + if self.on_gpu { + assert_eq!( + self.cpu_context.len(), + self.cuda_context.npoints + ); + } + self.cpu_context.len() + } + + fn cuda(&self) -> &CudaMSMContext { + &self.cuda_context + } + + fn points(&self) -> &[$affine] { + &self.cpu_context + } + } - #[cfg(feature = "cuda")] - if npoints >= 1 << 16 && unsafe { !CUDA_OFF && cuda_available() } { extern "C" { - fn cuda_pippenger_vesta( - out: *mut vesta::Point, - points: *const vesta::Affine, + fn $name_cpu( + out: *mut $point, + points: *const $affine, npoints: usize, - scalars: *const vesta::Scalar, + scalars: *const $scalar, is_mont: bool, - ) -> cuda::Error; + ); } - let mut ret = vesta::Point::default(); - let err = unsafe { - cuda_pippenger_vesta( - &mut ret, - &points[0], - npoints, - &scalars[0], - true, - ) - }; - assert!(err.code == 0, "{}", String::from(err)); - - return ret; - } - let mut ret = vesta::Point::default(); - unsafe { - mult_pippenger_vesta(&mut ret, &points[0], npoints, &scalars[0], true) + + pub fn msm_aux( + points: &[$affine], + scalars: &[$scalar], + indices: Option<&[u32]>, + ) -> $point { + let npoints = points.len(); + let nscalars = scalars.len(); + if let Some(indices) = indices { + assert!(nscalars == indices.len(), "length mismatch"); + } else { + assert!(npoints == nscalars, "length mismatch"); + } + + #[cfg(feature = "cuda")] + if npoints >= 1 << 16 && unsafe { !CUDA_OFF && cuda_available() } { + extern "C" { + fn $name( + out: *mut $point, + points: *const $affine, + npoints: usize, + scalars: *const $scalar, + nscalars: usize, + indices: *const u32, + is_mont: bool, + ) -> cuda::Error; + + } + + let indices = if let Some(inner) = indices { + inner.as_ptr() + } else { + std::ptr::null() + }; + + let mut ret = $point::default(); + let err = unsafe { + $name( + &mut ret, + &points[0], + npoints, + &scalars[0], + nscalars, + indices, + true, + ) + }; + assert!(err.code == 0, "{}", String::from(err)); + + return ret; + } + + assert!(indices.is_none(), "no cpu support for indexed MSMs"); + let mut ret = $point::default(); + unsafe { + $name_cpu(&mut ret, &points[0], npoints, &scalars[0], true) + }; + ret + } + + pub fn init(points: &[$affine]) -> MSMContext<'_> { + let npoints = points.len(); + + let mut ret = MSMContext::new(points); + + #[cfg(feature = "cuda")] + if npoints >= 1 << 16 && unsafe { !CUDA_OFF && cuda_available() } { + extern "C" { + fn $name_init( + points: *const $affine, + npoints: usize, + msm_context: &mut CudaMSMContext, + ) -> cuda::Error; + } + + let npoints = points.len(); + let err = unsafe { + $name_init( + points.as_ptr() as *const _, + npoints, + &mut ret.cuda_context, + ) + }; + assert!(err.code == 0, "{}", String::from(err)); + ret.on_gpu = true; + return ret; + } + + ret + } + + pub fn msm(points: &[$affine], scalars: &[$scalar]) -> $point { + msm_aux(points, scalars, None) + } + + /// An indexed MSM. We do not check if the indices are valid, i.e. + /// for i in 0..nscalars, 0 <= indices[i] < npoints + /// + /// Also, this version carries a performance penalty that all the + /// points must be moved onto the GPU once instead of in batches. + /// If the points are to be reused, please use the [`MSMContext`] API + pub fn indexed_msm( + points: &[$affine], + scalars: &[$scalar], + indices: &[u32], + ) -> $point { + msm_aux(points, scalars, Some(indices)) + } + + pub fn with_context_aux( + context: &MSMContext<'_>, + scalars: &[$scalar], + indices: Option<&[u32]>, + ) -> $point { + let npoints = context.npoints(); + let nscalars = scalars.len(); + if let Some(indices) = indices { + assert!(nscalars == indices.len(), "length mismatch"); + } else { + assert!(npoints >= nscalars, "not enough points"); + } + + let mut ret = $point::default(); + + #[cfg(feature = "cuda")] + if nscalars >= 1 << 16 && unsafe { !CUDA_OFF && cuda_available() } { + extern "C" { + fn $name_with( + out: *mut $point, + context: &CudaMSMContext, + scalars: *const $scalar, + nscalars: usize, + indices: *const u32, + is_mont: bool, + ) -> cuda::Error; + } + + let indices = if let Some(inner) = indices { + inner.as_ptr() + } else { + std::ptr::null() + }; + + let err = unsafe { + $name_with( + &mut ret, + context.cuda(), + &scalars[0], + nscalars, + indices, + true, + ) + }; + assert!(err.code == 0, "{}", String::from(err)); + return ret; + } + + assert!(indices.is_none(), "no cpu support for indexed MSMs"); + + unsafe { + $name_cpu( + &mut ret, + &context.cpu_context[0], + nscalars, + &scalars[0], + true, + ) + }; + + ret + } + + pub fn with_context( + context: &MSMContext<'_>, + scalars: &[$scalar], + ) -> $point { + with_context_aux(context, scalars, None) + } + + /// An indexed MSM. We do not check if the indices are valid, i.e. + /// for i in 0..nscalars, 0 <= indices[i] < npoints + pub fn indexed_with_context( + context: &MSMContext<'_>, + scalars: &[$scalar], + indices: &[u32], + ) -> $point { + with_context_aux(context, scalars, Some(indices)) + } }; - ret } pub mod utils { @@ -234,7 +447,7 @@ mod tests { }; #[test] - fn it_works() { + fn test_simple() { #[cfg(not(debug_assertions))] const NPOINTS: usize = 128 * 1024; #[cfg(debug_assertions)] @@ -246,9 +459,15 @@ mod tests { let naive = naive_multiscalar_mul(&points, &scalars); println!("{:?}", naive); - let ret = pallas(&points, &scalars).to_affine(); + let ret = pallas::msm_aux(&points, &scalars, None).to_affine(); println!("{:?}", ret); + let context = pallas::init(&points); + let ret_other = + pallas::with_context_aux(&context, &scalars, None).to_affine(); + println!("{:?}", ret_other); + assert_eq!(ret, naive); + assert_eq!(ret, ret_other); } }