From 4facf7157ecd5a42e0dfa4ddc936ac8673321804 Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Mon, 6 Jan 2025 19:13:16 -0800 Subject: [PATCH] Update KeyRefresh --- synedrion/src/cggmp21.rs | 3 + synedrion/src/cggmp21/aux_gen.rs | 13 +- synedrion/src/cggmp21/entities.rs | 10 - synedrion/src/cggmp21/key_refresh.rs | 1198 ++++++++++------- .../src/cggmp21/key_refresh_malicious.rs | 650 +++++++++ synedrion/src/paillier/encryption.rs | 13 +- synedrion/src/paillier/keys.rs | 11 + synedrion/src/paillier/ring_pedersen.rs | 24 + synedrion/src/paillier/rsa.rs | 28 + 9 files changed, 1469 insertions(+), 481 deletions(-) create mode 100644 synedrion/src/cggmp21/key_refresh_malicious.rs diff --git a/synedrion/src/cggmp21.rs b/synedrion/src/cggmp21.rs index e5b56db..18d4cf0 100644 --- a/synedrion/src/cggmp21.rs +++ b/synedrion/src/cggmp21.rs @@ -22,6 +22,9 @@ mod signing_malicious; #[cfg(test)] mod key_init_malicious; +#[cfg(test)] +mod key_refresh_malicious; + pub use aux_gen::{AuxGen, AuxGenProtocol}; pub use entities::{AuxInfo, KeyShare, KeyShareChange}; pub use interactive_signing::{InteractiveSigning, InteractiveSigningProtocol, PrehashedMessage}; diff --git a/synedrion/src/cggmp21/aux_gen.rs b/synedrion/src/cggmp21/aux_gen.rs index 65d3b08..dbed916 100644 --- a/synedrion/src/cggmp21/aux_gen.rs +++ b/synedrion/src/cggmp21/aux_gen.rs @@ -608,7 +608,6 @@ impl Round for Round3 { ( id, PublicAuxInfo { - el_gamal_pk: data.data.cap_y, paillier_pk: data.paillier_pk.into_wire(), rp_params: data.rp_params.to_wire(), }, @@ -618,7 +617,6 @@ impl Round for Round3 { let secret_aux = SecretAuxInfo { paillier_sk: self.context.paillier_sk.into_wire(), - el_gamal_sk: self.context.y, }; let aux_info = AuxInfo { @@ -660,18 +658,9 @@ mod tests { }) .collect::>(); - let aux_infos = run_sync::<_, TestSessionParams>(&mut OsRng, entry_points) + let _aux_infos = run_sync::<_, TestSessionParams>(&mut OsRng, entry_points) .unwrap() .results() .unwrap(); - - for (id, aux_info) in aux_infos.iter() { - for other_aux_info in aux_infos.values() { - assert_eq!( - aux_info.secret_aux.el_gamal_sk.mul_by_generator(), - other_aux_info.public_aux[id].el_gamal_pk - ); - } - } } } diff --git a/synedrion/src/cggmp21/entities.rs b/synedrion/src/cggmp21/entities.rs index 6119b02..7760a16 100644 --- a/synedrion/src/cggmp21/entities.rs +++ b/synedrion/src/cggmp21/entities.rs @@ -45,14 +45,12 @@ pub struct AuxInfo { #[serde(bound(deserialize = "SecretKeyPaillierWire: for <'x> Deserialize<'x>"))] pub(crate) struct SecretAuxInfo { pub(crate) paillier_sk: SecretKeyPaillierWire, - pub(crate) el_gamal_sk: Secret, // `y_i` } #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(bound(serialize = "PublicKeyPaillierWire: Serialize"))] #[serde(bound(deserialize = "PublicKeyPaillierWire: for <'x> Deserialize<'x>"))] pub(crate) struct PublicAuxInfo { - pub(crate) el_gamal_pk: Point, // `Y_i` /// The Paillier public key. pub(crate) paillier_pk: PublicKeyPaillierWire, /// The ring-Pedersen parameters. @@ -68,14 +66,10 @@ pub(crate) struct AuxInfoPrecomputed { #[derive(Debug, Clone)] pub(crate) struct SecretAuxInfoPrecomputed { pub(crate) paillier_sk: SecretKeyPaillier, - #[allow(dead_code)] // TODO (#36): this will be needed for the 6-round presigning protocol. - pub(crate) el_gamal_sk: Secret, // `y_i` } #[derive(Debug, Clone)] pub(crate) struct PublicAuxInfoPrecomputed { - #[allow(dead_code)] // TODO (#36): this will be needed for the 6-round presigning protocol. - pub(crate) el_gamal_pk: Point, pub(crate) paillier_pk: PublicKeyPaillier, pub(crate) rp_params: RPParams, } @@ -259,7 +253,6 @@ impl AuxInfo { let secret_aux = (0..ids.len()) .map(|_| SecretAuxInfo { paillier_sk: SecretKeyPaillierWire::::random(rng), - el_gamal_sk: Secret::init_with(|| Scalar::random(rng)), }) .collect::>(); @@ -271,7 +264,6 @@ impl AuxInfo { id.clone(), PublicAuxInfo { paillier_pk: secret.paillier_sk.public_key(), - el_gamal_pk: secret.el_gamal_sk.mul_by_generator(), rp_params: RPParams::random(rng).to_wire(), }, ) @@ -297,7 +289,6 @@ impl AuxInfo { AuxInfoPrecomputed { secret_aux: SecretAuxInfoPrecomputed { paillier_sk: self.secret_aux.paillier_sk.clone().into_precomputed(), - el_gamal_sk: self.secret_aux.el_gamal_sk.clone(), }, public_aux: self .public_aux @@ -307,7 +298,6 @@ impl AuxInfo { ( id.clone(), PublicAuxInfoPrecomputed { - el_gamal_pk: public_aux.el_gamal_pk, paillier_pk: paillier_pk.clone(), rp_params: public_aux.rp_params.to_precomputed(), }, diff --git a/synedrion/src/cggmp21/key_refresh.rs b/synedrion/src/cggmp21/key_refresh.rs index 0596c52..eaf1c4f 100644 --- a/synedrion/src/cggmp21/key_refresh.rs +++ b/synedrion/src/cggmp21/key_refresh.rs @@ -1,26 +1,24 @@ -//! KeyRefresh protocol, in the paper Auxiliary Info. & Key Refresh in Three Rounds (Fig. 6). +//! KeyRefresh protocol, in the paper Auxiliary Info. & Key Refresh in Three Rounds (Fig. 7). //! This protocol generates an update to the secret key shares and new auxiliary parameters //! for ZK proofs (e.g. Paillier keys). -use alloc::{ - collections::{BTreeMap, BTreeSet}, - format, - string::String, - vec::Vec, +use alloc::collections::{BTreeMap, BTreeSet}; +use core::{ + fmt::{self, Debug, Display}, + marker::PhantomData, }; -use core::{fmt::Debug, marker::PhantomData}; use crypto_bigint::BitOps; use manul::protocol::{ Artifact, BoxedRound, Deserializer, DirectMessage, EchoBroadcast, EntryPoint, FinalizeOutcome, LocalError, MessageValidationError, NormalBroadcast, PartyId, Payload, Protocol, ProtocolError, ProtocolMessage, - ProtocolMessagePart, ProtocolValidationError, ReceiveError, RequiredMessages, Round, RoundId, Serializer, + ProtocolMessagePart, ProtocolValidationError, ReceiveError, RequiredMessageParts, RequiredMessages, Round, RoundId, + Serializer, }; use rand_core::CryptoRngCore; use serde::{Deserialize, Serialize}; use super::{ - conversion::{secret_scalar_from_signed, secret_signed_from_scalar}, entities::{AuxInfo, KeyShareChange, PublicAuxInfo, SecretAuxInfo}, params::SchemeParams, sigma::{FacProof, ModProof, PrmProof, SchCommitment, SchProof, SchSecret}, @@ -28,13 +26,13 @@ use super::{ use crate::{ curve::{secret_split, Point, Scalar}, paillier::{ - Ciphertext, CiphertextWire, PaillierParams, PublicKeyPaillier, PublicKeyPaillierWire, RPParams, RPParamsWire, - RPSecret, SecretKeyPaillier, SecretKeyPaillierWire, + PaillierParams, PublicKeyPaillier, PublicKeyPaillierWire, RPParams, RPParamsWire, RPSecret, SecretKeyPaillier, + SecretKeyPaillierWire, }, tools::{ bitvec::BitVec, - hashing::{Chain, FofHasher, HashOutput}, - DowncastMap, Secret, Without, + hashing::{Chain, FofHasher, HashOutput, XofHasher}, + verify_that, DeserializeAll, DowncastMap, GetRound, SafeGet, Secret, Without, }, }; @@ -45,78 +43,373 @@ pub struct KeyRefreshProtocol(PhantomData<(P, I)>); impl Protocol for KeyRefreshProtocol { type Result = (KeyShareChange, AuxInfo); - type ProtocolError = KeyRefreshError

; + type ProtocolError = KeyRefreshError; fn verify_direct_message_is_invalid( - _deserializer: &Deserializer, - _round_id: &RoundId, - _message: &DirectMessage, + deserializer: &Deserializer, + round_id: &RoundId, + message: &DirectMessage, ) -> Result<(), MessageValidationError> { - unimplemented!() + match round_id { + r if r == &1 => message.verify_is_some(), + r if r == &2 => message.verify_is_some(), + r if r == &3 => message.verify_is_not::>(deserializer), + _ => Err(MessageValidationError::InvalidEvidence("Invalid round number".into())), + } } fn verify_echo_broadcast_is_invalid( - _deserializer: &Deserializer, - _round_id: &RoundId, - _message: &EchoBroadcast, + deserializer: &Deserializer, + round_id: &RoundId, + message: &EchoBroadcast, ) -> Result<(), MessageValidationError> { - unimplemented!() + match round_id { + r if r == &1 => message.verify_is_not::(deserializer), + r if r == &2 => message.verify_is_some(), + r if r == &3 => message.verify_is_not::>(deserializer), + _ => Err(MessageValidationError::InvalidEvidence("Invalid round number".into())), + } } fn verify_normal_broadcast_is_invalid( - _deserializer: &Deserializer, - _round_id: &RoundId, - _message: &NormalBroadcast, + deserializer: &Deserializer, + round_id: &RoundId, + message: &NormalBroadcast, ) -> Result<(), MessageValidationError> { - unimplemented!() + match round_id { + r if r == &1 => message.verify_is_some(), + r if r == &2 => message.verify_is_not::>(deserializer), + r if r == &3 => message.verify_is_not::>(deserializer), + _ => Err(MessageValidationError::InvalidEvidence("Invalid round number".into())), + } } } /// Provable KeyRefresh faults. -#[derive(displaydoc::Display, Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] #[serde(bound(serialize = " - KeyRefreshErrorEnum

