Skip to content

Commit

Permalink
fix clippy + tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanting Zhang committed Dec 12, 2023
1 parent 7fc9efd commit 82da843
Show file tree
Hide file tree
Showing 7 changed files with 175 additions and 313 deletions.
30 changes: 30 additions & 0 deletions .cargo/config
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
[alias]
# Collection of project wide clippy lints. This is done via an alias because
# clippy doesn't currently allow for specifiying project-wide lints in a
# configuration file. This is a similar workaround to the ones presented here:
# <https://github.com/EmbarkStudios/rust-ecosystem/issues/59>
# TODO: add support for --all-features
xclippy = [
"clippy", "--all-targets", "--",
"-Wclippy::all",
"-Wclippy::match_same_arms",
"-Wclippy::cast_lossless",
"-Wclippy::checked_conversions",
"-Wclippy::dbg_macro",
"-Wclippy::disallowed_methods",
"-Wclippy::derive_partial_eq_without_eq",
"-Wclippy::filter_map_next",
"-Wclippy::flat_map_option",
"-Wclippy::inefficient_to_string",
"-Wclippy::large_types_passed_by_value",
"-Wclippy::manual_assert",
"-Wclippy::manual_ok_or",
"-Wclippy::map_flatten",
"-Wclippy::map_unwrap_or",
"-Wclippy::needless_borrow",
"-Wclippy::needless_pass_by_value",
"-Wclippy::unnecessary_mut_passed",
"-Wrust_2018_idioms",
"-Wtrivial_numeric_casts",
"-Wunused_qualifications",
]
6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,16 @@ cuda-mobile = []
blst = "~0.3.11"
sppark = "~0.1.2"
halo2curves = { version = "0.4.0" }
rand = "^0"
rand_chacha = "^0"
rayon = "1.5"

[build-dependencies]
cc = "^1.0.70"
which = "^4.0"

[dev-dependencies]
criterion = { version = "0.3", features = [ "html_reports" ] }
rand = "^0"
rand_chacha = "^0"
rayon = "1.5"

[[bench]]
name = "msm"
Expand Down
20 changes: 4 additions & 16 deletions benches/msm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,8 @@
// Licensed under the Apache License, Version 2.0, see LICENSE for details.
// SPDX-License-Identifier: Apache-2.0

#![allow(dead_code)]
#![allow(unused_imports)]
#![allow(unused_mut)]

use criterion::{criterion_group, criterion_main, Criterion};

use grumpkin_msm;

#[cfg(feature = "cuda")]
extern "C" {
fn cuda_available() -> bool;
}

include!("../src/tests.rs");
use grumpkin_msm::{utils::{gen_points, gen_scalars}, cuda_available};

