From 33afa19a5866062f9c130f51099aa2b3ac03e6e7 Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 5 Sep 2024 19:56:21 -0400 Subject: [PATCH] feat: initial to_fsm logic for core regex elements --- src/interegular/fsm.rs | 72 ++++++++--- src/interegular/patterns.rs | 205 ++++++++++++++++++++++++++++--- src/interegular/simple_parser.rs | 6 +- src/lib.rs | 2 +- src/python_bindings/mod.rs | 48 ++++++++ 5 files changed, 296 insertions(+), 37 deletions(-) diff --git a/src/interegular/fsm.rs b/src/interegular/fsm.rs index a98277c9..f042ca8a 100644 --- a/src/interegular/fsm.rs +++ b/src/interegular/fsm.rs @@ -25,10 +25,19 @@ impl From for usize { } } +impl From for u32 { + fn from(c: TransitionKey) -> Self { + match c { + TransitionKey::Symbol(i) => i as u32, + _ => panic!("Cannot convert `anything else` to u32"), + } + } +} + pub trait SymbolTrait: Eq + Hash + Clone + Debug + From {} impl> SymbolTrait for T {} -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub struct Alphabet { pub symbol_mapping: HashMap, pub by_transition: HashMap>, @@ -49,6 +58,14 @@ impl Alphabet { } } + #[must_use] + pub fn empty() -> Self { + Alphabet { + symbol_mapping: HashMap::new(), + by_transition: HashMap::new(), + } + } + pub fn get(&self, item: &T) -> TransitionKey { match self.symbol_mapping.get(item) { Some(x) => *x, @@ -60,7 +77,8 @@ impl Alphabet { self.symbol_mapping.contains_key(item) } - #[must_use] pub fn from_groups(groups: &[HashSet]) -> Self { + #[must_use] + pub fn from_groups(groups: &[HashSet]) -> Self { let mut symbol_mapping = HashMap::new(); for (i, group) in groups.iter().enumerate() { for symbol in group { @@ -118,16 +136,17 @@ impl Alphabet { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub struct Fsm { - alphabet: Alphabet, + pub alphabet: Alphabet, pub states: HashSet, pub initial: TransitionKey, pub finals: HashSet, pub map: HashMap>, } impl Fsm { - #[must_use] pub fn new( + #[must_use] + pub fn new( alphabet: Alphabet, states: HashSet, initial: TransitionKey, @@ -166,7 +185,8 @@ impl Fsm { self.finals.contains(&state) } - #[must_use] pub fn reduce(&self) -> Self { + #[must_use] + pub fn reduce(&self) -> Self { self.reversed().reversed() } @@ -203,7 +223,8 @@ impl Fsm { crawl(&self.alphabet, initial, final_fn, follow) } - #[must_use] pub fn is_live(&self, state: TransitionKey) -> bool { + #[must_use] + pub fn is_live(&self, state: TransitionKey) -> bool { let mut seen = HashSet::new(); let mut reachable = vec![state]; let mut i = 0; @@ -226,7 +247,8 @@ impl Fsm { false } - #[must_use] pub fn is_empty(&self) -> bool { + #[must_use] + pub fn is_empty(&self) -> bool { !self.is_live(self.initial) } @@ -268,27 +290,32 @@ impl Fsm { }) } - #[must_use] pub fn union(fsms: &[Self]) -> Self { + #[must_use] + pub fn union(fsms: &[Self]) -> Self { Self::parallel(fsms, |accepts| accepts.iter().any(|&x| x)) } - #[must_use] pub fn intersection(fsms: &[Self]) -> Self { + #[must_use] + pub fn intersection(fsms: &[Self]) -> Self { Self::parallel(fsms, |accepts| accepts.iter().all(|&x| x)) } - #[must_use] pub fn symmetric_difference(fsms: &[Self]) -> Self { + #[must_use] + pub fn symmetric_difference(fsms: &[Self]) -> Self { Self::parallel(fsms, |accepts| { accepts.iter().filter(|&&x| x).count() % 2 == 1 }) } - #[must_use] pub fn difference(fsms: &[Self]) -> Self { + #[must_use] + pub fn difference(fsms: &[Self]) -> Self { Self::parallel(fsms, |accepts| { accepts[0] && !accepts[1..].iter().any(|&x| x) }) } - #[must_use] pub fn concatenate(fsms: &[Self]) -> Self { + #[must_use] + pub fn concatenate(fsms: &[Self]) -> Self { let alphabets_from_fsms: Vec> = fsms.iter().map(|f| f.alphabet.clone()).collect(); let alphabets = Alphabet::union(alphabets_from_fsms.as_slice()); @@ -362,7 +389,8 @@ impl Fsm { crawl(&alphabet, initial, final_fn, follow) } - #[must_use] pub fn star(&self) -> Self { + #[must_use] + pub fn star(&self) -> Self { let initial = HashSet::from([self.initial]); let follow = |state: &HashSet, @@ -398,7 +426,8 @@ impl Fsm { result } - #[must_use] pub fn times(&self, multiplier: usize) -> Self { + #[must_use] + pub fn times(&self, multiplier: usize) -> Self { // metastate is a set of iterations+states let initial = HashSet::from([(self.initial, 0)]); let final_fn = |state: &HashSet<(TransitionKey, usize)>| { @@ -433,7 +462,8 @@ impl Fsm { crawl(&self.alphabet, initial, final_fn, follow) } - #[must_use] pub fn everythingbut(&self) -> Self { + #[must_use] + pub fn everythingbut(&self) -> Self { let initial = HashSet::from([(self.initial, 0)]); let follow = |current: &HashSet<(TransitionKey, usize)>, @@ -522,7 +552,8 @@ impl Fsm { } } -#[must_use] pub fn null(alphabet: &Alphabet) -> Fsm { +#[must_use] +pub fn null(alphabet: &Alphabet) -> Fsm { Fsm::new( alphabet.clone(), HashSet::from([0.into()]), @@ -539,7 +570,8 @@ impl Fsm { ) } -#[must_use] pub fn epsilon(alphabet: &Alphabet) -> Fsm { +#[must_use] +pub fn epsilon(alphabet: &Alphabet) -> Fsm { Fsm::new( alphabet.clone(), HashSet::from([0.into()]), @@ -575,7 +607,9 @@ where for transition in alphabet.by_transition.keys() { match follow(&state, transition) { Some(next) => { - let j = if let Some(index) = states.iter().position(|s| s == &next) { index } else { + let j = if let Some(index) = states.iter().position(|s| s == &next) { + index + } else { states.push_back(next.clone()); states.len() - 1 }; diff --git a/src/interegular/patterns.rs b/src/interegular/patterns.rs index 7f3a3ccb..e8463761 100644 --- a/src/interegular/patterns.rs +++ b/src/interegular/patterns.rs @@ -6,6 +6,7 @@ use std::rc::Rc; use std::vec; use crate::interegular::fsm::SymbolTrait; +use crate::interegular::fsm::TransitionKey; use crate::interegular::fsm::{Alphabet, Fsm}; const SPECIAL_CHARS_INNER: [&str; 2] = ["\\", "]"]; @@ -98,7 +99,8 @@ fn _combine_char_groups(groups: &[RegexElement], negate: bool) -> RegexElement { } impl RegexElement { - #[must_use] pub fn repeat(self, min: usize, max: Option) -> Self { + #[must_use] + pub fn repeat(self, min: usize, max: Option) -> Self { RegexElement::Repeated { element: Box::new(self), min, @@ -106,15 +108,18 @@ impl RegexElement { } } - #[must_use] pub fn capture(self) -> Self { + #[must_use] + pub fn capture(self) -> Self { RegexElement::Capture(Box::new(self)) } - #[must_use] pub fn group(self) -> Self { + #[must_use] + pub fn group(self) -> Self { RegexElement::Group(Box::new(self)) } - #[must_use] pub fn with_flags(self, added: Vec, removed: Vec) -> Self { + #[must_use] + pub fn with_flags(self, added: Vec, removed: Vec) -> Self { RegexElement::Flag { element: Box::new(self), added, @@ -124,26 +129,91 @@ impl RegexElement { } impl RegexElement { - #[must_use] pub fn to_fsm( + pub fn to_fsm( &self, alphabet: Option>, prefix_postfix: Option<(usize, Option)>, flags: Option>, ) -> Fsm { match self { + RegexElement::Literal(c) => { + let alphabet = alphabet + .unwrap_or_else(|| self.get_alphabet(&flags.clone().unwrap_or_default())); + let prefix_postfix = prefix_postfix.unwrap_or_else(|| self.get_prefix_postfix()); + + let case_insensitive = flags + .clone() + .as_ref() + .map_or(false, |f| f.contains(&Flag::CaseInsensitive)); + + let mut mapping = HashMap::<_, HashMap<_, _>>::new(); + let symbol = alphabet.get(c); + + let mut m = std::collections::HashMap::new(); + m.insert(symbol, TransitionKey::Symbol(1_usize)); + mapping.insert(TransitionKey::Symbol(0_usize), m); + + let states = (0..=1).map(std::convert::Into::into).collect(); + let finals = (1..=1).map(std::convert::Into::into).collect(); + + Fsm::new( + alphabet, + states, // {0, 1} + 0.into(), + finals, // {1} + mapping, + ) + } RegexElement::CharGroup { chars, inverted } => { let alphabet = alphabet .unwrap_or_else(|| self.get_alphabet(&flags.clone().unwrap_or_default())); let prefix_postfix = prefix_postfix.unwrap_or_else(|| self.get_prefix_postfix()); - assert!(prefix_postfix == (0, Some(0)), "Cannot have prefix/postfix on CharGroup-level"); + assert!( + prefix_postfix == (0, Some(0)), + "Cannot have prefix/postfix on CharGroup-level" + ); let case_insensitive = flags .clone() .as_ref() .map_or(false, |f| f.contains(&Flag::CaseInsensitive)); - let mapping = HashMap::<_, HashMap<_, _>>::new(); + let mut mapping = HashMap::<_, HashMap<_, _>>::new(); + + if *inverted { + let chars = chars.clone(); + let alphabet = alphabet.clone(); + let alphabet_set = alphabet + .clone() + .by_transition + .keys() + .copied() + .collect::>(); + + let char_as_usize = chars + .iter() + .map(|c| TransitionKey::Symbol(*c as usize)) + .collect(); + let diff = alphabet_set + .difference(&char_as_usize) + .copied() + .collect::>(); + + let mut m = std::collections::HashMap::new(); + for symbol in diff { + m.insert(symbol, TransitionKey::Symbol(1_usize)); + } + mapping.insert(TransitionKey::Symbol(0_usize), m); + } else { + let chars = chars.clone(); + for symbol in chars { + let mut m = std::collections::HashMap::new(); + let symbol_value = alphabet.get(&symbol); + m.insert(symbol_value, TransitionKey::Symbol(1_usize)); + mapping.insert(TransitionKey::Symbol(0_usize), m); + } + } let states = (0..=1).map(std::convert::Into::into).collect(); let finals = (1..=1).map(std::convert::Into::into).collect(); @@ -156,12 +226,68 @@ impl RegexElement { mapping, ) } - // Implement other variants as needed + RegexElement::Repeated { element, min, max } => { + let unit = element.to_fsm(alphabet.clone(), None, flags.clone()); + let alphabet = alphabet + .unwrap_or_else(|| self.get_alphabet(&flags.clone().unwrap_or_default())); + let mandatory = std::iter::repeat(unit.clone()).take(*min).fold( + Fsm::new( + // TODO: fix if alphabet is None + alphabet.clone(), + HashSet::new(), + 0.into(), + HashSet::new(), + std::collections::HashMap::new(), + ), + |acc, f| Fsm::concatenate(&[acc, f]), + ); + + let optional = if max.is_none() { + unit.star() + } else { + let mut optional = unit.clone(); + optional.finals.insert(optional.initial); + optional = std::iter::repeat(optional.clone()) + .take(max.unwrap() - min) + .fold( + Fsm::new( + alphabet.clone(), + HashSet::new(), + 0.into(), + HashSet::new(), + std::collections::HashMap::new(), + ), + |acc, f| Fsm::concatenate(&[acc, f]), + ); + + optional + }; + + Fsm::concatenate(&[mandatory, optional]) + } + RegexElement::Concatenation(parts) => { + let mut current = vec![]; + for part in parts { + current.push(part.to_fsm(alphabet.clone(), None, flags.clone())); + } + + Fsm::concatenate(¤t) + } + RegexElement::Alternation(options) => { + let mut current = vec![]; + for option in options { + current.push(option.to_fsm(alphabet.clone(), None, flags.clone())); + } + + Fsm::union(¤t) + } + // throw on non implemented variants _ => unimplemented!("FSM conversion not implemented for this variant"), } } - #[must_use] pub fn get_alphabet(&self, flags: &HashSet) -> Alphabet { + #[must_use] + pub fn get_alphabet(&self, flags: &HashSet) -> Alphabet { match self { RegexElement::CharGroup { chars, .. } => { let relevant = if flags.contains(&Flag::CaseInsensitive) { @@ -177,11 +303,31 @@ impl RegexElement { Alphabet::from_groups(&[relevant, HashSet::from(['\0'.into()])]) } RegexElement::Literal(c) => Alphabet::from_groups(&[HashSet::from([(*c).into()])]), + RegexElement::Repeated { element, .. } => element.get_alphabet(flags), + RegexElement::Alternation(options) => { + let mut alphabet = Alphabet::empty(); + for option in options { + let alphabets = vec![alphabet, option.get_alphabet(flags)]; + let (res, new_to_old) = Alphabet::union(alphabets.as_slice()); + alphabet = res; + } + alphabet + } + RegexElement::Concatenation(parts) => { + let mut alphabet = Alphabet::empty(); + for part in parts { + let alphabets = vec![alphabet, part.get_alphabet(flags)]; + let (res, new_to_old) = Alphabet::union(alphabets.as_slice()); + alphabet = res; + } + alphabet + } _ => unimplemented!("Alphabet not implemented for this variant"), } } - #[must_use] pub fn get_prefix_postfix(&self) -> (usize, Option) { + #[must_use] + pub fn get_prefix_postfix(&self) -> (usize, Option) { match self { RegexElement::CharGroup { .. } => (0, Some(0)), RegexElement::Literal(_) => (1, Some(1)), @@ -228,7 +374,8 @@ impl RegexElement { } } - #[must_use] pub fn get_lengths(&self) -> (usize, Option) { + #[must_use] + pub fn get_lengths(&self) -> (usize, Option) { match self { RegexElement::CharGroup { .. } => (1, Some(1)), RegexElement::Literal(_) => (1, Some(1)), @@ -270,11 +417,13 @@ impl RegexElement { } } - #[must_use] pub fn simplify(&self) -> Rc { + #[must_use] + pub fn simplify(&self) -> Rc { Rc::new(self.clone()) } - #[must_use] pub fn to_concrete(&self) -> RegexElement { + #[must_use] + pub fn to_concrete(&self) -> RegexElement { self.clone() } } @@ -286,7 +435,8 @@ pub struct ParsePattern<'a> { } impl<'a> ParsePattern<'a> { - #[must_use] pub fn new(data: &'a str) -> Self { + #[must_use] + pub fn new(data: &'a str) -> Self { ParsePattern { parser: crate::interegular::simple_parser::SimpleParser::new(data), flags: None, @@ -671,7 +821,7 @@ mod tests { #[test] fn test_parse_pattern_simple() { - let pattern = "a"; + let pattern: &str = "a"; let result = parse_pattern(pattern); assert_eq!( result, @@ -946,4 +1096,29 @@ mod tests { let result = parse_pattern(pattern); assert!(result.is_err()); } + #[test] + fn test_parse_pattern_simple_to_fsm() { + let pattern: &str = "a"; + let result = parse_pattern(pattern).unwrap(); + let result = result.to_fsm(None, None, None); + + let expected = Fsm { + alphabet: Alphabet { + symbol_mapping: HashMap::from([('a', TransitionKey::Symbol(0))]), + by_transition: HashMap::from([(TransitionKey::Symbol(0), vec!['a'])]), + }, + states: HashSet::from([TransitionKey::Symbol(0), TransitionKey::Symbol(1)]), + initial: TransitionKey::Symbol(0), + finals: HashSet::from([TransitionKey::Symbol(1)]), + map: HashMap::from([ + ( + TransitionKey::Symbol(0), + HashMap::from([(TransitionKey::Symbol(0), TransitionKey::Symbol(1))]), + ), + (TransitionKey::Symbol(1), HashMap::new()), + ]), + }; + + assert_eq!(result, expected); + } } diff --git a/src/interegular/simple_parser.rs b/src/interegular/simple_parser.rs index a426ecc5..53cf9520 100644 --- a/src/interegular/simple_parser.rs +++ b/src/interegular/simple_parser.rs @@ -13,7 +13,8 @@ pub struct NoMatch { } impl NoMatch { - #[must_use] pub fn new(data: &str, index: usize, expected: Vec) -> Self { + #[must_use] + pub fn new(data: &str, index: usize, expected: Vec) -> Self { NoMatch { data: data.to_string(), index, @@ -54,7 +55,8 @@ pub struct SimpleParser { } impl SimpleParser { - #[must_use] pub fn new(data: &str) -> Self { + #[must_use] + pub fn new(data: &str) -> Self { SimpleParser { data: data.to_string(), index: 0, diff --git a/src/lib.rs b/src/lib.rs index d0082831..49b5d556 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,6 @@ +pub mod interegular; pub mod json_schema; pub mod regex; -pub mod interegular; #[cfg(feature = "python-bindings")] mod python_bindings; diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index fb8c7c1b..7d3dcad6 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -1,3 +1,4 @@ +use crate::interegular::fsm::Fsm; use crate::interegular::patterns::parse_pattern; use crate::interegular::patterns::RegexElement; use crate::json_schema; @@ -455,6 +456,50 @@ pub fn parse_pattern_internal(py: Python, pattern: &str) -> PyResult { } } +#[pyclass] +pub struct InteregularFSMInfo { + #[pyo3(get)] + initial: u32, + #[pyo3(get)] + finals: HashSet, + #[pyo3(get)] + states: HashSet, + #[pyo3(get)] + map: HashMap>, +} + +#[pyfunction(name = "parse_pattern_to_fsm")] +#[pyo3(text_signature = "(pattern: &str)")] +pub fn parse_pattern_to_fsm_internal(py: Python, pattern: &str) -> PyResult { + let regex_element = + parse_pattern(pattern).map_err(|_| PyValueError::new_err("Invalid pattern"))?; + + let alphabet = None; + let prefix_postfix = None; + let flags = None; + + let fsm_info = regex_element.to_fsm(alphabet, prefix_postfix, flags); + let map: HashMap> = fsm_info + .map + .iter() + .map(|(key, map)| { + let u32_key = u32::from(*key); + let map_as_u32s = map + .iter() + .map(|(key, value)| (u32::from(*key), u32::from(*value))) + .collect(); + (u32_key, map_as_u32s) + }) + .collect(); + + Ok(InteregularFSMInfo { + initial: fsm_info.initial.into(), + finals: fsm_info.finals.iter().map(|f| (*f).into()).collect(), + states: fsm_info.states.iter().map(|s| (*s).into()).collect(), + map, + }) +} + #[pymodule] fn outlines_core_rs(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(walk_fsm_py, m)?)?; @@ -473,6 +518,9 @@ fn outlines_core_rs(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; + m.add_function(wrap_pyfunction!(parse_pattern_to_fsm_internal, m)?)?; + m.add_class::()?; + m.add_class::()?; m.add("BOOLEAN", json_schema::BOOLEAN)?;