Skip to content

Commit

Permalink
Merge pull request #685 from Altair-Bueno/main
Browse files Browse the repository at this point in the history
feat(combinator): try_fold and verify_fold
  • Loading branch information
epage authored Jan 10, 2025
2 parents 4cbcb9c + e7350dc commit 83e2b58
Showing 1 changed file with 284 additions and 0 deletions.
284 changes: 284 additions & 0 deletions src/combinator/multi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use crate::combinator::trace;
use crate::error::ErrMode;
use crate::error::ErrorKind;
use crate::error::FromExternalError;
use crate::error::ParserError;
use crate::stream::Accumulate;
use crate::stream::Range;
Expand Down Expand Up @@ -289,6 +290,157 @@ where
}
})
}

/// Akin to [`Repeat::fold`], but for containers that can reject an element.
///
/// This stops before `n` when the parser returns [`ErrMode::Backtrack`]. To instead chain an error up, see
/// [`cut_err`][crate::combinator::cut_err]. Additionally, if the fold function returns `None`, the parser will
/// stop and return an error.
///
/// # Arguments
/// * `init` A function returning the initial value.
/// * `op` The function that combines a result of `f` with
/// the current accumulator.
///
/// <div class="warning">
///
/// **Warning:** If the parser passed to `repeat` accepts empty inputs
/// (like `alpha0` or `digit0`), `verify_fold` will return an error,
/// to prevent going into an infinite loop.
///
/// </div>
///
/// # Example
///
/// Guaranteeing that the input had unique elements:
/// ```rust
/// # use winnow::error::IResult;
/// # use winnow::{error::ErrMode, error::{InputError, ErrorKind}, error::Needed};
/// # use winnow::prelude::*;
/// use winnow::combinator::repeat;
/// use std::collections::HashSet;
///
/// fn parser(s: &str) -> IResult<&str, HashSet<&str>> {
/// repeat(
/// 0..,
/// "abc"
/// ).verify_fold(
/// HashSet::new,
/// |mut acc: HashSet<_>, item| {
/// if acc.insert(item) {
/// Some(acc)
/// } else {
/// None
/// }
/// }
/// ).parse_peek(s)
/// }
///
/// assert_eq!(parser("abc"), Ok(("", HashSet::from(["abc"]))));
/// assert_eq!(parser("abcabc"), Err(ErrMode::Backtrack(InputError::new("abc", ErrorKind::Verify))));
/// assert_eq!(parser("abc123"), Ok(("123", HashSet::from(["abc"]))));
/// assert_eq!(parser("123123"), Ok(("123123", HashSet::from([]))));
/// assert_eq!(parser(""), Ok(("", HashSet::from([]))));
/// ```
#[inline(always)]
pub fn verify_fold<Init, Op, Result>(
mut self,
mut init: Init,
mut op: Op,
) -> impl Parser<Input, Result, Error>
where
Init: FnMut() -> Result,
Op: FnMut(Result, Output) -> Option<Result>,
{
let Range {
start_inclusive,
end_inclusive,
} = self.occurrences;
trace("repeat_verify_fold", move |input: &mut Input| {
verify_fold_m_n(
start_inclusive,
end_inclusive.unwrap_or(usize::MAX),
&mut self.parser,
&mut init,
&mut op,
input,
)
})
}