fn criterion_benchmark(c: &mut Criterion) {
let bench_npow: usize = std::env::var("BENCH_NPOW")
Expand All @@ -25,8 +13,8 @@ fn criterion_benchmark(c: &mut Criterion) {
let npoints: usize = 1 << bench_npow;

//println!("generating {} random points, just hang on...", npoints);
let mut points = crate::tests::gen_points(npoints);
let mut scalars = crate::tests::gen_scalars(npoints);
let mut points = gen_points(npoints);
let mut scalars = gen_scalars(npoints);

#[cfg(feature = "cuda")]
{
Expand Down Expand Up @@ -55,7 +43,7 @@ fn criterion_benchmark(c: &mut Criterion) {
while points.len() < npoints {
points.append(&mut points.clone());
}
scalars.append(&mut crate::tests::gen_scalars(npoints - scalars.len()));
scalars.append(&mut gen_scalars(npoints - scalars.len()));

let mut group = c.benchmark_group("GPU");
group.sample_size(20);
Expand Down
131 changes: 3 additions & 128 deletions examples/msm.rs
Original file line number Diff line number Diff line change
@@ -1,139 +1,14 @@
use core::cell::UnsafeCell;
use core::mem::transmute;
use core::sync::atomic::*;
use halo2curves::bn256;
use halo2curves::ff::Field;
use grumpkin_msm::{utils::{gen_points, gen_scalars, naive_multiscalar_mul}, cuda_available};
use halo2curves::group::Curve;
use halo2curves::CurveExt;
use rand::{RngCore, SeedableRng};
use rand_chacha::ChaCha20Rng;

#[cfg(feature = "cuda")]
extern "C" {
fn cuda_available() -> bool;
}

pub fn gen_points(npoints: usize) -> Vec<bn256::G1Affine> {
let mut ret: Vec<bn256::G1Affine> = Vec::with_capacity(npoints);
unsafe { ret.set_len(npoints) };

let mut rnd: Vec<u8> = Vec::with_capacity(32 * npoints);
unsafe { rnd.set_len(32 * npoints) };
ChaCha20Rng::from_entropy().fill_bytes(&mut rnd);

let n_workers = rayon::current_num_threads();
let work = AtomicUsize::new(0);
rayon::scope(|s| {
for _ in 0..n_workers {
s.spawn(|_| {
let hash = bn256::G1::hash_to_curve("foobar");

let mut stride = 1024;
let mut tmp: Vec<bn256::G1> = Vec::with_capacity(stride);
unsafe { tmp.set_len(stride) };

loop {
let work = work.fetch_add(stride, Ordering::Relaxed);
if work >= npoints {
break;
}
if work + stride > npoints {
stride = npoints - work;
unsafe { tmp.set_len(stride) };
}
for i in 0..stride {
let off = (work + i) * 32;
tmp[i] = hash(&rnd[off..off + 32]);
}
#[allow(mutable_transmutes)]
bn256::G1::batch_normalize(&tmp, unsafe {
transmute::<&[bn256::G1Affine], &mut [bn256::G1Affine]>(
&ret[work..work + stride],
)
});
}
})
}
});

ret
}

fn as_mut<T>(x: &T) -> &mut T {
unsafe { &mut *UnsafeCell::raw_get(x as *const _ as *const _) }
}

pub fn gen_scalars(npoints: usize) -> Vec<bn256::Fr> {
let mut ret: Vec<bn256::Fr> = Vec::with_capacity(npoints);
unsafe { ret.set_len(npoints) };

let n_workers = rayon::current_num_threads();
let work = AtomicUsize::new(0);

rayon::scope(|s| {
for _ in 0..n_workers {
s.spawn(|_| {
let mut rng = ChaCha20Rng::from_entropy();
loop {
let work = work.fetch_add(1, Ordering::Relaxed);
if work >= npoints {
break;
}
*as_mut(&ret[work]) = bn256::Fr::random(&mut rng);
}
})
}
});

ret
}

pub fn naive_multiscalar_mul(
points: &[bn256::G1Affine],
scalars: &[bn256::Fr],
) -> bn256::G1Affine {
let n_workers = rayon::current_num_threads();

let mut rets: Vec<bn256::G1> = Vec::with_capacity(n_workers);
unsafe { rets.set_len(n_workers) };

let npoints = points.len();
let work = AtomicUsize::new(0);
let tid = AtomicUsize::new(0);
rayon::scope(|s| {
for _ in 0..n_workers {
s.spawn(|_| {
let mut ret = bn256::G1::default();

loop {
let work = work.fetch_add(1, Ordering::Relaxed);
if work >= npoints {
break;
}
ret += points[work] * scalars[work];
}

*as_mut(&rets[tid.fetch_add(1, Ordering::Relaxed)]) = ret;
})
}
});

let mut ret = bn256::G1::default();
for i in 0..n_workers {
ret += rets[i];
}

ret.to_affine()
}

fn main() {
let bench_npow: usize = std::env::var("BENCH_NPOW")
.unwrap_or("10".to_string())
.unwrap_or("17".to_string())
.parse()
.unwrap();
let npoints: usize = 1 << bench_npow;

// println!("generating {} random points, just hang on...", npoints);
println!("generating {} random points, just hang on...", npoints);
let points = gen_points(npoints);
let scalars = gen_scalars(npoints);

Expand Down
55 changes: 39 additions & 16 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@
// Licensed under the Apache License, Version 2.0, see LICENSE for details.
// SPDX-License-Identifier: Apache-2.0
#![allow(improper_ctypes)]
#![allow(unused)]

pub mod utils;

extern crate blst;

#[cfg(feature = "cuda")]
sppark::cuda_error!();
#[cfg(feature = "cuda")]
extern "C" {
fn cuda_available() -> bool;
pub fn cuda_available() -> bool;
}
#[cfg(feature = "cuda")]
pub static mut CUDA_OFF: bool = false;
Expand All @@ -29,9 +32,8 @@ extern "C" {

pub fn bn256(points: &[bn256::G1Affine], scalars: &[bn256::Fr]) -> bn256::G1 {
let npoints = points.len();
if npoints != scalars.len() {
panic!("length mismatch")
}
assert!(npoints == scalars.len(), "length mismatch");

#[cfg(feature = "cuda")]
if npoints >= 1 << 16 && unsafe { !CUDA_OFF && cuda_available() } {
extern "C" {
Expand All @@ -47,10 +49,9 @@ pub fn bn256(points: &[bn256::G1Affine], scalars: &[bn256::Fr]) -> bn256::G1 {
let err = unsafe {
cuda_pippenger_bn254(&mut ret, &points[0], npoints, &scalars[0])
};
if err.code != 0 {
panic!("{}", String::from(err));
}
return ret;
assert!(err.code == 0, "{}", String::from(err));

return bn256::G1::new_jacobian(ret.x, ret.y, ret.z).unwrap();
}
let mut ret = bn256::G1::default();
unsafe { mult_pippenger_bn254(&mut ret, &points[0], npoints, &scalars[0]) };
Expand All @@ -74,9 +75,8 @@ pub fn grumpkin(
scalars: &[grumpkin::Fr],
) -> grumpkin::G1 {
let npoints = points.len();
if npoints != scalars.len() {
panic!("length mismatch")
}
assert!(npoints == scalars.len(), "length mismatch");

#[cfg(feature = "cuda")]
if npoints >= 1 << 16 && unsafe { !CUDA_OFF && cuda_available() } {
extern "C" {
Expand All @@ -92,10 +92,9 @@ pub fn grumpkin(
let err = unsafe {
cuda_pippenger_grumpkin(&mut ret, &points[0], npoints, &scalars[0])
};
if err.code != 0 {
panic!("{}", String::from(err));
}
return ret;
assert!(err.code == 0, "{}", String::from(err));

return grumpkin::G1::new_jacobian(ret.x, ret.y, ret.z).unwrap();
}
let mut ret = grumpkin::G1::default();
unsafe {
Expand All @@ -104,4 +103,28 @@ pub fn grumpkin(
grumpkin::G1::new_jacobian(ret.x, ret.y, ret.z).unwrap()
}

include!("tests.rs");
#[cfg(test)]
mod tests {
use halo2curves::group::Curve;

use crate::utils::{gen_points, gen_scalars, naive_multiscalar_mul};

#[test]
fn it_works() {
#[cfg(not(debug_assertions))]
const NPOINTS: usize = 128 * 1024;
#[cfg(debug_assertions)]
const NPOINTS: usize = 8 * 1024;

let points = gen_points(NPOINTS);
let scalars = gen_scalars(NPOINTS);

let naive = naive_multiscalar_mul(&points, &scalars);
println!("{:?}", naive);

let ret = crate::bn256(&points, &scalars).to_affine();
println!("{:?}", ret);

assert_eq!(ret, naive);
}
}
Loading

0 comments on commit 82da843

Please sign in to comment.