: Serialize, + KeyRefreshErrorEnum: Serialize, "))] #[serde(bound(deserialize = " - KeyRefreshErrorEnum

: for<'x> Deserialize<'x>, + KeyRefreshErrorEnum: for<'x> Deserialize<'x>, "))] -pub struct KeyRefreshError(KeyRefreshErrorEnum

); +pub struct KeyRefreshError { + error: KeyRefreshErrorEnum, + phantom: PhantomData

, +} -#[derive(Debug, Clone, Serialize, Deserialize)] -enum KeyRefreshErrorEnum { - // TODO (#43): this can be removed when error verification is added - #[allow(dead_code)] - Round2(String), - // TODO (#43): this can be removed when error verification is added - #[allow(dead_code)] - Round3(String), - // TODO (#43): this can be removed when error verification is added - #[allow(dead_code)] - Round3MismatchedSecret { - cap_c: CiphertextWire, - x: Scalar, - mu: ::Uint, +impl KeyRefreshError { + fn new(error: KeyRefreshErrorEnum) -> Self { + Self { + error, + phantom: PhantomData, + } + } +} + +impl Display for KeyRefreshError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + write!(f, "{:?}", self.error) + } +} + +/// KeyRefresh error +#[derive(displaydoc::Display, Debug, Clone, Serialize, Deserialize)] +enum KeyRefreshErrorEnum { + /// Round2: public data hash mismatch + R2HashMismatch, + /// Round2: wrong IDs in public shares map + R2WrongIdsX, + /// Round2: wrong IDs in Elgamal keys map + R2WrongIdsY, + /// Round2: wrong IDs in Schnorr commitments map + R2WrongIdsA, + /// Round2: Paillier modulus is too small + R2PaillierModulusTooSmall, + /// Round2: ring-Pedersent modulus is too small + R2RPModulusTooSmall, + /// Round2: sum of share changes is not zero + R2NonZeroSumOfChanges, + /// Round2: P_prm verification failed + R2PrmFailed, + /// Round3: secret share change does not match the public commitment + R3ShareChangeMismatch { + /// The index $i$ of the node that produced the evidence. + reported_by: I, + /// $y_{i,j}$, where where $j$ is the index of the guilty party. + y: Scalar, + }, + /// Round3: P_mod verification failed + R3ModFailed, + /// Round3: P_fac verification failed + R3FacFailed { + /// The index $i$ of the node that produced the evidence. + reported_by: I, + }, + /// Round3: Wrong IDs in Schnorr proofs map + R3WrongIdsHatPsi, + /// Round3: P_sch verification failed + R3SchFailed { + /// The index $k$ for which the verification of $П^{sch}_{j,k}$ failed + /// (where $j$ is the index of the guilty party). + failed_for: I, }, } -impl ProtocolError for KeyRefreshError

