Skip to content

Commit

Permalink
Add tests for KeyInit errors
Browse files Browse the repository at this point in the history
  • Loading branch information
fjarri committed Dec 31, 2024
1 parent e0565ce commit 9f0022f
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 37 deletions.
3 changes: 3 additions & 0 deletions synedrion/src/cggmp21.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ mod sigma;
#[cfg(test)]
mod signing_malicious;

#[cfg(test)]
mod key_init_malicious;

pub use aux_gen::{AuxGen, AuxGenProtocol};
pub use entities::{AuxInfo, KeyShare, KeyShareChange};
pub use interactive_signing::{InteractiveSigning, InteractiveSigningProtocol, PrehashedMessage};
Expand Down
63 changes: 27 additions & 36 deletions synedrion/src/cggmp21/key_init.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
//! KeyInit protocol, in the paper ECDSA Key-Generation (Fig. 5).
//! KeyInit protocol, in the paper ECDSA Key-Generation (Fig. 6).
//! Note that this protocol only generates the key itself which is not enough to perform signing;
//! auxiliary parameters need to be generated as well (during the KeyRefresh protocol).
use alloc::{
collections::{BTreeMap, BTreeSet},
format,
string::String,
vec::Vec,
};
Expand All @@ -28,7 +27,7 @@ use crate::{
tools::{
bitvec::BitVec,
hashing::{Chain, FofHasher, HashOutput},
DowncastMap, Secret, Without,
DowncastMap, SafeGet, Secret, Without,
},
};

