diff --git a/src/iter/arrays.rs b/src/iter/arrays.rs new file mode 100644 index 000000000..a60f2ba74 --- /dev/null +++ b/src/iter/arrays.rs @@ -0,0 +1,205 @@ +use super::plumbing::*; +use super::*; + +/// `Arrays` is an iterator that groups elements of an underlying iterator. +/// +/// This struct is created by the [`arrays()`] method on [`IndexedParallelIterator`] +/// +/// [`arrays()`]: trait.IndexedParallelIterator.html#method.arrays +/// [`IndexedParallelIterator`]: trait.IndexedParallelIterator.html +#[must_use = "iterator adaptors are lazy and do nothing unless consumed"] +#[derive(Debug, Clone)] +pub struct Arrays +where + I: IndexedParallelIterator, +{ + iter: I, +} + +impl Arrays +where + I: IndexedParallelIterator, +{ + /// Creates a new `Arrays` iterator + pub(super) fn new(iter: I) -> Self { + Arrays { iter } + } +} + +impl ParallelIterator for Arrays +where + I: IndexedParallelIterator, +{ + type Item = [I::Item; N]; + + fn drive_unindexed(self, consumer: C) -> C::Result + where + C: Consumer, + { + bridge(self, consumer) + } + + fn opt_len(&self) -> Option { + Some(self.len()) + } +} + +impl IndexedParallelIterator for Arrays +where + I: IndexedParallelIterator, +{ + fn drive(self, consumer: C) -> C::Result + where + C: Consumer, + { + bridge(self, consumer) + } + + fn len(&self) -> usize { + self.iter.len() / N + } + + fn with_producer(self, callback: CB) -> CB::Output + where + CB: ProducerCallback, + { + let len = self.iter.len(); + return self.iter.with_producer(Callback { len, callback }); + + struct Callback { + len: usize, + callback: CB, + } + + impl ProducerCallback for Callback + where + CB: ProducerCallback<[T; N]>, + { + type Output = CB::Output; + + fn callback

