Skip to content

Commit

Permalink
Add trusted dealer parameter and substitute unwraps
Browse files Browse the repository at this point in the history
  • Loading branch information
kisakishy committed Aug 29, 2024
1 parent b81a632 commit 6190746
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 100 deletions.
2 changes: 1 addition & 1 deletion examples/http-multi-server-channels/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ async fn main() -> Result<(), Error> {
let fpre = Preprocessor::TrustedDealer(urls.len() - 1);
let p_out: Vec<_> = (0..(urls.len() - 1)).collect();
let channel = HttpChannel::new(urls, party).await?;
let output = mpc(channel, &prg.circuit, &input, fpre, 0, party, &p_out).await?;
let output = mpc(channel, &prg.circuit, &input, fpre, 0, party, &p_out, true).await?;
if !output.is_empty() {
println!("\nThe result is {}", prg.parse_output(&output)?);
}
Expand Down
2 changes: 1 addition & 1 deletion examples/http-single-server-channels/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ async fn main() {
}
let fpre = Preprocessor::TrustedDealer(P_DEALER);
let p_out: Vec<_> = (0..parties).collect();
let output = mpc(channel, &prg.circuit, &input, fpre, p_eval, party, &p_out)
let output = mpc(channel, &prg.circuit, &input, fpre, p_eval, party, &p_out, true)
.await
.unwrap();
if !output.is_empty() {
Expand Down
2 changes: 1 addition & 1 deletion examples/iroh-p2p-channels/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ async fn main() -> Result<()> {
let fpre = Preprocessor::TrustedDealer(conns.len() - 1);
let channel = IrohChannel::new(conns, MAX_MSG_BYTES);
let p_out: Vec<_> = (0..parties).collect();
let output = mpc(channel, &prg.circuit, &input, fpre, p_eval, party, &p_out)
let output = mpc(channel, &prg.circuit, &input, fpre, p_eval, party, &p_out, true)
.await
.unwrap();
if !output.is_empty() {
Expand Down
2 changes: 1 addition & 1 deletion examples/sql-integration/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ async fn execute_mpc(
recv_count: 0,
}
};
let output = mpc(channel, &prg.circuit, &input, fpre, 0, *party, &p_out).await?;
let output = mpc(channel, &prg.circuit, &input, fpre, 0, *party, &p_out, true).await?;
state.lock().await.senders.clear();
let elapsed = now.elapsed();
info!(
Expand Down
30 changes: 23 additions & 7 deletions src/faand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ pub enum Error {
AANDWrongEFMAC,
/// No Mac or Key
MissingMacKey,
/// Conversion error
ConversionError,
/// Empty bucket
EmptyBucketError,
}

impl From<channel::Error> for Error {
Expand Down Expand Up @@ -383,8 +387,12 @@ pub(crate) async fn fashare(
dm_entry.push(shares[length + r].0 as u8);
for p in 0..p_max {
if p != p_own {
d0[r] ^= shares[length + r].1 .0[p].unwrap().1 .0;
dm_entry.extend(&shares[length + r].1 .0[p].unwrap().0 .0.to_be_bytes());
if let Some((mac, key)) = shares[length + r].1 .0[p] {
d0[r] ^= key.0;
dm_entry.extend(&mac.0.to_be_bytes());
} else {
return Err(Error::MissingMacKey);
}
} else {
dm_entry.extend(&[0; 16]);
}
Expand Down Expand Up @@ -443,10 +451,11 @@ pub(crate) async fn fashare(
for (p, pitem) in dmp.iter().enumerate().take(p_max) {
for pp in (0..p_max).filter(|pp| *pp != p) {
if !pitem[r].is_empty() {
let b: u128 = u128::from_be_bytes(
pitem[r][(1 + pp * 16)..(17 + pp * 16)].try_into().unwrap(),
);
xormacs[pp][r] ^= b;
if let Ok(b) = pitem[r][(1 + pp * 16)..(17 + pp * 16)].try_into().map(u128::from_be_bytes) {
xormacs[pp][r] ^= b;
} else {
return Err(Error::ConversionError);
}
}
}
}
Expand Down Expand Up @@ -967,10 +976,17 @@ pub(crate) async fn combine_bucket(
bucket: Vec<(Share, Share, Share)>,
) -> Result<(Share, Share, Share), Error> {
let mut bucketcopy = bucket.clone();
let mut result = bucketcopy.pop().unwrap();

let mut result = match bucketcopy.pop() {
Some(first_triple) => first_triple,
None => {
return Err(Error::EmptyBucketError);
}
};
while let Some(triple) = bucketcopy.pop() {
result = combine_two_leaky_ands(channel, p_own, p_max, delta, &result, &triple).await?;
}

Ok(result)
}

Expand Down
83 changes: 10 additions & 73 deletions src/ot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,31 +44,27 @@ pub(crate) async fn mpz_ot_sender(
sender
.setup(ctx_sender)
.map_err(OTError::from)
.await
.unwrap();
.await?;

// extend once.
// Extend once.
let num = LPN_PARAMETERS_TEST.k;
sender
.extend(ctx_sender, num)
.map_err(OTError::from)
.await
.unwrap();
.await?;

// extend twice
// Extend twice.
sender
.extend(ctx_sender, count)
.map_err(OTError::from)
.await
.unwrap();
.await?;

let RCOTSenderOutput {
id: sender_id,
msgs: u,
} = sender
.send_random_correlated(ctx_sender, count)
.await
.unwrap();
.await?;

Ok((sender_id, u, block_to_u128(sender.delta())))
}
Expand All @@ -87,91 +83,32 @@ pub(crate) async fn mpz_ot_receiver(
receiver
.setup(ctx_receiver)
.map_err(OTError::from)
.await
.unwrap();
.await?;

// extend once.
let num = LPN_PARAMETERS_TEST.k;
receiver
.extend(ctx_receiver, num)
.map_err(OTError::from)
.await
.unwrap();
.await?;

// extend twice
receiver
.extend(ctx_receiver, count)
.map_err(OTError::from)
.await
.unwrap();
.await?;

let RCOTReceiverOutput {
id: receiver_id,
choices: b,
msgs: w,
} = receiver
.receive_random_correlated(ctx_receiver, count)
.await
.unwrap();
.await?;

Ok((receiver_id, b, w))
}

pub(crate) async fn _mpz_ot(count: usize) -> Result<bool, OTError> {
let lpn_type: LpnType = LpnType::Regular;
let (mut ctx_sender, mut ctx_receiver) = test_st_executor(8);

let (rcot_sender, rcot_receiver) = ideal_rcot();

let config = FerretConfig::new(LPN_PARAMETERS_TEST, lpn_type);

let mut sender = Sender::new(config.clone(), rcot_sender);
let mut receiver = Receiver::new(config, rcot_receiver);

tokio::try_join!(
sender.setup(&mut ctx_sender).map_err(OTError::from),
receiver.setup(&mut ctx_receiver).map_err(OTError::from)
)
.unwrap();

// extend once.
let num = LPN_PARAMETERS_TEST.k;
tokio::try_join!(
sender.extend(&mut ctx_sender, num).map_err(OTError::from),
receiver
.extend(&mut ctx_receiver, num)
.map_err(OTError::from)
)
.unwrap();

// extend twice
tokio::try_join!(
sender.extend(&mut ctx_sender, count).map_err(OTError::from),
receiver
.extend(&mut ctx_receiver, count)
.map_err(OTError::from)
)
.unwrap();

let (
RCOTSenderOutput {
id: _sender_id,
msgs: _u,
},
RCOTReceiverOutput {
id: _receiver_id,
choices: _b,
msgs: _w,
},
) = tokio::try_join!(
sender.send_random_correlated(&mut ctx_sender, count),
receiver.receive_random_correlated(&mut ctx_receiver, count)
)
.unwrap();

Ok(true)
}

pub(crate) async fn generate_ots(
count: usize,
) -> Result<(u128, Vec<u128>, Vec<bool>, Vec<u128>), OTError> {
Expand Down
17 changes: 10 additions & 7 deletions src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ pub(crate) struct GarbledGate(pub(crate) [Vec<u8>; 4]);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub(crate) struct Label(pub(crate) u128);

const TRUSTEDDEALER: bool = false;

impl BitXor for Label {
type Output = Self;

Expand Down Expand Up @@ -178,22 +176,24 @@ pub fn simulate_mpc(
circuit: &Circuit,
inputs: &[&[bool]],
output_parties: &[usize],
trusted: bool,
) -> Result<Vec<bool>, Error> {
let tokio = Runtime::new().expect("Could not start tokio runtime");
tokio.block_on(simulate_mpc_async(circuit, inputs, output_parties))
tokio.block_on(simulate_mpc_async(circuit, inputs, output_parties, trusted))
}

/// Simulates the multi party computation with the given inputs and party 0 as the evaluator.
pub async fn simulate_mpc_async(
circuit: &Circuit,
inputs: &[&[bool]],
output_parties: &[usize],
trusted: bool,
) -> Result<Vec<bool>, Error> {
let p_eval = 0;
let p_pre = inputs.len();

let mut channels: Vec<SimpleChannel>;
if TRUSTEDDEALER {
if trusted {
channels = SimpleChannel::channels(inputs.len() + 1);
tokio::spawn(fpre(channels.pop().unwrap(), inputs.len()));
} else {
Expand All @@ -219,6 +219,7 @@ pub async fn simulate_mpc_async(
p_eval,
p_own,
&output_parties,
trusted,
)
.await;
let mut res_party: Vec<bool> = Vec::new();
Expand All @@ -242,6 +243,7 @@ pub async fn simulate_mpc_async(
p_eval,
p_eval,
output_parties,
trusted,
)
.await;
match eval_result {
Expand Down Expand Up @@ -282,6 +284,7 @@ pub async fn mpc(
p_eval: usize,
p_own: usize,
p_out: &[usize],
trusted: bool,
) -> Result<Vec<bool>, Error> {
let p_max = circuit.input_gates.len();
let is_contrib = p_own != p_eval;
Expand Down Expand Up @@ -334,7 +337,7 @@ pub async fn mpc(
];

let mut delta: Delta = Delta(0);
if TRUSTEDDEALER {
if trusted {
channel.send_to(p_fpre, "delta", &()).await?;
delta = channel.recv_from(p_fpre, "delta").await?;
} else {
Expand Down Expand Up @@ -368,7 +371,7 @@ pub async fn mpc(
let mut labels: Vec<Label> = vec![Label(0); num_gates];
let mut shared_rng = shared_rng(&mut channel, p_own, p_max).await?;

if TRUSTEDDEALER {
if trusted {
channel
.send_to(p_fpre, "random shares", &(secret_bits as u32))
.await?;
Expand Down Expand Up @@ -423,7 +426,7 @@ pub async fn mpc(
}
}

if TRUSTEDDEALER {
if trusted {
channel.send_to(p_fpre, "AND shares", &and_shares).await?;
auth_bits = channel.recv_from(p_fpre, "AND shares").await?;
} else {
Expand Down
18 changes: 9 additions & 9 deletions tests/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ fn eval_xor_circuits_2pc() -> Result<(), Error> {
output_gates: vec![4],
};

let output = simulate_mpc(&circuit, &[&[x, z], &[y]], &output_parties)?;
let output = simulate_mpc(&circuit, &[&[x, z], &[y]], &output_parties, false)?;
assert_eq!(output, vec![x ^ y ^ z]);
}
}
Expand All @@ -36,7 +36,7 @@ fn eval_xor_circuits_3pc() -> Result<(), Error> {
output_gates: vec![4],
};

let output = simulate_mpc(&circuit, &[&[x], &[y], &[z]], &output_parties)?;
let output = simulate_mpc(&circuit, &[&[x], &[y], &[z]], &output_parties, false)?;
assert_eq!(output, vec![x ^ y ^ z]);
}
}
Expand All @@ -55,7 +55,7 @@ fn eval_not_circuits_2pc() -> Result<(), Error> {
output_gates: vec![2, 3, 4, 5],
};

let output = simulate_mpc(&circuit, &[&[x], &[y]], &output_parties)?;
let output = simulate_mpc(&circuit, &[&[x], &[y]], &output_parties, false)?;
assert_eq!(output, vec![!x, !y, x, y]);
}
}
Expand All @@ -81,7 +81,7 @@ fn eval_not_circuits_3pc() -> Result<(), Error> {
output_gates: vec![3, 4, 5, 6, 7, 8],
};

let output = simulate_mpc(&circuit, &[&[x], &[y], &[z]], &output_parties)?;
let output = simulate_mpc(&circuit, &[&[x], &[y], &[z]], &output_parties, false)?;
assert_eq!(output, vec![!x, !y, !z, x, y, z]);
}
}
Expand All @@ -101,7 +101,7 @@ fn eval_and_circuits_2pc() -> Result<(), Error> {
output_gates: vec![4],
};

let output = simulate_mpc(&circuit, &[&[x, z], &[y]], &output_parties)?;
let output = simulate_mpc(&circuit, &[&[x, z], &[y]], &output_parties, false)?;
assert_eq!(output, vec![x & y & z]);
}
}
Expand All @@ -121,7 +121,7 @@ fn eval_and_circuits_3pc() -> Result<(), Error> {
output_gates: vec![4],
};

let output = simulate_mpc(&circuit, &[&[x], &[y], &[z]], &output_parties)?;
let output = simulate_mpc(&circuit, &[&[x], &[y], &[z]], &output_parties, false)?;
assert_eq!(output, vec![x & y & z]);
}
}
Expand All @@ -141,7 +141,7 @@ fn eval_garble_prg_3pc() -> Result<(), Error> {
let x = prg.parse_arg(0, &format!("{x}u8")).unwrap().as_bits();
let y = prg.parse_arg(1, &format!("{y}u8")).unwrap().as_bits();
let z = prg.parse_arg(2, &format!("{z}u8")).unwrap().as_bits();
let output = simulate_mpc(&prg.circuit, &[&x, &y, &z], &output_parties)?;
let output = simulate_mpc(&prg.circuit, &[&x, &y, &z], &output_parties, false)?;
let result = prg.parse_output(&output).unwrap();
println!("{calculation} = {result}");
assert_eq!(format!("{result}"), format!("{expected}u8"));
Expand Down Expand Up @@ -175,7 +175,7 @@ fn eval_large_and_circuit() -> Result<(), Error> {
output_gates,
};

let output_smpc = simulate_mpc(&circuit, &[&in_a, &in_b], &output_parties)?;
let output_smpc = simulate_mpc(&circuit, &[&in_a, &in_b], &output_parties, false)?;
let output_direct = eval_directly(&circuit, &[&in_a, &in_b]);
assert_eq!(output_smpc, vec![true]);
assert_eq!(output_smpc, output_direct);
Expand Down Expand Up @@ -229,7 +229,7 @@ fn eval_mixed_circuits() -> Result<(), Error> {
for (w, (circuit, in_a, in_b)) in circuits_with_inputs.into_iter().enumerate() {
if w % eval_only_every_n == 0 {
total_tests += 1;
let output_smpc = simulate_mpc(&circuit, &[&in_a, &in_b], &output_parties)?;
let output_smpc = simulate_mpc(&circuit, &[&in_a, &in_b], &output_parties, false)?;
let output_direct = eval_directly(&circuit, &[&in_a, &in_b]);
if output_smpc != output_direct {
println!("Circuit: {:?}", circuit);
Expand Down

0 comments on commit 6190746

Please sign in to comment.