Skip to content

Commit

Permalink
Fix benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
fkettelhoit committed Aug 26, 2024
1 parent 7bc3bab commit 6c3747c
Show file tree
Hide file tree
Showing 7 changed files with 279 additions and 102 deletions.
7 changes: 7 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,10 @@ serde = { version = "1.0.195", features = ["derive"] }
tokio = { version = "1.35.1", features = ["full"] }
trait-variant = "0.1.1"
async-trait = "0.1.77"

[dev-dependencies]
criterion = { version = "0.5.1", features = ["async_tokio"] }

[[bench]]
name = "join"
harness = false
29 changes: 29 additions & 0 deletions benches/.join.garble.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
const ROWS_0: usize = PARTY_0::ROWS;
const ROWS_1: usize = PARTY_1::ROWS;

enum ScreeningStatus {
Recent,
OutOfDate,
Missing,
}

pub fn main(
screenings: [([u8; 20], ScreeningStatus); ROWS_0],
school_examinations: [([u8; 20], u8); ROWS_1],
) -> [(u16, u16); 1] {
let mut missing_screenings_with_special_ed_needs = 0u16;
let mut total = ROWS_1 as u16;
for joined in join(screenings, school_examinations) {
let ((_, screening), (_, special_ed_needs)) = joined;
if special_ed_needs <= 2u8 {
match screening {
ScreeningStatus::Missing => {
missing_screenings_with_special_ed_needs =
missing_screenings_with_special_ed_needs + 1u16;
}
_ => {}
}
}
}
[(missing_screenings_with_special_ed_needs, total)]
}
94 changes: 94 additions & 0 deletions benches/join.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
use std::{collections::HashMap, time::Instant};

use criterion::{criterion_group, criterion_main, Criterion};
use garble_lang::{
compile_with_constants,
literal::{Literal, VariantLiteral},
token::UnsignedNumType,
};
use parlay::protocol::simulate_mpc_async;

fn join_benchmark(c: &mut Criterion) {
let n_records = 100;
let code = include_str!(".join.garble.rs");

let bench_id = format!("join {n_records} records");
c.bench_function(&bench_id, |b| {
b.to_async(tokio::runtime::Runtime::new().unwrap())
.iter(|| async {
let now = Instant::now();
println!("\n\nRUNNING MPC SIMULATION FOR {n_records} RECORDS:\n");
println!("Compiling circuit...");
let consts = HashMap::from_iter(vec![
(
"PARTY_0".into(),
HashMap::from_iter(vec![(
"ROWS".into(),
Literal::NumUnsigned(n_records, UnsignedNumType::Usize),
)]),
),
(
"PARTY_1".into(),
HashMap::from_iter(vec![(
"ROWS".into(),
Literal::NumUnsigned(n_records, UnsignedNumType::Usize),
)]),
),
]);
let prg = compile_with_constants(&code, consts).unwrap();
println!("{}", prg.circuit.report_gates());
let elapsed = now.elapsed();
println!(
"Compilation took {} minute(s), {} second(s)",
elapsed.as_secs() / 60,
elapsed.as_secs() % 60,
);

let id = Literal::ArrayRepeat(
Box::new(Literal::NumUnsigned(0, UnsignedNumType::U8)),
20,
);

let screening_status = Literal::Enum(
"ScreeningStatus".into(),
"Missing".into(),
VariantLiteral::Unit,
);

let rows0 = Literal::ArrayRepeat(
Box::new(Literal::Tuple(vec![id.clone(), screening_status])),
n_records as usize,
);

let rows1 = Literal::ArrayRepeat(
Box::new(Literal::Tuple(vec![
id.clone(),
Literal::NumUnsigned(0, UnsignedNumType::U8),
])),
n_records as usize,
);

let input0 = prg.literal_arg(0, rows0).unwrap().as_bits();
let input1 = prg.literal_arg(1, rows1).unwrap().as_bits();

let inputs = vec![input0.as_slice(), input1.as_slice()];

simulate_mpc_async(&prg.circuit, &inputs, &[0])
.await
.unwrap();
let elapsed = now.elapsed();
println!(
"MPC computation took {} minute(s), {} second(s)",
elapsed.as_secs() / 60,
elapsed.as_secs() % 60,
);
})
});
}

