From d6bc2377b82b7e6c9858c68b5ef34705e2328410 Mon Sep 17 00:00:00 2001 From: Sebastian Walter Date: Mon, 3 Jun 2024 17:57:42 +0200 Subject: [PATCH] save progress --- python/text_utils/api/cli.py | 8 ++++---- python/text_utils/api/utils.py | 14 +++++++------- python/text_utils/constraints.py | 5 ++--- python/text_utils/inference/__init__.py | 10 +++++----- python/text_utils/inference/utils.py | 10 +++++----- src/continuations.rs | 4 ++++ 6 files changed, 27 insertions(+), 24 deletions(-) diff --git a/python/text_utils/api/cli.py b/python/text_utils/api/cli.py index a812e50..cfd7dd2 100644 --- a/python/text_utils/api/cli.py +++ b/python/text_utils/api/cli.py @@ -222,8 +222,8 @@ def setup(self) -> TextProcessor: return cor @staticmethod - def inference_data_size(item: data.InferenceData) -> int: - return len(item.text.encode("utf8")) + def input_size(item: str) -> int: + return len(item.encode("utf8")) def run(self) -> None: # ignore warnings in CLI, so the user doesn't get confused @@ -289,7 +289,7 @@ def run(self) -> None: ) sized_it = ProgressIterator( input_it, - self.inference_data_size + self.input_size ) for output in self.process_iter(self.cor, sized_it): for opt in self.format_output(output): @@ -338,7 +338,7 @@ def run(self) -> None: ) sized_it = ProgressIterator( input_it, - self.inference_data_size + self.input_size ) for output in self.process_iter(self.cor, sized_it): for opt in self.format_output(output): diff --git a/python/text_utils/api/utils.py b/python/text_utils/api/utils.py index 78db15a..460dd63 100644 --- a/python/text_utils/api/utils.py +++ b/python/text_utils/api/utils.py @@ -79,13 +79,13 @@ def __next__(self): def download_zip( - name: str, - url: str, - download_dir: str, - cache_dir: str, - sub_cache_dir: str, - force_download: bool, - logger: logging.Logger + name: str, + url: str, + download_dir: str, + cache_dir: str, + sub_cache_dir: str, + force_download: bool, + logger: logging.Logger ) -> str: """ Downloads and extracts a zip into cache dir and returns the path to the only subdirectory diff --git a/python/text_utils/constraints.py b/python/text_utils/constraints.py index 68d4870..b38445b 100644 --- a/python/text_utils/constraints.py +++ b/python/text_utils/constraints.py @@ -1,5 +1,3 @@ -import copy - from text_utils._internal import grammar from text_utils._internal import continuations @@ -46,6 +44,7 @@ def clone(self) -> 'Constraint': raise NotImplementedError + class ContinuationConstraint(Constraint): """ Constraint for only allowing certain continuations for @@ -81,5 +80,5 @@ def clone(self) -> 'ContinuationConstraint': self.prefix ) - def get_value(self) -> str | None: + def get_value(self) -> str: return self.value diff --git a/python/text_utils/inference/__init__.py b/python/text_utils/inference/__init__.py index 847e507..7db876f 100644 --- a/python/text_utils/inference/__init__.py +++ b/python/text_utils/inference/__init__.py @@ -22,8 +22,8 @@ def eos_stop_fn(eos_token_id: int) -> StopFn: - def _stop(token_ids: torch.Tensor, _: list[int]) -> torch.Tensor: - return token_ids == eos_token_id + def _stop(token_ids: torch.Tensor, _: int) -> bool: + return token_ids[-1].item() == eos_token_id return _stop @@ -160,9 +160,9 @@ def get_outputs() -> list[list[int]]: max_length_indices = torch.where(lengths >= max_length)[0] smaller_max_length_mask[max_length_indices] = False - stop_mask = stop_fn(sel_ids.to("cpu"), batch_indices) - new_stop_indices = indices[mask][stop_mask] - non_stop_mask[new_stop_indices] = False + for idx in batch_indices: + if stop_fn(token_ids[idx, :lengths[idx]], idx): + non_stop_mask[idx] = False mask = non_stop_mask & smaller_max_length_mask diff --git a/python/text_utils/inference/utils.py b/python/text_utils/inference/utils.py index cc8d535..144a562 100644 --- a/python/text_utils/inference/utils.py +++ b/python/text_utils/inference/utils.py @@ -91,13 +91,13 @@ def __repr__(self) -> str: # checks if decoding should be stopped StopFn = Callable[ [ - # selected token ids, shape [batch_size] + # token ids decoded so far torch.Tensor, - # indices of input batch elements which are checked for stopping - list[int] + # index of input batch element checked for stopping + int ], - # mask indicating which elements should be stopped - torch.Tensor + # bool indicating if decoding should be stopped + bool ] # takes in log probs and beam width and returns diff --git a/src/continuations.rs b/src/continuations.rs index 6f3a344..db735c2 100644 --- a/src/continuations.rs +++ b/src/continuations.rs @@ -60,6 +60,10 @@ impl ContinuationIndex { .map(|c| c.as_slice()) } + fn get_continuations(&self) -> Vec> { + self.cont_trie.continuations.clone() + } + fn get(&self, prefix: &[u8]) -> (Vec, Option) { ( self.cont_trie.contains_continuations(prefix),