Skip to content

Commit

Permalink
Reduce clones / to_vec allocations in faand.rs
Browse files Browse the repository at this point in the history
  • Loading branch information
fkettelhoit committed Sep 3, 2024
1 parent a09ed50 commit 701f964
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 96 deletions.
4 changes: 2 additions & 2 deletions benches/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use garble_lang::{
use parlay::protocol::simulate_mpc_async;

fn join_benchmark(c: &mut Criterion) {
let n_records = 100;
let n_records = 200;
let code = include_str!(".join.garble.rs");

let bench_id = format!("join {n_records} records");
Expand Down Expand Up @@ -73,7 +73,7 @@ fn join_benchmark(c: &mut Criterion) {

let inputs = vec![input0.as_slice(), input1.as_slice()];

simulate_mpc_async(&prg.circuit, &inputs, &[0], true)
simulate_mpc_async(&prg.circuit, &inputs, &[0], false)
.await
.unwrap();
let elapsed = now.elapsed();
Expand Down
148 changes: 54 additions & 94 deletions src/faand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ pub(crate) async fn fabit(
channel: &mut MsgChannel<impl Channel>,
p_to: usize,
delta: Delta,
c: Vec<bool>, //xbits
c: &[bool], //xbits
role: bool,
(b, r0, _r1, rb): (Vec<bool>, Vec<u128>, Vec<u128>, Vec<u128>),
(b, r0, _r1, rb): (&[bool], &[u128], &[u128], &[u128]),
) -> Result<Vec<u128>, Error> {
match role {
true => {
Expand All @@ -136,11 +136,9 @@ pub(crate) async fn fabit(
false => {
let x0: Vec<u128> = (0..r0.len()).map(|_| random::<u128>()).collect();
let x1: Vec<u128> = x0.iter().map(|&x| x ^ delta.0).collect();
let y0: Vec<u128> = x0.clone();
let y1: Vec<u128> = x1.clone();

let _kbits: Vec<bool> = channel.recv_from(p_to, "bits").await?;
channel.send_to(p_to, "y0y1", &(y0, y1)).await?;
channel.send_to(p_to, "y0y1", &(&x0, &x1)).await?;

Ok(x0)
}
Expand Down Expand Up @@ -171,59 +169,45 @@ pub(crate) async fn fabitn(
let mut xkeys: Vec<Vec<u128>> = vec![vec![]; p_max];
let mut xmacs: Vec<Vec<u128>> = vec![vec![]; p_max];
for p in (0..p_max).filter(|p| *p != p_own) {
let delta_added: Vec<u128> = sender_ot[p]
.clone()
.into_iter()
.map(|x| x ^ delta.0)
.collect();
let delta_added: Vec<u128> = sender_ot[p].iter().map(|x| x ^ delta.0).collect();
let macvec: Vec<u128>;
let keyvec: Vec<u128>;
if p_own < p {
macvec = fabit(
channel,
p,
delta,
x.to_vec(),
&x,
true,
(
receiver_ot[p].0.clone(),
vec![],
vec![],
receiver_ot[p].1.clone(),
),
(&receiver_ot[p].0, &[], &[], &receiver_ot[p].1),
)
.await?;
keyvec = fabit(
channel,
p,
delta,
x.to_vec(),
&x,
false,
(vec![], sender_ot[p].clone(), delta_added, vec![]),
(&[], &sender_ot[p], &delta_added, &[]),
)
.await?;
} else {
keyvec = fabit(
channel,
p,
delta,
x.to_vec(),
&x,
false,
(vec![], sender_ot[p].clone(), delta_added, vec![]),
(&[], &sender_ot[p], &delta_added, &[]),
)
.await?;
macvec = fabit(
channel,
p,
delta,
x.to_vec(),
&x,
true,
(
receiver_ot[p].0.clone(),
vec![],
vec![],
receiver_ot[p].1.clone(),
),
(&receiver_ot[p].0, &[], &[], &receiver_ot[p].1),
)
.await?;
}
Expand Down Expand Up @@ -260,9 +244,7 @@ pub(crate) async fn fabitn(
}

for p in (0..p_max).filter(|p| *p != p_own) {
channel
.send_to(p, "mac", &(macint[p].clone(), randbits.clone()))
.await?;
channel.send_to(p, "mac", &(&macint[p], &randbits)).await?;
}

let mut macp: Vec<Vec<u128>> = vec![vec![0; two_rho]; p_max];
Expand Down Expand Up @@ -452,7 +434,7 @@ pub(crate) async fn fhaand(
p_max: usize,
delta: Delta,
length: usize,
x: Vec<Share>,
x: &[Share],
y: Vec<bool>,
) -> Result<Vec<bool>, Error> {
// Step 1 obtain <x> (input).
Expand Down Expand Up @@ -522,14 +504,14 @@ pub(crate) fn hash128(input: u128) -> u128 {
#[allow(clippy::too_many_arguments)]
pub(crate) async fn flaand(
channel: &mut MsgChannel<impl Channel>,
xbits: Vec<Share>,
ybits: Vec<Share>,
rbits: Vec<Share>,
xbits: &[Share],
ybits: &[Share],
rbits: &[Share],
p_own: usize,
p_max: usize,
delta: Delta,
length: usize,
) -> Result<(Vec<Share>, Vec<Share>, Vec<Share>), Error> {
) -> Result<Vec<Share>, Error> {
// Triple computation.
// Step 1 triple computation.
let mut yvec: Vec<bool> = vec![false; length];
Expand All @@ -538,7 +520,7 @@ pub(crate) async fn flaand(
}

// Step 2 run Pi_HaAND for each pair of parties.
let v = fhaand(channel, p_own, p_max, delta, length, xbits.clone(), yvec).await?;
let v = fhaand(channel, p_own, p_max, delta, length, xbits, yvec).await?;

// Step 3 compute z and e and broadcast e.
let mut z: Vec<bool> = vec![false; length];
Expand Down Expand Up @@ -657,7 +639,7 @@ pub(crate) async fn flaand(
}
}

Ok((xbits, ybits, zbits))
Ok(zbits)
}

/// Calculates the bucket size according to WRK17a, Table 4 for statistical security ρ = 40 (rho).
Expand All @@ -670,15 +652,12 @@ pub(crate) fn bucket_size(circuit_size: usize) -> usize {
}

/// Transforms all triples into a single vector of triples.
fn transform(
alltriples: (Vec<Share>, Vec<Share>, Vec<Share>),
length: usize,
) -> Vec<(Share, Share, Share)> {
fn transform(x: &[Share], y: &[Share], z: &[Share], length: usize) -> Vec<(Share, Share, Share)> {
let mut triples: Vec<(Share, Share, Share)> = vec![];
for l in 0..length {
let s1 = alltriples.0[l].clone();
let s2 = alltriples.1[l].clone();
let s3 = alltriples.2[l].clone();
let s1 = x[l].clone();
let s2 = y[l].clone();
let s3 = z[l].clone();
triples.push((s1, s2, s3));
}
triples
Expand All @@ -705,18 +684,9 @@ pub(crate) async fn faand_precomp(
let (ybits, rbits) = rest.split_at(lprime);

// Step 1 generate all leaky and triples.
let alltriples: (Vec<Share>, Vec<Share>, Vec<Share>) = flaand(
channel,
xbits.to_vec(),
ybits.to_vec(),
rbits.to_vec(),
p_own,
p_max,
delta,
lprime,
)
.await?;
let triples = transform(alltriples, lprime);
let zbits: Vec<Share> =
flaand(channel, xbits, ybits, rbits, p_own, p_max, delta, lprime).await?;
let triples = transform(xbits, ybits, &zbits, lprime);

// Step 2 assign objects to buckets.
let mut buckets: Vec<Vec<(Share, Share, Share)>> = vec![vec![]; length];
Expand All @@ -734,20 +704,15 @@ pub(crate) async fn faand_precomp(
}

// Step 3 check d-values.
let y: Vec<Vec<Share>> = buckets
.clone()
.into_iter()
.map(|bucket| bucket.into_iter().map(|(_, share2, _)| share2).collect())
.collect();
let dvalues = check_dvalue(channel, p_own, p_max, delta, y).await?;
let dvalues = check_dvalue(channel, p_own, p_max, &buckets).await?;

let mut bucketcombined: Vec<(Share, Share, Share)> = vec![];
let mut combined: Vec<(Share, Share, Share)> = Vec::with_capacity(buckets.len());

for (buc, dval) in buckets.into_iter().zip(dvalues.into_iter()) {
bucketcombined.push(combine_bucket(p_own, p_max, buc, dval)?);
for (bucket, dval) in buckets.into_iter().zip(dvalues.into_iter()) {
combined.push(combine_bucket(p_own, p_max, bucket, dval)?);
}

Ok(bucketcombined)
Ok(combined)
}

/// Protocol Pi_aAND that performs F_aAND.
Expand Down Expand Up @@ -801,13 +766,7 @@ pub(crate) async fn faand(
emacs.push(Some(emac));
fmacs.push(Some(fmac));
}
channel
.send_to(
p,
"ef",
&(e.clone(), f.clone(), emacs.clone(), fmacs.clone()),
)
.await?;
channel.send_to(p, "ef", &(&e, &f, &emacs, &fmacs)).await?;
}
let mut ep: Vec<Vec<bool>> = vec![vec![]; p_max];
let mut fp: Vec<Vec<bool>> = vec![vec![]; p_max];
Expand Down Expand Up @@ -838,7 +797,7 @@ pub(crate) async fn faand(
Ok(result)
}

/// Combine the whole bucket by combining elements two by two.
/// Combine the whole bucket by combining elements one by one.
pub(crate) fn combine_bucket(
p_own: usize,
p_max: usize,
Expand All @@ -849,12 +808,13 @@ pub(crate) fn combine_bucket(
return Err(Error::EmptyBucketError);
}

let mut result = bucket[0].clone();
let mut bucket = bucket.into_iter();
let mut result = bucket.next().unwrap();

// Combine elements two by two, starting from the second element.
for (i, triple) in bucket.iter().enumerate().skip(1) {
let d = d_values[i - 1];
result = combine_two_leaky_ands(p_own, p_max, &result, triple, d)?;
// Combine elements one by one, starting from the second element.
for (i, triple) in bucket.enumerate() {
let d = d_values[i];
result = combine_two_leaky_ands(p_own, p_max, result, triple, d)?;
}
Ok(result)
}
Expand All @@ -864,20 +824,20 @@ pub(crate) async fn check_dvalue(
channel: &mut MsgChannel<impl Channel>,
p_own: usize,
p_max: usize,
_delta: Delta,
y: Vec<Vec<Share>>,
buckets: &[Vec<(Share, Share, Share)>],
) -> Result<Vec<Vec<bool>>, Error> {
// Step (a) compute and check macs of d-values.
let mut d_values: Vec<Vec<bool>> = vec![vec![]; y.len()];
let mut d_macs: Vec<Vec<Vec<Option<Mac>>>> = vec![vec![vec![]; y.len()]; p_max];
let mut d_values: Vec<Vec<bool>> = vec![vec![]; buckets.len()];
let mut d_macs: Vec<Vec<Vec<Option<Mac>>>> = vec![vec![vec![]; buckets.len()]; p_max];

for j in 0..y.len() {
let first = y[j][0].0;
for y_value in y[j].iter().skip(1) {
for j in 0..buckets.len() {
let (_, y, _) = &buckets[j][0];
let first = y.0;
for (_, y_value, _) in buckets[j].iter().skip(1) {
let current_d = first ^ y_value.0;
d_values[j].push(current_d);
for p in (0..p_max).filter(|p| *p != p_own) {
let Some((y0mac, _)) = y[j][0].1 .0[p] else {
let Some((y0mac, _)) = y.1 .0[p] else {
return Err(Error::MissingMacKey);
};
let Some((ymac, _)) = y_value.1 .0[p] else {
Expand All @@ -890,7 +850,7 @@ pub(crate) async fn check_dvalue(

for p in (0..p_max).filter(|p| *p != p_own) {
channel
.send_to(p, "dvalue", &(d_values.clone(), d_macs[p].clone()))
.send_to(p, "dvalue", &(&d_values, &d_macs[p]))
.await?;
}
let mut all_d_values: Vec<Vec<Vec<bool>>> = vec![vec![]; p_max];
Expand All @@ -900,7 +860,7 @@ pub(crate) async fn check_dvalue(
}

for p in (0..p_max).filter(|p| *p != p_own) {
for (j, dval) in d_values.iter_mut().enumerate().take(y.len()) {
for (j, dval) in d_values.iter_mut().enumerate().take(buckets.len()) {
for (i, d) in dval.iter_mut().enumerate().take(all_d_macs[p][j].len()) {
*d ^= all_d_values[p][j][i];
}
Expand All @@ -914,8 +874,8 @@ pub(crate) async fn check_dvalue(
pub(crate) fn combine_two_leaky_ands(
p_own: usize,
p_max: usize,
(x1, y1, z1): &(Share, Share, Share),
(x2, _, z2): &(Share, Share, Share),
(x1, y1, z1): (Share, Share, Share),
(x2, _, z2): (Share, Share, Share),
d: bool,
) -> Result<(Share, Share, Share), Error> {
//Step (b) compute x, y, z.
Expand Down Expand Up @@ -951,5 +911,5 @@ pub(crate) fn combine_two_leaky_ands(
}
let zres: Share = Share(zbit, zauth);

Ok((xres, y1.clone(), zres))
}
Ok((xres, y1, zres))
}

0 comments on commit 701f964

Please sign in to comment.