From 176d985170de4e519d4cfeaa490cc6dc33d7350c Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Wed, 10 Apr 2024 18:31:54 +0100 Subject: [PATCH 01/25] Add example based on accepted/rejected programs --- Cargo.toml | 1 + examples/postgres-integration/.add.garble.rs | 3 + examples/postgres-integration/Cargo.toml | 19 ++ examples/postgres-integration/policies0.json | 15 + examples/postgres-integration/policies1.json | 15 + .../postgres-integration/preprocessor.json | 15 + examples/postgres-integration/src/main.rs | 311 ++++++++++++++++++ examples/postgres-integration/tests/cli.rs | 92 ++++++ src/fpre.rs | 7 +- src/protocol.rs | 36 +- 10 files changed, 501 insertions(+), 13 deletions(-) create mode 100644 examples/postgres-integration/.add.garble.rs create mode 100644 examples/postgres-integration/Cargo.toml create mode 100644 examples/postgres-integration/policies0.json create mode 100644 examples/postgres-integration/policies1.json create mode 100644 examples/postgres-integration/preprocessor.json create mode 100644 examples/postgres-integration/src/main.rs create mode 100644 examples/postgres-integration/tests/cli.rs diff --git a/Cargo.toml b/Cargo.toml index e27479d..a68405a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,7 @@ workspace = { members = [ "examples/http-multi-server-channels", "examples/http-single-server-channels", "examples/iroh-p2p-channels", + "examples/postgres-integration", ] } [package] name = "parlay" diff --git a/examples/postgres-integration/.add.garble.rs b/examples/postgres-integration/.add.garble.rs new file mode 100644 index 0000000..ca6e9c6 --- /dev/null +++ b/examples/postgres-integration/.add.garble.rs @@ -0,0 +1,3 @@ +pub fn main(x: u32, y: u32) -> u32 { + x + y +} diff --git a/examples/postgres-integration/Cargo.toml b/examples/postgres-integration/Cargo.toml new file mode 100644 index 0000000..e1e4724 --- /dev/null +++ b/examples/postgres-integration/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "parlay-postgres-integration" +version = "0.1.0" +edition = "2021" + +[dependencies] +anyhow = "1.0.79" +axum = "0.7.4" +blake3 = "1.5.1" +clap = { version = "4.4.18", features = ["derive"] } +parlay = { path = "../../", version = "0.1.0" } +reqwest = { version = "0.11.23", features = ["json"] } +serde = "1.0.197" +serde_json = "1.0.115" +tokio = { version = "1.35.1", features = ["macros", "rt", "rt-multi-thread"] } +tower-http = { version = "0.5.1", features = ["fs", "trace"] } +tracing = "0.1.40" +tracing-subscriber = "0.3.18" +url = { version = "2.5.0", features = ["serde"] } diff --git a/examples/postgres-integration/policies0.json b/examples/postgres-integration/policies0.json new file mode 100644 index 0000000..396cff6 --- /dev/null +++ b/examples/postgres-integration/policies0.json @@ -0,0 +1,15 @@ +{ + "accepted": [ + { + "participants": [ + "http://localhost:8000", + "http://localhost:8001", + "http://localhost:8002" + ], + "program": ".add.garble.rs", + "leader": 0, + "party": 0, + "input": "2u32" + } + ] +} diff --git a/examples/postgres-integration/policies1.json b/examples/postgres-integration/policies1.json new file mode 100644 index 0000000..f9dfe7c --- /dev/null +++ b/examples/postgres-integration/policies1.json @@ -0,0 +1,15 @@ +{ + "accepted": [ + { + "participants": [ + "http://localhost:8000", + "http://localhost:8001", + "http://localhost:8002" + ], + "program": ".add.garble.rs", + "leader": 0, + "party": 1, + "input": "3u32" + } + ] +} diff --git a/examples/postgres-integration/preprocessor.json b/examples/postgres-integration/preprocessor.json new file mode 100644 index 0000000..b1e8251 --- /dev/null +++ b/examples/postgres-integration/preprocessor.json @@ -0,0 +1,15 @@ +{ + "accepted": [ + { + "participants": [ + "http://localhost:8000", + "http://localhost:8001", + "http://localhost:8002" + ], + "program": "", + "leader": 0, + "party": 0, + "input": "" + } + ] +} diff --git a/examples/postgres-integration/src/main.rs b/examples/postgres-integration/src/main.rs new file mode 100644 index 0000000..47509a1 --- /dev/null +++ b/examples/postgres-integration/src/main.rs @@ -0,0 +1,311 @@ +use anyhow::anyhow; +use anyhow::{Context, Error}; +use axum::Json; +use axum::{ + body::Bytes, + extract::{Path, State}, + routing::post, + Router, +}; +use clap::Parser; +use parlay::garble_lang::literal::Literal; +use parlay::{ + channel::Channel, + fpre::fpre, + garble_lang::compile, + protocol::{mpc, Preprocessor}, +}; +use reqwest::StatusCode; +use serde::{Deserialize, Serialize}; +use std::borrow::BorrowMut; +use std::process::exit; +use std::sync::Arc; +use std::{net::SocketAddr, path::PathBuf, result::Result, time::Duration}; +use tokio::sync::Mutex; +use tokio::time::timeout; +use tokio::{ + fs, + sync::mpsc::{channel, Receiver, Sender}, + time::sleep, +}; +use tower_http::trace::TraceLayer; +use url::Url; + +/// A CLI for Multi-Party Computation using the Parlay engine. +#[derive(Debug, Parser)] +#[command(name = "parlay")] +struct Cli { + /// The port to listen on for connection attempts from other parties. + #[arg(required = true, long, short)] + port: u16, + /// The location of the file with the policy configuration. + #[arg(long, short)] + config: PathBuf, +} + +#[derive(Debug, Clone, Deserialize)] +struct Policies { + accepted: Vec, +} + +#[derive(Debug, Clone, Deserialize, PartialEq, Eq, Hash)] +struct Policy { + participants: Vec, + program: PathBuf, + leader: usize, + party: usize, + // TODO: replace with db connection info + input: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +struct PolicyRequest { + participants: Vec, + program_hash: String, + leader: usize, +} + +const TIME_BETWEEN_EXECUTIONS: Duration = Duration::from_secs(5); + +type MpcState = Arc>>>>; + +#[tokio::main] +async fn main() -> Result<(), Error> { + let Cli { port, config } = Cli::parse(); + let policies = load_policies(config).await?; + if policies.accepted.len() == 1 { + let policy = &policies.accepted[0]; + if policy.input.is_empty() { + println!("Running as preprocessor..."); + return run_fpre(port, policy.participants.clone()).await; + } + } + tracing_subscriber::fmt::init(); + + let state = Arc::new(Mutex::new(vec![])); + + let app = Router::new() + .route("/run", post(run)) + .route("/msg/:from", post(msg)) + .with_state((policies.clone(), Arc::clone(&state))) + .layer(TraceLayer::new_for_http()); + + let addr = SocketAddr::from(([127, 0, 0, 1], port)); + tracing::info!("listening on {}", addr); + let listener = tokio::net::TcpListener::bind(&addr).await?; + tokio::spawn(async move { axum::serve(listener, app).await.unwrap() }); + loop { + for policy in &policies.accepted { + if policy.leader == policy.party { + println!( + "Acting as leader (party {}) for program {}", + policy.leader, + policy.program.display() + ); + let Ok(code) = fs::read_to_string(&policy.program).await else { + eprintln!("Could not load program {:?}", &policy.program); + continue; + }; + let hash = blake3::hash(code.as_bytes()).to_string(); + let client = reqwest::Client::new(); + let policy_request = PolicyRequest { + participants: policy.participants.clone(), + leader: policy.leader, + program_hash: hash, + }; + for party in policy.participants.iter().rev().skip(1).rev() { + if party != &policy.participants[policy.party] { + println!("Waiting for confirmation from party {party}"); + let url = format!("{party}run"); + match client + .post(&url) + .json(&policy_request) + .send() + .await? + .status() + { + StatusCode::OK => {} + code => { + eprintln!("Unexpected response while trying to trigger execution for {url}: {code}"); + } + } + } + } + println!("All participants have accepted the session, starting calculation now..."); + match execute_mpc(Arc::clone(&state), code, policy).await { + Ok(Some(output)) => { + println!("Output is {output}") + } + Ok(None) => {} + Err(e) => { + eprintln!("Error while executing MPC: {e}") + } + } + } + } + sleep(TIME_BETWEEN_EXECUTIONS).await; + } +} + +async fn run_fpre(port: u16, urls: Vec) -> Result<(), Error> { + let parties = urls.len() - 1; + + tracing_subscriber::fmt::init(); + + let mut senders = vec![]; + let mut receivers = vec![]; + for _ in 0..parties { + let (s, r) = channel(1); + senders.push(s); + receivers.push(r); + } + + let policies = Policies { accepted: vec![] }; + let app = Router::new() + .route("/msg/:from", post(msg)) + .with_state((policies, Arc::new(Mutex::new(senders)))) + .layer(TraceLayer::new_for_http()); + + let addr = SocketAddr::from(([127, 0, 0, 1], port)); + tracing::info!("listening on {}", addr); + let listener = tokio::net::TcpListener::bind(&addr).await?; + tokio::spawn(async move { axum::serve(listener, app).await.unwrap() }); + + let mut channel = HttpChannel { + urls, + party: parties, + recv: receivers, + }; + loop { + println!("Running FPre..."); + channel = fpre(channel, parties).await.context("FPre")?; + } +} + +async fn load_policies(path: PathBuf) -> Result { + let Ok(policies) = fs::read_to_string(&path).await else { + eprintln!("Could not find '{}', exiting...", path.display()); + exit(-1); + }; + let Ok(policies) = serde_json::from_str::(&policies) else { + eprintln!("'{}' has an invalid format, exiting...", path.display()); + exit(-1); + }; + Ok(policies) +} + +async fn execute_mpc( + state: MpcState, + code: String, + policy: &Policy, +) -> Result, Error> { + let Policy { + leader, + participants, + party, + input, + .. + } = policy; + let prg = compile(&code).map_err(|e| anyhow!(e.prettify(&code)))?; + let input = prg.parse_arg(*party, input)?.as_bits(); + let fpre = Preprocessor::TrustedDealer(participants.len() - 1); + let p_out: Vec<_> = vec![*leader]; + let channel = { + let mut state = state.lock().await; + let senders = state.borrow_mut(); + if !senders.is_empty() { + panic!("Cannot start a new MPC execution while there are still active senders!"); + } + let mut receivers = vec![]; + for _ in 0..policy.participants.len() { + let (s, r) = channel(1); + senders.push(s); + receivers.push(r); + } + + HttpChannel { + urls: participants.clone(), + party: *party, + recv: receivers, + } + }; + let output = mpc(channel, &prg.circuit, &input, fpre, 0, *party, &p_out).await?; + state.lock().await.clear(); + if output.is_empty() { + Ok(None) + } else { + Ok(Some(prg.parse_output(&output)?)) + } +} + +async fn run( + State((policies, state)): State<(Policies, MpcState)>, + Json(body): Json, +) { + for policy in policies.accepted { + if policy.participants == body.participants && policy.leader == body.leader { + let Ok(code) = fs::read_to_string(&policy.program).await else { + eprintln!("Could not load program {:?}", &policy.program); + return; + }; + let expected = blake3::hash(code.as_bytes()).to_string(); + if expected != body.program_hash { + eprintln!("Aborting due to different hashes for program in policy {policy:?}"); + return; + } + println!( + "Accepted policy for {}, starting execution", + policy.program.display() + ); + tokio::spawn(async move { + if let Err(e) = execute_mpc(state, code, &policy).await { + eprintln!("{e}"); + } + }); + return; + } + } + eprintln!("Policy not accepted: {body:?}"); +} + +async fn msg(State((_, state)): State<(Policies, MpcState)>, Path(from): Path, body: Bytes) { + let senders = state.lock().await; + if senders.len() > from as usize { + senders[from as usize].send(body.to_vec()).await.unwrap(); + } else { + eprintln!("No sender for party {from}"); + } +} + +struct HttpChannel { + urls: Vec, + party: usize, + recv: Vec>>, +} + +impl Channel for HttpChannel { + type SendError = anyhow::Error; + type RecvError = anyhow::Error; + + async fn send_bytes_to(&mut self, p: usize, msg: Vec) -> Result<(), Self::SendError> { + let client = reqwest::Client::new(); + let url = format!("{}msg/{}", self.urls[p], self.party); + loop { + match client.post(&url).body(msg.clone()).send().await?.status() { + StatusCode::OK => return Ok(()), + StatusCode::NOT_FOUND => { + println!("Could not reach party {p} at {url}..."); + sleep(Duration::from_millis(1000)).await; + } + status => anyhow::bail!("Unexpected status code: {status}"), + } + } + } + + async fn recv_bytes_from(&mut self, p: usize) -> Result, Self::RecvError> { + Ok(timeout(Duration::from_secs(10), self.recv[p].recv()) + .await + .context(format!("recv_bytes_from(p = {p})"))? + .unwrap_or_default()) + } +} diff --git a/examples/postgres-integration/tests/cli.rs b/examples/postgres-integration/tests/cli.rs new file mode 100644 index 0000000..8ca6861 --- /dev/null +++ b/examples/postgres-integration/tests/cli.rs @@ -0,0 +1,92 @@ +use std::{ + io::{BufRead, BufReader}, + process::{Command, Stdio}, + thread::{self, sleep}, + time::Duration, +}; + +const URLS: &str = + "http://127.0.0.1:8000;http://127.0.0.1:8001;http://127.0.0.1:8002;http://127.0.0.1:8003"; + +#[test] +fn simulate() { + for millis in [50, 500, 2_000, 10_000] { + println!("\n\n--- {millis}ms ---\n\n"); + let mut children = vec![]; + let mut cmd = Command::new("cargo") + .args(["run", "--", URLS]) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .unwrap(); + sleep(Duration::from_millis(millis)); + let mut stdout = BufReader::new(cmd.stdout.take().unwrap()).lines(); + let mut stderr = BufReader::new(cmd.stderr.take().unwrap()).lines(); + thread::spawn(move || { + while let Some(Ok(line)) = stdout.next() { + println!("pre> {line}"); + } + }); + thread::spawn(move || { + while let Some(Ok(line)) = stderr.next() { + eprintln!("pre> {line}"); + } + }); + children.push(cmd); + for p in [1, 2] { + let party_arg = format!("--party={p}"); + let args = vec![ + "run", + "--", + "party", + URLS, + "--program=.add.garble.rs", + "--input=2u32", + &party_arg, + ]; + let mut cmd = Command::new("cargo") + .args(args) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .unwrap(); + let mut stdout = BufReader::new(cmd.stdout.take().unwrap()).lines().skip(4); + let mut stderr = BufReader::new(cmd.stderr.take().unwrap()).lines(); + thread::spawn(move || { + while let Some(Ok(line)) = stdout.next() { + println!("party{p}> {line}"); + } + }); + thread::spawn(move || { + while let Some(Ok(line)) = stderr.next() { + eprintln!("party{p}> {line}"); + } + }); + children.push(cmd); + } + let args = vec![ + "run", + "--", + "party", + URLS, + "--program=.add.garble.rs", + "--input=2u32", + "--party=0", + ]; + let out = Command::new("cargo").args(args).output().unwrap(); + eprintln!("{}", String::from_utf8(out.stderr).unwrap()); + let out = String::from_utf8(out.stdout).unwrap(); + let out = out.lines().last().unwrap_or_default(); + for mut child in children { + child.kill().unwrap(); + } + if out == "The result is 6u32" { + return; + } else { + eprintln!("Last line is '{out}'"); + sleep(Duration::from_millis(millis)); + println!("Could not complete test successfully, trying again with more time..."); + } + } + panic!("Test failed repeatedly!"); +} diff --git a/src/fpre.rs b/src/fpre.rs index d1df8a5..7d60858 100644 --- a/src/fpre.rs +++ b/src/fpre.rs @@ -40,7 +40,10 @@ impl From for Error { } /// Runs FPre as a trusted dealer, communicating with all other parties. -pub async fn fpre(channel: impl Channel, parties: usize) -> Result<(), Error> { +pub async fn fpre(channel: C, parties: usize) -> Result +where + C: Channel, +{ let mut channel = MsgChannel(channel); for p in 0..parties { channel.recv_from(p, "delta (fpre)").await?; @@ -183,7 +186,7 @@ pub async fn fpre(channel: impl Channel, parties: usize) -> Result<(), Error> { for (p, and_shares) in and_shares.into_iter().enumerate() { channel.send_to(p, "AND shares (fpre)", &and_shares).await?; } - Ok(()) + Ok(channel.0) } /// The global key known only to a single party that is used to authenticate bits. diff --git a/src/protocol.rs b/src/protocol.rs index 4137902..88a253a 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -126,10 +126,16 @@ impl From for Error { /// A custom error type for all MPC operations. #[derive(Debug)] pub enum MpcError { - /// No secret share was sent for the specified wire. - MissingShareForWire(usize), + /// No secret share was sent during preprocessing for the specified wire. + MissingPreprocessingShareForWire(usize), + /// No secret share was sent in the garbled table for the specified wire. + MissingTableShareForWire(usize), + /// No secret share was sent for the specified output wire. + MissingOutputShareForWire(usize), /// No AND share was sent for the specified wire. MissingAndShareForWire(usize), + /// No share was sent for the input wire, possibly because there are fewer parties than inputs. + MissingSharesForInput(usize), /// The input on the specified wire did not match the message authenatication code. InvalidInputMacOnWire(usize), /// Two different parties tried to provide an input mask for the wire. @@ -292,13 +298,12 @@ pub async fn mpc( let random_shares: Vec = channel.recv_from(p_fpre, "random shares").await?; let mut random_shares = random_shares.into_iter(); - //let mut wire_shares_and_labels = vec![(Share(false, Auth(vec![])), Label(0)); num_gates]; let mut shares = vec![Share(false, Auth(vec![])); num_gates]; let mut labels = vec![Label(0); num_gates]; for (w, gate) in circuit.wires().iter().enumerate() { if let Wire::Input(_) | Wire::And(_, _) = gate { let Some(share) = random_shares.next() else { - return Err(MpcError::MissingShareForWire(w).into()); + return Err(MpcError::MissingPreprocessingShareForWire(w).into()); }; shares[w] = share; if is_contrib { @@ -415,8 +420,11 @@ pub async fn mpc( for (w, gate) in circuit.wires().iter().enumerate() { if let Wire::Input(i) = gate { let Share(bit, Auth(macs_and_keys)) = shares[w].clone(); - if let Some((mac, _)) = macs_and_keys[*i] { - wire_shares_for_others[*i][w] = Some((bit, mac)); + let Some(mac_and_key) = macs_and_keys.get(*i) else { + return Err(MpcError::MissingSharesForInput(*i).into()); + }; + if let Some((mac, _)) = mac_and_key { + wire_shares_for_others[*i][w] = Some((bit, *mac)); } } } @@ -540,7 +548,7 @@ pub async fn mpc( let label_y = &labels_eval[*y]; let i = 2 * (input_x as usize) + (input_y as usize); let Some(table_shares) = &table_shares[w] else { - return Err(MpcError::MissingShareForWire(w).into()); + return Err(MpcError::MissingTableShareForWire(w).into()); }; let mut label = vec![Label(0); p_max]; @@ -607,30 +615,36 @@ pub async fn mpc( .await?; } } - let mut wires_and_labels: Vec> = vec![None; num_gates]; + let mut input_wires: Vec> = vec![None; num_gates]; if !is_contrib { for p_out in p_out.iter().copied().filter(|p| *p != p_own) { + let mut wires_and_labels: Vec> = vec![None; num_gates]; for w in circuit.output_gates.iter().copied() { wires_and_labels[w] = Some((values[w], labels_eval[w][p_out])); } channel.send_to(p_out, "lambda", &wires_and_labels).await?; } + for w in circuit.output_gates.iter().copied() { + input_wires[w] = Some(values[w]); + } } else if p_out.contains(&p_own) { - wires_and_labels = channel.recv_vec_from(p_eval, "lambda", num_gates).await?; + let wires_and_labels: Vec> = + channel.recv_vec_from(p_eval, "lambda", num_gates).await?; for w in circuit.output_gates.iter().copied() { if !(wires_and_labels[w] == Some((true, labels[w] ^ delta)) || wires_and_labels[w] == Some((false, labels[w]))) { return Err(MpcError::InvalidOutputWireLabel(w).into()); } + input_wires[w] = wires_and_labels[w].map(|(bit, _)| bit); } } let mut outputs: Vec = vec![]; if p_out.contains(&p_own) { let mut output_wires: Vec> = vec![None; num_gates]; for w in circuit.output_gates.iter().copied() { - let Some((input, _)) = wires_and_labels.get(w).copied().flatten() else { - return Err(MpcError::MissingShareForWire(w).into()); + let Some(input) = input_wires.get(w).copied().flatten() else { + return Err(MpcError::MissingOutputShareForWire(w).into()); }; let Share(bit, _) = &shares[w]; output_wires[w] = Some(input ^ bit); From c1db33defb4f82da17970f9db00aa0f41f6286de Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Wed, 10 Apr 2024 18:37:21 +0100 Subject: [PATCH 02/25] Fix other examples --- examples/http-multi-server-channels/src/main.rs | 3 ++- examples/http-single-server-channels/src/main.rs | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/http-multi-server-channels/src/main.rs b/examples/http-multi-server-channels/src/main.rs index 1667526..9a572e5 100644 --- a/examples/http-multi-server-channels/src/main.rs +++ b/examples/http-multi-server-channels/src/main.rs @@ -66,7 +66,8 @@ async fn main() -> Result<(), Error> { Commands::Pre { urls } => { let parties = urls.len() - 1; let channel = HttpChannel::new(urls, parties).await?; - fpre(channel, parties).await.context("FPre") + fpre(channel, parties).await.context("FPre").unwrap(); + Ok(()) } Commands::Party { urls, diff --git a/examples/http-single-server-channels/src/main.rs b/examples/http-single-server-channels/src/main.rs index 027392b..679a67d 100644 --- a/examples/http-single-server-channels/src/main.rs +++ b/examples/http-single-server-channels/src/main.rs @@ -82,7 +82,7 @@ async fn main() { sleep(Duration::from_secs(1)).await; } } - fpre(channel, parties).await.unwrap() + fpre(channel, parties).await.unwrap(); } Commands::Party { url, From b7998b7c1aafb5d8c1efafce82c227db57c41953 Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Mon, 15 Apr 2024 15:21:44 +0100 Subject: [PATCH 03/25] Add test for integration example --- examples/postgres-integration/src/main.rs | 1 + examples/postgres-integration/tests/cli.rs | 126 ++++++++++----------- 2 files changed, 62 insertions(+), 65 deletions(-) diff --git a/examples/postgres-integration/src/main.rs b/examples/postgres-integration/src/main.rs index 47509a1..1d8e9e6 100644 --- a/examples/postgres-integration/src/main.rs +++ b/examples/postgres-integration/src/main.rs @@ -94,6 +94,7 @@ async fn main() -> Result<(), Error> { tracing::info!("listening on {}", addr); let listener = tokio::net::TcpListener::bind(&addr).await?; tokio::spawn(async move { axum::serve(listener, app).await.unwrap() }); + println!("Found {} active policies", policies.accepted.len()); loop { for policy in &policies.accepted { if policy.leader == policy.party { diff --git a/examples/postgres-integration/tests/cli.rs b/examples/postgres-integration/tests/cli.rs index 8ca6861..625a320 100644 --- a/examples/postgres-integration/tests/cli.rs +++ b/examples/postgres-integration/tests/cli.rs @@ -1,92 +1,88 @@ use std::{ io::{BufRead, BufReader}, - process::{Command, Stdio}, + process::{exit, Command, Stdio}, thread::{self, sleep}, time::Duration, }; -const URLS: &str = - "http://127.0.0.1:8000;http://127.0.0.1:8001;http://127.0.0.1:8002;http://127.0.0.1:8003"; - #[test] fn simulate() { - for millis in [50, 500, 2_000, 10_000] { - println!("\n\n--- {millis}ms ---\n\n"); - let mut children = vec![]; + let millis = 1000; + println!("\n\n--- {millis}ms ---\n\n"); + let mut children = vec![]; + let mut cmd = Command::new("cargo") + .args(["run", "--", "--port=8002", "--config=preprocessor.json"]) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .unwrap(); + + sleep(Duration::from_millis(millis)); + let mut stdout = BufReader::new(cmd.stdout.take().unwrap()).lines(); + let mut stderr = BufReader::new(cmd.stderr.take().unwrap()).lines(); + thread::spawn(move || { + while let Some(Ok(line)) = stdout.next() { + println!("pre> {line}"); + } + }); + thread::spawn(move || { + while let Some(Ok(line)) = stderr.next() { + eprintln!("pre> {line}"); + } + }); + children.push(cmd); + for p in [1] { + let port = format!("--port=800{p}"); + let config = format!("--config=policies{p}.json"); + let args = vec!["run", "--", &port, &config]; let mut cmd = Command::new("cargo") - .args(["run", "--", URLS]) + .args(args) .stdout(Stdio::piped()) .stderr(Stdio::piped()) .spawn() .unwrap(); - sleep(Duration::from_millis(millis)); - let mut stdout = BufReader::new(cmd.stdout.take().unwrap()).lines(); + let mut stdout = BufReader::new(cmd.stdout.take().unwrap()).lines().skip(4); let mut stderr = BufReader::new(cmd.stderr.take().unwrap()).lines(); thread::spawn(move || { while let Some(Ok(line)) = stdout.next() { - println!("pre> {line}"); + println!("party{p}> {line}"); } }); thread::spawn(move || { while let Some(Ok(line)) = stderr.next() { - eprintln!("pre> {line}"); + eprintln!("party{p}> {line}"); } }); children.push(cmd); - for p in [1, 2] { - let party_arg = format!("--party={p}"); - let args = vec![ - "run", - "--", - "party", - URLS, - "--program=.add.garble.rs", - "--input=2u32", - &party_arg, - ]; - let mut cmd = Command::new("cargo") - .args(args) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .spawn() - .unwrap(); - let mut stdout = BufReader::new(cmd.stdout.take().unwrap()).lines().skip(4); - let mut stderr = BufReader::new(cmd.stderr.take().unwrap()).lines(); - thread::spawn(move || { - while let Some(Ok(line)) = stdout.next() { - println!("party{p}> {line}"); - } - }); - thread::spawn(move || { - while let Some(Ok(line)) = stderr.next() { - eprintln!("party{p}> {line}"); - } - }); - children.push(cmd); - } - let args = vec![ - "run", - "--", - "party", - URLS, - "--program=.add.garble.rs", - "--input=2u32", - "--party=0", - ]; - let out = Command::new("cargo").args(args).output().unwrap(); - eprintln!("{}", String::from_utf8(out.stderr).unwrap()); - let out = String::from_utf8(out.stdout).unwrap(); - let out = out.lines().last().unwrap_or_default(); - for mut child in children { - child.kill().unwrap(); + } + sleep(Duration::from_millis(500)); + let args = vec!["run", "--", "--port=8000", "--config=policies0.json"]; + let mut cmd = Command::new("cargo") + .args(args) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .unwrap(); + let mut stdout = BufReader::new(cmd.stdout.take().unwrap()).lines(); + let mut stderr = BufReader::new(cmd.stderr.take().unwrap()).lines(); + children.push(cmd); + thread::spawn(move || { + while let Some(Ok(line)) = stdout.next() { + println!("party0> {line}"); + if line == "Output is 5u32" { + exit(0); + } } - if out == "The result is 6u32" { - return; - } else { - eprintln!("Last line is '{out}'"); - sleep(Duration::from_millis(millis)); - println!("Could not complete test successfully, trying again with more time..."); + }); + thread::spawn(move || { + while let Some(Ok(line)) = stderr.next() { + eprintln!("party0> {line}"); } + }); + sleep(Duration::from_secs(10)); + eprintln!("Test did not complete!"); + for mut child in children { + child.kill().unwrap(); } - panic!("Test failed repeatedly!"); + exit(-1); } From 8d0b22da7a253c6c992009d3790055e8ca2c37a1 Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Tue, 16 Apr 2024 17:31:54 +0100 Subject: [PATCH 04/25] Fix shutdown of child processes in tests/cli.rs --- examples/postgres-integration/.add.garble.rs | 4 ++-- examples/postgres-integration/policies0.json | 5 +++-- examples/postgres-integration/policies1.json | 5 +++-- examples/postgres-integration/policies2.json | 16 ++++++++++++++ .../postgres-integration/preprocessor.json | 3 ++- examples/postgres-integration/tests/cli.rs | 21 ++++++++++++------- 6 files changed, 40 insertions(+), 14 deletions(-) create mode 100644 examples/postgres-integration/policies2.json diff --git a/examples/postgres-integration/.add.garble.rs b/examples/postgres-integration/.add.garble.rs index ca6e9c6..2d202fe 100644 --- a/examples/postgres-integration/.add.garble.rs +++ b/examples/postgres-integration/.add.garble.rs @@ -1,3 +1,3 @@ -pub fn main(x: u32, y: u32) -> u32 { - x + y +pub fn main(x: u32, y: u32, z: u32) -> u32 { + x + y + z } diff --git a/examples/postgres-integration/policies0.json b/examples/postgres-integration/policies0.json index 396cff6..3d65240 100644 --- a/examples/postgres-integration/policies0.json +++ b/examples/postgres-integration/policies0.json @@ -4,12 +4,13 @@ "participants": [ "http://localhost:8000", "http://localhost:8001", - "http://localhost:8002" + "http://localhost:8002", + "http://localhost:8003" ], "program": ".add.garble.rs", "leader": 0, "party": 0, - "input": "2u32" + "input": "0u32" } ] } diff --git a/examples/postgres-integration/policies1.json b/examples/postgres-integration/policies1.json index f9dfe7c..ec66ae9 100644 --- a/examples/postgres-integration/policies1.json +++ b/examples/postgres-integration/policies1.json @@ -4,12 +4,13 @@ "participants": [ "http://localhost:8000", "http://localhost:8001", - "http://localhost:8002" + "http://localhost:8002", + "http://localhost:8003" ], "program": ".add.garble.rs", "leader": 0, "party": 1, - "input": "3u32" + "input": "1u32" } ] } diff --git a/examples/postgres-integration/policies2.json b/examples/postgres-integration/policies2.json new file mode 100644 index 0000000..5039866 --- /dev/null +++ b/examples/postgres-integration/policies2.json @@ -0,0 +1,16 @@ +{ + "accepted": [ + { + "participants": [ + "http://localhost:8000", + "http://localhost:8001", + "http://localhost:8002", + "http://localhost:8003" + ], + "program": ".add.garble.rs", + "leader": 0, + "party": 2, + "input": "2u32" + } + ] +} diff --git a/examples/postgres-integration/preprocessor.json b/examples/postgres-integration/preprocessor.json index b1e8251..d29c11e 100644 --- a/examples/postgres-integration/preprocessor.json +++ b/examples/postgres-integration/preprocessor.json @@ -4,7 +4,8 @@ "participants": [ "http://localhost:8000", "http://localhost:8001", - "http://localhost:8002" + "http://localhost:8002", + "http://localhost:8003" ], "program": "", "leader": 0, diff --git a/examples/postgres-integration/tests/cli.rs b/examples/postgres-integration/tests/cli.rs index 625a320..49800fa 100644 --- a/examples/postgres-integration/tests/cli.rs +++ b/examples/postgres-integration/tests/cli.rs @@ -1,6 +1,7 @@ use std::{ io::{BufRead, BufReader}, process::{exit, Command, Stdio}, + sync::mpsc::channel, thread::{self, sleep}, time::Duration, }; @@ -11,7 +12,7 @@ fn simulate() { println!("\n\n--- {millis}ms ---\n\n"); let mut children = vec![]; let mut cmd = Command::new("cargo") - .args(["run", "--", "--port=8002", "--config=preprocessor.json"]) + .args(["run", "--", "--port=8003", "--config=preprocessor.json"]) .stdout(Stdio::piped()) .stderr(Stdio::piped()) .spawn() @@ -31,7 +32,7 @@ fn simulate() { } }); children.push(cmd); - for p in [1] { + for p in [1, 2] { let port = format!("--port=800{p}"); let config = format!("--config=policies{p}.json"); let args = vec!["run", "--", &port, &config]; @@ -66,11 +67,12 @@ fn simulate() { let mut stdout = BufReader::new(cmd.stdout.take().unwrap()).lines(); let mut stderr = BufReader::new(cmd.stderr.take().unwrap()).lines(); children.push(cmd); + let (s, r) = channel::<()>(); thread::spawn(move || { while let Some(Ok(line)) = stdout.next() { println!("party0> {line}"); - if line == "Output is 5u32" { - exit(0); + if line == "Output is 3u32" { + return s.send(()).unwrap(); } } }); @@ -79,10 +81,15 @@ fn simulate() { eprintln!("party0> {line}"); } }); - sleep(Duration::from_secs(10)); - eprintln!("Test did not complete!"); + let result = r.recv_timeout(Duration::from_secs(10)); for mut child in children { child.kill().unwrap(); } - exit(-1); + match result { + Ok(_) => exit(0), + Err(_) => { + eprintln!("Test did not complete!"); + exit(-1); + } + } } From c783d701b67a4db21eba115f473e612aa0b70369 Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Wed, 17 Apr 2024 17:34:33 +0100 Subject: [PATCH 05/25] Load MPC inputs from Postgres DBs --- examples/postgres-integration/.add.garble.rs | 3 - .../postgres-integration/.example.garble.rs | 30 +++++ examples/postgres-integration/Cargo.toml | 1 + .../postgres-integration/docker-compose.yml | 29 +++++ examples/postgres-integration/policies0.json | 6 +- examples/postgres-integration/policies1.json | 6 +- examples/postgres-integration/policies2.json | 2 +- .../postgres-integration/seeds/db0/init.sql | 11 ++ .../seeds/db0/residents.csv | 3 + .../postgres-integration/seeds/db1/init.sql | 11 ++ .../seeds/db1/insurance.csv | 3 + examples/postgres-integration/src/main.rs | 121 +++++++++++++++--- examples/postgres-integration/tests/cli.rs | 2 +- 13 files changed, 200 insertions(+), 28 deletions(-) delete mode 100644 examples/postgres-integration/.add.garble.rs create mode 100644 examples/postgres-integration/.example.garble.rs create mode 100644 examples/postgres-integration/docker-compose.yml create mode 100644 examples/postgres-integration/seeds/db0/init.sql create mode 100644 examples/postgres-integration/seeds/db0/residents.csv create mode 100644 examples/postgres-integration/seeds/db1/init.sql create mode 100644 examples/postgres-integration/seeds/db1/insurance.csv diff --git a/examples/postgres-integration/.add.garble.rs b/examples/postgres-integration/.add.garble.rs deleted file mode 100644 index 2d202fe..0000000 --- a/examples/postgres-integration/.add.garble.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub fn main(x: u32, y: u32, z: u32) -> u32 { - x + y + z -} diff --git a/examples/postgres-integration/.example.garble.rs b/examples/postgres-integration/.example.garble.rs new file mode 100644 index 0000000..a6f246b --- /dev/null +++ b/examples/postgres-integration/.example.garble.rs @@ -0,0 +1,30 @@ +enum Row0 { + None, + Some(i32, i32), +} + +enum Row1 { + None, + Some(i32, bool), +} + +pub fn main(residents: [Row0; 10], insurance: [Row1; 10], z: u32) -> u32 { + let mut result = 0u32; + for row in residents { + match row { + Row0::Some(id, age) => result = result + (age as u32), + Row0::None => {} + } + } + for row in insurance { + match row { + Row1::Some(id, disabled) => { + if disabled == true { + result = result + 1u32; + } + } + Row1::None => {} + } + } + result + z +} diff --git a/examples/postgres-integration/Cargo.toml b/examples/postgres-integration/Cargo.toml index e1e4724..7d2beaa 100644 --- a/examples/postgres-integration/Cargo.toml +++ b/examples/postgres-integration/Cargo.toml @@ -12,6 +12,7 @@ parlay = { path = "../../", version = "0.1.0" } reqwest = { version = "0.11.23", features = ["json"] } serde = "1.0.197" serde_json = "1.0.115" +sqlx = { version = "0.7.4", features = ["runtime-tokio", "postgres"] } tokio = { version = "1.35.1", features = ["macros", "rt", "rt-multi-thread"] } tower-http = { version = "0.5.1", features = ["fs", "trace"] } tracing = "0.1.40" diff --git a/examples/postgres-integration/docker-compose.yml b/examples/postgres-integration/docker-compose.yml new file mode 100644 index 0000000..aa76b24 --- /dev/null +++ b/examples/postgres-integration/docker-compose.yml @@ -0,0 +1,29 @@ +version: "3.9" +services: + db0: + image: postgres:16.2 + restart: always + shm_size: 128mb + volumes: + - ./seeds/db0:/docker-entrypoint-initdb.d + ports: + - 5550:5432 + environment: + POSTGRES_PASSWORD: test + + db1: + image: postgres:16.2 + restart: always + shm_size: 128mb + volumes: + - ./seeds/db1:/docker-entrypoint-initdb.d + ports: + - 5551:5432 + environment: + POSTGRES_PASSWORD: test + + adminer: + image: adminer + restart: always + ports: + - 8080:8080 diff --git a/examples/postgres-integration/policies0.json b/examples/postgres-integration/policies0.json index 3d65240..54c8498 100644 --- a/examples/postgres-integration/policies0.json +++ b/examples/postgres-integration/policies0.json @@ -7,10 +7,12 @@ "http://localhost:8002", "http://localhost:8003" ], - "program": ".add.garble.rs", + "program": ".example.garble.rs", "leader": 0, "party": 0, - "input": "0u32" + "input": "SELECT id, age FROM residents", + "db": "postgres://postgres:test@localhost:5550/postgres", + "max_rows": 10 } ] } diff --git a/examples/postgres-integration/policies1.json b/examples/postgres-integration/policies1.json index ec66ae9..360ef7d 100644 --- a/examples/postgres-integration/policies1.json +++ b/examples/postgres-integration/policies1.json @@ -7,10 +7,12 @@ "http://localhost:8002", "http://localhost:8003" ], - "program": ".add.garble.rs", + "program": ".example.garble.rs", "leader": 0, "party": 1, - "input": "1u32" + "input": "SELECT id, disabled FROM insurance", + "db": "postgres://postgres:test@localhost:5551/postgres", + "max_rows": 10 } ] } diff --git a/examples/postgres-integration/policies2.json b/examples/postgres-integration/policies2.json index 5039866..46210b4 100644 --- a/examples/postgres-integration/policies2.json +++ b/examples/postgres-integration/policies2.json @@ -7,7 +7,7 @@ "http://localhost:8002", "http://localhost:8003" ], - "program": ".add.garble.rs", + "program": ".example.garble.rs", "leader": 0, "party": 2, "input": "2u32" diff --git a/examples/postgres-integration/seeds/db0/init.sql b/examples/postgres-integration/seeds/db0/init.sql new file mode 100644 index 0000000..30c94be --- /dev/null +++ b/examples/postgres-integration/seeds/db0/init.sql @@ -0,0 +1,11 @@ +CREATE TABLE residents ( + id SERIAL PRIMARY KEY, + name TEXT, + age INT, + address TEXT +); + +COPY residents +FROM '/docker-entrypoint-initdb.d/residents.csv' +DELIMITER ',' +CSV HEADER; \ No newline at end of file diff --git a/examples/postgres-integration/seeds/db0/residents.csv b/examples/postgres-integration/seeds/db0/residents.csv new file mode 100644 index 0000000..356d040 --- /dev/null +++ b/examples/postgres-integration/seeds/db0/residents.csv @@ -0,0 +1,3 @@ +Id,Name,Age,Address +0,"Max Mustermann",22,"Parkstraße 1, 12345 Berlin" +1,"Erika Musterfrau",33,"Schlossallee 2, 12345 Berlin" diff --git a/examples/postgres-integration/seeds/db1/init.sql b/examples/postgres-integration/seeds/db1/init.sql new file mode 100644 index 0000000..906d9e0 --- /dev/null +++ b/examples/postgres-integration/seeds/db1/init.sql @@ -0,0 +1,11 @@ +CREATE TABLE insurance ( + id SERIAL PRIMARY KEY, + name TEXT, + disabled BOOL, + address TEXT +); + +COPY insurance +FROM '/docker-entrypoint-initdb.d/insurance.csv' +DELIMITER ',' +CSV HEADER; \ No newline at end of file diff --git a/examples/postgres-integration/seeds/db1/insurance.csv b/examples/postgres-integration/seeds/db1/insurance.csv new file mode 100644 index 0000000..21f0512 --- /dev/null +++ b/examples/postgres-integration/seeds/db1/insurance.csv @@ -0,0 +1,3 @@ +Id,Name,Disabled,Address +0,"Max Mustermann",true,"Parkstraße 1, 12345 Berlin" +1,"Erika Musterfrau",false,"Schlossallee 2, 12345 Berlin" diff --git a/examples/postgres-integration/src/main.rs b/examples/postgres-integration/src/main.rs index 1d8e9e6..aafe4af 100644 --- a/examples/postgres-integration/src/main.rs +++ b/examples/postgres-integration/src/main.rs @@ -1,36 +1,42 @@ -use anyhow::anyhow; -use anyhow::{Context, Error}; -use axum::Json; +use anyhow::{anyhow, bail, Context, Error}; use axum::{ body::Bytes, extract::{Path, State}, routing::post, - Router, + Json, Router, }; use clap::Parser; -use parlay::garble_lang::literal::Literal; use parlay::{ channel::Channel, fpre::fpre, - garble_lang::compile, + garble_lang::{ + compile, + literal::{Literal, VariantLiteral}, + token::SignedNumType, + }, protocol::{mpc, Preprocessor}, }; use reqwest::StatusCode; use serde::{Deserialize, Serialize}; -use std::borrow::BorrowMut; -use std::process::exit; -use std::sync::Arc; -use std::{net::SocketAddr, path::PathBuf, result::Result, time::Duration}; -use tokio::sync::Mutex; -use tokio::time::timeout; +use sqlx::{postgres::PgPoolOptions, Row}; +use std::{ + borrow::BorrowMut, net::SocketAddr, path::PathBuf, process::exit, result::Result, sync::Arc, + time::Duration, +}; use tokio::{ fs, - sync::mpsc::{channel, Receiver, Sender}, - time::sleep, + sync::{ + mpsc::{channel, Receiver, Sender}, + Mutex, + }, + time::{sleep, timeout}, }; use tower_http::trace::TraceLayer; use url::Url; +const TIME_BETWEEN_EXECUTIONS: Duration = Duration::from_secs(5); +const DEFAULT_MAX_ROWS: usize = 10; + /// A CLI for Multi-Party Computation using the Parlay engine. #[derive(Debug, Parser)] #[command(name = "parlay")] @@ -54,8 +60,9 @@ struct Policy { program: PathBuf, leader: usize, party: usize, - // TODO: replace with db connection info input: String, + db: Option, + max_rows: Option, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] @@ -65,12 +72,40 @@ struct PolicyRequest { leader: usize, } -const TIME_BETWEEN_EXECUTIONS: Duration = Duration::from_secs(5); - type MpcState = Arc>>>>; #[tokio::main] async fn main() -> Result<(), Error> { + /*let pool = PgPoolOptions::new() + .max_connections(5) + .connect("postgres://postgres:test@localhost:5555/postgres") + .await?; + match sqlx::query("SELECT * FROM residents") + .fetch_all(&pool) + .await + { + Ok(rows) => { + for row in rows { + for c in 0..row.len() { + if let Ok(s) = row.try_get::(c) { + print!("\"{s}\",") + } else if let Ok(n) = row.try_get::(c) { + print!("{n},") + } else if let Ok(n) = row.try_get::(c) { + print!("...{n},") + } else { + eprintln!("Could not decode column {c}"); + } + } + println!(""); + } + } + Err(e) => { + eprintln!("{e}") + } + } + + exit(0);*/ let Cli { port, config } = Cli::parse(); let policies = load_policies(config).await?; if policies.accepted.len() == 1 { @@ -201,14 +236,62 @@ async fn execute_mpc( policy: &Policy, ) -> Result, Error> { let Policy { + program: _program, leader, participants, party, input, - .. + db, + max_rows, } = policy; let prg = compile(&code).map_err(|e| anyhow!(e.prettify(&code)))?; - let input = prg.parse_arg(*party, input)?.as_bits(); + let input = if let Some(db) = db { + println!("Connecting to {db}..."); + let pool = PgPoolOptions::new().max_connections(5).connect(db).await?; + let rows = sqlx::query(input).fetch_all(&pool).await?; + println!("'{input}' returned {} rows in {db}", rows.len()); + let mut rows_as_literals = vec![]; + let max_rows = max_rows.unwrap_or(DEFAULT_MAX_ROWS); + for _ in 0..max_rows { + rows_as_literals.push(Literal::Enum( + format!("Row{party}"), + "None".to_string(), + VariantLiteral::Unit, + )); + } + for (r, row) in rows.iter().enumerate() { + let mut row_as_literal = vec![]; + for c in 0..row.len() { + let field = if let Ok(_) = row.try_get::(c) { + todo!("String literals are not yet implemented"); + } else if let Ok(b) = row.try_get::(c) { + Literal::from(b) + } else if let Ok(n) = row.try_get::(c) { + Literal::NumSigned(n as i64, SignedNumType::I32) + } else if let Ok(n) = row.try_get::(c) { + Literal::NumSigned(n as i64, SignedNumType::I64) + } else { + bail!("Could not decode column {c}"); + }; + row_as_literal.push(field); + } + if r >= max_rows { + eprintln!("Dropping record {r}"); + } else { + let literal = Literal::Enum( + format!("Row{party}"), + "Some".to_string(), + VariantLiteral::Tuple(row_as_literal), + ); + println!("rows[{r}] = {literal}"); + rows_as_literals[r] = literal; + } + } + let literal = Literal::Array(rows_as_literals); + prg.literal_arg(*party, literal)?.as_bits() + } else { + prg.parse_arg(*party, input)?.as_bits() + }; let fpre = Preprocessor::TrustedDealer(participants.len() - 1); let p_out: Vec<_> = vec![*leader]; let channel = { diff --git a/examples/postgres-integration/tests/cli.rs b/examples/postgres-integration/tests/cli.rs index 49800fa..40cc1b0 100644 --- a/examples/postgres-integration/tests/cli.rs +++ b/examples/postgres-integration/tests/cli.rs @@ -71,7 +71,7 @@ fn simulate() { thread::spawn(move || { while let Some(Ok(line)) = stdout.next() { println!("party0> {line}"); - if line == "Output is 3u32" { + if line == "Output is 58u32" { return s.send(()).unwrap(); } } From 9cafc34c162839dcbd09fcb0891dbabef598065e Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Wed, 17 Apr 2024 17:38:55 +0100 Subject: [PATCH 06/25] Fix clippy warnings --- examples/postgres-integration/src/main.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/postgres-integration/src/main.rs b/examples/postgres-integration/src/main.rs index aafe4af..22a57aa 100644 --- a/examples/postgres-integration/src/main.rs +++ b/examples/postgres-integration/src/main.rs @@ -262,14 +262,14 @@ async fn execute_mpc( for (r, row) in rows.iter().enumerate() { let mut row_as_literal = vec![]; for c in 0..row.len() { - let field = if let Ok(_) = row.try_get::(c) { - todo!("String literals are not yet implemented"); + let field = if let Ok(s) = row.try_get::(c) { + todo!("String literals are not yet implemented: '{s}'"); } else if let Ok(b) = row.try_get::(c) { Literal::from(b) } else if let Ok(n) = row.try_get::(c) { Literal::NumSigned(n as i64, SignedNumType::I32) } else if let Ok(n) = row.try_get::(c) { - Literal::NumSigned(n as i64, SignedNumType::I64) + Literal::NumSigned(n, SignedNumType::I64) } else { bail!("Could not decode column {c}"); }; From 44e6cf0c1e57cb8b93a2834cebb7e9a9f0a6f004 Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Wed, 17 Apr 2024 17:39:11 +0100 Subject: [PATCH 07/25] Run postgres example using CI --- .github/workflows/test.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index fffe078..4d3f38a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -20,3 +20,5 @@ jobs: working-directory: ./examples/http-single-server-channels - run: cargo test -- --nocapture working-directory: ./examples/iroh-p2p-channels + - run: docker compose -f docker-compose.yml up && cargo test -- --nocapture + working-directory: ./examples/postgres-integration From a1af61f8430d0f0d088c9eb1692d6b279f3a598c Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Wed, 17 Apr 2024 17:49:49 +0100 Subject: [PATCH 08/25] Run docker compose in background (in postgres example) --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4d3f38a..6c0183e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -20,5 +20,5 @@ jobs: working-directory: ./examples/http-single-server-channels - run: cargo test -- --nocapture working-directory: ./examples/iroh-p2p-channels - - run: docker compose -f docker-compose.yml up && cargo test -- --nocapture + - run: docker compose -f docker-compose.yml up -d && cargo test -- --nocapture working-directory: ./examples/postgres-integration From 74b71b7050bb990443a190f3eb8573d7e5651323 Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Mon, 22 Apr 2024 18:04:10 +0100 Subject: [PATCH 09/25] Use string DB fields in Postgres example --- .../postgres-integration/.example.garble.rs | 32 +++-- examples/postgres-integration/policies0.json | 4 +- examples/postgres-integration/policies1.json | 4 +- examples/postgres-integration/policies2.json | 2 +- .../seeds/db0/residents.csv | 6 +- examples/postgres-integration/src/main.rs | 117 +++++++++++------- examples/postgres-integration/tests/cli.rs | 4 +- 7 files changed, 100 insertions(+), 69 deletions(-) diff --git a/examples/postgres-integration/.example.garble.rs b/examples/postgres-integration/.example.garble.rs index a6f246b..adc25bd 100644 --- a/examples/postgres-integration/.example.garble.rs +++ b/examples/postgres-integration/.example.garble.rs @@ -1,30 +1,28 @@ enum Row0 { None, - Some(i32, i32), + Some(i32, [u8; 16], i32), } enum Row1 { None, - Some(i32, bool), + Some(i32, [u8; 16], bool), } -pub fn main(residents: [Row0; 10], insurance: [Row1; 10], z: u32) -> u32 { - let mut result = 0u32; - for row in residents { - match row { - Row0::Some(id, age) => result = result + (age as u32), - Row0::None => {} - } - } - for row in insurance { - match row { - Row1::Some(id, disabled) => { - if disabled == true { - result = result + 1u32; +pub fn main(residents: [Row0; 4], insurance: [Row1; 4], age_threshold: u32) -> [[u8; 16]; 4] { + let mut result = [[0u8; 16]; 4]; + let mut i = 0usize; + for resident in residents { + for insurance in insurance { + match (resident, insurance) { + (Row0::Some(_, name0, age), Row1::Some(_, name1, health_problems)) => { + if name0 == name1 && health_problems == true && age > (age_threshold as i32) { + result[i] = name0; + i = i + 1usize; + } } + _ => {} } - Row1::None => {} } } - result + z + result } diff --git a/examples/postgres-integration/policies0.json b/examples/postgres-integration/policies0.json index 54c8498..eefd69e 100644 --- a/examples/postgres-integration/policies0.json +++ b/examples/postgres-integration/policies0.json @@ -10,9 +10,9 @@ "program": ".example.garble.rs", "leader": 0, "party": 0, - "input": "SELECT id, age FROM residents", + "input": "SELECT id, name, age FROM residents", "db": "postgres://postgres:test@localhost:5550/postgres", - "max_rows": 10 + "max_rows": 4 } ] } diff --git a/examples/postgres-integration/policies1.json b/examples/postgres-integration/policies1.json index 360ef7d..8610bb7 100644 --- a/examples/postgres-integration/policies1.json +++ b/examples/postgres-integration/policies1.json @@ -10,9 +10,9 @@ "program": ".example.garble.rs", "leader": 0, "party": 1, - "input": "SELECT id, disabled FROM insurance", + "input": "SELECT id, name, disabled FROM insurance", "db": "postgres://postgres:test@localhost:5551/postgres", - "max_rows": 10 + "max_rows": 4 } ] } diff --git a/examples/postgres-integration/policies2.json b/examples/postgres-integration/policies2.json index 46210b4..6dd26eb 100644 --- a/examples/postgres-integration/policies2.json +++ b/examples/postgres-integration/policies2.json @@ -10,7 +10,7 @@ "program": ".example.garble.rs", "leader": 0, "party": 2, - "input": "2u32" + "input": "65u32" } ] } diff --git a/examples/postgres-integration/seeds/db0/residents.csv b/examples/postgres-integration/seeds/db0/residents.csv index 356d040..0d6996a 100644 --- a/examples/postgres-integration/seeds/db0/residents.csv +++ b/examples/postgres-integration/seeds/db0/residents.csv @@ -1,3 +1,5 @@ Id,Name,Age,Address -0,"Max Mustermann",22,"Parkstraße 1, 12345 Berlin" -1,"Erika Musterfrau",33,"Schlossallee 2, 12345 Berlin" +0,"John Doe",22,"Parkstraße 1, 12345 Berlin" +1,"Jane Doe",33,"Schlossallee 2, 12345 Berlin" +2,"Max Mustermann",66,"Parkstraße 1, 12345 Berlin" +3,"Erika Musterfrau",77,"Schlossallee 2, 12345 Berlin" diff --git a/examples/postgres-integration/src/main.rs b/examples/postgres-integration/src/main.rs index 22a57aa..a6b4806 100644 --- a/examples/postgres-integration/src/main.rs +++ b/examples/postgres-integration/src/main.rs @@ -1,7 +1,7 @@ use anyhow::{anyhow, bail, Context, Error}; use axum::{ body::Bytes, - extract::{Path, State}, + extract::{DefaultBodyLimit, Path, State}, routing::post, Json, Router, }; @@ -12,7 +12,7 @@ use parlay::{ garble_lang::{ compile, literal::{Literal, VariantLiteral}, - token::SignedNumType, + token::{SignedNumType, UnsignedNumType}, }, protocol::{mpc, Preprocessor}, }; @@ -34,8 +34,9 @@ use tokio::{ use tower_http::trace::TraceLayer; use url::Url; -const TIME_BETWEEN_EXECUTIONS: Duration = Duration::from_secs(5); +const TIME_BETWEEN_EXECUTIONS: Duration = Duration::from_secs(30); const DEFAULT_MAX_ROWS: usize = 10; +const STR_LEN_BYTES: usize = 16; /// A CLI for Multi-Party Computation using the Parlay engine. #[derive(Debug, Parser)] @@ -76,36 +77,6 @@ type MpcState = Arc>>>>; #[tokio::main] async fn main() -> Result<(), Error> { - /*let pool = PgPoolOptions::new() - .max_connections(5) - .connect("postgres://postgres:test@localhost:5555/postgres") - .await?; - match sqlx::query("SELECT * FROM residents") - .fetch_all(&pool) - .await - { - Ok(rows) => { - for row in rows { - for c in 0..row.len() { - if let Ok(s) = row.try_get::(c) { - print!("\"{s}\",") - } else if let Ok(n) = row.try_get::(c) { - print!("{n},") - } else if let Ok(n) = row.try_get::(c) { - print!("...{n},") - } else { - eprintln!("Could not decode column {c}"); - } - } - println!(""); - } - } - Err(e) => { - eprintln!("{e}") - } - } - - exit(0);*/ let Cli { port, config } = Cli::parse(); let policies = load_policies(config).await?; if policies.accepted.len() == 1 { @@ -123,6 +94,7 @@ async fn main() -> Result<(), Error> { .route("/run", post(run)) .route("/msg/:from", post(msg)) .with_state((policies.clone(), Arc::clone(&state))) + .layer(DefaultBodyLimit::disable()) .layer(TraceLayer::new_for_http()); let addr = SocketAddr::from(([127, 0, 0, 1], port)); @@ -168,10 +140,52 @@ async fn main() -> Result<(), Error> { } } println!("All participants have accepted the session, starting calculation now..."); - match execute_mpc(Arc::clone(&state), code, policy).await { - Ok(Some(output)) => { - println!("Output is {output}") + fn decode_literal(l: Literal) -> Result>, String> { + let Literal::Array(rows) = l else { + return Err(format!("Expected an array of rows, but found {l}")); + }; + let mut records = vec![]; + for row in rows { + let row = match row { + Literal::Tuple(row) => row, + record => vec![record], + }; + let mut record = vec![]; + fn stringify(elements: &[Literal]) -> Option { + let mut bytes = vec![]; + for e in elements { + if let Literal::NumUnsigned(n, UnsignedNumType::U8) = e { + if *n != 0 { + bytes.push(*n as u8); + } + } else { + return None; + } + } + String::from_utf8(bytes).ok() + } + for col in row { + record.push(match col { + Literal::True => "true".to_string(), + Literal::False => "false".to_string(), + Literal::NumUnsigned(n, _) => n.to_string(), + Literal::NumSigned(n, _) => n.to_string(), + Literal::Array(elements) => match stringify(&elements) { + Some(s) => format!("'{s}'"), + None => format!("'{}'", Literal::Array(elements)), + }, + l => format!("'{l}'"), + }); + } + records.push(record); } + Ok(records) + } + match execute_mpc(Arc::clone(&state), code, policy).await { + Ok(Some(output)) => match decode_literal(output) { + Ok(rows) => println!("MPC Output: {rows:?}"), + Err(e) => eprintln!("MPC Error: {e}"), + }, Ok(None) => {} Err(e) => { eprintln!("Error while executing MPC: {e}") @@ -200,6 +214,7 @@ async fn run_fpre(port: u16, urls: Vec) -> Result<(), Error> { let app = Router::new() .route("/msg/:from", post(msg)) .with_state((policies, Arc::new(Mutex::new(senders)))) + .layer(DefaultBodyLimit::disable()) .layer(TraceLayer::new_for_http()); let addr = SocketAddr::from(([127, 0, 0, 1], port)); @@ -245,25 +260,39 @@ async fn execute_mpc( max_rows, } = policy; let prg = compile(&code).map_err(|e| anyhow!(e.prettify(&code)))?; + println!( + "Trying to execute circuit with {}K gates ({}K AND gates)", + prg.circuit.gates.len() / 1000, + prg.circuit.and_gates() / 1000 + ); let input = if let Some(db) = db { println!("Connecting to {db}..."); let pool = PgPoolOptions::new().max_connections(5).connect(db).await?; let rows = sqlx::query(input).fetch_all(&pool).await?; println!("'{input}' returned {} rows in {db}", rows.len()); - let mut rows_as_literals = vec![]; let max_rows = max_rows.unwrap_or(DEFAULT_MAX_ROWS); - for _ in 0..max_rows { - rows_as_literals.push(Literal::Enum( + let mut rows_as_literals = vec![ + Literal::Enum( format!("Row{party}"), "None".to_string(), VariantLiteral::Unit, - )); - } + ); + max_rows + ]; for (r, row) in rows.iter().enumerate() { let mut row_as_literal = vec![]; for c in 0..row.len() { let field = if let Ok(s) = row.try_get::(c) { - todo!("String literals are not yet implemented: '{s}'"); + let mut fixed_str = + vec![Literal::NumUnsigned(0, UnsignedNumType::U8); STR_LEN_BYTES]; + for (i, b) in s.as_bytes().into_iter().enumerate() { + if i < STR_LEN_BYTES { + fixed_str[i] = Literal::NumUnsigned(*b as u64, UnsignedNumType::U8); + } else { + bail!("String is longer than {STR_LEN_BYTES} bytes: '{s}'"); + } + } + Literal::Array(fixed_str) } else if let Ok(b) = row.try_get::(c) { Literal::from(b) } else if let Ok(n) = row.try_get::(c) { @@ -374,6 +403,8 @@ impl Channel for HttpChannel { async fn send_bytes_to(&mut self, p: usize, msg: Vec) -> Result<(), Self::SendError> { let client = reqwest::Client::new(); let url = format!("{}msg/{}", self.urls[p], self.party); + let mb = msg.len() as f64 / 1024.0 / 1024.0; + println!("{} -> {} ({:.2}MB)", self.party, p, mb,); loop { match client.post(&url).body(msg.clone()).send().await?.status() { StatusCode::OK => return Ok(()), @@ -387,7 +418,7 @@ impl Channel for HttpChannel { } async fn recv_bytes_from(&mut self, p: usize) -> Result, Self::RecvError> { - Ok(timeout(Duration::from_secs(10), self.recv[p].recv()) + Ok(timeout(Duration::from_secs(60), self.recv[p].recv()) .await .context(format!("recv_bytes_from(p = {p})"))? .unwrap_or_default()) diff --git a/examples/postgres-integration/tests/cli.rs b/examples/postgres-integration/tests/cli.rs index 40cc1b0..245f3a7 100644 --- a/examples/postgres-integration/tests/cli.rs +++ b/examples/postgres-integration/tests/cli.rs @@ -71,7 +71,7 @@ fn simulate() { thread::spawn(move || { while let Some(Ok(line)) = stdout.next() { println!("party0> {line}"); - if line == "Output is 58u32" { + if line.starts_with("MPC Output:") { return s.send(()).unwrap(); } } @@ -81,7 +81,7 @@ fn simulate() { eprintln!("party0> {line}"); } }); - let result = r.recv_timeout(Duration::from_secs(10)); + let result = r.recv_timeout(Duration::from_secs(300)); for mut child in children { child.kill().unwrap(); } From 9a6e38e3dcc23aa47b701cf8cccca95239a55abd Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Tue, 23 Apr 2024 12:23:34 +0100 Subject: [PATCH 10/25] Write output of Postgres integration example into DB --- .../postgres-integration/docker-compose.yml | 11 +++++++ examples/postgres-integration/policies0.json | 6 ++-- examples/postgres-integration/policies1.json | 2 +- .../seeds/db_out/init.sql | 4 +++ examples/postgres-integration/src/main.rs | 33 +++++++++++++++++-- 5 files changed, 50 insertions(+), 6 deletions(-) create mode 100644 examples/postgres-integration/seeds/db_out/init.sql diff --git a/examples/postgres-integration/docker-compose.yml b/examples/postgres-integration/docker-compose.yml index aa76b24..2ce664b 100644 --- a/examples/postgres-integration/docker-compose.yml +++ b/examples/postgres-integration/docker-compose.yml @@ -22,6 +22,17 @@ services: environment: POSTGRES_PASSWORD: test + db_out: + image: postgres:16.2 + restart: always + shm_size: 128mb + volumes: + - ./seeds/db_out:/docker-entrypoint-initdb.d + ports: + - 5555:5432 + environment: + POSTGRES_PASSWORD: test + adminer: image: adminer restart: always diff --git a/examples/postgres-integration/policies0.json b/examples/postgres-integration/policies0.json index eefd69e..318b8c9 100644 --- a/examples/postgres-integration/policies0.json +++ b/examples/postgres-integration/policies0.json @@ -11,8 +11,10 @@ "leader": 0, "party": 0, "input": "SELECT id, name, age FROM residents", - "db": "postgres://postgres:test@localhost:5550/postgres", - "max_rows": 4 + "input_db": "postgres://postgres:test@localhost:5550/postgres", + "max_rows": 4, + "output": "INSERT INTO results (name) VALUES ($1)", + "output_db": "postgres://postgres:test@localhost:5555/postgres" } ] } diff --git a/examples/postgres-integration/policies1.json b/examples/postgres-integration/policies1.json index 8610bb7..b1d99d9 100644 --- a/examples/postgres-integration/policies1.json +++ b/examples/postgres-integration/policies1.json @@ -11,7 +11,7 @@ "leader": 0, "party": 1, "input": "SELECT id, name, disabled FROM insurance", - "db": "postgres://postgres:test@localhost:5551/postgres", + "input_db": "postgres://postgres:test@localhost:5551/postgres", "max_rows": 4 } ] diff --git a/examples/postgres-integration/seeds/db_out/init.sql b/examples/postgres-integration/seeds/db_out/init.sql new file mode 100644 index 0000000..b8d7ee2 --- /dev/null +++ b/examples/postgres-integration/seeds/db_out/init.sql @@ -0,0 +1,4 @@ +CREATE TABLE results ( + id SERIAL PRIMARY KEY, + name TEXT +); diff --git a/examples/postgres-integration/src/main.rs b/examples/postgres-integration/src/main.rs index a6b4806..930499c 100644 --- a/examples/postgres-integration/src/main.rs +++ b/examples/postgres-integration/src/main.rs @@ -62,8 +62,10 @@ struct Policy { leader: usize, party: usize, input: String, - db: Option, + input_db: Option, max_rows: Option, + output: Option, + output_db: Option, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] @@ -183,7 +185,30 @@ async fn main() -> Result<(), Error> { } match execute_mpc(Arc::clone(&state), code, policy).await { Ok(Some(output)) => match decode_literal(output) { - Ok(rows) => println!("MPC Output: {rows:?}"), + Ok(rows) => { + let n_rows = rows.len(); + let Policy { + output, output_db, .. + } = policy; + if let (Some(output_db), Some(output)) = (output_db, output) { + println!("Connecting to {output_db}..."); + let pool = PgPoolOptions::new() + .max_connections(5) + .connect(output_db) + .await?; + for row in rows { + let mut query = sqlx::query(output); + for field in row { + query = query.bind(field); + } + let rows = query.execute(&pool).await?.rows_affected(); + println!("Inserted {rows} row(s)"); + } + } else { + println!("No 'output' and/or 'output_db' specified in the policy, dropping {n_rows} rows"); + } + println!("MPC Output: {n_rows} rows") + } Err(e) => eprintln!("MPC Error: {e}"), }, Ok(None) => {} @@ -256,8 +281,10 @@ async fn execute_mpc( participants, party, input, - db, + input_db: db, max_rows, + output: _output, + output_db: _output_db, } = policy; let prg = compile(&code).map_err(|e| anyhow!(e.prettify(&code)))?; println!( From 6c94bcc1e1c97109d5fd759af1f953a57678996b Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Tue, 23 Apr 2024 17:40:25 +0100 Subject: [PATCH 11/25] Clean results in Postgres example DB before new execution --- examples/postgres-integration/policies0.json | 1 + examples/postgres-integration/src/main.rs | 79 +++++++++++--------- examples/postgres-integration/tests/cli.rs | 6 +- 3 files changed, 48 insertions(+), 38 deletions(-) diff --git a/examples/postgres-integration/policies0.json b/examples/postgres-integration/policies0.json index 318b8c9..6c4e9af 100644 --- a/examples/postgres-integration/policies0.json +++ b/examples/postgres-integration/policies0.json @@ -13,6 +13,7 @@ "input": "SELECT id, name, age FROM residents", "input_db": "postgres://postgres:test@localhost:5550/postgres", "max_rows": 4, + "setup": "TRUNCATE results", "output": "INSERT INTO results (name) VALUES ($1)", "output_db": "postgres://postgres:test@localhost:5555/postgres" } diff --git a/examples/postgres-integration/src/main.rs b/examples/postgres-integration/src/main.rs index 930499c..1cf2721 100644 --- a/examples/postgres-integration/src/main.rs +++ b/examples/postgres-integration/src/main.rs @@ -32,6 +32,7 @@ use tokio::{ time::{sleep, timeout}, }; use tower_http::trace::TraceLayer; +use tracing::{debug, error, info, trace, warn}; use url::Url; const TIME_BETWEEN_EXECUTIONS: Duration = Duration::from_secs(30); @@ -64,6 +65,7 @@ struct Policy { input: String, input_db: Option, max_rows: Option, + setup: Option, output: Option, output_db: Option, } @@ -79,16 +81,16 @@ type MpcState = Arc>>>>; #[tokio::main] async fn main() -> Result<(), Error> { + tracing_subscriber::fmt::init(); let Cli { port, config } = Cli::parse(); let policies = load_policies(config).await?; if policies.accepted.len() == 1 { let policy = &policies.accepted[0]; if policy.input.is_empty() { - println!("Running as preprocessor..."); + info!("Running as preprocessor..."); return run_fpre(port, policy.participants.clone()).await; } } - tracing_subscriber::fmt::init(); let state = Arc::new(Mutex::new(vec![])); @@ -100,20 +102,20 @@ async fn main() -> Result<(), Error> { .layer(TraceLayer::new_for_http()); let addr = SocketAddr::from(([127, 0, 0, 1], port)); - tracing::info!("listening on {}", addr); + info!("listening on {}", addr); let listener = tokio::net::TcpListener::bind(&addr).await?; tokio::spawn(async move { axum::serve(listener, app).await.unwrap() }); - println!("Found {} active policies", policies.accepted.len()); + info!("Found {} active policies", policies.accepted.len()); loop { for policy in &policies.accepted { if policy.leader == policy.party { - println!( + info!( "Acting as leader (party {}) for program {}", policy.leader, policy.program.display() ); let Ok(code) = fs::read_to_string(&policy.program).await else { - eprintln!("Could not load program {:?}", &policy.program); + error!("Could not load program {:?}", &policy.program); continue; }; let hash = blake3::hash(code.as_bytes()).to_string(); @@ -125,7 +127,7 @@ async fn main() -> Result<(), Error> { }; for party in policy.participants.iter().rev().skip(1).rev() { if party != &policy.participants[policy.party] { - println!("Waiting for confirmation from party {party}"); + info!("Waiting for confirmation from party {party}"); let url = format!("{party}run"); match client .post(&url) @@ -136,12 +138,12 @@ async fn main() -> Result<(), Error> { { StatusCode::OK => {} code => { - eprintln!("Unexpected response while trying to trigger execution for {url}: {code}"); + error!("Unexpected response while trying to trigger execution for {url}: {code}"); } } } } - println!("All participants have accepted the session, starting calculation now..."); + info!("All participants have accepted the session, starting calculation now..."); fn decode_literal(l: Literal) -> Result>, String> { let Literal::Array(rows) = l else { return Err(format!("Expected an array of rows, but found {l}")); @@ -188,32 +190,40 @@ async fn main() -> Result<(), Error> { Ok(rows) => { let n_rows = rows.len(); let Policy { - output, output_db, .. + setup, + output, + output_db, + .. } = policy; if let (Some(output_db), Some(output)) = (output_db, output) { - println!("Connecting to {output_db}..."); + info!("Connecting to output db at {output_db}..."); let pool = PgPoolOptions::new() .max_connections(5) .connect(output_db) .await?; + if let Some(setup) = setup { + let rows_affected = + sqlx::query(setup).execute(&pool).await?.rows_affected(); + debug!("{rows_affected} rows affected by '{setup}'"); + } for row in rows { let mut query = sqlx::query(output); for field in row { query = query.bind(field); } let rows = query.execute(&pool).await?.rows_affected(); - println!("Inserted {rows} row(s)"); + debug!("Inserted {rows} row(s)"); } } else { - println!("No 'output' and/or 'output_db' specified in the policy, dropping {n_rows} rows"); + warn!("No 'output' and/or 'output_db' specified in the policy, dropping {n_rows} rows"); } - println!("MPC Output: {n_rows} rows") + info!("MPC Output: {n_rows} rows") } - Err(e) => eprintln!("MPC Error: {e}"), + Err(e) => error!("MPC Error: {e}"), }, Ok(None) => {} Err(e) => { - eprintln!("Error while executing MPC: {e}") + error!("Error while executing MPC: {e}") } } } @@ -225,8 +235,6 @@ async fn main() -> Result<(), Error> { async fn run_fpre(port: u16, urls: Vec) -> Result<(), Error> { let parties = urls.len() - 1; - tracing_subscriber::fmt::init(); - let mut senders = vec![]; let mut receivers = vec![]; for _ in 0..parties { @@ -243,7 +251,7 @@ async fn run_fpre(port: u16, urls: Vec) -> Result<(), Error> { .layer(TraceLayer::new_for_http()); let addr = SocketAddr::from(([127, 0, 0, 1], port)); - tracing::info!("listening on {}", addr); + info!("listening on {}", addr); let listener = tokio::net::TcpListener::bind(&addr).await?; tokio::spawn(async move { axum::serve(listener, app).await.unwrap() }); @@ -253,18 +261,18 @@ async fn run_fpre(port: u16, urls: Vec) -> Result<(), Error> { recv: receivers, }; loop { - println!("Running FPre..."); + info!("Running FPre..."); channel = fpre(channel, parties).await.context("FPre")?; } } async fn load_policies(path: PathBuf) -> Result { let Ok(policies) = fs::read_to_string(&path).await else { - eprintln!("Could not find '{}', exiting...", path.display()); + error!("Could not find '{}', exiting...", path.display()); exit(-1); }; let Ok(policies) = serde_json::from_str::(&policies) else { - eprintln!("'{}' has an invalid format, exiting...", path.display()); + error!("'{}' has an invalid format, exiting...", path.display()); exit(-1); }; Ok(policies) @@ -283,20 +291,21 @@ async fn execute_mpc( input, input_db: db, max_rows, + setup: _setup, output: _output, output_db: _output_db, } = policy; let prg = compile(&code).map_err(|e| anyhow!(e.prettify(&code)))?; - println!( + info!( "Trying to execute circuit with {}K gates ({}K AND gates)", prg.circuit.gates.len() / 1000, prg.circuit.and_gates() / 1000 ); let input = if let Some(db) = db { - println!("Connecting to {db}..."); + info!("Connecting to input db at {db}..."); let pool = PgPoolOptions::new().max_connections(5).connect(db).await?; let rows = sqlx::query(input).fetch_all(&pool).await?; - println!("'{input}' returned {} rows in {db}", rows.len()); + info!("'{input}' returned {} rows from {db}", rows.len()); let max_rows = max_rows.unwrap_or(DEFAULT_MAX_ROWS); let mut rows_as_literals = vec![ Literal::Enum( @@ -332,14 +341,14 @@ async fn execute_mpc( row_as_literal.push(field); } if r >= max_rows { - eprintln!("Dropping record {r}"); + warn!("Dropping record {r}"); } else { let literal = Literal::Enum( format!("Row{party}"), "Some".to_string(), VariantLiteral::Tuple(row_as_literal), ); - println!("rows[{r}] = {literal}"); + debug!("rows[{r}] = {literal}"); rows_as_literals[r] = literal; } } @@ -385,27 +394,27 @@ async fn run( for policy in policies.accepted { if policy.participants == body.participants && policy.leader == body.leader { let Ok(code) = fs::read_to_string(&policy.program).await else { - eprintln!("Could not load program {:?}", &policy.program); + error!("Could not load program {:?}", &policy.program); return; }; let expected = blake3::hash(code.as_bytes()).to_string(); if expected != body.program_hash { - eprintln!("Aborting due to different hashes for program in policy {policy:?}"); + error!("Aborting due to different hashes for program in policy {policy:?}"); return; } - println!( + info!( "Accepted policy for {}, starting execution", policy.program.display() ); tokio::spawn(async move { if let Err(e) = execute_mpc(state, code, &policy).await { - eprintln!("{e}"); + error!("{e}"); } }); return; } } - eprintln!("Policy not accepted: {body:?}"); + error!("Policy not accepted: {body:?}"); } async fn msg(State((_, state)): State<(Policies, MpcState)>, Path(from): Path, body: Bytes) { @@ -413,7 +422,7 @@ async fn msg(State((_, state)): State<(Policies, MpcState)>, Path(from): Path from as usize { senders[from as usize].send(body.to_vec()).await.unwrap(); } else { - eprintln!("No sender for party {from}"); + error!("No sender for party {from}"); } } @@ -431,12 +440,12 @@ impl Channel for HttpChannel { let client = reqwest::Client::new(); let url = format!("{}msg/{}", self.urls[p], self.party); let mb = msg.len() as f64 / 1024.0 / 1024.0; - println!("{} -> {} ({:.2}MB)", self.party, p, mb,); + trace!("sending msg from {} to {} ({:.2}MB)", self.party, p, mb,); loop { match client.post(&url).body(msg.clone()).send().await?.status() { StatusCode::OK => return Ok(()), StatusCode::NOT_FOUND => { - println!("Could not reach party {p} at {url}..."); + error!("Could not reach party {p} at {url}..."); sleep(Duration::from_millis(1000)).await; } status => anyhow::bail!("Unexpected status code: {status}"), diff --git a/examples/postgres-integration/tests/cli.rs b/examples/postgres-integration/tests/cli.rs index 245f3a7..a4aff2a 100644 --- a/examples/postgres-integration/tests/cli.rs +++ b/examples/postgres-integration/tests/cli.rs @@ -23,12 +23,12 @@ fn simulate() { let mut stderr = BufReader::new(cmd.stderr.take().unwrap()).lines(); thread::spawn(move || { while let Some(Ok(line)) = stdout.next() { - println!("pre> {line}"); + println!(" pre> {line}"); } }); thread::spawn(move || { while let Some(Ok(line)) = stderr.next() { - eprintln!("pre> {line}"); + eprintln!(" pre> {line}"); } }); children.push(cmd); @@ -71,7 +71,7 @@ fn simulate() { thread::spawn(move || { while let Some(Ok(line)) = stdout.next() { println!("party0> {line}"); - if line.starts_with("MPC Output:") { + if line.contains("MPC Output: 4 rows") { return s.send(()).unwrap(); } } From 8f8a9099df6ea9985045dc7cf82babb9688b63b4 Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Wed, 24 Apr 2024 17:52:58 +0100 Subject: [PATCH 12/25] Add policy viewer routes in Postgres DB example --- examples/postgres-integration/Cargo.toml | 1 + examples/postgres-integration/policies0.json | 16 +++ examples/postgres-integration/src/main.rs | 97 +++++++++++++++---- .../templates/policies.html | 28 ++++++ .../templates/policy.html | 32 ++++++ 5 files changed, 157 insertions(+), 17 deletions(-) create mode 100644 examples/postgres-integration/templates/policies.html create mode 100644 examples/postgres-integration/templates/policy.html diff --git a/examples/postgres-integration/Cargo.toml b/examples/postgres-integration/Cargo.toml index 7d2beaa..69db645 100644 --- a/examples/postgres-integration/Cargo.toml +++ b/examples/postgres-integration/Cargo.toml @@ -8,6 +8,7 @@ anyhow = "1.0.79" axum = "0.7.4" blake3 = "1.5.1" clap = { version = "4.4.18", features = ["derive"] } +handlebars = "5.1.2" parlay = { path = "../../", version = "0.1.0" } reqwest = { version = "0.11.23", features = ["json"] } serde = "1.0.197" diff --git a/examples/postgres-integration/policies0.json b/examples/postgres-integration/policies0.json index 6c4e9af..0eaad77 100644 --- a/examples/postgres-integration/policies0.json +++ b/examples/postgres-integration/policies0.json @@ -16,6 +16,22 @@ "setup": "TRUNCATE results", "output": "INSERT INTO results (name) VALUES ($1)", "output_db": "postgres://postgres:test@localhost:5555/postgres" + }, + { + "participants": [ + "http://localhost:8000", + "http://localhost:8001", + "http://localhost:8002" + ], + "program": "does_not_exist.garble.rs", + "leader": 1, + "party": 0, + "input": "SELECT id, name, age FROM residents", + "input_db": "postgres://postgres:test@localhost:5550/postgres", + "max_rows": 4, + "setup": "TRUNCATE results", + "output": "INSERT INTO results (name) VALUES ($1)", + "output_db": "postgres://postgres:test@localhost:5555/postgres" } ] } diff --git a/examples/postgres-integration/src/main.rs b/examples/postgres-integration/src/main.rs index 1cf2721..ad75859 100644 --- a/examples/postgres-integration/src/main.rs +++ b/examples/postgres-integration/src/main.rs @@ -2,10 +2,12 @@ use anyhow::{anyhow, bail, Context, Error}; use axum::{ body::Bytes, extract::{DefaultBodyLimit, Path, State}, - routing::post, + response::Html, + routing::{get, post}, Json, Router, }; use clap::Parser; +use handlebars::Handlebars; use parlay::{ channel::Channel, fpre::fpre, @@ -18,6 +20,7 @@ use parlay::{ }; use reqwest::StatusCode; use serde::{Deserialize, Serialize}; +use serde_json::json; use sqlx::{postgres::PgPoolOptions, Row}; use std::{ borrow::BorrowMut, net::SocketAddr, path::PathBuf, process::exit, result::Result, sync::Arc, @@ -51,12 +54,12 @@ struct Cli { config: PathBuf, } -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] struct Policies { accepted: Vec, } -#[derive(Debug, Clone, Deserialize, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] struct Policy { participants: Vec, program: PathBuf, @@ -83,9 +86,9 @@ type MpcState = Arc>>>>; async fn main() -> Result<(), Error> { tracing_subscriber::fmt::init(); let Cli { port, config } = Cli::parse(); - let policies = load_policies(config).await?; - if policies.accepted.len() == 1 { - let policy = &policies.accepted[0]; + let local_policies = load_policies(config).await?; + if local_policies.accepted.len() == 1 { + let policy = &local_policies.accepted[0]; if policy.input.is_empty() { info!("Running as preprocessor..."); return run_fpre(port, policy.participants.clone()).await; @@ -97,7 +100,9 @@ async fn main() -> Result<(), Error> { let app = Router::new() .route("/run", post(run)) .route("/msg/:from", post(msg)) - .with_state((policies.clone(), Arc::clone(&state))) + .route("/policies", get(policies)) + .route("/policies/:id", get(policy)) + .with_state((local_policies.clone(), Arc::clone(&state))) .layer(DefaultBodyLimit::disable()) .layer(TraceLayer::new_for_http()); @@ -105,9 +110,9 @@ async fn main() -> Result<(), Error> { info!("listening on {}", addr); let listener = tokio::net::TcpListener::bind(&addr).await?; tokio::spawn(async move { axum::serve(listener, app).await.unwrap() }); - info!("Found {} active policies", policies.accepted.len()); + info!("Found {} active policies", local_policies.accepted.len()); loop { - for policy in &policies.accepted { + for policy in &local_policies.accepted { if policy.leader == policy.party { info!( "Acting as leader (party {}) for program {}", @@ -125,24 +130,32 @@ async fn main() -> Result<(), Error> { leader: policy.leader, program_hash: hash, }; + let mut participant_missing = false; for party in policy.participants.iter().rev().skip(1).rev() { if party != &policy.participants[policy.party] { info!("Waiting for confirmation from party {party}"); let url = format!("{party}run"); - match client - .post(&url) - .json(&policy_request) - .send() - .await? - .status() - { + let Ok(res) = client.post(&url).json(&policy_request).send().await else { + error!("Could not reach {url}"); + participant_missing = true; + continue; + }; + match res.status() { StatusCode::OK => {} code => { error!("Unexpected response while trying to trigger execution for {url}: {code}"); + participant_missing = true; } } } } + if participant_missing { + error!( + "Some participants of program {} are missing, skipping execution...", + policy.program.display() + ); + continue; + } info!("All participants have accepted the session, starting calculation now..."); fn decode_literal(l: Literal) -> Result>, String> { let Literal::Array(rows) = l else { @@ -321,7 +334,7 @@ async fn execute_mpc( let field = if let Ok(s) = row.try_get::(c) { let mut fixed_str = vec![Literal::NumUnsigned(0, UnsignedNumType::U8); STR_LEN_BYTES]; - for (i, b) in s.as_bytes().into_iter().enumerate() { + for (i, b) in s.as_bytes().iter().enumerate() { if i < STR_LEN_BYTES { fixed_str[i] = Literal::NumUnsigned(*b as u64, UnsignedNumType::U8); } else { @@ -426,6 +439,56 @@ async fn msg(State((_, state)): State<(Policies, MpcState)>, Path(from): Path, +) -> Result, axum::http::StatusCode> { + let mut accepted = vec![]; + for (i, p) in policies.accepted.into_iter().enumerate() { + accepted.push(json!({ + "id": i, + "num_participants": p.participants.len(), + "program": p.program.to_str().unwrap_or(""), + "leader": p.leader, + "party": p.party, + })); + } + let params = json!({ + "accepted": accepted + }); + render_template(include_str!("../templates/policies.html"), ¶ms) +} + +async fn policy( + State((policies, _)): State<(Policies, MpcState)>, + Path(id): Path, +) -> Result, axum::http::StatusCode> { + let Some(p) = policies.accepted.get(id) else { + return Err(axum::http::StatusCode::NOT_FOUND); + }; + let params = json!({ + "policy": { + "num_participants": p.participants.len(), + "participants": p.participants, + "program": p.program.to_str().unwrap_or(""), + "code": fs::read_to_string(&p.program).await.unwrap_or("Program not found".to_string()), + "leader": p.leader, + "party": p.party, + } + }); + render_template(include_str!("../templates/policy.html"), ¶ms) +} + +fn render_template( + template: &str, + params: &serde_json::Value, +) -> Result, axum::http::StatusCode> { + let h = Handlebars::new(); + let Ok(html) = h.render_template(template, ¶ms) else { + return Err(axum::http::StatusCode::INTERNAL_SERVER_ERROR); + }; + Ok(Html(html)) +} + struct HttpChannel { urls: Vec, party: usize, diff --git a/examples/postgres-integration/templates/policies.html b/examples/postgres-integration/templates/policies.html new file mode 100644 index 0000000..4030824 --- /dev/null +++ b/examples/postgres-integration/templates/policies.html @@ -0,0 +1,28 @@ + + + + + + Policies + + + +
+

