diff --git a/synedrion/src/cggmp21.rs b/synedrion/src/cggmp21.rs index 06571347..e5b56dbb 100644 --- a/synedrion/src/cggmp21.rs +++ b/synedrion/src/cggmp21.rs @@ -19,6 +19,9 @@ mod sigma; #[cfg(test)] mod signing_malicious; +#[cfg(test)] +mod key_init_malicious; + pub use aux_gen::{AuxGen, AuxGenProtocol}; pub use entities::{AuxInfo, KeyShare, KeyShareChange}; pub use interactive_signing::{InteractiveSigning, InteractiveSigningProtocol, PrehashedMessage}; diff --git a/synedrion/src/cggmp21/entities.rs b/synedrion/src/cggmp21/entities.rs index b2aa0a54..6119b026 100644 --- a/synedrion/src/cggmp21/entities.rs +++ b/synedrion/src/cggmp21/entities.rs @@ -24,12 +24,12 @@ use crate::{ /// The result of the KeyInit protocol. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct KeyShare { - pub(crate) owner: I, + owner: I, /// Secret key share of this node. - pub(crate) secret_share: Secret, // `x_i` - pub(crate) public_shares: BTreeMap, // `X_j` + secret_share: Secret, // `x_i` + public_shares: BTreeMap, // `X_j` // TODO (#27): this won't be needed when Scalar/Point are a part of `P` - pub(crate) phantom: PhantomData

, + phantom: PhantomData

, } /// The result of the AuxGen protocol. @@ -125,7 +125,23 @@ pub(crate) struct PresigningValues { pub(crate) hat_cap_f: Ciphertext, } -impl KeyShare { +impl KeyShare { + pub(crate) fn new( + owner: I, + secret_share: Secret, + public_shares: BTreeMap, + ) -> Result { + if public_shares.values().sum::() == Point::IDENTITY { + return Err(LocalError::new("Key shares add up to zero")); + } + Ok(KeyShare { + owner, + secret_share, + public_shares, + phantom: PhantomData, + }) + } + /// Updates a key share with a change obtained from KeyRefresh protocol. pub fn update(self, change: KeyShareChange) -> Result { if self.owner != change.owner { @@ -217,6 +233,14 @@ impl KeyShare { &self.owner } + pub(crate) fn secret_share(&self) -> &Secret { + &self.secret_share + } + + pub(crate) fn public_shares(&self) -> &BTreeMap { + &self.public_shares + } + /// Returns the set of parties holding other shares from the set. pub fn all_parties(&self) -> BTreeSet { self.public_shares.keys().cloned().collect() diff --git a/synedrion/src/cggmp21/interactive_signing.rs b/synedrion/src/cggmp21/interactive_signing.rs index fe6d6958..5dfbb8ed 100644 --- a/synedrion/src/cggmp21/interactive_signing.rs +++ b/synedrion/src/cggmp21/interactive_signing.rs @@ -152,7 +152,7 @@ impl EntryPoint for InteractiveSigning { } let other_ids = key_share - .public_shares + .public_shares() .keys() .cloned() .collect::>() @@ -164,7 +164,7 @@ impl EntryPoint for InteractiveSigning { let ssid_hash = FofHasher::new_with_dst(b"ShareSetID") .chain_type::

