Skip to content

Commit

Permalink
various updates throughout the package
Browse files Browse the repository at this point in the history
  • Loading branch information
bastiscode committed Apr 11, 2024
1 parent dc68aa2 commit 97afe6f
Show file tree
Hide file tree
Showing 18 changed files with 1,301 additions and 327 deletions.
72 changes: 28 additions & 44 deletions python/text_utils/api/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import time
import logging
import warnings
from typing import Iterator, Iterable, Union, Optional, Type
from typing import Iterator, Iterable, Union, Type
try:
import readline # noqa
except ImportError:
Expand Down Expand Up @@ -33,14 +33,15 @@ def parser(
) -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(name, description)
model_group = parser.add_mutually_exclusive_group()
default_model = cls.text_processor_cls.default_model()
model_group.add_argument(
"-m",
"--model",
choices=[
model.name for model in
cls.text_processor_cls.available_models()
],
default=cls.text_processor_cls.default_model().name,
default=None if default_model is None else default_model.name,
help=f"Name of the model to use for {cls.text_processor_cls.task}"
)
model_group.add_argument(
Expand Down Expand Up @@ -185,26 +186,25 @@ def __init__(self, args: argparse.Namespace):
def version(self) -> str:
raise NotImplementedError

def format_output(self, item: data.InferenceData) -> Iterable[str]:
return [item.text]
def format_output(self, output: str) -> Iterable[str]:
return [output]

def _run_with_profiling(self, file: str) -> None:
import cProfile
cProfile.runctx("self.run()", globals(), locals(), file)

def process_iter(
self,
text_processor: TextProcessor,
iter: Iterator[data.InferenceData]
) -> Iterator[data.InferenceData]:
processor: TextProcessor,
iter: Iterator[str]
) -> Iterator[str]:
raise NotImplementedError

def process_file(
self,
text_processor: TextProcessor,
path: str,
lang: Optional[str],
out_file: Union[str, TextIOWrapper]
processor: TextProcessor,
input_file: str,
output_file: str | TextIOWrapper
) -> None:
raise NotImplementedError

Expand Down Expand Up @@ -277,8 +277,7 @@ def run(self) -> None:
start = time.perf_counter()
if self.args.process is not None:
self.args.progress = False
ipt = data.InferenceData(self.args.process)
opt = next(self.process_iter(self.cor, iter([ipt])))
opt = next(self.process_iter(self.cor, iter([self.args.process])))
for line in self.format_output(opt):
print(line)

Expand All @@ -290,7 +289,7 @@ def run(self) -> None:
assert isinstance(self.args.out_path, str)
out = self.args.out_path

self.process_file(self.cor, self.args.file, self.args.lang, out)
self.process_file(self.cor, self.args.file, out)

if self.args.report:
for d in self.cor.devices:
Expand All @@ -311,7 +310,7 @@ def run(self) -> None:
not self.args.unsorted,
self.cor.devices,
next(self.cor.model.parameters()).dtype,
batch_max_tokens=self.args.batch_max_tokens,
self.args.batch_max_tokens,
)
print(report)

Expand All @@ -328,34 +327,19 @@ def run(self) -> None:
return

try:
if self.args.unsorted:
# correct lines from stdin as they come
input_it = (
data.InferenceData(line.rstrip("\r\n"))
for line in sys.stdin
)
sized_it = ProgressIterator(
input_it,
self.inference_data_size
)
outputs = self.process_iter(self.cor, sized_it)
for opt in outputs:
for line in self.format_output(opt):
print(line)
else:
# read stdin completely, then potentially sort and correct
inputs = [
data.InferenceData(line.rstrip("\r\n"))
for line in sys.stdin
]
sized_it = ProgressIterator(
iter(inputs),
self.inference_data_size
)
outputs = self.process_iter(self.cor, sized_it)
for opt in outputs:
for line in self.format_output(opt):
print(line)
# correct lines from stdin as they come
input_it = (
line.rstrip("\r\n")
for line in sys.stdin
)
sized_it = ProgressIterator(
input_it,
self.inference_data_size
)
outputs = self.process_iter(self.cor, sized_it)
for opt in outputs:
for line in self.format_output(opt):
print(line)

if self.args.report:
for d in self.cor.devices:
Expand All @@ -373,7 +357,7 @@ def run(self) -> None:
not self.args.unsorted,
self.cor.devices,
next(self.cor.model.parameters()).dtype,
batch_max_tokens=self.args.batch_max_tokens,
self.args.batch_max_tokens,
)
print(report)

Expand Down
21 changes: 13 additions & 8 deletions python/text_utils/api/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from tqdm import tqdm
import torch
from torch import autocast, nn
from torch import nn
from torch.backends import cudnn, cuda

