Skip to content

Commit

Permalink
allow max outputs in beam search
Browse files Browse the repository at this point in the history
  • Loading branch information
bastiscode committed Aug 6, 2024
1 parent 1ebf966 commit 721b63c
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions python/text_utils/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def beam_search(
logit_fns: list[LogitFn] | None = None,
kwargs_select_fn: MaskSelectFn | None = None,
kwargs_update_fn: MaskUpdateFn | None = None,
max_outputs: int | list[int] | None = None,
return_incomplete: bool = False,
yield_intermediate: bool = False,
**kwargs: Any
Expand All @@ -44,6 +45,14 @@ def beam_search(
update_info: list[int] = []
current_beams: list[list[Beam]] = []
beam_queues: list[list[Beam]] = []
if max_outputs is None:
max_outputs = [beam_width] * batch_size
elif isinstance(max_outputs, int):
max_outputs = [max_outputs] * batch_size
else:
assert len(max_outputs) == batch_size, \
"max_outputs must be None, int or list of length batch_size"

for init in initial:
if isinstance(init, Beam):
beams = [init]
Expand Down Expand Up @@ -73,7 +82,7 @@ def filter_beams() -> bool:
current_beams[idx] = new_beams
finished = finished and (
len(current_beams[idx]) == 0
or len(beam_queues[idx]) >= beam_width
or len(beam_queues[idx]) >= max_outputs[idx]
)
return finished

Expand All @@ -86,6 +95,9 @@ def get_outputs(intermediate: bool) -> list[list[Beam]]:
# for intermediate outputs we
# return the active beams, so swap here
beam_queue, current = current, beam_queue
n = beam_width
else:
n = max_outputs[idx]

beam_queue = sorted(
beam_queue,
Expand All @@ -99,7 +111,7 @@ def get_outputs(intermediate: bool) -> list[list[Beam]]:
reverse=True
)

outputs.append(beam_queue[:beam_width])
outputs.append(beam_queue[:n])

return outputs

Expand Down Expand Up @@ -184,7 +196,8 @@ def get_outputs(intermediate: bool) -> list[list[Beam]]:
for beam_idx, beam in enumerate(current_beams[idx]):
for token_id in sample_fn(log_probs[beam_idx], beam_width).tolist():
candidate = beam.clone()
candidate.add(token_id, org_log_probs[beam_idx, token_id].item())
candidate.add(
token_id, org_log_probs[beam_idx, token_id].item())
candidates.append(candidate)

# reset current beams and fill with best candidates
Expand Down

0 comments on commit 721b63c

Please sign in to comment.