() .chain(&shared_randomness) - .chain(&key_share.public_shares) + .chain(&key_share.public_shares()) .chain(&aux_info.public_aux) .finalize(); @@ -221,11 +221,11 @@ struct Context { impl Context where P: SchemeParams, - I: Ord + Debug, + I: Clone + Ord + Debug, { pub fn public_share(&self, i: &I) -> Result<&Point, LocalError> { self.key_share - .public_shares + .public_shares() .get(i) .ok_or_else(|| LocalError::new("Missing public_share for party Id {i:?}")) } @@ -507,7 +507,7 @@ impl Round for Round2 { let hat_s = Randomizer::random(rng, target_pk); let gamma = secret_signed_from_scalar::

(&self.context.gamma); - let x = secret_signed_from_scalar::

(&self.context.key_share.secret_share); + let x = secret_signed_from_scalar::

(self.context.key_share.secret_share()); let others_cap_k = self .all_cap_k @@ -518,7 +518,7 @@ impl Round for Round2 { let cap_d = others_cap_k * &gamma + Ciphertext::new_with_randomizer(target_pk, &-&beta, &s); let hat_cap_f = Ciphertext::new_with_randomizer(pk, &hat_beta, &hat_r); - let hat_cap_d = others_cap_k * &secret_signed_from_scalar::

(&self.context.key_share.secret_share) + let hat_cap_d = others_cap_k * &secret_signed_from_scalar::

(self.context.key_share.secret_share()) + Ciphertext::new_with_randomizer(target_pk, &-&hat_beta, &hat_s); let cap_g = self.all_cap_g.get(&self.context.my_id).ok_or(LocalError::new(format!( @@ -737,7 +737,7 @@ impl Round for Round2 { let hat_alpha_sum: SecretSigned<_> = payloads.values().map(|payload| &payload.hat_alpha).sum(); let hat_beta_sum: SecretSigned<_> = artifacts.values().map(|artifact| &artifact.hat_beta).sum(); - let chi = secret_signed_from_scalar::

(&self.context.key_share.secret_share) + let chi = secret_signed_from_scalar::

(self.context.key_share.secret_share()) * secret_signed_from_scalar::

(&self.context.k) + &hat_alpha_sum + &hat_beta_sum; @@ -1257,7 +1257,7 @@ impl Round for Round4 { let p_aff_g = AffGProof::

::new( rng, AffGSecretInputs { - x: &secret_signed_from_scalar::

(&self.context.key_share.secret_share), + x: &secret_signed_from_scalar::

(self.context.key_share.secret_share()), y: &values.hat_beta, rho: &values.hat_s, rho_y: &values.hat_r, @@ -1293,7 +1293,7 @@ impl Round for Round4 { // mul* proofs - let x = &self.context.key_share.secret_share; + let x = &self.context.key_share.secret_share(); let cap_x = self.context.public_share(&my_id)?; let rho = Randomizer::random(rng, pk); diff --git a/synedrion/src/cggmp21/key_init.rs b/synedrion/src/cggmp21/key_init.rs index 892a134b..b767544a 100644 --- a/synedrion/src/cggmp21/key_init.rs +++ b/synedrion/src/cggmp21/key_init.rs @@ -1,17 +1,21 @@ -//! KeyInit protocol, in the paper ECDSA Key-Generation (Fig. 5). +//! KeyInit protocol, in the paper ECDSA Key-Generation (Fig. 6). //! Note that this protocol only generates the key itself which is not enough to perform signing; //! auxiliary parameters need to be generated as well (during the KeyRefresh protocol). use alloc::{ collections::{BTreeMap, BTreeSet}, - format, + vec::Vec, +}; +use core::{ + fmt::{self, Debug, Display}, + marker::PhantomData, }; -use core::{fmt::Debug, marker::PhantomData}; use manul::protocol::{ Artifact, BoxedRound, Deserializer, DirectMessage, EchoBroadcast, EntryPoint, FinalizeOutcome, LocalError, MessageValidationError, NormalBroadcast, PartyId, Payload, Protocol, ProtocolError, ProtocolMessage, - ProtocolMessagePart, ProtocolValidationError, ReceiveError, RequiredMessages, Round, RoundId, Serializer, + ProtocolMessagePart, ProtocolValidationError, ReceiveError, RequiredMessageParts, RequiredMessages, Round, RoundId, + Serializer, }; use rand_core::CryptoRngCore; use serde::{Deserialize, Serialize}; @@ -26,7 +30,7 @@ use crate::{ tools::{ bitvec::BitVec, hashing::{Chain, FofHasher, HashOutput}, - DowncastMap, Secret, Without, + DowncastMap, SafeGet, Secret, Without, }, }; @@ -36,70 +40,176 @@ pub struct KeyInitProtocol(PhantomData<(P, I)>); impl Protocol for KeyInitProtocol { type Result = KeyShare; - type ProtocolError = KeyInitError; + type ProtocolError = KeyInitError

; fn verify_direct_message_is_invalid( _deserializer: &Deserializer, - _round_id: &RoundId, - _message: &DirectMessage, + round_id: &RoundId, + message: &DirectMessage, ) -> Result<(), MessageValidationError> { - unimplemented!() + if round_id == &RoundId::new(1) || round_id == &RoundId::new(2) || round_id == &RoundId::new(3) { + message.verify_is_some() + } else { + Err(MessageValidationError::InvalidEvidence("Invalid round number".into())) + } } fn verify_echo_broadcast_is_invalid( - _deserializer: &Deserializer, - _round_id: &RoundId, - _message: &EchoBroadcast, + deserializer: &Deserializer, + round_id: &RoundId, + message: &EchoBroadcast, ) -> Result<(), MessageValidationError> { - unimplemented!() + match round_id { + r if r == &RoundId::new(1) => message.verify_is_not::(deserializer), + r if r == &RoundId::new(2) => message.verify_is_not::>(deserializer), + r if r == &RoundId::new(3) => message.verify_is_some(), + _ => Err(MessageValidationError::InvalidEvidence("Invalid round number".into())), + } } fn verify_normal_broadcast_is_invalid( - _deserializer: &Deserializer, - _round_id: &RoundId, - _message: &NormalBroadcast, + deserializer: &Deserializer, + round_id: &RoundId, + message: &NormalBroadcast, ) -> Result<(), MessageValidationError> { - unimplemented!() + match round_id { + r if r == &RoundId::new(1) => message.verify_is_some(), + r if r == &RoundId::new(2) => message.verify_is_some(), + r if r == &RoundId::new(3) => message.verify_is_not::(deserializer), + _ => Err(MessageValidationError::InvalidEvidence("Invalid round number".into())), + } } } /// Possible verifiable errors of the KeyGen protocol. +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +pub struct KeyInitError

{ + error: KeyInitErrorEnum, + phantom: PhantomData

, +} + +impl

Display for KeyInitError

{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + write!(f, "{}", self.error) + } +} + +impl

KeyInitError

{ + fn new(error: KeyInitErrorEnum) -> Self { + Self { + error, + phantom: PhantomData, + } + } +} + #[derive(displaydoc::Display, Debug, Clone, Copy, Serialize, Deserialize)] -pub enum KeyInitError { +enum KeyInitErrorEnum { /// A hash mismatch in Round 2. R2HashMismatch, /// Failed to verify `П^sch` in Round 3. R3InvalidSchProof, } -impl ProtocolError for KeyInitError { +impl ProtocolError for KeyInitError

{ type AssociatedData = (); fn required_messages(&self) -> RequiredMessages { - unimplemented!() + match self.error { + KeyInitErrorEnum::R2HashMismatch => RequiredMessages::new( + RequiredMessageParts::echo_broadcast_only(), + Some([(RoundId::new(1), RequiredMessageParts::echo_broadcast_only())].into()), + None, + ), + KeyInitErrorEnum::R3InvalidSchProof => RequiredMessages::new( + RequiredMessageParts::normal_broadcast_only(), + Some([(RoundId::new(2), RequiredMessageParts::echo_broadcast_only())].into()), + Some([RoundId::new(2)].into()), + ), + } } fn verify_messages_constitute_error( &self, - _deserializer: &Deserializer, - _guilty_party: &I, - _shared_randomness: &[u8], + deserializer: &Deserializer, + guilty_party: &I, + shared_randomness: &[u8], _associated_data: &Self::AssociatedData, - _message: ProtocolMessage, - _previous_messages: BTreeMap, - _combined_echos: BTreeMap>, + message: ProtocolMessage, + previous_messages: BTreeMap, + combined_echos: BTreeMap>, ) -> Result<(), ProtocolValidationError> { - unimplemented!() + let sid_hash = FofHasher::new_with_dst(b"SID") + .chain_type::

() + .chain(&shared_randomness) + .finalize(); + + match self.error { + KeyInitErrorEnum::R2HashMismatch => { + let r1_serialized = &previous_messages + .get(&RoundId::new(1)) + .ok_or_else(|| { + ProtocolValidationError::InvalidEvidence("Missing echo broadcast from Round 1".into()) + })? + .echo_broadcast; + let r1_message = r1_serialized.deserialize::(deserializer)?; + let r2_message = message + .echo_broadcast + .deserialize::>(deserializer)?; + if r2_message.data.hash(&sid_hash, guilty_party) != r1_message.cap_v { + Ok(()) + } else { + Err(ProtocolValidationError::InvalidEvidence( + "The received hash is valid".into(), + )) + } + } + KeyInitErrorEnum::R3InvalidSchProof => { + let r2_combined = combined_echos.get(&RoundId::new(2)).ok_or_else(|| { + ProtocolValidationError::InvalidEvidence("Missing combined echos from Round 2".into()) + })?; + let r2_messages = r2_combined + .values() + .map(|echo| echo.deserialize::>(deserializer)) + .collect::, _>>()?; + + let r2_serialized = &previous_messages + .get(&RoundId::new(2)) + .ok_or_else(|| { + ProtocolValidationError::InvalidEvidence("Missing echo broadcast from Round 2".into()) + })? + .echo_broadcast; + let r2_message = r2_serialized.deserialize::>(deserializer)?; + + let mut rho = r2_message.data.rho; + for message in r2_messages { + rho ^= &message.data.rho; + } + + let r3_message = message.normal_broadcast.deserialize::(deserializer)?; + let aux = (&sid_hash, guilty_party, &rho); + if !r3_message + .psi + .verify(&r2_message.data.cap_a, &r2_message.data.cap_x, &aux) + { + Ok(()) + } else { + Err(ProtocolValidationError::InvalidEvidence( + "The Schnorr proof is valid".into(), + )) + } + } + } } } #[derive(Debug, Clone, Serialize, Deserialize)] -struct PublicData { - cap_x: Point, - cap_a: SchCommitment, - rid: BitVec, - u: BitVec, - phantom: PhantomData

, +pub(super) struct PublicData { + pub(super) cap_x: Point, + pub(super) cap_a: SchCommitment, + pub(super) rho: BitVec, + pub(super) u: BitVec, + pub(super) phantom: PhantomData

, } impl PublicData

{ @@ -148,7 +258,6 @@ impl EntryPoint for KeyInit { let sid_hash = FofHasher::new_with_dst(b"SID") .chain_type::

() .chain(&shared_randomness) - .chain(&self.all_ids) .finalize(); // The secret share @@ -156,7 +265,7 @@ impl EntryPoint for KeyInit { // The public share let cap_x = x.mul_by_generator(); - let rid = BitVec::random(rng, P::SECURITY_PARAMETER); + let rho = BitVec::random(rng, P::SECURITY_PARAMETER); let tau = SchSecret::random(rng); let cap_a = SchCommitment::new(&tau); let u = BitVec::random(rng, P::SECURITY_PARAMETER); @@ -164,7 +273,7 @@ impl EntryPoint for KeyInit { let public_data = PublicData { cap_x, cap_a, - rid, + rho, u, phantom: PhantomData, }; @@ -183,13 +292,13 @@ impl EntryPoint for KeyInit { } #[derive(Debug)] -struct Context { - other_ids: BTreeSet, - my_id: I, - x: Secret, - tau: SchSecret, - public_data: PublicData

, - sid_hash: HashOutput, +pub(super) struct Context { + pub(super) other_ids: BTreeSet, + pub(super) my_id: I, + pub(super) x: Secret, + pub(super) tau: SchSecret, + pub(super) public_data: PublicData

, + pub(super) sid_hash: HashOutput, } #[derive(Debug)] @@ -198,7 +307,7 @@ struct Round1 { } #[derive(Debug, Clone, Serialize, Deserialize)] -struct Round1Message { +struct Round1EchoBroadcast { cap_v: HashOutput, } @@ -234,7 +343,7 @@ impl Round for Round1 { .context .public_data .hash(&self.context.sid_hash, &self.context.my_id); - EchoBroadcast::new(serializer, Round1Message { cap_v }) + EchoBroadcast::new(serializer, Round1EchoBroadcast { cap_v }) } fn receive_message( @@ -246,7 +355,9 @@ impl Round for Round1 { ) -> Result> { message.normal_broadcast.assert_is_none()?; message.direct_message.assert_is_none()?; - let echo = message.echo_broadcast.deserialize::(deserializer)?; + let echo = message + .echo_broadcast + .deserialize::(deserializer)?; Ok(Payload::new(Round1Payload { cap_v: echo.cap_v })) } @@ -276,8 +387,8 @@ struct Round2 { #[derive(Clone, Serialize, Deserialize)] #[serde(bound(serialize = "PublicData

: Serialize"))] #[serde(bound(deserialize = "PublicData

: for<'x> Deserialize<'x>"))] -struct Round2Message { - data: PublicData

, +pub(super) struct Round2EchoBroadcast { + pub(super) data: PublicData

, } struct Round2Payload { @@ -310,7 +421,7 @@ impl Round for Round2 { ) -> Result { EchoBroadcast::new( serializer, - Round2Message { + Round2EchoBroadcast { data: self.context.public_data.clone(), }, ) @@ -325,14 +436,15 @@ impl Round for Round2 { ) -> Result> { message.normal_broadcast.assert_is_none()?; message.direct_message.assert_is_none()?; - let echo = message.echo_broadcast.deserialize::>(deserializer)?; - let cap_v = self - .others_cap_v - .get(from) - .ok_or_else(|| LocalError::new(format!("Missing `V` for {from:?}")))?; + let echo = message + .echo_broadcast + .deserialize::>(deserializer)?; + let cap_v = self.others_cap_v.safe_get("vector `V`", from)?; if &echo.data.hash(&self.context.sid_hash, from) != cap_v { - return Err(ReceiveError::protocol(KeyInitError::R2HashMismatch)); + return Err(ReceiveError::protocol(KeyInitError::new( + KeyInitErrorEnum::R2HashMismatch, + ))); } Ok(Payload::new(Round2Payload { data: echo.data })) @@ -344,12 +456,12 @@ impl Round for Round2 { payloads: BTreeMap, _artifacts: BTreeMap, ) -> Result, LocalError> { - let mut rid = self.context.public_data.rid.clone(); + let mut rho = self.context.public_data.rho.clone(); let payloads = payloads.downcast_all::>()?; for payload in payloads.values() { - rid ^= &payload.data.rid; + rho ^= &payload.data.rho; } let others_data = payloads.into_iter().map(|(k, v)| (k, v.data)).collect(); @@ -357,23 +469,23 @@ impl Round for Round2 { Ok(FinalizeOutcome::AnotherRound(BoxedRound::new_dynamic(Round3 { context: self.context, others_data, - rid, + rho, phantom: PhantomData, }))) } } #[derive(Debug)] -struct Round3 { - context: Context, - others_data: BTreeMap>, - rid: BitVec, - phantom: PhantomData

, +pub(super) struct Round3 { + pub(super) context: Context, + pub(super) others_data: BTreeMap>, + pub(super) rho: BitVec, + pub(super) phantom: PhantomData

, } #[derive(Clone, Serialize, Deserialize)] -struct Round3Message { - psi: SchProof, +pub(super) struct Round3Broadcast { + pub(super) psi: SchProof, } impl Round for Round3 { @@ -404,7 +516,7 @@ impl Round for Round3 { _rng: &mut impl CryptoRngCore, serializer: &Serializer, ) -> Result { - let aux = (&self.context.sid_hash, &self.context.my_id, &self.rid); + let aux = (&self.context.sid_hash, &self.context.my_id, &self.rho); let psi = SchProof::new( &self.context.tau, &self.context.x, @@ -412,7 +524,7 @@ impl Round for Round3 { &self.context.public_data.cap_x, &aux, ); - NormalBroadcast::new(serializer, Round3Message { psi }) + NormalBroadcast::new(serializer, Round3Broadcast { psi }) } fn receive_message( @@ -424,17 +536,15 @@ impl Round for Round3 { ) -> Result> { message.echo_broadcast.assert_is_none()?; message.direct_message.assert_is_none()?; + let bc = message.normal_broadcast.deserialize::(deserializer)?; - let bc = message.normal_broadcast.deserialize::(deserializer)?; + let data = self.others_data.safe_get("other nodes' public data", from)?; - let data = self - .others_data - .get(from) - .ok_or_else(|| LocalError::new(format!("Missing data for {from:?}")))?; - - let aux = (&self.context.sid_hash, from, &self.rid); + let aux = (&self.context.sid_hash, from, &self.rho); if !bc.psi.verify(&data.cap_a, &data.cap_x, &aux) { - return Err(ReceiveError::protocol(KeyInitError::R3InvalidSchProof)); + return Err(ReceiveError::protocol(KeyInitError::new( + KeyInitErrorEnum::R3InvalidSchProof, + ))); } Ok(Payload::empty()) } @@ -452,12 +562,13 @@ impl Round for Round3 { .map(|(k, v)| (k, v.cap_x)) .collect::>(); public_shares.insert(my_id.clone(), self.context.public_data.cap_x); - Ok(FinalizeOutcome::Result(KeyShare { - owner: my_id, - secret_share: self.context.x, - public_shares, - phantom: PhantomData, - })) + + // This can fail if the shares add up to zero. + // Can't really protect from it, and it should be extremely rare. + // If that happens one can only restart the whole thing. + let key_share = KeyShare::::new(my_id, self.context.x, public_shares)?; + + Ok(FinalizeOutcome::Result(key_share)) } } @@ -500,7 +611,7 @@ mod tests { let public_sets = shares .iter() - .map(|(id, share)| (*id, share.public_shares.clone())) + .map(|(id, share)| (*id, share.public_shares().clone())) .collect::>(); assert!(public_sets.values().all(|pk| pk == &public_sets[&id0])); @@ -510,7 +621,7 @@ mod tests { let public_from_secret = shares .into_iter() - .map(|(id, share)| (id, share.secret_share.mul_by_generator())) + .map(|(id, share)| (id, share.secret_share().mul_by_generator())) .collect(); assert!(public_set == &public_from_secret); diff --git a/synedrion/src/cggmp21/key_init_malicious.rs b/synedrion/src/cggmp21/key_init_malicious.rs new file mode 100644 index 00000000..1bba10ac --- /dev/null +++ b/synedrion/src/cggmp21/key_init_malicious.rs @@ -0,0 +1,143 @@ +use alloc::collections::{BTreeMap, BTreeSet}; +use core::marker::PhantomData; + +use manul::{ + combinators::misbehave::{Misbehaving, MisbehavingEntryPoint}, + dev::{run_sync, BinaryFormat, TestSessionParams, TestSigner, TestVerifier}, + protocol::{ + BoxedRound, Deserializer, EchoBroadcast, EntryPoint, LocalError, NormalBroadcast, PartyId, ProtocolMessagePart, + RoundId, Serializer, + }, + session::SessionReport, + signature::Keypair, +}; +use rand_core::{CryptoRngCore, OsRng}; + +use super::{ + key_init::{KeyInit, KeyInitProtocol, Round2EchoBroadcast, Round3, Round3Broadcast}, + params::{SchemeParams, TestParams}, + sigma::SchProof, +}; +use crate::{ + curve::Scalar, + tools::{bitvec::BitVec, Secret}, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum Behavior { + R2RandomU, + R3InvalidSchProof, +} + +struct MaliciousKeyInitOverride

(PhantomData

); + +impl Misbehaving for MaliciousKeyInitOverride

{ + type EntryPoint = KeyInit; + + fn modify_echo_broadcast( + rng: &mut impl CryptoRngCore, + round: &BoxedRound>::Protocol>, + behavior: &Behavior, + serializer: &Serializer, + deserializer: &Deserializer, + echo_broadcast: EchoBroadcast, + ) -> Result { + if round.id() == RoundId::new(2) && behavior == &Behavior::R2RandomU { + let orig_message = echo_broadcast + .deserialize::>(deserializer) + .unwrap(); + let mut data = orig_message.data; + + // Replace `u` with something other than we committed to when hashing it in Round 1. + data.u = BitVec::random(rng, data.u.bits().len()); + + let message = Round2EchoBroadcast { data }; + return EchoBroadcast::new(serializer, message); + } + + Ok(echo_broadcast) + } + + fn modify_normal_broadcast( + rng: &mut impl CryptoRngCore, + round: &BoxedRound>::Protocol>, + behavior: &Behavior, + serializer: &Serializer, + _deserializer: &Deserializer, + normal_broadcast: NormalBroadcast, + ) -> Result { + if round.id() == RoundId::new(3) && behavior == &Behavior::R3InvalidSchProof { + let round3 = round.downcast_ref::>()?; + let context = &round3.context; + let aux = (&context.sid_hash, &context.my_id, &round3.rho); + + // Make a proof for a random secret. This won't pass verification. + let x = Secret::init_with(|| Scalar::random(rng)); + let psi = SchProof::new( + &context.tau, + &x, + &context.public_data.cap_a, + &x.mul_by_generator(), + &aux, + ); + + let message = Round3Broadcast { psi }; + return NormalBroadcast::new(serializer, message); + } + + Ok(normal_broadcast) + } +} + +type MaliciousKeyEP = MisbehavingEntryPoint>; + +type Protocol = KeyInitProtocol; +type SP = TestSessionParams; + +fn run_with_one_malicious_party( + behavior: Behavior, +) -> (Vec, BTreeMap>) { + let signers = (0..3).map(TestSigner::new).collect::>(); + let ids = signers.iter().map(|signer| signer.verifying_key()).collect::>(); + let ids_set = BTreeSet::from_iter(ids.clone()); + + let entry_points = signers + .into_iter() + .map(|signer| { + let id = signer.verifying_key(); + let entry_point = KeyInit::::new(ids_set.clone()).unwrap(); + let behavior = if id == ids[0] { Some(behavior) } else { None }; + let entry_point = MaliciousKeyEP::new(entry_point, behavior); + (signer, entry_point) + }) + .collect(); + + let reports = run_sync::<_, SP>(&mut OsRng, entry_points).unwrap().reports; + (ids, reports) +} + +#[test] +fn r2_hash_mismatch() { + let (ids, mut reports) = run_with_one_malicious_party(Behavior::R2RandomU); + + let report0 = reports.remove(&ids[0]).unwrap(); + let report1 = reports.remove(&ids[1]).unwrap(); + let report2 = reports.remove(&ids[2]).unwrap(); + + assert!(report0.provable_errors.is_empty()); + assert!(report1.provable_errors[&ids[0]].verify(&()).is_ok()); + assert!(report2.provable_errors[&ids[0]].verify(&()).is_ok()); +} + +#[test] +fn r3_invalid_sch_proof() { + let (ids, mut reports) = run_with_one_malicious_party(Behavior::R3InvalidSchProof); + + let report0 = reports.remove(&ids[0]).unwrap(); + let report1 = reports.remove(&ids[1]).unwrap(); + let report2 = reports.remove(&ids[2]).unwrap(); + + assert!(report0.provable_errors.is_empty()); + assert!(report1.provable_errors[&ids[0]].verify(&()).is_ok()); + assert!(report2.provable_errors[&ids[0]].verify(&()).is_ok()); +} diff --git a/synedrion/src/tools.rs b/synedrion/src/tools.rs index aeada742..16597054 100644 --- a/synedrion/src/tools.rs +++ b/synedrion/src/tools.rs @@ -1,4 +1,8 @@ -use alloc::collections::{BTreeMap, BTreeSet}; +use alloc::{ + collections::{BTreeMap, BTreeSet}, + format, +}; +use core::fmt::Debug; pub(crate) mod bitvec; pub(crate) mod hashing; @@ -45,3 +49,14 @@ impl DowncastMap for BTreeMap { .collect::>() } } + +pub(crate) trait SafeGet { + fn safe_get(&self, container: &str, key: &K) -> Result<&V, LocalError>; +} + +impl SafeGet for BTreeMap { + fn safe_get(&self, container: &str, key: &K) -> Result<&V, LocalError> { + self.get(key) + .ok_or_else(|| LocalError::new(format!("Key {key:?} not found in {container}"))) + } +} diff --git a/synedrion/src/www02/entities.rs b/synedrion/src/www02/entities.rs index 868be3ed..16b8204d 100644 --- a/synedrion/src/www02/entities.rs +++ b/synedrion/src/www02/entities.rs @@ -173,12 +173,7 @@ impl ThresholdKeyShare>()?; - Ok(KeyShare { - owner: self.owner.clone(), - secret_share, - public_shares, - phantom: PhantomData, - }) + KeyShare::new(self.owner.clone(), secret_share, public_shares) } /// Creates a t-of-t threshold keyshare that can be used in KeyResharing protocol. @@ -196,7 +191,7 @@ impl ThresholdKeyShare ThresholdKeyShare ThresholdKeyShare