{ - type AssociatedData = (); +/// Reconstruct `rid` from echoed messages +fn reconstruct_rid( + deserializer: &Deserializer, + previous_messages: &BTreeMap, + combined_echos: &BTreeMap>, +) -> Result { + let r2_messages = combined_echos + .get_round(2)? + .deserialize_all::>(deserializer)?; + let r2_echo = previous_messages + .get_round(2)? + .echo_broadcast + .deserialize::>(deserializer)?; + let mut rid = r2_echo.rid_part; + for message in r2_messages.values() { + rid ^= &message.rid_part; + } + Ok(rid) +} + +impl ProtocolError for KeyRefreshError { + type AssociatedData = BTreeSet; fn required_messages(&self) -> RequiredMessages { - unimplemented!() + match self.error { + KeyRefreshErrorEnum::R2HashMismatch => RequiredMessages::new( + RequiredMessageParts::normal_broadcast(), + Some([(1.into(), RequiredMessageParts::echo_broadcast())].into()), + None, + ), + KeyRefreshErrorEnum::R2WrongIdsX => { + RequiredMessages::new(RequiredMessageParts::normal_broadcast(), None, None) + } + KeyRefreshErrorEnum::R2WrongIdsY => { + RequiredMessages::new(RequiredMessageParts::echo_broadcast(), None, None) + } + KeyRefreshErrorEnum::R2WrongIdsA => { + RequiredMessages::new(RequiredMessageParts::normal_broadcast(), None, None) + } + KeyRefreshErrorEnum::R2PaillierModulusTooSmall => { + RequiredMessages::new(RequiredMessageParts::normal_broadcast(), None, None) + } + KeyRefreshErrorEnum::R2RPModulusTooSmall => { + RequiredMessages::new(RequiredMessageParts::echo_broadcast(), None, None) + } + KeyRefreshErrorEnum::R2NonZeroSumOfChanges => { + RequiredMessages::new(RequiredMessageParts::normal_broadcast(), None, None) + } + KeyRefreshErrorEnum::R2PrmFailed => RequiredMessages::new( + RequiredMessageParts::echo_broadcast().and_normal_broadcast(), + None, + None, + ), + KeyRefreshErrorEnum::R3ShareChangeMismatch { .. } => RequiredMessages::new( + RequiredMessageParts::direct_message(), + Some([(2.into(), RequiredMessageParts::echo_broadcast().and_normal_broadcast())].into()), + Some([2.into()].into()), + ), + KeyRefreshErrorEnum::R3ModFailed => RequiredMessages::new( + RequiredMessageParts::normal_broadcast(), + Some([(2.into(), RequiredMessageParts::echo_broadcast().and_normal_broadcast())].into()), + Some([2.into()].into()), + ), + KeyRefreshErrorEnum::R3FacFailed { .. } => RequiredMessages::new( + RequiredMessageParts::direct_message(), + Some([(2.into(), RequiredMessageParts::echo_broadcast().and_normal_broadcast())].into()), + Some([2.into()].into()), + ), + KeyRefreshErrorEnum::R3WrongIdsHatPsi => { + RequiredMessages::new(RequiredMessageParts::echo_broadcast(), None, None) + } + KeyRefreshErrorEnum::R3SchFailed { .. } => RequiredMessages::new( + RequiredMessageParts::echo_broadcast(), + Some([(2.into(), RequiredMessageParts::echo_broadcast().and_normal_broadcast())].into()), + Some([2.into()].into()), + ), + } } fn verify_messages_constitute_error( &self, - _deserializer: &Deserializer, - _guilty_party: &I, - _shared_randomness: &[u8], - _associated_data: &Self::AssociatedData, - _message: ProtocolMessage, - _previous_messages: BTreeMap, - _combined_echos: BTreeMap>, + deserializer: &Deserializer, + guilty_party: &I, + shared_randomness: &[u8], + associated_data: &Self::AssociatedData, + message: ProtocolMessage, + previous_messages: BTreeMap, + combined_echos: BTreeMap>, ) -> Result<(), ProtocolValidationError> { - unimplemented!() + let sid_hash = FofHasher::new_with_dst(b"SID") + .chain_type::

() + .chain(&shared_randomness) + .finalize(); + + match &self.error { + KeyRefreshErrorEnum::R2HashMismatch => { + let r1_message = previous_messages + .get_round(1)? + .echo_broadcast + .deserialize::(deserializer)?; + let r2_message = message + .normal_broadcast + .deserialize::>(deserializer)?; + verify_that(r2_message.hash(&sid_hash, guilty_party) != r1_message.cap_v) + } + KeyRefreshErrorEnum::R2WrongIdsX => { + let r2_message = message + .normal_broadcast + .deserialize::>(deserializer)?; + verify_that(&r2_message.cap_xs.keys().cloned().collect::>() != associated_data) + } + KeyRefreshErrorEnum::R2WrongIdsY => { + let r2_message = message + .echo_broadcast + .deserialize::>(deserializer)?; + verify_that(&r2_message.cap_ys.keys().cloned().collect::>() != associated_data) + } + KeyRefreshErrorEnum::R2WrongIdsA => { + let r2_message = message + .normal_broadcast + .deserialize::>(deserializer)?; + verify_that(&r2_message.cap_as.keys().cloned().collect::>() != associated_data) + } + KeyRefreshErrorEnum::R2PaillierModulusTooSmall => { + let r2_message = message + .normal_broadcast + .deserialize::>(deserializer)?; + verify_that( + r2_message.paillier_pk.modulus().bits_vartime() < ::MODULUS_BITS - 2, + ) + } + KeyRefreshErrorEnum::R2RPModulusTooSmall => { + let r2_message = message + .echo_broadcast + .deserialize::>(deserializer)?; + verify_that( + r2_message.rp_params.modulus().bits_vartime() < ::MODULUS_BITS - 2, + ) + } + KeyRefreshErrorEnum::R2NonZeroSumOfChanges => { + let r2_message = message + .normal_broadcast + .deserialize::>(deserializer)?; + verify_that(r2_message.cap_xs.values().sum::() != Point::IDENTITY) + } + KeyRefreshErrorEnum::R2PrmFailed => { + let r2_eb = message + .echo_broadcast + .deserialize::>(deserializer)?; + let r2_bc = message + .normal_broadcast + .deserialize::>(deserializer)?; + let aux = (&sid_hash, guilty_party); + let rp_params = r2_eb.rp_params.to_precomputed(); + verify_that(!r2_bc.psi.verify(&rp_params, &aux)) + } + KeyRefreshErrorEnum::R3ShareChangeMismatch { reported_by, y } => { + // Check that `y` attached to the evidence is correct + // (that is, can be verified against something signed by `guilty_party`). + // It is `y_{i,j}` where `i == reported_by` and `j == guilty_party` + let r2_message_i = combined_echos + .get_round(2)? + .try_get("combined echos for Round 2", reported_by)? + .deserialize::>(deserializer)?; + let cap_y_ij = r2_message_i.cap_ys.try_get("public Elgamal values", guilty_party)?; + if &y.mul_by_generator() != cap_y_ij { + return Err(ProtocolValidationError::InvalidEvidence( + "The provided `y` is invalid".into(), + )); + } + + let rid = reconstruct_rid::(deserializer, &previous_messages, &combined_echos)?; + + let r2_echo = previous_messages + .get_round(2)? + .echo_broadcast + .deserialize::>(deserializer)?; + let cap_y_ji = r2_echo.cap_ys.try_get("public Elgamal values", reported_by)?; + let mut reader = XofHasher::new_with_dst(b"KeyRefresh Round3") + .chain(&sid_hash) + .chain(&rid) + .chain(guilty_party) + .chain(&(cap_y_ji * y)) + .finalize_to_reader(); + let rho = Scalar::from_xof_reader(&mut reader); + + let r2_bc = previous_messages + .get_round(2)? + .normal_broadcast + .deserialize::>(deserializer)?; + let r3_message = message + .direct_message + .deserialize::>(deserializer)?; + + let x = r3_message.cap_c - rho; + let cap_x_ji = r2_bc.cap_xs.try_get("public key share changes", reported_by)?; + verify_that(&x.mul_by_generator() != cap_x_ji) + } + KeyRefreshErrorEnum::R3ModFailed => { + let rid = reconstruct_rid::(deserializer, &previous_messages, &combined_echos)?; + let aux = (&sid_hash, guilty_party, &rid); + let r2_bc = previous_messages + .get_round(2)? + .normal_broadcast + .deserialize::>(deserializer)?; + let r3_bc = message + .normal_broadcast + .deserialize::>(deserializer)?; + let paillier_pk = r2_bc.paillier_pk.into_precomputed(); + verify_that(!r3_bc.psi_prime.verify(&paillier_pk, &aux)) + } + KeyRefreshErrorEnum::R3FacFailed { reported_by } => { + let rid = reconstruct_rid::(deserializer, &previous_messages, &combined_echos)?; + let aux = (&sid_hash, guilty_party, &rid); + + let r2_eb = combined_echos + .get_round(2)? + .try_get("combined echos for Round 2", reported_by)? + .deserialize::>(deserializer)?; + let r2_bc = previous_messages + .get_round(2)? + .normal_broadcast + .deserialize::>(deserializer)?; + let r3_dm = message + .direct_message + .deserialize::>(deserializer)?; + let paillier_pk = r2_bc.paillier_pk.into_precomputed(); + let rp_params = r2_eb.rp_params.to_precomputed(); + verify_that(!r3_dm.psi.verify(&paillier_pk, &rp_params, &aux)) + } + KeyRefreshErrorEnum::R3WrongIdsHatPsi => { + let r3_eb = message + .echo_broadcast + .deserialize::>(deserializer)?; + verify_that(&r3_eb.hat_psis.keys().cloned().collect::>() != associated_data) + } + KeyRefreshErrorEnum::R3SchFailed { failed_for } => { + let rid = reconstruct_rid::(deserializer, &previous_messages, &combined_echos)?; + let aux = (&sid_hash, guilty_party, &rid); + + let r2_bc = previous_messages + .get_round(2)? + .normal_broadcast + .deserialize::>(deserializer)?; + let r3_eb = message + .echo_broadcast + .deserialize::>(deserializer)?; + + let cap_a = r2_bc.cap_as.try_get("Schnorr commitments", failed_for)?; + let cap_x = r2_bc.cap_xs.try_get("public share changes", failed_for)?; + let hat_psi = r3_eb.hat_psis.try_get("Schnorr proofs", failed_for)?; + verify_that(!hat_psi.verify(cap_a, cap_x, &aux)) + } + } } } @@ -157,161 +450,117 @@ impl EntryPoint for KeyRefresh { let other_ids = self.all_ids.clone().without(id); - let ids_ordering = self - .all_ids - .iter() - .cloned() - .enumerate() - .map(|(idx, id)| (id, idx)) - .collect(); - let sid_hash = FofHasher::new_with_dst(b"SID") .chain_type::

() .chain(&shared_randomness) - .chain(&self.all_ids) .finalize(); - // $p_i$, $q_i$ + // Paillier secret key $p_i$, $q_i$ let paillier_sk = SecretKeyPaillierWire::::random(rng); - // $N_i$ + // Paillier public key $N_i$ let paillier_pk = paillier_sk.public_key(); - // El-Gamal key - let y = Secret::init_with(|| Scalar::random(rng)); - let cap_y = y.mul_by_generator(); + // Ring-Pedersen secret $\lambda$. + let rp_secret = RPSecret::random(rng); + // Ring-Pedersen parameters ($N$, $s$, $t$) bundled in a single object. + let rp_params = RPParams::random_with_secret(rng, &rp_secret); - // The secret and the commitment for the Schnorr PoK of the El-Gamal key - let tau_y = SchSecret::random(rng); // $\tau$ - let cap_b = SchCommitment::new(&tau_y); + let aux = (&sid_hash, id); + let psi = PrmProof::

::new(rng, &rp_secret, &rp_params, &aux); - // Secret share updates for each node ($x_i^j$ where $i$ is this party's index). - let x_to_send = self + // Ephemeral DH keys $y_{i,j}$ where $i$ is this party's index. + let ys = self .all_ids .iter() .cloned() - .zip(secret_split( - rng, - Secret::init_with(|| Scalar::ZERO), - self.all_ids.len(), - )) + .map(|id| (id, Secret::init_with(|| Scalar::random(rng)))) .collect::>(); + // Corresponding public keys $Y_{i,j}$. + let cap_ys = ys.iter().map(|(id, y)| (id.clone(), y.mul_by_generator())).collect(); - // Public counterparts of secret share updates ($X_i^j$ where $i$ is this party's index). - let cap_x_to_send = x_to_send.values().map(|x| x.mul_by_generator()).collect(); - - let rp_secret = RPSecret::random(rng); - // Ring-Pedersen parameters ($s$, $t$) bundled in a single object. - let rp_params = RPParams::random_with_secret(rng, &rp_secret); + // Secret share updates for each node ($x_{i,j}$ where $i$ is this party's index). + let split_zero = secret_split(rng, Secret::init_with(|| Scalar::ZERO), self.all_ids.len()); + let xs = self.all_ids.iter().cloned().zip(split_zero).collect::>(); - let aux = (&sid_hash, id); - let hat_psi = PrmProof::

::new(rng, &rp_secret, &rp_params, &aux); + // Public counterparts of secret share updates ($X_i^j$ where $i$ is this party's index). + let cap_xs = xs.iter().map(|(id, x)| (id.clone(), x.mul_by_generator())).collect(); - // The secrets share changes ($\tau_j$, not to be confused with $\tau$) - let tau_x = self + // Schnorr proof secrets $\tau_j$ + let taus = self .all_ids .iter() .map(|id| (id.clone(), SchSecret::random(rng))) .collect::>(); - // The commitments for share changes ($A_i^j$ where $i$ is this party's index) - let cap_a_to_send = tau_x.values().map(SchCommitment::new).collect(); + // Schnorr commitments for share changes ($A_{i,j}$ where $i$ is this party's index) + let cap_as = taus + .iter() + .map(|(id, tau)| (id.clone(), SchCommitment::new(tau))) + .collect(); - let rho = BitVec::random(rng, P::SECURITY_PARAMETER); + let rid_part = BitVec::random(rng, P::SECURITY_PARAMETER); let u = BitVec::random(rng, P::SECURITY_PARAMETER); - let data = PublicData1 { - cap_x_to_send, - cap_a_to_send, - cap_y, - cap_b, + // Note: typo in the paper, $V$ hashes in $B_i$ which is not present in the '24 version of the paper. + let r2_normal_broadcast = Round2Broadcast { + cap_xs, + cap_as, paillier_pk: paillier_pk.clone(), - rp_params: rp_params.to_wire(), - hat_psi, - rho, + psi, u, }; - let data_precomp = PublicData1Precomp { - data, - paillier_pk: paillier_pk.into_precomputed(), - rp_params, + let r2_echo_broadcast = Round2EchoBroadcast { + rp_params: rp_params.to_wire(), + cap_ys, + rid_part, }; let context = Context { paillier_sk: paillier_sk.into_precomputed(), - y, - x_to_send, - tau_x, - tau_y, - data_precomp, + rp_params, + xs, + ys, + taus, my_id: id.clone(), other_ids, + all_ids: self.all_ids, sid_hash, - ids_ordering, }; - Ok(BoxedRound::new_dynamic(Round1 { context })) - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(bound(serialize = " - PrmProof

: Serialize, - "))] -#[serde(bound(deserialize = " - PrmProof

: for<'x> Deserialize<'x>, - "))] -struct PublicData1 { - cap_x_to_send: Vec, // $X_i^j$ where $i$ is this party's index - cap_a_to_send: Vec, // $A_i^j$ where $i$ is this party's index - cap_y: Point, - cap_b: SchCommitment, - paillier_pk: PublicKeyPaillierWire, // $N_i$ - rp_params: RPParamsWire, // $s_i$ and $t_i$ - hat_psi: PrmProof

, - rho: BitVec, - u: BitVec, -} + let round = Round1 { + context, + r2_normal_broadcast, + r2_echo_broadcast, + }; -#[derive(Debug, Clone)] -struct PublicData1Precomp { - data: PublicData1

, - paillier_pk: PublicKeyPaillier, - rp_params: RPParams, + Ok(BoxedRound::new_dynamic(round)) + } } #[derive(Debug)] -struct Context { +pub(super) struct Context { paillier_sk: SecretKeyPaillier, - y: Secret, - x_to_send: BTreeMap>, // $x_i^j$ where $i$ is this party's index - tau_y: SchSecret, - tau_x: BTreeMap, - data_precomp: PublicData1Precomp

, - my_id: I, + rp_params: RPParams, + xs: BTreeMap>, // $x_{i,j}$ where $i$ is this party's index + ys: BTreeMap>, // $y_{i,j}$ where $i$ is this party's index + taus: BTreeMap, + pub(super) my_id: I, other_ids: BTreeSet, - sid_hash: HashOutput, - ids_ordering: BTreeMap, -} - -impl PublicData1

