diff --git a/python/text_utils/inference/__init__.py b/python/text_utils/inference/__init__.py index c3ec3c3..ec43857 100644 --- a/python/text_utils/inference/__init__.py +++ b/python/text_utils/inference/__init__.py @@ -34,6 +34,7 @@ def beam_search( logit_fns: list[LogitFn] | None = None, kwargs_select_fn: MaskSelectFn | None = None, kwargs_update_fn: MaskUpdateFn | None = None, + max_outputs: int | list[int] | None = None, return_incomplete: bool = False, yield_intermediate: bool = False, **kwargs: Any @@ -44,6 +45,14 @@ def beam_search( update_info: list[int] = [] current_beams: list[list[Beam]] = [] beam_queues: list[list[Beam]] = [] + if max_outputs is None: + max_outputs = [beam_width] * batch_size + elif isinstance(max_outputs, int): + max_outputs = [max_outputs] * batch_size + else: + assert len(max_outputs) == batch_size, \ + "max_outputs must be None, int or list of length batch_size" + for init in initial: if isinstance(init, Beam): beams = [init] @@ -73,7 +82,7 @@ def filter_beams() -> bool: current_beams[idx] = new_beams finished = finished and ( len(current_beams[idx]) == 0 - or len(beam_queues[idx]) >= beam_width + or len(beam_queues[idx]) >= max_outputs[idx] ) return finished @@ -86,6 +95,9 @@ def get_outputs(intermediate: bool) -> list[list[Beam]]: # for intermediate outputs we # return the active beams, so swap here beam_queue, current = current, beam_queue + n = beam_width + else: + n = max_outputs[idx] beam_queue = sorted( beam_queue, @@ -99,7 +111,7 @@ def get_outputs(intermediate: bool) -> list[list[Beam]]: reverse=True ) - outputs.append(beam_queue[:beam_width]) + outputs.append(beam_queue[:n]) return outputs @@ -184,7 +196,8 @@ def get_outputs(intermediate: bool) -> list[list[Beam]]: for beam_idx, beam in enumerate(current_beams[idx]): for token_id in sample_fn(log_probs[beam_idx], beam_width).tolist(): candidate = beam.clone() - candidate.add(token_id, org_log_probs[beam_idx, token_id].item()) + candidate.add( + token_id, org_log_probs[beam_idx, token_id].item()) candidates.append(candidate) # reset current beams and fill with best candidates