Policies

+ {{#each accepted}} +
+ +
{{num_participants}} participants
+
Own Role: Party {{party}}
+
Leader: Party {{leader}}
+
+ {{/each}} +
+ + diff --git a/examples/postgres-integration/templates/policy.html b/examples/postgres-integration/templates/policy.html new file mode 100644 index 0000000..30a3ef4 --- /dev/null +++ b/examples/postgres-integration/templates/policy.html @@ -0,0 +1,32 @@ + + + {{#with policy}} + + + + Policy {{program}} + + + +
+

+ Policies + > {{program}} +

+
+
{{num_participants}} participants
+
    + {{#each participants}} +
  1. {{this}}
  2. + {{/each}} +
+
Own Role: Party {{party}}
+
Leader: Party {{leader}}
+
{{code}}
+
+
+ + {{/with}} + From a40bd8b1da53e73dc2029e12aa38e2d944814529 Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Wed, 24 Apr 2024 18:14:32 +0100 Subject: [PATCH 13/25] Use shorter strings for Postgres db example --- .github/workflows/test.yml | 2 +- examples/postgres-integration/.example.garble.rs | 8 ++++---- examples/postgres-integration/policies0.json | 1 + examples/postgres-integration/policies1.json | 3 ++- .../postgres-integration/seeds/db0/residents.csv | 4 ++-- .../postgres-integration/seeds/db1/insurance.csv | 4 ++-- examples/postgres-integration/src/main.rs | 15 +++++++++++---- examples/postgres-integration/tests/cli.rs | 2 +- 8 files changed, 24 insertions(+), 15 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 6c0183e..7d55414 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -20,5 +20,5 @@ jobs: working-directory: ./examples/http-single-server-channels - run: cargo test -- --nocapture working-directory: ./examples/iroh-p2p-channels - - run: docker compose -f docker-compose.yml up -d && cargo test -- --nocapture + - run: docker compose -f docker-compose.yml up -d && cargo test --release -- --nocapture working-directory: ./examples/postgres-integration diff --git a/examples/postgres-integration/.example.garble.rs b/examples/postgres-integration/.example.garble.rs index adc25bd..3bfc81d 100644 --- a/examples/postgres-integration/.example.garble.rs +++ b/examples/postgres-integration/.example.garble.rs @@ -1,15 +1,15 @@ enum Row0 { None, - Some(i32, [u8; 16], i32), + Some(i32, [u8; 6], i32), } enum Row1 { None, - Some(i32, [u8; 16], bool), + Some(i32, [u8; 6], bool), } -pub fn main(residents: [Row0; 4], insurance: [Row1; 4], age_threshold: u32) -> [[u8; 16]; 4] { - let mut result = [[0u8; 16]; 4]; +pub fn main(residents: [Row0; 4], insurance: [Row1; 4], age_threshold: u32) -> [[u8; 6]; 4] { + let mut result = [[0u8; 6]; 4]; let mut i = 0usize; for resident in residents { for insurance in insurance { diff --git a/examples/postgres-integration/policies0.json b/examples/postgres-integration/policies0.json index 0eaad77..72a9ec4 100644 --- a/examples/postgres-integration/policies0.json +++ b/examples/postgres-integration/policies0.json @@ -13,6 +13,7 @@ "input": "SELECT id, name, age FROM residents", "input_db": "postgres://postgres:test@localhost:5550/postgres", "max_rows": 4, + "max_str_bytes": 6, "setup": "TRUNCATE results", "output": "INSERT INTO results (name) VALUES ($1)", "output_db": "postgres://postgres:test@localhost:5555/postgres" diff --git a/examples/postgres-integration/policies1.json b/examples/postgres-integration/policies1.json index b1d99d9..5ec8f4a 100644 --- a/examples/postgres-integration/policies1.json +++ b/examples/postgres-integration/policies1.json @@ -12,7 +12,8 @@ "party": 1, "input": "SELECT id, name, disabled FROM insurance", "input_db": "postgres://postgres:test@localhost:5551/postgres", - "max_rows": 4 + "max_rows": 4, + "max_str_bytes": 6 } ] } diff --git a/examples/postgres-integration/seeds/db0/residents.csv b/examples/postgres-integration/seeds/db0/residents.csv index 0d6996a..eabf47c 100644 --- a/examples/postgres-integration/seeds/db0/residents.csv +++ b/examples/postgres-integration/seeds/db0/residents.csv @@ -1,5 +1,5 @@ Id,Name,Age,Address 0,"John Doe",22,"Parkstraße 1, 12345 Berlin" 1,"Jane Doe",33,"Schlossallee 2, 12345 Berlin" -2,"Max Mustermann",66,"Parkstraße 1, 12345 Berlin" -3,"Erika Musterfrau",77,"Schlossallee 2, 12345 Berlin" +2,"Max Doe",66,"Parkstraße 1, 12345 Berlin" +3,"Bob Doe",77,"Schlossallee 2, 12345 Berlin" diff --git a/examples/postgres-integration/seeds/db1/insurance.csv b/examples/postgres-integration/seeds/db1/insurance.csv index 21f0512..7a98cc8 100644 --- a/examples/postgres-integration/seeds/db1/insurance.csv +++ b/examples/postgres-integration/seeds/db1/insurance.csv @@ -1,3 +1,3 @@ Id,Name,Disabled,Address -0,"Max Mustermann",true,"Parkstraße 1, 12345 Berlin" -1,"Erika Musterfrau",false,"Schlossallee 2, 12345 Berlin" +0,"Max Doe",true,"Parkstraße 1, 12345 Berlin" +1,"Bob Doe",false,"Schlossallee 2, 12345 Berlin" diff --git a/examples/postgres-integration/src/main.rs b/examples/postgres-integration/src/main.rs index ad75859..799ec4c 100644 --- a/examples/postgres-integration/src/main.rs +++ b/examples/postgres-integration/src/main.rs @@ -40,7 +40,7 @@ use url::Url; const TIME_BETWEEN_EXECUTIONS: Duration = Duration::from_secs(30); const DEFAULT_MAX_ROWS: usize = 10; -const STR_LEN_BYTES: usize = 16; +const DEFAULT_MAX_STR_BYTES: usize = 8; /// A CLI for Multi-Party Computation using the Parlay engine. #[derive(Debug, Parser)] @@ -68,6 +68,7 @@ struct Policy { input: String, input_db: Option, max_rows: Option, + max_str_bytes: Option, setup: Option, output: Option, output_db: Option, @@ -304,6 +305,7 @@ async fn execute_mpc( input, input_db: db, max_rows, + max_str_bytes, setup: _setup, output: _output, output_db: _output_db, @@ -320,6 +322,7 @@ async fn execute_mpc( let rows = sqlx::query(input).fetch_all(&pool).await?; info!("'{input}' returned {} rows from {db}", rows.len()); let max_rows = max_rows.unwrap_or(DEFAULT_MAX_ROWS); + let max_str_bytes = max_str_bytes.unwrap_or(DEFAULT_MAX_STR_BYTES); let mut rows_as_literals = vec![ Literal::Enum( format!("Row{party}"), @@ -333,12 +336,16 @@ async fn execute_mpc( for c in 0..row.len() { let field = if let Ok(s) = row.try_get::(c) { let mut fixed_str = - vec![Literal::NumUnsigned(0, UnsignedNumType::U8); STR_LEN_BYTES]; + vec![Literal::NumUnsigned(0, UnsignedNumType::U8); max_str_bytes]; for (i, b) in s.as_bytes().iter().enumerate() { - if i < STR_LEN_BYTES { + if i < max_str_bytes { fixed_str[i] = Literal::NumUnsigned(*b as u64, UnsignedNumType::U8); } else { - bail!("String is longer than {STR_LEN_BYTES} bytes: '{s}'"); + warn!( + "String is longer than {max_str_bytes} bytes: '{s}', dropping '{}'", + &s[i..] + ); + break; } } Literal::Array(fixed_str) diff --git a/examples/postgres-integration/tests/cli.rs b/examples/postgres-integration/tests/cli.rs index a4aff2a..fd24032 100644 --- a/examples/postgres-integration/tests/cli.rs +++ b/examples/postgres-integration/tests/cli.rs @@ -8,7 +8,7 @@ use std::{ #[test] fn simulate() { - let millis = 1000; + let millis = 2000; println!("\n\n--- {millis}ms ---\n\n"); let mut children = vec![]; let mut cmd = Command::new("cargo") From 155f1f9ed45e75581cf573d89414d0b9a1655757 Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Mon, 29 Apr 2024 14:22:54 +0100 Subject: [PATCH 14/25] Make waiting time between executions in Postgres db example configurable --- examples/postgres-integration/src/main.rs | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/examples/postgres-integration/src/main.rs b/examples/postgres-integration/src/main.rs index 799ec4c..19f665b 100644 --- a/examples/postgres-integration/src/main.rs +++ b/examples/postgres-integration/src/main.rs @@ -32,13 +32,12 @@ use tokio::{ mpsc::{channel, Receiver, Sender}, Mutex, }, - time::{sleep, timeout}, + time::{self, sleep, timeout}, }; use tower_http::trace::TraceLayer; use tracing::{debug, error, info, trace, warn}; use url::Url; -const TIME_BETWEEN_EXECUTIONS: Duration = Duration::from_secs(30); const DEFAULT_MAX_ROWS: usize = 10; const DEFAULT_MAX_STR_BYTES: usize = 8; @@ -52,6 +51,9 @@ struct Cli { /// The location of the file with the policy configuration. #[arg(long, short)] config: PathBuf, + /// The time in minutes to wait before executing policies again. + #[arg(long, short, default_value = "30")] + sleep: u64, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -86,7 +88,11 @@ type MpcState = Arc>>>>; #[tokio::main] async fn main() -> Result<(), Error> { tracing_subscriber::fmt::init(); - let Cli { port, config } = Cli::parse(); + let Cli { + port, + config, + sleep, + } = Cli::parse(); let local_policies = load_policies(config).await?; if local_policies.accepted.len() == 1 { let policy = &local_policies.accepted[0]; @@ -242,7 +248,8 @@ async fn main() -> Result<(), Error> { } } } - sleep(TIME_BETWEEN_EXECUTIONS).await; + info!("Waiting {sleep} minutes before checking MPC policies again..."); + time::sleep(Duration::from_secs(60 * sleep)).await; } } From 89711163091d34dd5c99b7e4a1a494195f75a162 Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Mon, 29 Apr 2024 14:34:06 +0100 Subject: [PATCH 15/25] Run `cargo build` before `cargo test` in postgres example CI --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 7d55414..3774588 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -20,5 +20,5 @@ jobs: working-directory: ./examples/http-single-server-channels - run: cargo test -- --nocapture working-directory: ./examples/iroh-p2p-channels - - run: docker compose -f docker-compose.yml up -d && cargo test --release -- --nocapture + - run: docker compose -f docker-compose.yml up -d && cargo build --release && cargo test --release -- --nocapture working-directory: ./examples/postgres-integration From 8aa932c51c1581bb90e76c9a74a2b2383c783bc5 Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Mon, 3 Jun 2024 17:07:10 +0100 Subject: [PATCH 16/25] Use dynamic sizes for rows/strings in Postgres example --- Cargo.toml | 2 +- .../postgres-integration/.example.garble.rs | 33 +- examples/postgres-integration/policies0.json | 20 +- examples/postgres-integration/policies1.json | 15 +- .../postgres-integration/preprocessor.json | 6 +- examples/postgres-integration/src/main.rs | 314 +++++++++++++----- examples/postgres-integration/tests/cli.rs | 54 +-- 7 files changed, 305 insertions(+), 139 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a68405a..e38dbff 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ edition = "2021" bincode = "1.3.3" blake3 = "1.5.0" chacha20poly1305 = "0.10.1" -garble_lang = "0.2.0" +garble_lang = { path = "../garble", features = ["serde"] } rand = "0.8.5" serde = { version = "1.0.195", features = ["derive"] } tokio = { version = "1.35.1", features = ["full"] } diff --git a/examples/postgres-integration/.example.garble.rs b/examples/postgres-integration/.example.garble.rs index 3bfc81d..4064087 100644 --- a/examples/postgres-integration/.example.garble.rs +++ b/examples/postgres-integration/.example.garble.rs @@ -1,26 +1,21 @@ -enum Row0 { - None, - Some(i32, [u8; 6], i32), -} - -enum Row1 { - None, - Some(i32, [u8; 6], bool), -} +const ROWS_0: usize = PARTY_0::ROWS; +const ROWS_1: usize = PARTY_1::ROWS; +const ROWS_RESULT: usize = max(PARTY_0::ROWS, PARTY_1::ROWS); +const STR_LEN: usize = max(PARTY_0::STR_LEN, PARTY_1::STR_LEN); -pub fn main(residents: [Row0; 4], insurance: [Row1; 4], age_threshold: u32) -> [[u8; 6]; 4] { - let mut result = [[0u8; 6]; 4]; +pub fn main( + residents: [(i32, [u8; STR_LEN], i32); ROWS_0], + insurance: [(i32, [u8; STR_LEN], bool); ROWS_1], +) -> [[u8; STR_LEN]; ROWS_RESULT] { + let mut result = [[0u8; STR_LEN]; ROWS_RESULT]; let mut i = 0usize; for resident in residents { for insurance in insurance { - match (resident, insurance) { - (Row0::Some(_, name0, age), Row1::Some(_, name1, health_problems)) => { - if name0 == name1 && health_problems == true && age > (age_threshold as i32) { - result[i] = name0; - i = i + 1usize; - } - } - _ => {} + let (_, name0, age) = resident; + let (_, name1, health_problems) = insurance; + if name0 == name1 && health_problems == true && age > 65i32 { + result[i] = name0; + i = i + 1usize; } } } diff --git a/examples/postgres-integration/policies0.json b/examples/postgres-integration/policies0.json index 72a9ec4..ff2ee6c 100644 --- a/examples/postgres-integration/policies0.json +++ b/examples/postgres-integration/policies0.json @@ -4,19 +4,26 @@ "participants": [ "http://localhost:8000", "http://localhost:8001", - "http://localhost:8002", - "http://localhost:8003" + "http://localhost:8002" ], "program": ".example.garble.rs", "leader": 0, "party": 0, "input": "SELECT id, name, age FROM residents", "input_db": "postgres://postgres:test@localhost:5550/postgres", - "max_rows": 4, - "max_str_bytes": 6, "setup": "TRUNCATE results", "output": "INSERT INTO results (name) VALUES ($1)", - "output_db": "postgres://postgres:test@localhost:5555/postgres" + "output_db": "postgres://postgres:test@localhost:5555/postgres", + "constants": { + "ROWS": { + "query": "SELECT COUNT(id) FROM residents", + "ty": "usize" + }, + "STR_LEN": { + "query": "SELECT MAX(LENGTH(name)) FROM residents", + "ty": "usize" + } + } }, { "participants": [ @@ -32,7 +39,8 @@ "max_rows": 4, "setup": "TRUNCATE results", "output": "INSERT INTO results (name) VALUES ($1)", - "output_db": "postgres://postgres:test@localhost:5555/postgres" + "output_db": "postgres://postgres:test@localhost:5555/postgres", + "constants": {} } ] } diff --git a/examples/postgres-integration/policies1.json b/examples/postgres-integration/policies1.json index 5ec8f4a..959bab4 100644 --- a/examples/postgres-integration/policies1.json +++ b/examples/postgres-integration/policies1.json @@ -4,16 +4,23 @@ "participants": [ "http://localhost:8000", "http://localhost:8001", - "http://localhost:8002", - "http://localhost:8003" + "http://localhost:8002" ], "program": ".example.garble.rs", "leader": 0, "party": 1, "input": "SELECT id, name, disabled FROM insurance", "input_db": "postgres://postgres:test@localhost:5551/postgres", - "max_rows": 4, - "max_str_bytes": 6 + "constants": { + "ROWS": { + "query": "SELECT COUNT(id) FROM insurance", + "ty": "usize" + }, + "STR_LEN": { + "query": "SELECT MAX(LENGTH(name)) FROM insurance", + "ty": "usize" + } + } } ] } diff --git a/examples/postgres-integration/preprocessor.json b/examples/postgres-integration/preprocessor.json index d29c11e..e86069a 100644 --- a/examples/postgres-integration/preprocessor.json +++ b/examples/postgres-integration/preprocessor.json @@ -4,13 +4,13 @@ "participants": [ "http://localhost:8000", "http://localhost:8001", - "http://localhost:8002", - "http://localhost:8003" + "http://localhost:8002" ], "program": "", "leader": 0, "party": 0, - "input": "" + "input": "", + "constants": {} } ] } diff --git a/examples/postgres-integration/src/main.rs b/examples/postgres-integration/src/main.rs index 19f665b..2efd046 100644 --- a/examples/postgres-integration/src/main.rs +++ b/examples/postgres-integration/src/main.rs @@ -12,8 +12,9 @@ use parlay::{ channel::Channel, fpre::fpre, garble_lang::{ - compile, - literal::{Literal, VariantLiteral}, + ast::Type, + compile_with_constants, + literal::Literal, token::{SignedNumType, UnsignedNumType}, }, protocol::{mpc, Preprocessor}, @@ -23,8 +24,8 @@ use serde::{Deserialize, Serialize}; use serde_json::json; use sqlx::{postgres::PgPoolOptions, Row}; use std::{ - borrow::BorrowMut, net::SocketAddr, path::PathBuf, process::exit, result::Result, sync::Arc, - time::Duration, + borrow::BorrowMut, collections::HashMap, net::SocketAddr, path::PathBuf, process::exit, + result::Result, sync::Arc, time::Duration, }; use tokio::{ fs, @@ -38,9 +39,6 @@ use tower_http::trace::TraceLayer; use tracing::{debug, error, info, trace, warn}; use url::Url; -const DEFAULT_MAX_ROWS: usize = 10; -const DEFAULT_MAX_STR_BYTES: usize = 8; - /// A CLI for Multi-Party Computation using the Parlay engine. #[derive(Debug, Parser)] #[command(name = "parlay")] @@ -61,7 +59,7 @@ struct Policies { accepted: Vec, } -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] struct Policy { participants: Vec, program: PathBuf, @@ -69,11 +67,16 @@ struct Policy { party: usize, input: String, input_db: Option, - max_rows: Option, - max_str_bytes: Option, setup: Option, output: Option, output_db: Option, + constants: HashMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +struct Constant { + query: String, + ty: String, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] @@ -83,7 +86,17 @@ struct PolicyRequest { leader: usize, } -type MpcState = Arc>>>>; +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +struct ConstsRequest { + consts: HashMap, +} + +struct MpcComms { + consts: HashMap>, + senders: Vec>>, +} + +type MpcState = Arc>; #[tokio::main] async fn main() -> Result<(), Error> { @@ -102,10 +115,14 @@ async fn main() -> Result<(), Error> { } } - let state = Arc::new(Mutex::new(vec![])); + let state = Arc::new(Mutex::new(MpcComms { + consts: HashMap::new(), + senders: vec![], + })); let app = Router::new() .route("/run", post(run)) + .route("/consts/:from", post(consts)) .route("/msg/:from", post(msg)) .route("/policies", get(policies)) .route("/policies/:id", get(policy)) @@ -267,7 +284,13 @@ async fn run_fpre(port: u16, urls: Vec) -> Result<(), Error> { let policies = Policies { accepted: vec![] }; let app = Router::new() .route("/msg/:from", post(msg)) - .with_state((policies, Arc::new(Mutex::new(senders)))) + .with_state(( + policies, + Arc::new(Mutex::new(MpcComms { + consts: HashMap::new(), + senders, + })), + )) .layer(DefaultBodyLimit::disable()) .layer(TraceLayer::new_for_http()); @@ -292,11 +315,13 @@ async fn load_policies(path: PathBuf) -> Result { error!("Could not find '{}', exiting...", path.display()); exit(-1); }; - let Ok(policies) = serde_json::from_str::(&policies) else { - error!("'{}' has an invalid format, exiting...", path.display()); - exit(-1); - }; - Ok(policies) + match serde_json::from_str::(&policies) { + Ok(policies) => Ok(policies), + Err(e) => { + error!("'{}' has an invalid format: {e}", path.display()); + exit(-1); + } + } } async fn execute_mpc( @@ -311,91 +336,210 @@ async fn execute_mpc( party, input, input_db: db, - max_rows, - max_str_bytes, setup: _setup, output: _output, output_db: _output_db, + constants, } = policy; - let prg = compile(&code).map_err(|e| anyhow!(e.prettify(&code)))?; - info!( - "Trying to execute circuit with {}K gates ({}K AND gates)", - prg.circuit.gates.len() / 1000, - prg.circuit.and_gates() / 1000 - ); - let input = if let Some(db) = db { + let (prg, input) = if let Some(db) = db { info!("Connecting to input db at {db}..."); let pool = PgPoolOptions::new().max_connections(5).connect(db).await?; let rows = sqlx::query(input).fetch_all(&pool).await?; info!("'{input}' returned {} rows from {db}", rows.len()); - let max_rows = max_rows.unwrap_or(DEFAULT_MAX_ROWS); - let max_str_bytes = max_str_bytes.unwrap_or(DEFAULT_MAX_STR_BYTES); - let mut rows_as_literals = vec![ - Literal::Enum( - format!("Row{party}"), - "None".to_string(), - VariantLiteral::Unit, + + let mut my_consts = HashMap::new(); + for (k, c) in constants { + let row = sqlx::query(&c.query).fetch_one(&pool).await?; + if row.len() != 1 { + bail!( + "Expected a single scalar value, but got {} from query '{}'", + row.len(), + c.query + ); + } else { + if let Ok(n) = row.try_get::(0) { + if n >= 0 && c.ty == "usize" { + my_consts.insert( + k.clone(), + Literal::NumUnsigned(n as u64, UnsignedNumType::Usize), + ); + continue; + } + } else if let Ok(n) = row.try_get::(0) { + if n >= 0 && c.ty == "usize" { + my_consts.insert( + k.clone(), + Literal::NumUnsigned(n as u64, UnsignedNumType::Usize), + ); + continue; + } + } + bail!("Could not decode scalar value as {} of '{}'", c.ty, c.query); + } + } + { + let mut locked = state.lock().await; + locked + .consts + .insert(format!("PARTY_{party}"), my_consts.clone()); + } + let client = reqwest::Client::new(); + for p in participants.iter().rev().skip(1).rev() { + if p != &participants[*party] { + info!("Sending constants to party {p}"); + let url = format!("{p}consts/{party}"); + let const_request = ConstsRequest { + consts: my_consts.clone(), + }; + let Ok(res) = client.post(&url).json(&const_request).send().await else { + bail!("Could not reach {url}"); + }; + match res.status() { + StatusCode::OK => {} + code => { + bail!("Unexpected response while trying to send consts to {url}: {code}"); + } + } + } + } + loop { + sleep(Duration::from_millis(500)).await; + let locked = state.lock().await; + if locked.consts.len() >= participants.len() - 1 { + break; + } else { + let missing = participants.len() - 1 - locked.consts.len(); + info!( + "Constants missing from {} parties, received constants from {:?}", + missing, + locked.consts.keys() + ); + } + } + + let prg = { + let locked = state.lock().await; + compile_with_constants(&code, locked.consts.clone()) + .map_err(|e| anyhow!(e.prettify(&code)))? + }; + let input_ty = &prg.main.params[*party].ty; + let Type::ArrayConst(row_type, _) = input_ty else { + bail!("Expected an array input type (with const size) for party {party}, but found {input_ty}"); + }; + let Type::Tuple(field_types) = row_type.as_ref() else { + bail!( + "Expected an array of tuples as input type for party {party}, but found {input_ty}" ); - max_rows - ]; + }; + info!( + "Trying to execute circuit with {}K gates ({}K AND gates)", + prg.circuit.gates.len() / 1000, + prg.circuit.and_gates() / 1000 + ); + let mut rows_as_literals = vec![]; for (r, row) in rows.iter().enumerate() { let mut row_as_literal = vec![]; + if field_types.len() != row.len() { + bail!( + "The program expects a tuple with {} fields, but the query returned a row with {} fields", + field_types.len(), + row.len() + ); + } for c in 0..row.len() { - let field = if let Ok(s) = row.try_get::(c) { - let mut fixed_str = - vec![Literal::NumUnsigned(0, UnsignedNumType::U8); max_str_bytes]; - for (i, b) in s.as_bytes().iter().enumerate() { - if i < max_str_bytes { - fixed_str[i] = Literal::NumUnsigned(*b as u64, UnsignedNumType::U8); - } else { - warn!( - "String is longer than {max_str_bytes} bytes: '{s}', dropping '{}'", - &s[i..] - ); - break; + let mut literal = None; + match &field_types[c] { + Type::Bool => { + if let Ok(b) = row.try_get::(c) { + literal = Some(Literal::from(b)) + } + } + Type::Signed(SignedNumType::I32) => { + if let Ok(n) = row.try_get::(c) { + literal = Some(Literal::NumSigned(n as i64, SignedNumType::I32)) + } + } + Type::Signed(SignedNumType::I64) => { + if let Ok(n) = row.try_get::(c) { + literal = Some(Literal::NumSigned(n, SignedNumType::I64)) + } + } + Type::Array(ty, size) + if ty.as_ref() == &Type::Unsigned(UnsignedNumType::U8) => + { + if let Ok(s) = row.try_get::(c) { + let mut fixed_str = + vec![Literal::NumUnsigned(0, UnsignedNumType::U8); *size]; + for (i, b) in s.as_bytes().iter().enumerate() { + if i < *size { + fixed_str[i] = + Literal::NumUnsigned(*b as u64, UnsignedNumType::U8); + } else { + warn!( + "String is longer than {size} bytes: '{s}', dropping '{}'", + &s[i..] + ); + break; + } + } + literal = Some(Literal::Array(fixed_str)) } } - Literal::Array(fixed_str) - } else if let Ok(b) = row.try_get::(c) { - Literal::from(b) - } else if let Ok(n) = row.try_get::(c) { - Literal::NumSigned(n as i64, SignedNumType::I32) - } else if let Ok(n) = row.try_get::(c) { - Literal::NumSigned(n, SignedNumType::I64) + Type::ArrayConst(ty, size) + if ty.as_ref() == &Type::Unsigned(UnsignedNumType::U8) => + { + let size = *prg.const_sizes.get(size).unwrap(); + if let Ok(s) = row.try_get::(c) { + let mut fixed_str = + vec![Literal::NumUnsigned(0, UnsignedNumType::U8); size]; + for (i, b) in s.as_bytes().iter().enumerate() { + if i < size { + fixed_str[i] = + Literal::NumUnsigned(*b as u64, UnsignedNumType::U8); + } else { + warn!( + "String is longer than {size} bytes: '{s}', dropping '{}'", + &s[i..] + ); + break; + } + } + literal = Some(Literal::Array(fixed_str)) + } + } + _ => {} + } + if let Some(literal) = literal { + row_as_literal.push(literal); } else { - bail!("Could not decode column {c}"); - }; - row_as_literal.push(field); - } - if r >= max_rows { - warn!("Dropping record {r}"); - } else { - let literal = Literal::Enum( - format!("Row{party}"), - "Some".to_string(), - VariantLiteral::Tuple(row_as_literal), - ); - debug!("rows[{r}] = {literal}"); - rows_as_literals[r] = literal; + bail!("Could not decode column {c} with type {}", field_types[c]); + } } + let literal = Literal::Tuple(row_as_literal); + debug!("rows[{r}] = {literal}"); + rows_as_literals.push(literal); } let literal = Literal::Array(rows_as_literals); - prg.literal_arg(*party, literal)?.as_bits() + let input = prg.literal_arg(*party, literal)?.as_bits(); + (prg, input) } else { - prg.parse_arg(*party, input)?.as_bits() + let consts = HashMap::new(); + let prg = compile_with_constants(&code, consts).map_err(|e| anyhow!(e.prettify(&code)))?; + let input = prg.parse_arg(*party, input)?.as_bits(); + (prg, input) }; let fpre = Preprocessor::TrustedDealer(participants.len() - 1); let p_out: Vec<_> = vec![*leader]; let channel = { - let mut state = state.lock().await; - let senders = state.borrow_mut(); - if !senders.is_empty() { + let mut locked = state.lock().await; + let state = locked.borrow_mut(); + if !state.senders.is_empty() { panic!("Cannot start a new MPC execution while there are still active senders!"); } let mut receivers = vec![]; for _ in 0..policy.participants.len() { let (s, r) = channel(1); - senders.push(s); + state.senders.push(s); receivers.push(r); } @@ -406,7 +550,7 @@ async fn execute_mpc( } }; let output = mpc(channel, &prg.circuit, &input, fpre, 0, *party, &p_out).await?; - state.lock().await.clear(); + state.lock().await.senders.clear(); if output.is_empty() { Ok(None) } else { @@ -444,10 +588,22 @@ async fn run( error!("Policy not accepted: {body:?}"); } +async fn consts( + State((_, state)): State<(Policies, MpcState)>, + Path(from): Path, + Json(body): Json, +) { + let mut state = state.lock().await; + state.consts.insert(format!("PARTY_{from}"), body.consts); +} + async fn msg(State((_, state)): State<(Policies, MpcState)>, Path(from): Path, body: Bytes) { - let senders = state.lock().await; - if senders.len() > from as usize { - senders[from as usize].send(body.to_vec()).await.unwrap(); + let state = state.lock().await; + if state.senders.len() > from as usize { + state.senders[from as usize] + .send(body.to_vec()) + .await + .unwrap(); } else { error!("No sender for party {from}"); } diff --git a/examples/postgres-integration/tests/cli.rs b/examples/postgres-integration/tests/cli.rs index fd24032..218069b 100644 --- a/examples/postgres-integration/tests/cli.rs +++ b/examples/postgres-integration/tests/cli.rs @@ -12,7 +12,7 @@ fn simulate() { println!("\n\n--- {millis}ms ---\n\n"); let mut children = vec![]; let mut cmd = Command::new("cargo") - .args(["run", "--", "--port=8003", "--config=preprocessor.json"]) + .args(["run", "--", "--port=8002", "--config=preprocessor.json"]) .stdout(Stdio::piped()) .stderr(Stdio::piped()) .spawn() @@ -32,31 +32,31 @@ fn simulate() { } }); children.push(cmd); - for p in [1, 2] { - let port = format!("--port=800{p}"); - let config = format!("--config=policies{p}.json"); - let args = vec!["run", "--", &port, &config]; - let mut cmd = Command::new("cargo") - .args(args) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .spawn() - .unwrap(); - let mut stdout = BufReader::new(cmd.stdout.take().unwrap()).lines().skip(4); - let mut stderr = BufReader::new(cmd.stderr.take().unwrap()).lines(); - thread::spawn(move || { - while let Some(Ok(line)) = stdout.next() { - println!("party{p}> {line}"); - } - }); - thread::spawn(move || { - while let Some(Ok(line)) = stderr.next() { - eprintln!("party{p}> {line}"); - } - }); - children.push(cmd); - } - sleep(Duration::from_millis(500)); + + let port = format!("--port=8001"); + let config = format!("--config=policies1.json"); + let args = vec!["run", "--", &port, &config]; + let mut cmd = Command::new("cargo") + .args(args) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .unwrap(); + let mut stdout = BufReader::new(cmd.stdout.take().unwrap()).lines().skip(4); + let mut stderr = BufReader::new(cmd.stderr.take().unwrap()).lines(); + thread::spawn(move || { + while let Some(Ok(line)) = stdout.next() { + println!("party1> {line}"); + } + }); + thread::spawn(move || { + while let Some(Ok(line)) = stderr.next() { + eprintln!("party1> {line}"); + } + }); + children.push(cmd); + + sleep(Duration::from_millis(millis)); let args = vec!["run", "--", "--port=8000", "--config=policies0.json"]; let mut cmd = Command::new("cargo") .args(args) @@ -81,7 +81,7 @@ fn simulate() { eprintln!("party0> {line}"); } }); - let result = r.recv_timeout(Duration::from_secs(300)); + let result = r.recv_timeout(Duration::from_secs(60)); for mut child in children { child.kill().unwrap(); } From 4387cb407806547ca17d9a27708939d43c084adc Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Wed, 5 Jun 2024 11:49:44 +0100 Subject: [PATCH 17/25] Use Garble v0.3.0 for const support --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index e38dbff..efeb156 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ edition = "2021" bincode = "1.3.3" blake3 = "1.5.0" chacha20poly1305 = "0.10.1" -garble_lang = { path = "../garble", features = ["serde"] } +garble_lang = { version = "0.3.0", features = ["serde"] } rand = "0.8.5" serde = { version = "1.0.195", features = ["derive"] } tokio = { version = "1.35.1", features = ["full"] } From 446267e7967096c6f74635b18044cf6aa3a87153 Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Wed, 5 Jun 2024 11:52:16 +0100 Subject: [PATCH 18/25] Fix clippy warnings --- examples/postgres-integration/src/main.rs | 4 ++-- examples/postgres-integration/tests/cli.rs | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/postgres-integration/src/main.rs b/examples/postgres-integration/src/main.rs index 2efd046..83cf0d1 100644 --- a/examples/postgres-integration/src/main.rs +++ b/examples/postgres-integration/src/main.rs @@ -446,9 +446,9 @@ async fn execute_mpc( row.len() ); } - for c in 0..row.len() { + for (c, field_type) in field_types.iter().enumerate() { let mut literal = None; - match &field_types[c] { + match field_type { Type::Bool => { if let Ok(b) = row.try_get::(c) { literal = Some(Literal::from(b)) diff --git a/examples/postgres-integration/tests/cli.rs b/examples/postgres-integration/tests/cli.rs index 218069b..f62ae67 100644 --- a/examples/postgres-integration/tests/cli.rs +++ b/examples/postgres-integration/tests/cli.rs @@ -33,8 +33,8 @@ fn simulate() { }); children.push(cmd); - let port = format!("--port=8001"); - let config = format!("--config=policies1.json"); + let port = "--port=8001".to_string(); + let config = "--config=policies1.json".to_string(); let args = vec!["run", "--", &port, &config]; let mut cmd = Command::new("cargo") .args(args) From 11420c112f746a57c9d9029bb046350a87dade07 Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Wed, 5 Jun 2024 15:38:17 +0100 Subject: [PATCH 19/25] Increase timeout in Postgres test to 5 mins --- examples/postgres-integration/tests/cli.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/postgres-integration/tests/cli.rs b/examples/postgres-integration/tests/cli.rs index f62ae67..41cbb3c 100644 --- a/examples/postgres-integration/tests/cli.rs +++ b/examples/postgres-integration/tests/cli.rs @@ -81,7 +81,7 @@ fn simulate() { eprintln!("party0> {line}"); } }); - let result = r.recv_timeout(Duration::from_secs(60)); + let result = r.recv_timeout(Duration::from_secs(5 * 60)); for mut child in children { child.kill().unwrap(); } From ce0cefad7b301f45d975f52487b89a761a3e8a40 Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Tue, 11 Jun 2024 16:39:23 +0100 Subject: [PATCH 20/25] Implement SQL text -> Garble enum conversion in Postgres example --- .../postgres-integration/.example.garble.rs | 25 +++++++++++++++---- examples/postgres-integration/policies1.json | 2 +- .../postgres-integration/seeds/db1/init.sql | 2 +- .../seeds/db1/insurance.csv | 7 +++--- examples/postgres-integration/src/main.rs | 23 +++++++++++++++-- 5 files changed, 47 insertions(+), 12 deletions(-) diff --git a/examples/postgres-integration/.example.garble.rs b/examples/postgres-integration/.example.garble.rs index 4064087..4cc263e 100644 --- a/examples/postgres-integration/.example.garble.rs +++ b/examples/postgres-integration/.example.garble.rs @@ -3,19 +3,34 @@ const ROWS_1: usize = PARTY_1::ROWS; const ROWS_RESULT: usize = max(PARTY_0::ROWS, PARTY_1::ROWS); const STR_LEN: usize = max(PARTY_0::STR_LEN, PARTY_1::STR_LEN); +enum InsuranceStatus { + Foo, + Bar, + FooBar, +} + pub fn main( residents: [(i32, [u8; STR_LEN], i32); ROWS_0], - insurance: [(i32, [u8; STR_LEN], bool); ROWS_1], + insurance: [(i32, [u8; STR_LEN], InsuranceStatus); ROWS_1], ) -> [[u8; STR_LEN]; ROWS_RESULT] { let mut result = [[0u8; STR_LEN]; ROWS_RESULT]; let mut i = 0usize; for resident in residents { for insurance in insurance { let (_, name0, age) = resident; - let (_, name1, health_problems) = insurance; - if name0 == name1 && health_problems == true && age > 65i32 { - result[i] = name0; - i = i + 1usize; + let (_, name1, insurance_status) = insurance; + if name0 == name1 && age > 65i32 { + match insurance_status { + InsuranceStatus::Foo => { + result[i] = name0; + i = i + 1usize; + } + InsuranceStatus::Bar => { + result[i] = name0; + i = i + 1usize; + } + _ => {} + } } } } diff --git a/examples/postgres-integration/policies1.json b/examples/postgres-integration/policies1.json index 959bab4..e953382 100644 --- a/examples/postgres-integration/policies1.json +++ b/examples/postgres-integration/policies1.json @@ -9,7 +9,7 @@ "program": ".example.garble.rs", "leader": 0, "party": 1, - "input": "SELECT id, name, disabled FROM insurance", + "input": "SELECT id, name, status FROM insurance", "input_db": "postgres://postgres:test@localhost:5551/postgres", "constants": { "ROWS": { diff --git a/examples/postgres-integration/seeds/db1/init.sql b/examples/postgres-integration/seeds/db1/init.sql index 906d9e0..540e94e 100644 --- a/examples/postgres-integration/seeds/db1/init.sql +++ b/examples/postgres-integration/seeds/db1/init.sql @@ -1,7 +1,7 @@ CREATE TABLE insurance ( id SERIAL PRIMARY KEY, name TEXT, - disabled BOOL, + status TEXT, address TEXT ); diff --git a/examples/postgres-integration/seeds/db1/insurance.csv b/examples/postgres-integration/seeds/db1/insurance.csv index 7a98cc8..5d0ad18 100644 --- a/examples/postgres-integration/seeds/db1/insurance.csv +++ b/examples/postgres-integration/seeds/db1/insurance.csv @@ -1,3 +1,4 @@ -Id,Name,Disabled,Address -0,"Max Doe",true,"Parkstraße 1, 12345 Berlin" -1,"Bob Doe",false,"Schlossallee 2, 12345 Berlin" +Id,Name,Status,Address +0,"Max Doe","foo","Parkstraße 1, 12345 Berlin" +1,"Bob Doe","bar","Schlossallee 2, 12345 Berlin" +2,"Jane Doe","foobar","Schlossallee 2, 12345 Berlin" diff --git a/examples/postgres-integration/src/main.rs b/examples/postgres-integration/src/main.rs index 83cf0d1..5199629 100644 --- a/examples/postgres-integration/src/main.rs +++ b/examples/postgres-integration/src/main.rs @@ -12,9 +12,9 @@ use parlay::{ channel::Channel, fpre::fpre, garble_lang::{ - ast::Type, + ast::{Type, Variant}, compile_with_constants, - literal::Literal, + literal::{Literal, VariantLiteral}, token::{SignedNumType, UnsignedNumType}, }, protocol::{mpc, Preprocessor}, @@ -507,6 +507,25 @@ async fn execute_mpc( literal = Some(Literal::Array(fixed_str)) } } + Type::Enum(name) => { + let Some(enum_def) = prg.program.enum_defs.get(name) else { + bail!("Could not find definition for enum {name} in program"); + }; + if let Ok(variant) = row.try_get::(c) { + for v in &enum_def.variants { + if let Variant::Unit(variant_name) = v { + if variant_name.to_lowercase() == variant.to_lowercase() { + literal = Some(Literal::Enum( + name.clone(), + variant_name.clone(), + VariantLiteral::Unit, + )); + break; + } + } + } + } + } _ => {} } if let Some(literal) = literal { From 1101aa8d42ca63f6f5097656139d2c6572671899 Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Mon, 17 Jun 2024 17:13:10 +0200 Subject: [PATCH 21/25] Use sqlx::any instead of postgres --- examples/postgres-integration/Cargo.toml | 8 +++++++- examples/postgres-integration/src/main.rs | 25 +++++++++++++---------- 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/examples/postgres-integration/Cargo.toml b/examples/postgres-integration/Cargo.toml index 69db645..99df69d 100644 --- a/examples/postgres-integration/Cargo.toml +++ b/examples/postgres-integration/Cargo.toml @@ -13,7 +13,13 @@ parlay = { path = "../../", version = "0.1.0" } reqwest = { version = "0.11.23", features = ["json"] } serde = "1.0.197" serde_json = "1.0.115" -sqlx = { version = "0.7.4", features = ["runtime-tokio", "postgres"] } +sqlx = { version = "0.7.4", features = [ + "runtime-tokio", + "any", + "postgres", + "mysql", + "sqlite", +] } tokio = { version = "1.35.1", features = ["macros", "rt", "rt-multi-thread"] } tower-http = { version = "0.5.1", features = ["fs", "trace"] } tracing = "0.1.40" diff --git a/examples/postgres-integration/src/main.rs b/examples/postgres-integration/src/main.rs index 5199629..91089d8 100644 --- a/examples/postgres-integration/src/main.rs +++ b/examples/postgres-integration/src/main.rs @@ -22,7 +22,10 @@ use parlay::{ use reqwest::StatusCode; use serde::{Deserialize, Serialize}; use serde_json::json; -use sqlx::{postgres::PgPoolOptions, Row}; +use sqlx::{ + any::{install_default_drivers, AnyQueryResult, AnyRow}, + AnyPool, Pool, Row, +}; use std::{ borrow::BorrowMut, collections::HashMap, net::SocketAddr, path::PathBuf, process::exit, result::Result, sync::Arc, time::Duration, @@ -101,6 +104,7 @@ type MpcState = Arc>; #[tokio::main] async fn main() -> Result<(), Error> { tracing_subscriber::fmt::init(); + install_default_drivers(); let Cli { port, config, @@ -234,13 +238,11 @@ async fn main() -> Result<(), Error> { } = policy; if let (Some(output_db), Some(output)) = (output_db, output) { info!("Connecting to output db at {output_db}..."); - let pool = PgPoolOptions::new() - .max_connections(5) - .connect(output_db) - .await?; + let pool: AnyPool = Pool::connect(output_db).await?; if let Some(setup) = setup { - let rows_affected = - sqlx::query(setup).execute(&pool).await?.rows_affected(); + let result: AnyQueryResult = + sqlx::query(setup).execute(&pool).await?; + let rows_affected = result.rows_affected(); debug!("{rows_affected} rows affected by '{setup}'"); } for row in rows { @@ -248,7 +250,8 @@ async fn main() -> Result<(), Error> { for field in row { query = query.bind(field); } - let rows = query.execute(&pool).await?.rows_affected(); + let result: AnyQueryResult = query.execute(&pool).await?; + let rows = result.rows_affected(); debug!("Inserted {rows} row(s)"); } } else { @@ -343,13 +346,13 @@ async fn execute_mpc( } = policy; let (prg, input) = if let Some(db) = db { info!("Connecting to input db at {db}..."); - let pool = PgPoolOptions::new().max_connections(5).connect(db).await?; - let rows = sqlx::query(input).fetch_all(&pool).await?; + let pool: AnyPool = Pool::connect(db).await?; + let rows: Vec = sqlx::query(input).fetch_all(&pool).await?; info!("'{input}' returned {} rows from {db}", rows.len()); let mut my_consts = HashMap::new(); for (k, c) in constants { - let row = sqlx::query(&c.query).fetch_one(&pool).await?; + let row: AnyRow = sqlx::query(&c.query).fetch_one(&pool).await?; if row.len() != 1 { bail!( "Expected a single scalar value, but got {} from query '{}'", From 4c9df103d6289bd15a559d9b25a91c0536bb3147 Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Tue, 18 Jun 2024 16:56:02 +0200 Subject: [PATCH 22/25] Use MySQL as second DB in example --- examples/postgres-integration/docker-compose.yml | 7 ++++--- examples/postgres-integration/policies1.json | 2 +- examples/postgres-integration/seeds/db1/init.sql | 13 +++++++++---- .../postgres-integration/seeds/db1/insurance.csv | 6 +++--- 4 files changed, 17 insertions(+), 11 deletions(-) diff --git a/examples/postgres-integration/docker-compose.yml b/examples/postgres-integration/docker-compose.yml index 2ce664b..a87bf1e 100644 --- a/examples/postgres-integration/docker-compose.yml +++ b/examples/postgres-integration/docker-compose.yml @@ -12,15 +12,16 @@ services: POSTGRES_PASSWORD: test db1: - image: postgres:16.2 + image: mysql:8.4 restart: always shm_size: 128mb volumes: - ./seeds/db1:/docker-entrypoint-initdb.d + - ./seeds/db1/insurance.csv:/var/lib/mysql-files/insurance.csv ports: - - 5551:5432 + - 5551:3306 environment: - POSTGRES_PASSWORD: test + MYSQL_ROOT_PASSWORD: test db_out: image: postgres:16.2 diff --git a/examples/postgres-integration/policies1.json b/examples/postgres-integration/policies1.json index e953382..8d72c0c 100644 --- a/examples/postgres-integration/policies1.json +++ b/examples/postgres-integration/policies1.json @@ -10,7 +10,7 @@ "leader": 0, "party": 1, "input": "SELECT id, name, status FROM insurance", - "input_db": "postgres://postgres:test@localhost:5551/postgres", + "input_db": "mysql://root:test@localhost:5551/test", "constants": { "ROWS": { "query": "SELECT COUNT(id) FROM insurance", diff --git a/examples/postgres-integration/seeds/db1/init.sql b/examples/postgres-integration/seeds/db1/init.sql index 540e94e..aab35aa 100644 --- a/examples/postgres-integration/seeds/db1/init.sql +++ b/examples/postgres-integration/seeds/db1/init.sql @@ -1,3 +1,6 @@ +CREATE DATABASE test; +USE test; + CREATE TABLE insurance ( id SERIAL PRIMARY KEY, name TEXT, @@ -5,7 +8,9 @@ CREATE TABLE insurance ( address TEXT ); -COPY insurance -FROM '/docker-entrypoint-initdb.d/insurance.csv' -DELIMITER ',' -CSV HEADER; \ No newline at end of file +LOAD DATA INFILE '/var/lib/mysql-files/insurance.csv' +INTO TABLE insurance +FIELDS TERMINATED BY ',' +OPTIONALLY ENCLOSED BY '"' +LINES TERMINATED BY '\n' +IGNORE 1 ROWS diff --git a/examples/postgres-integration/seeds/db1/insurance.csv b/examples/postgres-integration/seeds/db1/insurance.csv index 5d0ad18..f4ffc7c 100644 --- a/examples/postgres-integration/seeds/db1/insurance.csv +++ b/examples/postgres-integration/seeds/db1/insurance.csv @@ -1,4 +1,4 @@ Id,Name,Status,Address -0,"Max Doe","foo","Parkstraße 1, 12345 Berlin" -1,"Bob Doe","bar","Schlossallee 2, 12345 Berlin" -2,"Jane Doe","foobar","Schlossallee 2, 12345 Berlin" +1,"Max Doe","foo","Parkstraße 1, 12345 Berlin" +2,"Bob Doe","bar","Schlossallee 2, 12345 Berlin" +3,"Jane Doe","foobar","Schlossallee 2, 12345 Berlin" From dca47bc7a679eeabf207e3cde576caed8dd722e7 Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Wed, 19 Jun 2024 16:36:49 +0200 Subject: [PATCH 23/25] Fix decoding of MySQL blob columns --- examples/postgres-integration/src/main.rs | 26 +++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/examples/postgres-integration/src/main.rs b/examples/postgres-integration/src/main.rs index 91089d8..26da636 100644 --- a/examples/postgres-integration/src/main.rs +++ b/examples/postgres-integration/src/main.rs @@ -24,7 +24,7 @@ use serde::{Deserialize, Serialize}; use serde_json::json; use sqlx::{ any::{install_default_drivers, AnyQueryResult, AnyRow}, - AnyPool, Pool, Row, + AnyPool, Pool, Row, ValueRef, }; use std::{ borrow::BorrowMut, collections::HashMap, net::SocketAddr, path::PathBuf, process::exit, @@ -492,7 +492,15 @@ async fn execute_mpc( if ty.as_ref() == &Type::Unsigned(UnsignedNumType::U8) => { let size = *prg.const_sizes.get(size).unwrap(); + let mut str = None; if let Ok(s) = row.try_get::(c) { + str = Some(s) + } else if let Ok(s) = row.try_get::, _>(c) { + if let Ok(s) = String::from_utf8(s) { + str = Some(s) + } + } + if let Some(s) = str { let mut fixed_str = vec![Literal::NumUnsigned(0, UnsignedNumType::U8); size]; for (i, b) in s.as_bytes().iter().enumerate() { @@ -514,7 +522,15 @@ async fn execute_mpc( let Some(enum_def) = prg.program.enum_defs.get(name) else { bail!("Could not find definition for enum {name} in program"); }; - if let Ok(variant) = row.try_get::(c) { + let mut str = None; + if let Ok(s) = row.try_get::(c) { + str = Some(s) + } else if let Ok(s) = row.try_get::, _>(c) { + if let Ok(s) = String::from_utf8(s) { + str = Some(s) + } + } + if let Some(variant) = str { for v in &enum_def.variants { if let Variant::Unit(variant_name) = v { if variant_name.to_lowercase() == variant.to_lowercase() { @@ -533,6 +549,12 @@ async fn execute_mpc( } if let Some(literal) = literal { row_as_literal.push(literal); + } else if let Ok(raw) = row.try_get_raw(c) { + bail!( + "Could not decode column {c} with type {}, column has type {}", + field_types[c], + raw.type_info() + ); } else { bail!("Could not decode column {c} with type {}", field_types[c]); } From ccbf3678ca175f38627671c7424777e876446b36 Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Wed, 19 Jun 2024 16:57:03 +0200 Subject: [PATCH 24/25] Rename postgres-integration -> sql-integration --- Cargo.toml | 2 +- .../.example.garble.rs | 0 examples/{postgres-integration => sql-integration}/Cargo.toml | 2 +- .../docker-compose.yml | 0 .../{postgres-integration => sql-integration}/policies0.json | 0 .../{postgres-integration => sql-integration}/policies1.json | 0 .../{postgres-integration => sql-integration}/policies2.json | 0 .../{postgres-integration => sql-integration}/preprocessor.json | 0 .../seeds/db0/init.sql | 0 .../seeds/db0/residents.csv | 0 .../seeds/db1/init.sql | 0 .../seeds/db1/insurance.csv | 0 .../seeds/db_out/init.sql | 0 examples/{postgres-integration => sql-integration}/src/main.rs | 0 .../templates/policies.html | 0 .../templates/policy.html | 0 examples/{postgres-integration => sql-integration}/tests/cli.rs | 0 17 files changed, 2 insertions(+), 2 deletions(-) rename examples/{postgres-integration => sql-integration}/.example.garble.rs (100%) rename examples/{postgres-integration => sql-integration}/Cargo.toml (94%) rename examples/{postgres-integration => sql-integration}/docker-compose.yml (100%) rename examples/{postgres-integration => sql-integration}/policies0.json (100%) rename examples/{postgres-integration => sql-integration}/policies1.json (100%) rename examples/{postgres-integration => sql-integration}/policies2.json (100%) rename examples/{postgres-integration => sql-integration}/preprocessor.json (100%) rename examples/{postgres-integration => sql-integration}/seeds/db0/init.sql (100%) rename examples/{postgres-integration => sql-integration}/seeds/db0/residents.csv (100%) rename examples/{postgres-integration => sql-integration}/seeds/db1/init.sql (100%) rename examples/{postgres-integration => sql-integration}/seeds/db1/insurance.csv (100%) rename examples/{postgres-integration => sql-integration}/seeds/db_out/init.sql (100%) rename examples/{postgres-integration => sql-integration}/src/main.rs (100%) rename examples/{postgres-integration => sql-integration}/templates/policies.html (100%) rename examples/{postgres-integration => sql-integration}/templates/policy.html (100%) rename examples/{postgres-integration => sql-integration}/tests/cli.rs (100%) diff --git a/Cargo.toml b/Cargo.toml index efeb156..d26173b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ workspace = { members = [ "examples/http-multi-server-channels", "examples/http-single-server-channels", "examples/iroh-p2p-channels", - "examples/postgres-integration", + "examples/sql-integration", ] } [package] name = "parlay" diff --git a/examples/postgres-integration/.example.garble.rs b/examples/sql-integration/.example.garble.rs similarity index 100% rename from examples/postgres-integration/.example.garble.rs rename to examples/sql-integration/.example.garble.rs diff --git a/examples/postgres-integration/Cargo.toml b/examples/sql-integration/Cargo.toml similarity index 94% rename from examples/postgres-integration/Cargo.toml rename to examples/sql-integration/Cargo.toml index 99df69d..6de82ff 100644 --- a/examples/postgres-integration/Cargo.toml +++ b/examples/sql-integration/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "parlay-postgres-integration" +name = "parlay-sql-integration" version = "0.1.0" edition = "2021" diff --git a/examples/postgres-integration/docker-compose.yml b/examples/sql-integration/docker-compose.yml similarity index 100% rename from examples/postgres-integration/docker-compose.yml rename to examples/sql-integration/docker-compose.yml diff --git a/examples/postgres-integration/policies0.json b/examples/sql-integration/policies0.json similarity index 100% rename from examples/postgres-integration/policies0.json rename to examples/sql-integration/policies0.json diff --git a/examples/postgres-integration/policies1.json b/examples/sql-integration/policies1.json similarity index 100% rename from examples/postgres-integration/policies1.json rename to examples/sql-integration/policies1.json diff --git a/examples/postgres-integration/policies2.json b/examples/sql-integration/policies2.json similarity index 100% rename from examples/postgres-integration/policies2.json rename to examples/sql-integration/policies2.json diff --git a/examples/postgres-integration/preprocessor.json b/examples/sql-integration/preprocessor.json similarity index 100% rename from examples/postgres-integration/preprocessor.json rename to examples/sql-integration/preprocessor.json diff --git a/examples/postgres-integration/seeds/db0/init.sql b/examples/sql-integration/seeds/db0/init.sql similarity index 100% rename from examples/postgres-integration/seeds/db0/init.sql rename to examples/sql-integration/seeds/db0/init.sql diff --git a/examples/postgres-integration/seeds/db0/residents.csv b/examples/sql-integration/seeds/db0/residents.csv similarity index 100% rename from examples/postgres-integration/seeds/db0/residents.csv rename to examples/sql-integration/seeds/db0/residents.csv diff --git a/examples/postgres-integration/seeds/db1/init.sql b/examples/sql-integration/seeds/db1/init.sql similarity index 100% rename from examples/postgres-integration/seeds/db1/init.sql rename to examples/sql-integration/seeds/db1/init.sql diff --git a/examples/postgres-integration/seeds/db1/insurance.csv b/examples/sql-integration/seeds/db1/insurance.csv similarity index 100% rename from examples/postgres-integration/seeds/db1/insurance.csv rename to examples/sql-integration/seeds/db1/insurance.csv diff --git a/examples/postgres-integration/seeds/db_out/init.sql b/examples/sql-integration/seeds/db_out/init.sql similarity index 100% rename from examples/postgres-integration/seeds/db_out/init.sql rename to examples/sql-integration/seeds/db_out/init.sql diff --git a/examples/postgres-integration/src/main.rs b/examples/sql-integration/src/main.rs similarity index 100% rename from examples/postgres-integration/src/main.rs rename to examples/sql-integration/src/main.rs diff --git a/examples/postgres-integration/templates/policies.html b/examples/sql-integration/templates/policies.html similarity index 100% rename from examples/postgres-integration/templates/policies.html rename to examples/sql-integration/templates/policies.html diff --git a/examples/postgres-integration/templates/policy.html b/examples/sql-integration/templates/policy.html similarity index 100% rename from examples/postgres-integration/templates/policy.html rename to examples/sql-integration/templates/policy.html diff --git a/examples/postgres-integration/tests/cli.rs b/examples/sql-integration/tests/cli.rs similarity index 100% rename from examples/postgres-integration/tests/cli.rs rename to examples/sql-integration/tests/cli.rs From a2dd5e1d8e9412c8bc94952735a7637005e46c29 Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Wed, 19 Jun 2024 17:19:48 +0200 Subject: [PATCH 25/25] Fix GH workflow after renaming sql-integration example --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3774588..c5a8bdb 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -21,4 +21,4 @@ jobs: - run: cargo test -- --nocapture working-directory: ./examples/iroh-p2p-channels - run: docker compose -f docker-compose.yml up -d && cargo build --release && cargo test --release -- --nocapture - working-directory: ./examples/postgres-integration + working-directory: ./examples/sql-integration