Skip to content

Commit

Permalink
refactor advance_by and advance_back_by. Add back cfg for with_critic…
Browse files Browse the repository at this point in the history
…al_section
  • Loading branch information
Owen-CH-Leung committed Jan 15, 2025
1 parent 00e4802 commit b7373aa
Showing 1 changed file with 37 additions and 11 deletions.
48 changes: 37 additions & 11 deletions src/types/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -640,13 +640,14 @@ impl<'py> BoundListIterator<'py> {
list.get_item(target_index).expect("get-item failed")
}
};
*length = Length(target_index);
length.0 = target_index;
Some(item)
} else {
None
}
}

#[cfg(not(Py_LIMITED_API))]
fn with_critical_section<R>(
&mut self,
f: impl FnOnce(&mut Index, &mut Length, &Bound<'py, PyList>) -> R,
Expand Down Expand Up @@ -818,12 +819,25 @@ impl<'py> Iterator for BoundListIterator<'py> {
#[cfg(feature = "nightly")]
fn advance_by(&mut self, n: usize) -> Result<(), NonZero<usize>> {
self.with_critical_section(|index, length, list| {
for i in 0..n {
if unsafe { Self::next_unchecked(index, length, list).is_none() } {
return Err(unsafe { NonZero::new_unchecked(n - i) });
let max_len = length.0.min(list.len());
let currently_at = index.0;
if currently_at >= max_len {
if n == 0 {
return Ok(());
} else {
return Err(unsafe { NonZero::new_unchecked(n) });
}
}
Ok(())

let items_left = max_len - currently_at;
if n <= items_left {
index.0 += n;
Ok(())
} else {
index.0 = max_len;
let remainder = n - items_left;
Err(unsafe { NonZero::new_unchecked(remainder) })
}
})
}
}
Expand Down Expand Up @@ -891,12 +905,25 @@ impl DoubleEndedIterator for BoundListIterator<'_> {
#[cfg(feature = "nightly")]
fn advance_back_by(&mut self, n: usize) -> Result<(), NonZero<usize>> {
self.with_critical_section(|index, length, list| {
for i in 0..n {
if unsafe { Self::next_back_unchecked(index, length, list).is_none() } {
return Err(unsafe { NonZero::new_unchecked(n - i) });
let max_len = length.0.min(list.len());
let currently_at = index.0;
if currently_at >= max_len {
if n == 0 {
return Ok(());
} else {
return Err(unsafe { NonZero::new_unchecked(n) });
}
}
Ok(())

let items_left = max_len - currently_at;
if n <= items_left {
length.0 = max_len - n;
Ok(())
} else {
length.0 = max_len;
let remainder = n - items_left;
Err(unsafe { NonZero::new_unchecked(remainder) })
}
})
}
}
Expand Down Expand Up @@ -1637,8 +1664,7 @@ mod tests {
assert_eq!(iter.next().unwrap().extract::<i32>().unwrap(), 10);

let mut iter = list.iter();
println!("iter.nth_back(1) = {:?}", iter.nth_back(1));
// assert_eq!(iter.nth_back(1).unwrap().extract::<i32>().unwrap(), 9);
assert_eq!(iter.nth_back(1).unwrap().extract::<i32>().unwrap(), 9);
assert_eq!(iter.nth(2).unwrap().extract::<i32>().unwrap(), 8);
assert!(iter.next().is_none());
});
Expand Down

0 comments on commit b7373aa

Please sign in to comment.