From b41838f35b6ecda8cb9dbdd3c408e14bcb75b0ad Mon Sep 17 00:00:00 2001 From: David Hall Date: Thu, 3 Oct 2024 22:57:36 -0700 Subject: [PATCH] Simplify tokenization pipeline, make it work with large numbers of shards again, (re)add configuration metadata to cache (#752) Co-authored-by: Ahmed Ahmed --- .dockerignore | 1 + config/data/dclm_gpt_neo.yaml | 78 + config/data/dolma_olmo_paloma.yaml | 44 +- config/llama_7b_with_dclm.yaml | 33 + pyproject.toml | 5 +- src/levanter/data/_preprocessor.py | 16 +- src/levanter/data/_queue.py | 248 ---- src/levanter/data/audio.py | 41 +- src/levanter/data/text.py | 58 +- src/levanter/main/train_asr.py | 2 +- src/levanter/store/_prefetch_actor.py | 156 ++ src/levanter/store/cache.py | 1944 ++++++++++++------------- src/levanter/store/jagged_array.py | 68 +- src/levanter/store/tree_store.py | 18 +- src/levanter/utils/py_utils.py | 35 + src/levanter/utils/ray_utils.py | 42 +- tests/test_audio.py | 10 + tests/test_jagged_array.py | 48 +- tests/test_new_cache.py | 619 ++------ tests/test_prefetch_actor.py | 137 ++ tests/test_tree_store.py | 15 +- tests/test_utils.py | 6 +- 22 files changed, 1762 insertions(+), 1862 deletions(-) create mode 100644 config/data/dclm_gpt_neo.yaml create mode 100644 config/llama_7b_with_dclm.yaml delete mode 100644 src/levanter/data/_queue.py create mode 100644 src/levanter/store/_prefetch_actor.py create mode 100644 tests/test_prefetch_actor.py diff --git a/.dockerignore b/.dockerignore index 45dfa95e6..9abaa045d 100644 --- a/.dockerignore +++ b/.dockerignore @@ -117,3 +117,4 @@ dmypy.json # local execution commands local_*.sh +.aider* diff --git a/config/data/dclm_gpt_neo.yaml b/config/data/dclm_gpt_neo.yaml new file mode 100644 index 000000000..fd70a5d52 --- /dev/null +++ b/config/data/dclm_gpt_neo.yaml @@ -0,0 +1,78 @@ +cache_dir: "gs://marin-us-central2/tokenized/gpt_neox/" +tokenizer: "EleutherAI/gpt-neox-20b" +cache_options: + batch_size: 256 + num_shard_groups: 1024 +stop_strategy: restart +shuffle: 100000 +configs: + "dclm": + train_urls: + - gs://marin-us-central2/raw/dclm/v2024-07-09-baseline-dedup/**/*.zstd + # these are just for eval + "paloma/4chan": + validation_urls: + - gs://levanter-data/paloma/4chan_meta_sep/val/val*.jsonl.gz + "paloma/c4_100_domains": + validation_urls: + - gs://levanter-data/paloma/c4_100_domains/val/val*.jsonl.gz + "paloma/c4_en": + validation_urls: + - gs://levanter-data/paloma/c4_en/val/val*.jsonl.gz + "paloma/dolma-v1_5": + validation_urls: + - gs://levanter-data/paloma/dolma-v1_5/val/val*.jsonl.gz + "paloma/dolma_100_programing_languages": + validation_urls: + - gs://levanter-data/paloma/dolma_100_programing_languages/val/val*.jsonl.gz + "paloma/dolma_100_subreddits": + validation_urls: + - gs://levanter-data/paloma/dolma_100_subreddits/val/val*.jsonl.gz + "paloma/falcon-refinedweb": + validation_urls: + - gs://levanter-data/paloma/falcon-refinedweb/val/val*.jsonl.gz + "paloma/gab": + validation_urls: + - gs://levanter-data/paloma/gab/val/val*.jsonl.gz + "paloma/m2d2_s2orc_unsplit": + validation_urls: + - gs://levanter-data/paloma/m2d2_s2orc_unsplit/val/val*.jsonl.gz + "paloma/m2d2_wikipedia_unsplit": + validation_urls: + - gs://levanter-data/paloma/m2d2_wikipedia_unsplit/val/val*.jsonl.gz + "paloma/manosphere_meta_sep": + validation_urls: + - gs://levanter-data/paloma/manosphere_meta_sep/val/val*.jsonl.gz + "paloma/mc4": + validation_urls: + - gs://levanter-data/paloma/mc4/val/val*.jsonl.gz + "paloma/ptb": + validation_urls: + - gs://levanter-data/paloma/ptb/val/val*.jsonl.gz + "paloma/redpajama": + validation_urls: + - gs://levanter-data/paloma/redpajama/val/val*.jsonl.gz + "paloma/twitterAAE_HELM_fixed": + validation_urls: + - gs://levanter-data/paloma/twitterAAE_HELM_fixed/val/val*.jsonl.gz + "paloma/wikitext_103": + validation_urls: + - gs://levanter-data/paloma/wikitext_103/val/val*.jsonl.gz +train_weights: + dclm: 1.0 + paloma/4chan: 0.0 + paloma/c4_100_domains: 0.0 + paloma/c4_en: 0.0 + paloma/dolma-v1_5: 0.0 + paloma/dolma_100_programing_languages: 0.0 + paloma/dolma_100_subreddits: 0.0 + paloma/falcon-refinedweb: 0.0 + paloma/gab: 0.0 + paloma/m2d2_s2orc_unsplit: 0.0 + paloma/m2d2_wikipedia_unsplit: 0.0 + paloma/manosphere_meta_sep: 0.0 + paloma/mc4: 0.0 + paloma/ptb: 0.0 + paloma/redpajama: 0.0 + paloma/twitterAAE_HELM_fixed: 0.0 + paloma/wikitext_103: 0.0 diff --git a/config/data/dolma_olmo_paloma.yaml b/config/data/dolma_olmo_paloma.yaml index 54cbcd05f..6aefbdd47 100644 --- a/config/data/dolma_olmo_paloma.yaml +++ b/config/data/dolma_olmo_paloma.yaml @@ -1,59 +1,59 @@ -cache_dir: "gs://marin-data/tokenized/OLMo-1B/dolma-v1.7" +cache_dir: "gs://marin-us-central2/tokenized/OLMo-1B/dolma/v1.7" tokenizer: "allenai/OLMo-1B" # requires `pip install ai2-olmo` # tokenizer: "meta-llama/Llama-2-7b-hf" stop_strategy: restart configs: dolma-algebraic-stack: train_urls: - - gs://marin-data/raw/dolma/dolma-v1.7/algebraic-stack-train-{0000..0015}.json.gz + - gs://marin-us-central2/raw/dolma/v1.7/algebraic-stack-train-{0000..0015}.json.gz dolma-arxiv: train_urls: - - gs://marin-data/raw/dolma/dolma-v1.7/arxiv-{0000..0099}.json.gz + - gs://marin-us-central2/raw/dolma/v1.7/arxiv-{0000..0099}.json.gz dolma-gutenberg: train_urls: - - gs://marin-data/raw/dolma/dolma-v1.7/books-{0000..0002}.json.gz + - gs://marin-us-central2/raw/dolma/v1.7/books-{0000..0002}.json.gz dolma-c4: train_urls: - - gs://marin-data/raw/dolma/dolma-v1.7/c4-{0000..0170}.json.gz + - gs://marin-us-central2/raw/dolma/v1.7/c4-{0000..0170}.json.gz dolma-cc: train_urls: - - gs://marin-data/raw/dolma/dolma-v1.7/cc_en_head-{0000..0274}.json.gz - - gs://marin-data/raw/dolma/dolma-v1.7/cc_en_middle-{0000..0238}.json.gz # 239 is missing - - gs://marin-data/raw/dolma/dolma-v1.7/cc_en_middle-{0240..0379}.json.gz - - gs://marin-data/raw/dolma/dolma-v1.7/cc_en_tail-{0000..0152}.json.gz # 153 is missing - - gs://marin-data/raw/dolma/dolma-v1.7/cc_en_tail-{0154..0444}.json.gz + - gs://marin-us-central2/raw/dolma/v1.7/cc_en_head-{0000..0274}.json.gz + - gs://marin-us-central2/raw/dolma/v1.7/cc_en_middle-{0000..0238}.json.gz # 239 is missing + - gs://marin-us-central2/raw/dolma/v1.7/cc_en_middle-{0240..0379}.json.gz + - gs://marin-us-central2/raw/dolma/v1.7/cc_en_tail-{0000..0152}.json.gz # 153 is missing + - gs://marin-us-central2/raw/dolma/v1.7/cc_en_tail-{0154..0444}.json.gz dolma-cc-news: train_urls: - - gs://marin-data/raw/dolma/dolma-v1.7/cc_news_head-{0000..0004}.json.gz - - gs://marin-data/raw/dolma/dolma-v1.7/cc_news_middle-{0000..0002}.json.gz - - gs://marin-data/raw/dolma/dolma-v1.7/cc_news_tail-0000.json.gz + - gs://marin-us-central2/raw/dolma/v1.7/cc_news_head-{0000..0004}.json.gz + - gs://marin-us-central2/raw/dolma/v1.7/cc_news_middle-{0000..0002}.json.gz + - gs://marin-us-central2/raw/dolma/v1.7/cc_news_tail-0000.json.gz dolma-falcon: train_urls: - - gs://marin-data/raw/dolma/dolma-v1.7/falcon-{0000..0499}.json.gz + - gs://marin-us-central2/raw/dolma/v1.7/falcon-{0000..0499}.json.gz dolma-megawika: train_urls: - - gs://marin-data/raw/dolma/dolma-v1.7/megawika-{0000..0261}.json.gz + - gs://marin-us-central2/raw/dolma/v1.7/megawika-{0000..0261}.json.gz dolma-owmath: train_urls: - - gs://marin-data/raw/dolma/dolma-v1.7/open-web-math-train-{0000..0012}.json.gz + - gs://marin-us-central2/raw/dolma/v1.7/open-web-math-train-{0000..0012}.json.gz dolma-pes2o: train_urls: - - gs://marin-data/raw/dolma/dolma-v1.7/pes2o-{0000..0025}.json.gz + - gs://marin-us-central2/raw/dolma/v1.7/pes2o-{0000..0025}.json.gz dolma-reddit: train_urls: - - gs://marin-data/raw/dolma/dolma-v1.7/reddit-{0000..0077}.json.gz + - gs://marin-us-central2/raw/dolma/v1.7/reddit-{0000..0077}.json.gz dolma-stackexchange: train_urls: - - gs://marin-data/raw/dolma/dolma-v1.7/stackexchange-{0000..0025}.json.gz + - gs://marin-us-central2/raw/dolma/v1.7/stackexchange-{0000..0025}.json.gz dolma-starcoder: train_urls: - - gs://marin-data/raw/dolma/dolma-v1.7/starcoder-{0000..0048}.json.gz + - gs://marin-us-central2/raw/dolma/v1.7/starcoder-{0000..0048}.json.gz dolma-flan: train_urls: - - gs://marin-data/raw/dolma/dolma-v1.7/tulu_flan-{0000..0065}.json.gz + - gs://marin-us-central2/raw/dolma/v1.7/tulu_flan-{0000..0065}.json.gz dolma-wiki: train_urls: - - gs://marin-data/raw/dolma/dolma-v1.7/wiki-{0000..0001}.json.gz + - gs://marin-us-central2/raw/dolma/v1.7/wiki-{0000..0001}.json.gz # these are just for eval "paloma/4chan": validation_urls: diff --git a/config/llama_7b_with_dclm.yaml b/config/llama_7b_with_dclm.yaml new file mode 100644 index 000000000..980e64e41 --- /dev/null +++ b/config/llama_7b_with_dclm.yaml @@ -0,0 +1,33 @@ +data: !include data/dclm_gpt_neo.yaml +model: # 7B class model + type: llama + seq_len: 2048 + hidden_dim: 4096 + intermediate_dim: 11008 + num_layers: 32 + num_heads: 32 + num_kv_heads: 32 + use_flash_attention: True +trainer: + tracker: + type: wandb + entity: "stanford-mercury" + project: "marin" + tags: ["dclm", "7B", "llama"] + + mp: p=f32,c=bfloat16 + train_batch_size: 2048 + num_train_steps: 70000 # 280B / 4M + steps_per_eval: 1000 + tensor_parallel_axes: ["mlp", "heads"] + fsdp_axis: "embed" + batch_axis: "batch" +optimizer: + learning_rate: 4e-4 + weight_decay: 0.1 + min_lr_ratio: 0.1 + beta1: 0.9 + beta2: 0.95 + warmup: 5000 + +z_loss_weight: 5e-6 diff --git a/pyproject.toml b/pyproject.toml index babf664e9..b0c3df90a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ dependencies = [ "braceexpand>=0.1.7", "jmp>=0.0.3", "fsspec[http]>=2024.2,<2024.10", - "tensorstore==0.1.63", + "tensorstore>=0.1.65", "pytimeparse>=1.1.8", "humanfriendly==10.0", "safetensors[numpy]~=0.4.2", @@ -50,7 +50,8 @@ dependencies = [ "filelock~=3.13", # "ai2-olmo", "async-lru~=2.0", - "tqdm-loggable>=0.2" + "tqdm-loggable>=0.2", + "deepdiff" ] [project.urls] diff --git a/src/levanter/data/_preprocessor.py b/src/levanter/data/_preprocessor.py index 9ee1e2dc2..09efb364d 100644 --- a/src/levanter/data/_preprocessor.py +++ b/src/levanter/data/_preprocessor.py @@ -27,10 +27,8 @@ class BatchProcessor(Generic[T_contra, U], ABC): @abstractmethod def __call__(self, batch: Sequence[T_contra]) -> Sequence[U] | U: # U can be batched "structure of arrays" form """ - Process a batch of data. You should return either a RecordBatch, a sequence of dicts (one per output + Process a batch of data. You should return a sequence of dicts (one per output example), or a dict of sequences (one per output field). - - (We allow Mapping so that you can just return HF's BatchEncoding if you want.) """ raise NotImplementedError @@ -58,8 +56,10 @@ def num_gpus(self) -> int: return 0 @property - def batch_size(self) -> int: - return 128 + @abstractmethod + def metadata(self) -> Dict[str, Any]: + """Any metadata that changes the behavior of this processor.""" + raise NotImplementedError class _DatasetTransform(ABC): @@ -150,7 +150,7 @@ def rec(dataset): class _CompositeBatchProcessor(BatchProcessor): - def __init__(self, transforms, batch_size, num_cpus, num_gpus, resources): + def __init__(self, transforms, num_cpus, num_gpus, resources): self.transforms = transforms self._num_cpus = num_cpus self._num_gpus = num_gpus @@ -207,6 +207,10 @@ def __call__(self, batch): return batch + @property + def metadata(self): + return {} + def dict_from_record_batch(b) -> dict: # we follow the convention from hf batchencoding where homogeneous-lengthed arrays are turned into nd arrays diff --git a/src/levanter/data/_queue.py b/src/levanter/data/_queue.py deleted file mode 100644 index fd8f84860..000000000 --- a/src/levanter/data/_queue.py +++ /dev/null @@ -1,248 +0,0 @@ -import asyncio -import dataclasses -import heapq -import logging as pylogging -import threading -import time -from dataclasses import dataclass -from queue import PriorityQueue -from typing import List, Optional, Protocol, Sequence, TypeVar - -import ray -from ray.actor import ActorHandle - -from levanter.utils.ray_utils import RefBox - -from ._preprocessor import BatchProcessor - - -logger = pylogging.getLogger(__name__) - -T = TypeVar("T") -U = TypeVar("U") -LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" - - -class PriorityWorkTaskGroupSpec(Protocol): - name: str - - def build(self) -> "PriorityWorkTaskGroup": - raise NotImplementedError() - - -class PriorityWorkTaskGroup(Protocol): - name: str - spec: PriorityWorkTaskGroupSpec - - def items(self) -> Sequence["PriorityWorkItem"]: - raise NotImplementedError() - - -class PriorityWorkItem(Protocol): - name: str - priority: float - spec: PriorityWorkTaskGroupSpec - - def execute(self) -> tuple[bool, Optional[ray.ObjectRef]]: - """ - Returns true if the item is finished, false if it should be rescheduled. - The object ref is used (1) to block shutting down the actor too early - and (2) for backpressure. - """ - raise NotImplementedError() - - # needs to be sortable by priority - def __lt__(self, other: "PriorityWorkItem"): - if self.priority == other.priority: - return self.name < other.name - else: - return self.priority < other.priority - - def __le__(self, other: "PriorityWorkItem"): - if self.priority == other.priority: - return self.name <= other.name - else: - return self.priority <= other.priority - - -def _mk_queue_aware_process_task(processor: BatchProcessor[T, U], queue: ActorHandle): - @ray.remote(num_cpus=processor.num_cpus, num_gpus=processor.num_gpus, resources=processor.resources) - def process_task(desc, batch: List[T]): - # pylogging.basicConfig(level=DEFAULT_LOG_LEVEL, format=LOG_FORMAT) - logger.debug(f"Processing batch {desc}") - queue.task_running.remote() - try: - result = processor(batch) - logger.debug(f"Finished processing batch {desc}") - return result - except Exception as e: - logger.exception(f"Error while processing batch {desc}") - raise e - finally: - pass - - return process_task - - -@dataclass(order=True, frozen=True) -class _QueueItem: - priority: float - desc: str - batch: ray.ObjectRef = dataclasses.field(compare=False) - task_id: int - task_future: asyncio.Future = dataclasses.field(compare=False) - - -@ray.remote(num_cpus=0) -class _BatchProcessorQueue: # (Generic[T]): ray doesn't like generics - """ - A queue of tasks to be processed by a BatchProcessor. - - BatchProcessorQueue spins up tasks to process batches of data. - It spins up tasks until it reaches the maximum number of tasks that can be run in parallel. - It then waits for a task to finish before spinning up another one. - """ - - pqueue: PriorityQueue[_QueueItem] - processor: BatchProcessor - _next_task_id: int - ready: bool # whether or not we can spin up a new task - - @property - def batch_size(self): - return self.processor.batch_size - - def __init__(self, batch_processor: BatchProcessor[T, U]): - self.pqueue = PriorityQueue() - self.processor = batch_processor - self._next_task_id = 0 - self.ready = True # whether we're ready to ask ray to start a new task - self_ref = ray.runtime_context.get_runtime_context().current_actor - self._task_processor = _mk_queue_aware_process_task(batch_processor, self_ref) - - # we don't need/want to dereference the batch, so we wrap it in a RefBox - # one virtue of doing things this way is that we can let Ray try to schedule the compute near the data. - async def submit(self, priority: float, desc: str, batch: RefBox): - """Returns a future that is set to the *ObjectRef* of the processed batch. The future is "complete" when the task - starts, not when it finishes. You then call ray.get on the future's result to get the actual batch.""" - task_id = self._next_task_id - self._next_task_id += 1 - f: asyncio.Future = asyncio.Future() - self.pqueue.put(_QueueItem(priority, desc, batch.ref, task_id, f)) - self._maybe_start_task() - return await f - - def _maybe_start_task(self): - if self.ready and not self.pqueue.empty(): - self.ready = False - item = self.pqueue.get() - batch = item.batch - try: - item.task_future.set_result(self._task_processor.remote(item.desc, batch)) - except Exception as e: - item.task_future.set_exception(e) - - def task_running(self): - self.ready = True - self._maybe_start_task() - - -@ray.remote(num_cpus=0.5, scheduling_strategy="SPREAD") -class WorkQueueDispatcherActor: - def __init__(self, max_in_flight: Optional[int] = 200): - pylogging.basicConfig(level=pylogging.INFO, format=LOG_FORMAT) - self._queue: list[PriorityWorkItem] = [] # heapq - self._queue_lock = threading.Lock() - self._shutdown_event = threading.Event() - self._current_item: Optional[PriorityWorkItem] = None - self._max_in_flight = max_in_flight - - self._max_priority: Optional[float] = None - self._processing_thread = threading.Thread(target=self._loop, daemon=True) - self._processing_thread.start() - - def set_max_dispatch_priority(self, max_priority: Optional[float]): - """ - When the sink is full, we will not dispatch items with a priority higher than this. - """ - with self._queue_lock: - self._max_priority = max_priority - - def assign_work(self, group: PriorityWorkTaskGroupSpec): - items = group.build().items() - with self._queue_lock: - for item in items: - heapq.heappush(self._queue, item) - - def is_group_finished(self, group: PriorityWorkTaskGroupSpec): - with self._queue_lock: - if any(item.spec == group for item in self._queue): - return False - - if self._current_item is not None and self._current_item.spec == group: - return False - - logger.debug(f"Group {group.name} is finished.") - - return True - - def cancel_work_group(self, group: PriorityWorkTaskGroupSpec): - # kill all the items in the group - with self._queue_lock: - self._queue = [item for item in self._queue if item.spec != group] - heapq.heapify(self._queue) - - def shutdown(self): - if not self._shutdown_event.is_set(): - self._shutdown_event.set() - - if self._processing_thread.is_alive(): - self._processing_thread.join() - - def _loop(self: "WorkQueueDispatcherActor"): - should_sleep = False - backpressure_queue: list[ray.ObjectRef] = [] - - def drain_backpressure_to(count): - nonlocal backpressure_queue - while len(backpressure_queue) > count: - finished, remaining = ray.wait(backpressure_queue, num_returns=1, fetch_local=False) - backpressure_queue = remaining - - while not self._shutdown_event.is_set(): - if should_sleep: - time.sleep(0.1) - - drain_backpressure_to(self._max_in_flight) - - with self._queue_lock: - if len(self._queue) == 0: - should_sleep = True - continue - else: - should_sleep = False - - item = heapq.heappop(self._queue) - if self._max_priority is not None and item.priority > self._max_priority: - logger.debug(f"Item {item.name} has priority {item.priority} which is too high. Rescheduling.") - heapq.heappush(self._queue, item) - continue - self._current_item = item - - try: - item_is_finished, ref = item.execute() - if ref is not None: - backpressure_queue.append(ref) - except Exception: - logger.exception(f"Error while processing {item.name}. Killing all associated work.") - self.cancel_work_group(item.spec) - continue - - with self._queue_lock: - self._current_item = None - if not item_is_finished: - heapq.heappush(self._queue, item) - - logger.debug("Shutting down PriorityProcessorActor. Waiting for backpressure to drain.") - drain_backpressure_to(0) - logger.debug("Backpressure drained. Shutting down PriorityProcessorActor.") diff --git a/src/levanter/data/audio.py b/src/levanter/data/audio.py index d04479a24..f8a193f04 100644 --- a/src/levanter/data/audio.py +++ b/src/levanter/data/audio.py @@ -4,7 +4,7 @@ import os from dataclasses import dataclass from functools import cached_property -from typing import Iterator, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, Union import braceexpand import datasets @@ -29,7 +29,7 @@ # intercept the logging nonsense here from levanter.logging import silence_transformer_nag from levanter.models.asr_model import AudioTextExample -from levanter.store.cache import TreeCache, build_or_load_cache +from levanter.store.cache import CacheOptions, TreeCache, build_or_load_cache from levanter.utils.jax_utils import local_cpu_mesh @@ -73,7 +73,6 @@ def __init__( enforce_bos=True, enforce_eos=True, *, - batch_size=128, override_resources=None, max_length=448, padding=True, @@ -83,7 +82,6 @@ def __init__( tokenizer, enforce_bos=enforce_bos, enforce_eos=enforce_eos, - batch_size=batch_size, override_resources=override_resources, return_attention_mask=True, padding="max_length" if padding else False, @@ -91,7 +89,6 @@ def __init__( ) self.override_resources = override_resources - self._batch_size = batch_size def __call__(self, batch: Sequence[Tuple[np.ndarray, int, str]]) -> Sequence[AudioTextDict]: """ @@ -123,6 +120,13 @@ def __call__(self, batch: Sequence[Tuple[np.ndarray, int, str]]) -> Sequence[Aud return out # type: ignore + @property + def metadata(self) -> Dict[str, Any]: + return { + "tokenizer": self.bt.metadata, + "processor": self.feature_extractor.to_dict(), + } + @property def output_exemplar(self): return AudioTextDict_exemplar @@ -136,10 +140,6 @@ def num_cpus(self) -> int: def num_gpus(self) -> int: return self.bt.num_gpus - @property - def batch_size(self) -> int: - return self.bt._batch_size - @dataclass class AudioDatasetSourceConfig: @@ -247,8 +247,10 @@ def the_feature_extractor(self) -> SequenceFeatureExtractor: @abc.abstractmethod def train_set( - self, batch_size: int, monitors: Union[bool, List[MetricsMonitor]] = True - ) -> AsyncDataset[np.ndarray]: + self, + monitors: Union[bool, List[MetricsMonitor]] = True, + options: CacheOptions = CacheOptions.default(), + ) -> AsyncDataset[AudioTextDict]: pass @abc.abstractmethod @@ -294,18 +296,17 @@ def build_or_load( tokenizer: PreTrainedTokenizerBase, enforce_bos=True, enforce_eos=True, - batch_size=128, monitors=None, await_finished=True, override_resources=None, max_length=448, + cache_options: CacheOptions = CacheOptions.default(), ) -> "ProcessedAudioCache": bp = BatchAudioProcessor( processor, tokenizer, enforce_bos=enforce_bos, enforce_eos=enforce_eos, - batch_size=batch_size, override_resources=override_resources, max_length=max_length, ) @@ -316,6 +317,7 @@ def build_or_load( bp, await_finished=await_finished, monitors=monitors, + options=cache_options, ) if cache.is_finished: logger.info(f"Cache {cache_dir} is complete.") @@ -339,7 +341,8 @@ def load(cache_dir): """ try: - cache = TreeCache.load(cache_dir, AudioTextDict_exemplar) + # TODO: populate cache config + cache = TreeCache.load(cache_dir, AudioTextDict_exemplar, options=None) return ProcessedAudioCache(cache) except FileNotFoundError: raise FileNotFoundError(f"{cache_dir} is not a complete cache") @@ -352,8 +355,10 @@ def load(cache_dir): class AudioIODatasetConfig(AudioDatasetSourceConfig, AudioTaskConfig): """This class supports loading data both from HF Datasets and from a raw dataset of jsonl urls""" - def train_set(self, batch_size: int, monitors: Union[bool, List[MetricsMonitor]] = True) -> ProcessedAudioCache: - ds = self.build_or_load_cache(self.train_split, batch_size=batch_size, monitors=monitors) + def train_set( + self, monitors: Union[bool, List[MetricsMonitor]] = True, options: CacheOptions = CacheOptions.default() + ) -> ProcessedAudioCache: + ds = self.build_or_load_cache(self.train_split, monitors=monitors) if ds is None: raise ValueError("No training set!") return ds @@ -388,9 +393,9 @@ def _has_validation_set(self): def build_or_load_cache( self, split: str, - batch_size: int = 128, monitors: Union[bool, List[MetricsMonitor]] = True, logger_name: Optional[str] = None, + cache_options: CacheOptions = CacheOptions.default(), ) -> Optional[ProcessedAudioCache]: split_cache_dir = os.path.join(self.cache_dir, split) name = logger_name or os.path.basename(self.cache_dir) @@ -422,10 +427,10 @@ def build_or_load_cache( self.the_tokenizer, enforce_bos=self.enforce_bos, enforce_eos=self.enforce_eos, - batch_size=batch_size, monitors=monitors, await_finished=(split == "validation"), max_length=self.max_length, + cache_options=cache_options, ) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 20a11d090..62dfb62ba 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -8,7 +8,7 @@ from dataclasses import dataclass from functools import cached_property from itertools import chain -from typing import Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, TypeVar, Union +from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, TypeVar, Union import braceexpand import datasets @@ -34,7 +34,7 @@ from levanter.logging import silence_transformer_nag # noqa from levanter.models.attention import AttentionMask from levanter.models.lm_model import LmExample -from levanter.store.cache import TreeCache +from levanter.store.cache import CacheOptions, TreeCache from levanter.store.jagged_array import JaggedArrayStore from levanter.store.tree_store import TreeStore from levanter.utils.hf_utils import num_cpus_used_by_tokenizer @@ -114,23 +114,6 @@ async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]: out = await asyncio.gather(*out) return out - def get_batch_sync(self, indices: Sequence[int]) -> Sequence[T_co]: - token_arrays = self.doc_cache.store.tree["input_ids"] - # logger.info(f"Time to get token cache: {time.time() - time_in}") - # len = await self.wait_until_len_at_least(max(indices) + 1) - # if len is not None and len < max(indices) + 1: - # raise ValueError("Requested indices beyond the end of the dataset") - offsets = np.array(indices) * self.seq_len - with ts.Batch(): - out = [] - for offset in offsets: - out.append(token_arrays.data[offset : offset + self.seq_len].read()) - # logger.info(f"Time to read token cache: {time.time() - time_in}") - - out = [x.result() for x in out] - # logger.info(f"Time to wait for token cache: {time.time() - time_in}") - return out - async def wait_until_len_at_least(self, length: int) -> int: # length is brutally slow to compute, so we cache it if self._cached_len is not None and self._cached_len >= length: @@ -213,7 +196,6 @@ def __init__( enforce_bos=True, enforce_eos=True, *, - batch_size=128, override_resources=None, _workaround_len=LONG_STRING_WORKAROUND, return_attention_mask=False, @@ -247,8 +229,6 @@ def __init__( should_append_eos = False should_append_bos = False - self._batch_size = batch_size - self._need_to_add_eos = should_append_eos self._need_to_add_bos = should_append_bos self._workaround_len = _workaround_len @@ -306,6 +286,18 @@ def _break_for_long_sequences(self, batch): batch.append(d) return batch, needs_merge + @property + def metadata(self) -> Dict[str, Any]: + return { + "tokenizer": self.tokenizer.name_or_path, + "vocab_size": len(self.tokenizer), + "return_attention_mask": self.return_attention_mask, + "padding": self.padding, + "max_length": self.max_length, + "append_bos": self._need_to_add_bos, + "append_eos": self._need_to_add_eos, + } + @property def output_exemplar(self) -> dict: return dict(**self.tokenizer("hi there", return_attention_mask=self.return_attention_mask, verbose=False)) @@ -385,10 +377,6 @@ def num_gpus(self) -> int: return self.override_resources.get("num_gpus", 0) return 0 - @property - def batch_size(self) -> int: - return self._batch_size - def concatenate_and_group_texts( encoding: BatchEncoding, @@ -543,7 +531,7 @@ class LMTaskConfig(abc.ABC): # config related to caching cache_dir: str = "cache/" - tokenizer_batch_size: int = 32 + cache_options: CacheOptions = field(default_factory=CacheOptions) enforce_eos: bool = True # whether to append eos even if the tokenizer doesn't ignore_token_id: Optional[int] = None @@ -650,6 +638,7 @@ def build_or_load_cache( name = logger_name or os.path.basename(self.cache_dir) try: + # TODO: pass in options return TreeCache.load(split_cache_dir, exemplar={"input_ids": np.zeros(0, dtype=np.int32)}) except FileNotFoundError: pass @@ -669,20 +658,15 @@ def build_or_load_cache( elif monitors is False: monitors = [] - bt = BatchTokenizer( - self.the_tokenizer, enforce_bos=True, enforce_eos=self.enforce_eos, batch_size=self.tokenizer_batch_size - ) + bt = BatchTokenizer(self.the_tokenizer, enforce_bos=True, enforce_eos=self.enforce_eos) return build_or_load_cache( split_cache_dir, source, bt, - await_finished=False, monitors=monitors, - cache_config={ - "tokenizer": self.the_tokenizer.name_or_path, - "vocab_size": self.the_tokenizer.vocab_size, - }, + await_finished=False, + options=self.cache_options, ) @@ -820,10 +804,12 @@ def build_caches( # in practice it works best if we block on validation caches if split == "validation": - logger.info("Waiting for validation caches to finish building...") for cache in caches.values(): cache.await_finished() + else: + logger.info(f"Not waiting for {split} caches to finish building") + return caches @property diff --git a/src/levanter/main/train_asr.py b/src/levanter/main/train_asr.py index 72e6d5adb..681a806a6 100644 --- a/src/levanter/main/train_asr.py +++ b/src/levanter/main/train_asr.py @@ -115,7 +115,7 @@ def compute_loss( eval_datasets = config.data.validation_sets() train_dataset = AudioTextDataset( - config.data.train_set(config.batch_size), + config.data.train_set(), Pos, [config.model.Mels, config.model.MelPos], KeyPos, diff --git a/src/levanter/store/_prefetch_actor.py b/src/levanter/store/_prefetch_actor.py new file mode 100644 index 000000000..6b3c302c2 --- /dev/null +++ b/src/levanter/store/_prefetch_actor.py @@ -0,0 +1,156 @@ +import asyncio +import logging +from dataclasses import dataclass +from queue import Empty as QueueEmpty +from typing import Callable, Generic, Iterator, List, Optional, TypeVar + +import ray + +from levanter.utils.ray_utils import ExceptionInfo, ser_exc_info + + +T = TypeVar("T") + +logger = logging.getLogger(__name__) + + +@dataclass +class _PrefetchException: + info: ExceptionInfo + + +class _Sentinel: + pass + + +_SENTINEL = _Sentinel() + + +class RayPrefetchQueue(Generic[T]): + def __init__( + self, producer: Callable[[], Iterator[T]], max_queue_size: int = 100, producer_options: dict | None = None + ): + self.max_queue_size = max_queue_size + if producer_options is None: + producer_options = {} + self.queue_actor = _QueueActor.remote(max_queue_size) # type: ignore + self.producer_task = _run_producer.options(**producer_options).remote(self.queue_actor, producer) + self._stopped = False + self._finished = False + + def queue_size(self): + return ray.get(self.queue_actor.qsize.remote()) + + def __next__(self): + return self.get_next() + + def __iter__(self): + return self + + def get_next(self, timeout: float | None = None) -> T: + """ + Get the next item from the producer. If the producer raises an exception, it will be reraised here. + + If the producer is done, this will raise StopIteration. + + Args: + timeout (float|None): Timeout in seconds for getting the next item. If None, will block indefinitely. + + Raises: + Empty: If the queue is empty and the timeout is reached. + """ + if self._finished: + raise StopIteration + # time_in = time.time() + item = ray.get(self.queue_actor.get_next.remote(timeout)) + # time_out = time.time() + # if time_out - time_in > 0.1: + # current_name = ray.get_runtime_context().get_actor_name() + # print(f"{current_name} :: Queue get took {time_out - time_in} seconds :: {self.queue_size()}") + # logger.info(f"{current_name} :: Queue get took {time_out - time_in} seconds :: {self.queue_size()}") + if isinstance(item, _PrefetchException): + item.info.reraise() + if isinstance(item, _Sentinel): + self._finished = True + raise StopIteration + return item + + def stop(self): + ray.cancel(self.producer_task) + ray.get(self.queue_actor.stop.remote()) + self._stopped = True + + def is_stopped(self): + return self._stopped + + def drain_available(self, max_size: int) -> List[T]: + return ray.get(self.queue_actor.drain_available.remote(max_size)) + + +@ray.remote +class _QueueActor: + def __init__(self, max_queue_size: int): + self.queue: asyncio.Queue = asyncio.Queue(maxsize=max_queue_size) + self._stopped = False + self._finished = False + + async def put(self, item): + await self.queue.put(item) + + async def get_next(self, timeout: Optional[float] = None): + try: + if timeout is not None: + item = await asyncio.wait_for(self.queue.get(), timeout) + else: + item = await self.queue.get() + if isinstance(item, _Sentinel): + self._finished = True + return item + except asyncio.TimeoutError: + raise QueueEmpty + + async def drain_available(self, max_size: int) -> List[T]: + items: list[T] = [] + while len(items) < max_size: + try: + item = self.queue.get_nowait() + if isinstance(item, _Sentinel): + self._finished = True + break + if isinstance(item, _PrefetchException): + item.info.reraise() + items.append(item) + except asyncio.QueueEmpty: + break + return items + + async def qsize(self): + return self.queue.qsize() + + async def stop(self): + self._stopped = True + + +@ray.remote +def _run_producer(queue_actor, producer_fn: Callable[[], Iterator[T]]): + async def _run_producer(queue_actor, producer_fn): + previous_put = None + try: + producer = producer_fn() + del producer_fn + + while True: + next_item = next(producer) + if previous_put is not None: + await previous_put + previous_put = queue_actor.put.remote(next_item) + except StopIteration: + if previous_put is not None: + await previous_put + await queue_actor.put.remote(_SENTINEL) + except Exception as e: + if previous_put is not None: + await previous_put + await queue_actor.put.remote(_PrefetchException(ser_exc_info(e))) + + asyncio.run(_run_producer(queue_actor, producer_fn)) diff --git a/src/levanter/store/cache.py b/src/levanter/store/cache.py index 56aa54f99..eae9f8402 100644 --- a/src/levanter/store/cache.py +++ b/src/levanter/store/cache.py @@ -1,44 +1,36 @@ import asyncio import concurrent +import copy import dataclasses -import heapq import logging as pylogging import os +import pprint +import random import threading import time from asyncio import InvalidStateError from concurrent.futures import Future as threading_Future from contextlib import AbstractContextManager from dataclasses import dataclass -from typing import Any, Callable, Dict, Generic, Iterator, List, Optional, Sequence, TypeVar, Union +from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, TypeVar, Union +import deepdiff import fsspec.core import pyarrow as pa import ray from dataclasses_json import dataclass_json from fsspec import AbstractFileSystem from ray.actor import ActorHandle +from ray.remote_function import RemoteFunction from levanter.data.dataset import AsyncDataset +from levanter.store._prefetch_actor import QueueEmpty, RayPrefetchQueue +from levanter.utils.py_utils import Stopwatch from ..data._preprocessor import BatchProcessor, BatchResult, dict_from_record_batch -from ..data._queue import ( - PriorityWorkItem, - PriorityWorkTaskGroup, - PriorityWorkTaskGroupSpec, - WorkQueueDispatcherActor, - _BatchProcessorQueue, -) from ..data.metrics_monitor import InProgressCacheMetrics, LoggerMetricsMonitor, MetricsMonitor from ..data.sharded_datasource import ShardedDataSource -from ..utils.ray_utils import ( - ExceptionInfo, - RefBox, - SnitchRecipient, - current_actor_handle, - log_failures_to, - ser_exc_info, -) +from ..utils.ray_utils import ExceptionInfo, SnitchRecipient, current_actor_handle, log_failures_to, ser_exc_info from .tree_store import TreeStore @@ -53,10 +45,46 @@ DEFAULT_LOG_LEVEL = pylogging.INFO LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" -# TODO: should probably do this in terms of bytes -# this is kinda silly, but the bigger the better. -MIN_ITEMS_TO_WRITE = 32 * 1024 -MAX_TIME_BETWEEN_WRITES = 100.0 + +@dataclass_json +@dataclass(frozen=True) +class CacheOptions: + """ + Configuration for a cache. This is used to configure a few parts of the cache creation process and to + store metadata that can be checked to ensure that the cache being loaded was created with the expected + configuration. It combined with the [[BatchProcessor]] metadata to form the [[CacheMetadata]]. + + It is intended that caching it deterministic conditional on the input data, processor, and these options. + """ + + num_shard_groups: Optional[int] = 128 + """Number of groups to divide the shards into. This is used to parallelize the cache building process without + overloading Ray. If None, all shards will be in their own group.""" + shard_order_randomization_key: Optional[int] = 0 + """A key used to randomize the order of the shards before building and grouping.""" + batch_size: int = 128 + """The batch size to use when processing the data. This is used to control the memory usage of the cache building + process. Lower values will use less memory but take somewhat longer to build the cache.""" + + @staticmethod + def default(): + return CacheOptions() + + @staticmethod + def no_fanciness(batch_size: Optional[int] = None): + """ + For testing, disables all the fancy features of the cache. This makes it easier to predict the behavior + """ + if batch_size is None: + batch_size = 128 + return CacheOptions(num_shard_groups=None, shard_order_randomization_key=None, batch_size=batch_size) + + @staticmethod + def one_group(): + """ + For testing, disables all the fancy features of the cache. This makes it easier to predict the behavior + """ + return CacheOptions(num_shard_groups=1, shard_order_randomization_key=None, batch_size=128) def build_or_load_cache( @@ -65,8 +93,8 @@ def build_or_load_cache( processor: BatchProcessor[T, U], await_finished: bool = True, monitors: Optional[Sequence["MetricsMonitor"]] = None, - cache_config: Optional[Dict[str, Any]] = None, - items_per_write: int = MIN_ITEMS_TO_WRITE, + options: CacheOptions = CacheOptions.default(), + force_flush: bool = False, ) -> "TreeCache[U]": """ Produces a sharded cache of the dataset using Ray for distributed processing. The cache can be any path @@ -91,10 +119,9 @@ def build_or_load_cache( monitors: a list of MetricsMonitors to attach to the cache. These will be called periodically with metrics about the cache build process. If None, will add a LoggerMetricsMonitor. - cache_config: A dictionary of configuration options for the cache. This is passed to the cache writer. + options: Configuration for the cache. This is used to configure a few parts of the cache creation process - items_per_write: The number of items to write to the cache at a time. This is a performance tuning parameter, - and you probably don't need to change it. We mostly use it for testing. + force_flush: for testing, forces the cache to flush after every batch. This is useful for testing. Returns: (TreeCache) A TreeCache object that can be used to read the cache. @@ -105,8 +132,8 @@ def build_or_load_cache( cache_dir=cache_dir, shard_source=input_shards, processor=processor, - cache_config=cache_config, - items_per_write=items_per_write, + options=options, + force_flush=force_flush, ) if cache.is_finished: @@ -129,519 +156,551 @@ def build_or_load_cache( return cache -@dataclass_json -@dataclass -class CacheLedger: - # NB: unlike the old cache, the mere existence of a ledger doesn't mean the cache is finished - total_num_rows: int - shard_rows: Dict[str, int] - is_finished: bool = False - finished_shards: List[str] = dataclasses.field(default_factory=list) - field_counts: Dict[str, int] = dataclasses.field(default_factory=dict) - metadata: Dict[str, Any] = dataclasses.field(default_factory=dict) - - -@dataclass -class ShardStatus: - shard_name: str - num_rows_committed: int - is_finished: bool - - -class SerialCacheWriter(AbstractContextManager): - """ - Writes TreeCache-compatible caches to disk. This is a serial version of TreeCacheWriter that doesn't use Ray. - Mostly for scripts and debugging. - - Examples: - >>> with SerialCacheWriter(cache_dir, exemplar) as writer: - ... for batch in process_batches(): - ... writer.write_batch(batch) - """ +class TreeCache(AsyncDataset[T_co]): + ledger: Optional["CacheLedger"] + _builder: Optional[ActorHandle] # handle of _TreeStoreCacheBuilder + # monitor_thread waits for new metrics and also periodically reloads the cache + _monitor_thread: Optional[threading.Thread] + _metrics_monitors: List[MetricsMonitor] def __init__( self, cache_dir: str, - exemplar: T, - cache_config: Optional[Dict[str, Any]] = None, + exemplar: T_co, + ledger: Optional["CacheLedger"], + _broker, # handle of _TreeStoreCacheBuilder ): self.cache_dir = cache_dir - self.cache_config = cache_config + self.ledger = ledger + self._was_already_finished = ledger is not None and ledger.is_finished + self._builder = _broker self._exemplar = exemplar - self._tree_store = TreeStore.open(exemplar, self.cache_dir, mode="w") # type: ignore - self._is_closed = False - def __enter__(self) -> "SerialCacheWriter": - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - # if successful, write the ledger - # TODO: store field counts in the ledger - ledger = CacheLedger( - total_num_rows=len(self._tree_store), - is_finished=True, - shard_rows={"": len(self._tree_store)}, - finished_shards=[""], - field_counts={}, - ) + self._metrics_monitors = [] + name = os.path.join(*cache_dir.split("/")[-2:]) + self.logger = pylogging.getLogger(f"TreeCache.{name}") + self._store_future: threading_Future[TreeStore] = threading_Future() + self._stop = False + # assert _broker is None - if exc_type is None: - _serialize_json_and_commit(os.path.join(self.cache_dir, LEDGER_FILE_NAME), ledger) - logger.info(f"Cache ledger written to {self.cache_dir}") - self._is_closed = True + if self._builder is not None: + self._monitor_thread = threading.Thread(target=self._monitor_metrics, daemon=True) + self._monitor_thread.start() + else: + self._attempt_to_load_store() + assert self._store_future.done() - def result(self) -> "TreeCache": - if not self._is_closed: - raise RuntimeError("Cannot get result until TreeCacheWriter is closed") - return TreeCache.load(self.cache_dir, self._exemplar) + @property + def store(self) -> TreeStore[T_co]: + return self._store_future.result() - def write_batch(self, batch: BatchResult): - if isinstance(batch, pa.RecordBatch): - raise NotImplementedError("Only non-RecordBatch batches are supported for now") + async def store_async(self) -> TreeStore[T_co]: + if self._builder is not None: + return await asyncio.wrap_future(self._store_future) + else: + return self.store - batch = _canonicalize_batch(batch) # type: ignore + async def async_len(self) -> int: + if self._builder is not None: + self.await_finished() - self._tree_store.extend(batch) + return len(await self.store_async()) + def __len__(self): + self.await_finished() -def _load_or_initialize_ledger(path): - try: - with fsspec.open(path, "r") as file: - return CacheLedger.from_json(file.read()) - except FileNotFoundError: - return CacheLedger(0, {}) + return len(self.store) + async def final_length_is_known(self) -> bool: + return self.ledger is not None and self.ledger.is_finished -@ray.remote(num_cpus=0.5) # type: ignore -class _OrderedCacheWriter: - """ - This cache writer receives examples from some number of shards (generally out of order) and writes them to the store - in a defined round-robin order. It also keeps track of the metadata for each shard. + def is_finite(self) -> bool: + return True - Once a shard finishes sending batches, it notifies this writer, which then updates the metadata and writes it to disk. - """ + async def current_len(self) -> int: + if not self._store_future.done(): + return 0 - def __init__( - self, - parent, - name, - exemplar, - batch_size, - cache_dir: str, - shards: Sequence[str], - min_items_to_write=MIN_ITEMS_TO_WRITE, - ): - pylogging.basicConfig(level=DEFAULT_LOG_LEVEL, format=LOG_FORMAT) - with log_failures_to(parent): - self._parent = parent - self.cache_dir = cache_dir - self.shards = shards - self.batch_size = batch_size - self._min_items_to_write = min_items_to_write - self._failed = False - self._logger = pylogging.getLogger(name) - - # these are batches that we've received but haven't ordered them for writing yet - self._batch_queue = GroupRoundRobinBuffer(shards) # type: ignore - self._total_queue_length = 0 - self._was_overwhelmed = False # whether the queue has gotten too big - # writes are very slow (~2s) so we want to batch them up - self._ordered_but_unwritten_items: list = [] - self._batches_in_next_write_by_shard: dict[str, int] = {shard: 0 for shard in shards} - # we also want to write every so often - self._last_write_time = time.time() - - self._ledger = _load_or_initialize_ledger(os.path.join(cache_dir, LEDGER_FILE_NAME)) - self._expected_num_rows: dict[str, Optional[int]] = {shard: None for shard in shards} - - self._tree_store = TreeStore.open(exemplar, self.cache_dir, mode="a") - # careful: trim the store to the total number of rows in the cache that we've committed to - self._tree_store.trim_to_size(self._ledger.total_num_rows) - # we also have to tell the queue how many rows for each shard we've already written - for shard, num_rows in self._ledger.shard_rows.items(): - if num_rows > 0: - self._logger.info(f"Already written {num_rows} rows for shard {shard}") - - # careful: this is in terms of batch size - # Have to round up to the nearest batch size - self._batch_queue.fast_forward(shard, div_round_up(num_rows, self.batch_size)) - if shard in self._ledger.finished_shards: - self._expected_num_rows[shard] = num_rows - self._batch_queue.group_total_known(shard, div_round_up(num_rows, self.batch_size)) - - # double check that we're not finished by committing the ledger - self._attempt_to_write_batches() - - if not self._ledger.is_finished: - self._actual_writer_thread = threading.Thread(target=self._write_loop, daemon=True) - self._stop_loop = threading.Event() - self._actual_writer_thread.start() - - def batch_finished(self, shard_name: str, shard_batch_idx: int, batch_result_box): - with log_failures_to(self._parent): - if self._failed: - self._logger.warning("Received batch after failure. Ignoring.") - return + return len(await self.store_async()) - if isinstance(batch_result_box, RefBox): - batch_result = ray.get(batch_result_box.ref) - else: - batch_result = batch_result_box - - # we need to keep track of the order of the batches so that we can write them out in order - self._total_queue_length += len(batch_result) - self._batch_queue.append_to_group(shard_name, shard_batch_idx, batch_result) - next_missing_item = self._batch_queue.next_missing_item_index() - - overwhelmed = self.is_overwhelmed() - if overwhelmed: - if not self._was_overwhelmed: - self._logger.warning(f"Writer queue is getting long ({self._total_queue_length}).") - self._parent.signal_backpressure.remote(next_missing_item) - elif self._was_overwhelmed: - self._logger.info(f"Writer queue is no longer overwhelmed ({self._total_queue_length}).") - self._parent.signal_backpressure.remote(None) - - self._was_overwhelmed = overwhelmed - - def shard_failed(self, shard_name: str, batch_id: int, exc_info: ExceptionInfo): - with log_failures_to(self._parent): - self._failed = True - self._stop_loop.set() - logger.error(f"Shard {shard_name} failed at batch {batch_id}", exc_info=exc_info.restore()) - self._parent.shard_failed.remote(shard_name, exc_info) - - def shard_finished_reading(self, shard_name: str, expected_num_rows: int): - with log_failures_to(self._parent): - # careful: this is in terms of batch size - self._batch_queue.group_total_known(shard_name, div_round_up(expected_num_rows, self.batch_size)) - self._expected_num_rows[shard_name] = expected_num_rows - logger.debug( - f"Attempting to write batches because {shard_name} finished reading with {expected_num_rows} batches." - ) - self.flush() + async def get_batch(self, indices: Sequence[int] | slice): + # this is tricky: we want to wait until either the cache is finished or we have the max index + if isinstance(indices, slice): + start, step, stop = await self._get_start_stops_async(indices) + await self._wait_for_len(max(stop, start)) + indices = range(start, stop, step) - def flush(self): - self._attempt_to_write_batches() + max_index = max(indices) + await self._wait_for_len(max_index + 1) - def get_shard_status(self, shard_name: str): - with log_failures_to(self._parent): - rows = self._ledger.shard_rows.get(shard_name, 0) - is_finished = shard_name in self._ledger.finished_shards - return ShardStatus(shard_name, rows, is_finished) + return await self.store.get_batch(indices) - def get_ledger(self): - return self._ledger + async def _wait_for_len(self, needed_len: int): + if self._builder is not None: + while needed_len > await self.current_len(): + new_ledger: CacheLedger = await self._builder.updated_ledger.remote() - def _attempt_to_write_batches(self): - if self._ledger.is_finished: - return + if needed_len <= new_ledger.total_num_rows: + break - if self._failed: - logger.warning("Not writing batches because of failure.") - return + if new_ledger.is_finished: + if needed_len >= new_ledger.total_num_rows: + raise IndexError( + f"Index {needed_len} out of bounds for cache of size {new_ledger.total_num_rows}" + ) + break + else: + if needed_len > len(self.store): + raise IndexError(f"Index {needed_len} out of bounds for cache of size {len(self.store)}") - self._dequeue_ready_batches() - updated_shards = self._write_available_batches() + def _wait_for_len_sync(self, needed_len, timeout: Optional[float] = None): + time_in = time.time() + t_max = time_in + (timeout or 1e6) + if self._builder is not None: + while needed_len > len(self.store): + cur_time = time.time() + if cur_time > t_max: + raise TimeoutError(f"Timed out waiting for cache to reach {needed_len}") + try: + new_ledger: CacheLedger = ray.get( + self._builder.updated_ledger.remote(), timeout=max(t_max - cur_time, 10) + ) + except TimeoutError: + continue - logger.debug(f"Updated shards: {updated_shards}") + if needed_len <= new_ledger.total_num_rows: + break - need_to_commit = len(updated_shards) > 0 - total_rows = self._ledger.total_num_rows + sum(updated_shards.values()) + if new_ledger.is_finished: + if needed_len >= new_ledger.total_num_rows: + raise IndexError( + f"Index {needed_len} out of bounds for cache of size {new_ledger.total_num_rows}" + ) + break + else: + if needed_len > len(self.store): + raise IndexError(f"Index {needed_len} out of bounds for cache of size {len(self.store)}") - for shard, num_rows in updated_shards.items(): - self._ledger.shard_rows[shard] = self._ledger.shard_rows.get(shard, 0) + num_rows + @staticmethod + def load(cache_dir: str, exemplar: T, options: Optional["CacheMetadata"] = None) -> "TreeCache": + """Loads a cache from disk or an object store. Raises FileNotFoundError if the cache doesn't exist""" + logger.info(f"Loading cache from {cache_dir}") + ledger = CacheLedger.load(cache_dir, options) + if not ledger.is_finished: + raise FileNotFoundError(f"Cache at {cache_dir} is not finished. Use build_or_load to build it.") + return TreeCache(cache_dir, exemplar, ledger, None) - futures_to_await_shards, need_to_commit_for_shards = self._check_for_finished_shards() + @staticmethod + def build_or_load( + cache_dir: str, + shard_source: ShardedDataSource[T], + processor: BatchProcessor[T, U], + options: Optional["CacheOptions"] = None, + force_flush: bool = False, + ) -> "TreeCache[U]": + if options is None: + options = CacheOptions.default() + metadata = CacheMetadata(options=options, preprocessor_metadata=processor.metadata) + try: + return TreeCache.load(cache_dir, processor.output_exemplar, metadata) + except FileNotFoundError: + broker = _get_builder_actor( + cache_dir=cache_dir, + shard_source=shard_source, + processor=processor, + options=options, + force_flush=force_flush, + ) + return TreeCache(cache_dir=cache_dir, exemplar=processor.output_exemplar, ledger=None, _broker=broker) - need_to_commit = need_to_commit or need_to_commit_for_shards + def finished_sentinel(self): + """Returns a Ray-awaitable object that will be set when the cache is finished""" + if self._builder is None: + return ray.remote(num_cpus=0)(lambda: None).remote() + else: + return self._builder.finished_sentinel.remote() - futures_to_await = [] - if need_to_commit: - self._ledger.total_num_rows = total_rows - _serialize_json_and_commit(os.path.join(self.cache_dir, LEDGER_FILE_NAME), self._ledger) + @property + def is_finished(self): + return self.ledger is not None and self.ledger.is_finished - futures_to_await.append(self._parent._updated_ledger.remote(self._ledger)) + def __getitem__(self, item): + if isinstance(item, slice): + start, step, stop = self._get_start_stops(item) + # TODO: wait for store to be set + return self.store[start:stop:step] + else: + if item < 0: + item += len(self) + if item < 0 or item >= len(self): + raise IndexError(f"Index {item} out of bounds for cache of size {len(self)}") + return self.store[item] - if self._ledger.is_finished: - f = self._parent._finalize.remote() - futures_to_await.append(f) + def get_batch_sync(self, indices_or_slice, *, timeout: Optional[float] = None): + store = self.store + if isinstance(indices_or_slice, slice): + start, step, stop = self._get_start_stops(indices_or_slice) + indices_or_slice = range(start, stop, step) - ray.wait(futures_to_await + futures_to_await_shards) + max_index = max(indices_or_slice) - def _finish(self): - self._stop_loop.set() - self._actual_writer_thread.join() + self._wait_for_len_sync(max_index + 1, timeout=timeout) - def _write_loop(self): - while True: - try: - self._stop_loop.wait(1) - if self._stop_loop.is_set(): - break - except TimeoutError: - pass - self._attempt_to_write_batches() - if self._ledger.is_finished: - break - - def _dequeue_ready_batches(self): - for shard, batch in self._batch_queue.drain(): - logger.debug(f"Writing batch for {shard}") - self._total_queue_length -= len(batch) - self._ordered_but_unwritten_items.extend(batch) - self._batches_in_next_write_by_shard[shard] = self._batches_in_next_write_by_shard.get(shard, 0) + len( - batch - ) + return store.get_batch_sync(indices_or_slice) - def _write_available_batches(self): - if len(self._ordered_but_unwritten_items) == 0: - return {} + def _get_start_stops(self, slice): + start = slice.start or 0 + if slice.stop is None: + stop = len(self) + elif slice.stop < 0: + stop = len(self) + slice.stop + else: + stop = slice.stop + if start < 0: + start = len(self) + slice.start + step = slice.step or 1 + return start, step, stop - any_shard_finished_reading = any(num_rows is not None for num_rows in self._expected_num_rows.values()) - - if ( - len(self._ordered_but_unwritten_items) >= self._min_items_to_write - or (time.time() - self._last_write_time > MAX_TIME_BETWEEN_WRITES) - or any_shard_finished_reading - ): - time_in = time.time() - self._tree_store.extend(self._ordered_but_unwritten_items) - time_out = time.time() - logger.debug(f"Wrote {len(self._ordered_but_unwritten_items)} rows in {time_out - time_in:.2f} seconds") - self._ordered_but_unwritten_items = [] - - written_by_shard = self._batches_in_next_write_by_shard - self._batches_in_next_write_by_shard = {} - self._last_write_time = time.time() - return written_by_shard + async def _get_start_stops_async(self, slice): + start = slice.start or 0 + if slice.stop is None: + stop = await self.async_len() + elif slice.stop < 0: + stop = (await self.async_len()) + slice.stop else: - return {} + stop = slice.stop + if start < 0: + start = (await self.async_len()) + slice.start - def _check_for_finished_shards(self): - futures_to_await_shards = [] - need_to_commit_for_shards = False - for shard, expected_rows in self._expected_num_rows.items(): - if expected_rows is None: - continue - - current_rows = self._ledger.shard_rows.get(shard, 0) - if current_rows == expected_rows: - if shard not in self._ledger.finished_shards: - logger.info(f"Shard {shard} finished.") - self._ledger.finished_shards.append(shard) - futures_to_await_shards.append(self._parent.shard_finished.remote(shard)) - need_to_commit_for_shards = True - elif current_rows > expected_rows: - raise ValueError(f"Shard {shard} has more rows than expected: {current_rows} > {expected_rows}") - - if len(self._ledger.finished_shards) == len(self.shards) and set(self._ledger.finished_shards) == set( - self.shards - ): - self._ledger.is_finished = True - need_to_commit_for_shards = True - return futures_to_await_shards, need_to_commit_for_shards - - def is_overwhelmed(self) -> bool: - max_queue_size = self._min_items_to_write * 3 - return self._total_queue_length > max_queue_size - - def __del__(self): - self._finish() + step = slice.step or 1 + return start, step, stop + def await_finished(self, timeout: Optional[float] = None): + if self._builder is None: + return + x = ray.get(self.finished_sentinel(), timeout=timeout) + self._attempt_to_load_store() + return x -def _to_list_of_dicts(batch: dict) -> List[dict]: - """ - Convert a batch of dictionaries to a list of dictionaries, suitable for writing to a cache. - """ - keys = list(batch.keys()) - values = list(batch.values()) - num_rows = len(values[0]) - return [{key: values[i][j] for i, key in enumerate(keys)} for j in range(num_rows)] + async def finished(self): + if self._builder is None: + return + x = await self.finished_sentinel() + # TODO: make an async version of this + self._attempt_to_load_store() + return x + def _attempt_to_load_store(self): + if self._store_future.done(): + return -def _canonicalize_batch(batch: Union[dict, List[dict]]) -> List[dict]: - if isinstance(batch, pa.RecordBatch): - batch = dict_from_record_batch(batch) + try: + store = TreeStore.open(self._exemplar, self.cache_dir, mode="r") + except FileNotFoundError: + assert self._builder is not None + ledger = ray.get(self._builder.current_ledger.remote()) + metrics = _ledger_to_metrics(ledger) + if metrics.rows_finished == 0 and metrics.is_finished: + # this means we built an empty cache. go with it + store = TreeStore.open(self._exemplar, f"memory://{self.cache_dir}", mode="a") + else: + raise + try: + self._store_future.set_result(store) + except concurrent.futures.InvalidStateError: + pass - if isinstance(batch, dict): - return _to_list_of_dicts(batch) - else: - return batch + def attach_metrics_monitor(self, monitor: MetricsMonitor): + if self._builder is None: + logger.warning("Cannot attach metrics monitor to finished cache.") + # TODO: decide what to do about attaching if the cache is already finished + # maybe get the final metrics? + return + + self._metrics_monitors.append(monitor) + + def _monitor_metrics(self): + while not self._stop: + try: + try: + # it's better to let the Ray actor handle the timeout + ledger_or_timeout = ray.get(self._builder.updated_ledger.remote(timeout=4.0), timeout=10.0) + if isinstance(ledger_or_timeout, Exception): + raise ledger_or_timeout + self.ledger = ledger_or_timeout + metrics = _ledger_to_metrics(self.ledger) + for monitor in self._metrics_monitors: + monitor(metrics) + if metrics.is_finished: + break + except TimeoutError: + pass + except Exception as e: + if str(e).startswith("Failed to submit task to actor"): + logger.warning("Cache builder actor is gone. Stopping monitoring.") + break + else: + raise + try: + self._attempt_to_load_store() + except FileNotFoundError: + pass + except Exception as e: + if str(e).startswith("Failed to submit task to actor"): + logger.warning("Cache builder actor is gone. Stopping monitoring.") + break + else: + self.logger.exception("Error while reading metrics from shard cache.") + raise e -# thinking through the design of the cache system +@dataclass_json +@dataclass +class CacheLedger: + # NB: unlike the old cache, the mere existence of a ledger doesn't mean the cache is finished + total_num_rows: int + shard_rows: Dict[str, int] + is_finished: bool = False + finished_shards: List[str] = dataclasses.field(default_factory=list) + field_counts: Dict[str, int] = dataclasses.field(default_factory=dict) + metadata: "CacheMetadata" = dataclasses.field(default_factory=lambda: CacheMetadata(CacheOptions(), {})) + + @staticmethod + def load_or_initialize( + cache_dir: str, source: ShardedDataSource, processor: BatchProcessor, config: "CacheOptions" + ): + metadata = CacheMetadata(options=config, preprocessor_metadata=processor.metadata) + try: + return CacheLedger.load(cache_dir, metadata) + except FileNotFoundError: + return CacheLedger( + total_num_rows=0, + shard_rows={shard: 0 for shard in source.shard_names}, + is_finished=False, + metadata=metadata, + ) -# we decided to use Ray, which was maybe a mistake, but here we are. -# Ray doesn't like it when the number of actors gets too large, so we can't have one actor per shard. -# we have N nodes and K shards. + @staticmethod + def load(cache_dir: str, metadata: Optional["CacheMetadata"] = None) -> "CacheLedger": + ledger_path = os.path.join(cache_dir, LEDGER_FILE_NAME) + try: + logger.debug(f"Attempting to load cache ledger from {ledger_path}") + with fsspec.open(ledger_path) as file: + cache_ledger = CacheLedger.from_json(file.read()) # type: ignore + if metadata: + diff = cache_ledger.metadata.compare_to(metadata) + if not diff: + logger.debug("Metadata matches") + else: + logger.warning(f"Metadata mismatch: {pprint.pformat(diff, indent=2)}") + return cache_ledger + except FileNotFoundError: + raise FileNotFoundError(f"Cache ledger not found at {ledger_path}") -# at a high level, we have 3 steps: -# 1. read batches from the shard source -# 2. process batches -# 3. write batches to the cache for that shard + def _serialize_and_commit(self, cache_dir): + path = os.path.join(cache_dir, LEDGER_FILE_NAME) + return _serialize_json_and_commit(path, self) # type: ignore -# The difficulty is that we want parallelism, and we want to control the order of the written data. -# Reading batches requires CPU and network. -# ==> This means we should limit the number of shard groups to roughly the number of nodes, maybe times 2. -# We ideally want to read from shards roughly evenly (at least within a group of shards) +@dataclass_json +@dataclass(frozen=True) +class CacheMetadata: + options: CacheOptions = CacheOptions.default() + preprocessor_metadata: Optional[dict[str, Any]] = None -def _shard_reader_generator(shard_source: ShardedDataSource[T], shard_name: str, start_row: int, batch_size: int): - shard_iter = shard_source.open_shard_at_row(shard_name, start_row) - batch = [] - for row in shard_iter: - batch.append(row) + def compare_to(self, other: "CacheMetadata") -> deepdiff.DeepDiff: + """ + Compare this metadata to another set of metadata. This is used to check if the cache being loaded + was created with the expected configuration. - if len(batch) == batch_size: - yield batch - batch = [] + if other.preprocessor_metadata is None, we ignore it for the purposes of comparison. + """ + if other.preprocessor_metadata is None: + sorta_self = dataclasses.replace(self, preprocessor_metadata=None) + else: + sorta_self = self + return deepdiff.DeepDiff(sorta_self, other) - if len(batch) > 0: - yield batch + @staticmethod + def empty(): + return CacheMetadata() @dataclass -class ShardGroupToBeProcessed(PriorityWorkTaskGroupSpec): - name: str - builder_ref: ray.actor.ActorHandle # _TreeStoreCacheBuilder - writer: ray.actor.ActorHandle # _GroupedShardWriter - shard_source: ShardedDataSource - shard_names: Sequence[str] - priority_fn: Callable[[int, int], float] - processor_actor: ray.actor.ActorHandle # BatchProcessorQueue - batch_size: int - group_id: int - - def build(self) -> "PriorityWorkTaskGroup": - return ShardGroupTaskGroup(self) - - -class ShardGroupTaskGroup(PriorityWorkTaskGroup): - def __init__(self, spec: ShardGroupToBeProcessed): - self.spec: ShardGroupToBeProcessed = spec - self.logger = pylogging.getLogger(f"shard_reader.{spec.group_id}.{spec.name}") - - current_shard_status: dict[str, ShardStatus] = {} - for shard_name in self.spec.shard_names: - try: - current_shard_status[shard_name] = ray.get(self.spec.writer.get_shard_status.remote(shard_name)) - except Exception as e: - self.spec.builder_ref.shard_failed.remote(shard_name, ser_exc_info()) - raise e +class _ShardStatus: + shard_name: str + num_rows_committed: int + is_finished: bool - batch_size = self.spec.batch_size - self._items: list[PriorityWorkItem] = [] +class SerialCacheWriter(AbstractContextManager): + """ + Writes TreeCache-compatible caches to disk. This is a serial version of TreeCacheWriter that doesn't use Ray. + Mostly for scripts and debugging. - for shard_name in self.spec.shard_names: - try: - status = current_shard_status[shard_name] - if status.is_finished: - self.logger.info(f"Shard {shard_name} already finished. Skipping.") - continue + Examples: + >>> with SerialCacheWriter(cache_dir,exemplar) as writer: + ... for batch in process_batches(): + ... writer.write_batch(batch) + """ - reader = _shard_reader_generator( - self.spec.shard_source, shard_name, status.num_rows_committed, batch_size - ) + def __init__( + self, + cache_dir: str, + exemplar: T, + metadata: Optional["CacheMetadata"] = None, + ): + self.cache_dir = cache_dir + self.metadata = metadata + self._exemplar = exemplar + self._tree_store = TreeStore.open(exemplar, self.cache_dir, mode="w", cache_metadata=True) + self._is_closed = False + + def __enter__(self) -> "SerialCacheWriter": + return self - task_name = f"shard_reader.{self.spec.name}.{shard_name}" + def __exit__(self, exc_type, exc_val, exc_tb): + # if successful, write the ledger + # TODO: store field counts in the ledger + ledger = CacheLedger( + total_num_rows=len(self._tree_store), + is_finished=True, + shard_rows={"": len(self._tree_store)}, + finished_shards=[""], + field_counts={}, + metadata=self.metadata or CacheMetadata.empty(), + ) - batch_idx = status.num_rows_committed // batch_size + if exc_type is None: + ledger._serialize_and_commit(self.cache_dir) + logger.info(f"Cache ledger written to {self.cache_dir}") + self._is_closed = True - shard_idx = self.spec.shard_source.shard_names.index(shard_name) - item = ShardReaderItem( - self, - task_name, - shard_name, - shard_idx, - batch_idx=batch_idx, - reader=reader, - current_row=status.num_rows_committed, - ) + def result(self) -> "TreeCache": + if not self._is_closed: + raise RuntimeError("Cannot get result until TreeCacheWriter is closed") + return TreeCache.load(self.cache_dir, self._exemplar, self.metadata) - heapq.heappush(self._items, item) - except Exception as e: - self.logger.exception(f"Error while initializing shard {shard_name}") - self.spec.writer[shard_name].shard_failed.remote(ser_exc_info()) - raise e + def write_batch(self, batch: BatchResult): + if isinstance(batch, pa.RecordBatch): + raise NotImplementedError("Only non-RecordBatch batches are supported for now") - @property - def name(self): - return self.spec.name + batch = _canonicalize_batch(batch) # type: ignore - def items(self) -> Sequence["PriorityWorkItem"]: - return self._items + self._tree_store.extend(batch) -# NB This class is stateful -@dataclass -class ShardReaderItem(PriorityWorkItem): +class ShardedCacheWriter: """ - Each time execute is called, this class reads a batch of examples from the shard - and dispatches them to the processor. + Similar to SerialCacheWriter, but tracks shard metadata. + + Similar to _OrderedCacheWriter, it also supports resuming, and it + groups together batches before writing (at some interval) in order to improve performance. """ - group: ShardGroupTaskGroup - name: str - shard_name: str - shard_idx: int - batch_idx: int - reader: Iterator[list] - current_row: int = 0 + def __init__( + self, + cache_dir: str, + initial_ledger: CacheLedger, + exemplar: T, + on_write: Optional[Callable[[CacheLedger], None]] = None, + ): + self.cache_dir = cache_dir + self._on_write = on_write + + self._ledger = copy.deepcopy(initial_ledger) + + self._tree_store = TreeStore.open(exemplar, self.cache_dir, mode="a") # type: ignore + self._tree_store.trim_to_size(self._ledger.total_num_rows) + self._items_ready_to_write: list = [] @property - def priority(self): - return self.group.spec.priority_fn(self.shard_idx, self.batch_idx) + def ledger(self): + return self._ledger + + # we have both versions b/c we need this one for actors + def get_ledger(self): + return self._ledger @property - def spec(self): - return self.group.spec + def is_finished(self): + return self._ledger.is_finished - def execute(self) -> tuple[bool, Optional[ray.ObjectRef]]: - writer = self.spec.writer - write_finished_ref = None + def finish_shard(self, shard_name: str, num_rows: int): + self.flush() + current_rows = self._ledger.shard_rows.get(shard_name, 0) + if current_rows != num_rows: + raise ValueError(f"Expected {num_rows} rows in finished shard {shard_name}, but found {current_rows}") - self.group.logger.debug(f"Reading one batch of shard {self.shard_name}: {self.batch_idx}") + self._ledger.finished_shards.append(shard_name) + self._ledger._serialize_and_commit(self.cache_dir) - try: - batch = next(self.reader, None) - exhausted_shard = batch is None or (len(batch) < self.spec.batch_size) + def write_batch(self, shard_name: str, batch: BatchResult): + if self.is_finished: + raise RuntimeError("Cannot write to a finished cache") - if batch: - priority = self.spec.priority_fn(self.shard_idx, self.batch_idx) - try: - batch_result_ref = ray.get( - self.spec.processor_actor.submit.remote( - priority=priority, - desc=f"{self.shard_name}.{self.batch_idx}", - batch=RefBox(ray.put(batch)), - ) - ) - logger.debug(f"Got batch result: {batch_result_ref}") - write_finished_ref = writer.batch_finished.remote( - self.shard_name, self.batch_idx, RefBox(batch_result_ref) - ) - self.batch_idx += 1 - self.current_row += len(batch) - except Exception as e: - self.group.logger.exception(f"Error while processing batch {self.batch_idx}") - # fire and forget - writer.shard_failed.remote(self.shard_name, self.batch_idx, ser_exc_info()) - raise e + if isinstance(batch, pa.RecordBatch): + raise NotImplementedError("Only non-RecordBatch batches are supported for now") - if exhausted_shard: - logger.info(f"Shard {self.shard_name} exhausted. Expecting {self.current_row} rows.") - writer.shard_finished_reading.remote(self.shard_name, self.current_row) + batch = _canonicalize_batch(batch) # type: ignore - self.group.logger.debug(f"Finished reading one batch of shard {self.shard_name}: {self.batch_idx}") + self._items_ready_to_write.append((shard_name, batch)) - return exhausted_shard, write_finished_ref - except Exception as e: # noqa - self.group.logger.exception(f"Error while processing shard {self.shard_name}") - # fire and forget - writer.shard_failed.remote(self.shard_name, self.batch_idx, ser_exc_info()) - raise e + def flush(self): + self._attempt_to_write_batches() + + def finish(self): + self.flush() + + # if successful, write the ledger + logger.info("Finished writing cache") + self._ledger.is_finished = True + self._ledger._serialize_and_commit(self.cache_dir) + if self._on_write: + self._on_write(self._ledger) + + return self._tree_store + + def _attempt_to_write_batches(self): + if self._ledger.is_finished: + return + + if not self._items_ready_to_write: + return + + updated_shards = self._write_available_batches() + + logger.debug(f"Updated shards: {updated_shards}") + + did_write = len(updated_shards) > 0 + + if did_write: + + for shard, num_rows in updated_shards.items(): + self._ledger.shard_rows[shard] = self._ledger.shard_rows.get(shard, 0) + num_rows + + total_rows = self._ledger.total_num_rows + sum(updated_shards.values()) + self._ledger.total_num_rows = total_rows + self._ledger._serialize_and_commit(self.cache_dir) + + if self._on_write: + self._on_write(self._ledger) + + def _write_available_batches(self): + ready = self._items_ready_to_write + self._items_ready_to_write = [] + + if len(ready) == 0: + return {} + + to_write = [] + written_by_shard = {} + for shard, batch in ready: + to_write.extend(batch) + written_by_shard[shard] = written_by_shard.get(shard, 0) + len(batch) + + self._tree_store.extend(to_write) + return written_by_shard def _serialize_json_and_commit(path, obj): @@ -657,17 +716,6 @@ def _serialize_json_and_commit(path, obj): fs.rename(f"{path}.tmp", path) -def _load_cache_ledger(cache_dir) -> CacheLedger: - try: - ledger_path = os.path.join(cache_dir, LEDGER_FILE_NAME) - logger.debug(f"Attempting to load cache ledger from {ledger_path}") - with fsspec.open(ledger_path) as file: - cache_ledger = CacheLedger.from_json(file.read()) # type: ignore - return cache_ledger - except FileNotFoundError: - raise FileNotFoundError(f"Cache ledger not found at {ledger_path}") - - @ray.remote(num_cpus=0.1) # keep this small b/c it doesn't do a lot class _TreeStoreCacheBuilder(SnitchRecipient): """ @@ -682,119 +730,42 @@ def __init__( name: str, source: ShardedDataSource[T], processor: BatchProcessor[T, U], - cache_config: Dict[str, Any], - min_items_to_write: int, + options: CacheOptions, + force_flush: bool, ): pylogging.basicConfig(level=DEFAULT_LOG_LEVEL, format=LOG_FORMAT) self.logger = pylogging.getLogger(f"{__name__}.{name}") - self.source = source - self._cache_dir = cache_dir - # self._metrics = InProgressCacheMetrics() - self._updated_ledger_condition = asyncio.Condition() - self._ledger = CacheLedger(0, {}) - self.shards_in_progress: set[str] = set() - exemplar = processor.output_exemplar - self._finished_promise: asyncio.Future[None] = asyncio.Future() - # used to subscribe to metrics updates - self._cache_config = cache_config - path_for_name = os.path.join(*self._cache_dir.split("/")[-2:]) - name = f"broker::{path_for_name}" - self.logger = pylogging.getLogger(f"{name}") - self._cache_writer: Optional[ActorHandle] = _OrderedCacheWriter.remote( # type: ignore - current_actor_handle(), - f"writer::{path_for_name}", - exemplar, - processor.batch_size, - cache_dir, - source.shard_names, - min_items_to_write, - ) - try: - cache_ledger = _load_cache_ledger(self._cache_dir) - self._ledger = cache_ledger - except FileNotFoundError: - pass - - if self._ledger.is_finished: - self._finished_promise.set_result(None) - self._start_workers(cache_dir, name, processor, source) - - def _start_workers(self, cache_dir, name, processor, source): - if len(source.shard_names) == 0: - self.logger.warning("No shards to index?!?") - self._finalize() - else: - self.logger.debug(f"Starting cache build for {source.shard_names}") - self.logger.info(f"Starting cache build for {len(source.shard_names)} shards") - - self_ref = current_actor_handle() - - self._shard_writers = [] - self._shard_readers = [] - self._processor_actors = [] - - for shard_name in source.shard_names: - self.shards_in_progress.add(shard_name) - - num_shards = len(source.shard_names) - num_worker_groups = len(ray.nodes()) - num_shard_groups = max(min(num_worker_groups, num_shards), 1) - - # if we have a bunch of caches to build with one shard, we don't want them all - # assigned to the same node, so we use an offset based on the hash of the name (for stability) - # in an attempt to spread them out - group_offset = int(hash(name) % num_worker_groups) - - shard_groups: list[list[str]] = [[] for _ in range(num_shard_groups)] - for i, shard_name in enumerate(source.shard_names): - shard_groups[i % num_shard_groups].append(shard_name) - - def priority_fn(shard_idx, batch_idx): - return batch_idx * num_shards + shard_idx - - for group_id, shard_group in enumerate(shard_groups): - # TODO: would probably be better if we didn't create one of these per shard group - processor_actor = _BatchProcessorQueue.remote(processor) # type: ignore - self._processor_actors.append(processor_actor) - - assert self._cache_writer is not None - - work_item = ShardGroupToBeProcessed( - name=name, - builder_ref=self_ref, - writer=self._cache_writer, - shard_source=source, - shard_names=shard_group, - priority_fn=priority_fn, - processor_actor=processor_actor, - batch_size=processor.batch_size, - group_id=group_id, - ) + self.source = source + self._cache_dir = cache_dir + self._options = options + self._updated_ledger_condition = asyncio.Condition() # used to subscribe to metrics updates - # we want global names so that different tasks can coordinate priorities - worker_to_assign = (group_id + group_offset) % num_worker_groups - priority_actor_name = f"priority_processor.{worker_to_assign}" + self._ledger = CacheLedger.load_or_initialize(cache_dir, source, processor, options) - reader_actor = WorkQueueDispatcherActor.options( # type: ignore - name=priority_actor_name, get_if_exists=True - ).remote() - - reader_actor.assign_work.remote(work_item) - self._shard_readers.append(reader_actor) + if self._ledger.is_finished: + self._finished_promise.set_result(None) - def shard_finished(self, shard_name: str): - """Callback method for when a shard worker has finished.""" - self.shards_in_progress.remove(shard_name) + path_for_name = os.path.join(*self._cache_dir.split("/")[-2:]) + name = f"broker::{path_for_name}" + self.logger = pylogging.getLogger(f"{name}") - def shard_failed(self, shard_name: str, error: ExceptionInfo): - """Callback method for when a shard worker has failed.""" - self._writer_exception(shard_name, error) + if self._ledger.is_finished: + self.logger.info("Cache already finished. Nothing to do.") + return + self._cache_writer = _core_writer_task.remote( + current_actor_handle(), cache_dir, self._ledger, source, processor, force_flush + ) + except Exception: + # Ray behaves poorly if the constructor of an actor fails, so we catch and log here + # this also propagates to the finished promise, so we can handle it there + self._writer_exception(None, ser_exc_info()) - def _updated_ledger(self, ledger: CacheLedger): - self._ledger = ledger - self._do_notify() + def current_ledger(self): + if self._finished_promise.done() and self._finished_promise.exception() is not None: + raise self._finished_promise.exception() + return self._ledger def other_failed(self, error: ExceptionInfo): """Callback method for when a shard worker has failed.""" @@ -805,21 +776,37 @@ def _child_failed(self, child: ray.actor.ActorHandle, exception: ExceptionInfo): self._writer_exception(None, exception) def is_finished(self): + if self.failed(): + return False return self._ledger.is_finished + def failed(self): + return self._finished_promise.done() and self._finished_promise.exception() is not None + async def finished_sentinel(self): await self._finished_promise - async def updated_ledger(self) -> CacheLedger: + async def updated_ledger(self, timeout: float | None = None) -> CacheLedger | TimeoutError: + """ + NB: we **return** a timeout error, we don't raise it. This is because we want to find real failures + in the ray dashboard, and it's a real pain to find exceptions in the logs. + """ if self._finished_promise.done(): if self._finished_promise.exception() is not None: raise self._finished_promise.exception() # type: ignore else: return self._ledger - async with self._updated_ledger_condition: - await self._updated_ledger_condition.wait() + try: + async with self._updated_ledger_condition: + cond = self._updated_ledger_condition.wait() + if timeout is not None: + await asyncio.wait_for(cond, timeout=timeout) + else: + await cond return self._ledger + except asyncio.TimeoutError: + return TimeoutError("Timed out waiting for cache to update") def _writer_exception(self, shard_name, exc_info: ExceptionInfo): info = exc_info.restore() @@ -834,6 +821,26 @@ def _writer_exception(self, shard_name, exc_info: ExceptionInfo): pass self._do_notify() + def _notify_updated_ledger(self, ledger: CacheLedger): + """ + Called by the cache writer when it has updated the ledger. + """ + was_finished = self._ledger.is_finished + self._ledger = ledger + + if was_finished: + raise RuntimeError("Ledger was already finished") + + if self._ledger.is_finished: + logger.info(f"Finalizing cache {self._cache_dir}...") + # guard against invalid state errors + if not self._finished_promise.done(): + self._finished_promise.set_result(None) + + self._cache_writer = None + + self._do_notify() + def _do_notify(self): async def _do_notify_async(): async with self._updated_ledger_condition: @@ -841,36 +848,8 @@ async def _do_notify_async(): asyncio.create_task(_do_notify_async()) - def current_ledger(self): - return self._ledger - - def _finalize(self): - logger.info(f"Finalizing cache {self._cache_dir}...") - - self._ledger.is_finished = True - self._finished_promise.set_result(None) - - # notify metrics subscribers - self._do_notify() - self._cache_writer = None - - def signal_backpressure(self, next_item_desired: Optional[int]): - # get the priority of the item we want - if next_item_desired is not None: - self.logger.debug(f"Signaling backpressure for {next_item_desired}") - # our priority function above is basically (batch_index, shard_index). We just ask we don't get more - # than one round of batches ahead - max_priority = (next_item_desired + 1) * len(self.source.shard_names) - - for reader in self._shard_readers: - reader.set_max_dispatch_priority.remote(max_priority) - else: - self.logger.debug("Signaling no backpressure") - for reader in self._shard_readers: - reader.set_max_dispatch_priority.remote(None) - -def _get_builder_actor(cache_dir, input_shards, processor, cache_config=None, items_per_write=MIN_ITEMS_TO_WRITE): +def _get_builder_actor(cache_dir, shard_source, processor, options=CacheOptions.default(), force_flush=False): name = f"lev_cache_manager::{cache_dir}" path_for_name = os.path.join(*os.path.split(cache_dir)[-2:]) name_for_display = f"builder::{path_for_name}" @@ -878,478 +857,445 @@ def _get_builder_actor(cache_dir, input_shards, processor, cache_config=None, it return _TreeStoreCacheBuilder.options(name=name, get_if_exists=True).remote( # type: ignore name=name_for_display, cache_dir=cache_dir, - source=input_shards, + source=shard_source, processor=processor, - cache_config=cache_config, - min_items_to_write=items_per_write, + options=options, + force_flush=force_flush, ) -class TreeCache(AsyncDataset[T_co]): - ledger: Optional[CacheLedger] - _broker: Optional[ActorHandle] - # monitor_thread waits for new metrics and also periodically reloads the cache - _monitor_thread: Optional[threading.Thread] - _metrics_monitors: List[MetricsMonitor] +##### +# Core implementation starts below. +##### +# The main idea is to have a bunch of reader tasks that read batches, dispatch tokenization tasks, producing +# a stream of tokenized batches. We then interleave these tokenized batches and write them to the cache. +# The reader tasks are given a group of shards, which are implicitly concatenated together. - def __init__( - self, - cache_dir: str, - exemplar: T_co, - ledger: Optional[CacheLedger], - _broker, # handle of _TreeStoreCacheBuilder - ): - self.cache_dir = cache_dir - self.ledger = ledger - self._was_already_finished = ledger is not None and ledger.is_finished - self._broker = _broker - self._exemplar = exemplar +# This is still much slower than I would like but I haven't figured out why yet. +# TODO: +# - [ ] Profile the tokenization process more (see TIME comments) +# - [ ] Try Ray's autoscaling actorpool if the issue is tokenization isn't fast enough +# https://github.com/ray-project/ray/blob/1bab09bf842edee51c3778be4cfb16f8b900d764/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py +# - [ ] More observability into what's queued and how long work items take - self._metrics_monitors = [] - name = os.path.join(*cache_dir.split("/")[-2:]) - self.logger = pylogging.getLogger(f"TreeCache.{name}") - self._store_future: threading_Future[TreeStore] = threading_Future() - self._stop = False - # assert _broker is None - - if self._broker is not None: - self._monitor_thread = threading.Thread(target=self._monitor_metrics, daemon=True) - self._monitor_thread.start() - else: - self._attempt_to_load_store() - assert self._store_future.done() - @property - def store(self) -> TreeStore[T_co]: - return self._store_future.result() +@dataclass +class _Batch: + """ + A batch of data that has either been read or tokenized. + """ - async def store_async(self) -> TreeStore[T_co]: - if self._broker is not None: - return await asyncio.wrap_future(self._store_future) - else: - return self.store + shard_name: str + row_indices: List[int] + payload: ray.ObjectRef - async def async_len(self) -> int: - if self._broker is not None: - self.await_finished() - return len(await self.store_async()) +@dataclass +class _ShardFinished: + """ + A message indicating that a shard has finished. + """ - def __len__(self): - self.await_finished() + shard_name: str + total_rows: int - return len(self.store) - async def final_length_is_known(self) -> bool: - if self._broker is not None: - return await self._broker.is_finished.remote() +_Message = _Batch | _ShardFinished +""" +A message that can be sent from a reader task to the writer task. +""" - return True +_TIME_BETWEEN_WRITES = 20.0 # seconds +_MAX_WRITE_BATCHES = 1000 +_MIN_WRITE_BATCHES = 100 - def is_finite(self) -> bool: - return True - async def current_len(self) -> int: - if not self._store_future.done(): - return 0 +@ray.remote(num_cpus=1) +def _core_writer_task( + parent, + cache_dir, + initial_ledger: CacheLedger, + source: ShardedDataSource, + processor, + force_flush: bool, +): + """ + This is the main task that processes the data and writes it to the cache. - return len(await self.store_async()) + It chains together: + * 1 generator per shard group + * interleaving of the generators + * processing of the batches + * writing of the batches to the cache + """ + logger.setLevel(DEFAULT_LOG_LEVEL) + logger.info("Starting writer task") - async def get_batch(self, indices: Sequence[int] | slice): - # this is tricky: we want to wait until either the cache is finished or we have the max index - if isinstance(indices, slice): - start, step, stop = await self._get_start_stops_async(indices) - await self._wait_for_len(max(stop, start)) - indices = range(start, stop, step) + name = str(os.path.join(*cache_dir.split("/")[-2:])) + # append a small random number to the name to avoid collisions + name += f"::{random.randint(0, 1000)}" - max_index = max(indices) - await self._wait_for_len(max_index + 1) + def on_write(ledger): + ray.get(parent._notify_updated_ledger.remote(ledger)) - return await self.store.get_batch(indices) + with log_failures_to(parent): + sharded_cache_writer = ray.remote(ShardedCacheWriter).remote( + cache_dir, initial_ledger, processor.output_exemplar, on_write=on_write + ) - async def _wait_for_len(self, needed_len: int): - if self._broker is not None: - while needed_len > await self.current_len(): - new_ledger: CacheLedger = await self._broker.updated_ledger.remote() + interleave: RayPrefetchQueue = RayPrefetchQueue( + lambda: _make_interleave(name, source, initial_ledger, processor), + 4096, + producer_options={"num_cpus": 1, "name": f"{name}::interleave"}, + ) - if needed_len <= new_ledger.total_num_rows: - break + total_time = Stopwatch() + loading_time = Stopwatch() + append_time = Stopwatch() + flush_time = Stopwatch() + flush_amortized_time = Stopwatch() - if new_ledger.is_finished: - if needed_len >= new_ledger.total_num_rows: - raise IndexError( - f"Index {needed_len} out of bounds for cache of size {new_ledger.total_num_rows}" - ) - break - else: - if needed_len > len(self.store): - raise IndexError(f"Index {needed_len} out of bounds for cache of size {len(self.store)}") + i = 0 + batches_since_last_write = 0 + time_of_last_write = time.time() + last_flush_future: Optional[ray.ObjectRef] = None + # start_of_last_flush = time_of_last_write - def _wait_for_len_sync(self, needed_len, timeout: Optional[float] = None): - time_in = time.time() - t_max = time_in + (timeout or 1e6) - if self._broker is not None: - while needed_len > len(self.store): - cur_time = time.time() - if cur_time > t_max: - raise TimeoutError(f"Timed out waiting for cache to reach {needed_len}") + # for i, batch_box in enumerate(interleave): + while True: + with total_time: # 0.014 try: - new_ledger: CacheLedger = ray.get( - self._broker.updated_ledger.remote(), timeout=max(t_max - cur_time, 10) - ) - except TimeoutError: - continue - - if needed_len <= new_ledger.total_num_rows: + cur_time = time.time() + time_since_last_write = cur_time - time_of_last_write + remaining_time = _TIME_BETWEEN_WRITES - time_since_last_write + + if batches_since_last_write > 0: + with flush_amortized_time: + if remaining_time <= 0 or batches_since_last_write >= _MAX_WRITE_BATCHES or force_flush: + with flush_time: + # TODO: don't block? + if last_flush_future: + ray.get(last_flush_future) + # print( + # f"Flushed {batches_since_last_write} batches in" + # f" {time.time() - start_of_last_flush} seconds" + # ) + last_flush_future = sharded_cache_writer.flush.remote() + # start_of_last_flush = time.time() + batches_since_last_write = 0 + time_of_last_write = time.time() + continue + else: + remaining_time = _TIME_BETWEEN_WRITES + + with loading_time: + try: + message = interleave.get_next(timeout=max(remaining_time, 0.1)) + except QueueEmpty: + logger.info("Writer running ahead of reader.") + continue + + with append_time: + match message: + case _Batch(shard, _, payload): + # TODO: ensure indices are what we expect + sharded_cache_writer.write_batch.remote(shard, payload) + batches_since_last_write += 1 + i += 1 + case _ShardFinished(shard, total_rows): + ray.get(sharded_cache_writer.finish_shard.remote(shard, total_rows)) + case _: + raise AssertionError(f"Unexpected message type {type(message)}") + + # if i % 1000 == 0: + # print( + # f"Processed {i} batches: {loading_time.average()}s load," + # f" {append_time.average()}s append, {flush_time.average()}s flush blocked, " + # f"{flush_amortized_time.average()}s amortized flush, " + # f"{total_time.average()}s total" + # ) + except StopIteration: + logger.info("Finished all shards") break + except Exception as e: + logger.exception("Error while processing batch") + raise e - if new_ledger.is_finished: - if needed_len >= new_ledger.total_num_rows: - raise IndexError( - f"Index {needed_len} out of bounds for cache of size {new_ledger.total_num_rows}" - ) - break - else: - if needed_len > len(self.store): - raise IndexError(f"Index {needed_len} out of bounds for cache of size {len(self.store)}") - - @staticmethod - def load(cache_dir: str, exemplar: T) -> "TreeCache": - """Loads a cache from disk or an object store. Raises FileNotFoundError if the cache doesn't exist""" - logger.info(f"Loading cache from {cache_dir}") - ledger = _load_cache_ledger(cache_dir) - if not ledger.is_finished: - raise FileNotFoundError(f"Cache at {cache_dir} is not finished. Use build_or_load to build it.") - return TreeCache(cache_dir, exemplar, ledger, None) - - @staticmethod - def build_or_load( - cache_dir: str, - shard_source: ShardedDataSource[T], - processor: BatchProcessor[T, U], - cache_config: Optional[Dict[str, Any]] = None, - items_per_write: int = MIN_ITEMS_TO_WRITE, - ) -> "TreeCache[U]": - try: - return TreeCache.load(cache_dir, processor.output_exemplar) - except FileNotFoundError: - broker = _get_builder_actor( - cache_dir=cache_dir, - input_shards=shard_source, - processor=processor, - cache_config=cache_config, - items_per_write=items_per_write, - ) - return TreeCache(cache_dir=cache_dir, exemplar=processor.output_exemplar, ledger=None, _broker=broker) - - def finished_sentinel(self): - """Returns a Ray-awaitable object that will be set when the cache is finished""" - if self._broker is None: - return ray.remote(num_cpus=0)(lambda: None).remote() - else: - return self._broker.finished_sentinel.remote() - - @property - def is_finished(self): - if self._broker is None: - return True - else: - return ray.get(self._broker.is_finished.remote()) - - def __getitem__(self, item): - if isinstance(item, slice): - start, step, stop = self._get_start_stops(item) - # TODO: wait for store to be set - return self.store[start:stop:step] - else: - if item < 0: - item += len(self) - if item < 0 or item >= len(self): - raise IndexError(f"Index {item} out of bounds for cache of size {len(self)}") - return self.store[item] - - def get_batch_sync(self, indices_or_slice, *, timeout: Optional[float] = None): - store = self.store - if isinstance(indices_or_slice, slice): - start, step, stop = self._get_start_stops(indices_or_slice) - indices_or_slice = range(start, stop, step) - - max_index = max(indices_or_slice) - - self._wait_for_len_sync(max_index + 1, timeout=timeout) - - return store.get_batch_sync(indices_or_slice) - - def _get_start_stops(self, slice): - start = slice.start or 0 - if slice.stop is None: - stop = len(self) - elif slice.stop < 0: - stop = len(self) + slice.stop - else: - stop = slice.stop - if start < 0: - start = len(self) + slice.start - step = slice.step or 1 - return start, step, stop - - async def _get_start_stops_async(self, slice): - start = slice.start or 0 - if slice.stop is None: - stop = await self.async_len() - elif slice.stop < 0: - stop = (await self.async_len()) + slice.stop - else: - stop = slice.stop - if start < 0: - start = (await self.async_len()) + slice.start + sharded_cache_writer.finish.remote() - step = slice.step or 1 - return start, step, stop + out = sharded_cache_writer.get_ledger.remote() + return out - def await_finished(self, timeout: Optional[float] = None): - if self._broker is None: - return - x = ray.get(self.finished_sentinel(), timeout=timeout) - self._attempt_to_load_store() - return x - async def finished(self): - if self._broker is None: - return - x = await self.finished_sentinel() - # TODO: make an async version of this - self._attempt_to_load_store() - return x - - def _attempt_to_load_store(self): - if self._store_future.done(): - return - - try: - store = TreeStore.open(self._exemplar, self.cache_dir, mode="r") - except FileNotFoundError: - logger.error(f"Cache at {self.cache_dir} not found.") - assert self._broker is not None - ledger = ray.get(self._broker.current_ledger.remote()) - metrics = _ledger_to_metrics(ledger) - if metrics.rows_finished == 0 and metrics.is_finished: - # this means we built an empty cache. go with it - store = TreeStore.open(self._exemplar, f"memory://{self.cache_dir}", mode="a") - else: - raise - try: - self._store_future.set_result(store) - except concurrent.futures.InvalidStateError: - pass - - def attach_metrics_monitor(self, monitor: MetricsMonitor): - if self._broker is None: - logger.warning("Cannot attach metrics monitor to finished cache.") - # TODO: decide what to do about attaching if the cache is already finished - # maybe get the final metrics? - return +def _interleave_shards(readers: Sequence[RayPrefetchQueue], first_index: int) -> Iterator[T]: # _Message + """ + Interleaves the results of multiple iterators. To support resume, + we need to be able to start from not the "first" iterator. - self._metrics_monitors.append(monitor) + Args: + readers: A list of iterators + first_index: The index of the first iterator to start from. We use this to support resuming. + """ - def _monitor_metrics(self): - while not self._stop: - try: + finished: set[int] = set() + total = 0 + while len(finished) < len(readers): + for i in range(first_index, len(readers)): + reader = readers[i] + if i not in finished: try: - ledger = ray.get(self._broker.updated_ledger.remote(), timeout=10.0) - metrics = _ledger_to_metrics(ledger) - for monitor in self._metrics_monitors: - monitor(metrics) - if metrics.is_finished: - break - except TimeoutError: - pass + message = reader.get_next() + total += 1 + yield message + except StopIteration: + finished.add(i) except Exception as e: - if str(e).startswith("Failed to submit task to actor"): - logger.warning("Cache builder actor is gone. Stopping monitoring.") - break - try: - self._attempt_to_load_store() - except FileNotFoundError: - pass - except Exception as e: - if str(e).startswith("Failed to submit task to actor"): - logger.warning("Cache builder actor is gone. Stopping monitoring.") - break - else: - self.logger.exception("Error while reading metrics from shard cache.") + logger.exception(f"Error while processing group {i}") raise e + first_index = 0 -def _ledger_to_metrics(ledger: CacheLedger) -> InProgressCacheMetrics: - return InProgressCacheMetrics( - rows_finished=ledger.total_num_rows, - is_finished=ledger.is_finished, - # shard_rows=ledger.shard_rows, - # finished_shards=ledger.finished_shards, - field_counts=ledger.field_counts, - ) + logger.info(f"Finished all shards, got {total} batches") -class GroupRoundRobinBuffer(Generic[T]): +def _assign_shards_to_groups(shards: Sequence[_ShardStatus], num_groups: int) -> list["_ShardGroup"]: """ - A buffer that holds items from multiple groups and returns them in a round-robin fashion. - The groups need not have the same number of items. If a group is exhausted, it is removed from the rotation. + Assigns shards to groups in a round-robin fashion. """ + groups: list[list] = [[] for _ in range(num_groups)] + for i, shard in enumerate(shards): + groups[i % num_groups].append(shard) + return [_ShardGroup(group) for group in groups] - def __init__(self, groups: Sequence[str]): - self.groups = groups - self._current_group = 0 - self.buffers: dict[str, list[tuple[int, T]]] = {group: [] for group in groups} - self._remaining_groups = set(groups) - self._totals_written: dict[str, int] = {group: 0 for group in groups} - self._totals_expected: dict[str, Optional[int]] = {group: None for group in groups} - def __len__(self): - return sum(len(buffer) for buffer in self.buffers.values()) +def _randomize_shards(shards: Sequence[T], seed: int) -> list[T]: + prng = random.Random(seed) + shuffled = list(shards) + prng.shuffle(shuffled) + return shuffled - def append_to_group(self, group: str, item_serial: int, item: T): - if group not in self.groups: - raise ValueError(f"Group {group} not in {self.groups}") - if group not in self._remaining_groups: - raise ValueError(f"Group {group} already finished") +class _ShardGroup: + """ + Given a group of shards and a list of statuses, implicitly concatenates the shards and reads from them. - logger.debug(f"Appending item {item_serial} to {group}") + This class mostly exists for resuming: we want to be able to start from the last shard we were working on. + """ - heapq.heappush(self.buffers[group], (item_serial, item)) + def __init__(self, group: list[_ShardStatus]): + self.shards = group + self.total_rows_committed, _all_finished = self._impute_total_rows_committed_and_check_invariants() + + def _impute_total_rows_committed_and_check_invariants(self): + # we also want to ensure that we haven't started any shards until we've finished the previous ones + total_committed = 0 + last_shard_name = None + last_was_finished = True + all_finished = True + + for status in self.shards: + shard_name = status.shard_name + if not last_was_finished and status.num_rows_committed > 0: + raise ValueError( + f"Shard {shard_name} has rows committed but previous shard in group {last_shard_name} " + "is not finished. Something about the cache configuration has changed: either the " + "number/order of shards, the shard shuffle random seed, or the number of groups." + ) + total_committed += status.num_rows_committed + if not status.is_finished: + all_finished = False + last_was_finished = status.is_finished + last_shard_name = shard_name - def group_total_known(self, group: str, total: int): - if group not in self.groups: - raise ValueError(f"Group {group} not in {self.groups}") + return total_committed, all_finished - if group not in self._remaining_groups: - raise ValueError(f"Group {group} already finished: {total} vs {self._totals_expected[group]}") - self._totals_expected[group] = total +def _make_interleave(name: str, source: ShardedDataSource, initial_ledger: CacheLedger, processor: BatchProcessor): + """ + Given a list of ShardStatus objects and sources, creates an interleaving generator + that reads from shards and tokenizes them in parallel. - if self._totals_written[group] == total: - assert len(self.buffers[group]) == 0 - self._remaining_groups.remove(group) - elif self._totals_written[group] > total: - raise ValueError(f"Group {group} has written more than expected: {self._totals_written[group]} > {total}") + We use ShardStatus objects to track the progress of each shard. If we're preempted, we can resume + from the last shard we were working on. This function starts each shard at the last committed row + and starts interleaving from the next shard (i.e. the one with the fewest rows that isn't finished). + """ + logger.setLevel(DEFAULT_LOG_LEVEL) + statuses = _get_shard_statuses(initial_ledger, source) - def is_finished(self): - return len(self._remaining_groups) == 0 + options = initial_ledger.metadata.options - def pop(self) -> Optional[tuple[str, T]]: - group = self._next_group_to_read_from() - if group is None: - return None + unfinished_shards = _check_current_shard_progress(statuses) - if len(self.buffers[group]) == 0: - return None + if not unfinished_shards: + logger.info("All shards finished. Nothing to do.") + return - cur_serial, item = self.buffers[group][0] + group_names, groups = _randomize_and_group_shards(name, options, statuses) - # logger.debug( - # f"group: {group}, cur_serial: {cur_serial}, totals_written: {self._totals_written[group]}," - # f" totals_expected: {self._totals_expected.get(group)}" - # ) + logger.warning(f"Starting cache build with {len(statuses)} shards, in {len(groups)} groups") - if cur_serial > self._totals_written[group]: - return None - elif cur_serial < self._totals_written[group]: - raise ValueError(f"Duplicate serial {cur_serial} for group {group}") + process_task = _mk_process_task(processor) + processor_ref = ray.put(processor) - heapq.heappop(self.buffers[group]) - logger.debug(f"Read item {cur_serial} from {group}") + def _make_generator_fn(group: _ShardGroup): + def generator(): + pylogging.basicConfig(level=DEFAULT_LOG_LEVEL, format=LOG_FORMAT) + for message in _shard_reader_generator(source, group, options.batch_size): + match message: + case _Batch(): + # processed = ray.put(process_task(ray.get(message.payload))) + processed = process_task.remote(processor_ref, message.payload) + yield dataclasses.replace(message, payload=processed) + case _ShardFinished(): + yield message + case _: + raise AssertionError(f"Unexpected message type {type(message)}") - self._totals_written[group] += 1 + return generator - if self._totals_written[group] == self._totals_expected[group]: - assert len(self.buffers[group]) == 0 - assert group in self._remaining_groups - self._remaining_groups.remove(group) + generator_fns = [_make_generator_fn(group) for group in groups] - self._current_group = (self._current_group + 1) % len(self.groups) + readers = [ + RayPrefetchQueue(fn, 128, producer_options=dict(name=name, scheduling_strategy="SPREAD")) + for name, fn in zip(group_names, generator_fns) + ] - return group, item + # then figure out the first shard to start from. This is the first unfinished shard with the minimum number of rows + first_group_to_start = min( + range(len(groups)), + key=lambda i: groups[i].total_rows_committed, + ) - def drain(self) -> Iterator[tuple[str, T]]: - while True: - item = self.pop() - if item is None: - break - yield item + yield from _interleave_shards(readers, first_group_to_start) - def _next_group_to_read_from(self): - """ - Returns the next group to read from. This is always the group with the least that is not finished. - """ - if len(self._remaining_groups) == 0: - return None - # careful: this is only correct if self._current_group is correct. whenever we fast forward, we have to - # recompute it - while True: - group = self.groups[self._current_group] - if group not in self._remaining_groups: - assert self._totals_written[group] == self._totals_expected[group] - assert len(self.buffers[group]) == 0 - self._current_group = (self._current_group + 1) % len(self.groups) - else: - break - return group +def _check_current_shard_progress(statuses): + unfinished_shards: list[_ShardStatus] = [] + shards_with_progress: dict[str, int] = {} + for status in statuses: + if not status.is_finished: + unfinished_shards.append(status) + if status.num_rows_committed > 0: + shards_with_progress[status.shard_name] = status.num_rows_committed + if unfinished_shards and shards_with_progress: + formatted = ", ".join(f"{k}: {v}" for k, v in shards_with_progress.items()) + logger.info(f"Resuming from shards with progress: {formatted}") + return unfinished_shards - def fast_forward(self, group, num_rows): - """ - Fast forwards the buffer for a group to a certain number of rows. This sets the "next" item to be the - num_rows-th item. - """ - if group not in self.groups: - raise ValueError(f"Group {group} not in {self.groups}") - if self._totals_written[group] != 0: - raise ValueError(f"Group {group} already written to: {self._totals_written[group]}") +def _randomize_and_group_shards(name, options, statuses): + if options.shard_order_randomization_key is not None: + seed = options.shard_order_randomization_key + logger.info(f"Randomizing shard order with seed {seed}") + statuses = _randomize_shards(statuses, seed) - self._totals_written[group] = num_rows + num_groups = min( + options.num_shard_groups if options.num_shard_groups is not None else len(statuses), len(statuses) + ) + if num_groups == 1: + group_names = [f"generator::{name}::all_shards"] + elif len(statuses) == num_groups: + group_names = [f"generator::{name}::{status.shard_name}" for status in statuses] + else: + group_names = [f"generator::{name}::group_{i}" for i in range(num_groups)] - self._fix_current_group() + groups = _assign_shards_to_groups(statuses, num_groups) + return group_names, groups - def _fix_current_group(self): - # This is always the minimum total written group that is not finished - self._current_group = 0 - min_total = None - for i, group in enumerate(self.groups): - if group not in self._remaining_groups: - continue - total = self._totals_written[group] - if min_total is None or total < min_total: - min_total = total - self._current_group = i +def _shard_reader_generator( + shard_source: ShardedDataSource[T], group: _ShardGroup, batch_size: int +) -> Iterator[_Message]: + """ + Given a group of shards, implicitly concatenates the shards and reads from them. + """ + for status in group.shards: + if status.is_finished: + logger.info(f"Skipping finished shard {status.shard_name}") + continue + start_row = status.num_rows_committed + logger.info(f"Opening shard {status.shard_name} at row {start_row}") + shard_iter = shard_source.open_shard_at_row(status.shard_name, start_row) + + batch = [] + batch_idxes = [] + row_idx = start_row + for row in shard_iter: + batch.append(row) + batch_idxes.append(row_idx) + row_idx += 1 + + if len(batch) == batch_size: + yield _Batch(status.shard_name, batch_idxes, ray.put(batch)) + batch = [] + batch_idxes = [] + + if len(batch) > 0: + yield _Batch(status.shard_name, batch_idxes, ray.put(batch)) + + logger.info(f"Finished generating shard {status.shard_name} with {row_idx} rows") + yield _ShardFinished(status.shard_name, row_idx) + + +def _mk_process_task(processor: BatchProcessor[T, U]) -> RemoteFunction: + """ + Returns a Ray remote function that processes a batch of data. Basically it takes the resources from + the processor and wraps its call + """ + # processor_ref = ray.put(processor) + # exemplar = { + # "input_ids": np.random.randint(0, 100, size=(4096,)) + # } - def next_missing_item_index(self): - """ - Returns the index of the next item that is not in the buffer - (i.e. what's stopping us from yielding the next item). - """ - if len(self._remaining_groups) == 0: - return None + @ray.remote(num_cpus=processor.num_cpus, num_gpus=processor.num_gpus, resources=processor.resources) + def process_task(processor, batch_payload): + try: + result = processor(batch_payload) # TIME: 0.03 seconds + result = _canonicalize_batch(result) # type: ignore + logger.debug("Finished processing batch") + return result + except Exception as e: + logger.exception("Error while processing batch") + raise e + finally: + pass - group = self.groups[self._current_group] - if group not in self._remaining_groups: - self._fix_current_group() - return self.next_missing_item_index() + return process_task - if len(self.buffers[group]) == 0: - return self._totals_written[group] - cur_serial, _ = self.buffers[group][0] +def _canonicalize_batch(batch: Union[dict, List[dict]]) -> List[dict]: + if isinstance(batch, pa.RecordBatch): + batch = dict_from_record_batch(batch) - if cur_serial > self._totals_written[group]: - return self._totals_written[group] - elif cur_serial < self._totals_written[group]: - raise ValueError(f"Duplicate serial {cur_serial} for group {group}") + if isinstance(batch, dict): + return _to_list_of_dicts(batch) + else: + return batch - return None + +def _to_list_of_dicts(batch: dict) -> List[dict]: + """ + Convert a batch of dictionaries to a list of dictionaries, suitable for writing to a cache. + """ + keys = list(batch.keys()) + values = list(batch.values()) + num_rows = len(values[0]) + return [{key: values[i][j] for i, key in enumerate(keys)} for j in range(num_rows)] + + +def _ledger_to_metrics(ledger: CacheLedger) -> InProgressCacheMetrics: + # TODO: remove this + return InProgressCacheMetrics( + rows_finished=ledger.total_num_rows, + is_finished=ledger.is_finished, + # shard_rows=ledger.shard_rows, + shards_finished=len(ledger.finished_shards), + field_counts=ledger.field_counts, + ) -def div_round_up(x, y): - return (x + y - 1) // y +def _get_shard_statuses(ledger: CacheLedger, source: ShardedDataSource): + return [ + _ShardStatus(name, ledger.shard_rows.get(name, 0), name in ledger.finished_shards) + for name in source.shard_names + ] diff --git a/src/levanter/store/jagged_array.py b/src/levanter/store/jagged_array.py index 8b3a26a54..b236641c9 100644 --- a/src/levanter/store/jagged_array.py +++ b/src/levanter/store/jagged_array.py @@ -14,7 +14,7 @@ from levanter.utils.thread_utils import future_from_value -# zarr suggests 1MB chunk size (in bytes, but whatever) +# zarr suggests 1MB chunk size # at 4 bytes this is 256k elements DEFAULT_CHUNK_SIZE = 256 * 1024 DEFAULT_WRITE_CHUNK_SIZE = DEFAULT_CHUNK_SIZE * 512 @@ -38,9 +38,14 @@ class JaggedArrayStore: data: ts.TensorStore shapes: Optional[ts.TensorStore] # (len(offsets), len(data.shape)-1) item_rank: int = 1 + _cache_metadata: bool = False + _cached_num_rows: Optional[int] = None + _cached_data_size: Optional[int] = None @staticmethod - async def open_async(path: Optional[str], *, mode="a", item_rank=1, dtype) -> "JaggedArrayStore": + async def open_async( + path: Optional[str], *, mode="a", item_rank=1, dtype, cache_metadata: bool = False + ) -> "JaggedArrayStore": offset_path = _extend_path(path, "offsets") offsets = _ts_open_async(offset_path, jnp.int64, [1], mode=mode) @@ -53,10 +58,12 @@ async def open_async(path: Optional[str], *, mode="a", item_rank=1, dtype) -> "J else: shapes = None - return JaggedArrayStore(await offsets, await data, await shapes if shapes is not None else None, item_rank) + return JaggedArrayStore( + await offsets, await data, await shapes if shapes is not None else None, item_rank, cache_metadata + ) @staticmethod - def open(path: Optional[str], *, mode="a", item_rank=1, dtype) -> "JaggedArrayStore": + def open(path: Optional[str], *, mode="a", item_rank=1, dtype, cache_metadata: bool = False) -> "JaggedArrayStore": offset_path = _extend_path(path, "offsets") offsets = _ts_open_sync(offset_path, jnp.int64, [1], mode=mode) @@ -69,18 +76,42 @@ def open(path: Optional[str], *, mode="a", item_rank=1, dtype) -> "JaggedArraySt else: shapes = None - return JaggedArrayStore(offsets, data, shapes, item_rank) + return JaggedArrayStore(offsets, data, shapes, item_rank, cache_metadata) @property def num_rows(self): - return int(self.offsets[0].read().result()) + if self._cached_num_rows is not None: + return self._cached_num_rows + result = int(self.offsets[0].read().result()) + if self._cache_metadata: + self._cached_num_rows = result + return result async def num_rows_async(self): - return int(await self.offsets[0].read()) + if self._cached_num_rows is not None: + return self._cached_num_rows + result = int(await self.offsets[0].read()) + if self._cache_metadata: + self._cached_num_rows = result + return result @property def data_size(self): - return int(self.offsets[self.num_rows].read().result()) + # return int(self.offsets[self.num_rows].read().result()) + if self._cached_data_size is not None: + return self._cached_data_size + result = int(self.offsets[self.num_rows].read().result()) + if self._cache_metadata: + self._cached_data_size = result + return result + + async def data_size_async(self): + if self._cached_data_size is not None: + return self._cached_data_size + result = int(await self.offsets[self.num_rows].read()) + if self._cache_metadata: + self._cached_data_size = result + return result async def append_async(self, data: jax.Array): await self.extend_async([data]) @@ -122,6 +153,10 @@ async def trim_to_size_async(self, size: int): await data_fut await offsets_fut + if self._cache_metadata: + self._cached_num_rows = size + self._cached_data_size = new_max + def trim_to_size(self, size: int): if size >= self.num_rows: return @@ -151,6 +186,10 @@ def trim_to_size(self, size: int): if shape_fut is not None: shape_fut.result() + if self._cache_metadata: + self._cached_num_rows = size + self._cached_data_size = new_max + async def extend_async(self, arrays: Sequence[jax.Array]): data, new_offsets, shapes = self._prepare_batch(arrays) @@ -165,13 +204,14 @@ async def extend_async(self, arrays: Sequence[jax.Array]): ] if self.shapes is not None: write_tasks.append(self.shapes[num_rows : num_rows + num_added].write(shapes)) - await asyncio.gather(*write_tasks) # Update num_rows - int(self.offsets[self.num_rows].read().result()) await self.offsets[0].write(num_rows + len(arrays)) - # print("done") + + if self._cache_metadata: + self._cached_num_rows = num_rows + len(arrays) + self._cached_data_size = current_data_size + len(data) def extend(self, arrays: Sequence[jax.Array]): data, new_offsets, shapes = self._prepare_batch(arrays) @@ -187,12 +227,16 @@ def extend(self, arrays: Sequence[jax.Array]): if self.shapes is not None: write_tasks.append(self.shapes[num_rows : num_rows + num_added].write(shapes)) + # Update num_rows. We want to make sure this comes after the other data is committed to avoid a race for task in write_tasks: task.result() - # Update num_rows. We want to make sure this comes after the other data is committed to avoid a race self.offsets[0].write(num_rows + len(arrays)).result() + if self._cache_metadata: + self._cached_num_rows = num_rows + len(arrays) + self._cached_data_size = current_data_size + len(data) + def _prepare_batch(self, arrays): if self.shapes is not None: for data in arrays: diff --git a/src/levanter/store/tree_store.py b/src/levanter/store/tree_store.py index 0b1e93bff..cd29e5a4c 100644 --- a/src/levanter/store/tree_store.py +++ b/src/levanter/store/tree_store.py @@ -1,6 +1,6 @@ import asyncio import os -from typing import Generic, List, TypeVar +from typing import Generic, List, Sequence, TypeVar import jax import jax.numpy as jnp @@ -50,17 +50,17 @@ def __init__(self, tree, path: str, mode: str): self.tree = tree @staticmethod - def open(exemplar: T, path: str, *, mode="a") -> "TreeStore": + def open(exemplar: T, path: str, *, mode="a", cache_metadata: bool = False) -> "TreeStore": """ Open a TreeStoreBuilder from a file. """ - tree = _construct_builder_tree(exemplar, path, mode) + tree = _construct_builder_tree(exemplar, path, mode, cache_metadata) return TreeStore(tree, path, mode) def append(self, ex: T): return self.extend([ex]) - def extend(self, batch: List[T]): + def extend(self, batch: Sequence[T]): """ Append a batch of data to the store. """ @@ -168,12 +168,18 @@ def get_batch_sync(self, indices) -> List[T]: return out -def _construct_builder_tree(exemplar, path, mode): +def _construct_builder_tree(exemplar, path, mode, cache_metadata): def open_builder(tree_path, item): item = np.asarray(item) rank = item.ndim render_tree_path = "/".join(_render_path_elem(x) for x in tree_path) - return JaggedArrayStore.open(os.path.join(path, render_tree_path), mode=mode, item_rank=rank, dtype=item.dtype) + return JaggedArrayStore.open( + os.path.join(path, render_tree_path), + mode=mode, + item_rank=rank, + dtype=item.dtype, + cache_metadata=cache_metadata, + ) return jtu.tree_map_with_path(open_builder, exemplar, is_leaf=heuristic_is_leaf) diff --git a/src/levanter/utils/py_utils.py b/src/levanter/utils/py_utils.py index a796dd6af..8431e1c3a 100644 --- a/src/levanter/utils/py_utils.py +++ b/src/levanter/utils/py_utils.py @@ -1,5 +1,6 @@ import os import sys +import time from dataclasses import dataclass from typing import Callable, TypeVar @@ -181,3 +182,37 @@ def actual_sizeof(obj): need_to_see.extend(obj) objects = need_to_see return size + + +class Stopwatch: + """Resumable stop watch for tracking time per call""" + + def __init__(self): + self._start_time = time.time() + self._elapsed = 0.0 + self._n = 0 + + def start(self): + self._start_time = time.time() + self._n += 1 + + def stop(self): + self._elapsed += time.time() - self._start_time + + def reset(self): + self._elapsed = 0.0 + + def elapsed(self): + return self._elapsed + + def average(self): + if self._n == 0: + return 0.0 + return self._elapsed / self._n + + def __enter__(self): + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stop() diff --git a/src/levanter/utils/ray_utils.py b/src/levanter/utils/ray_utils.py index 8a299720e..40c76b614 100644 --- a/src/levanter/utils/ray_utils.py +++ b/src/levanter/utils/ray_utils.py @@ -1,6 +1,7 @@ import contextlib import dataclasses import logging +import logging as pylogging import sys from dataclasses import dataclass from typing import Optional @@ -52,6 +53,9 @@ class RefBox: ref: ray.ObjectRef + def get(self): + return ray.get(self.ref) + class DoneSentinel: pass @@ -78,7 +82,7 @@ def current_actor_handle() -> ray.actor.ActorHandle: class SnitchRecipient: logger: logging.Logger - def _child_failed(self, child: ray.actor.ActorHandle, exception: ExceptionInfo): + def _child_failed(self, child: ray.actor.ActorHandle | str | None, exception: ExceptionInfo): info = exception.restore() self.logger.error(f"Child {child} failed with exception {info[1]}", exc_info=info) exception.reraise() @@ -90,6 +94,40 @@ def log_failures_to(parent, suppress=False): try: yield except Exception as e: - parent._child_failed.remote(current_actor_handle(), ser_exc_info(e)) + try: + handle = current_actor_handle() + except RuntimeError: + handle = ray.runtime_context.get_runtime_context().get_task_id() + + parent._child_failed.remote(handle, ser_exc_info(e)) if not suppress: raise e + + +DEFAULT_LOG_LEVEL = logging.INFO +LOG_FORMAT = "%(asctime)s %(levelname)s: %(message)s" + + +@ray.remote +class StopwatchActor: + def __init__(self): + pylogging.basicConfig(level=DEFAULT_LOG_LEVEL, format=LOG_FORMAT) + self._logger = pylogging.getLogger("StopwatchActor") + self._times_per = {} + self._counts_per = {} + self._total = 0 + + def measure(self, name: str, time: float): + self._times_per[name] = self._times_per.get(name, 0) + time + self._counts_per[name] = self._counts_per.get(name, 0) + 1 + self._total += 1 + + if self._total % 1000 == 0: + for name, time in self._times_per.items(): + self._logger.info(f"{name}: {time / self._counts_per[name]}") + + def get(self, name: str): + return self._times_per.get(name, 0), self._counts_per.get(name, 0) + + def average(self, name: str): + return self._times_per.get(name, 0) / self._counts_per.get(name, 1) diff --git a/tests/test_audio.py b/tests/test_audio.py index 8d3015431..3ad9b09b3 100644 --- a/tests/test_audio.py +++ b/tests/test_audio.py @@ -80,3 +80,13 @@ def test_hf_audio_serial_cache(): assert ex["input_features"].shape == (80, 3000), ex["input_features"].shape assert ex["input_ids"].shape == (1024,), ex["input_ids"].shape assert ex["attention_mask"].shape == (1024,), ex["attention_mask"].shape + + +@skip_if_no_soundlibs +@skip_if_hf_model_not_accessible("openai/whisper-tiny") +def test_metadata_works(): + processor = AutoProcessor.from_pretrained("openai/whisper-tiny") + tokenizer = AutoTokenizer.from_pretrained("openai/whisper-tiny") + batch_processor = BatchAudioProcessor(processor, tokenizer) + # test this doesn't throw + assert len(batch_processor.metadata) diff --git a/tests/test_jagged_array.py b/tests/test_jagged_array.py index 24ed24b08..c89a2c625 100644 --- a/tests/test_jagged_array.py +++ b/tests/test_jagged_array.py @@ -10,9 +10,10 @@ class TestJaggedArrayStore: - def test_append_and_get(self): + @pytest.mark.parametrize("cache_metadata", [True, False]) + def test_append_and_get(self, cache_metadata): with tempfile.TemporaryDirectory() as tmpdir: - builder = JaggedArrayStore.open(tmpdir, item_rank=2, dtype=jnp.float32) + builder = JaggedArrayStore.open(tmpdir, item_rank=2, dtype=jnp.float32, cache_metadata=cache_metadata) data1 = jnp.array([[1.0, 2.0], [3.0, 4.0]]) data2 = jnp.array([[5.0]]) @@ -31,9 +32,10 @@ def test_append_and_get(self): # result_slice = builder[0:2] # assert isinstance(result_slice, JaggedArray) - def test_extend_with_multiple(self): + @pytest.mark.parametrize("cache_metadata", [True, False]) + def test_extend_with_multiple(self, cache_metadata): with tempfile.TemporaryDirectory() as tmpdir: - builder = JaggedArrayStore.open(tmpdir, item_rank=2, dtype=jnp.float32) + builder = JaggedArrayStore.open(tmpdir, item_rank=2, dtype=jnp.float32, cache_metadata=cache_metadata) data1 = jnp.array([[1.0, 2.0], [3.0, 4.0]]) data2 = jnp.array([[5.0]]) @@ -54,9 +56,10 @@ def test_append_error(self): with pytest.raises(ValueError): builder.append(jnp.array([[1.0, 2.0]])) - def test_append_single_rank(self): + @pytest.mark.parametrize("cache_metadata", [True, False]) + def test_append_single_rank(self, cache_metadata): with tempfile.TemporaryDirectory() as tmpdir: - builder = JaggedArrayStore.open(tmpdir, item_rank=1, dtype=jnp.float32) + builder = JaggedArrayStore.open(tmpdir, item_rank=1, dtype=jnp.float32, cache_metadata=cache_metadata) data = jnp.array([1.0, 2.0, 3.0]) builder.append(data) @@ -66,9 +69,10 @@ def test_append_single_rank(self): result = builder[0] assert jnp.all(result == data) - def test_append_multi_rank(self): + @pytest.mark.parametrize("cache_metadata", [True, False]) + def test_append_multi_rank(self, cache_metadata): with tempfile.TemporaryDirectory() as tmpdir: - builder = JaggedArrayStore.open(tmpdir, item_rank=2, dtype=jnp.float32) + builder = JaggedArrayStore.open(tmpdir, item_rank=2, dtype=jnp.float32, cache_metadata=cache_metadata) data1 = jnp.array([[1.0, 2.0], [3.0, 4.0]]) data2 = jnp.array([[5.0, 6.0], [7.0, 8.0]]) @@ -105,14 +109,18 @@ def test_step_slicing(self): # builder[::2] -async def create_builder_with_data(directory, num_sequences: int, sequence_length: int | tuple[int, ...]): +async def create_builder_with_data( + directory, num_sequences: int, sequence_length: int | tuple[int, ...], cache_metadata: bool = True +) -> JaggedArrayStore: if isinstance(sequence_length, int): sequence_length = (sequence_length,) """Helper function to create a JaggedArrayStore with specific data.""" seed = jax.random.PRNGKey(num_sequences * math.prod(sequence_length)) - builder = await JaggedArrayStore.open_async(directory, item_rank=len(sequence_length), dtype=jnp.int64) + builder = await JaggedArrayStore.open_async( + directory, item_rank=len(sequence_length), dtype=jnp.int64, cache_metadata=cache_metadata + ) for i in range(num_sequences): key, seed = jax.random.split(seed) data = jax.random.randint(key, sequence_length, 0, 100) @@ -122,7 +130,7 @@ async def create_builder_with_data(directory, num_sequences: int, sequence_lengt def create_builder_with_data_sync( - directory, num_sequences: int, sequence_length: int | tuple[int, ...] + directory, num_sequences: int, sequence_length: int | tuple[int, ...], cache_metadata: bool = True ) -> JaggedArrayStore: if isinstance(sequence_length, int): sequence_length = (sequence_length,) @@ -130,7 +138,9 @@ def create_builder_with_data_sync( """Helper function to create a JaggedArrayStore with specific data.""" seed = jax.random.PRNGKey(num_sequences * math.prod(sequence_length)) - builder = JaggedArrayStore.open(directory, item_rank=len(sequence_length), dtype=jnp.int64) + builder = JaggedArrayStore.open( + directory, item_rank=len(sequence_length), dtype=jnp.int64, cache_metadata=cache_metadata + ) for i in range(num_sequences): key, seed = jax.random.split(seed) data = jax.random.randint(key, sequence_length, 0, 100) @@ -190,9 +200,12 @@ async def test_trim_to_size_larger_than_current(): @pytest.mark.asyncio -async def test_trim_to_size_with_shapes_async(): +@pytest.mark.parametrize("cache_metadata", [True, False]) +async def test_trim_to_size_with_shapes_async(cache_metadata): tmpdir = tempfile.TemporaryDirectory().name - builder = await create_builder_with_data(tmpdir, num_sequences=10, sequence_length=(10, 100)) + builder = await create_builder_with_data( + tmpdir, num_sequences=10, sequence_length=(10, 100), cache_metadata=cache_metadata + ) expected_shapes = list(await builder.shapes[0:10].read()) # Trim to smaller size @@ -205,9 +218,12 @@ async def test_trim_to_size_with_shapes_async(): assert np.array_equal(trimmed_shapes, jnp.stack(expected_shapes[:5])) -def test_trim_to_size(): +@pytest.mark.parametrize("cache_metadata", [True, False]) +def test_trim_to_size_sync(cache_metadata): tmpdir = tempfile.TemporaryDirectory().name - builder = create_builder_with_data_sync(tmpdir, num_sequences=10, sequence_length=1000) + builder = create_builder_with_data_sync( + tmpdir, num_sequences=10, sequence_length=1000, cache_metadata=cache_metadata + ) # Initial size initial_size = len(builder) diff --git a/tests/test_new_cache.py b/tests/test_new_cache.py index b6132e548..af6fa885f 100644 --- a/tests/test_new_cache.py +++ b/tests/test_new_cache.py @@ -1,42 +1,43 @@ import asyncio +import copy import logging +import os import tempfile -from typing import Iterator, Sequence -from unittest.mock import MagicMock +from typing import Any, Dict, Iterator, Sequence import numpy as np import pytest import ray -from ray.exceptions import RayTaskError from levanter.data import BatchProcessor, ShardedDataSource, batched from levanter.data.sharded_datasource import TextUrlDataSource from levanter.store.cache import ( + LEDGER_FILE_NAME, + CacheLedger, + CacheOptions, SerialCacheWriter, + ShardedCacheWriter, TreeStore, _get_builder_actor, - _OrderedCacheWriter, + _serialize_json_and_commit, build_or_load_cache, ) from levanter.utils.py_utils import logical_cpu_core_count -from levanter.utils.ray_utils import ExceptionInfo, SnitchRecipient, ser_exc_info +from levanter.utils.ray_utils import ExceptionInfo, SnitchRecipient class TestProcessor(BatchProcessor[Sequence[int], dict[str, np.ndarray]]): - def __init__(self, batch_size: int = 8): - self._batch_size = batch_size - def __call__(self, batch: Sequence[Sequence[int]]) -> Sequence[dict[str, np.ndarray]]: # return pa.RecordBatch.from_arrays([pa.array(batch)], ["test"]) return [{"test": np.asarray(x)} for x in batch] @property - def output_exemplar(self): - return {"test": np.array([0], dtype=np.int64)} + def metadata(self) -> Dict[str, Any]: + return {} @property - def batch_size(self) -> int: - return self._batch_size + def output_exemplar(self): + return {"test": np.array([0], dtype=np.int64)} @property def num_cpus(self) -> int: @@ -52,8 +53,7 @@ def simple_process(processor, source): return result -def process_interleave(processor, source): - batch_size = processor.batch_size +def process_interleave(processor, source, batch_size): shard_iterators = { shard_name: batched(iter(source.open_shard(shard_name)), batch_size) for shard_name in source.shard_names } @@ -82,16 +82,9 @@ def teardown_module(module): class SimpleProcessor(BatchProcessor[Sequence[int], dict[str, np.ndarray]]): - def __init__(self, batch_size: int = 8): - self._batch_size = batch_size - def __call__(self, batch: Sequence[Sequence[int]]) -> Sequence[dict[str, Sequence[int]]]: return [{"data": x} for x in batch] - @property - def batch_size(self) -> int: - return self._batch_size - @property def num_cpus(self) -> int: return 1 @@ -100,6 +93,10 @@ def num_cpus(self) -> int: def output_exemplar(self) -> dict[str, np.ndarray]: return {"data": np.array([0], dtype=np.int64)} + @property + def metadata(self) -> Dict[str, Any]: + return {} + class SimpleShardSource(ShardedDataSource[list[int]]): def __init__(self, num_shards: int = 4): @@ -124,7 +121,7 @@ def test_serial_cache_writer(): with SerialCacheWriter(tmpdir1, exemplar) as writer: for shard_name in source.shard_names: - for ex in batched(source.open_shard(shard_name), processor.batch_size): + for ex in batched(source.open_shard(shard_name), 32): writer.write_batch(processor(ex)) _ = writer.result() @@ -181,7 +178,7 @@ def shard_finished(self, shard_name): def get_finished_shards(self): return self._finished_shards - def _updated_ledger(self, ledger): + def _notify_updated_ledger(self, ledger): if ledger.is_finished: self._finished = True @@ -193,421 +190,56 @@ def _finalize(self): def is_finished(self): return self._finished - def signal_backpressure(self, desired_next_item: float): - self._desired_next_item = desired_next_item - - def desired_next_item(self): - return self._desired_next_item - - -@pytest.mark.asyncio -async def test_batch_finished(): - parent = PretendParent.remote() - exemplar = np.array([1, 2, 3]) - with tempfile.TemporaryDirectory() as cache_dir: - shards = ["shard1", "shard2", "shard3"] - writer = _OrderedCacheWriter.remote( - parent, "test", exemplar, DEFAULT_BATCH_SIZE, cache_dir, shards, min_items_to_write=1 - ) - - try: - shard_idx = "shard1" - shard_batch_idx = 0 - batch_result = [np.array([1, 2, 3])] - - await writer.batch_finished.remote(shard_idx, shard_batch_idx, batch_result) - await writer.flush.remote() - shard_status = await writer.get_shard_status.remote("shard1") - assert shard_status.num_rows_committed == 1 - finally: - ray.kill(parent) - ray.kill(writer) - - -@pytest.mark.asyncio -async def test_shard_finished_reading(): - parent = PretendParent.remote() - exemplar = MagicMock() - with tempfile.TemporaryDirectory() as cache_dir: - shards = ["shard1", "shard2", "shard3"] - writer = _OrderedCacheWriter.remote(parent, "test", exemplar, DEFAULT_BATCH_SIZE, cache_dir, shards) - - try: - shard_name = "shard1" - expected_batches = 5 - - await writer.shard_finished_reading.remote(shard_name, expected_batches) - shard_status = await writer.get_shard_status.remote(shard_name) - assert shard_status.is_finished is False - finally: - ray.kill(parent) - ray.kill(writer) - - -@pytest.mark.asyncio -async def test_get_shard_status(): - parent = PretendParent.remote() - exemplar = np.array([1, 2, 3]) - with tempfile.TemporaryDirectory() as cache_dir: - shards = ["shard1", "shard2", "shard3"] - writer = _OrderedCacheWriter.remote(parent, "test", exemplar, DEFAULT_BATCH_SIZE, cache_dir, shards) - - try: - shard_name = "shard1" - shard_status = await writer.get_shard_status.remote(shard_name) - - assert shard_status.shard_name == shard_name - assert shard_status.num_rows_committed == 0 - assert not shard_status.is_finished - finally: - ray.kill(parent) - ray.kill(writer) - - -@pytest.mark.asyncio -async def test_shard_failed(): - parent = PretendParent.remote() - exemplar = MagicMock() - with tempfile.TemporaryDirectory() as cache_dir: - shards = ["shard1", "shard2", "shard3"] - writer = _OrderedCacheWriter.remote(parent, "test", exemplar, DEFAULT_BATCH_SIZE, cache_dir, shards) - - try: - shard_name = "shard1" - batch_id = 0 - try: - raise Exception("Test Exception") - except: # noqa - exc_info = ser_exc_info() - - await writer.shard_failed.remote(shard_name, batch_id, exc_info) - exception_received = await parent.wait_for_failure.remote() - assert str(exception_received.ex) == str(exc_info.ex) - finally: - ray.kill(parent) - ray.kill(writer) - - -DEFAULT_BATCH_SIZE = 128 - - -@pytest.mark.asyncio -async def test_attempt_to_write_batches(): - parent = PretendParent.remote() - exemplar = np.array([1, 2, 3]) - with tempfile.TemporaryDirectory() as cache_dir: - shards = ["shard1", "shard2", "shard3"] - writer = _OrderedCacheWriter.remote( - parent, "test", exemplar, DEFAULT_BATCH_SIZE, cache_dir, shards, min_items_to_write=2 - ) - - try: - shard1_batch = [np.asarray([1, 2, 3])] - shard2_batch = [np.asarray([4, 5, 6, 7])] - - await writer.batch_finished.remote("shard1", 0, shard1_batch) - await writer.batch_finished.remote("shard2", 0, shard2_batch) - - await writer.flush.remote() - - ledger = await writer.get_ledger.remote() - assert ledger.is_finished is False - assert ledger.total_num_rows == 2 # Assuming each batch has 1 row for simplicity - - store = TreeStore.open(exemplar, cache_dir, mode="r") - assert len(store) == 2 - np.testing.assert_array_equal(store[0], shard1_batch[0]) - np.testing.assert_array_equal(store[1], shard2_batch[0]) - finally: - ray.kill(parent) - ray.kill(writer) - - -@pytest.mark.asyncio -async def test_finalize_cache(): - parent = PretendParent.remote() - exemplar = np.array([1, 2, 3]) - with tempfile.TemporaryDirectory() as cache_dir: - shards = ["shard1", "shard2", "shard3"] - writer = _OrderedCacheWriter.remote(parent, "test", exemplar, DEFAULT_BATCH_SIZE, cache_dir, shards) - - try: - shard1_batch = [np.array([1, 2, 3])] - shard2_batch = [np.array([4, 5, 6, 7])] - - await writer.batch_finished.remote("shard1", 0, shard1_batch) - await writer.shard_finished_reading.remote("shard1", 1) - await writer.shard_finished_reading.remote("shard2", 1) - await writer.batch_finished.remote("shard2", 0, shard2_batch) - await writer.flush.remote() - - ledger = await writer.get_ledger.remote() - assert ledger.is_finished is False - assert ledger.total_num_rows == 2 # Assuming each batch has 1 row for simplicity - - await writer.shard_finished_reading.remote("shard3", 0) - finished_shards = await parent.get_finished_shards.remote() - assert len(finished_shards) == 3 - - ledger = await writer.get_ledger.remote() - assert ledger.is_finished is True - assert ledger.total_num_rows == 2 - assert await parent.is_finished.remote() is True - finally: - ray.kill(parent) - ray.kill(writer) - - -@pytest.mark.asyncio -async def test_error_handling(): - parent = PretendParent.remote() - exemplar = np.array([1, 2, 3]) - with tempfile.TemporaryDirectory() as cache_dir: - shards = ["shard1", "shard2", "shard3"] - writer = _OrderedCacheWriter.remote(parent, "test", exemplar, DEFAULT_BATCH_SIZE, cache_dir, shards) - - try: - with pytest.raises(TypeError): - await writer.batch_finished.remote("shard1", 0, None) - - exception_received = await parent.wait_for_failure.remote() - assert exception_received is not None - finally: - ray.kill(parent) - ray.kill(writer) - - -@pytest.mark.asyncio -async def test_out_of_order_batches_same_shard(): - parent = PretendParent.remote() - exemplar = np.array([1, 2, 3]) - with tempfile.TemporaryDirectory() as cache_dir: - shards = ["shard1"] - writer = _OrderedCacheWriter.remote( - parent, "test", exemplar, DEFAULT_BATCH_SIZE, cache_dir, shards, min_items_to_write=2 - ) - - try: - # Sending batch 1 before batch 0 for shard1 - shard1_batch0 = [np.array([1, 2, 3])] - shard1_batch1 = [np.array([4, 5, 6])] - - await writer.batch_finished.remote("shard1", 1, shard1_batch1) - await writer.batch_finished.remote("shard1", 0, shard1_batch0) - await writer.flush.remote() - - store = TreeStore.open(exemplar, cache_dir, mode="r") - assert len(store) == 2 - np.testing.assert_array_equal(store[0], shard1_batch0[0]) - np.testing.assert_array_equal(store[1], shard1_batch1[0]) - finally: - ray.kill(parent) - ray.kill(writer) - - -@pytest.mark.asyncio -async def test_out_of_order_batches_different_shards(): - parent = PretendParent.remote() - exemplar = np.array([1, 2, 3]) - with tempfile.TemporaryDirectory() as cache_dir: - shards = ["shard1", "shard2"] - writer = _OrderedCacheWriter.remote( - parent, "test", exemplar, DEFAULT_BATCH_SIZE, cache_dir, shards, min_items_to_write=3 - ) - - try: - # Sending batches out of order across different shards - shard1_batch0 = [np.array([1, 2, 3])] - shard2_batch0 = [np.array([4, 5, 6])] - shard1_batch1 = [np.array([7, 8, 9])] - - await writer.batch_finished.remote("shard1", 1, shard1_batch1) - await writer.batch_finished.remote("shard2", 0, shard2_batch0) - await writer.batch_finished.remote("shard1", 0, shard1_batch0) - await writer.flush.remote() - - store = TreeStore.open(exemplar, cache_dir, mode="r") - assert len(store) == 3 - np.testing.assert_array_equal(store[0], shard1_batch0[0]) - np.testing.assert_array_equal(store[1], shard2_batch0[0]) - np.testing.assert_array_equal(store[2], shard1_batch1[0]) - finally: - ray.kill(parent) - ray.kill(writer) - - -@pytest.mark.asyncio -async def test_batches_different_orders_all_shards(): - parent = PretendParent.remote() - exemplar = np.array([1, 2, 3]) - with tempfile.TemporaryDirectory() as cache_dir: - shards = ["shard1", "shard2", "shard3"] - writer = _OrderedCacheWriter.remote( - parent, "test", exemplar, DEFAULT_BATCH_SIZE, cache_dir, shards, min_items_to_write=2 - ) - - try: - # Sending batches in different orders across all shards - shard1_batch0 = [np.array([1, 2, 3])] - shard1_batch1 = [np.array([4, 5, 6])] - shard2_batch0 = [np.array([7, 8, 9])] - shard3_batch0 = [np.array([10, 11, 12])] - - await writer.batch_finished.remote("shard2", 0, shard2_batch0) - await writer.batch_finished.remote("shard3", 0, shard3_batch0) - await writer.batch_finished.remote("shard1", 1, shard1_batch1) - await writer.batch_finished.remote("shard1", 0, shard1_batch0) - await writer.flush.remote() - - store = TreeStore.open(exemplar, cache_dir, mode="r") - assert len(store) == 4 - np.testing.assert_array_equal(store[0], shard1_batch0[0]) - np.testing.assert_array_equal(store[1], shard2_batch0[0]) - np.testing.assert_array_equal(store[2], shard3_batch0[0]) - np.testing.assert_array_equal(store[3], shard1_batch1[0]) - finally: - ray.kill(parent) - ray.kill(writer) - -@pytest.mark.asyncio -async def test_intermixed_batches_same_and_different_shards(): - parent = PretendParent.remote() - exemplar = np.array([1, 2, 3]) - with tempfile.TemporaryDirectory() as cache_dir: - shards = ["shard1", "shard2", "shard3"] - writer = _OrderedCacheWriter.remote( - parent, "test", exemplar, DEFAULT_BATCH_SIZE, cache_dir, shards, min_items_to_write=1 +@pytest.mark.ray +def test_full_end_to_end_cache(): + td = tempfile.TemporaryDirectory() + with td as tmpdir: + ray_ds = build_or_load_cache( + tmpdir, + SimpleShardSource(num_shards=2), + TestProcessor(), + await_finished=True, + options=CacheOptions.no_fanciness(8), ) - try: - # Sending intermixed batches from the same and different shards - shard1_batch0 = [np.array([1, 2, 3])] - shard2_batch0 = [np.array([4, 5, 6])] - shard1_batch1 = [np.array([7, 8, 9])] - shard3_batch0 = [np.array([10, 11, 12])] - shard2_batch1 = [np.array([13, 14, 15])] - - await writer.batch_finished.remote("shard2", 0, shard2_batch0) - await writer.batch_finished.remote("shard3", 0, shard3_batch0) - await writer.batch_finished.remote("shard1", 1, shard1_batch1) - await writer.batch_finished.remote("shard2", 1, shard2_batch1) - await writer.batch_finished.remote("shard1", 0, shard1_batch0) - await writer.flush.remote() - - store = TreeStore.open(exemplar, cache_dir, mode="r") - assert len(store) == 5 - np.testing.assert_array_equal(store[0], shard1_batch0[0]) - np.testing.assert_array_equal(store[1], shard2_batch0[0]) - np.testing.assert_array_equal(store[2], shard3_batch0[0]) - np.testing.assert_array_equal(store[3], shard1_batch1[0]) - np.testing.assert_array_equal(store[4], shard2_batch1[0]) - finally: - ray.kill(parent) - ray.kill(writer) - + expected = process_interleave(TestProcessor(), SimpleShardSource(num_shards=2), 8) -@pytest.mark.asyncio -async def test_duplicate_batches_same_shard(): - parent = PretendParent.remote() - exemplar = np.array([1, 2, 3]) - with tempfile.TemporaryDirectory() as cache_dir: - shards = ["shard1"] - writer = _OrderedCacheWriter.remote(parent, "test", exemplar, DEFAULT_BATCH_SIZE, cache_dir, shards) - - try: - # Sending duplicate batches for the same shard - shard1_batch0 = [np.array([1, 2, 3])] - - await writer.batch_finished.remote("shard1", 0, shard1_batch0) - await writer.flush.remote() - with pytest.raises(RayTaskError): - await writer.batch_finished.remote("shard1", 0, shard1_batch0) # Duplicate - finally: - ray.kill(parent) - ray.kill(writer) - - -@pytest.mark.asyncio -async def test_mixed_order_batches_multiple_shards(): - parent = PretendParent.remote() - exemplar = np.array([1, 2, 3]) - with tempfile.TemporaryDirectory() as cache_dir: - shards = ["shard1", "shard2", "shard3"] - writer = _OrderedCacheWriter.remote( - parent, "test", exemplar, DEFAULT_BATCH_SIZE, cache_dir, shards, min_items_to_write=1 - ) + all_data = ray_ds[:] - try: - # Sending batches in mixed order for multiple shards - shard1_batch0 = [np.array([1, 2, 3])] - shard2_batch0 = [np.array([4, 5, 6])] - shard1_batch1 = [np.array([7, 8, 9])] - shard2_batch1 = [np.array([10, 11, 12])] - shard3_batch0 = [np.array([13, 14, 15])] - shard3_batch1 = [np.array([16, 17, 18])] - - await writer.batch_finished.remote("shard3", 0, shard3_batch0) - await writer.batch_finished.remote("shard1", 1, shard1_batch1) - await writer.batch_finished.remote("shard2", 0, shard2_batch0) - await writer.batch_finished.remote("shard2", 1, shard2_batch1) - await writer.batch_finished.remote("shard1", 0, shard1_batch0) - await writer.batch_finished.remote("shard3", 1, shard3_batch1) - await writer.flush.remote() - - store = TreeStore.open(exemplar, cache_dir, mode="r") - assert len(store) == 6 - np.testing.assert_array_equal(store[0], shard1_batch0[0]) - np.testing.assert_array_equal(store[1], shard2_batch0[0]) - np.testing.assert_array_equal(store[2], shard3_batch0[0]) - np.testing.assert_array_equal(store[3], shard1_batch1[0]) - np.testing.assert_array_equal(store[4], shard2_batch1[0]) - np.testing.assert_array_equal(store[5], shard3_batch1[0]) - finally: - ray.kill(parent) - ray.kill(writer) + check_datasets_equal(all_data, expected) @pytest.mark.ray -def test_full_end_to_end_cache_simple(): +def test_full_end_to_end_cache_with_groups(): td = tempfile.TemporaryDirectory() with td as tmpdir: ray_ds = build_or_load_cache( tmpdir, - SimpleShardSource(num_shards=1), + SimpleShardSource(num_shards=5), TestProcessor(), await_finished=True, + options=CacheOptions(num_shard_groups=2, batch_size=8, shard_order_randomization_key=None), ) - simple_processed = simple_process(TestProcessor(), SimpleShardSource()) + expected = process_interleave(TestProcessor(), SimpleShardSource(num_shards=5), 8) all_data = ray_ds[:] - check_datasets_equal(all_data, simple_processed) + # check_datasets_equal(all_data, expected) + assert len(all_data) == len(list(expected)) @pytest.mark.ray def test_cache_remembers_its_cached(): directory = tempfile.TemporaryDirectory() with directory as tmpdir: - ds1 = build_or_load_cache(tmpdir, SimpleShardSource(), TestProcessor()) + ds1 = build_or_load_cache(tmpdir, SimpleShardSource(), TestProcessor(), await_finished=True) - class ThrowingProcessor(BatchProcessor[Sequence[int], dict[str, np.ndarray]]): + class ThrowingProcessor(TestProcessor): def __call__(self, batch: Sequence[Sequence[int]]): raise RuntimeError("This should not be called") - @property - def output_exemplar(self) -> dict[str, np.ndarray]: - return {"test": np.array([0], dtype=np.int64)} - - @property - def batch_size(self) -> int: - return 8 - - @property - def num_cpus(self) -> int: - return 1 - # testing this doesn't throw ds2 = build_or_load_cache(tmpdir, SimpleShardSource(), ThrowingProcessor(), await_finished=True) @@ -615,6 +247,9 @@ def num_cpus(self) -> int: def check_datasets_equal(ds1, ds2): + ds1 = list(ds1) + ds2 = list(ds2) + assert len(ds1) == len(ds2) for r1, r2 in zip(ds1, ds2): assert r1.keys() == r2.keys() for key in r1.keys(): @@ -672,7 +307,6 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[list[int]]: # compare to the original with no crash reader2 = build_or_load_cache(tmpdir2, SimpleShardSource(), TestProcessor(), await_finished=True) - assert len(list(reader1)) == 40 check_datasets_equal(reader1, reader2) @@ -699,23 +333,26 @@ def shard_names(self) -> Sequence[str]: return ["shard_0", "shard_1"] def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[list[int]]: + assert shard_name in self.shard_names max_count = 40 if shard_name == "shard_1" else 20 shard_id = int(shard_name.split("_")[1]) for i in range(0, max_count): yield [i * 10 + shard_id] * 10 with tempfile.TemporaryDirectory() as tmpdir: + processor = TestProcessor() cache = build_or_load_cache( tmpdir, SlowShardSource(), - TestProcessor(1), + processor, await_finished=False, + options=CacheOptions.no_fanciness(16), ) # now block until the cache is done - cache.await_finished(timeout=10) + cache.await_finished(timeout=30) - expected = process_interleave(TestProcessor(1), SlowShardSource()) + expected = process_interleave(processor, SlowShardSource(), 16) check_datasets_equal(list(cache[:]), expected) @@ -750,8 +387,13 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[list[int]]: with tempfile.TemporaryDirectory() as tmpdir: cache = build_or_load_cache( - tmpdir, SlowShardSource(), TestProcessor(5), await_finished=False, items_per_write=5 - ) + tmpdir, + SlowShardSource(), + TestProcessor(), + await_finished=False, + force_flush=True, + options=CacheOptions.no_fanciness(5), + ) # we need force_flush to ensure the cache is written to disk # read the first 10 elements # ensure the first 10 elements are [{"test": np.array([i] * 10)} for i in range(10)] @@ -782,22 +424,10 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[list[int]]: @pytest.mark.skip("This test segfaults in CI. I think a ray bug") @pytest.mark.ray def test_shard_cache_crashes_if_processor_throws(): - class ThrowingProcessor(BatchProcessor[Sequence[int], dict[str, np.ndarray]]): + class ThrowingProcessor(SimpleProcessor): def __call__(self, batch: Sequence[Sequence[int]]): raise RuntimeError("exc") - @property - def output_exemplar(self) -> dict: - return {"test": np.array([0], dtype=np.int64)} - - @property - def batch_size(self) -> int: - return 8 - - @property - def num_cpus(self) -> int: - return 1 - with tempfile.TemporaryDirectory() as tmpdir: with pytest.raises(RuntimeError): build_or_load_cache(tmpdir, SimpleShardSource(), ThrowingProcessor(), await_finished=True) @@ -880,60 +510,81 @@ def test_shard_cache_fails_gracefully_with_unknown_file_type(): del cache -@pytest.mark.ray -@pytest.mark.asyncio -async def test_backpressure_mechanism(): - parent = PretendParent.remote() - exemplar = np.array([1, 2, 3]) - with tempfile.TemporaryDirectory() as cache_dir: - shards = ["shard1", "shard2", "shard3"] - writer = _OrderedCacheWriter.remote( - parent, "test", exemplar, DEFAULT_BATCH_SIZE, cache_dir, shards, min_items_to_write=1 - ) +def test_sharded_cache_writer(): + with tempfile.TemporaryDirectory() as tmpdir: + source = SimpleShardSource(num_shards=4) + processor = SimpleProcessor() + ledger = CacheLedger.load_or_initialize(tmpdir, source, processor, CacheOptions.no_fanciness(8)) + + exemplar = {"data": np.array([0], dtype=np.int64)} + + writer = ShardedCacheWriter(tmpdir, ledger, exemplar) + for shard_name in source.shard_names: + for ex in batched(source.open_shard(shard_name), ledger.metadata.options.batch_size): + writer.write_batch(shard_name, processor(ex)) + + store = writer.finish() + + data_path = store.path + + del store + + builder = TreeStore.open(exemplar, data_path, mode="r") + + assert len(builder) == 40 + + for i, x in enumerate(builder): + np.testing.assert_array_equal(x["data"], np.asarray([i % 10 + i // 10 * 10] * 10)) + + # check totals for the ledger + ledger = writer.ledger + assert ledger.total_num_rows == 40 + assert ledger.is_finished + + for shard_name in source.shard_names: + assert ledger.shard_rows[shard_name] == 10 + + +def test_sharded_cache_writer_trims_on_resume(): + with tempfile.TemporaryDirectory() as tmpdir: + source = SimpleShardSource(num_shards=4) + processor = SimpleProcessor() + + exemplar = {"data": np.array([0], dtype=np.int64)} + + ledger = CacheLedger.load_or_initialize(tmpdir, source, processor, CacheOptions.no_fanciness(batch_size=8)) + + writer = ShardedCacheWriter(tmpdir, ledger, exemplar) + for shard_name in source.shard_names: + for ex in batched(source.open_shard(shard_name), 8): + writer.write_batch(shard_name, processor(ex)) + + writer.finish() + + # now deliberately truncate the ledger a bit + ledger = copy.deepcopy(writer.ledger) + assert ledger.total_num_rows == 40 + assert ledger.is_finished + ledger.total_num_rows = 24 + ledger.shard_rows["shard_0"] = 8 + ledger.shard_rows["shard_1"] = 8 + ledger.shard_rows["shard_2"] = 8 + ledger.shard_rows["shard_3"] = 0 + ledger.is_finished = False + + _serialize_json_and_commit(os.path.join(tmpdir, LEDGER_FILE_NAME), ledger) + + writer = ShardedCacheWriter(tmpdir, ledger, exemplar) + + # ensure it got truncated + assert writer.ledger.total_num_rows == 24 + assert writer.ledger.is_finished is False + assert writer.ledger.shard_rows["shard_0"] == 8 + assert writer.ledger.shard_rows["shard_1"] == 8 + assert writer.ledger.shard_rows["shard_2"] == 8 + assert writer.ledger.shard_rows["shard_3"] == 0 + + new_store = writer._tree_store + new_data = new_store[:] - # Simulate batches being processed - shard1_batch = [np.array([1, 2, 3])] - shard2_batch = [np.array([4, 5, 6])] - shard3_batch = [np.array([7, 8, 9])] - - # await writer.batch_finished.remote("shard1", 0, shard1_batch) - await writer.batch_finished.remote("shard2", 0, shard2_batch) - await writer.batch_finished.remote("shard3", 0, shard3_batch) - await writer.batch_finished.remote("shard1", 1, shard3_batch) - await writer.batch_finished.remote("shard1", 2, shard3_batch) - await writer.batch_finished.remote("shard1", 3, shard3_batch) - await writer.flush.remote() - - # Check if backpressure is signaled - is_overwhelmed = await writer.is_overwhelmed.remote() - assert is_overwhelmed is True - await writer.flush.remote() - - for i in range(4): - if (await parent.desired_next_item.remote()) == 0: - break - - await asyncio.sleep(0.1 * (i + 1) * (i + 1)) - else: - assert False, "Backpressure wasn't sent" - - await writer.batch_finished.remote("shard1", 0, shard1_batch) - - # Reduce the queue size to relieve backpressure - # Check if backpressure is relieved - is_overwhelmed = await writer.is_overwhelmed.remote() - count = 0 - while is_overwhelmed and count < 10: - await writer.flush.remote() - await asyncio.sleep(0.4) - is_overwhelmed = await writer.is_overwhelmed.remote() - count += 1 - assert is_overwhelmed is False - - for i in range(4): - if (await parent.desired_next_item.remote()) is None: - break - - await asyncio.sleep(0.1 * (i + 1) * (i + 1)) - else: - assert False, "Backpressure wasn't relieved" + assert len(new_data) == 24 diff --git a/tests/test_prefetch_actor.py b/tests/test_prefetch_actor.py new file mode 100644 index 000000000..e48546fc1 --- /dev/null +++ b/tests/test_prefetch_actor.py @@ -0,0 +1,137 @@ +import time +from typing import Iterator + +import pytest +import ray + +from levanter.store._prefetch_actor import RayPrefetchQueue + + +def _sleep_until(condition, timeout=5, message="Condition not met within timeout"): + start = time.time() + while not condition(): + if time.time() - start > timeout: + pytest.fail(message) + time.sleep(0.1) + + +@pytest.fixture(scope="module", autouse=True) +def ray_init_and_shutdown(): + ray.init("local", ignore_reinit_error=True) + yield + ray.shutdown() + + +@pytest.mark.ray +def test_initialization_and_basic_functionality(): + def simple_producer(): + for i in range(10): + yield i + + actor = RayPrefetchQueue(simple_producer) + results = [actor.get_next() for _ in range(10)] + assert results == list(range(10)) + + +@pytest.mark.ray +def test_queue_size_limit(): + def simple_producer() -> Iterator[ray.ObjectRef]: + for i in range(100): + yield i + + actor = RayPrefetchQueue(simple_producer, max_queue_size=10) + # Allow some time for the queue to fill up + _sleep_until(lambda: actor.queue_size() == 10) + + # get a few items to make some space + [actor.get_next() for _ in range(5)] + _sleep_until(lambda: actor.queue_size() == 10, message="Queue size did not reach 10") + + +@pytest.mark.ray +def test_stop_functionality(): + def simple_producer(): + for i in range(10000): + yield i + + actor = RayPrefetchQueue(simple_producer) + actor.stop() + + _sleep_until(lambda: actor.is_stopped(), message="Actor did not stop") + + +@pytest.mark.ray +def test_exception_handling(): + def faulty_producer(): + for i in range(5): + yield i + raise ValueError("Test exception") + + actor = RayPrefetchQueue(faulty_producer) + results = [] + try: + for _ in range(10): + results.append(actor.get_next()) + except ValueError as e: + assert "Test exception" in str(e) # Ray puts a lot of crap in the exception message + assert results == list(range(5)) + + +@pytest.mark.ray +def test_empty_producer(): + def empty_producer() -> Iterator[ray.ObjectRef]: + if False: + yield + + actor = RayPrefetchQueue(empty_producer) + with pytest.raises(StopIteration): + actor.get_next() + + +@pytest.mark.ray +def test_multiple_consumers(): + def simple_producer() -> Iterator[ray.ObjectRef]: + for i in range(20): + yield i + + actor = RayPrefetchQueue(simple_producer) + results = [actor.get_next() for _ in range(10)] + results += [actor.get_next() for _ in range(10)] + assert results == list(range(20)) + + +@pytest.mark.ray +def test_producer_completion(): + def simple_producer(): + for i in range(10): + yield i + + actor = RayPrefetchQueue(simple_producer) + results = [] + try: + while True: + results.append(actor.get_next()) + except StopIteration: + pass + assert results == list(range(10)) + + +@pytest.mark.ray +def test_drain_queue(): + def simple_producer(): + for i in range(10): + yield i + + actor = RayPrefetchQueue(simple_producer) + + all_results = [] + + for tot in range(0, 5): + out = actor.drain_available(tot) + assert len(out) <= tot + all_results.extend(out) + + while len(all_results) < 10: + all_results.append(actor.get_next()) + + assert all_results == list(range(10)) diff --git a/tests/test_tree_store.py b/tests/test_tree_store.py index e25ef7928..66131ca48 100644 --- a/tests/test_tree_store.py +++ b/tests/test_tree_store.py @@ -1,5 +1,5 @@ import tempfile -from typing import Iterator, List, Sequence +from typing import Any, Dict, Iterator, List, Sequence import numpy as np import pytest @@ -11,9 +11,6 @@ class SimpleProcessor(BatchProcessor[Sequence[int], dict[str, np.ndarray]]): - def __init__(self, batch_size: int = 8): - self._batch_size = batch_size - def __call__(self, batch: Sequence[Sequence[int]]) -> Sequence[dict[str, Sequence[int]]]: return [{"data": x} for x in batch] @@ -21,14 +18,14 @@ def __call__(self, batch: Sequence[Sequence[int]]) -> Sequence[dict[str, Sequenc def output_exemplar(self) -> dict[str, Sequence[int]]: return {"data": np.array([0], dtype=np.int64)} - @property - def batch_size(self) -> int: - return self._batch_size - @property def num_cpus(self) -> int: return 1 + @property + def metadata(self) -> Dict[str, Any]: + return {} + class SimpleShardSource(ShardedDataSource[List[int]]): def __init__(self, num_shards: int = 4): @@ -52,7 +49,7 @@ def test_tree_builder_with_processor(): processor = SimpleProcessor() source = SimpleShardSource() - for batch in batched(source, processor.batch_size): + for batch in batched(source, 8): processed = processor(batch) builder.extend(processed) diff --git a/tests/test_utils.py b/tests/test_utils.py index 6206ec2ff..a86217549 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,7 +1,7 @@ import glob import os from functools import reduce -from typing import Callable, List, Optional, Sequence, TypeVar +from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar import draccus import equinox as eqx @@ -204,6 +204,10 @@ def output_exemplar(self): def num_cpus(self) -> int: return 0 + @property + def metadata(self) -> Dict[str, Any]: + return {} + class ShardsDataSource(ShardedDataSource[T]): def __init__(self, docs: List[List[T]]):