(self, base: P) -> CB::Output + where + P: Producer, + { + self.callback.callback(ArrayProducer { + len: self.len, + base, + }) + } + } + } +} + +struct ArrayProducer +where + P: Producer, +{ + len: usize, + base: P, +} + +impl Producer for ArrayProducer +where + P: Producer, +{ + type Item = [P::Item; N]; + type IntoIter = ArraySeq; + + fn into_iter(self) -> Self::IntoIter { + // TODO: we're ignoring any remainder -- should we no-op consume it? + let remainder = self.len % N; + let len = self.len - remainder; + let inner = (len > 0).then(|| self.base.split_at(len).0); + ArraySeq { len, inner } + } + + fn split_at(self, index: usize) -> (Self, Self) { + let elem_index = index * N; + let (left, right) = self.base.split_at(elem_index); + ( + ArrayProducer { + len: elem_index, + base: left, + }, + ArrayProducer { + len: self.len - elem_index, + base: right, + }, + ) + } + + fn min_len(&self) -> usize { + self.base.min_len() / N + } + + fn max_len(&self) -> usize { + self.base.max_len() / N + } +} + +struct ArraySeq { + len: usize, + inner: Option

, +} + +impl Iterator for ArraySeq +where + P: Producer, +{ + type Item = [P::Item; N]; + + fn next(&mut self) -> Option { + let mut producer = self.inner.take()?; + debug_assert!(self.len > 0 && self.len % N == 0); + if self.len > N { + let (left, right) = producer.split_at(N); + producer = left; + self.inner = Some(right); + self.len -= N; + } else { + self.len = 0; + } + Some(collect_array(producer.into_iter())) + } + + fn size_hint(&self) -> (usize, Option) { + let len = self.len(); + (len, Some(len)) + } +} + +impl ExactSizeIterator for ArraySeq +where + P: Producer, +{ + #[inline] + fn len(&self) -> usize { + self.len / N + } +} + +impl DoubleEndedIterator for ArraySeq +where + P: Producer, +{ + fn next_back(&mut self) -> Option { + let mut producer = self.inner.take()?; + debug_assert!(self.len > 0 && self.len % N == 0); + if self.len > N { + let (left, right) = producer.split_at(self.len - N); + producer = right; + self.inner = Some(left); + self.len -= N; + } else { + self.len = 0; + } + Some(collect_array(producer.into_iter())) + } +} + +fn collect_array(mut iter: impl ExactSizeIterator) -> [T; N] { + debug_assert_eq!(iter.len(), N); + let array = std::array::from_fn(|_| iter.next().expect("should have N items")); + debug_assert!(iter.next().is_none()); + array +} diff --git a/src/iter/mod.rs b/src/iter/mod.rs index 7b5a29aeb..75509bd63 100644 --- a/src/iter/mod.rs +++ b/src/iter/mod.rs @@ -102,6 +102,7 @@ mod test; // e.g. `find::find()`, are always used **prefixed**, so that they // can be readily distinguished. +mod arrays; mod chain; mod chunks; mod cloned; @@ -159,6 +160,7 @@ mod zip; mod zip_eq; pub use self::{ + arrays::Arrays, chain::Chain, chunks::Chunks, cloned::Cloned, @@ -2544,6 +2546,32 @@ pub trait IndexedParallelIterator: ParallelIterator { InterleaveShortest::new(self, other.into_par_iter()) } + /// Splits an iterator up into fixed-size arrays. + /// + /// Returns an iterator that returns arrays with the given number of elements. + /// If the number of elements in the iterator is not divisible by `N`, + /// the remaining items are ignored. + /// + /// See also [`par_array_chunks()`] and [`par_array_chunks_mut()`] for similar + /// behavior on slices, although they yield array references instead. + /// + /// [`par_array_chunks()`]: ../slice/trait.ParallelSlice.html#method.par_array_chunks + /// [`par_array_chunks_mut()`]: ../slice/trait.ParallelSliceMut.html#method.par_array_chunks_mut + /// + /// # Examples + /// + /// ``` + /// use rayon::prelude::*; + /// let a = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; + /// let r: Vec<[i32; 3]> = a.into_par_iter().arrays().collect(); + /// assert_eq!(r, vec![[1, 2, 3], [4, 5, 6], [7, 8, 9]]); + /// ``` + #[track_caller] + fn arrays(self) -> Arrays { + assert!(N != 0, "array length must not be zero"); + Arrays::new(self) + } + /// Splits an iterator up into fixed-size chunks. /// /// Returns an iterator that returns `Vec`s of the given number of elements. diff --git a/tests/clones.rs b/tests/clones.rs index 0d6c86487..da0cd1476 100644 --- a/tests/clones.rs +++ b/tests/clones.rs @@ -151,6 +151,7 @@ fn clone_adaptors() { check(v.par_iter().interleave_shortest(&v)); check(v.par_iter().intersperse(&None)); check(v.par_iter().chunks(3)); + check(v.par_iter().arrays::<3>()); check(v.par_iter().map(|x| x)); check(v.par_iter().map_with(0, |_, x| x)); check(v.par_iter().map_init(|| 0, |_, x| x)); diff --git a/tests/debug.rs b/tests/debug.rs index 14f37917b..2705543eb 100644 --- a/tests/debug.rs +++ b/tests/debug.rs @@ -165,6 +165,7 @@ fn debug_adaptors() { check(v.par_iter().interleave_shortest(&v)); check(v.par_iter().intersperse(&-1)); check(v.par_iter().chunks(3)); + check(v.par_iter().arrays::<3>()); check(v.par_iter().map(|x| x)); check(v.par_iter().map_with(0, |_, x| x)); check(v.par_iter().map_init(|| 0, |_, x| x)); diff --git a/tests/producer_split_at.rs b/tests/producer_split_at.rs index d71050492..80ed07f82 100644 --- a/tests/producer_split_at.rs +++ b/tests/producer_split_at.rs @@ -343,6 +343,29 @@ fn chunks() { check(&v, || s.par_iter().cloned().chunks(2)); } +#[test] +fn arrays() { + use std::convert::TryInto; + fn check_len(s: &[i32]) { + let v: Vec<[_; N]> = s.chunks_exact(N).map(|c| c.try_into().unwrap()).collect(); + check(&v, || s.par_iter().copied().arrays::()); + } + + let s: Vec<_> = (0..10).collect(); + check_len::<1>(&s); + check_len::<2>(&s); + check_len::<3>(&s); + check_len::<4>(&s); + check_len::<5>(&s); + check_len::<6>(&s); + check_len::<7>(&s); + check_len::<8>(&s); + check_len::<9>(&s); + check_len::<10>(&s); + check_len::<11>(&s); + check_len::<12>(&s); +} + #[test] fn map() { let v: Vec<_> = (0..10).collect();