{ - fn hash(&self, sid_hash: &HashOutput, id: &I) -> HashOutput { - FofHasher::new_with_dst(b"Auxiliary") - .chain(sid_hash) - .chain(id) - .chain(self) - .finalize() - } + all_ids: BTreeSet, + pub(super) sid_hash: HashOutput, } #[derive(Debug)] -struct Round1 { - context: Context, +pub(super) struct Round1 { + pub(super) context: Context, + pub(super) r2_normal_broadcast: Round2Broadcast, + pub(super) r2_echo_broadcast: Round2EchoBroadcast, } #[derive(Debug, Clone, Serialize, Deserialize)] -struct Round1Message { - cap_v: HashOutput, +pub(super) struct Round1EchoBroadcast { + pub(super) cap_v: HashOutput, } struct Round1Payload { @@ -342,16 +591,12 @@ impl Round for Round1 { _rng: &mut impl CryptoRngCore, serializer: &Serializer, ) -> Result { - EchoBroadcast::new( - serializer, - Round1Message { - cap_v: self - .context - .data_precomp - .data - .hash(&self.context.sid_hash, &self.context.my_id), - }, - ) + let message = Round1EchoBroadcast { + cap_v: self + .r2_normal_broadcast + .hash(&self.context.sid_hash, &self.context.my_id), + }; + EchoBroadcast::new(serializer, message) } fn receive_message( @@ -362,10 +607,13 @@ impl Round for Round1 { ) -> Result> { message.normal_broadcast.assert_is_none()?; message.direct_message.assert_is_none()?; - let echo_broadcast = message.echo_broadcast.deserialize::(deserializer)?; - Ok(Payload::new(Round1Payload { + let echo_broadcast = message + .echo_broadcast + .deserialize::(deserializer)?; + let payload = Round1Payload { cap_v: echo_broadcast.cap_v, - })) + }; + Ok(Payload::new(payload)) } fn finalize( @@ -376,28 +624,70 @@ impl Round for Round1 { ) -> Result, LocalError> { let payloads = payloads.downcast_all::()?; let others_cap_v = payloads.into_iter().map(|(id, payload)| (id, payload.cap_v)).collect(); - Ok(FinalizeOutcome::AnotherRound(BoxedRound::new_dynamic(Round2 { + let next_round = Round2 { context: self.context, + r2_echo_broadcast: self.r2_echo_broadcast, + r2_normal_broadcast: self.r2_normal_broadcast, others_cap_v, - }))) + }; + Ok(FinalizeOutcome::AnotherRound(BoxedRound::new_dynamic(next_round))) } } #[derive(Debug)] -struct Round2 { +struct Round2 { context: Context, + r2_normal_broadcast: Round2Broadcast, + r2_echo_broadcast: Round2EchoBroadcast, others_cap_v: BTreeMap, } -#[derive(Clone, Serialize, Deserialize)] -#[serde(bound(serialize = "PublicData1

: Serialize"))] -#[serde(bound(deserialize = "PublicData1

: for<'x> Deserialize<'x>"))] -struct Round2Message { - data: PublicData1

, +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(bound(serialize = " + PrmProof

: Serialize, +"))] +#[serde(bound(deserialize = " + PrmProof

: for<'x> Deserialize<'x>, +"))] +pub(super) struct Round2Broadcast { + pub(super) cap_xs: BTreeMap, // $X_{i,j}$ where $i$ is this party's index + pub(super) cap_as: BTreeMap, // $A_{i,j}$ where $i$ is this party's index + pub(super) paillier_pk: PublicKeyPaillierWire, // $N_i$ + pub(super) psi: PrmProof

, + u: BitVec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(bound(serialize = " + I: Serialize, +"))] +#[serde(bound(deserialize = " + I: for<'x> Deserialize<'x>, +"))] +pub(super) struct Round2EchoBroadcast { + pub(super) rp_params: RPParamsWire, // $\hat{N}_i$, $s_i$, and $t_i$ + pub(super) cap_ys: BTreeMap, // $Y_{i,j}$ where $i$ is this party's index + rid_part: BitVec, +} + +impl Round2Broadcast { + pub(super) fn hash(&self, sid_hash: &HashOutput, id: &I) -> HashOutput { + FofHasher::new_with_dst(b"Auxiliary") + .chain(sid_hash) + .chain(id) + .chain(self) + .finalize() + } } -struct Round2Payload { - data: PublicData1Precomp

, +#[derive(Debug)] +struct Round2Payload { + cap_xs: BTreeMap, // $X_{i,j}$ where $i$ is this party's index + cap_as: BTreeMap, // $A_{i,j}$ where $i$ is this party's index + cap_ys: BTreeMap, // $Y_{i,j}$ where $i$ is this party's index + paillier_pk: PublicKeyPaillier, // $N_i$ + rp_params: RPParams, // $\hat{N}_i$, $s_i$, and $t_i$ + rid_part: BitVec, } impl Round for Round2 { @@ -424,12 +714,15 @@ impl Round for Round2 { _rng: &mut impl CryptoRngCore, serializer: &Serializer, ) -> Result { - NormalBroadcast::new( - serializer, - Round2Message { - data: self.context.data_precomp.data.clone(), - }, - ) + NormalBroadcast::new(serializer, self.r2_normal_broadcast.clone()) + } + + fn make_echo_broadcast( + &self, + _rng: &mut impl CryptoRngCore, + serializer: &Serializer, + ) -> Result { + EchoBroadcast::new(serializer, self.r2_echo_broadcast.clone()) } fn receive_message( @@ -438,50 +731,78 @@ impl Round for Round2 { from: &I, message: ProtocolMessage, ) -> Result> { - message.echo_broadcast.assert_is_none()?; message.direct_message.assert_is_none()?; - let normal_broadcast = message.normal_broadcast.deserialize::>(deserializer)?; - let cap_v = self - .others_cap_v - .get(from) - .ok_or_else(|| LocalError::new(format!("Missing `V` for {from:?}")))?; - - if &normal_broadcast.data.hash(&self.context.sid_hash, from) != cap_v { - return Err(ReceiveError::protocol(KeyRefreshError(KeyRefreshErrorEnum::Round2( - "Hash mismatch".into(), - )))); + let echo_broadcast = message + .echo_broadcast + .deserialize::>(deserializer)?; + let normal_broadcast = message + .normal_broadcast + .deserialize::>(deserializer)?; + + let cap_v = self.others_cap_v.safe_get("other nodes' `V`", from)?; + + if &normal_broadcast.hash(&self.context.sid_hash, from) != cap_v { + return Err(ReceiveError::protocol(KeyRefreshError::new( + KeyRefreshErrorEnum::R2HashMismatch, + ))); } - let paillier_pk = normal_broadcast.data.paillier_pk.clone().into_precomputed(); + if normal_broadcast.cap_xs.keys().cloned().collect::>() != self.context.all_ids { + return Err(ReceiveError::protocol(KeyRefreshError::new( + KeyRefreshErrorEnum::R2WrongIdsX, + ))); + } - if (paillier_pk.modulus().bits_vartime() as usize) < 8 * P::SECURITY_PARAMETER { - return Err(ReceiveError::protocol(KeyRefreshError(KeyRefreshErrorEnum::Round2( - "Paillier modulus is too small".into(), - )))); + if echo_broadcast.cap_ys.keys().cloned().collect::>() != self.context.all_ids { + return Err(ReceiveError::protocol(KeyRefreshError::new( + KeyRefreshErrorEnum::R2WrongIdsY, + ))); } - if normal_broadcast.data.cap_x_to_send.iter().sum::() != Point::IDENTITY { - return Err(ReceiveError::protocol(KeyRefreshError(KeyRefreshErrorEnum::Round2( - "Sum of X points is not identity".into(), - )))); + if normal_broadcast.cap_as.keys().cloned().collect::>() != self.context.all_ids { + return Err(ReceiveError::protocol(KeyRefreshError::new( + KeyRefreshErrorEnum::R2WrongIdsA, + ))); } - let aux = (&self.context.sid_hash, &from); + let paillier_pk = normal_broadcast.paillier_pk.clone().into_precomputed(); + let rp_params = echo_broadcast.rp_params.to_precomputed(); - let rp_params = normal_broadcast.data.rp_params.to_precomputed(); - if !normal_broadcast.data.hat_psi.verify(&rp_params, &aux) { - return Err(ReceiveError::protocol(KeyRefreshError(KeyRefreshErrorEnum::Round2( - "PRM verification failed".into(), - )))); + if paillier_pk.modulus().bits_vartime() < ::MODULUS_BITS - 2 { + return Err(ReceiveError::protocol(KeyRefreshError::new( + KeyRefreshErrorEnum::R2PaillierModulusTooSmall, + ))); + } + + if rp_params.modulus().bits_vartime() < ::MODULUS_BITS - 2 { + return Err(ReceiveError::protocol(KeyRefreshError::new( + KeyRefreshErrorEnum::R2RPModulusTooSmall, + ))); + } + + if normal_broadcast.cap_xs.values().sum::() != Point::IDENTITY { + return Err(ReceiveError::protocol(KeyRefreshError::new( + KeyRefreshErrorEnum::R2NonZeroSumOfChanges, + ))); + } + + let aux = (&self.context.sid_hash, &from); + if !normal_broadcast.psi.verify(&rp_params, &aux) { + return Err(ReceiveError::protocol(KeyRefreshError::new( + KeyRefreshErrorEnum::R2PrmFailed, + ))); } - Ok(Payload::new(Round2Payload { - data: PublicData1Precomp { - data: normal_broadcast.data, - paillier_pk, - rp_params, - }, - })) + let payload = Round2Payload:: { + cap_xs: normal_broadcast.cap_xs, + cap_as: normal_broadcast.cap_as, + cap_ys: echo_broadcast.cap_ys, + paillier_pk: normal_broadcast.paillier_pk.into_precomputed(), + rp_params: echo_broadcast.rp_params.to_precomputed(), + rid_part: echo_broadcast.rid_part, + }; + + Ok(Payload::new(payload)) } fn finalize( @@ -490,86 +811,107 @@ impl Round for Round2 { payloads: BTreeMap, _artifacts: BTreeMap, ) -> Result, LocalError> { - let payloads = payloads.downcast_all::>()?; - let others_data = payloads - .into_iter() - .map(|(id, payload)| (id, payload.data)) - .collect::>(); - let mut rho = self.context.data_precomp.data.rho.clone(); - for data in others_data.values() { - rho ^= &data.data.rho; + let mut payloads = payloads.downcast_all::>()?; + + let mut rid = self.r2_echo_broadcast.rid_part.clone(); + for payload in payloads.values() { + rid ^= &payload.rid_part; } + // Add in the payload with this node's info, for the sake of uniformity + let my_payload = Round2Payload:: { + cap_xs: self.r2_normal_broadcast.cap_xs, + cap_as: self.r2_normal_broadcast.cap_as, + cap_ys: self.r2_echo_broadcast.cap_ys, + paillier_pk: self.r2_normal_broadcast.paillier_pk.into_precomputed(), + rp_params: self.r2_echo_broadcast.rp_params.to_precomputed(), + rid_part: self.r2_echo_broadcast.rid_part, + }; + payloads.insert(self.context.my_id.clone(), my_payload); + Ok(FinalizeOutcome::AnotherRound(BoxedRound::new_dynamic(Round3::new( rng, self.context, - others_data, - rho, - )))) + payloads, + rid, + )?))) } } #[derive(Debug)] struct Round3 { context: Context, - rho: BitVec, - others_data: BTreeMap>, - psi_mod: ModProof

, - pi: SchProof, -} - -#[derive(Clone, Serialize, Deserialize)] -#[serde(bound(serialize = " - ModProof

: Serialize, - FacProof

: Serialize, - CiphertextWire: Serialize, -"))] -#[serde(bound(deserialize = " - ModProof

: for<'x> Deserialize<'x>, - FacProof

: for<'x> Deserialize<'x>, - CiphertextWire: for<'x> Deserialize<'x>, -"))] -struct PublicData2 { - psi_mod: ModProof

, // $\psi_i$, a P^{mod} for the Paillier modulus - phi: FacProof

