Skip to content

Commit

Permalink
Add pasta bindings (#3)
Browse files Browse the repository at this point in the history
* add pasta bindings

* add tests and benches

* refactor: re-organize build script

---------

Co-authored-by: Hanting Zhang <[email protected]>
Co-authored-by: François Garillot <[email protected]>
  • Loading branch information
3 people authored Jan 7, 2024
1 parent a584fbb commit 0965aeb
Show file tree
Hide file tree
Showing 13 changed files with 547 additions and 85 deletions.
7 changes: 5 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,11 @@ jobs:
with:
command: test

- name: Run msm example
run: cargo run --release --example msm
- name: Run grumpkin_msm example
run: cargo run --release --example grumpkin_msm

- name: Run pasta_msm example
run: cargo run --release --example pasta_msm

- name: Check benches build
run: cargo check --benches
12 changes: 9 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,21 @@ include = [
default = []
# Compile in portable mode, without ISA extensions.
# Binary can be executed on all systems.
portable = [ "blst/portable" ]
portable = [ "blst/portable", "semolina/portable" ]
# Enable ADX even if the host CPU doesn't support it.
# Binary can be executed on Broadwell+ and Ryzen+ systems.
force-adx = [ "blst/force-adx" ]
force-adx = [ "blst/force-adx", "semolina/force-adx" ]
cuda-mobile = []
# Build with __MSM_SORT_DONT_IMPLEMENT__ to prevent redefining
# symbols that breaks compilation during linking.
dont-implement-sort = []

[dependencies]
blst = "~0.3.11"
semolina = "~0.1.3"
sppark = "~0.1.2"
halo2curves = { version = "0.5.0" }
pasta_curves = { git = "https://github.com/lurk-lab/pasta_curves", branch = "dev", version = ">=0.3.1, <=0.5", features = ["repr-c"] }
rand = "^0"
rand_chacha = "^0"
rayon = "1.5"
Expand All @@ -46,5 +48,9 @@ which = "^4.0"
criterion = { version = "0.3", features = [ "html_reports" ] }

[[bench]]
name = "msm"
name = "grumpkin_msm"
harness = false

[[bench]]
name = "pasta_msm"
harness = false
1 change: 1 addition & 0 deletions benches/msm.rs → benches/grumpkin_msm.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright Supranational LLC
// Licensed under the Apache License, Version 2.0, see LICENSE for details.
// SPDX-License-Identifier: Apache-2.0
#![allow(unused_mut)]

use criterion::{criterion_group, criterion_main, Criterion};
use grumpkin_msm::utils::{gen_points, gen_scalars};
Expand Down
66 changes: 66 additions & 0 deletions benches/pasta_msm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright Supranational LLC
// Licensed under the Apache License, Version 2.0, see LICENSE for details.
// SPDX-License-Identifier: Apache-2.0
#![allow(unused_mut)]

use criterion::{criterion_group, criterion_main, Criterion};
use grumpkin_msm::pasta::utils::{gen_points, gen_scalars};

#[cfg(feature = "cuda")]
use grumpkin_msm::cuda_available;

fn criterion_benchmark(c: &mut Criterion) {
let bench_npow: usize = std::env::var("BENCH_NPOW")
.unwrap_or("17".to_string())
.parse()
.unwrap();
let npoints: usize = 1 << bench_npow;

// println!("generating {} random points, just hang on...", npoints);
let mut points = gen_points(npoints);
let mut scalars = gen_scalars(npoints);

#[cfg(feature = "cuda")]
{
unsafe { grumpkin_msm::CUDA_OFF = true };
}

let mut group = c.benchmark_group("CPU");
group.sample_size(10);

group.bench_function(format!("2**{} points", bench_npow), |b| {
b.iter(|| {
let _ = grumpkin_msm::pasta::pallas(&points, &scalars);
})
});

group.finish();

#[cfg(feature = "cuda")]
if unsafe { cuda_available() } {
unsafe { grumpkin_msm::CUDA_OFF = false };

const EXTRA: usize = 5;
let bench_npow = bench_npow + EXTRA;
let npoints: usize = 1 << bench_npow;

while points.len() < npoints {
points.append(&mut points.clone());
}
scalars.append(&mut gen_scalars(npoints - scalars.len()));

let mut group = c.benchmark_group("GPU");
group.sample_size(20);

group.bench_function(format!("2**{} points", bench_npow), |b| {
b.iter(|| {
let _ = grumpkin_msm::pasta::pallas(&points, &scalars);
})
});

group.finish();
}
}

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
186 changes: 107 additions & 79 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,105 +2,133 @@ use std::env;
use std::path::PathBuf;

fn main() {
// account for cross-compilation [by examining environment variable]
let target_arch = env::var("CARGO_CFG_TARGET_ARCH").unwrap();

// Set CXX environment variable to choose alternative C compiler.
// Optimization level depends on whether or not --release is passed
// or implied.
compile_source(
"grumpkin_pippenger.cpp",
"__BLST_PORTABLE__",
"grumpkin_msm",
&target_arch,
);
compile_source(
"pasta_pippenger.cpp",
"__PASTA_PORTABLE__",
"pasta_msm",
&target_arch,
);

if cfg!(target_os = "windows") && !cfg!(target_env = "msvc") {
return;
}

if cuda_available() {
let mut implement_sort: bool = true;
compile_cuda("cuda/bn254.cu", "bn256_msm_cuda", implement_sort);
implement_sort = false;
compile_cuda("cuda/grumpkin.cu", "grumpkin_msm_cuda", implement_sort);
compile_cuda("cuda/pallas.cu", "pallas_msm_cuda", implement_sort);
compile_cuda("cuda/vesta.cu", "vesta_msm_cuda", implement_sort);
println!("cargo:rerun-if-changed=cuda");
}
println!("cargo:rerun-if-env-changed=NVCC");
}

fn compile_source(
file_name: &str,
def: &str,
output_name: &str,
target_arch: &str,
) {
let mut cc = cc::Build::new();
cc.cpp(true);

let c_src_dir = PathBuf::from("src");
let files = vec![c_src_dir.join("pippenger.cpp")];
let mut cc_def = None;
let file = c_src_dir.join(file_name);
let cc_def = determine_cc_def(target_arch, def);

match (cfg!(feature = "portable"), cfg!(feature = "force-adx")) {
(true, false) => {
println!("Compiling in portable mode without ISA extensions");
cc_def = Some("__BLST_PORTABLE__");
}
(false, true) => {
if target_arch.eq("x86_64") {
println!("Enabling ADX support via `force-adx` feature");
cc_def = Some("__ADX__");
} else {
println!("`force-adx` is ignored for non-x86_64 targets");
}
}
(false, false) => {
#[cfg(target_arch = "x86_64")]
if target_arch.eq("x86_64") && std::is_x86_feature_detected!("adx")
{
println!("Enabling ADX because it was detected on the host");
cc_def = Some("__ADX__");
}
}
(true, true) => panic!(
"Cannot compile with both `portable` and `force-adx` features"
),
common_build_configurations(&mut cc);
if let Some(cc_def) = cc_def {
cc.define(&cc_def, None);
}
if let Some(include) = env::var_os("DEP_BLST_C_SRC") {
cc.include(include);
}
if let Some(include) = env::var_os("DEP_SEMOLINA_C_INCLUDE") {
cc.include(include);
}
if let Some(include) = env::var_os("DEP_SPPARK_ROOT") {
cc.include(include);
}
cc.file(file).compile(output_name);
}

cc.flag_if_supported("-mno-avx") // avoid costly transitions
fn common_build_configurations(cc: &mut cc::Build) {
cc.flag_if_supported("-mno-avx")
.flag_if_supported("-fno-builtin")
.flag_if_supported("-std=c++11")
.flag_if_supported("-Wno-unused-command-line-argument");
if !cfg!(debug_assertions) {
cc.define("NDEBUG", None);
}
if let Some(def) = cc_def {
cc.define(def, None);
}

fn determine_cc_def(target_arch: &str, default_def: &str) -> Option<String> {
match (cfg!(feature = "portable"), cfg!(feature = "force-adx")) {
(true, false) => Some(default_def.to_string()),
(false, true) if target_arch == "x86_64" => Some("__ADX__".to_string()),
(false, false)
if target_arch == "x86_64"
&& std::is_x86_feature_detected!("adx") =>
{
Some("__ADX__".to_string())
}
(true, true) => panic!(
"Cannot compile with both `portable` and `force-adx` features"
),
_ => None,
}
}

fn cuda_available() -> bool {
match env::var("NVCC") {
Ok(var) => which::which(var).is_ok(),
Err(_) => which::which("nvcc").is_ok(),
}
}

fn compile_cuda(file_name: &str, output_name: &str, implement_sort: bool) {
let mut nvcc = cc::Build::new();
nvcc.cuda(true);
nvcc.flag("-arch=sm_80");
nvcc.flag("-gencode").flag("arch=compute_70,code=sm_70");
nvcc.flag("-t0");
#[cfg(not(target_env = "msvc"))]
nvcc.flag("-Xcompiler").flag("-Wno-unused-function");
nvcc.define("TAKE_RESPONSIBILITY_FOR_ERROR_MESSAGE", None);
#[cfg(feature = "cuda-mobile")]
nvcc.define("NTHREADS", "128");

if let Some(def) = determine_cc_def(
&env::var("CARGO_CFG_TARGET_ARCH").unwrap(),
"__CUDA_PORTABLE__",
) {
nvcc.define(&def, None);
}

if let Some(include) = env::var_os("DEP_BLST_C_SRC") {
cc.include(include);
nvcc.include(include);
}
if let Some(include) = env::var_os("DEP_SPPARK_ROOT") {
cc.include(include);
if let Some(include) = env::var_os("DEP_SEMOLINA_C_INCLUDE") {
nvcc.include(include);
}
cc.files(&files).compile("grumpkin_msm");

if cfg!(target_os = "windows") && !cfg!(target_env = "msvc") {
return;
if let Some(include) = env::var_os("DEP_SPPARK_ROOT") {
nvcc.include(include);
}
// Detect if there is CUDA compiler and engage "cuda" feature accordingly
let nvcc = match env::var("NVCC") {
Ok(var) => which::which(var),
Err(_) => which::which("nvcc"),
};
if nvcc.is_ok() {
let mut nvcc = cc::Build::new();
nvcc.cuda(true);
nvcc.flag("-arch=sm_80");
nvcc.flag("-gencode").flag("arch=compute_70,code=sm_70");
nvcc.flag("-t0");
#[cfg(not(target_env = "msvc"))]
nvcc.flag("-Xcompiler").flag("-Wno-unused-function");
nvcc.define("TAKE_RESPONSIBILITY_FOR_ERROR_MESSAGE", None);
#[cfg(feature = "cuda-mobile")]
nvcc.define("NTHREADS", "128");
if let Some(def) = cc_def {
nvcc.define(def, None);
}
if let Some(include) = env::var_os("DEP_BLST_C_SRC") {
nvcc.include(include);
}
if let Some(include) = env::var_os("DEP_SPPARK_ROOT") {
nvcc.include(include);
}
#[cfg(not(feature = "dont-implement-sort"))]
nvcc.clone().file("cuda/bn254.cu").compile("bn256_msm_cuda");
#[cfg(feature = "dont-implement-sort")]
nvcc.clone()
.define("__MSM_SORT_DONT_IMPLEMENT__", None)
.file("cuda/bn254.cu")
.compile("bn256_msm_cuda");
if implement_sort {
nvcc.file(file_name).compile(output_name);
} else {
nvcc.define("__MSM_SORT_DONT_IMPLEMENT__", None)
.file("cuda/grumpkin.cu")
.compile("grumpkin_msm_cuda");

println!("cargo:rerun-if-changed=cuda");
println!("cargo:rerun-if-env-changed=CXXFLAGS");
println!("cargo:rustc-cfg=feature=\"cuda\"");
.file(file_name)
.compile(output_name);
}
println!("cargo:rerun-if-env-changed=NVCC");
}
24 changes: 24 additions & 0 deletions cuda/pallas.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright Supranational LLC
// Licensed under the Apache License, Version 2.0, see LICENSE for details.
// SPDX-License-Identifier: Apache-2.0

#include <cuda.h>

#include <ec/jacobian_t.hpp>
#include <ec/xyzz_t.hpp>

#include <ff/pasta.hpp>

typedef jacobian_t<pallas_t> point_t;
typedef xyzz_t<pallas_t> bucket_t;
typedef bucket_t::affine_t affine_t;
typedef vesta_t scalar_t;

#include <msm/pippenger.cuh>

#ifndef __CUDA_ARCH__
extern "C"
RustError cuda_pippenger_pallas(point_t *out, const affine_t points[], size_t npoints,
const scalar_t scalars[])
{ return mult_pippenger<bucket_t>(out, points, npoints, scalars); }
#endif
24 changes: 24 additions & 0 deletions cuda/vesta.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright Supranational LLC
// Licensed under the Apache License, Version 2.0, see LICENSE for details.
// SPDX-License-Identifier: Apache-2.0

#include <cuda.h>

#include <ec/jacobian_t.hpp>
#include <ec/xyzz_t.hpp>

#include <ff/pasta.hpp>

typedef jacobian_t<vesta_t> point_t;
typedef xyzz_t<vesta_t> bucket_t;
typedef bucket_t::affine_t affine_t;
typedef pallas_t scalar_t;

#include <msm/pippenger.cuh>

#ifndef __CUDA_ARCH__
extern "C"
RustError cuda_pippenger_vesta(point_t *out, const affine_t points[], size_t npoints,
const scalar_t scalars[])
{ return mult_pippenger<bucket_t>(out, points, npoints, scalars); }
#endif
File renamed without changes.
Loading

0 comments on commit 0965aeb

Please sign in to comment.