Skip to content

Commit

Permalink
update beam stop logic
Browse files Browse the repository at this point in the history
  • Loading branch information
bastiscode committed Jul 5, 2024
1 parent feae5c0 commit 5756bfa
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions python/text_utils/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,18 +263,21 @@ def get_outputs() -> list[list[Beam]]:
key=lambda b: score_fn(b),
reverse=True
)[:n]

if len(beam_queue) < n:
active_beams = sorted(
active_beams,
key=lambda b: score_fn(b),
reverse=True
)
beam_queue.extend(active_beams[:n - len(beam_queue)])

pfx = 0 if return_full else initial_lengths[idx]
out_beams.append([
beam.truncate_prefix(pfx)
for beam in beam_queue
])

return out_beams

while len(indices_to_decode) > 0:
Expand Down Expand Up @@ -373,8 +376,11 @@ def get_outputs() -> list[list[Beam]]:
if len(new_beams) >= n:
break

current_beams[idx] = new_beams
update_info[idx] = (i, len(new_beams))
if len(new_beams) == 0:
stop_mask[idx] = True
else:
current_beams[idx] = new_beams
update_info[idx] = (i, len(new_beams))

indices_to_decode = get_indices_to_decode()

Expand Down

0 comments on commit 5756bfa

Please sign in to comment.