Skip to content

Commit

Permalink
Refactor dvalue computation and resolve clippy warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
kisakishy committed Sep 2, 2024
1 parent 8a42410 commit 3a3fb1c
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 58 deletions.
15 changes: 12 additions & 3 deletions examples/http-single-server-channels/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,18 @@ async fn main() {
}
let fpre = Preprocessor::TrustedDealer(P_DEALER);
let p_out: Vec<_> = (0..parties).collect();
let output = mpc(channel, &prg.circuit, &input, fpre, p_eval, party, &p_out, true)
.await
.unwrap();
let output = mpc(
channel,
&prg.circuit,
&input,
fpre,
p_eval,
party,
&p_out,
true,
)
.await
.unwrap();
if !output.is_empty() {
let result = prg.parse_output(&output).unwrap();
println!("\nThe result is {result}");
Expand Down
15 changes: 12 additions & 3 deletions examples/iroh-p2p-channels/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,18 @@ async fn main() -> Result<()> {
let fpre = Preprocessor::TrustedDealer(conns.len() - 1);
let channel = IrohChannel::new(conns, MAX_MSG_BYTES);
let p_out: Vec<_> = (0..parties).collect();
let output = mpc(channel, &prg.circuit, &input, fpre, p_eval, party, &p_out, true)
.await
.unwrap();
let output = mpc(
channel,
&prg.circuit,
&input,
fpre,
p_eval,
party,
&p_out,
true,
)
.await
.unwrap();
if !output.is_empty() {
let result = prg.parse_output(&output).unwrap();
println!("\nThe result is {result}");
Expand Down
111 changes: 59 additions & 52 deletions src/faand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,10 +287,10 @@ pub(crate) async fn fabitn(
}

let mut macint: Vec<Vec<u128>> = vec![vec![0; 2 * RHO]; p_max];
for ind in 0..2 * RHO {
for (ind, randbit_vec) in randbits.iter().take(2 * RHO).enumerate() {
for p in (0..p_max).filter(|p| *p != p_own) {
for (i, rbit) in randbits[ind].iter().enumerate().take(len_abit) {
if *rbit {
for (i, &rbit) in randbit_vec.iter().take(len_abit).enumerate() {
if rbit {
macint[p][ind] ^= xmacs[p][i];
}
}
Expand Down Expand Up @@ -813,9 +813,18 @@ pub(crate) async fn faand_precomp(
}

// Step 3

let y: Vec<Vec<Share>> = buckets
.clone()
.into_iter()
.map(|bucket| bucket.into_iter().map(|(_, share2, _)| share2).collect())
.collect();
let dvalues = check_dvalue(channel, p_own, p_max, delta, y).await?;

let mut bucketcombined: Vec<(Share, Share, Share)> = vec![];
for buc in buckets {
bucketcombined.push(combine_bucket(channel, p_own, p_max, delta, buc).await?);

for (buc, dval) in buckets.into_iter().zip(dvalues.into_iter()) {
bucketcombined.push(combine_bucket(p_own, p_max, buc, dval)?);
}

Ok(bucketcombined)
Expand Down Expand Up @@ -936,28 +945,23 @@ pub(crate) async fn faand(
}

/// Combine the whole bucket by combining elements two by two.
pub(crate) async fn combine_bucket(
channel: &mut MsgChannel<impl Channel>,
pub(crate) fn combine_bucket(
p_own: usize,
p_max: usize,
delta: Delta,
bucket: Vec<(Share, Share, Share)>,
d_values: Vec<bool>,
) -> Result<(Share, Share, Share), Error> {
if bucket.is_empty() {
return Err(Error::EmptyBucketError);
}

let y: Vec<Share> = bucket.iter().map(|(_, share1, _)| share1.clone()).collect();
let d_values = check_dvalue(channel, p_own, p_max, delta, y).await?;

let mut result = bucket[0].clone();

// Combine elements two by two, starting from the second element
for (i, triple) in bucket.iter().enumerate().skip(1) {
let d = d_values[i - 1];
result = combine_two_leaky_ands(p_own, p_max, &result, triple, d)?;
}

Ok(result)
}

Expand All @@ -966,63 +970,66 @@ pub(crate) async fn check_dvalue(
channel: &mut MsgChannel<impl Channel>,
p_own: usize,
p_max: usize,
delta: Delta,
y: Vec<Share>,
) -> Result<Vec<bool>, Error> {
_delta: Delta,
y: Vec<Vec<Share>>,
) -> Result<Vec<Vec<bool>>, Error> {
// Step (a)
let mut d_values: Vec<bool> = vec![false; y.len() - 1];
let mut d_macs: Vec<Vec<Option<Mac>>> = vec![vec![None; y.len() - 1]; p_max];

let first = y[0].0;
let mut i = 0;
for y_value in y.iter().skip(1) {
let current_d = first ^ y_value.0;
d_values[i] = current_d;
for p in (0..p_max).filter(|p| *p != p_own) {
let Some((y0mac, _)) = y[0].1 .0[p] else {
return Err(Error::MissingMacKey);
};
let Some((ymac, _)) = y_value.1 .0[p] else {
return Err(Error::MissingMacKey);
};
d_macs[p][i] = Some(y0mac ^ ymac);
let mut d_values: Vec<Vec<bool>> = vec![vec![]; y.len()];
let mut d_macs: Vec<Vec<Vec<Option<Mac>>>> = vec![vec![vec![]; y.len()]; p_max];

for j in 0..y.len() {
let first = y[j][0].0;
for y_value in y[j].iter().skip(1) {
let current_d = first ^ y_value.0;
d_values[j].push(current_d);
for p in (0..p_max).filter(|p| *p != p_own) {
let Some((y0mac, _)) = y[j][0].1 .0[p] else {
return Err(Error::MissingMacKey);
};
let Some((ymac, _)) = y_value.1 .0[p] else {
return Err(Error::MissingMacKey);
};
d_macs[p][j].push(Some(y0mac ^ ymac));
}
}
i += 1;
}

for p in (0..p_max).filter(|p| *p != p_own) {
channel
.send_to(p, "dvalue", &(d_values.clone(), d_macs[p].clone()))
.await?;
}
let mut all_d_values: Vec<Vec<bool>> = vec![vec![false; y.len() - 1]; p_max];
let mut all_d_macs: Vec<Vec<Option<Mac>>> = vec![vec![None; y.len() - 1]; p_max];
let mut all_d_values: Vec<Vec<Vec<bool>>> = vec![vec![]; p_max];
let mut all_d_macs: Vec<Vec<Vec<Option<Mac>>>> = vec![vec![]; p_max];
for p in (0..p_max).filter(|p| *p != p_own) {
(all_d_values[p], all_d_macs[p]) = channel.recv_from(p, "dvalue").await?;
}

for p in (0..p_max).filter(|p| *p != p_own) {
for i in 0..y.len() - 1 {
let Some(dmac) = all_d_macs[p][i] else {
return Err(Error::MissingMacKey);
};
let Some((_, y0key)) = y[0].1 .0[p] else {
return Err(Error::MissingMacKey);
};
/*if i != 0{
let Some((_, ykey)) = y[i].1 .0[p] else {
for (j, dval) in d_values.iter_mut().enumerate().take(y.len()) {
//for i in 0..all_d_macs[p][j].len() {
for (i, d) in dval.iter_mut().enumerate().take(all_d_macs[p][j].len()) {
/*let Some(dmac) = all_d_macs[p][j][i] else {
return Err(Error::MissingMacKey);
};
if (all_d_values[p][i] && dmac.0 != y0key.0 ^ ykey.0 ^ delta.0)
|| (!all_d_values[p][i] && dmac.0 != y0key.0 ^ ykey.0)
{
println!("{:?} {:?} {:?}", dmac.0, y0key.0 ^ ykey.0 ^ delta.0, y0key.0 ^ ykey.0);
return Err(Error::AANDWrongDMAC);
}
}*/
//TODO TOFIX
let Some((_, y0key)) = y[j][0].1 .0[p] else {
return Err(Error::MissingMacKey);
};
if i != 0{
let Some((_, ykey)) = y[j][i].1 .0[p] else {
return Err(Error::MissingMacKey);
};
if (all_d_values[p][j][i] && dmac.0 != y0key.0 ^ ykey.0 ^ delta.0)
|| (!all_d_values[p][j][i] && dmac.0 != y0key.0 ^ ykey.0)
{
println!("{:?} {:?} {:?}", dmac.0, y0key.0 ^ ykey.0 ^ delta.0, y0key.0 ^ ykey.0);
return Err(Error::AANDWrongDMAC);
}
}*/
//TODO TOFIX

d_values[i] ^= all_d_values[p][i];
*d ^= all_d_values[p][j][i];
}
}
}

Expand Down
1 change: 1 addition & 0 deletions src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ pub enum Preprocessor {
}

/// Executes the MPC protocol for one party and returns the outputs (empty for the contributor).
#[allow(clippy::too_many_arguments)]
pub async fn mpc(
channel: impl Channel,
circuit: &Circuit,
Expand Down

0 comments on commit 3a3fb1c

Please sign in to comment.