/// Akin to [`Repeat::fold`], but for containers that can error when an element is accumulated.
///
/// This stops before `n` when the parser returns [`ErrMode::Backtrack`]. To instead chain an error up, see
/// [`cut_err`][crate::combinator::cut_err]. Additionally, if the fold function returns an error, the parser will
/// stop and return it.
///
/// # Arguments
/// * `init` A function returning the initial value.
/// * `op` The function that combines a result of `f` with
/// the current accumulator.
///
/// <div class="warning">
///
/// **Warning:** If the parser passed to `repeat` accepts empty inputs
/// (like `alpha0` or `digit0`), `try_fold` will return an error,
/// to prevent going into an infinite loop.
///
/// </div>
///
/// # Example
///
/// Writing the output to a vector of bytes:
/// ```rust
/// # use winnow::error::IResult;
/// # use winnow::{error::ErrMode, error::{InputError, ErrorKind}, error::Needed};
/// # use winnow::prelude::*;
/// use winnow::combinator::repeat;
/// use std::io::Write;
/// use std::io::Error;
///
/// fn parser(s: &str) -> IResult<&str, Vec<u8>> {
/// repeat(
/// 0..,
/// "abc"
/// ).try_fold(
/// Vec::new,
/// |mut acc, item: &str| -> Result<_, Error> {
/// acc.write(item.as_bytes())?;
/// Ok(acc)
/// }
/// ).parse_peek(s)
/// }
///
/// assert_eq!(parser("abc"), Ok(("", b"abc".to_vec())));
/// assert_eq!(parser("abc123"), Ok(("123", b"abc".to_vec())));
/// assert_eq!(parser("123123"), Ok(("123123", vec![])));
/// assert_eq!(parser(""), Ok(("", vec![])));
#[inline(always)]
pub fn try_fold<Init, Op, OpError, Result>(
mut self,
mut init: Init,
mut op: Op,
) -> impl Parser<Input, Result, Error>
where
Init: FnMut() -> Result,
Op: FnMut(Result, Output) -> core::result::Result<Result, OpError>,
Error: FromExternalError<Input, OpError>,
{
let Range {
start_inclusive,
end_inclusive,
} = self.occurrences;
trace("repeat_try_fold", move |input: &mut Input| {
try_fold_m_n(
start_inclusive,
end_inclusive.unwrap_or(usize::MAX),
&mut self.parser,
&mut init,
&mut op,
input,
)
})
}
}

impl<P, I, O, C, E> Parser<I, C, E> for Repeat<P, I, O, C, E>
Expand Down Expand Up @@ -1353,3 +1505,135 @@ where

Ok(acc)
}

#[inline(always)]
fn verify_fold_m_n<I, O, E, F, G, H, R>(
min: usize,
max: usize,
parse: &mut F,
init: &mut H,
fold: &mut G,
input: &mut I,
) -> PResult<R, E>
where
I: Stream,
F: Parser<I, O, E>,
G: FnMut(R, O) -> Option<R>,
H: FnMut() -> R,
E: ParserError<I>,
{
if min > max {
return Err(ErrMode::assert(
input,
"range should be ascending, rather than descending",
));
}

let mut acc = init();
for count in 0..max {
let start = input.checkpoint();
let len = input.eof_offset();
match parse.parse_next(input) {
Ok(value) => {
// infinite loop check: the parser must always consume
if input.eof_offset() == len {
return Err(ErrMode::assert(
input,
"`repeat` parsers must always consume",
));
}

let Some(tmp) = fold(acc, value) else {
input.reset(&start);
let res = Err(ErrMode::from_error_kind(input, ErrorKind::Verify));
super::debug::trace_result("verify_fold", &res);
return res;
};
acc = tmp;
}
//FInputXMError: handle failure properly
Err(ErrMode::Backtrack(err)) => {
if count < min {
return Err(ErrMode::Backtrack(err.append(
input,
&start,
ErrorKind::Repeat,
)));
} else {
input.reset(&start);
break;
}
}
Err(e) => return Err(e),
}
}

Ok(acc)
}

#[inline(always)]
fn try_fold_m_n<I, O, E, F, G, H, R, GE>(
min: usize,
max: usize,
parse: &mut F,
init: &mut H,
fold: &mut G,
input: &mut I,
) -> PResult<R, E>
where
I: Stream,
F: Parser<I, O, E>,
G: FnMut(R, O) -> Result<R, GE>,
H: FnMut() -> R,
E: ParserError<I> + FromExternalError<I, GE>,
{
if min > max {
return Err(ErrMode::assert(
input,
"range should be ascending, rather than descending",
));
}

let mut acc = init();
for count in 0..max {
let start = input.checkpoint();
let len = input.eof_offset();
match parse.parse_next(input) {
Ok(value) => {
// infinite loop check: the parser must always consume
if input.eof_offset() == len {
return Err(ErrMode::assert(
input,
"`repeat` parsers must always consume",
));
}

match fold(acc, value) {
Ok(tmp) => acc = tmp,
Err(e) => {
input.reset(&start);
let res = Err(ErrMode::from_external_error(input, ErrorKind::Verify, e));
super::debug::trace_result("try_fold", &res);
return res;
}
}
}
//FInputXMError: handle failure properly
Err(ErrMode::Backtrack(err)) => {
if count < min {
return Err(ErrMode::Backtrack(err.append(
input,
&start,
ErrorKind::Repeat,
)));
} else {
input.reset(&start);
break;
}
}
Err(e) => return Err(e),
}
}

Ok(acc)
}

0 comments on commit 83e2b58

Please sign in to comment.