From b7373aaa145c1ee68e01c39273f658a4df9ff80d Mon Sep 17 00:00:00 2001 From: Owen Leung Date: Thu, 16 Jan 2025 00:06:52 +0800 Subject: [PATCH] refactor advance_by and advance_back_by. Add back cfg for with_critical_section --- src/types/list.rs | 48 ++++++++++++++++++++++++++++++++++++----------- 1 file changed, 37 insertions(+), 11 deletions(-) diff --git a/src/types/list.rs b/src/types/list.rs index a3071a22c17..2d65dd4280f 100644 --- a/src/types/list.rs +++ b/src/types/list.rs @@ -640,13 +640,14 @@ impl<'py> BoundListIterator<'py> { list.get_item(target_index).expect("get-item failed") } }; - *length = Length(target_index); + length.0 = target_index; Some(item) } else { None } } + #[cfg(not(Py_LIMITED_API))] fn with_critical_section( &mut self, f: impl FnOnce(&mut Index, &mut Length, &Bound<'py, PyList>) -> R, @@ -818,12 +819,25 @@ impl<'py> Iterator for BoundListIterator<'py> { #[cfg(feature = "nightly")] fn advance_by(&mut self, n: usize) -> Result<(), NonZero> { self.with_critical_section(|index, length, list| { - for i in 0..n { - if unsafe { Self::next_unchecked(index, length, list).is_none() } { - return Err(unsafe { NonZero::new_unchecked(n - i) }); + let max_len = length.0.min(list.len()); + let currently_at = index.0; + if currently_at >= max_len { + if n == 0 { + return Ok(()); + } else { + return Err(unsafe { NonZero::new_unchecked(n) }); } } - Ok(()) + + let items_left = max_len - currently_at; + if n <= items_left { + index.0 += n; + Ok(()) + } else { + index.0 = max_len; + let remainder = n - items_left; + Err(unsafe { NonZero::new_unchecked(remainder) }) + } }) } } @@ -891,12 +905,25 @@ impl DoubleEndedIterator for BoundListIterator<'_> { #[cfg(feature = "nightly")] fn advance_back_by(&mut self, n: usize) -> Result<(), NonZero> { self.with_critical_section(|index, length, list| { - for i in 0..n { - if unsafe { Self::next_back_unchecked(index, length, list).is_none() } { - return Err(unsafe { NonZero::new_unchecked(n - i) }); + let max_len = length.0.min(list.len()); + let currently_at = index.0; + if currently_at >= max_len { + if n == 0 { + return Ok(()); + } else { + return Err(unsafe { NonZero::new_unchecked(n) }); } } - Ok(()) + + let items_left = max_len - currently_at; + if n <= items_left { + length.0 = max_len - n; + Ok(()) + } else { + length.0 = max_len; + let remainder = n - items_left; + Err(unsafe { NonZero::new_unchecked(remainder) }) + } }) } } @@ -1637,8 +1664,7 @@ mod tests { assert_eq!(iter.next().unwrap().extract::().unwrap(), 10); let mut iter = list.iter(); - println!("iter.nth_back(1) = {:?}", iter.nth_back(1)); - // assert_eq!(iter.nth_back(1).unwrap().extract::().unwrap(), 9); + assert_eq!(iter.nth_back(1).unwrap().extract::().unwrap(), 9); assert_eq!(iter.nth(2).unwrap().extract::().unwrap(), 8); assert!(iter.next().is_none()); });