diff --git a/src/vdaf.rs b/src/vdaf.rs index 8ea17b71..b53a6f58 100644 --- a/src/vdaf.rs +++ b/src/vdaf.rs @@ -303,7 +303,16 @@ pub trait Aggregator: Vda &self, agg_param: &Self::AggregationParam, output_shares: M, - ) -> Result; + ) -> Result { + let mut share = self.aggregate_init(agg_param); + for output_share in output_shares { + share.accumulate(&output_share)?; + } + Ok(share) + } + + /// Create an empty aggregate share. + fn aggregate_init(&self, agg_param: &Self::AggregationParam) -> Self::AggregateShare; /// Validates an aggregation parameter with respect to all previous aggregaiton parameters used /// for the same input share. `prev` MUST be sorted from least to most recently used. diff --git a/src/vdaf/dummy.rs b/src/vdaf/dummy.rs index 1a78e3ee..9fec3a68 100644 --- a/src/vdaf/dummy.rs +++ b/src/vdaf/dummy.rs @@ -158,16 +158,8 @@ impl vdaf::Aggregator<0, 16> for Vdaf { (self.prep_step_fn)(&state) } - fn aggregate>( - &self, - _aggregation_param: &Self::AggregationParam, - output_shares: M, - ) -> Result { - let mut aggregate_share = AggregateShare(0); - for output_share in output_shares { - aggregate_share.accumulate(&output_share)?; - } - Ok(aggregate_share) + fn aggregate_init(&self, _agg_param: &Self::AggregationParam) -> Self::AggregateShare { + AggregateShare(0) } fn is_agg_param_valid(_cur: &Self::AggregationParam, _prev: &[Self::AggregationParam]) -> bool { diff --git a/src/vdaf/mastic.rs b/src/vdaf/mastic.rs index d16fbb09..15924165 100644 --- a/src/vdaf/mastic.rs +++ b/src/vdaf/mastic.rs @@ -725,23 +725,15 @@ where Ok(PrepareTransition::Finish(output_shares)) } - fn aggregate>>( - &self, - agg_param: &MasticAggregationParam, - output_shares: M, - ) -> Result, VdafError> { - let mut agg_share = MasticAggregateShare::::from(vec![ + fn aggregate_init(&self, agg_param: &Self::AggregationParam) -> Self::AggregateShare { + MasticAggregateShare::::from(vec![ T::Field::zero(); self.vidpf.weight_parameter * agg_param .level_and_prefixes .prefixes() .len() - ]); - for output_share in output_shares.into_iter() { - agg_share.accumulate(&output_share)?; - } - Ok(agg_share) + ]) } } diff --git a/src/vdaf/poplar1.rs b/src/vdaf/poplar1.rs index 5a54eabd..49d36654 100644 --- a/src/vdaf/poplar1.rs +++ b/src/vdaf/poplar1.rs @@ -1251,15 +1251,10 @@ impl, const SEED_SIZE: usize> Aggregator } } - fn aggregate>( - &self, - agg_param: &Poplar1AggregationParam, - output_shares: M, - ) -> Result { - aggregate( + fn aggregate_init(&self, agg_param: &Self::AggregationParam) -> Self::AggregateShare { + Poplar1FieldVec::zero( usize::from(agg_param.level) == self.bits - 1, agg_param.prefixes.len(), - output_shares, ) } diff --git a/src/vdaf/prio2.rs b/src/vdaf/prio2.rs index 96a8f5a3..dd35e1e3 100644 --- a/src/vdaf/prio2.rs +++ b/src/vdaf/prio2.rs @@ -317,17 +317,8 @@ impl Aggregator<32, 16> for Prio2 { Ok(PrepareTransition::Finish(OutputShare::from(data))) } - fn aggregate>>( - &self, - _agg_param: &Self::AggregationParam, - out_shares: M, - ) -> Result, VdafError> { - let mut agg_share = AggregateShare(vec![FieldPrio2::zero(); self.input_len]); - for out_share in out_shares.into_iter() { - agg_share.accumulate(&out_share)?; - } - - Ok(agg_share) + fn aggregate_init(&self, _agg_param: &Self::AggregationParam) -> Self::AggregateShare { + AggregateShare(vec![FieldPrio2::zero(); self.input_len]) } /// Returns `true` iff `prev.is_empty()` diff --git a/src/vdaf/prio3.rs b/src/vdaf/prio3.rs index 22250203..7ac8233b 100644 --- a/src/vdaf/prio3.rs +++ b/src/vdaf/prio3.rs @@ -1573,18 +1573,8 @@ where Ok(PrepareTransition::Finish(output_share)) } - /// Aggregates a sequence of output shares into an aggregate share. - fn aggregate>>( - &self, - _agg_param: &(), - output_shares: It, - ) -> Result, VdafError> { - let mut agg_share = AggregateShare(vec![T::Field::zero(); self.typ.output_len()]); - for output_share in output_shares.into_iter() { - agg_share.accumulate(&output_share)?; - } - - Ok(agg_share) + fn aggregate_init(&self, _agg_param: &Self::AggregationParam) -> Self::AggregateShare { + AggregateShare(vec![T::Field::zero(); self.typ.output_len()]) } /// Returns `true` iff `prev.is_empty()`