From 701f96436b7cef0d64c5c2a808cf3f2823d39bb2 Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Tue, 3 Sep 2024 18:13:00 +0100 Subject: [PATCH] Reduce clones / to_vec allocations in faand.rs --- benches/join.rs | 4 +- src/faand.rs | 148 ++++++++++++++++++------------------------------ 2 files changed, 56 insertions(+), 96 deletions(-) diff --git a/benches/join.rs b/benches/join.rs index 62b3748..125a79c 100644 --- a/benches/join.rs +++ b/benches/join.rs @@ -9,7 +9,7 @@ use garble_lang::{ use parlay::protocol::simulate_mpc_async; fn join_benchmark(c: &mut Criterion) { - let n_records = 100; + let n_records = 200; let code = include_str!(".join.garble.rs"); let bench_id = format!("join {n_records} records"); @@ -73,7 +73,7 @@ fn join_benchmark(c: &mut Criterion) { let inputs = vec![input0.as_slice(), input1.as_slice()]; - simulate_mpc_async(&prg.circuit, &inputs, &[0], true) + simulate_mpc_async(&prg.circuit, &inputs, &[0], false) .await .unwrap(); let elapsed = now.elapsed(); diff --git a/src/faand.rs b/src/faand.rs index f1d7d7c..688d5d5 100644 --- a/src/faand.rs +++ b/src/faand.rs @@ -109,9 +109,9 @@ pub(crate) async fn fabit( channel: &mut MsgChannel, p_to: usize, delta: Delta, - c: Vec, //xbits + c: &[bool], //xbits role: bool, - (b, r0, _r1, rb): (Vec, Vec, Vec, Vec), + (b, r0, _r1, rb): (&[bool], &[u128], &[u128], &[u128]), ) -> Result, Error> { match role { true => { @@ -136,11 +136,9 @@ pub(crate) async fn fabit( false => { let x0: Vec = (0..r0.len()).map(|_| random::()).collect(); let x1: Vec = x0.iter().map(|&x| x ^ delta.0).collect(); - let y0: Vec = x0.clone(); - let y1: Vec = x1.clone(); let _kbits: Vec = channel.recv_from(p_to, "bits").await?; - channel.send_to(p_to, "y0y1", &(y0, y1)).await?; + channel.send_to(p_to, "y0y1", &(&x0, &x1)).await?; Ok(x0) } @@ -171,11 +169,7 @@ pub(crate) async fn fabitn( let mut xkeys: Vec> = vec![vec![]; p_max]; let mut xmacs: Vec> = vec![vec![]; p_max]; for p in (0..p_max).filter(|p| *p != p_own) { - let delta_added: Vec = sender_ot[p] - .clone() - .into_iter() - .map(|x| x ^ delta.0) - .collect(); + let delta_added: Vec = sender_ot[p].iter().map(|x| x ^ delta.0).collect(); let macvec: Vec; let keyvec: Vec; if p_own < p { @@ -183,23 +177,18 @@ pub(crate) async fn fabitn( channel, p, delta, - x.to_vec(), + &x, true, - ( - receiver_ot[p].0.clone(), - vec![], - vec![], - receiver_ot[p].1.clone(), - ), + (&receiver_ot[p].0, &[], &[], &receiver_ot[p].1), ) .await?; keyvec = fabit( channel, p, delta, - x.to_vec(), + &x, false, - (vec![], sender_ot[p].clone(), delta_added, vec![]), + (&[], &sender_ot[p], &delta_added, &[]), ) .await?; } else { @@ -207,23 +196,18 @@ pub(crate) async fn fabitn( channel, p, delta, - x.to_vec(), + &x, false, - (vec![], sender_ot[p].clone(), delta_added, vec![]), + (&[], &sender_ot[p], &delta_added, &[]), ) .await?; macvec = fabit( channel, p, delta, - x.to_vec(), + &x, true, - ( - receiver_ot[p].0.clone(), - vec![], - vec![], - receiver_ot[p].1.clone(), - ), + (&receiver_ot[p].0, &[], &[], &receiver_ot[p].1), ) .await?; } @@ -260,9 +244,7 @@ pub(crate) async fn fabitn( } for p in (0..p_max).filter(|p| *p != p_own) { - channel - .send_to(p, "mac", &(macint[p].clone(), randbits.clone())) - .await?; + channel.send_to(p, "mac", &(&macint[p], &randbits)).await?; } let mut macp: Vec> = vec![vec![0; two_rho]; p_max]; @@ -452,7 +434,7 @@ pub(crate) async fn fhaand( p_max: usize, delta: Delta, length: usize, - x: Vec, + x: &[Share], y: Vec, ) -> Result, Error> { // Step 1 obtain (input). @@ -522,14 +504,14 @@ pub(crate) fn hash128(input: u128) -> u128 { #[allow(clippy::too_many_arguments)] pub(crate) async fn flaand( channel: &mut MsgChannel, - xbits: Vec, - ybits: Vec, - rbits: Vec, + xbits: &[Share], + ybits: &[Share], + rbits: &[Share], p_own: usize, p_max: usize, delta: Delta, length: usize, -) -> Result<(Vec, Vec, Vec), Error> { +) -> Result, Error> { // Triple computation. // Step 1 triple computation. let mut yvec: Vec = vec![false; length]; @@ -538,7 +520,7 @@ pub(crate) async fn flaand( } // Step 2 run Pi_HaAND for each pair of parties. - let v = fhaand(channel, p_own, p_max, delta, length, xbits.clone(), yvec).await?; + let v = fhaand(channel, p_own, p_max, delta, length, xbits, yvec).await?; // Step 3 compute z and e and broadcast e. let mut z: Vec = vec![false; length]; @@ -657,7 +639,7 @@ pub(crate) async fn flaand( } } - Ok((xbits, ybits, zbits)) + Ok(zbits) } /// Calculates the bucket size according to WRK17a, Table 4 for statistical security ρ = 40 (rho). @@ -670,15 +652,12 @@ pub(crate) fn bucket_size(circuit_size: usize) -> usize { } /// Transforms all triples into a single vector of triples. -fn transform( - alltriples: (Vec, Vec, Vec), - length: usize, -) -> Vec<(Share, Share, Share)> { +fn transform(x: &[Share], y: &[Share], z: &[Share], length: usize) -> Vec<(Share, Share, Share)> { let mut triples: Vec<(Share, Share, Share)> = vec![]; for l in 0..length { - let s1 = alltriples.0[l].clone(); - let s2 = alltriples.1[l].clone(); - let s3 = alltriples.2[l].clone(); + let s1 = x[l].clone(); + let s2 = y[l].clone(); + let s3 = z[l].clone(); triples.push((s1, s2, s3)); } triples @@ -705,18 +684,9 @@ pub(crate) async fn faand_precomp( let (ybits, rbits) = rest.split_at(lprime); // Step 1 generate all leaky and triples. - let alltriples: (Vec, Vec, Vec) = flaand( - channel, - xbits.to_vec(), - ybits.to_vec(), - rbits.to_vec(), - p_own, - p_max, - delta, - lprime, - ) - .await?; - let triples = transform(alltriples, lprime); + let zbits: Vec = + flaand(channel, xbits, ybits, rbits, p_own, p_max, delta, lprime).await?; + let triples = transform(xbits, ybits, &zbits, lprime); // Step 2 assign objects to buckets. let mut buckets: Vec> = vec![vec![]; length]; @@ -734,20 +704,15 @@ pub(crate) async fn faand_precomp( } // Step 3 check d-values. - let y: Vec> = buckets - .clone() - .into_iter() - .map(|bucket| bucket.into_iter().map(|(_, share2, _)| share2).collect()) - .collect(); - let dvalues = check_dvalue(channel, p_own, p_max, delta, y).await?; + let dvalues = check_dvalue(channel, p_own, p_max, &buckets).await?; - let mut bucketcombined: Vec<(Share, Share, Share)> = vec![]; + let mut combined: Vec<(Share, Share, Share)> = Vec::with_capacity(buckets.len()); - for (buc, dval) in buckets.into_iter().zip(dvalues.into_iter()) { - bucketcombined.push(combine_bucket(p_own, p_max, buc, dval)?); + for (bucket, dval) in buckets.into_iter().zip(dvalues.into_iter()) { + combined.push(combine_bucket(p_own, p_max, bucket, dval)?); } - Ok(bucketcombined) + Ok(combined) } /// Protocol Pi_aAND that performs F_aAND. @@ -801,13 +766,7 @@ pub(crate) async fn faand( emacs.push(Some(emac)); fmacs.push(Some(fmac)); } - channel - .send_to( - p, - "ef", - &(e.clone(), f.clone(), emacs.clone(), fmacs.clone()), - ) - .await?; + channel.send_to(p, "ef", &(&e, &f, &emacs, &fmacs)).await?; } let mut ep: Vec> = vec![vec![]; p_max]; let mut fp: Vec> = vec![vec![]; p_max]; @@ -838,7 +797,7 @@ pub(crate) async fn faand( Ok(result) } -/// Combine the whole bucket by combining elements two by two. +/// Combine the whole bucket by combining elements one by one. pub(crate) fn combine_bucket( p_own: usize, p_max: usize, @@ -849,12 +808,13 @@ pub(crate) fn combine_bucket( return Err(Error::EmptyBucketError); } - let mut result = bucket[0].clone(); + let mut bucket = bucket.into_iter(); + let mut result = bucket.next().unwrap(); - // Combine elements two by two, starting from the second element. - for (i, triple) in bucket.iter().enumerate().skip(1) { - let d = d_values[i - 1]; - result = combine_two_leaky_ands(p_own, p_max, &result, triple, d)?; + // Combine elements one by one, starting from the second element. + for (i, triple) in bucket.enumerate() { + let d = d_values[i]; + result = combine_two_leaky_ands(p_own, p_max, result, triple, d)?; } Ok(result) } @@ -864,20 +824,20 @@ pub(crate) async fn check_dvalue( channel: &mut MsgChannel, p_own: usize, p_max: usize, - _delta: Delta, - y: Vec>, + buckets: &[Vec<(Share, Share, Share)>], ) -> Result>, Error> { // Step (a) compute and check macs of d-values. - let mut d_values: Vec> = vec![vec![]; y.len()]; - let mut d_macs: Vec>>> = vec![vec![vec![]; y.len()]; p_max]; + let mut d_values: Vec> = vec![vec![]; buckets.len()]; + let mut d_macs: Vec>>> = vec![vec![vec![]; buckets.len()]; p_max]; - for j in 0..y.len() { - let first = y[j][0].0; - for y_value in y[j].iter().skip(1) { + for j in 0..buckets.len() { + let (_, y, _) = &buckets[j][0]; + let first = y.0; + for (_, y_value, _) in buckets[j].iter().skip(1) { let current_d = first ^ y_value.0; d_values[j].push(current_d); for p in (0..p_max).filter(|p| *p != p_own) { - let Some((y0mac, _)) = y[j][0].1 .0[p] else { + let Some((y0mac, _)) = y.1 .0[p] else { return Err(Error::MissingMacKey); }; let Some((ymac, _)) = y_value.1 .0[p] else { @@ -890,7 +850,7 @@ pub(crate) async fn check_dvalue( for p in (0..p_max).filter(|p| *p != p_own) { channel - .send_to(p, "dvalue", &(d_values.clone(), d_macs[p].clone())) + .send_to(p, "dvalue", &(&d_values, &d_macs[p])) .await?; } let mut all_d_values: Vec>> = vec![vec![]; p_max]; @@ -900,7 +860,7 @@ pub(crate) async fn check_dvalue( } for p in (0..p_max).filter(|p| *p != p_own) { - for (j, dval) in d_values.iter_mut().enumerate().take(y.len()) { + for (j, dval) in d_values.iter_mut().enumerate().take(buckets.len()) { for (i, d) in dval.iter_mut().enumerate().take(all_d_macs[p][j].len()) { *d ^= all_d_values[p][j][i]; } @@ -914,8 +874,8 @@ pub(crate) async fn check_dvalue( pub(crate) fn combine_two_leaky_ands( p_own: usize, p_max: usize, - (x1, y1, z1): &(Share, Share, Share), - (x2, _, z2): &(Share, Share, Share), + (x1, y1, z1): (Share, Share, Share), + (x2, _, z2): (Share, Share, Share), d: bool, ) -> Result<(Share, Share, Share), Error> { //Step (b) compute x, y, z. @@ -951,5 +911,5 @@ pub(crate) fn combine_two_leaky_ands( } let zres: Share = Share(zbit, zauth); - Ok((xres, y1.clone(), zres)) -} \ No newline at end of file + Ok((xres, y1, zres)) +}