From c90cff9bef4341cea376d4e330dcddabeaf4d86a Mon Sep 17 00:00:00 2001 From: Nikil Ravi <55033516+nikil-ravi@users.noreply.github.com> Date: Thu, 5 Dec 2024 13:51:46 -0800 Subject: [PATCH 1/3] make eval_harness part of levanter namespace (#833) For being able to import as `from levanter.eval_harness import __` --- src/levanter/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/levanter/__init__.py b/src/levanter/__init__.py index f9570aaf7..6b2cbeb1e 100644 --- a/src/levanter/__init__.py +++ b/src/levanter/__init__.py @@ -3,6 +3,7 @@ import levanter.data as data import levanter.distributed as distributed import levanter.eval as eval +import levanter.eval_harness as eval_harness import levanter.models as models import levanter.optim as optim import levanter.tracker as tracker From 091f1cd2a8ffa4c43b79f7c68a08eff27b007a73 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed <ahmedah@stanford.edu> Date: Thu, 5 Dec 2024 22:06:58 -0800 Subject: [PATCH 2/3] fix toml to capture dev transformers (#834) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 233af26f5..f2a63f7ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ dependencies = [ "equinox>=0.11.7", "jaxtyping>=0.2.34", "tokenizers>=0.15.2", - "transformers>=4.41.2,<4.48.0", + "transformers>=4.41.2,<4.49.0", "optax>=0.1.9", "wandb>=0.17.8", "draccus>=0.9.3", From 19b5f93d908f6875e35cddd36da6927032a4b01e Mon Sep 17 00:00:00 2001 From: David Hall <dlwh@cs.stanford.edu> Date: Fri, 6 Dec 2024 15:19:53 -0800 Subject: [PATCH 3/3] remove a bunch of old unused stuff (#832) --- pyproject.toml | 1 - src/levanter/data/_preprocessor.py | 53 ---- src/levanter/data/metrics_monitor.py | 87 ------- src/levanter/data/shard_cache.py | 0 src/levanter/mesh.py | 58 ----- src/levanter/models/longformer.py | 114 --------- src/levanter/store/_prefetch_actor.py | 156 ----------- src/levanter/store/stress_test_new_cache.py | 148 ----------- src/levanter/utils/actor_pool.py | 270 -------------------- src/levanter/utils/py_utils.py | 95 ------- tests/test_actor_pool.py | 167 ------------ tests/test_longformer.py | 102 -------- tests/test_prefetch_actor.py | 137 ---------- 13 files changed, 1388 deletions(-) delete mode 100644 src/levanter/data/shard_cache.py delete mode 100644 src/levanter/mesh.py delete mode 100644 src/levanter/models/longformer.py delete mode 100644 src/levanter/store/_prefetch_actor.py delete mode 100644 src/levanter/store/stress_test_new_cache.py delete mode 100644 src/levanter/utils/actor_pool.py delete mode 100644 tests/test_actor_pool.py delete mode 100644 tests/test_longformer.py delete mode 100644 tests/test_prefetch_actor.py diff --git a/pyproject.toml b/pyproject.toml index f2a63f7ae..f07ff3864 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,6 @@ dependencies = [ "dataclasses-json~=0.6.4", "ray[default]>=2.34.0", "pydantic<3", - "rich~=13.0", "filelock~=3.13", "async-lru~=2.0", "tqdm-loggable>=0.2", diff --git a/src/levanter/data/_preprocessor.py b/src/levanter/data/_preprocessor.py index dd6578667..dd810857b 100644 --- a/src/levanter/data/_preprocessor.py +++ b/src/levanter/data/_preprocessor.py @@ -1,13 +1,8 @@ -import logging from abc import ABC, abstractmethod from typing import Any, Callable, Dict, Generic, Iterable, Mapping, Sequence, TypeVar, Union import numpy as np import pyarrow as pa -import ray - -from levanter.utils.actor_pool import AutoScalingActorPool, PoolWorkerBase -from levanter.utils.ray_utils import RefBox T = TypeVar("T") @@ -236,54 +231,6 @@ def to_hf_batched(x): return {b.field(i).name: to_hf_batched(b.column(i).to_numpy(zero_copy_only=False)) for i in range(b.num_columns)} -@ray.remote(num_cpus=0.1) # keep this low b/c it doesn't do much -class BatchProcessorPool: - def __init__(self, processor: BatchProcessor, min_size: int = 1, max_size: int = 10): - logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(message)s") - processor_ref = ray.put(processor) - self.actor_pool = AutoScalingActorPool( - lambda: _create_batch_processor_actor(processor, processor_ref), min_size, max_size - ) - - async def process_batch(self, batch_ref: RefBox): - return await self.actor_pool.submit( - lambda a, b: a.process_batch.remote(b), batch_ref.ref, obj_ref=batch_ref.ref - ) - - def num_pending_tasks(self): - return self.actor_pool.num_pending_tasks - - def resize_pool(self, *, min_size: int | None = None, max_size: int | None = None): - self.actor_pool.resize_pool(min_size=min_size, max_size=max_size) - - def ensure_max_at_least(self, size: int): - self.actor_pool.resize_pool(max_size=max(size, self.actor_pool.get_max_size())) - - -def _create_batch_processor_actor(processor: BatchProcessor, processor_ref): - cpus = processor.num_cpus - gpus = processor.num_gpus - resources = processor.resources - return _BatchProcessorActor.options( # type: ignore - num_cpus=cpus, num_gpus=gpus, resources=resources, scheduling_strategy="SPREAD" - ).remote(processor_ref) - - -@ray.remote -class _BatchProcessorActor(PoolWorkerBase): - def __init__(self, processor: BatchProcessor): - from levanter.store.tree_store import TreeBatchPreparer - - self.processor = processor - self.preparer = TreeBatchPreparer(processor.output_exemplar) - - def process_batch(self, batch): - result = self.processor(batch) - result = _canonicalize_batch(result) - prepared = self.preparer(result) - return prepared - - def _canonicalize_batch(batch: Union[dict, list[dict]]) -> list[dict]: if isinstance(batch, pa.RecordBatch): batch = dict_from_record_batch(batch) diff --git a/src/levanter/data/metrics_monitor.py b/src/levanter/data/metrics_monitor.py index 96c17ec65..d6cd496f8 100644 --- a/src/levanter/data/metrics_monitor.py +++ b/src/levanter/data/metrics_monitor.py @@ -1,21 +1,11 @@ import dataclasses import logging as pylogging -import threading import time from dataclasses import dataclass from typing import Any, Dict, Optional, Protocol, Union import jax from dataclasses_json import dataclass_json -from rich.progress import ( - BarColumn, - Progress, - TaskID, - TaskProgressColumn, - TextColumn, - TimeElapsedColumn, - TimeRemainingColumn, -) import levanter.tracker @@ -35,53 +25,6 @@ def __call__(self, metrics: InProgressCacheMetrics): ... -class RichMetricsMonitor(MetricsMonitor): - - progress: Optional[Progress] # type: ignore - task: Optional[TaskID] - - def __init__(self, num_shards, **kwargs): - """kwargs are passed to rich.progress.Progress""" - self.kwargs = kwargs - self.progress: Optional[Progress] = None - self.task = None - self.num_shards = num_shards - - def __call__(self, metrics: InProgressCacheMetrics): - if self.progress is None: - self._init_progress(metrics) - - self.progress.update(self.task, completed=metrics.shards_finished, **dataclasses.asdict(metrics)) # type: ignore - - self.progress.refresh() # type: ignore - - if metrics.is_finished: - self.progress.stop() # type: ignore - - def _init_progress(self, metrics): - columns = [ - BarColumn(), - TaskProgressColumn(), - TextColumn("| {task.fields[rows_finished]} docs", justify="center"), - ] - - for field in metrics.field_counts: - columns.append(TextColumn(f"| {{task.fields[field_counts][{field}]}} {field}", justify="center")) - - columns.append(TimeElapsedColumn()) - columns.append(TimeRemainingColumn()) - - self.progress = Progress( - *columns, - **self.kwargs, - ) - - self.task = self.progress.add_task( - "Shards", total=self.num_shards, completed=metrics.shards_finished, **dataclasses.asdict(metrics) - ) - self.progress.start() - - class LoggingMetricsMonitor(MetricsMonitor): last_metrics: Optional[InProgressCacheMetrics] last_time: Optional[float] @@ -109,16 +52,6 @@ def __call__(self, metrics: InProgressCacheMetrics): if metrics.is_finished: to_log[f"{self.prefix}/finished"] = 1 - # estimate the rate of progress - # if self.last_metrics is not None: - # assert self.last_time is not None - # elapsed = time.time() - self.last_time - # to_log[f"{self.prefix}/shards_per_s"] = (metrics.shards_finished - self.last_metrics.shards_finished) / elapsed - # to_log[f"{self.prefix}/rows_per_s"] = (metrics.rows_finished - self.last_metrics.rows_finished) / elapsed - # - # for field, count in metrics.field_counts.items(): - # to_log[f"{self.prefix}/{field}_per_s"] = (count - self.last_metrics.field_counts[field]) / elapsed - self.last_metrics = metrics self.last_time = time.time() @@ -153,23 +86,3 @@ def __call__(self, metrics: InProgressCacheMetrics): if metrics.is_finished: self.logger.info("Cache creation finished") - - -class WaitTimeReportingThread(threading.Thread): - def __init__(self, report, interval=60): - super().__init__() - self.report = report - self.interval = interval - self.shutdown_event = threading.Event() - - def run(self): - total_waited = 0 - while True: - if self.shutdown_event.wait(self.interval): - break - if total_waited > 0: - self.report(total_waited) - total_waited += self.interval - - def shutdown(self): - self.shutdown_event.set() diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/levanter/mesh.py b/src/levanter/mesh.py deleted file mode 100644 index 9783ce5eb..000000000 --- a/src/levanter/mesh.py +++ /dev/null @@ -1,58 +0,0 @@ -from typing import Optional - -import jax -import numpy as np -from jax.sharding import Mesh - - -def local_device_grid_positions(mesh, process_index: Optional[int] = None) -> tuple[np.ndarray, np.ndarray]: - """Returns a tuple of nd arrays, one for each axis, indicating the position of each device on the grid. - Analogous to what np.where would return.""" - if process_index is None: - process_index = jax.process_index() - - my_device_pos = np.vectorize(lambda dev: dev.process_index == process_index)(mesh.devices) - return my_device_pos.nonzero() - - -def local_devices_mapping(mesh: Mesh, process_index: Optional[int] = None) -> dict[int, int]: - """ - Handles the case when different devices in same process share the same data in TP. - Returns a mapping from local devices' DP/FSDP group index in global mesh to local indices - """ - local_device_pos = local_device_grid_positions(mesh, process_index)[:2] # first 2 axes are DP axes. - result = {} - uid = 0 - for local_device_index in range(len(local_device_pos[0])): - key = local_device_pos[0][local_device_index] * mesh.devices.shape[1] + local_device_pos[1][local_device_index] - if key not in result: - # when two devices maps to the same key (different TP index), they will get the same data - result[key] = uid - uid += 1 - return result - - -def process_mesh_mapping(mesh) -> dict[int, int]: - """ - Handles the case when different processes share the same data in TP. - If we envision each process as a subgrid of the mesh for its devices, this is the position of the process - in the coarsened process-level mesh - """ - devices = mesh.devices - result = {} - uid = 0 - leftmost2uid = {} - # basic logic: process index -> upper-left device -> TP index 0 device -> process index -> uid - for process_index in range(jax.process_count()): - tmp = [np.min(axis) for axis in local_device_grid_positions(mesh, process_index)] - tmp[-1] = 0 # we want the device with TP group index 0 in the same DP/FSDP group - upper_left_position = tuple(tmp) # in order to index into devices - upper_left_process = devices[upper_left_position].process_index - # assign uid to each process that has a device with TP group index 0 - if upper_left_process not in leftmost2uid: - leftmost2uid[upper_left_process] = uid - uid += 1 - this_uid = leftmost2uid[upper_left_process] - result[process_index] = this_uid - - return result diff --git a/src/levanter/models/longformer.py b/src/levanter/models/longformer.py deleted file mode 100644 index c956aec14..000000000 --- a/src/levanter/models/longformer.py +++ /dev/null @@ -1,114 +0,0 @@ -from typing import Optional - -import jax.lax -import jax.numpy as jnp - -import haliax as hax -from haliax import Axis, NamedArray -from haliax.types import PrecisionLike - - -def causal_sliding_window_attention( - Pos: Axis, - Window: Axis, - Head: Axis, - query: NamedArray, - key: NamedArray, - value: NamedArray, - bias: Optional[NamedArray] = None, - attention_dtype: Optional[jnp.dtype] = None, - precision: PrecisionLike = None, -) -> NamedArray: - """ - Computes sliding window attention a la Longformer. This method uses blocking because Jax can't figure it - out automatically. - """ - # We use the window size as the block size - # The basic idea is that we want to compute attention one block (of query) at a time, where a block is a window - # of the sequence. Each q can attend to the prior window_size-1 positions plus itself - assert Window.size <= Pos.size, "Window size must be at least 2x sequence length" - assert Pos.size % Window.size == 0, "Sequence length must be divisible by window size" - - if Window.size == Pos.size: - # we can just use regular attention - # we have to special case this because jax won't like the attend_block_N function - # which doesn't actually get executed but does get traced - K = Pos.alias("K") - return hax.nn.attention.dot_product_attention( - K, - Head, - query, - key.rename({Pos: K}), - value.rename({Pos: K}), - mask=hax.nn.attention.causal_mask(Pos, K), - bias=bias, - attention_dtype=attention_dtype, - precision=precision, - ) - - # the attention structure is that each query attends to the prior window_size positions (q - window_size, q] - # We extract one query block of length Window.size at a time (the block size and window size could be different, but - # this seems fine) - # For each query block, we extract a key and value block of length KWindow == Window.size * 2 - 1 - # The key block is [query_block_start - window_size + 1, query_block_start + window_size) - - # TODO: relax? - Block = Axis("Block", Pos.size // Window.size) - KWindow = Axis("KWindow", Window.size * 2 - 1) # this is what we need to grab from the key/value - - # this makes code a bit easier to read below - Q = Window - K = KWindow - - # for our attention masks, each q can attend to the prior window_size-1 positions plus itself - # that is, each q can attend to k s.t. k \in [q - window_size + 1, q] - # however, note that K is offset by window_size - 1, so we need to shift the mask by that amount - # this means that we want to mask out k s.t. k \in [q, q + window_size) - # equivalently, k - q \in [0, window_size) - diff = hax.arange(K) - hax.arange(Q).broadcast_axis(K) - attn_mask = (diff >= 0) & (diff < Window.size) - - def attend_block_N(block_idx): - block_idx = block_idx.scalar() - query_block = query.slice(Pos, Q, start=block_idx * Q.size) - # extract the relevant window from the key and value - # this spans [query_block_start - window_size + 1, query_block_start + window_size) - key_block = key.slice(Pos, K, start=(block_idx - 1) * Q.size + 1) - value_block = value.slice(Pos, K, start=(block_idx - 1) * Q.size + 1) - - if bias is not None: - bias_block = bias.slice(Pos, K, start=(block_idx - 1) * Q.size + 1) - else: - bias_block = None - - return hax.nn.attention.dot_product_attention( - K, Head, query_block, key_block, value_block, attn_mask, bias_block, attention_dtype, precision - ) - - # for the 0th block, we have to worry about the out-of-bounds. just use a causal mask and do normal causal attention - # NB if you change it so that the block size and window size aren't the same, you'll need to change this - K0 = Q.alias("K0") - attn_mask_0 = hax.nn.attention.causal_mask(Q, K0) - - def attend_block_0(block_idx): - query_block = query.slice(Pos, Q, start=0) - key_block = key.slice(Pos, K0, start=0) - value_block = value.slice(Pos, K0, start=0) - if bias is not None: - bias_block = bias.slice(Pos, K0, start=0) - else: - bias_block = None - return hax.nn.attention.dot_product_attention( - K0, Head, query_block, key_block, value_block, attn_mask_0, bias_block, attention_dtype, precision - ) - - # extra arg/return for dummy scan accumulator - def attend_block(_, block_idx): - return None, jax.lax.cond(block_idx.scalar() == 0, attend_block_0, attend_block_N, block_idx) - - # we use scan here to encourage jax to do the blocking - _, blocked_attn = hax.scan(attend_block, Block)(None, hax.arange(Block)) # type: ignore - - # now we need to unblock the attention - # TODO: see if the rearrange and flatten_axes have perf implications - return blocked_attn.flatten_axes((Block, Q), Pos).rearrange(value.axes) diff --git a/src/levanter/store/_prefetch_actor.py b/src/levanter/store/_prefetch_actor.py deleted file mode 100644 index 6b3c302c2..000000000 --- a/src/levanter/store/_prefetch_actor.py +++ /dev/null @@ -1,156 +0,0 @@ -import asyncio -import logging -from dataclasses import dataclass -from queue import Empty as QueueEmpty -from typing import Callable, Generic, Iterator, List, Optional, TypeVar - -import ray - -from levanter.utils.ray_utils import ExceptionInfo, ser_exc_info - - -T = TypeVar("T") - -logger = logging.getLogger(__name__) - - -@dataclass -class _PrefetchException: - info: ExceptionInfo - - -class _Sentinel: - pass - - -_SENTINEL = _Sentinel() - - -class RayPrefetchQueue(Generic[T]): - def __init__( - self, producer: Callable[[], Iterator[T]], max_queue_size: int = 100, producer_options: dict | None = None - ): - self.max_queue_size = max_queue_size - if producer_options is None: - producer_options = {} - self.queue_actor = _QueueActor.remote(max_queue_size) # type: ignore - self.producer_task = _run_producer.options(**producer_options).remote(self.queue_actor, producer) - self._stopped = False - self._finished = False - - def queue_size(self): - return ray.get(self.queue_actor.qsize.remote()) - - def __next__(self): - return self.get_next() - - def __iter__(self): - return self - - def get_next(self, timeout: float | None = None) -> T: - """ - Get the next item from the producer. If the producer raises an exception, it will be reraised here. - - If the producer is done, this will raise StopIteration. - - Args: - timeout (float|None): Timeout in seconds for getting the next item. If None, will block indefinitely. - - Raises: - Empty: If the queue is empty and the timeout is reached. - """ - if self._finished: - raise StopIteration - # time_in = time.time() - item = ray.get(self.queue_actor.get_next.remote(timeout)) - # time_out = time.time() - # if time_out - time_in > 0.1: - # current_name = ray.get_runtime_context().get_actor_name() - # print(f"{current_name} :: Queue get took {time_out - time_in} seconds :: {self.queue_size()}") - # logger.info(f"{current_name} :: Queue get took {time_out - time_in} seconds :: {self.queue_size()}") - if isinstance(item, _PrefetchException): - item.info.reraise() - if isinstance(item, _Sentinel): - self._finished = True - raise StopIteration - return item - - def stop(self): - ray.cancel(self.producer_task) - ray.get(self.queue_actor.stop.remote()) - self._stopped = True - - def is_stopped(self): - return self._stopped - - def drain_available(self, max_size: int) -> List[T]: - return ray.get(self.queue_actor.drain_available.remote(max_size)) - - -@ray.remote -class _QueueActor: - def __init__(self, max_queue_size: int): - self.queue: asyncio.Queue = asyncio.Queue(maxsize=max_queue_size) - self._stopped = False - self._finished = False - - async def put(self, item): - await self.queue.put(item) - - async def get_next(self, timeout: Optional[float] = None): - try: - if timeout is not None: - item = await asyncio.wait_for(self.queue.get(), timeout) - else: - item = await self.queue.get() - if isinstance(item, _Sentinel): - self._finished = True - return item - except asyncio.TimeoutError: - raise QueueEmpty - - async def drain_available(self, max_size: int) -> List[T]: - items: list[T] = [] - while len(items) < max_size: - try: - item = self.queue.get_nowait() - if isinstance(item, _Sentinel): - self._finished = True - break - if isinstance(item, _PrefetchException): - item.info.reraise() - items.append(item) - except asyncio.QueueEmpty: - break - return items - - async def qsize(self): - return self.queue.qsize() - - async def stop(self): - self._stopped = True - - -@ray.remote -def _run_producer(queue_actor, producer_fn: Callable[[], Iterator[T]]): - async def _run_producer(queue_actor, producer_fn): - previous_put = None - try: - producer = producer_fn() - del producer_fn - - while True: - next_item = next(producer) - if previous_put is not None: - await previous_put - previous_put = queue_actor.put.remote(next_item) - except StopIteration: - if previous_put is not None: - await previous_put - await queue_actor.put.remote(_SENTINEL) - except Exception as e: - if previous_put is not None: - await previous_put - await queue_actor.put.remote(_PrefetchException(ser_exc_info(e))) - - asyncio.run(_run_producer(queue_actor, producer_fn)) diff --git a/src/levanter/store/stress_test_new_cache.py b/src/levanter/store/stress_test_new_cache.py deleted file mode 100644 index 66d002abd..000000000 --- a/src/levanter/store/stress_test_new_cache.py +++ /dev/null @@ -1,148 +0,0 @@ -# Reads an old-style ShardCache and compares to -import asyncio -import logging -import os - -import jax.random -import numpy as np -import tensorstore as ts - -from levanter.data import PermutationDataset -from levanter.data.text import TokenSeqDataset -from levanter.store.cache import LEDGER_FILE_NAME, CacheLedger, TreeCache, _serialize_json_and_commit -from levanter.store.tree_store import TreeStore -from levanter.tracker import capture_time -from levanter.utils import fsspec_utils - - -logging.basicConfig(level=logging.INFO) - - -SEQ_LEN = 1024 -BS = 8 -BATCHES = 1000 - -# want to test reading from: -# 1) old cache sequentially -# 2) new cache sequentially -# 3) new cache randomly - - -def bench_new_cache_serial(exemplar, new_cache_path): - jagged_array = TreeStore.open(exemplar, new_cache_path).tree["input_ids"] - len_cache = jagged_array.data_size - new_cache = jagged_array.data - num_batches = len_cache // SEQ_LEN - for b in range(BATCHES): - elems = [] - with ts.Batch(): - for j in range(BS): - idx = b * BS + j - idx = idx % num_batches - arr1 = new_cache[idx * SEQ_LEN : (idx + 1) * SEQ_LEN].read() - elems.append(arr1) - - for elem in elems: - elem.result() - - -def bench_new_cache_random(exemplar, new_cache_path): - jagged_array = TreeStore.open(exemplar, new_cache_path).tree["input_ids"] - len_cache = jagged_array.data_size - new_cache = jagged_array.data - num_batches = len_cache // SEQ_LEN - for b in range(BATCHES): - elems = [] - with ts.Batch(): - for j in range(BS): - idx = np.random.randint(0, num_batches) - arr1 = new_cache[idx * SEQ_LEN : (idx + 1) * SEQ_LEN].read() - elems.append(arr1) - - for elem in elems: - elem.result() - - -async def bench_new_cache_serial_tokenseq(exemplar, new_cache_path): - ensure_cache(new_cache_path) - cache = TreeCache.load(new_cache_path, exemplar) - - ds = TokenSeqDataset(cache, SEQ_LEN) - - num_batches = await ds.async_len() - - for b in range(BATCHES): - indices = [] - for j in range(BS): - idx = b * BS + j - idx = idx % num_batches - indices.append(idx) - elems = await ds.get_batch(indices) - del elems - - -async def bench_new_cache_permutation_random(exemplar, new_cache_path): - ensure_cache(new_cache_path) - cache = TreeCache.load(new_cache_path, exemplar) - - ds = TokenSeqDataset(cache, SEQ_LEN) - ds = PermutationDataset(ds, jax.random.PRNGKey(0)) - - num_batches = await ds.async_len() - - for b in range(BATCHES): - indices = [] - for j in range(BS): - idx = b * BS + j - idx = idx % num_batches - indices.append(idx) - elems = await ds.get_batch(indices) - del elems - - -def ensure_cache(new_cache_path): - if not fsspec_utils.exists(os.path.join(new_cache_path, LEDGER_FILE_NAME)): - ledger = CacheLedger(100000, {}, True) - _serialize_json_and_commit(os.path.join(new_cache_path, LEDGER_FILE_NAME), ledger) - - -if __name__ == "__main__": - import sys - - if not len(sys.argv) == 2: - print("Usage: convert_to_new_cache.py new_cache_path") - sys.exit(1) - - for split in ["validation", "train"]: - print(f"Split: {split}", flush=True) - cache_path = os.path.join(sys.argv[1], split) - # convert_to_new_cache(in_path, out_path) - # with capture_time() as time_fn: - # bench_old_cache(in_path) - # tokens_per_second = SEQ_LEN * BS * BATCHES / time_fn() - # print(f"Old Cache: {time_fn()} ({tokens_per_second} tps)", flush=True) - - exemplar = {"input_ids": np.zeros((SEQ_LEN,), dtype=np.int32)} - - with capture_time() as time_fn: - bench_new_cache_serial(exemplar, cache_path) - tokens_per_second = SEQ_LEN * BS * BATCHES / time_fn() - print(f"New Cache Serial: {time_fn()} ({tokens_per_second} tps)", flush=True) - - with capture_time() as time_fn: - asyncio.run(bench_new_cache_serial_tokenseq(exemplar, cache_path)) - tokens_per_second = SEQ_LEN * BS * BATCHES / time_fn() - - print(f"New Cache Serial TokenSeq: {time_fn()} ({tokens_per_second} tps)", flush=True) - - with capture_time() as time_fn: - bench_new_cache_random(exemplar, cache_path) - tokens_per_second = SEQ_LEN * BS * BATCHES / time_fn() - - print(f"New Cache Random: {time_fn()} ({tokens_per_second} tps)", flush=True) - - with capture_time() as time_fn: - asyncio.run(bench_new_cache_permutation_random(exemplar, cache_path)) - tokens_per_second = SEQ_LEN * BS * BATCHES / time_fn() - - print(f"New Cache Permutation: {time_fn()} ({tokens_per_second} tps)", flush=True) diff --git a/src/levanter/utils/actor_pool.py b/src/levanter/utils/actor_pool.py deleted file mode 100644 index a694bee20..000000000 --- a/src/levanter/utils/actor_pool.py +++ /dev/null @@ -1,270 +0,0 @@ -import asyncio -import logging -from abc import ABC -from typing import Any, Callable, Dict, List, Optional, TypeVar - -import ray - - -V = TypeVar("V") -R = TypeVar("R") - -logger = logging.getLogger(__name__) - -# Copilot-Adapted from: -# https://github.com/ray-project/ray/blob/1bab09bf842edee51c3778be4cfb16f8b900d764/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py - - -def _wrap_ray_future(ray_future): - # work around https://github.com/ray-project/ray/issues/45895#issuecomment-2165164129 - return asyncio.wrap_future(ray_future.future()) - - -class AutoScalingActorPool: - """Utility class to operate on a dynamically scaling pool of actors.""" - - def __init__( - self, - create_actor_fn: Callable[[], "ray.actor.ActorHandle"], - min_size: int = 1, - max_size: int = 10, - ): - if max_size < min_size: - raise ValueError("max_size must be greater than or equal to min_size.") - self._create_actor_fn = create_actor_fn - self._min_size = min_size - self._max_size = max_size - - self._idle_actors: List[ray.actor.ActorHandle] = [] - self._busy_actors: Dict[ray.ObjectRef, ray.actor.ActorHandle] = {} - self._pending_actors: Dict[ray.ObjectRef, ray.actor.ActorHandle] = {} - - self._actor_locations: Dict[ray.actor.ActorHandle, str] = {} - self._tasks_waiting_for_actor: list[asyncio.Future] = [] - self._next_task_id = 0 - self._scale_down_task: Optional[asyncio.Task] = None - - self._scale_up(self._min_size) - - @property - def num_pending_tasks(self): - return len(self._tasks_waiting_for_actor) - - def resize_pool(self, *, min_size: Optional[int] = None, max_size: Optional[int] = None): - old_min_size = self._min_size - if min_size is not None: - self._min_size = min_size - old_max_size = self._max_size - if max_size is not None: - self._max_size = max_size - - if old_min_size != self._min_size or old_max_size != self._max_size: - logger.info(f"Resizing pool to min_size: {self._min_size}, max_size: {self._max_size}") - - self._adjust_pool_size() - - def get_max_size(self): - return self._max_size - - def get_min_size(self): - return self._min_size - - def _scale_up(self, num_actors: int): - if self._scale_down_task and not self._scale_down_task.done(): - self._scale_down_task.cancel() - - for _ in range(num_actors): - try: - actor = self._create_actor_fn() - ready_ref = actor.get_location.remote() - self._pending_actors[ready_ref] = actor - - async def wait_for_ready(actor, ready_ref): - loc = await _wrap_ray_future(ready_ref) - # pending -> floating - if ready_ref not in self._pending_actors: - logger.info("Actor was cancelled before it was ready.") - return - del self._pending_actors[ready_ref] - self._assert_is_floating(actor) - self._actor_locations[actor] = loc - self._maybe_start_pending_task(actor) # floating -> {idle, busy} - - asyncio.ensure_future(wait_for_ready(actor, ready_ref)) - - except Exception as e: - logger.error("Failed to create actor.", exc_info=e) - - def _scale_down(self, target_num_actors: int): - while len(self._idle_actors) + len(self._pending_actors) > target_num_actors: - if self._pending_actors: - actor = self._pending_actors.popitem()[1] - # let it die through gc - # ray.kill(actor) - elif self._idle_actors: - actor = self._idle_actors.pop() - del self._actor_locations[actor] - # let it die through gc - # ray.kill(actor) - else: - break - - def _adjust_pool_size(self): - num_pending_tasks = self.num_pending_tasks - num_idle_actors = len(self._idle_actors) - num_busy_actors = len(self._busy_actors) - num_pending_actors = len(self._pending_actors) - - num_nonworking_actors = num_idle_actors + num_pending_actors - total_actors = num_nonworking_actors + num_busy_actors - - # TODO: better autoscale logic - if ( - num_pending_actors == 0 - and num_pending_tasks > 0 - and num_idle_actors == 0 - and total_actors < self._max_size - ): - logger.info( - f"Scaling up due to {num_pending_tasks} pending tasks. Current pool size: {total_actors}. Max size:" - f" {self._max_size}" - ) - self._scale_up(min(self._max_size - num_busy_actors, num_pending_tasks)) - - # Schedule scale down if idle - elif num_pending_tasks == 0 and num_nonworking_actors > self._min_size: - if self._scale_down_task is None: - self._scale_down_task = asyncio.create_task(self._schedule_scale_down()) - - async def _schedule_scale_down(self): - try: - await asyncio.sleep(10) - if self.num_pending_tasks == 0: - logger.info("Scaling down due to no pending tasks.") - self._scale_down(self._min_size) - self._scale_down_task = None - except asyncio.CancelledError: - logger.debug("Scale down task was cancelled due to new activity.") - - def _get_object_location(self, obj_ref: ray.ObjectRef) -> Optional[str]: - """Get the location of the given object reference.""" - try: - locs = ray.experimental.get_object_locations([obj_ref]) - nodes = locs[obj_ref]["node_ids"] - if nodes: - return nodes[0] - except Exception as e: - logger.error(f"Failed to get object location: {e}") - return None - - def _pick_actor(self, obj_ref: Optional[ray.ObjectRef] = None) -> Optional[ray.actor.ActorHandle]: - """Pick an actor based on locality and busyness.""" - # idle -> floating - if not self._idle_actors: - return None - - if obj_ref: - preferred_loc = self._get_object_location(obj_ref) - else: - preferred_loc = None - - def penalty_key(actor): - """Returns the key that should be minimized for the best actor.""" - requires_remote_fetch = self._actor_locations[actor] != preferred_loc - return requires_remote_fetch - - actor = min(self._idle_actors, key=penalty_key) - actor = self._idle_actors.pop(self._idle_actors.index(actor)) - return actor - - def submit(self, fn: Callable[["ray.actor.ActorHandle", V], R], value: V, obj_ref: Optional[ray.ObjectRef] = None): - actor = self._pick_actor(obj_ref) - if actor: - return self._assign_task_to_actor(actor, fn, value) - else: - actor_future: asyncio.Future = asyncio.Future() - self._tasks_waiting_for_actor.append(actor_future) - f = asyncio.ensure_future(self._enqueue_pending_task(fn, obj_ref, value, actor_future)) - self._adjust_pool_size() - return f - - def _assign_task_to_actor(self, actor, fn, value): - # floating -> busy - ray_future = fn(actor, value) - self._busy_actors[ray_future] = actor - if self._scale_down_task and not self._scale_down_task.done(): - self._scale_down_task.cancel() - self._adjust_pool_size() - - return asyncio.ensure_future(self._set_up_actor_return_on_finished(ray_future)) - - async def _enqueue_pending_task(self, fn, obj_ref, value, actor_future): - actor = await actor_future - return await self._assign_task_to_actor(actor, fn, value) - - def _assert_is_floating(self, actor): - assert actor not in self._idle_actors - assert actor not in self._busy_actors - assert actor not in self._pending_actors - - def _maybe_start_pending_task(self, actor): - self._assert_is_floating(actor) - if self._tasks_waiting_for_actor: - # floating -> busy (inside the _enqueue_pending_task coroutine) - actor_future = self._tasks_waiting_for_actor.pop(0) - actor_future.set_result(actor) - assigned = True - else: - # floating -> idle - self._idle_actors.append(actor) - self._adjust_pool_size() - assigned = False - return assigned - - async def _set_up_actor_return_on_finished(self, ray_future): - future = _wrap_ray_future(ray_future) - await asyncio.wait([future]) - self._on_task_done(ray_future) - return await future - - def _on_task_done(self, ray_future): - actor = self._busy_actors.pop(ray_future) - self._maybe_start_pending_task(actor) - - async def map( - self, - fn: Callable[["ray.actor.ActorHandle", V], Any], - values: List[V], - obj_refs: Optional[List[Optional[ray.ObjectRef]]] = None, - ) -> List[Any]: - if obj_refs is None: - obj_refs = [None] * len(values) - - tasks = [self.submit(fn, v, obj_ref) for v, obj_ref in zip(values, obj_refs)] - return await asyncio.gather(*tasks) - - def has_free(self): - return bool(self._idle_actors) - - def has_free_or_pending_actors(self): - return bool(self._idle_actors) or bool(self._pending_actors) - - def pop_idle(self): - if self._idle_actors: - return self._idle_actors.pop() - return None - - def push(self, actor: "ray.actor.ActorHandle"): - location = ray.get(actor.get_location.remote()) - self._actor_locations[actor] = location - self._maybe_start_pending_task(actor) - - def __del__(self): - if self._scale_down_task and not self._scale_down_task.done(): - self._scale_down_task.cancel() - # just let ray kill the actors naturally - - -class PoolWorkerBase(ABC): - def get_location(self) -> str: - return ray.get_runtime_context().get_node_id() diff --git a/src/levanter/utils/py_utils.py b/src/levanter/utils/py_utils.py index 8431e1c3a..dab038452 100644 --- a/src/levanter/utils/py_utils.py +++ b/src/levanter/utils/py_utils.py @@ -2,7 +2,6 @@ import sys import time from dataclasses import dataclass -from typing import Callable, TypeVar def logical_cpu_core_count(): @@ -68,100 +67,6 @@ def wrap(cls): return wrap(_cls) -# slightly modified from https://github.com/tensorflow/tensorflow/blob/14ea9d18c36946b09a1b0f4c0eb689f70b65512c/tensorflow/python/util/decorator_utils.py -# to make TF happy -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - - -class classproperty(object): # pylint: disable=invalid-name - """Class property decorator. - - Example usage: - - class MyClass(object): - - @classproperty - def value(cls): - return '123' - - > print MyClass.value - 123 - """ - - def __init__(self, func): - self._func = func - - def __get__(self, owner_self, owner_cls): - return self._func(owner_cls) - - -class _CachedClassProperty(object): - """Cached class property decorator. - - Transforms a class method into a property whose value is computed once - and then cached as a normal attribute for the life of the class. Example - usage: - - >>> class MyClass(object): - ... @cached_classproperty - ... def value(cls): - ... print("Computing value") - ... return '<property of %s>' % cls.__name__ - >>> class MySubclass(MyClass): - ... pass - >>> MyClass.value - Computing value - '<property of MyClass>' - >>> MyClass.value # uses cached value - '<property of MyClass>' - >>> MySubclass.value - Computing value - '<property of MySubclass>' - - This decorator is similar to `functools.cached_property`, but it adds a - property to the class, not to individual instances. - """ - - def __init__(self, func): - self._func = func - self._cache = {} - - def __get__(self, obj, objtype): - if objtype not in self._cache: - self._cache[objtype] = self._func(objtype) - return self._cache[objtype] - - def __set__(self, obj, value): - raise AttributeError("property %s is read-only" % self._func.__name__) - - def __delete__(self, obj): - raise AttributeError("property %s is read-only" % self._func.__name__) - - -# modification based on https://github.com/python/mypy/issues/2563 -PropReturn = TypeVar("PropReturn") - - -def cached_classproperty(func: Callable[..., PropReturn]) -> PropReturn: - return _CachedClassProperty(func) # type: ignore - - -cached_classproperty.__doc__ = _CachedClassProperty.__doc__ - - def actual_sizeof(obj): """similar to sys.getsizeof, but recurses into dicts and lists and other objects""" seen = set() diff --git a/tests/test_actor_pool.py b/tests/test_actor_pool.py deleted file mode 100644 index 08686eb30..000000000 --- a/tests/test_actor_pool.py +++ /dev/null @@ -1,167 +0,0 @@ -import asyncio -import time - -import pytest -import ray - -from levanter.utils.actor_pool import AutoScalingActorPool, PoolWorkerBase -from levanter.utils.py_utils import logical_cpu_core_count - - -@ray.remote -class TestActor(PoolWorkerBase): - def __init__(self): - self.node_id = ray.get_runtime_context().get_node_id() - - def get_node_id(self): - return self.node_id - - def double(self, v): - return 2 * v - - -@ray.remote -class BlockerActor(PoolWorkerBase): - def __init__(self): - self.node_id = ray.get_runtime_context().get_node_id() - self.unblocked = False - self.unblock_event = asyncio.Event() - - def get_node_id(self): - return self.node_id - - async def block(self): - if not self.unblocked: - await self.unblock_event.wait() - - async def unblock(self): - self.unblocked = True - self.unblock_event.set() - - -@ray.remote -class BlockingTestActor(PoolWorkerBase): - def __init__(self, blocker): - self.node_id = ray.get_runtime_context().get_node_id() - self.blocker = blocker - - def get_node_id(self): - return self.node_id - - def double(self, v, bypass_blocker=False): - if not bypass_blocker: - ray.get(self.blocker.block.remote()) - return 2 * v - - -# Helper function to create a TestActor -def create_test_actor(): - return TestActor.remote() - - -def create_test_actor_blocker(blocker_handle): - return BlockingTestActor.remote(blocker_handle) - - -def setup_module(module): - ray.init( - "local", num_cpus=max(2 * logical_cpu_core_count(), 8), ignore_reinit_error=True - ) # 2x cpu count is faster on my m1 - - -def teardown_module(module): - ray.shutdown() - - -@pytest.mark.asyncio -async def test_basic_submit(): - pool = AutoScalingActorPool(create_test_actor, min_size=1, max_size=4) - results = [pool.submit(lambda a, v: a.double.remote(v), i) for i in range(4)] - results = [await r for r in results] - - assert results == [0, 2, 4, 6] - - -@pytest.mark.asyncio -async def test_basic_submit_no_idle(): - pool = AutoScalingActorPool(create_test_actor, min_size=0, max_size=4) - results = [pool.submit(lambda a, v: a.double.remote(v), i) for i in range(4)] - results = [await r for r in results] - - assert results == [0, 2, 4, 6] - - -@pytest.mark.asyncio -async def test_basic_functionality(): - pool = AutoScalingActorPool(create_test_actor, min_size=1, max_size=4) - results = list(await pool.map(lambda a, v: a.double.remote(v), [1, 2, 3, 4])) - assert results == [2, 4, 6, 8] - - -@pytest.mark.asyncio -async def test_scaling_up(): - blocker = BlockerActor.remote() - pool = AutoScalingActorPool(lambda: create_test_actor_blocker(blocker), min_size=1, max_size=4) - f1 = pool.submit(lambda a, v: a.double.remote(v), 1) - f2 = pool.submit(lambda a, v: a.double.remote(v), 2) - f3 = pool.submit(lambda a, v: a.double.remote(v, True), 3) - f4 = pool.submit(lambda a, v: a.double.remote(v, True), 4) - - shield_f2 = asyncio.shield(f2) - with pytest.raises(asyncio.TimeoutError): - await asyncio.wait_for(shield_f2, timeout=0.1) - - assert (await asyncio.gather(f3, f4)) == [6, 8] - - await blocker.unblock.remote() - # assert (await asyncio.gather(f1, f2)) == [2, 4] - assert (await f1) == 2 - assert (await f2) == 4 - - -@pytest.mark.asyncio -async def test_scaling_down(): - pool = AutoScalingActorPool(create_test_actor, min_size=1, max_size=4) - await pool.submit(lambda a, v: a.double.remote(v), 1) - await pool.submit(lambda a, v: a.double.remote(v), 2) - await pool.submit(lambda a, v: a.double.remote(v), 3) - await pool.submit(lambda a, v: a.double.remote(v), 4) - results = await asyncio.gather( - pool.submit(lambda a, v: a.double.remote(v), 1), - pool.submit(lambda a, v: a.double.remote(v), 2), - pool.submit(lambda a, v: a.double.remote(v), 3), - pool.submit(lambda a, v: a.double.remote(v), 4), - ) - assert results == [2, 4, 6, 8] - assert len(pool._idle_actors) == 1 - assert len(pool._busy_actors) == 0 - - -@pytest.mark.asyncio -async def test_push_pop_idle(): - pool = AutoScalingActorPool(create_test_actor, min_size=1, max_size=4) - await pool.submit(lambda a, v: a.double.remote(v), 1) - actor = pool.pop_idle() - assert actor is not None - pool.push(actor) - assert len(pool._idle_actors) == 1 - - -@pytest.mark.asyncio -async def test_submit_with_no_idle_actors(): - blocker = BlockerActor.remote() - pool = AutoScalingActorPool(lambda: create_test_actor_blocker(blocker), min_size=1, max_size=4) - futs = [pool.submit(lambda a, v: a.double.remote(v), i) for i in range(4)] - f5 = pool.submit(lambda a, v: a.double.remote(v), 5) - await _sleep_until(lambda: pool.num_pending_tasks == 1, timeout=10) - await blocker.unblock.remote() - await asyncio.gather(*futs) - assert (await f5) == 10 - - -async def _sleep_until(condition, timeout=5, message="Condition not met within timeout"): - start = time.time() - while not condition(): - if time.time() - start > timeout: - pytest.fail(message) - await asyncio.sleep(0.1) diff --git a/tests/test_longformer.py b/tests/test_longformer.py deleted file mode 100644 index c964499a0..000000000 --- a/tests/test_longformer.py +++ /dev/null @@ -1,102 +0,0 @@ -import jax -import jax.numpy as jnp -import numpy as np -from chex import assert_trees_all_close - -import haliax as hax -from haliax import Axis -from haliax.nn.attention import causal_mask - -from levanter.models.longformer import causal_sliding_window_attention - - -def test_causal_sliding_window_attention_simple(): - # test that we can't attend to something outside of the range - D = 2 - for L, W in [(10, 5), (15, 5)]: - Pos = Axis("Pos", L) - Window = Axis("Window", W) - Head = Axis("Head", D) - - keys = np.zeros((L, D), dtype=np.float32) - keys[0, 0] = 100.0 # really want to attend to this - values = np.zeros((L, D), dtype=np.float32) - values[0, 1] = 300.0 # check if we did attend - - query = np.ones((L, D), dtype=np.float32) - - query = hax.named(query, (Pos, Head)) - keys = hax.named(keys, (Pos, Head)) - values = hax.named(values, (Pos, Head)) - - result = causal_sliding_window_attention(Pos, Window, Head, query, keys, values) - # we should be able to attend to the previous W positions for each position (including current), so 6-10 can't attend - # to 0-4 and can't get the 100.0 key - result = result.rearrange((Pos, Head)).array - assert_trees_all_close(result[0:W, 1], 300) - assert_trees_all_close(result[W:, 1], 0) - - -def test_sliding_window_attention_fancier(): - D = 4 - for L, W in [(2, 1), (2, 2), (4, 2), (10, 5), (15, 5), (16, 2), (15, 3), (10, 10)]: - Pos = Axis("Pos", L) - Window = Axis("Window", W) - Head = Axis("Head", D) - - q_key, k_key, v_key = jax.random.split(jax.random.PRNGKey(0), 3) - - query = hax.random.uniform(q_key, (Pos, Head)) - keys = hax.random.uniform(k_key, (Pos, Head)) - values = hax.random.uniform(v_key, (Pos, Head)) - - result = causal_sliding_window_attention(Pos, Window, Head, query, keys, values) - result = result.rearrange((Pos, Head)).array - - KPos = Axis("KPos", Pos.size) - keys = keys.rename({Pos: KPos}) - values = values.rename({Pos: KPos}) - - diff = hax.arange(Pos).broadcast_axis(KPos) - hax.arange(KPos).broadcast_axis(Pos) - mask = causal_mask(Pos, KPos) & (diff < Window.size) & (diff >= 0) - - # check that the result is the same as non-blocked attention with the right mask - expected = hax.nn.attention.dot_product_attention(KPos, Head, query, keys, values, mask=mask) - - expected = expected.rearrange((Pos, Head)).array - - assert_trees_all_close(result, expected, atol=1e-3, rtol=1e-3) - - -def test_longformer_alibi_bias_pos_invariance(): - D = 1 - W = 32 - H = 1 - - L = 4096 - - Head = Axis("Head", H) - Pos = Axis("Pos", L) - Window = Axis("Window", W) - Hidden = Axis("Hidden", D) - - # this cycles [31, ..., 0, 31, ..., 0, ...] - cycle = np.flip(np.arange(W, dtype=np.float32)) - v = np.tile(cycle, L // W).reshape((L, H, D)) - v = hax.named(v, (Pos, Head, Hidden)) - - q = hax.ones((Pos, Head, Hidden), dtype=jnp.bfloat16) * 0.001 - k = hax.ones((Pos, Head, Hidden), dtype=jnp.bfloat16) * 0.001 - - # bias gets geometrically larger as we go further in the sequence - # this is especially true if there are a lot of heads - big_head = hax.Axis("Head", 16) - # NB: this test doesn't work if you use bfloat16 for biases - bias = hax.nn.attention.alibi_attention_bias(big_head, Pos, dtype=jnp.float32).slice(big_head, Head, 0) - - attn = causal_sliding_window_attention(Pos, Window, Hidden, q, k, v, bias=bias, attention_dtype=jnp.bfloat16) - attn = attn.rearrange((Pos, Head, Hidden)).array.reshape(L) - - # final value for each cycle should be the same - finals = attn[W - 1 :: W] - assert np.isclose(finals, finals[0], rtol=2e-4).all(), f"finals: {finals}" diff --git a/tests/test_prefetch_actor.py b/tests/test_prefetch_actor.py deleted file mode 100644 index e48546fc1..000000000 --- a/tests/test_prefetch_actor.py +++ /dev/null @@ -1,137 +0,0 @@ -import time -from typing import Iterator - -import pytest -import ray - -from levanter.store._prefetch_actor import RayPrefetchQueue - - -def _sleep_until(condition, timeout=5, message="Condition not met within timeout"): - start = time.time() - while not condition(): - if time.time() - start > timeout: - pytest.fail(message) - time.sleep(0.1) - - -@pytest.fixture(scope="module", autouse=True) -def ray_init_and_shutdown(): - ray.init("local", ignore_reinit_error=True) - yield - ray.shutdown() - - -@pytest.mark.ray -def test_initialization_and_basic_functionality(): - def simple_producer(): - for i in range(10): - yield i - - actor = RayPrefetchQueue(simple_producer) - results = [actor.get_next() for _ in range(10)] - assert results == list(range(10)) - - -@pytest.mark.ray -def test_queue_size_limit(): - def simple_producer() -> Iterator[ray.ObjectRef]: - for i in range(100): - yield i - - actor = RayPrefetchQueue(simple_producer, max_queue_size=10) - # Allow some time for the queue to fill up - _sleep_until(lambda: actor.queue_size() == 10) - - # get a few items to make some space - [actor.get_next() for _ in range(5)] - _sleep_until(lambda: actor.queue_size() == 10, message="Queue size did not reach 10") - - -@pytest.mark.ray -def test_stop_functionality(): - def simple_producer(): - for i in range(10000): - yield i - - actor = RayPrefetchQueue(simple_producer) - actor.stop() - - _sleep_until(lambda: actor.is_stopped(), message="Actor did not stop") - - -@pytest.mark.ray -def test_exception_handling(): - def faulty_producer(): - for i in range(5): - yield i - raise ValueError("Test exception") - - actor = RayPrefetchQueue(faulty_producer) - results = [] - try: - for _ in range(10): - results.append(actor.get_next()) - except ValueError as e: - assert "Test exception" in str(e) # Ray puts a lot of crap in the exception message - assert results == list(range(5)) - - -@pytest.mark.ray -def test_empty_producer(): - def empty_producer() -> Iterator[ray.ObjectRef]: - if False: - yield - - actor = RayPrefetchQueue(empty_producer) - with pytest.raises(StopIteration): - actor.get_next() - - -@pytest.mark.ray -def test_multiple_consumers(): - def simple_producer() -> Iterator[ray.ObjectRef]: - for i in range(20): - yield i - - actor = RayPrefetchQueue(simple_producer) - results = [actor.get_next() for _ in range(10)] - results += [actor.get_next() for _ in range(10)] - assert results == list(range(20)) - - -@pytest.mark.ray -def test_producer_completion(): - def simple_producer(): - for i in range(10): - yield i - - actor = RayPrefetchQueue(simple_producer) - results = [] - try: - while True: - results.append(actor.get_next()) - except StopIteration: - pass - assert results == list(range(10)) - - -@pytest.mark.ray -def test_drain_queue(): - def simple_producer(): - for i in range(10): - yield i - - actor = RayPrefetchQueue(simple_producer) - - all_results = [] - - for tot in range(0, 5): - out = actor.drain_available(tot) - assert len(out) <= tot - all_results.extend(out) - - while len(all_results) < 10: - all_results.append(actor.get_next()) - - assert all_results == list(range(10))