, - pi: SchProof, - paillier_enc_x: CiphertextWire, // `C_j,i` - psi_sch: SchProof, // $psi_i^j$, a P^{sch} for the secret share change + rid: BitVec, + r2_payloads: BTreeMap>, + psi_prime: ModProof

, + hat_psis: BTreeMap, } impl Round3 { fn new( rng: &mut impl CryptoRngCore, context: Context, - others_data: BTreeMap>, - rho: BitVec, - ) -> Self { - let aux = (&context.sid_hash, &context.my_id, &rho); - let psi_mod = ModProof::new(rng, &context.paillier_sk, &aux); - - let pi = SchProof::new( - &context.tau_y, - &context.y, - &context.data_precomp.data.cap_b, - &context.data_precomp.data.cap_y, - &aux, - ); + r2_payloads: BTreeMap>, + rid: BitVec, + ) -> Result { + let my_id = &context.my_id; + let aux = (&context.sid_hash, my_id, &rid); + let psi_prime = ModProof::new(rng, &context.paillier_sk, &aux); + + let my_r2_payload = r2_payloads.safe_get("Round 2 payloads", my_id)?; + + let mut hat_psis = BTreeMap::new(); + for id in context.all_ids.iter() { + let x = context.xs.safe_get("secret share changes", id)?; + let tau = context.taus.safe_get("Schnorr secrets", id)?; + let cap_a = my_r2_payload.cap_as.safe_get("Schnorr commitments", id)?; + let cap_x = my_r2_payload.cap_xs.safe_get("public share changes", id)?; + let hat_psi = SchProof::new(tau, x, cap_a, cap_x, &aux); + hat_psis.insert(id.clone(), hat_psi); + } - Self { + Ok(Self { context, - others_data, - rho, - psi_mod, - pi, - } + r2_payloads, + rid, + psi_prime, + hat_psis, + }) } } #[derive(Clone, Serialize, Deserialize)] -#[serde(bound(serialize = "PublicData2

: Serialize"))] -#[serde(bound(deserialize = "PublicData2

: for<'x> Deserialize<'x>"))] -struct Round3Message { - data2: PublicData2

, +#[serde(bound(serialize = " + SchProof: Serialize, +"))] +#[serde(bound(deserialize = " + SchProof: for<'x> Deserialize<'x>, +"))] +pub(super) struct Round3EchoBroadcast { + pub(super) hat_psis: BTreeMap, +} + +#[derive(Clone, Serialize, Deserialize)] +#[serde(bound(serialize = " + ModProof

: Serialize, +"))] +#[serde(bound(deserialize = " + ModProof

: for<'x> Deserialize<'x>, +"))] +pub(super) struct Round3Broadcast { + pub(super) psi_prime: ModProof

, +} + +#[derive(Clone, Serialize, Deserialize)] +#[serde(bound(serialize = " + FacProof

: Serialize, +"))] +#[serde(bound(deserialize = " + FacProof

: for<'x> Deserialize<'x>, +"))] +pub(super) struct Round3DirectMessage { + pub(super) psi: FacProof

