Skip to content

Commit

Permalink
improve constraint speed
Browse files Browse the repository at this point in the history
  • Loading branch information
bastiscode committed Jul 15, 2024
1 parent e2236a3 commit c354e17
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 12 deletions.
2 changes: 1 addition & 1 deletion python/text_utils/inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def _constrain_logits(
zeros[i] = logits[i]
continue

indices = torch.from_numpy(constraint.get()).to(torch.long)
indices = torch.from_numpy(constraint.get())
zeros[i, indices] = logits[i, indices]

if constraint.is_match():
Expand Down
37 changes: 26 additions & 11 deletions src/grammar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::sync::{Arc, Mutex};
use anyhow::anyhow;
use lru::LruCache;
use numpy::ndarray::Array1;
use numpy::IntoPyArray;
use numpy::{IntoPyArray, PyArray1};
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList};
use rayon::spawn_fifo;
Expand All @@ -20,7 +20,7 @@ use text_utils_grammar::{
#[derive(Clone)]
struct RegexInner {
state: RegularExpressionState,
indices: Array1<usize>,
indices: Array1<i32>,
is_match: bool,
is_invalid: bool,
}
Expand All @@ -34,7 +34,11 @@ struct RegexConstraint {
impl RegexConstraint {
fn init(constraint: RegularExpressionConstraint) -> Self {
let state = constraint.get_start_state();
let indices = constraint.get_valid_continuations(&state).into();
let indices = constraint
.get_valid_continuations(&state)
.into_iter()
.map(|v| v as i32)
.collect();
let is_match = constraint.is_match_state(&state);
Self {
constraint: Arc::new(constraint),
Expand Down Expand Up @@ -84,7 +88,12 @@ impl RegexConstraint {
.lock()
.map(|mut inner| {
inner.state = state;
inner.indices = self.constraint.get_valid_continuations(&inner.state).into();
inner.indices = self
.constraint
.get_valid_continuations(&inner.state)
.into_iter()
.map(|v| v as i32)
.collect();
inner.is_match = self.constraint.is_match_state(&inner.state);
inner.is_invalid = false;
})
Expand Down Expand Up @@ -134,7 +143,11 @@ impl RegexConstraint {
return;
};
inner.state = next_state;
inner.indices = constraint.get_valid_continuations(&inner.state).into();
inner.indices = constraint
.get_valid_continuations(&inner.state)
.into_iter()
.map(|v| v as i32)
.collect();
inner.is_match = constraint.is_match_state(&inner.state);
});
// wait until spawned thread signals that is has locked
Expand All @@ -152,12 +165,12 @@ enum LR1Type {
#[derive(Clone)]
struct LR1Inner {
state: LR1State,
indices: Array1<usize>,
indices: Array1<i32>,
is_match: bool,
is_invalid: bool,
}

type LR1ConstraintCache = LruCache<LR1State, (Array1<usize>, bool)>;
type LR1ConstraintCache = LruCache<LR1State, (Array1<i32>, bool)>;

#[pyclass]
struct LR1Constraint {
Expand All @@ -181,12 +194,14 @@ impl LR1Type {
}
}

fn get_valid_continuations(&self, state: &LR1State) -> Array1<usize> {
fn get_valid_continuations(&self, state: &LR1State) -> Array1<i32> {
match self {
LR1Type::Exact(inner) => inner.get_valid_continuations(state),
LR1Type::Regular(inner) => inner.get_valid_continuations(state),
}
.into()
.into_iter()
.map(|v| v as i32)
.collect()
}

fn get_next_state(&self, state: &LR1State, continuation: usize) -> Option<LR1State> {
Expand Down Expand Up @@ -320,7 +335,7 @@ impl LR1Constraint {
.map_err(|_| anyhow!("error locking inner state"))
}

fn get(&self, py: Python<'_>) -> anyhow::Result<PyObject> {
fn get(&self, py: Python<'_>) -> anyhow::Result<Py<PyArray1<i32>>> {
self.inner
.lock()
.map(|inner| {
Expand All @@ -331,7 +346,7 @@ impl LR1Constraint {
inner.indices.clone()
}
.into_pyarray_bound(py)
.into_py(py)
.unbind()
})
.map_err(|_| anyhow!("error locking inner state"))
}
Expand Down

0 comments on commit c354e17

Please sign in to comment.