From 765236d7d6755f64dd300d759bc95b0160d7704a Mon Sep 17 00:00:00 2001 From: Sebastian Walter Date: Mon, 24 Jun 2024 12:39:24 +0200 Subject: [PATCH] add is_invalid, update beam search --- python/text_utils/constraints.py | 19 ++++++--- python/text_utils/inference/__init__.py | 49 +++++++++++++++------- python/text_utils/inference/utils.py | 33 +++++++++------ src/grammar.rs | 55 +++++++++++++------------ 4 files changed, 96 insertions(+), 60 deletions(-) diff --git a/python/text_utils/constraints.py b/python/text_utils/constraints.py index 3657d5d..9fbf8f1 100644 --- a/python/text_utils/constraints.py +++ b/python/text_utils/constraints.py @@ -14,10 +14,9 @@ class Constraint: Base class for constraints. """ - def get(self) -> tuple[np.ndarray, bool]: + def get(self) -> np.ndarray: """ - Returns the current constraint indices and whether we - are in a state that matches the constraint. + Returns the current constraint indices. """ raise NotImplementedError @@ -39,6 +38,16 @@ def is_match(self) -> bool: """ raise NotImplementedError + def is_invalid(self) -> bool: + """ + Returns whether the current state is invalid. + This must be true iff get() returns an empty list of indices in + a non-match state. + We have a separate function for that because depending on the constraint + this can be implemented more efficiently. + """ + return not self.is_match() and len(self.get()) == 0 + def clone(self) -> 'Constraint': """ Returns a copy of the constraint. @@ -61,8 +70,8 @@ def __init__( self.indices, self.value = cont_index.get(self.prefix) self.cont_index = cont_index - def get(self) -> tuple[np.ndarray, bool]: - return self.indices, self.is_match() + def get(self) -> np.ndarray: + return self.indices def reset(self, input: bytes | None = None) -> None: self.prefix = input or bytes() diff --git a/python/text_utils/inference/__init__.py b/python/text_utils/inference/__init__.py index 3a5c3a1..6395ae7 100644 --- a/python/text_utils/inference/__init__.py +++ b/python/text_utils/inference/__init__.py @@ -13,6 +13,7 @@ StopFn, Beam, BeamSampleFn, + BeamWidthFn, BeamCandidateFn, BeamStopFn, beam_greedy, @@ -199,7 +200,7 @@ def beam_search( device: torch.device, normalize_by_length: bool, alpha: float, - beam_width: int, + beam_width: int | BeamWidthFn, sample_fn: BeamSampleFn = beam_greedy(), candidate_fn: BeamCandidateFn = default_beam_candidate_fn(), logit_fns: list[LogitFn] | None = None, @@ -215,16 +216,24 @@ def beam_search( beam_queues: list[list[Beam]] = [[] for _ in range(batch_size)] + beam_widths: list[int] = [] current_beams: list[list[Beam]] = [] - initial_lenghts = [] + initial_lengths = [] for init in initial: - initial_lenghts.append(len(init)) if isinstance(init, Beam): - current_beams.append([init]) - continue - assert initial_lenghts[-1] <= max_length, \ - "initial token ids or beams cannot be longer than max length" - beam = Beam(init, [0.0] * len(init)) + beam = init + else: + beam = Beam(init, [0.0] * len(init)) + + initial_lengths.append(len(beam)) + assert initial_lengths[-1] <= max_length, \ + "initial beam cannot be longer than max length" + + if isinstance(beam_width, int): + beam_widths.append(beam_width) + else: + beam_widths.append(beam_width(beam)) + current_beams.append([beam]) stop_mask = [False for _ in range(batch_size)] @@ -244,20 +253,24 @@ def get_indices_to_decode() -> list[int]: def get_outputs() -> list[list[Beam]]: out_beams = [] - for idx, (beam_queue, active_beams) in enumerate(zip(beam_queues, current_beams)): + for idx, (beam_queue, active_beams, n) in enumerate(zip( + beam_queues, + current_beams, + beam_widths + )): beam_queue = sorted( beam_queue, key=lambda b: score_fn(b), reverse=True - )[:beam_width] - if len(beam_queue) < beam_width: + )[:n] + if len(beam_queue) < n: active_beams = sorted( active_beams, key=lambda b: score_fn(b), reverse=True ) - beam_queue.extend(active_beams[:beam_width - len(beam_queue)]) - pfx = 0 if return_full else initial_lenghts[idx] + beam_queue.extend(active_beams[:n - len(beam_queue)]) + pfx = 0 if return_full else initial_lengths[idx] out_beams.append([ beam.truncate_prefix(pfx) for beam in beam_queue @@ -326,9 +339,10 @@ def get_outputs() -> list[list[Beam]]: indices_to_decode, torch.split(log_probs, num_beams) )): + n = beam_widths[idx] candidates: list[tuple[Beam, int, float]] = [] for beam_idx, beam in enumerate(current_beams[idx]): - for token_id in sample_fn(log_probs[beam_idx], beam_width).tolist(): + for token_id in sample_fn(log_probs[beam_idx], n).tolist(): candidates.append(( beam, token_id, @@ -343,9 +357,12 @@ def get_outputs() -> list[list[Beam]]: )): # update candidates candidate = candidate_fn(beam, token_id, log_prob) + if candidate is None: + # skip invalid candidates + continue # only consider eos beams if they are in top beam_width beams stop = stop_fn(candidate) - if num < beam_width and stop: + if num < n and stop: # we record all stop beams, but only stop when the top beam should stop # (because then we are sure there is no better candidate left to decode) beam_queues[idx].append(candidate) @@ -353,7 +370,7 @@ def get_outputs() -> list[list[Beam]]: elif not stop: new_beams.append(candidate) - if len(new_beams) >= beam_width: + if len(new_beams) >= n: break current_beams[idx] = new_beams diff --git a/python/text_utils/inference/utils.py b/python/text_utils/inference/utils.py index 17f529f..5980af7 100644 --- a/python/text_utils/inference/utils.py +++ b/python/text_utils/inference/utils.py @@ -136,7 +136,17 @@ def __repr__(self) -> str: int, float ], - Beam + Beam | None +] + + +BeamWidthFn = Callable[ + [ + # beam to calculate beam width from + Beam + ], + # beam width + int ] @@ -144,31 +154,26 @@ def constraint_logit_fn( retrieve_constraint_fn: Callable[[int | Beam], Constraint | None], eos_token_id: int ) -> LogitFn: - import time - - times = [] - def _constrain_logits( logits: torch.Tensor, beams_or_indices: list[int] | list[Beam] ) -> torch.Tensor: - nonlocal times zeros = torch.full_like(logits, float("-inf")) for i, beam_or_idx in enumerate(beams_or_indices): constraint = retrieve_constraint_fn(beam_or_idx) - if constraint is None: + if constraint is None or constraint.is_invalid(): zeros[i] = logits[i] continue - constrain_to, is_match = constraint.get() + constrain_to = constraint.get() + if constraint.is_match(): + zeros[i, eos_token_id] = logits[i, eos_token_id] + constrain_to = torch.from_numpy(constrain_to).to(torch.long) zeros[i, constrain_to] = logits[i, constrain_to] - if len(constrain_to) == 0 or is_match: - zeros[i, eos_token_id] = logits[i, eos_token_id] - return zeros return _constrain_logits @@ -189,8 +194,10 @@ def _constrain_sample( continue constraint = retrieve_constraint_fn(idx) - if constraint is not None: - constraint.next(token_id) + if constraint is None or constraint.is_invalid(): + continue + + constraint.next(token_id) return token_ids diff --git a/src/grammar.rs b/src/grammar.rs index 781fd8b..ba2ccb3 100644 --- a/src/grammar.rs +++ b/src/grammar.rs @@ -79,8 +79,7 @@ impl RegexConstraint { .lock() .map(|mut inner| { inner.state = state; - let indices = self.constraint.get_valid_continuations(&inner.state); - inner.indices = indices.into(); + inner.indices = self.constraint.get_valid_continuations(&inner.state).into(); inner.is_match = self.constraint.is_match_state(&inner.state); }) .map_err(|_| anyhow!("error locking inner state")) @@ -96,15 +95,17 @@ impl RegexConstraint { .map_err(|_| anyhow!("error locking inner state")) } - fn get(&self, py: Python<'_>) -> anyhow::Result<(PyObject, bool)> { + fn get(&self, py: Python<'_>) -> anyhow::Result { self.inner .lock() - .map(|inner| { - ( - inner.indices.clone().into_pyarray_bound(py).into_py(py), - inner.is_match, - ) - }) + .map(|inner| inner.indices.clone().into_pyarray_bound(py).into_py(py)) + .map_err(|_| anyhow!("error locking inner state")) + } + + fn is_invalid(&self) -> anyhow::Result { + self.inner + .lock() + .map(|inner| inner.indices.is_empty() && !inner.is_match) .map_err(|_| anyhow!("error locking inner state")) } @@ -126,8 +127,7 @@ impl RegexConstraint { .get_next_state(&inner.state, index) .expect("invalid continuation"); inner.state = next_state; - let indices = constraint.get_valid_continuations(&inner.state); - inner.indices = indices.into(); + inner.indices = constraint.get_valid_continuations(&inner.state).into(); inner.is_match = constraint.is_match_state(&inner.state); }); // wait until spawned thread signals that is has locked @@ -270,8 +270,7 @@ impl LR1Constraint { .lock() .map(|mut inner| { inner.state = state; - let indices = self.constraint.get_valid_continuations(&inner.state); - inner.indices = indices; + inner.indices = self.constraint.get_valid_continuations(&inner.state); inner.is_match = self.constraint.is_match_state(&inner.state); }) .map_err(|_| anyhow!("error locking inner state")) @@ -287,24 +286,29 @@ impl LR1Constraint { .map_err(|_| anyhow!("error locking inner state")) } - fn get(&self, py: Python<'_>) -> anyhow::Result<(PyObject, bool)> { + fn get(&self, py: Python<'_>) -> anyhow::Result { self.inner .lock() .map(|inner| { - let indices = - if inner.is_match && self.constraint.only_skippable_matching(&inner.state) { - // should stop, return empty indices - vec![].into() - } else { - inner.indices.clone() - } - .into_pyarray_bound(py) - .into_py(py); - (indices, inner.is_match) + if inner.is_match && self.constraint.only_skippable_matching(&inner.state) { + // should stop, return empty indices + vec![].into() + } else { + inner.indices.clone() + } + .into_pyarray_bound(py) + .into_py(py) }) .map_err(|_| anyhow!("error locking inner state")) } + fn is_invalid(&self) -> anyhow::Result { + self.inner + .lock() + .map(|inner| inner.indices.is_empty() && !inner.is_match) + .map_err(|_| anyhow!("error locking inner state")) + } + fn is_match(&self) -> anyhow::Result { self.inner .lock() @@ -322,8 +326,7 @@ impl LR1Constraint { inner.state = constraint .get_next_state(&inner.state, index) .expect("invalid continuation"); - let indices = constraint.get_valid_continuations(&inner.state); - inner.indices = indices; + inner.indices = constraint.get_valid_continuations(&inner.state); inner.is_match = constraint.is_match_state(&inner.state); }); // wait until spawned thread signals that is has locked