Skip to content

Commit

Permalink
save progress
Browse files Browse the repository at this point in the history
  • Loading branch information
bastiscode committed Jun 3, 2024
1 parent 62871a1 commit d6bc237
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 24 deletions.
8 changes: 4 additions & 4 deletions python/text_utils/api/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,8 @@ def setup(self) -> TextProcessor:
return cor

@staticmethod
def inference_data_size(item: data.InferenceData) -> int:
return len(item.text.encode("utf8"))
def input_size(item: str) -> int:
return len(item.encode("utf8"))

def run(self) -> None:
# ignore warnings in CLI, so the user doesn't get confused
Expand Down Expand Up @@ -289,7 +289,7 @@ def run(self) -> None:
)
sized_it = ProgressIterator(
input_it,
self.inference_data_size
self.input_size
)
for output in self.process_iter(self.cor, sized_it):
for opt in self.format_output(output):
Expand Down Expand Up @@ -338,7 +338,7 @@ def run(self) -> None:
)
sized_it = ProgressIterator(
input_it,
self.inference_data_size
self.input_size
)
for output in self.process_iter(self.cor, sized_it):
for opt in self.format_output(output):
Expand Down
14 changes: 7 additions & 7 deletions python/text_utils/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,13 @@ def __next__(self):


def download_zip(
name: str,
url: str,
download_dir: str,
cache_dir: str,
sub_cache_dir: str,
force_download: bool,
logger: logging.Logger
name: str,
url: str,
download_dir: str,
cache_dir: str,
sub_cache_dir: str,
force_download: bool,
logger: logging.Logger
) -> str:
"""
Downloads and extracts a zip into cache dir and returns the path to the only subdirectory
Expand Down
5 changes: 2 additions & 3 deletions python/text_utils/constraints.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import copy

from text_utils._internal import grammar
from text_utils._internal import continuations

Expand Down Expand Up @@ -46,6 +44,7 @@ def clone(self) -> 'Constraint':
raise NotImplementedError



class ContinuationConstraint(Constraint):
"""
Constraint for only allowing certain continuations for
Expand Down Expand Up @@ -81,5 +80,5 @@ def clone(self) -> 'ContinuationConstraint':
self.prefix
)

def get_value(self) -> str | None:
def get_value(self) -> str:
return self.value
10 changes: 5 additions & 5 deletions python/text_utils/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@


def eos_stop_fn(eos_token_id: int) -> StopFn:
def _stop(token_ids: torch.Tensor, _: list[int]) -> torch.Tensor:
return token_ids == eos_token_id
def _stop(token_ids: torch.Tensor, _: int) -> bool:
return token_ids[-1].item() == eos_token_id

return _stop

Expand Down Expand Up @@ -160,9 +160,9 @@ def get_outputs() -> list[list[int]]:
max_length_indices = torch.where(lengths >= max_length)[0]
smaller_max_length_mask[max_length_indices] = False

stop_mask = stop_fn(sel_ids.to("cpu"), batch_indices)
new_stop_indices = indices[mask][stop_mask]
non_stop_mask[new_stop_indices] = False
for idx in batch_indices:
if stop_fn(token_ids[idx, :lengths[idx]], idx):
non_stop_mask[idx] = False

mask = non_stop_mask & smaller_max_length_mask

Expand Down
10 changes: 5 additions & 5 deletions python/text_utils/inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,13 @@ def __repr__(self) -> str:
# checks if decoding should be stopped
StopFn = Callable[
[
# selected token ids, shape [batch_size]
# token ids decoded so far
torch.Tensor,
# indices of input batch elements which are checked for stopping
list[int]
# index of input batch element checked for stopping
int
],
# mask indicating which elements should be stopped
torch.Tensor
# bool indicating if decoding should be stopped
bool
]

# takes in log probs and beam width and returns
Expand Down
4 changes: 4 additions & 0 deletions src/continuations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ impl ContinuationIndex {
.map(|c| c.as_slice())
}

fn get_continuations(&self) -> Vec<Vec<u8>> {
self.cont_trie.continuations.clone()
}

fn get(&self, prefix: &[u8]) -> (Vec<usize>, Option<String>) {
(
self.cont_trie.contains_continuations(prefix),
Expand Down

0 comments on commit d6bc237

Please sign in to comment.