Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
bastiscode committed Jun 25, 2024
1 parent 765236d commit 509017e
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 269 deletions.
3 changes: 2 additions & 1 deletion python/text_utils/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from text_utils.api.utils import (
to,
byte_progress_bar,
sequence_progress_bar,
item_progress_bar,
progress_bar,
cpu_info,
gpu_info,
device_info,
Expand Down
165 changes: 67 additions & 98 deletions python/text_utils/api/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
import sys
import os
import pprint
from typing import Iterator, Any
from typing import Iterator, Any, Callable, TypeVar

from tqdm import tqdm
import torch
from torch import nn
from torch.backends import cudnn, cuda
Expand All @@ -26,6 +25,8 @@
["name", "description", "tags"]
)

T = TypeVar("T")


class TextProcessor:
task: str
Expand Down Expand Up @@ -180,26 +181,27 @@ def __init__(
self.model = model
self.to(device)

@torch.inference_mode()
def _inference(self, batch: data.InferenceBatch) -> Iterator[Any]:
raise NotImplementedError

def _process_results(
self,
items: list[data.InferenceItem],
outputs: list[Any]
) -> data.InferenceData:
raise NotImplementedError

def _get_loader(
def _process(
self,
iter: Iterator[data.InferenceData],
inference_fn: Callable[
[data.InferenceBatch],
list[T]
],
postprocessing_fn: Callable[
[list[data.InferenceItem], list[T]],
data.InferenceData
],
progress_desc: str,
batch_size: int = 16,
batch_max_tokens: int | None = None,
sort: bool = True,
num_threads: int | None = None,
progress_total: int | None = None,
progress_unit: str = "it",
show_progress: bool = False,
**kwargs: Any
) -> data.InferenceLoader:
) -> Iterator[data.InferenceData]:
if num_threads is None:
num_threads = min(len(os.sched_getaffinity(0)), 4)

Expand Down Expand Up @@ -229,101 +231,68 @@ def _get_loader(
"sort": sort
}
inference_cfg.update(kwargs)
return data.InferenceLoader.from_iterator(
loader = data.InferenceLoader.from_iterator(
iter,
**inference_cfg
)

def _pbar(
self,
progress_desc: str,
progress_total: int,
progress_unit: str = "seq",
show_progress: bool = False,
) -> tqdm:
if progress_unit == "seq":
return api.sequence_progress_bar(progress_desc, progress_total, not show_progress)
elif progress_unit == "byte":
return api.byte_progress_bar(progress_desc, progress_total, not show_progress)
else:
raise ValueError(
f"unknown progress unit {progress_unit}, must be either 'seq' or 'byte'"
)

def _process_sorted(
self,
loader: data.InferenceLoader,
progress_desc: str,
progress_total: int,
progress_unit: str = "seq",
show_progress: bool = False,
) -> list[data.InferenceData]:
results = {}
pbar = self._pbar(
pbar = api.progress_bar(
progress_desc,
progress_total,
progress_unit,
show_progress
)
for batch in loader:
# use only last output of inference iterator
*_, outputs = self._inference(batch)
for item, output in zip(batch.items(), outputs):
if item.item_idx not in results:
results[item.item_idx] = {}
if progress_unit == "seq":
pbar.update(1)
if progress_unit == "byte":
pbar.update(item.window_bytes())
results[item.item_idx][item.window_idx] = (item, output)
outputs = []
for item_idx in range(len(results)):
window_items = []
window_outputs = []
for window_idx in range(len(results[item_idx])):
item, output = results[item_idx][window_idx]
window_items.append(item)
window_outputs.append(output)
outputs.append(self._process_results(window_items, window_outputs))
return outputs

def _process_unsorted(
self,
loader: data.InferenceLoader,
progress_desc: str,
progress_total: int,
progress_unit: str = "seq",
show_progress: bool = False,
) -> Iterator[data.InferenceData]:
prev_item_idx = 0
window_items = []
window_outputs = []
pbar = self._pbar(
progress_desc,
progress_total,
progress_unit,
show_progress
)
for batch in loader:
*_, outputs = self._inference(batch)
for item, output in zip(batch.items(), outputs):
if item.item_idx == prev_item_idx:
if sorted:
results = {}
for batch in loader:
with torch.inference_mode():
outputs = inference_fn(batch)
for item, output in zip(batch.items(), outputs):
if item.item_idx not in results:
results[item.item_idx] = {}
if progress_unit == "byte":
pbar.update(item.window_bytes())
results[item.item_idx][item.window_idx] = (item, output)
if progress_unit == "it":
pbar.update(1)
outputs = []
for item_idx in range(len(results)):
window_items = []
window_outputs = []
for window_idx in range(len(results[item_idx])):
item, output = results[item_idx][window_idx]
window_items.append(item)
window_outputs.append(output)
continue
yield self._process_results(window_items, window_outputs)
if progress_unit == "seq":
yield postprocessing_fn(window_items, window_outputs)

else:
# not sorted, we can yield as we go
prev_item_idx = 0
window_items = []
window_outputs = []
for batch in loader:
with torch.inference_mode():
outputs = inference_fn(batch)
for item, output in zip(batch.items(), outputs):
if item.item_idx == prev_item_idx:
window_items.append(item)
window_outputs.append(output)
continue
yield postprocessing_fn(window_items, window_outputs)
if progress_unit == "byte":
pbar.update(sum(
item.window_bytes()
for item in window_items
))
prev_item_idx = item.item_idx
window_items = [item]
window_outputs = [output]
if progress_unit == "it":
pbar.update(1)
else:
pbar.update(sum(
item.window_bytes()
for item in window_items
))
prev_item_idx = item.item_idx
window_items = [item]
window_outputs = [output]
# dont forget to yield final item
yield self._process_results(window_items, window_outputs)
# dont forget to yield final item
yield postprocessing_fn(window_items, window_outputs)

pbar.close()

def to(self, device: Device) -> "TextProcessor":
self.devices = get_devices(device)
Expand Down
24 changes: 18 additions & 6 deletions python/text_utils/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,20 +205,17 @@ def num_parameters(module: nn.Module) -> Dict[str, int]:
return {"trainable": trainable, "fixed": fixed, "total": trainable + fixed}


def sequence_progress_bar(desc: str, total: int, disable: bool = False) -> tqdm:
def item_progress_bar(desc: str, total: int | None = None, disable: bool = False) -> tqdm:
return tqdm(
desc=desc,
total=total,
disable=disable,
ascii=True,
leave=False,
unit="Seq",
unit_scale=True,
unit_divisor=1000
)


def byte_progress_bar(desc: str, total: Optional[int] = None, disable: bool = False) -> tqdm:
def byte_progress_bar(desc: str, total: int | None = None, disable: bool = False) -> tqdm:
return tqdm(
desc=desc,
total=total,
Expand All @@ -227,5 +224,20 @@ def byte_progress_bar(desc: str, total: Optional[int] = None, disable: bool = Fa
leave=False,
unit="B",
unit_scale=True,
unit_divisor=1000
)


def progress_bar(
desc: str,
total: int | None = None,
progress_unit: str = "it",
show_progress: bool = False,
) -> tqdm:
if progress_unit == "it":
return item_progress_bar(desc, total, not show_progress)
elif progress_unit == "byte":
return byte_progress_bar(desc, total, not show_progress)
else:
raise ValueError(
f"unknown progress unit {progress_unit}, must be either 'it' or 'byte'"
)
2 changes: 1 addition & 1 deletion python/text_utils/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class ContinuationConstraint(Constraint):

def __init__(
self,
cont_index: continuations.ContinuationIndex,
cont_index: continuations.MmapContinuationIndex,
prefix: bytes | None = None
):
self.prefix = prefix or bytes()
Expand Down
Loading

0 comments on commit 509017e

Please sign in to comment.