diff --git a/python/text_utils/inference/utils.py b/python/text_utils/inference/utils.py index a157f89..08ac4e2 100644 --- a/python/text_utils/inference/utils.py +++ b/python/text_utils/inference/utils.py @@ -167,7 +167,7 @@ def _constrain_logits( zeros[i] = logits[i] continue - indices = torch.from_numpy(constraint.get()).to(torch.long) + indices = torch.from_numpy(constraint.get()) zeros[i, indices] = logits[i, indices] if constraint.is_match(): diff --git a/src/grammar.rs b/src/grammar.rs index 5daa39b..32cc089 100644 --- a/src/grammar.rs +++ b/src/grammar.rs @@ -7,7 +7,7 @@ use std::sync::{Arc, Mutex}; use anyhow::anyhow; use lru::LruCache; use numpy::ndarray::Array1; -use numpy::IntoPyArray; +use numpy::{IntoPyArray, PyArray1}; use pyo3::prelude::*; use pyo3::types::{PyDict, PyList}; use rayon::spawn_fifo; @@ -20,7 +20,7 @@ use text_utils_grammar::{ #[derive(Clone)] struct RegexInner { state: RegularExpressionState, - indices: Array1, + indices: Array1, is_match: bool, is_invalid: bool, } @@ -34,7 +34,11 @@ struct RegexConstraint { impl RegexConstraint { fn init(constraint: RegularExpressionConstraint) -> Self { let state = constraint.get_start_state(); - let indices = constraint.get_valid_continuations(&state).into(); + let indices = constraint + .get_valid_continuations(&state) + .into_iter() + .map(|v| v as i32) + .collect(); let is_match = constraint.is_match_state(&state); Self { constraint: Arc::new(constraint), @@ -84,7 +88,12 @@ impl RegexConstraint { .lock() .map(|mut inner| { inner.state = state; - inner.indices = self.constraint.get_valid_continuations(&inner.state).into(); + inner.indices = self + .constraint + .get_valid_continuations(&inner.state) + .into_iter() + .map(|v| v as i32) + .collect(); inner.is_match = self.constraint.is_match_state(&inner.state); inner.is_invalid = false; }) @@ -134,7 +143,11 @@ impl RegexConstraint { return; }; inner.state = next_state; - inner.indices = constraint.get_valid_continuations(&inner.state).into(); + inner.indices = constraint + .get_valid_continuations(&inner.state) + .into_iter() + .map(|v| v as i32) + .collect(); inner.is_match = constraint.is_match_state(&inner.state); }); // wait until spawned thread signals that is has locked @@ -152,12 +165,12 @@ enum LR1Type { #[derive(Clone)] struct LR1Inner { state: LR1State, - indices: Array1, + indices: Array1, is_match: bool, is_invalid: bool, } -type LR1ConstraintCache = LruCache, bool)>; +type LR1ConstraintCache = LruCache, bool)>; #[pyclass] struct LR1Constraint { @@ -181,12 +194,14 @@ impl LR1Type { } } - fn get_valid_continuations(&self, state: &LR1State) -> Array1 { + fn get_valid_continuations(&self, state: &LR1State) -> Array1 { match self { LR1Type::Exact(inner) => inner.get_valid_continuations(state), LR1Type::Regular(inner) => inner.get_valid_continuations(state), } - .into() + .into_iter() + .map(|v| v as i32) + .collect() } fn get_next_state(&self, state: &LR1State, continuation: usize) -> Option { @@ -320,7 +335,7 @@ impl LR1Constraint { .map_err(|_| anyhow!("error locking inner state")) } - fn get(&self, py: Python<'_>) -> anyhow::Result { + fn get(&self, py: Python<'_>) -> anyhow::Result>> { self.inner .lock() .map(|inner| { @@ -331,7 +346,7 @@ impl LR1Constraint { inner.indices.clone() } .into_pyarray_bound(py) - .into_py(py) + .unbind() }) .map_err(|_| anyhow!("error locking inner state")) }