criterion_group! {
name = benches;
config = Criterion::default().significance_level(0.1).sample_size(10);
targets = join_benchmark
}
criterion_main!(benches);
24 changes: 16 additions & 8 deletions examples/sql-integration/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,7 @@ impl Channel for HttpChannel {
type RecvError = anyhow::Error;

async fn send_bytes_to(&mut self, p: usize, msg: Vec<u8>) -> Result<(), Self::SendError> {
let simulated_delay_in_ms = 500;
let simulated_delay_in_ms = 300;
let expected_msgs = (self.urls.len() - 2) * 2 + 3;
let client = reqwest::Client::new();
let url = format!("{}msg/{}", self.urls[p], self.party);
Expand All @@ -798,29 +798,37 @@ impl Channel for HttpChannel {
self.send_count, expected_msgs, p, mb
);
}
let chunk_size = 10 * 1024 * 1024;
let chunk_size = 100 * 1024 * 1024;
let mut chunks: Vec<_> = msg.chunks(chunk_size).collect();
if chunks.is_empty() {
chunks.push(&[]);
}
let length = chunks.len();
let len = chunks.len();
for (i, chunk) in chunks.into_iter().enumerate() {
if length > 1 && self.enable_log {
info!(" (Sending chunk {}/{} to party {})", i + 1, length, p);
if len > 1 && self.enable_log {
info!(" (Sending chunk {}/{} to party {})", i + 1, len, p);
}
let mut msg = Vec::with_capacity(2 * 4 + chunk.len());
msg.extend((i as u32).to_be_bytes());
msg.extend((length as u32).to_be_bytes());
msg.extend((len as u32).to_be_bytes());
msg.extend(chunk);
loop {
sleep(Duration::from_millis(simulated_delay_in_ms)).await;
match client.post(&url).body(msg.clone()).send().await?.status() {
let req = client.post(&url).body(msg.clone()).send();
let Ok(Ok(res)) = timeout(Duration::from_secs(1), req).await else {
warn!(" req timeout: chunk {}/{} for party {}", i + 1, len, p);
continue;
};
match res.status() {
StatusCode::OK => break,
StatusCode::NOT_FOUND => {
error!("Could not reach party {p} at {url}...");
sleep(Duration::from_millis(1000)).await;
}
status => anyhow::bail!("Unexpected status code: {status}"),
status => {
error!("Unexpected status code: {status}");
anyhow::bail!("Unexpected status code: {status}")
}
}
}
}
Expand Down
18 changes: 15 additions & 3 deletions examples/sql-integration/tests/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@ fn simulate() {
println!("\n\n--- {millis}ms ---\n\n");
let mut children = vec![];
let mut cmd = Command::new("cargo")
.args(["run", "--", "--port=8002", "--config=preprocessor.json"])
.args([
"run",
"--release",
"--",
"--port=8002",
"--config=preprocessor.json",
])
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
Expand All @@ -35,7 +41,7 @@ fn simulate() {

let port = "--port=8001".to_string();
let config = "--config=policies1.json".to_string();
let args = vec!["run", "--", &port, &config];
let args = vec!["run", "--release", "--", &port, &config];
let mut cmd = Command::new("cargo")
.args(args)
.stdout(Stdio::piped())
Expand All @@ -57,7 +63,13 @@ fn simulate() {
children.push(cmd);

sleep(Duration::from_millis(millis));
let args = vec!["run", "--", "--port=8000", "--config=policies0.json"];
let args = vec![
"run",
"--release",
"--",
"--port=8000",
"--config=policies0.json",
];
let mut cmd = Command::new("cargo")
.args(args)
.stdout(Stdio::piped())
Expand Down
63 changes: 41 additions & 22 deletions src/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,31 +159,50 @@ impl Channel for SimpleChannel {
type SendError = SendError<Vec<u8>>;
type RecvError = AsyncRecvError;

async fn send_bytes_to(
&mut self,
party: usize,
msg: Vec<u8>,
) -> Result<(), SendError<Vec<u8>>> {
self.s[party]
.as_ref()
.unwrap_or_else(|| panic!("No sender for party {party}"))
.send(msg)
.await
async fn send_bytes_to(&mut self, p: usize, msg: Vec<u8>) -> Result<(), SendError<Vec<u8>>> {
let mb = msg.len() as f64 / 1024.0 / 1024.0;
println!("Sending msg to party {p} ({mb:.2}MB)...");
let chunk_size = 100 * 1024 * 1024;
let mut chunks: Vec<_> = msg.chunks(chunk_size).collect();
if chunks.is_empty() {
chunks.push(&[]);
}
let length = chunks.len();
for (i, chunk) in chunks.into_iter().enumerate() {
if length > 1 {
println!(" (Sending chunk {}/{} to party {})", i + 1, length, p);
}
let mut msg = Vec::with_capacity(2 * 4 + chunk.len());
msg.extend((i as u32).to_be_bytes());
msg.extend((length as u32).to_be_bytes());
msg.extend(chunk);
self.s[p]
.as_ref()
.unwrap_or_else(|| panic!("No sender for party {p}"))
.send(msg)
.await?;
}
Ok(())
}

async fn recv_bytes_from(&mut self, party: usize) -> Result<Vec<u8>, AsyncRecvError> {
match timeout(
Duration::from_secs(1),
self.r[party]
async fn recv_bytes_from(&mut self, p: usize) -> Result<Vec<u8>, AsyncRecvError> {
let mut msg: Vec<u8> = vec![];
loop {
let chunk = self.r[p]
.as_mut()
.unwrap_or_else(|| panic!("No receiver for party {party}"))
.recv(),
)
.await
{
Ok(Some(bytes)) => Ok(bytes),
Ok(None) => Err(AsyncRecvError::Closed),
Err(_) => Err(AsyncRecvError::TimeoutElapsed),
.unwrap_or_else(|| panic!("No receiver for party {p}"))
.recv();
let chunk = match timeout(Duration::from_secs(10 * 60), chunk).await {
Ok(Some(bytes)) => bytes,
Ok(None) => return Err(AsyncRecvError::Closed),
Err(_) => return Err(AsyncRecvError::TimeoutElapsed),
};
let i = u32::from_be_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
let length = u32::from_be_bytes([chunk[4], chunk[5], chunk[6], chunk[7]]);
msg.extend(&chunk[8..]);
if i == length - 1 {
break Ok(msg);
}
}
}
}
Loading

0 comments on commit 6c3747c

Please sign in to comment.