, + pub(super) cap_c: Scalar, } struct Round3Payload { @@ -599,64 +941,55 @@ impl Round for Round3 { &self.context.other_ids } + fn make_echo_broadcast( + &self, + _rng: &mut impl CryptoRngCore, + serializer: &Serializer, + ) -> Result { + let message = Round3EchoBroadcast { + hat_psis: self.hat_psis.clone(), + }; + EchoBroadcast::new(serializer, message) + } + + fn make_normal_broadcast( + &self, + _rng: &mut impl CryptoRngCore, + serializer: &Serializer, + ) -> Result { + let message = Round3Broadcast { + psi_prime: self.psi_prime.clone(), + }; + NormalBroadcast::new(serializer, message) + } + fn make_direct_message( &self, rng: &mut impl CryptoRngCore, serializer: &Serializer, destination: &I, ) -> Result<(DirectMessage, Option), LocalError> { - let aux = (&self.context.sid_hash, &self.context.my_id, &self.rho); - - let data = self - .others_data - .get(destination) - .ok_or_else(|| LocalError::new(format!("Missing data for {destination:?}")))?; - - let phi = FacProof::new(rng, &self.context.paillier_sk, &data.rp_params, &aux); - - let destination_idx = *self - .context - .ids_ordering - .get(destination) - .ok_or_else(|| LocalError::new("destination={destination:?} is missing in ids_ordering"))?; - - let x_secret = self - .context - .x_to_send - .get(destination) - .ok_or_else(|| LocalError::new("destination={destination} is missing in x_to_send"))?; - let x_public = self - .context - .data_precomp - .data - .cap_x_to_send - .get(destination_idx) - .ok_or_else(|| LocalError::new("destination_idx={destination_idx} is missing in cap_x_to_send"))?; - let ciphertext = Ciphertext::new(rng, &data.paillier_pk, &secret_signed_from_scalar::

(x_secret)); - let proof_secret = self - .context - .tau_x - .get(destination) - .ok_or_else(|| LocalError::new("destination_idx={destination_idx} is missing in tau_x"))?; - let commitment = self - .context - .data_precomp - .data - .cap_a_to_send - .get(destination_idx) - .ok_or_else(|| LocalError::new("destination_idx={destination_idx} is missing in cap_a_to_send"))?; - - let psi_sch = SchProof::new(proof_secret, x_secret, commitment, x_public, &aux); - - let data2 = PublicData2 { - psi_mod: self.psi_mod.clone(), - phi, - pi: self.pi.clone(), - paillier_enc_x: ciphertext.to_wire(), - psi_sch, - }; - - let dm = DirectMessage::new(serializer, Round3Message { data2 })?; + let my_id = &self.context.my_id; + let aux = (&self.context.sid_hash, my_id, &self.rid); + + let r2_payload = self.r2_payloads.safe_get("Round 2 payloads", destination)?; + + let psi = FacProof::

::new(rng, &self.context.paillier_sk, &r2_payload.rp_params, &aux); + + let cap_y = r2_payload.cap_ys.safe_get("Elgamal public keys", my_id)?; + let y = self.context.ys.safe_get("Elgamal secrets", destination)?; + let mut reader = XofHasher::new_with_dst(b"KeyRefresh Round3") + .chain(&self.context.sid_hash) + .chain(&self.rid) + .chain(my_id) + .chain(&(cap_y * y)) + .finalize_to_reader(); + let rho = Scalar::from_xof_reader(&mut reader); + let x = self.context.xs.safe_get("secret share changes", destination)?; + let cap_c = *(x + &rho).expose_secret(); + + let message = Round3DirectMessage { psi, cap_c }; + let dm = DirectMessage::new(serializer, message)?; Ok((dm, None)) } @@ -666,89 +999,72 @@ impl Round for Round3 { from: &I, message: ProtocolMessage, ) -> Result> { - message.echo_broadcast.assert_is_none()?; - message.normal_broadcast.assert_is_none()?; - let direct_message = message.direct_message.deserialize::>(deserializer)?; - - let sender_data = &self - .others_data - .get(from) - .ok_or_else(|| LocalError::new(format!("Missing data for {from:?}")))?; - - let enc_x = direct_message - .data2 - .paillier_enc_x - .to_precomputed(&self.context.data_precomp.paillier_pk); - - let x = secret_scalar_from_signed::

(&enc_x.decrypt(&self.context.paillier_sk)); - - let my_idx = *self - .context - .ids_ordering - .get(&self.context.my_id) - .ok_or_else(|| LocalError::new(format!("my_id={:?} is missing in ids_ordering", self.context.my_id)))?; - - if x.mul_by_generator() - != *sender_data - .data - .cap_x_to_send - .get(my_idx) - .ok_or_else(|| LocalError::new("my_idx={my_idx} is missing in cap_x_to_send"))? - { - let mu = enc_x.derive_randomizer(&self.context.paillier_sk); - return Err(ReceiveError::protocol(KeyRefreshError( - KeyRefreshErrorEnum::Round3MismatchedSecret { - cap_c: direct_message.data2.paillier_enc_x, - x: *x.expose_secret(), - mu: mu.expose(), + let echo_broadcast = message + .echo_broadcast + .deserialize::>(deserializer)?; + let normal_broadcast = message + .normal_broadcast + .deserialize::>(deserializer)?; + let direct_message = message + .direct_message + .deserialize::>(deserializer)?; + + let my_id = &self.context.my_id; + + let r2_payload = self.r2_payloads.safe_get("Round 2 payloads", from)?; + let cap_y = r2_payload.cap_ys.safe_get("Elgamal public keys", my_id)?; + let y = self.context.ys.safe_get("Elgamal secrets", from)?; + let mut reader = XofHasher::new_with_dst(b"KeyRefresh Round3") + .chain(&self.context.sid_hash) + .chain(&self.rid) + .chain(from) + .chain(&(cap_y * y)) + .finalize_to_reader(); + let rho = Scalar::from_xof_reader(&mut reader); + + let x = Secret::init_with(|| direct_message.cap_c - rho); + let my_cap_x = r2_payload.cap_xs.safe_get("public share changes", my_id)?; + if &x.mul_by_generator() != my_cap_x { + return Err(ReceiveError::protocol(KeyRefreshError::new( + KeyRefreshErrorEnum::R3ShareChangeMismatch { + reported_by: my_id.clone(), + y: *y.expose_secret(), }, ))); } - let aux = (&self.context.sid_hash, &from, &self.rho); - - if !direct_message.data2.psi_mod.verify(&sender_data.paillier_pk, &aux) { - return Err(ReceiveError::protocol(KeyRefreshError(KeyRefreshErrorEnum::Round3( - "Mod proof verification failed".into(), - )))); + let aux = (&self.context.sid_hash, from, &self.rid); + if !normal_broadcast.psi_prime.verify(&r2_payload.paillier_pk, &aux) { + return Err(ReceiveError::protocol(KeyRefreshError::new( + KeyRefreshErrorEnum::R3ModFailed, + ))); } if !direct_message - .data2 - .phi - .verify(&sender_data.paillier_pk, &self.context.data_precomp.rp_params, &aux) + .psi + .verify(&r2_payload.paillier_pk, &self.context.rp_params, &aux) { - return Err(ReceiveError::protocol(KeyRefreshError(KeyRefreshErrorEnum::Round3( - "Fac proof verification failed".into(), - )))); + return Err(ReceiveError::protocol(KeyRefreshError::new( + KeyRefreshErrorEnum::R3FacFailed { + reported_by: my_id.clone(), + }, + ))); } - if !direct_message - .data2 - .pi - .verify(&sender_data.data.cap_b, &sender_data.data.cap_y, &aux) - { - return Err(ReceiveError::protocol(KeyRefreshError(KeyRefreshErrorEnum::Round3( - "Sch proof verification (Y) failed".into(), - )))); + if echo_broadcast.hat_psis.keys().cloned().collect::>() != self.context.all_ids { + return Err(ReceiveError::protocol(KeyRefreshError::new( + KeyRefreshErrorEnum::R3WrongIdsHatPsi, + ))); } - if !direct_message.data2.psi_sch.verify( - sender_data - .data - .cap_a_to_send - .get(my_idx) - .ok_or_else(|| LocalError::new("my_idx={my_idx} is missing in cap_a_to_send"))?, - sender_data - .data - .cap_x_to_send - .get(my_idx) - .ok_or_else(|| LocalError::new("my_idx={my_idx} is missing in cap_a_to_send"))?, - &aux, - ) { - return Err(ReceiveError::protocol(KeyRefreshError(KeyRefreshErrorEnum::Round3( - "Sch proof verification (X) failed".into(), - )))); + for (id, hat_psi) in echo_broadcast.hat_psis.iter() { + let cap_a = r2_payload.cap_as.safe_get("Schnorr commitments", id)?; + let cap_x = r2_payload.cap_xs.safe_get("Public share changes", id)?; + if !hat_psi.verify(cap_a, cap_x, &aux) { + return Err(ReceiveError::protocol(KeyRefreshError::new( + KeyRefreshErrorEnum::R3SchFailed { failed_for: id.clone() }, + ))); + } } Ok(Payload::new(Round3Payload { x })) @@ -761,58 +1077,47 @@ impl Round for Round3 { _artifacts: BTreeMap, ) -> Result, LocalError> { let payloads = payloads.downcast_all::()?; - let others_x = payloads + + let my_id = &self.context.my_id; + + // Share changes from other nodes + let xs = payloads .into_iter() .map(|(id, payload)| (id, payload.x)) .collect::>(); - // The combined secret share change - let x_star = - others_x.into_values().sum::>() - + self.context.x_to_send.get(&self.context.my_id).ok_or_else(|| { - LocalError::new(format!("my_id={:?} is missing in x_to_send", self.context.my_id)) - })?; - - let my_id = self.context.my_id.clone(); - let mut all_ids = self.context.other_ids; - all_ids.insert(self.context.my_id); + // Share change generated by this node + let my_x = self.context.xs.safe_get("secret share changes", my_id)?; - let mut all_data = self.others_data; - all_data.insert(my_id.clone(), self.context.data_precomp); + // The combined secret share change + let x_star = xs.into_values().sum::>() + my_x; // The combined public share changes for each node - let cap_x_star = all_ids - .iter() - .enumerate() - .map(|(idx, id)| { - Ok(( - id.clone(), - all_data - .values() - .map(|data| data.data.cap_x_to_send.get(idx)) - .sum::>() - .ok_or_else(|| LocalError::new("idx={idx} is missing in cap_x_to_send"))?, - )) - }) - .collect::>()?; + let mut cap_x_star = BTreeMap::new(); - let public_aux = all_data + for id_k in self.context.all_ids.iter() { + let mut result = Point::IDENTITY; + for payload in self.r2_payloads.values() { + let cap_x = payload.cap_xs.safe_get("public share changes", id_k)?; + result = result + *cap_x; + } + cap_x_star.insert(id_k.clone(), result); + } + + let public_aux = self + .r2_payloads .into_iter() - .map(|(id, data)| { - ( - id, - PublicAuxInfo { - el_gamal_pk: data.data.cap_y, - paillier_pk: data.paillier_pk.into_wire(), - rp_params: data.rp_params.to_wire(), - }, - ) + .map(|(id, payload)| { + let aux_info = PublicAuxInfo { + paillier_pk: payload.paillier_pk.into_wire(), + rp_params: payload.rp_params.to_wire(), + }; + (id, aux_info) }) - .collect(); + .collect::>(); let secret_aux = SecretAuxInfo { paillier_sk: self.context.paillier_sk.into_wire(), - el_gamal_sk: self.context.y, }; let key_share_change = KeyShareChange { @@ -867,7 +1172,7 @@ mod tests { .results() .unwrap(); - let (changes, aux_infos): (BTreeMap<_, _>, BTreeMap<_, _>) = results + let (changes, _aux_infos): (BTreeMap<_, _>, BTreeMap<_, _>) = results .into_iter() .map(|(id, (change, aux))| ((id, change), (id, aux))) .unzip(); @@ -882,15 +1187,6 @@ mod tests { } } - for (id, aux_info) in aux_infos.iter() { - for other_aux_info in aux_infos.values() { - assert_eq!( - aux_info.secret_aux.el_gamal_sk.mul_by_generator(), - other_aux_info.public_aux[id].el_gamal_pk - ); - } - } - // The resulting sum of masks should be zero, since the combined secret key // should not change after applying the masks at each node. let mask_sum: Scalar = changes diff --git a/synedrion/src/cggmp21/key_refresh_malicious.rs b/synedrion/src/cggmp21/key_refresh_malicious.rs new file mode 100644 index 0000000..ace96c2 --- /dev/null +++ b/synedrion/src/cggmp21/key_refresh_malicious.rs @@ -0,0 +1,650 @@ +use alloc::collections::{BTreeMap, BTreeSet}; + +use manul::{ + combinators::misbehave::{Misbehaving, MisbehavingEntryPoint}, + dev::{run_sync, BinaryFormat, TestSessionParams, TestSigner, TestVerifier}, + protocol::{ + Artifact, BoxedRound, Deserializer, DirectMessage, EchoBroadcast, EntryPoint, LocalError, NormalBroadcast, + ProtocolMessagePart, Serializer, + }, + session::SessionReport, + signature::Keypair, +}; +use rand_chacha::ChaCha8Rng; +use rand_core::{CryptoRngCore, OsRng, SeedableRng}; + +use super::{ + key_refresh::{ + KeyRefresh, KeyRefreshProtocol, Round1, Round1EchoBroadcast, Round2Broadcast, Round2EchoBroadcast, + Round3Broadcast, Round3DirectMessage, Round3EchoBroadcast, + }, + params::{SchemeParams, TestParams}, + sigma::{FacProof, ModProof, PrmProof, SchCommitment, SchProof, SchSecret}, +}; +use crate::{ + curve::Scalar, + paillier::{PaillierParams, PublicKeyPaillierWire, RPParams, RPParamsWire, RPSecret, SecretKeyPaillierWire}, + tools::{hashing::FofHasher, Secret}, +}; + +type Id = TestVerifier; +type P = TestParams; +type Protocol = KeyRefreshProtocol; +type SP = TestSessionParams; + +fn run_with_one_malicious_party() -> (Vec, BTreeMap>) +where + M: Misbehaving>, +{ + let signers = (0..3).map(TestSigner::new).collect::>(); + let ids = signers.iter().map(|signer| signer.verifying_key()).collect::>(); + let ids_set = BTreeSet::from_iter(ids.clone()); + + let entry_points = signers + .into_iter() + .map(|signer| { + let id = signer.verifying_key(); + let entry_point = KeyRefresh::::new(ids_set.clone()).unwrap(); + let behavior = if id == ids[0] { Some(()) } else { None }; + let entry_point = MisbehavingEntryPoint::::new(entry_point, behavior); + (signer, entry_point) + }) + .collect(); + + let reports = run_sync::<_, SP>(&mut OsRng, entry_points).unwrap().reports; + (ids, reports) +} + +fn check_evidence(expected_description: &str) +where + M: Misbehaving>, +{ + let (ids, mut reports) = run_with_one_malicious_party::(); + + let report0 = reports.remove(&ids[0]).unwrap(); + let report1 = reports.remove(&ids[1]).unwrap(); + let report2 = reports.remove(&ids[2]).unwrap(); + + let ids_set = BTreeSet::from_iter(ids.clone()); + + assert!(report0.provable_errors.is_empty()); + for report in [report1, report2] { + let description = report.provable_errors[&ids[0]].description(); + assert!( + description.starts_with(expected_description), + "Got {description}, expected {expected_description}" + ); + + let verification_result = report.provable_errors[&ids[0]].verify(&ids_set); + assert!(verification_result.is_ok(), "Failed to verify: {verification_result:?}"); + } +} + +#[test] +fn r2_hash_mismatch() { + struct Override; + + impl Misbehaving for Override { + type EntryPoint = KeyRefresh; + + fn modify_echo_broadcast( + _rng: &mut impl CryptoRngCore, + round: &BoxedRound>::Protocol>, + _behavior: &(), + serializer: &Serializer, + _deserializer: &Deserializer, + echo_broadcast: EchoBroadcast, + ) -> Result { + if round.id() == 1 { + // Send a wrong hash in the Round 1 message + let message = Round1EchoBroadcast { + cap_v: FofHasher::new_with_dst(b"bad hash").finalize(), + }; + let echo_broadcast = EchoBroadcast::new(serializer, message)?; + return Ok(echo_broadcast); + } + + Ok(echo_broadcast) + } + } + + check_evidence::("Protocol error: R2HashMismatch"); +} + +#[test] +fn r2_wrong_ids_x() { + struct Override; + + impl Misbehaving for Override { + type EntryPoint = KeyRefresh; + + fn modify_normal_broadcast( + _rng: &mut impl CryptoRngCore, + round: &BoxedRound>::Protocol>, + _behavior: &(), + serializer: &Serializer, + deserializer: &Deserializer, + normal_broadcast: NormalBroadcast, + ) -> Result { + if round.id() == 2 { + let mut message = normal_broadcast + .deserialize::>(deserializer) + .unwrap(); + message.cap_xs.pop_first(); + let normal_broadcast = NormalBroadcast::new(serializer, message)?; + return Ok(normal_broadcast); + } + + Ok(normal_broadcast) + } + + fn modify_echo_broadcast( + _rng: &mut impl CryptoRngCore, + round: &BoxedRound>::Protocol>, + _behavior: &(), + serializer: &Serializer, + _deserializer: &Deserializer, + echo_broadcast: EchoBroadcast, + ) -> Result { + if round.id() == 1 { + // Technically we only need to modify `X`, but we need to substitute the hash in Round 1 too, + // so that in Round 2 the hash check could pass and the execution reaches the IDs check. + let round1 = round.downcast_ref::>()?; + + let mut r2_normal_broadcast = round1.r2_normal_broadcast.clone(); + r2_normal_broadcast.cap_xs.pop_first(); + + let message = Round1EchoBroadcast { + cap_v: r2_normal_broadcast.hash(&round1.context.sid_hash, &round1.context.my_id), + }; + let echo_broadcast = EchoBroadcast::new(serializer, message)?; + return Ok(echo_broadcast); + } + + Ok(echo_broadcast) + } + } + + check_evidence::("Protocol error: R2WrongIdsX"); +} + +#[test] +fn r2_wrong_ids_y() { + struct Override; + + impl Misbehaving for Override { + type EntryPoint = KeyRefresh; + + fn modify_echo_broadcast( + _rng: &mut impl CryptoRngCore, + round: &BoxedRound>::Protocol>, + _behavior: &(), + serializer: &Serializer, + deserializer: &Deserializer, + echo_broadcast: EchoBroadcast, + ) -> Result { + if round.id() == 2 { + let mut message = echo_broadcast + .deserialize::>(deserializer) + .unwrap(); + message.cap_ys.pop_first(); + let echo_broadcast = EchoBroadcast::new(serializer, message)?; + return Ok(echo_broadcast); + } + + Ok(echo_broadcast) + } + } + + check_evidence::("Protocol error: R2WrongIdsY"); +} + +#[test] +fn r2_wrong_ids_a() { + struct Override; + + impl Misbehaving for Override { + type EntryPoint = KeyRefresh; + + fn modify_normal_broadcast( + _rng: &mut impl CryptoRngCore, + round: &BoxedRound>::Protocol>, + _behavior: &(), + serializer: &Serializer, + deserializer: &Deserializer, + normal_broadcast: NormalBroadcast, + ) -> Result { + if round.id() == 2 { + let mut message = normal_broadcast + .deserialize::>(deserializer) + .unwrap(); + + message.cap_as.pop_first(); + let normal_broadcast = NormalBroadcast::new(serializer, message)?; + return Ok(normal_broadcast); + } + + Ok(normal_broadcast) + } + + fn modify_echo_broadcast( + _rng: &mut impl CryptoRngCore, + round: &BoxedRound>::Protocol>, + _behavior: &(), + serializer: &Serializer, + _deserializer: &Deserializer, + echo_broadcast: EchoBroadcast, + ) -> Result { + if round.id() == 1 { + // Technically we only need to modify `A`, but we need to substitute the hash in Round 1 too, + // so that in Round 2 the hash check could pass and the execution reaches the IDs check. + let round1 = round.downcast_ref::>()?; + + let mut r2_normal_broadcast = round1.r2_normal_broadcast.clone(); + + r2_normal_broadcast.cap_as.pop_first(); + + let message = Round1EchoBroadcast { + cap_v: r2_normal_broadcast.hash(&round1.context.sid_hash, &round1.context.my_id), + }; + let echo_broadcast = EchoBroadcast::new(serializer, message)?; + return Ok(echo_broadcast); + } + + Ok(echo_broadcast) + } + } + + check_evidence::("Protocol error: R2WrongIdsA"); +} + +#[test] +fn r2_paillier_modulus_too_small() { + fn make_small_modulus_pk() -> PublicKeyPaillierWire

{ + let mut rng = ChaCha8Rng::seed_from_u64(123); + let paillier_sk = SecretKeyPaillierWire::

::random_small(&mut rng); + paillier_sk.public_key() + } + + struct Override; + + impl Misbehaving for Override { + type EntryPoint = KeyRefresh; + + fn modify_normal_broadcast( + _rng: &mut impl CryptoRngCore, + round: &BoxedRound>::Protocol>, + _behavior: &(), + serializer: &Serializer, + deserializer: &Deserializer, + normal_broadcast: NormalBroadcast, + ) -> Result { + if round.id() == 2 { + let mut message = normal_broadcast + .deserialize::>(deserializer) + .unwrap(); + message.paillier_pk = make_small_modulus_pk::<

::Paillier>(); + let normal_broadcast = NormalBroadcast::new(serializer, message)?; + return Ok(normal_broadcast); + } + + Ok(normal_broadcast) + } + + fn modify_echo_broadcast( + _rng: &mut impl CryptoRngCore, + round: &BoxedRound>::Protocol>, + _behavior: &(), + serializer: &Serializer, + _deserializer: &Deserializer, + echo_broadcast: EchoBroadcast, + ) -> Result { + if round.id() == 1 { + let round1 = round.downcast_ref::>()?; + let mut r2_normal_broadcast = round1.r2_normal_broadcast.clone(); + r2_normal_broadcast.paillier_pk = make_small_modulus_pk::<

::Paillier>(); + let message = Round1EchoBroadcast { + cap_v: r2_normal_broadcast.hash(&round1.context.sid_hash, &round1.context.my_id), + }; + let echo_broadcast = EchoBroadcast::new(serializer, message)?; + return Ok(echo_broadcast); + } + + Ok(echo_broadcast) + } + } + + check_evidence::("Protocol error: R2PaillierModulusTooSmall"); +} + +#[test] +fn r2_rp_modulus_too_small() { + fn make_small_modulus_rp_params() -> RPParamsWire

{ + let mut rng = ChaCha8Rng::seed_from_u64(123); + RPParams::random_small(&mut rng).to_wire() + } + + struct Override; + + impl Misbehaving for Override { + type EntryPoint = KeyRefresh; + + fn modify_echo_broadcast( + _rng: &mut impl CryptoRngCore, + round: &BoxedRound>::Protocol>, + _behavior: &(), + serializer: &Serializer, + deserializer: &Deserializer, + echo_broadcast: EchoBroadcast, + ) -> Result { + if round.id() == 2 { + let mut message = echo_broadcast + .deserialize::>(deserializer) + .unwrap(); + message.rp_params = make_small_modulus_rp_params::<

::Paillier>(); + let echo_broadcast = EchoBroadcast::new(serializer, message)?; + return Ok(echo_broadcast); + } + + Ok(echo_broadcast) + } + } + + check_evidence::("Protocol error: R2RPModulusTooSmall"); +} + +#[test] +fn r2_non_zero_sum_of_changes() { + struct Override; + + impl Misbehaving for Override { + type EntryPoint = KeyRefresh; + + fn modify_echo_broadcast( + _rng: &mut impl CryptoRngCore, + round: &BoxedRound>::Protocol>, + _behavior: &(), + serializer: &Serializer, + _deserializer: &Deserializer, + echo_broadcast: EchoBroadcast, + ) -> Result { + if round.id() == 1 { + let round1 = round.downcast_ref::>()?; + let mut r2_normal_broadcast = round1.r2_normal_broadcast.clone(); + + let (id, _point) = r2_normal_broadcast.cap_xs.pop_first().unwrap(); + let mut rng = ChaCha8Rng::seed_from_u64(123); + r2_normal_broadcast + .cap_xs + .insert(id, Scalar::random(&mut rng).mul_by_generator()); + + let message = Round1EchoBroadcast { + cap_v: r2_normal_broadcast.hash(&round1.context.sid_hash, &round1.context.my_id), + }; + let echo_broadcast = EchoBroadcast::new(serializer, message)?; + return Ok(echo_broadcast); + } + + Ok(echo_broadcast) + } + + fn modify_normal_broadcast( + _rng: &mut impl CryptoRngCore, + round: &BoxedRound>::Protocol>, + _behavior: &(), + serializer: &Serializer, + deserializer: &Deserializer, + normal_broadcast: NormalBroadcast, + ) -> Result { + if round.id() == 2 { + let mut message = normal_broadcast + .deserialize::>(deserializer) + .unwrap(); + + let (id, _point) = message.cap_xs.pop_first().unwrap(); + let mut rng = ChaCha8Rng::seed_from_u64(123); + message.cap_xs.insert(id, Scalar::random(&mut rng).mul_by_generator()); + + let normal_broadcast = NormalBroadcast::new(serializer, message)?; + return Ok(normal_broadcast); + } + + Ok(normal_broadcast) + } + } + + check_evidence::("Protocol error: R2NonZeroSumOfChanges"); +} + +#[test] +fn r2_prm_failed() { + struct Override; + + impl Misbehaving for Override { + type EntryPoint = KeyRefresh; + + fn modify_echo_broadcast( + _rng: &mut impl CryptoRngCore, + round: &BoxedRound>::Protocol>, + _behavior: &(), + serializer: &Serializer, + _deserializer: &Deserializer, + echo_broadcast: EchoBroadcast, + ) -> Result { + if round.id() == 1 { + let round1 = round.downcast_ref::>()?; + let mut r2_normal_broadcast = round1.r2_normal_broadcast.clone(); + + let mut rng = ChaCha8Rng::seed_from_u64(123); + let secret = RPSecret::random(&mut rng); + let rp_params = RPParams::random_with_secret(&mut rng, &secret); + r2_normal_broadcast.psi = PrmProof::new(&mut rng, &secret, &rp_params, &1u8); + + let message = Round1EchoBroadcast { + cap_v: r2_normal_broadcast.hash(&round1.context.sid_hash, &round1.context.my_id), + }; + let echo_broadcast = EchoBroadcast::new(serializer, message)?; + return Ok(echo_broadcast); + } + + Ok(echo_broadcast) + } + + fn modify_normal_broadcast( + _rng: &mut impl CryptoRngCore, + round: &BoxedRound>::Protocol>, + _behavior: &(), + serializer: &Serializer, + deserializer: &Deserializer, + normal_broadcast: NormalBroadcast, + ) -> Result { + if round.id() == 2 { + let mut message = normal_broadcast + .deserialize::>(deserializer) + .unwrap(); + + let mut rng = ChaCha8Rng::seed_from_u64(123); + let secret = RPSecret::random(&mut rng); + let rp_params = RPParams::random_with_secret(&mut rng, &secret); + message.psi = PrmProof::new(&mut rng, &secret, &rp_params, &1u8); + + let normal_broadcast = NormalBroadcast::new(serializer, message)?; + return Ok(normal_broadcast); + } + + Ok(normal_broadcast) + } + } + + check_evidence::("Protocol error: R2PrmFailed"); +} + +#[test] +fn r3_share_change_mismatch() { + struct Override; + + impl Misbehaving for Override { + type EntryPoint = KeyRefresh; + + fn modify_direct_message( + rng: &mut impl CryptoRngCore, + round: &BoxedRound>::Protocol>, + _behavior: &(), + serializer: &Serializer, + deserializer: &Deserializer, + _destination: &Id, + direct_message: DirectMessage, + artifact: Option, + ) -> Result<(DirectMessage, Option), LocalError> { + if round.id() == 3 { + let mut message = direct_message + .deserialize::>(deserializer) + .unwrap(); + message.cap_c = Scalar::random(rng); + let direct_message = DirectMessage::new(serializer, message)?; + return Ok((direct_message, artifact)); + } + + Ok((direct_message, artifact)) + } + } + + check_evidence::("Protocol error: R3ShareChangeMismatch"); +} + +#[test] +fn r3_mod_failed() { + struct Override; + + impl Misbehaving for Override { + type EntryPoint = KeyRefresh; + + fn modify_normal_broadcast( + rng: &mut impl CryptoRngCore, + round: &BoxedRound>::Protocol>, + _behavior: &(), + serializer: &Serializer, + deserializer: &Deserializer, + normal_broadcast: NormalBroadcast, + ) -> Result { + if round.id() == 3 { + let mut message = normal_broadcast + .deserialize::>(deserializer) + .unwrap(); + + let sk = SecretKeyPaillierWire::random(rng).into_precomputed(); + message.psi_prime = ModProof::new(rng, &sk, &1u8); + + let normal_broadcast = NormalBroadcast::new(serializer, message)?; + return Ok(normal_broadcast); + } + + Ok(normal_broadcast) + } + } + + check_evidence::("Protocol error: R3ModFailed"); +} + +#[test] +fn r3_fac_failed() { + struct Override; + + impl Misbehaving for Override { + type EntryPoint = KeyRefresh; + + fn modify_direct_message( + rng: &mut impl CryptoRngCore, + round: &BoxedRound>::Protocol>, + _behavior: &(), + serializer: &Serializer, + deserializer: &Deserializer, + _destination: &Id, + direct_message: DirectMessage, + artifact: Option, + ) -> Result<(DirectMessage, Option), LocalError> { + if round.id() == 3 { + let mut message = direct_message + .deserialize::>(deserializer) + .unwrap(); + let sk = SecretKeyPaillierWire::random(&mut OsRng).into_precomputed(); + let rp_params = RPParams::random(rng); + message.psi = FacProof::new(rng, &sk, &rp_params, &1u8); + let direct_message = DirectMessage::new(serializer, message)?; + return Ok((direct_message, artifact)); + } + + Ok((direct_message, artifact)) + } + } + + check_evidence::("Protocol error: R3FacFailed"); +} + +#[test] +fn r3_wrong_ids_hat_psi() { + struct Override; + + impl Misbehaving for Override { + type EntryPoint = KeyRefresh; + + fn modify_echo_broadcast( + _rng: &mut impl CryptoRngCore, + round: &BoxedRound>::Protocol>, + _behavior: &(), + serializer: &Serializer, + deserializer: &Deserializer, + echo_broadcast: EchoBroadcast, + ) -> Result { + if round.id() == 3 { + let mut message = echo_broadcast + .deserialize::>(deserializer) + .unwrap(); + message.hat_psis.pop_first(); + let echo_broadcast = EchoBroadcast::new(serializer, message)?; + return Ok(echo_broadcast); + } + + Ok(echo_broadcast) + } + } + + check_evidence::("Protocol error: R3WrongIdsHatPsi"); +} + +#[test] +fn r3_sch_failed() { + struct Override; + + impl Misbehaving for Override { + type EntryPoint = KeyRefresh; + + fn modify_echo_broadcast( + rng: &mut impl CryptoRngCore, + round: &BoxedRound>::Protocol>, + _behavior: &(), + serializer: &Serializer, + deserializer: &Deserializer, + echo_broadcast: EchoBroadcast, + ) -> Result { + if round.id() == 3 { + let mut message = echo_broadcast + .deserialize::>(deserializer) + .unwrap(); + let (id, _hat_psi) = message.hat_psis.pop_last().unwrap(); + let x = Secret::init_with(|| Scalar::random(rng)); + let cap_x = x.mul_by_generator(); + let secret = SchSecret::random(rng); + let commitment = SchCommitment::new(&secret); + let hat_psi = SchProof::new(&secret, &x, &commitment, &cap_x, &1u8); + message.hat_psis.insert(id, hat_psi); + let echo_broadcast = EchoBroadcast::new(serializer, message)?; + return Ok(echo_broadcast); + } + + Ok(echo_broadcast) + } + } + + check_evidence::("Protocol error: R3SchFailed"); +} diff --git a/synedrion/src/paillier/encryption.rs b/synedrion/src/paillier/encryption.rs index 7d2b60a..e3f37e7 100644 --- a/synedrion/src/paillier/encryption.rs +++ b/synedrion/src/paillier/encryption.rs @@ -53,13 +53,6 @@ impl Randomizer

{ Self::new(pk, randomizer) } - /// Expose this secret randomizer. - /// - /// Supposed to be used in certain error branches where it is needed to generate a malicious behavior evidence. - pub fn expose(&self) -> P::Uint { - *self.randomizer.expose_secret() - } - /// Converts the randomizer to a publishable form by masking it with another randomizer and a public exponent. pub fn to_masked(&self, coeff: &Self, exponent: &PublicSigned) -> MaskedRandomizer

{ MaskedRandomizer( @@ -209,6 +202,7 @@ impl Ciphertext

{ } /// Encrypts the plaintext with a random randomizer. + #[cfg(test)] pub fn new(rng: &mut impl CryptoRngCore, pk: &PublicKeyPaillier

, plaintext: &SecretSigned) -> Self { Self::new_with_randomizer(pk, plaintext, &Randomizer::random(rng, pk)) } @@ -461,7 +455,10 @@ mod tests { let randomizer = Randomizer::random(&mut OsRng, pk); let ciphertext = Ciphertext::::new_with_randomizer(pk, &plaintext, &randomizer); let randomizer_back = ciphertext.derive_randomizer(&sk); - assert_eq!(randomizer.expose(), randomizer_back.expose()); + assert_eq!( + randomizer.randomizer.expose_secret(), + randomizer_back.randomizer.expose_secret() + ); } #[test] diff --git a/synedrion/src/paillier/keys.rs b/synedrion/src/paillier/keys.rs index 4520a93..c9db130 100644 --- a/synedrion/src/paillier/keys.rs +++ b/synedrion/src/paillier/keys.rs @@ -27,6 +27,13 @@ pub(crate) struct SecretKeyPaillierWire { } impl SecretKeyPaillierWire

{ + #[cfg(test)] + pub fn random_small(rng: &mut impl CryptoRngCore) -> Self { + Self { + primes: SecretPrimesWire::

::random_small_paillier_blum(rng), + } + } + pub fn random(rng: &mut impl CryptoRngCore) -> Self { Self { primes: SecretPrimesWire::

::random_paillier_blum(rng), @@ -293,6 +300,10 @@ impl PublicKeyPaillierWire

{ } } + pub fn modulus(&self) -> &P::Uint { + self.modulus.modulus() + } + pub fn into_precomputed(self) -> PublicKeyPaillier

{ PublicKeyPaillier::new(self.modulus.into_precomputed()) } diff --git a/synedrion/src/paillier/ring_pedersen.rs b/synedrion/src/paillier/ring_pedersen.rs index 5369157..c8409fc 100644 --- a/synedrion/src/paillier/ring_pedersen.rs +++ b/synedrion/src/paillier/ring_pedersen.rs @@ -22,6 +22,20 @@ pub(crate) struct RPSecret { } impl RPSecret

{ + #[cfg(test)] + pub fn random_small(rng: &mut impl CryptoRngCore) -> Self { + let primes = SecretPrimesWire::

::random_small_safe(rng).into_precomputed(); + let bound = NonZero::new(primes.totient().expose_secret().wrapping_shr_vartime(2)) + .expect("totient / 4 is still non-zero because p, q >= 5"); + let lambda = SecretUnsigned::new( + Secret::init_with(|| P::Uint::random_mod(rng, &bound)), + P::MODULUS_BITS - 2, + ) + .expect("totient < N < 2^MODULUS_BITS, so totient / 4 < 2^(MODULUS_BITS - 2)"); + + Self { primes, lambda } + } + pub fn random(rng: &mut impl CryptoRngCore) -> Self { let primes = SecretPrimesWire::

::random_safe(rng).into_precomputed(); @@ -70,6 +84,12 @@ pub(crate) struct RPParams { } impl RPParams

{ + #[cfg(test)] + pub fn random_small(rng: &mut impl CryptoRngCore) -> Self { + let secret = RPSecret::random_small(rng); + Self::random_with_secret(rng, &secret) + } + pub fn random(rng: &mut impl CryptoRngCore) -> Self { let secret = RPSecret::random(rng); Self::random_with_secret(rng, &secret) @@ -152,6 +172,10 @@ pub(crate) struct RPParamsWire { } impl RPParamsWire

{ + pub fn modulus(&self) -> &P::Uint { + self.modulus.modulus() + } + pub fn to_precomputed(&self) -> RPParams

{ let modulus = self.modulus.clone().into_precomputed(); let base_randomizer = self.base_randomizer.to_montgomery(modulus.monty_params_mod_n()); diff --git a/synedrion/src/paillier/rsa.rs b/synedrion/src/paillier/rsa.rs index ec53074..8f8c261 100644 --- a/synedrion/src/paillier/rsa.rs +++ b/synedrion/src/paillier/rsa.rs @@ -10,6 +10,16 @@ use crate::{ uint::{FromXofReader, HasWide, IsInvertible, PublicSigned, SecretSigned, SecretUnsigned, ToMontgomery}, }; +#[cfg(test)] +fn random_small_paillier_blum_prime(rng: &mut impl CryptoRngCore) -> P::HalfUint { + loop { + let prime = P::HalfUint::generate_prime_with_rng(rng, P::PRIME_BITS - 2); + if prime.as_ref().first().expect("First Limb exists").0 & 3 == 3 { + return prime; + } + } +} + fn random_paillier_blum_prime(rng: &mut impl CryptoRngCore) -> P::HalfUint { loop { let prime = P::HalfUint::generate_prime_with_rng(rng, P::PRIME_BITS); @@ -44,6 +54,15 @@ impl SecretPrimesWire

{ Self { p, q } } + /// Creates smaller than required primes to trigger an error during tests. + #[cfg(test)] + pub fn random_small_paillier_blum(rng: &mut impl CryptoRngCore) -> Self { + Self::new( + Secret::init_with(|| random_small_paillier_blum_prime::

(rng)), + Secret::init_with(|| random_small_paillier_blum_prime::

(rng)), + ) + } + /// Creates the primes for a Paillier-Blum modulus, /// that is `p` and `q` are regular primes with an additional condition `p, q mod 3 = 4`. pub fn random_paillier_blum(rng: &mut impl CryptoRngCore) -> Self { @@ -53,6 +72,15 @@ impl SecretPrimesWire

{ ) } + /// Creates smaller than required primes to trigger an error during tests. + #[cfg(test)] + pub fn random_small_safe(rng: &mut impl CryptoRngCore) -> Self { + Self::new( + Secret::init_with(|| P::HalfUint::generate_safe_prime_with_rng(rng, P::PRIME_BITS - 2)), + Secret::init_with(|| P::HalfUint::generate_safe_prime_with_rng(rng, P::PRIME_BITS - 2)), + ) + } + /// Creates a pair of safe primes. pub fn random_safe(rng: &mut impl CryptoRngCore) -> Self { Self::new(