from text_utils import (
Expand Down Expand Up @@ -40,8 +40,10 @@ def available_models(cls) -> List[ModelInfo]:
raise NotImplementedError

@classmethod
def default_model(cls) -> ModelInfo:
def default_model(cls) -> ModelInfo | None:
available_models = cls.available_models()
if len(available_models) == 0:
return None
for info in available_models:
if "default" in info.tags:
return info
Expand Down Expand Up @@ -85,7 +87,10 @@ def from_pretrained(
force_download: bool = False
):
if model is None:
model = cls.default_model().name
default = cls.default_model()
assert default is not None, "no default model available"
model = default.name

assert model is not None
assert any(model == m.name for m in cls.available_models()), \
f"model {model} does not match any of the available models:\n" \
Expand Down Expand Up @@ -195,7 +200,7 @@ def _process_results(

def _get_loader(
self,
inputs: Union[Tuple[List[str], Optional[List[str]]], Iterator[data.InferenceData]],
inputs: list[str] | Iterator[data.InferenceData],
batch_size: int = 16,
batch_max_tokens: Optional[int] = None,
sort: bool = True,
Expand Down Expand Up @@ -229,7 +234,7 @@ def _get_loader(
"sort": sort
})
self._inference_loader_cfg.update(kwargs)
if isinstance(inputs, tuple):
if isinstance(inputs, list):
files, languages = inputs
loader = data.InferenceLoader.from_files(
files=files,
Expand All @@ -245,8 +250,8 @@ def _get_loader(
)
else:
raise ValueError(
f"unknown input type {type(inputs)}, must either be a tuple of "
f"files and languages or an iterator over sequence language pairs"
f"unknown input type {type(inputs)}, must either be a list of "
f"files and an iterator over strings"
)

return loader
Expand Down Expand Up @@ -274,7 +279,7 @@ def _process_sorted(
progress_total: int,
progress_unit: str = "seq",
show_progress: bool = False,
) -> List[data.InferenceData]:
) -> list[data.InferenceData]:
results = {}
pbar = self._pbar(
progress_desc,
Expand Down
4 changes: 2 additions & 2 deletions python/text_utils/cli/create_continuation_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def create(args: argparse.Namespace):
os.makedirs(dir, exist_ok=True)

start = time.perf_counter()
continuations.Continuations.build_from_file(
continuations.ContinuationIndex.build_from_file(
args.input_file,
args.output_file
)
Expand All @@ -29,7 +29,7 @@ def create(args: argparse.Namespace):
start = time.perf_counter()
# empty continuations for testing
conts = []
continuations.Continuations.load_with_continuations(
continuations.ContinuationIndex.load_with_continuations(
args.output_file,
conts
)
Expand Down
85 changes: 85 additions & 0 deletions python/text_utils/constraints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import copy

from text_utils._internal import grammar
from text_utils._internal import continuations


# re-export grammar constraints
RegexConstraint = grammar.RegexConstraint
LR1Constraint = grammar.LR1Constraint


class Constraint:
"""
Base class for constraints.
"""

def get(self) -> tuple[list[int], bool]:
"""
Returns the current constraint indices and whether we
are in a state that matches the constraint.
"""
raise NotImplementedError

def reset(self, input: bytes | None = None) -> None:
"""
Resets the constraint to the initial state.
"""
raise NotImplementedError

def next(self, index: int) -> None:
"""
Updates the constraint based on the chosen index / token id.
"""
raise NotImplementedError

def is_match(self) -> bool:
"""
Returns whether the current state matches the constraint.
"""
raise NotImplementedError

def clone(self) -> 'Constraint':
"""
Returns a copy of the constraint.
"""
raise NotImplementedError


class ContinuationConstraint(Constraint):
"""
Constraint for only allowing certain continuations for
a given prefix.
"""

def __init__(
self,
cont_index: continuations.ContinuationIndex,
prefix: bytes | None = None
):
self.prefix = prefix or bytes()
self.value = cont_index.get_value(self.prefix)
self.cont_index = cont_index

def get(self) -> tuple[list[int], bool]:
indices, value = self.cont_index.get(self.prefix)
self.value = value
return indices, self.is_match()

def reset(self, input: bytes | None = None) -> None:
self.prefix = input or bytes()

def next(self, index: int) -> None:
self.prefix += self.cont_index.get_continuation(index)

def is_match(self) -> bool:
return self.value is not None

def clone(self) -> 'ContinuationConstraint':
return ContinuationConstraint(
self.cont_index,
self.prefix
)

def get_value(self) -> str | None:
return self.value
Loading

0 comments on commit 97afe6f

Please sign in to comment.