Expand Down Expand Up @@ -109,7 +108,7 @@ impl<P: SchemeParams, I: PartyId> ProtocolError<I> for KeyInitError<P, I> {
fn required_echo_broadcasts(&self) -> BTreeSet<RoundId> {
match self.error {
KeyInitErrorEnum::R2HashMismatch => [RoundId::new(1)].into(),
KeyInitErrorEnum::R3InvalidSchProof => [].into(),
KeyInitErrorEnum::R3InvalidSchProof => [RoundId::new(2)].into(),
}
}

Expand Down Expand Up @@ -190,12 +189,12 @@ impl<P: SchemeParams, I: PartyId> ProtocolError<I> for KeyInitError<P, I> {
}

#[derive(Debug, Clone, Serialize, Deserialize)]
struct PublicData<P: SchemeParams> {
cap_x: Point,
cap_a: SchCommitment,
rho: BitVec,
u: BitVec,
phantom: PhantomData<P>,
pub(super) struct PublicData<P: SchemeParams> {
pub(super) cap_x: Point,
pub(super) cap_a: SchCommitment,
pub(super) rho: BitVec,
pub(super) u: BitVec,
pub(super) phantom: PhantomData<P>,
}

impl<P: SchemeParams> PublicData<P> {
Expand Down Expand Up @@ -244,7 +243,6 @@ impl<P: SchemeParams, I: PartyId> EntryPoint<I> for KeyInit<P, I> {
let sid_hash = FofHasher::new_with_dst(b"SID")
.chain_type::<P>()
.chain(&shared_randomness)
.chain(&self.all_ids)
.finalize();

// The secret share
Expand Down Expand Up @@ -279,13 +277,13 @@ impl<P: SchemeParams, I: PartyId> EntryPoint<I> for KeyInit<P, I> {
}

#[derive(Debug)]
struct Context<P: SchemeParams, I> {
other_ids: BTreeSet<I>,
my_id: I,
x: Secret<Scalar>,
tau: SchSecret,
public_data: PublicData<P>,
sid_hash: HashOutput,
pub(super) struct Context<P: SchemeParams, I> {
pub(super) other_ids: BTreeSet<I>,
pub(super) my_id: I,
pub(super) x: Secret<Scalar>,
pub(super) tau: SchSecret,
pub(super) public_data: PublicData<P>,
pub(super) sid_hash: HashOutput,
}

#[derive(Debug)]
Expand Down Expand Up @@ -374,8 +372,8 @@ struct Round2<P: SchemeParams, I> {
#[derive(Clone, Serialize, Deserialize)]
#[serde(bound(serialize = "PublicData<P>: Serialize"))]
#[serde(bound(deserialize = "PublicData<P>: for<'x> Deserialize<'x>"))]
struct Round2EchoBroadcast<P: SchemeParams> {
data: PublicData<P>,
pub(super) struct Round2EchoBroadcast<P: SchemeParams> {
pub(super) data: PublicData<P>,
}

struct Round2Payload<P: SchemeParams> {
Expand Down Expand Up @@ -426,10 +424,7 @@ impl<P: SchemeParams, I: PartyId> Round<I> for Round2<P, I> {
normal_broadcast.assert_is_none()?;
direct_message.assert_is_none()?;
let echo = echo_broadcast.deserialize::<Round2EchoBroadcast<P>>(deserializer)?;
let cap_v = self
.others_cap_v
.get(from)
.ok_or_else(|| LocalError::new(format!("Missing `V` for {from:?}")))?;
let cap_v = self.others_cap_v.safe_get("vector `V`", from)?;

if &echo.data.hash(&self.context.sid_hash, from) != cap_v {
return Err(ReceiveError::protocol(KeyInitError::new(
Expand Down Expand Up @@ -466,16 +461,16 @@ impl<P: SchemeParams, I: PartyId> Round<I> for Round2<P, I> {
}

#[derive(Debug)]
struct Round3<P: SchemeParams, I> {
context: Context<P, I>,
others_data: BTreeMap<I, PublicData<P>>,
rho: BitVec,
phantom: PhantomData<P>,
pub(super) struct Round3<P: SchemeParams, I> {
pub(super) context: Context<P, I>,
pub(super) others_data: BTreeMap<I, PublicData<P>>,
pub(super) rho: BitVec,
pub(super) phantom: PhantomData<P>,
}

#[derive(Clone, Serialize, Deserialize)]
struct Round3Broadcast {
psi: SchProof,
pub(super) struct Round3Broadcast {
pub(super) psi: SchProof,
}

impl<P: SchemeParams, I: PartyId> Round<I> for Round3<P, I> {
Expand Down Expand Up @@ -528,13 +523,9 @@ impl<P: SchemeParams, I: PartyId> Round<I> for Round3<P, I> {
) -> Result<Payload, ReceiveError<I, Self::Protocol>> {
echo_broadcast.assert_is_none()?;
direct_message.assert_is_none()?;

let bc = normal_broadcast.deserialize::<Round3Broadcast>(deserializer)?;

let data = self
.others_data
.get(from)
.ok_or_else(|| LocalError::new(format!("Missing data for {from:?}")))?;
let data = self.others_data.safe_get("other nodes' public data", from)?;

let aux = (&self.context.sid_hash, from, &self.rho);
if !bc.psi.verify(&data.cap_a, &data.cap_x, &aux) {
Expand Down
143 changes: 143 additions & 0 deletions synedrion/src/cggmp21/key_init_malicious.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
use alloc::collections::{BTreeMap, BTreeSet};
use core::marker::PhantomData;

use manul::{
combinators::misbehave::{Misbehaving, MisbehavingEntryPoint},
dev::{run_sync, BinaryFormat, TestSessionParams, TestSigner, TestVerifier},
protocol::{
BoxedRound, Deserializer, EchoBroadcast, EntryPoint, LocalError, NormalBroadcast, PartyId, ProtocolMessagePart,
RoundId, Serializer,
},
session::SessionReport,
signature::Keypair,
};
use rand_core::{CryptoRngCore, OsRng};

use super::{
key_init::{KeyInit, KeyInitProtocol, Round2EchoBroadcast, Round3, Round3Broadcast},
params::{SchemeParams, TestParams},
sigma::SchProof,
};
use crate::{
curve::Scalar,
tools::{bitvec::BitVec, Secret},
};

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Behavior {
R2RandomU,
R3InvalidSchProof,
}

struct MaliciousKeyInitOverride<P>(PhantomData<P>);

impl<P: SchemeParams, Id: PartyId> Misbehaving<Id, Behavior> for MaliciousKeyInitOverride<P> {
type EntryPoint = KeyInit<P, Id>;

fn modify_echo_broadcast(
rng: &mut impl CryptoRngCore,
round: &BoxedRound<Id, <Self::EntryPoint as EntryPoint<Id>>::Protocol>,
behavior: &Behavior,
serializer: &Serializer,
deserializer: &Deserializer,
echo_broadcast: EchoBroadcast,
) -> Result<EchoBroadcast, LocalError> {
if round.id() == RoundId::new(2) && behavior == &Behavior::R2RandomU {
let orig_message = echo_broadcast
.deserialize::<Round2EchoBroadcast<P>>(deserializer)
.unwrap();
let mut data = orig_message.data;

// Replace `u` with something other than we committed to when hashing it in Round 1.
data.u = BitVec::random(rng, data.u.bits().len());

let message = Round2EchoBroadcast { data };
return EchoBroadcast::new(serializer, message);
}

Ok(echo_broadcast)
}

fn modify_normal_broadcast(
rng: &mut impl CryptoRngCore,
round: &BoxedRound<Id, <Self::EntryPoint as EntryPoint<Id>>::Protocol>,
behavior: &Behavior,
serializer: &Serializer,
_deserializer: &Deserializer,
normal_broadcast: NormalBroadcast,
) -> Result<NormalBroadcast, LocalError> {
if round.id() == RoundId::new(3) && behavior == &Behavior::R3InvalidSchProof {
let round3 = round.downcast_ref::<Round3<P, Id>>()?;
let context = &round3.context;
let aux = (&context.sid_hash, &context.my_id, &round3.rho);

// Make a proof for a random secret. This won't pass verification.
let x = Secret::init_with(|| Scalar::random(rng));
let psi = SchProof::new(
&context.tau,
&x,
&context.public_data.cap_a,
&x.mul_by_generator(),
&aux,
);

let message = Round3Broadcast { psi };
return NormalBroadcast::new(serializer, message);
}

Ok(normal_broadcast)
}
}

type MaliciousKeyEP<P, Id> = MisbehavingEntryPoint<Id, Behavior, MaliciousKeyInitOverride<P>>;

type Protocol = KeyInitProtocol<TestParams, TestVerifier>;
type SP = TestSessionParams<BinaryFormat>;

fn run_with_one_malicious_party(
behavior: Behavior,
) -> (Vec<TestVerifier>, BTreeMap<TestVerifier, SessionReport<Protocol, SP>>) {
let signers = (0..3).map(TestSigner::new).collect::<Vec<_>>();
let ids = signers.iter().map(|signer| signer.verifying_key()).collect::<Vec<_>>();
let ids_set = BTreeSet::from_iter(ids.clone());

let entry_points = signers
.into_iter()
.map(|signer| {
let id = signer.verifying_key();
let entry_point = KeyInit::<TestParams, TestVerifier>::new(ids_set.clone()).unwrap();
let behavior = if id == ids[0] { Some(behavior) } else { None };
let entry_point = MaliciousKeyEP::new(entry_point, behavior);
(signer, entry_point)
})
.collect();

let reports = run_sync::<_, SP>(&mut OsRng, entry_points).unwrap().reports;
(ids, reports)
}

#[test]
fn r2_hash_mismatch() {
let (ids, mut reports) = run_with_one_malicious_party(Behavior::R2RandomU);

let report0 = reports.remove(&ids[0]).unwrap();
let report1 = reports.remove(&ids[1]).unwrap();
let report2 = reports.remove(&ids[2]).unwrap();

assert!(report0.provable_errors.is_empty());
assert!(report1.provable_errors[&ids[0]].verify().is_ok());
assert!(report2.provable_errors[&ids[0]].verify().is_ok());
}

#[test]
fn r3_invalid_sch_proof() {
let (ids, mut reports) = run_with_one_malicious_party(Behavior::R3InvalidSchProof);

let report0 = reports.remove(&ids[0]).unwrap();
let report1 = reports.remove(&ids[1]).unwrap();
let report2 = reports.remove(&ids[2]).unwrap();

assert!(report0.provable_errors.is_empty());
assert!(report1.provable_errors[&ids[0]].verify().is_ok());
assert!(report2.provable_errors[&ids[0]].verify().is_ok());
}
17 changes: 16 additions & 1 deletion synedrion/src/tools.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
use alloc::collections::{BTreeMap, BTreeSet};
use alloc::{
collections::{BTreeMap, BTreeSet},
format,
};
use core::fmt::Debug;

pub(crate) mod bitvec;
pub(crate) mod hashing;
Expand Down Expand Up @@ -45,3 +49,14 @@ impl<K: Ord> DowncastMap for BTreeMap<K, Artifact> {
.collect::<Result<_, _>>()
}
}

pub(crate) trait SafeGet<K, V> {
fn safe_get(&self, container: &str, key: &K) -> Result<&V, LocalError>;
}

impl<K: Ord + Debug, V> SafeGet<K, V> for BTreeMap<K, V> {
fn safe_get(&self, container: &str, key: &K) -> Result<&V, LocalError> {
self.get(key)
.ok_or_else(|| LocalError::new(format!("Key {key:?} not found in {container}")))
}
}

0 comments on commit 9f0022f

Please sign in to comment.