Skip to content

Commit

Permalink
add is_invalid, update beam search
Browse files Browse the repository at this point in the history
  • Loading branch information
bastiscode committed Jun 24, 2024
1 parent d9b712c commit 765236d
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 60 deletions.
19 changes: 14 additions & 5 deletions python/text_utils/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@ class Constraint:
Base class for constraints.
"""

def get(self) -> tuple[np.ndarray, bool]:
def get(self) -> np.ndarray:
"""
Returns the current constraint indices and whether we
are in a state that matches the constraint.
Returns the current constraint indices.
"""
raise NotImplementedError

Expand All @@ -39,6 +38,16 @@ def is_match(self) -> bool:
"""
raise NotImplementedError

def is_invalid(self) -> bool:
"""
Returns whether the current state is invalid.
This must be true iff get() returns an empty list of indices in
a non-match state.
We have a separate function for that because depending on the constraint
this can be implemented more efficiently.
"""
return not self.is_match() and len(self.get()) == 0

def clone(self) -> 'Constraint':
"""
Returns a copy of the constraint.
Expand All @@ -61,8 +70,8 @@ def __init__(
self.indices, self.value = cont_index.get(self.prefix)
self.cont_index = cont_index

def get(self) -> tuple[np.ndarray, bool]:
return self.indices, self.is_match()
def get(self) -> np.ndarray:
return self.indices

def reset(self, input: bytes | None = None) -> None:
self.prefix = input or bytes()
Expand Down
49 changes: 33 additions & 16 deletions python/text_utils/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
StopFn,
Beam,
BeamSampleFn,
BeamWidthFn,
BeamCandidateFn,
BeamStopFn,
beam_greedy,
Expand Down Expand Up @@ -199,7 +200,7 @@ def beam_search(
device: torch.device,
normalize_by_length: bool,
alpha: float,
beam_width: int,
beam_width: int | BeamWidthFn,
sample_fn: BeamSampleFn = beam_greedy(),
candidate_fn: BeamCandidateFn = default_beam_candidate_fn(),
logit_fns: list[LogitFn] | None = None,
Expand All @@ -215,16 +216,24 @@ def beam_search(

beam_queues: list[list[Beam]] = [[] for _ in range(batch_size)]

beam_widths: list[int] = []
current_beams: list[list[Beam]] = []
initial_lenghts = []
initial_lengths = []
for init in initial:
initial_lenghts.append(len(init))
if isinstance(init, Beam):
current_beams.append([init])
continue
assert initial_lenghts[-1] <= max_length, \
"initial token ids or beams cannot be longer than max length"
beam = Beam(init, [0.0] * len(init))
beam = init
else:
beam = Beam(init, [0.0] * len(init))

initial_lengths.append(len(beam))
assert initial_lengths[-1] <= max_length, \
"initial beam cannot be longer than max length"

if isinstance(beam_width, int):
beam_widths.append(beam_width)
else:
beam_widths.append(beam_width(beam))

current_beams.append([beam])

stop_mask = [False for _ in range(batch_size)]
Expand All @@ -244,20 +253,24 @@ def get_indices_to_decode() -> list[int]:

def get_outputs() -> list[list[Beam]]:
out_beams = []
for idx, (beam_queue, active_beams) in enumerate(zip(beam_queues, current_beams)):
for idx, (beam_queue, active_beams, n) in enumerate(zip(
beam_queues,
current_beams,
beam_widths
)):
beam_queue = sorted(
beam_queue,
key=lambda b: score_fn(b),
reverse=True
)[:beam_width]
if len(beam_queue) < beam_width:
)[:n]
if len(beam_queue) < n:
active_beams = sorted(
active_beams,
key=lambda b: score_fn(b),
reverse=True
)
beam_queue.extend(active_beams[:beam_width - len(beam_queue)])
pfx = 0 if return_full else initial_lenghts[idx]
beam_queue.extend(active_beams[:n - len(beam_queue)])
pfx = 0 if return_full else initial_lengths[idx]
out_beams.append([
beam.truncate_prefix(pfx)
for beam in beam_queue
Expand Down Expand Up @@ -326,9 +339,10 @@ def get_outputs() -> list[list[Beam]]:
indices_to_decode,
torch.split(log_probs, num_beams)
)):
n = beam_widths[idx]
candidates: list[tuple[Beam, int, float]] = []
for beam_idx, beam in enumerate(current_beams[idx]):
for token_id in sample_fn(log_probs[beam_idx], beam_width).tolist():
for token_id in sample_fn(log_probs[beam_idx], n).tolist():
candidates.append((
beam,
token_id,
Expand All @@ -343,17 +357,20 @@ def get_outputs() -> list[list[Beam]]:
)):
# update candidates
candidate = candidate_fn(beam, token_id, log_prob)
if candidate is None:
# skip invalid candidates
continue
# only consider eos beams if they are in top beam_width beams
stop = stop_fn(candidate)
if num < beam_width and stop:
if num < n and stop:
# we record all stop beams, but only stop when the top beam should stop
# (because then we are sure there is no better candidate left to decode)
beam_queues[idx].append(candidate)
stop_mask[idx] |= num == 0
elif not stop:
new_beams.append(candidate)

if len(new_beams) >= beam_width:
if len(new_beams) >= n:
break

current_beams[idx] = new_beams
Expand Down
33 changes: 20 additions & 13 deletions python/text_utils/inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,39 +136,44 @@ def __repr__(self) -> str:
int,
float
],
Beam
Beam | None
]


BeamWidthFn = Callable[
[
# beam to calculate beam width from
Beam
],
# beam width
int
]


def constraint_logit_fn(
retrieve_constraint_fn: Callable[[int | Beam], Constraint | None],
eos_token_id: int
) -> LogitFn:
import time

times = []

def _constrain_logits(
logits: torch.Tensor,
beams_or_indices: list[int] | list[Beam]
) -> torch.Tensor:
nonlocal times
zeros = torch.full_like(logits, float("-inf"))

for i, beam_or_idx in enumerate(beams_or_indices):
constraint = retrieve_constraint_fn(beam_or_idx)

if constraint is None:
if constraint is None or constraint.is_invalid():
zeros[i] = logits[i]
continue

constrain_to, is_match = constraint.get()
constrain_to = constraint.get()
if constraint.is_match():
zeros[i, eos_token_id] = logits[i, eos_token_id]

constrain_to = torch.from_numpy(constrain_to).to(torch.long)
zeros[i, constrain_to] = logits[i, constrain_to]

if len(constrain_to) == 0 or is_match:
zeros[i, eos_token_id] = logits[i, eos_token_id]

return zeros

return _constrain_logits
Expand All @@ -189,8 +194,10 @@ def _constrain_sample(
continue

constraint = retrieve_constraint_fn(idx)
if constraint is not None:
constraint.next(token_id)
if constraint is None or constraint.is_invalid():
continue

constraint.next(token_id)

return token_ids

Expand Down
55 changes: 29 additions & 26 deletions src/grammar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,7 @@ impl RegexConstraint {
.lock()
.map(|mut inner| {
inner.state = state;
let indices = self.constraint.get_valid_continuations(&inner.state);
inner.indices = indices.into();
inner.indices = self.constraint.get_valid_continuations(&inner.state).into();
inner.is_match = self.constraint.is_match_state(&inner.state);
})
.map_err(|_| anyhow!("error locking inner state"))
Expand All @@ -96,15 +95,17 @@ impl RegexConstraint {
.map_err(|_| anyhow!("error locking inner state"))
}

fn get(&self, py: Python<'_>) -> anyhow::Result<(PyObject, bool)> {
fn get(&self, py: Python<'_>) -> anyhow::Result<PyObject> {
self.inner
.lock()
.map(|inner| {
(
inner.indices.clone().into_pyarray_bound(py).into_py(py),
inner.is_match,
)
})
.map(|inner| inner.indices.clone().into_pyarray_bound(py).into_py(py))
.map_err(|_| anyhow!("error locking inner state"))
}

fn is_invalid(&self) -> anyhow::Result<bool> {
self.inner
.lock()
.map(|inner| inner.indices.is_empty() && !inner.is_match)
.map_err(|_| anyhow!("error locking inner state"))
}

Expand All @@ -126,8 +127,7 @@ impl RegexConstraint {
.get_next_state(&inner.state, index)
.expect("invalid continuation");
inner.state = next_state;
let indices = constraint.get_valid_continuations(&inner.state);
inner.indices = indices.into();
inner.indices = constraint.get_valid_continuations(&inner.state).into();
inner.is_match = constraint.is_match_state(&inner.state);
});
// wait until spawned thread signals that is has locked
Expand Down Expand Up @@ -270,8 +270,7 @@ impl LR1Constraint {
.lock()
.map(|mut inner| {
inner.state = state;
let indices = self.constraint.get_valid_continuations(&inner.state);
inner.indices = indices;
inner.indices = self.constraint.get_valid_continuations(&inner.state);
inner.is_match = self.constraint.is_match_state(&inner.state);
})
.map_err(|_| anyhow!("error locking inner state"))
Expand All @@ -287,24 +286,29 @@ impl LR1Constraint {
.map_err(|_| anyhow!("error locking inner state"))
}

fn get(&self, py: Python<'_>) -> anyhow::Result<(PyObject, bool)> {
fn get(&self, py: Python<'_>) -> anyhow::Result<PyObject> {
self.inner
.lock()
.map(|inner| {
let indices =
if inner.is_match && self.constraint.only_skippable_matching(&inner.state) {
// should stop, return empty indices
vec![].into()
} else {
inner.indices.clone()
}
.into_pyarray_bound(py)
.into_py(py);
(indices, inner.is_match)
if inner.is_match && self.constraint.only_skippable_matching(&inner.state) {
// should stop, return empty indices
vec![].into()
} else {
inner.indices.clone()
}
.into_pyarray_bound(py)
.into_py(py)
})
.map_err(|_| anyhow!("error locking inner state"))
}

fn is_invalid(&self) -> anyhow::Result<bool> {
self.inner
.lock()
.map(|inner| inner.indices.is_empty() && !inner.is_match)
.map_err(|_| anyhow!("error locking inner state"))
}

fn is_match(&self) -> anyhow::Result<bool> {
self.inner
.lock()
Expand All @@ -322,8 +326,7 @@ impl LR1Constraint {
inner.state = constraint
.get_next_state(&inner.state, index)
.expect("invalid continuation");
let indices = constraint.get_valid_continuations(&inner.state);
inner.indices = indices;
inner.indices = constraint.get_valid_continuations(&inner.state);
inner.is_match = constraint.is_match_state(&inner.state);
});
// wait until spawned thread signals that is has locked
Expand Down

0 comments on commit 765236d

Please sign in to comment.