From 314c2eb376562284a7ee018327c4ed6ee478a199 Mon Sep 17 00:00:00 2001 From: Brandon Pitman Date: Thu, 12 Dec 2024 16:14:47 -0800 Subject: [PATCH] PR review. --- Cargo.lock | 1 + aggregator/Cargo.toml | 1 + .../src/aggregator/aggregation_job_driver.rs | 40 ++++++++++++++----- aggregator_core/src/datastore.rs | 5 +++ aggregator_core/src/datastore/models.rs | 27 +++++++++++++ 5 files changed, 63 insertions(+), 11 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5480ade20..40c9ff153 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2756,6 +2756,7 @@ dependencies = [ "rayon", "regex", "reqwest", + "retry-after", "rstest", "rustc_version", "rustls", diff --git a/aggregator/Cargo.toml b/aggregator/Cargo.toml index 6c7d8796c..975b4995a 100644 --- a/aggregator/Cargo.toml +++ b/aggregator/Cargo.toml @@ -78,6 +78,7 @@ regex = { workspace = true } reqwest = { workspace = true, features = ["json"] } rustls = { workspace = true } rustls-pemfile = { workspace = true } +retry-after.workspace = true sec1.workspace = true serde.workspace = true serde_json.workspace = true diff --git a/aggregator/src/aggregator/aggregation_job_driver.rs b/aggregator/src/aggregator/aggregation_job_driver.rs index 937bef26e..18eaeb8fe 100644 --- a/aggregator/src/aggregator/aggregation_job_driver.rs +++ b/aggregator/src/aggregator/aggregation_job_driver.rs @@ -53,7 +53,13 @@ use prio::{ }; use rayon::iter::{IndexedParallelIterator as _, IntoParallelIterator as _, ParallelIterator as _}; use reqwest::Method; -use std::{collections::HashSet, panic, str::FromStr, sync::Arc, time::Duration}; +use retry_after::RetryAfter; +use std::{ + collections::HashSet, + panic, + sync::Arc, + time::{Duration, UNIX_EPOCH}, +}; use tokio::{join, sync::mpsc, try_join}; use tracing::{debug, error, info, info_span, trace_span, warn, Span}; @@ -885,7 +891,7 @@ where aggregation_job: AggregationJob, stepped_aggregations: Vec>, report_aggregations_to_write: Vec>, - retry_after: Option<&Duration>, + retry_after: Option<&RetryAfter>, helper_resp: AggregationJobResp, ) -> Result<(), Error> where @@ -899,7 +905,6 @@ where A::PublicShare: Send + Sync, { match helper_resp { - // TODO(#3436): implement asynchronous aggregation AggregationJobResp::Processing => { self.process_response_from_helper_pending( datastore, @@ -944,7 +949,7 @@ where aggregation_job: AggregationJob, stepped_aggregations: Vec>, mut report_aggregations_to_write: Vec>, - retry_after: Option<&Duration>, + retry_after: Option<&RetryAfter>, ) -> Result<(), Error> where A::AggregationParam: Send + Sync + Eq + PartialEq, @@ -993,7 +998,8 @@ where let aggregation_job_writer = Arc::new(aggregation_job_writer); let retry_after = retry_after - .copied() + .map(|ra| retry_after_to_duration(datastore.clock(), ra)) + .transpose()? .or_else(|| Some(Duration::from_secs(60))); let counters = datastore .run_tx("process_response_from_helper_pending", |tx| { @@ -1514,10 +1520,22 @@ struct SteppedAggregation, } -fn parse_retry_after(header_value: &HeaderValue) -> Result { - let val = header_value - .to_str() - .map_err(|err| Error::BadRequest(err.to_string()))?; - let val = u64::from_str(val).map_err(|err| Error::BadRequest(err.to_string()))?; - Ok(Duration::from_secs(val)) +fn parse_retry_after(header_value: &HeaderValue) -> Result { + RetryAfter::try_from(header_value) + .map_err(|err| Error::BadRequest(format!("couldn't parse retry-after header: {err}"))) +} + +fn retry_after_to_duration(clock: &C, retry_after: &RetryAfter) -> Result { + match retry_after { + RetryAfter::Delay(duration) => Ok(*duration), + RetryAfter::DateTime(next_retry_time) => { + let now = UNIX_EPOCH + Duration::from_secs(clock.now().as_seconds_since_epoch()); + if &now > next_retry_time { + return Ok(Duration::ZERO); + } + next_retry_time + .duration_since(now) + .map_err(|err| Error::Internal(format!("computing retry-after duration: {err}"))) + } + } } diff --git a/aggregator_core/src/datastore.rs b/aggregator_core/src/datastore.rs index 267dbfffb..ce22317d7 100644 --- a/aggregator_core/src/datastore.rs +++ b/aggregator_core/src/datastore.rs @@ -381,6 +381,11 @@ impl Datastore { (rslt, retry.load(Ordering::Relaxed)) } + /// Returns the clock in use by this datastore. + pub fn clock(&self) -> &C { + &self.clock + } + /// See [`Datastore::run_tx`]. This method provides a placeholder transaction name. It is useful /// for tests where the transaction name is not important. #[cfg(feature = "test-util")] diff --git a/aggregator_core/src/datastore/models.rs b/aggregator_core/src/datastore/models.rs index 8175d1273..cc5364841 100644 --- a/aggregator_core/src/datastore/models.rs +++ b/aggregator_core/src/datastore/models.rs @@ -886,6 +886,7 @@ impl> PartialEq where A::InputShare: PartialEq, A::PrepareShare: PartialEq, + A::PrepareState: PartialEq, A::PublicShare: PartialEq, A::OutputShare: PartialEq, { @@ -909,6 +910,7 @@ impl> Eq where A::InputShare: Eq, A::PrepareShare: Eq, + A::PrepareState: Eq, A::PublicShare: Eq, A::OutputShare: Eq, { @@ -1112,6 +1114,7 @@ impl> PartialEq where A::InputShare: PartialEq, A::PrepareShare: PartialEq, + A::PrepareState: PartialEq, A::PublicShare: PartialEq, A::OutputShare: PartialEq, { @@ -1139,6 +1142,7 @@ where && lhs_leader_input_share == rhs_leader_input_share && lhs_helper_encrypted_input_share == rhs_helper_encrypted_input_share } + ( Self::LeaderContinue { transition: lhs_transition, @@ -1147,6 +1151,26 @@ where transition: rhs_transition, }, ) => lhs_transition == rhs_transition, + + ( + Self::LeaderPoll { + leader_state: lhs_leader_state, + }, + Self::LeaderPoll { + leader_state: rhs_leader_state, + }, + ) => match (lhs_leader_state, rhs_leader_state) { + ( + PingPongState::Continued(lhs_prepare_state), + PingPongState::Continued(rhs_prepare_state), + ) => lhs_prepare_state == rhs_prepare_state, + ( + PingPongState::Finished(lhs_output_share), + PingPongState::Finished(rhs_output_share), + ) => lhs_output_share == rhs_output_share, + _ => false, + }, + ( Self::HelperContinue { prepare_state: lhs_state, @@ -1155,6 +1179,7 @@ where prepare_state: rhs_state, }, ) => lhs_state == rhs_state, + ( Self::Failed { report_error: lhs_report_error, @@ -1163,6 +1188,7 @@ where report_error: rhs_report_error, }, ) => lhs_report_error == rhs_report_error, + _ => core::mem::discriminant(self) == core::mem::discriminant(other), } } @@ -1177,6 +1203,7 @@ impl> Eq where A::InputShare: Eq, A::PrepareShare: Eq, + A::PrepareState: Eq, A::PublicShare: Eq, A::OutputShare: Eq, {