From c90cff9bef4341cea376d4e330dcddabeaf4d86a Mon Sep 17 00:00:00 2001
From: Nikil Ravi <55033516+nikil-ravi@users.noreply.github.com>
Date: Thu, 5 Dec 2024 13:51:46 -0800
Subject: [PATCH 1/3] make eval_harness part of levanter namespace (#833)

For being able to import as `from levanter.eval_harness import __`
---
 src/levanter/__init__.py | 1 +
 1 file changed, 1 insertion(+)

diff --git a/src/levanter/__init__.py b/src/levanter/__init__.py
index f9570aaf7..6b2cbeb1e 100644
--- a/src/levanter/__init__.py
+++ b/src/levanter/__init__.py
@@ -3,6 +3,7 @@
 import levanter.data as data
 import levanter.distributed as distributed
 import levanter.eval as eval
+import levanter.eval_harness as eval_harness
 import levanter.models as models
 import levanter.optim as optim
 import levanter.tracker as tracker

From 091f1cd2a8ffa4c43b79f7c68a08eff27b007a73 Mon Sep 17 00:00:00 2001
From: Ahmed Ahmed <ahmedah@stanford.edu>
Date: Thu, 5 Dec 2024 22:06:58 -0800
Subject: [PATCH 2/3] fix toml to capture dev transformers (#834)

---
 pyproject.toml | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/pyproject.toml b/pyproject.toml
index 233af26f5..f2a63f7ae 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -29,7 +29,7 @@ dependencies = [
     "equinox>=0.11.7",
     "jaxtyping>=0.2.34",
     "tokenizers>=0.15.2",
-    "transformers>=4.41.2,<4.48.0",
+    "transformers>=4.41.2,<4.49.0",
     "optax>=0.1.9",
     "wandb>=0.17.8",
     "draccus>=0.9.3",

From 19b5f93d908f6875e35cddd36da6927032a4b01e Mon Sep 17 00:00:00 2001
From: David Hall <dlwh@cs.stanford.edu>
Date: Fri, 6 Dec 2024 15:19:53 -0800
Subject: [PATCH 3/3] remove a bunch of old unused stuff (#832)

---
 pyproject.toml                              |   1 -
 src/levanter/data/_preprocessor.py          |  53 ----
 src/levanter/data/metrics_monitor.py        |  87 -------
 src/levanter/data/shard_cache.py            |   0
 src/levanter/mesh.py                        |  58 -----
 src/levanter/models/longformer.py           | 114 ---------
 src/levanter/store/_prefetch_actor.py       | 156 -----------
 src/levanter/store/stress_test_new_cache.py | 148 -----------
 src/levanter/utils/actor_pool.py            | 270 --------------------
 src/levanter/utils/py_utils.py              |  95 -------
 tests/test_actor_pool.py                    | 167 ------------
 tests/test_longformer.py                    | 102 --------
 tests/test_prefetch_actor.py                | 137 ----------
 13 files changed, 1388 deletions(-)
 delete mode 100644 src/levanter/data/shard_cache.py
 delete mode 100644 src/levanter/mesh.py
 delete mode 100644 src/levanter/models/longformer.py
 delete mode 100644 src/levanter/store/_prefetch_actor.py
 delete mode 100644 src/levanter/store/stress_test_new_cache.py
 delete mode 100644 src/levanter/utils/actor_pool.py
 delete mode 100644 tests/test_actor_pool.py
 delete mode 100644 tests/test_longformer.py
 delete mode 100644 tests/test_prefetch_actor.py

diff --git a/pyproject.toml b/pyproject.toml
index f2a63f7ae..f07ff3864 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -49,7 +49,6 @@ dependencies = [
     "dataclasses-json~=0.6.4",
     "ray[default]>=2.34.0",
     "pydantic<3",
-    "rich~=13.0",
     "filelock~=3.13",
     "async-lru~=2.0",
     "tqdm-loggable>=0.2",
diff --git a/src/levanter/data/_preprocessor.py b/src/levanter/data/_preprocessor.py
index dd6578667..dd810857b 100644
--- a/src/levanter/data/_preprocessor.py
+++ b/src/levanter/data/_preprocessor.py
@@ -1,13 +1,8 @@
-import logging
 from abc import ABC, abstractmethod
 from typing import Any, Callable, Dict, Generic, Iterable, Mapping, Sequence, TypeVar, Union
 
 import numpy as np
 import pyarrow as pa
-import ray
-
-from levanter.utils.actor_pool import AutoScalingActorPool, PoolWorkerBase
-from levanter.utils.ray_utils import RefBox
 
 
 T = TypeVar("T")
@@ -236,54 +231,6 @@ def to_hf_batched(x):
     return {b.field(i).name: to_hf_batched(b.column(i).to_numpy(zero_copy_only=False)) for i in range(b.num_columns)}
 
 
-@ray.remote(num_cpus=0.1)  # keep this low b/c it doesn't do much
-class BatchProcessorPool:
-    def __init__(self, processor: BatchProcessor, min_size: int = 1, max_size: int = 10):
-        logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(message)s")
-        processor_ref = ray.put(processor)
-        self.actor_pool = AutoScalingActorPool(
-            lambda: _create_batch_processor_actor(processor, processor_ref), min_size, max_size
-        )
-
-    async def process_batch(self, batch_ref: RefBox):
-        return await self.actor_pool.submit(
-            lambda a, b: a.process_batch.remote(b), batch_ref.ref, obj_ref=batch_ref.ref
-        )
-
-    def num_pending_tasks(self):
-        return self.actor_pool.num_pending_tasks
-
-    def resize_pool(self, *, min_size: int | None = None, max_size: int | None = None):
-        self.actor_pool.resize_pool(min_size=min_size, max_size=max_size)
-
-    def ensure_max_at_least(self, size: int):
-        self.actor_pool.resize_pool(max_size=max(size, self.actor_pool.get_max_size()))
-
-
-def _create_batch_processor_actor(processor: BatchProcessor, processor_ref):
-    cpus = processor.num_cpus
-    gpus = processor.num_gpus
-    resources = processor.resources
-    return _BatchProcessorActor.options(  # type: ignore
-        num_cpus=cpus, num_gpus=gpus, resources=resources, scheduling_strategy="SPREAD"
-    ).remote(processor_ref)
-
-
-@ray.remote
-class _BatchProcessorActor(PoolWorkerBase):
-    def __init__(self, processor: BatchProcessor):
-        from levanter.store.tree_store import TreeBatchPreparer
-
-        self.processor = processor
-        self.preparer = TreeBatchPreparer(processor.output_exemplar)
-
-    def process_batch(self, batch):
-        result = self.processor(batch)
-        result = _canonicalize_batch(result)
-        prepared = self.preparer(result)
-        return prepared
-
-
 def _canonicalize_batch(batch: Union[dict, list[dict]]) -> list[dict]:
     if isinstance(batch, pa.RecordBatch):
         batch = dict_from_record_batch(batch)
diff --git a/src/levanter/data/metrics_monitor.py b/src/levanter/data/metrics_monitor.py
index 96c17ec65..d6cd496f8 100644
--- a/src/levanter/data/metrics_monitor.py
+++ b/src/levanter/data/metrics_monitor.py
@@ -1,21 +1,11 @@
 import dataclasses
 import logging as pylogging
-import threading
 import time
 from dataclasses import dataclass
 from typing import Any, Dict, Optional, Protocol, Union
 
 import jax
 from dataclasses_json import dataclass_json
-from rich.progress import (
-    BarColumn,
-    Progress,
-    TaskID,
-    TaskProgressColumn,
-    TextColumn,
-    TimeElapsedColumn,
-    TimeRemainingColumn,
-)
 
 import levanter.tracker
 
@@ -35,53 +25,6 @@ def __call__(self, metrics: InProgressCacheMetrics):
         ...
 
 
-class RichMetricsMonitor(MetricsMonitor):
-
-    progress: Optional[Progress]  # type: ignore
-    task: Optional[TaskID]
-
-    def __init__(self, num_shards, **kwargs):
-        """kwargs are passed to rich.progress.Progress"""
-        self.kwargs = kwargs
-        self.progress: Optional[Progress] = None
-        self.task = None
-        self.num_shards = num_shards
-
-    def __call__(self, metrics: InProgressCacheMetrics):
-        if self.progress is None:
-            self._init_progress(metrics)
-
-        self.progress.update(self.task, completed=metrics.shards_finished, **dataclasses.asdict(metrics))  # type: ignore
-
-        self.progress.refresh()  # type: ignore
-
-        if metrics.is_finished:
-            self.progress.stop()  # type: ignore
-
-    def _init_progress(self, metrics):
-        columns = [
-            BarColumn(),
-            TaskProgressColumn(),
-            TextColumn("| {task.fields[rows_finished]} docs", justify="center"),
-        ]
-
-        for field in metrics.field_counts:
-            columns.append(TextColumn(f"| {{task.fields[field_counts][{field}]}} {field}", justify="center"))
-
-        columns.append(TimeElapsedColumn())
-        columns.append(TimeRemainingColumn())
-
-        self.progress = Progress(
-            *columns,
-            **self.kwargs,
-        )
-
-        self.task = self.progress.add_task(
-            "Shards", total=self.num_shards, completed=metrics.shards_finished, **dataclasses.asdict(metrics)
-        )
-        self.progress.start()
-
-
 class LoggingMetricsMonitor(MetricsMonitor):
     last_metrics: Optional[InProgressCacheMetrics]
     last_time: Optional[float]
@@ -109,16 +52,6 @@ def __call__(self, metrics: InProgressCacheMetrics):
         if metrics.is_finished:
             to_log[f"{self.prefix}/finished"] = 1
 
-        # estimate the rate of progress
-        # if self.last_metrics is not None:
-        #     assert self.last_time is not None
-        #     elapsed = time.time() - self.last_time
-        #     to_log[f"{self.prefix}/shards_per_s"] = (metrics.shards_finished - self.last_metrics.shards_finished) / elapsed
-        #     to_log[f"{self.prefix}/rows_per_s"] = (metrics.rows_finished - self.last_metrics.rows_finished) / elapsed
-        #
-        #     for field, count in metrics.field_counts.items():
-        #         to_log[f"{self.prefix}/{field}_per_s"] = (count - self.last_metrics.field_counts[field]) / elapsed
-
         self.last_metrics = metrics
         self.last_time = time.time()
 
@@ -153,23 +86,3 @@ def __call__(self, metrics: InProgressCacheMetrics):
 
         if metrics.is_finished:
             self.logger.info("Cache creation finished")
-
-
-class WaitTimeReportingThread(threading.Thread):
-    def __init__(self, report, interval=60):
-        super().__init__()
-        self.report = report
-        self.interval = interval
-        self.shutdown_event = threading.Event()
-
-    def run(self):
-        total_waited = 0
-        while True:
-            if self.shutdown_event.wait(self.interval):
-                break
-            if total_waited > 0:
-                self.report(total_waited)
-            total_waited += self.interval
-
-    def shutdown(self):
-        self.shutdown_event.set()
diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/src/levanter/mesh.py b/src/levanter/mesh.py
deleted file mode 100644
index 9783ce5eb..000000000
--- a/src/levanter/mesh.py
+++ /dev/null
@@ -1,58 +0,0 @@
-from typing import Optional
-
-import jax
-import numpy as np
-from jax.sharding import Mesh
-
-
-def local_device_grid_positions(mesh, process_index: Optional[int] = None) -> tuple[np.ndarray, np.ndarray]:
-    """Returns a tuple of nd arrays, one for each axis, indicating the position of each device on the grid.
-    Analogous to what np.where would return."""
-    if process_index is None:
-        process_index = jax.process_index()
-
-    my_device_pos = np.vectorize(lambda dev: dev.process_index == process_index)(mesh.devices)
-    return my_device_pos.nonzero()
-
-
-def local_devices_mapping(mesh: Mesh, process_index: Optional[int] = None) -> dict[int, int]:
-    """
-    Handles the case when different devices in same process share the same data in TP.
-    Returns a mapping from local devices' DP/FSDP group index in global mesh to local indices
-    """
-    local_device_pos = local_device_grid_positions(mesh, process_index)[:2]  # first 2 axes are DP axes.
-    result = {}
-    uid = 0
-    for local_device_index in range(len(local_device_pos[0])):
-        key = local_device_pos[0][local_device_index] * mesh.devices.shape[1] + local_device_pos[1][local_device_index]
-        if key not in result:
-            # when two devices maps to the same key (different TP index), they will get the same data
-            result[key] = uid
-            uid += 1
-    return result
-
-
-def process_mesh_mapping(mesh) -> dict[int, int]:
-    """
-    Handles the case when different processes share the same data in TP.
-    If we envision each process as a subgrid of the mesh for its devices, this is the position of the process
-    in the coarsened process-level mesh
-    """
-    devices = mesh.devices
-    result = {}
-    uid = 0
-    leftmost2uid = {}
-    # basic logic: process index -> upper-left device -> TP index 0 device -> process index -> uid
-    for process_index in range(jax.process_count()):
-        tmp = [np.min(axis) for axis in local_device_grid_positions(mesh, process_index)]
-        tmp[-1] = 0  # we want the device with TP group index 0 in the same DP/FSDP group
-        upper_left_position = tuple(tmp)  # in order to index into devices
-        upper_left_process = devices[upper_left_position].process_index
-        # assign uid to each process that has a device with TP group index 0
-        if upper_left_process not in leftmost2uid:
-            leftmost2uid[upper_left_process] = uid
-            uid += 1
-        this_uid = leftmost2uid[upper_left_process]
-        result[process_index] = this_uid
-
-    return result
diff --git a/src/levanter/models/longformer.py b/src/levanter/models/longformer.py
deleted file mode 100644
index c956aec14..000000000
--- a/src/levanter/models/longformer.py
+++ /dev/null
@@ -1,114 +0,0 @@
-from typing import Optional
-
-import jax.lax
-import jax.numpy as jnp
-
-import haliax as hax
-from haliax import Axis, NamedArray
-from haliax.types import PrecisionLike
-
-
-def causal_sliding_window_attention(
-    Pos: Axis,
-    Window: Axis,
-    Head: Axis,
-    query: NamedArray,
-    key: NamedArray,
-    value: NamedArray,
-    bias: Optional[NamedArray] = None,
-    attention_dtype: Optional[jnp.dtype] = None,
-    precision: PrecisionLike = None,
-) -> NamedArray:
-    """
-    Computes sliding window attention a la Longformer. This method uses blocking because Jax can't figure it
-    out automatically.
-    """
-    # We use the window size as the block size
-    # The basic idea is that we want to compute attention one block (of query) at a time, where a block is a window
-    # of the sequence. Each q can attend to the prior window_size-1 positions plus itself
-    assert Window.size <= Pos.size, "Window size must be at least 2x sequence length"
-    assert Pos.size % Window.size == 0, "Sequence length must be divisible by window size"
-
-    if Window.size == Pos.size:
-        # we can just use regular attention
-        # we have to special case this because jax won't like the attend_block_N function
-        # which doesn't actually get executed but does get traced
-        K = Pos.alias("K")
-        return hax.nn.attention.dot_product_attention(
-            K,
-            Head,
-            query,
-            key.rename({Pos: K}),
-            value.rename({Pos: K}),
-            mask=hax.nn.attention.causal_mask(Pos, K),
-            bias=bias,
-            attention_dtype=attention_dtype,
-            precision=precision,
-        )
-
-    # the attention structure is that each query attends to the prior window_size positions (q - window_size, q]
-    # We extract one query block of length Window.size at a time (the block size and window size could be different, but
-    # this seems fine)
-    # For each query block, we extract a key and value block of length KWindow == Window.size * 2 - 1
-    # The key block is [query_block_start - window_size + 1, query_block_start + window_size)
-
-    # TODO: relax?
-    Block = Axis("Block", Pos.size // Window.size)
-    KWindow = Axis("KWindow", Window.size * 2 - 1)  # this is what we need to grab from the key/value
-
-    # this makes code a bit easier to read below
-    Q = Window
-    K = KWindow
-
-    # for our attention masks, each q can attend to the prior window_size-1 positions plus itself
-    # that is, each q can attend to k s.t. k \in [q - window_size + 1, q]
-    # however, note that K is offset by window_size - 1, so we need to shift the mask by that amount
-    # this means that we want to mask out k s.t. k \in [q, q + window_size)
-    # equivalently, k - q \in [0, window_size)
-    diff = hax.arange(K) - hax.arange(Q).broadcast_axis(K)
-    attn_mask = (diff >= 0) & (diff < Window.size)
-
-    def attend_block_N(block_idx):
-        block_idx = block_idx.scalar()
-        query_block = query.slice(Pos, Q, start=block_idx * Q.size)
-        # extract the relevant window from the key and value
-        # this spans [query_block_start - window_size + 1, query_block_start + window_size)
-        key_block = key.slice(Pos, K, start=(block_idx - 1) * Q.size + 1)
-        value_block = value.slice(Pos, K, start=(block_idx - 1) * Q.size + 1)
-
-        if bias is not None:
-            bias_block = bias.slice(Pos, K, start=(block_idx - 1) * Q.size + 1)
-        else:
-            bias_block = None
-
-        return hax.nn.attention.dot_product_attention(
-            K, Head, query_block, key_block, value_block, attn_mask, bias_block, attention_dtype, precision
-        )
-
-    # for the 0th block, we have to worry about the out-of-bounds. just use a causal mask and do normal causal attention
-    # NB if you change it so that the block size and window size aren't the same, you'll need to change this
-    K0 = Q.alias("K0")
-    attn_mask_0 = hax.nn.attention.causal_mask(Q, K0)
-
-    def attend_block_0(block_idx):
-        query_block = query.slice(Pos, Q, start=0)
-        key_block = key.slice(Pos, K0, start=0)
-        value_block = value.slice(Pos, K0, start=0)
-        if bias is not None:
-            bias_block = bias.slice(Pos, K0, start=0)
-        else:
-            bias_block = None
-        return hax.nn.attention.dot_product_attention(
-            K0, Head, query_block, key_block, value_block, attn_mask_0, bias_block, attention_dtype, precision
-        )
-
-    # extra arg/return for dummy scan accumulator
-    def attend_block(_, block_idx):
-        return None, jax.lax.cond(block_idx.scalar() == 0, attend_block_0, attend_block_N, block_idx)
-
-    # we use scan here to encourage jax to do the blocking
-    _, blocked_attn = hax.scan(attend_block, Block)(None, hax.arange(Block))  # type: ignore
-
-    # now we need to unblock the attention
-    # TODO: see if the rearrange and flatten_axes have perf implications
-    return blocked_attn.flatten_axes((Block, Q), Pos).rearrange(value.axes)
diff --git a/src/levanter/store/_prefetch_actor.py b/src/levanter/store/_prefetch_actor.py
deleted file mode 100644
index 6b3c302c2..000000000
--- a/src/levanter/store/_prefetch_actor.py
+++ /dev/null
@@ -1,156 +0,0 @@
-import asyncio
-import logging
-from dataclasses import dataclass
-from queue import Empty as QueueEmpty
-from typing import Callable, Generic, Iterator, List, Optional, TypeVar
-
-import ray
-
-from levanter.utils.ray_utils import ExceptionInfo, ser_exc_info
-
-
-T = TypeVar("T")
-
-logger = logging.getLogger(__name__)
-
-
-@dataclass
-class _PrefetchException:
-    info: ExceptionInfo
-
-
-class _Sentinel:
-    pass
-
-
-_SENTINEL = _Sentinel()
-
-
-class RayPrefetchQueue(Generic[T]):
-    def __init__(
-        self, producer: Callable[[], Iterator[T]], max_queue_size: int = 100, producer_options: dict | None = None
-    ):
-        self.max_queue_size = max_queue_size
-        if producer_options is None:
-            producer_options = {}
-        self.queue_actor = _QueueActor.remote(max_queue_size)  # type: ignore
-        self.producer_task = _run_producer.options(**producer_options).remote(self.queue_actor, producer)
-        self._stopped = False
-        self._finished = False
-
-    def queue_size(self):
-        return ray.get(self.queue_actor.qsize.remote())
-
-    def __next__(self):
-        return self.get_next()
-
-    def __iter__(self):
-        return self
-
-    def get_next(self, timeout: float | None = None) -> T:
-        """
-        Get the next item from the producer. If the producer raises an exception, it will be reraised here.
-
-        If the producer is done, this will raise StopIteration.
-
-        Args:
-            timeout (float|None): Timeout in seconds for getting the next item. If None, will block indefinitely.
-
-        Raises:
-            Empty: If the queue is empty and the timeout is reached.
-        """
-        if self._finished:
-            raise StopIteration
-        # time_in = time.time()
-        item = ray.get(self.queue_actor.get_next.remote(timeout))
-        # time_out = time.time()
-        # if time_out - time_in > 0.1:
-        #     current_name = ray.get_runtime_context().get_actor_name()
-        #     print(f"{current_name} :: Queue get took {time_out - time_in} seconds :: {self.queue_size()}")
-        #     logger.info(f"{current_name} :: Queue get took {time_out - time_in} seconds :: {self.queue_size()}")
-        if isinstance(item, _PrefetchException):
-            item.info.reraise()
-        if isinstance(item, _Sentinel):
-            self._finished = True
-            raise StopIteration
-        return item
-
-    def stop(self):
-        ray.cancel(self.producer_task)
-        ray.get(self.queue_actor.stop.remote())
-        self._stopped = True
-
-    def is_stopped(self):
-        return self._stopped
-
-    def drain_available(self, max_size: int) -> List[T]:
-        return ray.get(self.queue_actor.drain_available.remote(max_size))
-
-
-@ray.remote
-class _QueueActor:
-    def __init__(self, max_queue_size: int):
-        self.queue: asyncio.Queue = asyncio.Queue(maxsize=max_queue_size)
-        self._stopped = False
-        self._finished = False
-
-    async def put(self, item):
-        await self.queue.put(item)
-
-    async def get_next(self, timeout: Optional[float] = None):
-        try:
-            if timeout is not None:
-                item = await asyncio.wait_for(self.queue.get(), timeout)
-            else:
-                item = await self.queue.get()
-            if isinstance(item, _Sentinel):
-                self._finished = True
-            return item
-        except asyncio.TimeoutError:
-            raise QueueEmpty
-
-    async def drain_available(self, max_size: int) -> List[T]:
-        items: list[T] = []
-        while len(items) < max_size:
-            try:
-                item = self.queue.get_nowait()
-                if isinstance(item, _Sentinel):
-                    self._finished = True
-                    break
-                if isinstance(item, _PrefetchException):
-                    item.info.reraise()
-                items.append(item)
-            except asyncio.QueueEmpty:
-                break
-        return items
-
-    async def qsize(self):
-        return self.queue.qsize()
-
-    async def stop(self):
-        self._stopped = True
-
-
-@ray.remote
-def _run_producer(queue_actor, producer_fn: Callable[[], Iterator[T]]):
-    async def _run_producer(queue_actor, producer_fn):
-        previous_put = None
-        try:
-            producer = producer_fn()
-            del producer_fn
-
-            while True:
-                next_item = next(producer)
-                if previous_put is not None:
-                    await previous_put
-                previous_put = queue_actor.put.remote(next_item)
-        except StopIteration:
-            if previous_put is not None:
-                await previous_put
-            await queue_actor.put.remote(_SENTINEL)
-        except Exception as e:
-            if previous_put is not None:
-                await previous_put
-            await queue_actor.put.remote(_PrefetchException(ser_exc_info(e)))
-
-    asyncio.run(_run_producer(queue_actor, producer_fn))
diff --git a/src/levanter/store/stress_test_new_cache.py b/src/levanter/store/stress_test_new_cache.py
deleted file mode 100644
index 66d002abd..000000000
--- a/src/levanter/store/stress_test_new_cache.py
+++ /dev/null
@@ -1,148 +0,0 @@
-# Reads an old-style ShardCache and compares to
-import asyncio
-import logging
-import os
-
-import jax.random
-import numpy as np
-import tensorstore as ts
-
-from levanter.data import PermutationDataset
-from levanter.data.text import TokenSeqDataset
-from levanter.store.cache import LEDGER_FILE_NAME, CacheLedger, TreeCache, _serialize_json_and_commit
-from levanter.store.tree_store import TreeStore
-from levanter.tracker import capture_time
-from levanter.utils import fsspec_utils
-
-
-logging.basicConfig(level=logging.INFO)
-
-
-SEQ_LEN = 1024
-BS = 8
-BATCHES = 1000
-
-# want to test reading from:
-# 1) old cache sequentially
-# 2) new cache sequentially
-# 3) new cache randomly
-
-
-def bench_new_cache_serial(exemplar, new_cache_path):
-    jagged_array = TreeStore.open(exemplar, new_cache_path).tree["input_ids"]
-    len_cache = jagged_array.data_size
-    new_cache = jagged_array.data
-    num_batches = len_cache // SEQ_LEN
-    for b in range(BATCHES):
-        elems = []
-        with ts.Batch():
-            for j in range(BS):
-                idx = b * BS + j
-                idx = idx % num_batches
-                arr1 = new_cache[idx * SEQ_LEN : (idx + 1) * SEQ_LEN].read()
-                elems.append(arr1)
-
-        for elem in elems:
-            elem.result()
-
-
-def bench_new_cache_random(exemplar, new_cache_path):
-    jagged_array = TreeStore.open(exemplar, new_cache_path).tree["input_ids"]
-    len_cache = jagged_array.data_size
-    new_cache = jagged_array.data
-    num_batches = len_cache // SEQ_LEN
-    for b in range(BATCHES):
-        elems = []
-        with ts.Batch():
-            for j in range(BS):
-                idx = np.random.randint(0, num_batches)
-                arr1 = new_cache[idx * SEQ_LEN : (idx + 1) * SEQ_LEN].read()
-                elems.append(arr1)
-
-        for elem in elems:
-            elem.result()
-
-
-async def bench_new_cache_serial_tokenseq(exemplar, new_cache_path):
-    ensure_cache(new_cache_path)
-    cache = TreeCache.load(new_cache_path, exemplar)
-
-    ds = TokenSeqDataset(cache, SEQ_LEN)
-
-    num_batches = await ds.async_len()
-
-    for b in range(BATCHES):
-        indices = []
-        for j in range(BS):
-            idx = b * BS + j
-            idx = idx % num_batches
-            indices.append(idx)
-        elems = await ds.get_batch(indices)
-        del elems
-
-
-async def bench_new_cache_permutation_random(exemplar, new_cache_path):
-    ensure_cache(new_cache_path)
-    cache = TreeCache.load(new_cache_path, exemplar)
-
-    ds = TokenSeqDataset(cache, SEQ_LEN)
-    ds = PermutationDataset(ds, jax.random.PRNGKey(0))
-
-    num_batches = await ds.async_len()
-
-    for b in range(BATCHES):
-        indices = []
-        for j in range(BS):
-            idx = b * BS + j
-            idx = idx % num_batches
-            indices.append(idx)
-        elems = await ds.get_batch(indices)
-        del elems
-
-
-def ensure_cache(new_cache_path):
-    if not fsspec_utils.exists(os.path.join(new_cache_path, LEDGER_FILE_NAME)):
-        ledger = CacheLedger(100000, {}, True)
-        _serialize_json_and_commit(os.path.join(new_cache_path, LEDGER_FILE_NAME), ledger)
-
-
-if __name__ == "__main__":
-    import sys
-
-    if not len(sys.argv) == 2:
-        print("Usage: convert_to_new_cache.py new_cache_path")
-        sys.exit(1)
-
-    for split in ["validation", "train"]:
-        print(f"Split: {split}", flush=True)
-        cache_path = os.path.join(sys.argv[1], split)
-        # convert_to_new_cache(in_path, out_path)
-        # with capture_time() as time_fn:
-        #     bench_old_cache(in_path)
-        # tokens_per_second = SEQ_LEN * BS * BATCHES / time_fn()
-        # print(f"Old Cache: {time_fn()} ({tokens_per_second} tps)", flush=True)
-
-        exemplar = {"input_ids": np.zeros((SEQ_LEN,), dtype=np.int32)}
-
-        with capture_time() as time_fn:
-            bench_new_cache_serial(exemplar, cache_path)
-        tokens_per_second = SEQ_LEN * BS * BATCHES / time_fn()
-        print(f"New Cache Serial: {time_fn()} ({tokens_per_second} tps)", flush=True)
-
-        with capture_time() as time_fn:
-            asyncio.run(bench_new_cache_serial_tokenseq(exemplar, cache_path))
-        tokens_per_second = SEQ_LEN * BS * BATCHES / time_fn()
-
-        print(f"New Cache Serial TokenSeq: {time_fn()} ({tokens_per_second} tps)", flush=True)
-
-        with capture_time() as time_fn:
-            bench_new_cache_random(exemplar, cache_path)
-        tokens_per_second = SEQ_LEN * BS * BATCHES / time_fn()
-
-        print(f"New Cache Random: {time_fn()} ({tokens_per_second} tps)", flush=True)
-
-        with capture_time() as time_fn:
-            asyncio.run(bench_new_cache_permutation_random(exemplar, cache_path))
-        tokens_per_second = SEQ_LEN * BS * BATCHES / time_fn()
-
-        print(f"New Cache Permutation: {time_fn()} ({tokens_per_second} tps)", flush=True)
diff --git a/src/levanter/utils/actor_pool.py b/src/levanter/utils/actor_pool.py
deleted file mode 100644
index a694bee20..000000000
--- a/src/levanter/utils/actor_pool.py
+++ /dev/null
@@ -1,270 +0,0 @@
-import asyncio
-import logging
-from abc import ABC
-from typing import Any, Callable, Dict, List, Optional, TypeVar
-
-import ray
-
-
-V = TypeVar("V")
-R = TypeVar("R")
-
-logger = logging.getLogger(__name__)
-
-# Copilot-Adapted from:
-# https://github.com/ray-project/ray/blob/1bab09bf842edee51c3778be4cfb16f8b900d764/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py
-
-
-def _wrap_ray_future(ray_future):
-    # work around https://github.com/ray-project/ray/issues/45895#issuecomment-2165164129
-    return asyncio.wrap_future(ray_future.future())
-
-
-class AutoScalingActorPool:
-    """Utility class to operate on a dynamically scaling pool of actors."""
-
-    def __init__(
-        self,
-        create_actor_fn: Callable[[], "ray.actor.ActorHandle"],
-        min_size: int = 1,
-        max_size: int = 10,
-    ):
-        if max_size < min_size:
-            raise ValueError("max_size must be greater than or equal to min_size.")
-        self._create_actor_fn = create_actor_fn
-        self._min_size = min_size
-        self._max_size = max_size
-
-        self._idle_actors: List[ray.actor.ActorHandle] = []
-        self._busy_actors: Dict[ray.ObjectRef, ray.actor.ActorHandle] = {}
-        self._pending_actors: Dict[ray.ObjectRef, ray.actor.ActorHandle] = {}
-
-        self._actor_locations: Dict[ray.actor.ActorHandle, str] = {}
-        self._tasks_waiting_for_actor: list[asyncio.Future] = []
-        self._next_task_id = 0
-        self._scale_down_task: Optional[asyncio.Task] = None
-
-        self._scale_up(self._min_size)
-
-    @property
-    def num_pending_tasks(self):
-        return len(self._tasks_waiting_for_actor)
-
-    def resize_pool(self, *, min_size: Optional[int] = None, max_size: Optional[int] = None):
-        old_min_size = self._min_size
-        if min_size is not None:
-            self._min_size = min_size
-        old_max_size = self._max_size
-        if max_size is not None:
-            self._max_size = max_size
-
-        if old_min_size != self._min_size or old_max_size != self._max_size:
-            logger.info(f"Resizing pool to min_size: {self._min_size}, max_size: {self._max_size}")
-
-        self._adjust_pool_size()
-
-    def get_max_size(self):
-        return self._max_size
-
-    def get_min_size(self):
-        return self._min_size
-
-    def _scale_up(self, num_actors: int):
-        if self._scale_down_task and not self._scale_down_task.done():
-            self._scale_down_task.cancel()
-
-        for _ in range(num_actors):
-            try:
-                actor = self._create_actor_fn()
-                ready_ref = actor.get_location.remote()
-                self._pending_actors[ready_ref] = actor
-
-                async def wait_for_ready(actor, ready_ref):
-                    loc = await _wrap_ray_future(ready_ref)
-                    # pending -> floating
-                    if ready_ref not in self._pending_actors:
-                        logger.info("Actor was cancelled before it was ready.")
-                        return
-                    del self._pending_actors[ready_ref]
-                    self._assert_is_floating(actor)
-                    self._actor_locations[actor] = loc
-                    self._maybe_start_pending_task(actor)  # floating -> {idle, busy}
-
-                asyncio.ensure_future(wait_for_ready(actor, ready_ref))
-
-            except Exception as e:
-                logger.error("Failed to create actor.", exc_info=e)
-
-    def _scale_down(self, target_num_actors: int):
-        while len(self._idle_actors) + len(self._pending_actors) > target_num_actors:
-            if self._pending_actors:
-                actor = self._pending_actors.popitem()[1]
-                # let it die through gc
-                # ray.kill(actor)
-            elif self._idle_actors:
-                actor = self._idle_actors.pop()
-                del self._actor_locations[actor]
-                # let it die through gc
-                # ray.kill(actor)
-            else:
-                break
-
-    def _adjust_pool_size(self):
-        num_pending_tasks = self.num_pending_tasks
-        num_idle_actors = len(self._idle_actors)
-        num_busy_actors = len(self._busy_actors)
-        num_pending_actors = len(self._pending_actors)
-
-        num_nonworking_actors = num_idle_actors + num_pending_actors
-        total_actors = num_nonworking_actors + num_busy_actors
-
-        # TODO: better autoscale logic
-        if (
-            num_pending_actors == 0
-            and num_pending_tasks > 0
-            and num_idle_actors == 0
-            and total_actors < self._max_size
-        ):
-            logger.info(
-                f"Scaling up due to {num_pending_tasks} pending tasks. Current pool size: {total_actors}. Max size:"
-                f" {self._max_size}"
-            )
-            self._scale_up(min(self._max_size - num_busy_actors, num_pending_tasks))
-
-        # Schedule scale down if idle
-        elif num_pending_tasks == 0 and num_nonworking_actors > self._min_size:
-            if self._scale_down_task is None:
-                self._scale_down_task = asyncio.create_task(self._schedule_scale_down())
-
-    async def _schedule_scale_down(self):
-        try:
-            await asyncio.sleep(10)
-            if self.num_pending_tasks == 0:
-                logger.info("Scaling down due to no pending tasks.")
-                self._scale_down(self._min_size)
-                self._scale_down_task = None
-        except asyncio.CancelledError:
-            logger.debug("Scale down task was cancelled due to new activity.")
-
-    def _get_object_location(self, obj_ref: ray.ObjectRef) -> Optional[str]:
-        """Get the location of the given object reference."""
-        try:
-            locs = ray.experimental.get_object_locations([obj_ref])
-            nodes = locs[obj_ref]["node_ids"]
-            if nodes:
-                return nodes[0]
-        except Exception as e:
-            logger.error(f"Failed to get object location: {e}")
-        return None
-
-    def _pick_actor(self, obj_ref: Optional[ray.ObjectRef] = None) -> Optional[ray.actor.ActorHandle]:
-        """Pick an actor based on locality and busyness."""
-        # idle -> floating
-        if not self._idle_actors:
-            return None
-
-        if obj_ref:
-            preferred_loc = self._get_object_location(obj_ref)
-        else:
-            preferred_loc = None
-
-        def penalty_key(actor):
-            """Returns the key that should be minimized for the best actor."""
-            requires_remote_fetch = self._actor_locations[actor] != preferred_loc
-            return requires_remote_fetch
-
-        actor = min(self._idle_actors, key=penalty_key)
-        actor = self._idle_actors.pop(self._idle_actors.index(actor))
-        return actor
-
-    def submit(self, fn: Callable[["ray.actor.ActorHandle", V], R], value: V, obj_ref: Optional[ray.ObjectRef] = None):
-        actor = self._pick_actor(obj_ref)
-        if actor:
-            return self._assign_task_to_actor(actor, fn, value)
-        else:
-            actor_future: asyncio.Future = asyncio.Future()
-            self._tasks_waiting_for_actor.append(actor_future)
-            f = asyncio.ensure_future(self._enqueue_pending_task(fn, obj_ref, value, actor_future))
-            self._adjust_pool_size()
-            return f
-
-    def _assign_task_to_actor(self, actor, fn, value):
-        # floating -> busy
-        ray_future = fn(actor, value)
-        self._busy_actors[ray_future] = actor
-        if self._scale_down_task and not self._scale_down_task.done():
-            self._scale_down_task.cancel()
-        self._adjust_pool_size()
-
-        return asyncio.ensure_future(self._set_up_actor_return_on_finished(ray_future))
-
-    async def _enqueue_pending_task(self, fn, obj_ref, value, actor_future):
-        actor = await actor_future
-        return await self._assign_task_to_actor(actor, fn, value)
-
-    def _assert_is_floating(self, actor):
-        assert actor not in self._idle_actors
-        assert actor not in self._busy_actors
-        assert actor not in self._pending_actors
-
-    def _maybe_start_pending_task(self, actor):
-        self._assert_is_floating(actor)
-        if self._tasks_waiting_for_actor:
-            # floating -> busy (inside the _enqueue_pending_task coroutine)
-            actor_future = self._tasks_waiting_for_actor.pop(0)
-            actor_future.set_result(actor)
-            assigned = True
-        else:
-            # floating -> idle
-            self._idle_actors.append(actor)
-            self._adjust_pool_size()
-            assigned = False
-        return assigned
-
-    async def _set_up_actor_return_on_finished(self, ray_future):
-        future = _wrap_ray_future(ray_future)
-        await asyncio.wait([future])
-        self._on_task_done(ray_future)
-        return await future
-
-    def _on_task_done(self, ray_future):
-        actor = self._busy_actors.pop(ray_future)
-        self._maybe_start_pending_task(actor)
-
-    async def map(
-        self,
-        fn: Callable[["ray.actor.ActorHandle", V], Any],
-        values: List[V],
-        obj_refs: Optional[List[Optional[ray.ObjectRef]]] = None,
-    ) -> List[Any]:
-        if obj_refs is None:
-            obj_refs = [None] * len(values)
-
-        tasks = [self.submit(fn, v, obj_ref) for v, obj_ref in zip(values, obj_refs)]
-        return await asyncio.gather(*tasks)
-
-    def has_free(self):
-        return bool(self._idle_actors)
-
-    def has_free_or_pending_actors(self):
-        return bool(self._idle_actors) or bool(self._pending_actors)
-
-    def pop_idle(self):
-        if self._idle_actors:
-            return self._idle_actors.pop()
-        return None
-
-    def push(self, actor: "ray.actor.ActorHandle"):
-        location = ray.get(actor.get_location.remote())
-        self._actor_locations[actor] = location
-        self._maybe_start_pending_task(actor)
-
-    def __del__(self):
-        if self._scale_down_task and not self._scale_down_task.done():
-            self._scale_down_task.cancel()
-        # just let ray kill the actors naturally
-
-
-class PoolWorkerBase(ABC):
-    def get_location(self) -> str:
-        return ray.get_runtime_context().get_node_id()
diff --git a/src/levanter/utils/py_utils.py b/src/levanter/utils/py_utils.py
index 8431e1c3a..dab038452 100644
--- a/src/levanter/utils/py_utils.py
+++ b/src/levanter/utils/py_utils.py
@@ -2,7 +2,6 @@
 import sys
 import time
 from dataclasses import dataclass
-from typing import Callable, TypeVar
 
 
 def logical_cpu_core_count():
@@ -68,100 +67,6 @@ def wrap(cls):
         return wrap(_cls)
 
 
-# slightly modified from https://github.com/tensorflow/tensorflow/blob/14ea9d18c36946b09a1b0f4c0eb689f70b65512c/tensorflow/python/util/decorator_utils.py
-# to make TF happy
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-
-class classproperty(object):  # pylint: disable=invalid-name
-    """Class property decorator.
-
-    Example usage:
-
-    class MyClass(object):
-
-      @classproperty
-      def value(cls):
-        return '123'
-
-    > print MyClass.value
-    123
-    """
-
-    def __init__(self, func):
-        self._func = func
-
-    def __get__(self, owner_self, owner_cls):
-        return self._func(owner_cls)
-
-
-class _CachedClassProperty(object):
-    """Cached class property decorator.
-
-    Transforms a class method into a property whose value is computed once
-    and then cached as a normal attribute for the life of the class.  Example
-    usage:
-
-    >>> class MyClass(object):
-    ...   @cached_classproperty
-    ...   def value(cls):
-    ...     print("Computing value")
-    ...     return '<property of %s>' % cls.__name__
-    >>> class MySubclass(MyClass):
-    ...   pass
-    >>> MyClass.value
-    Computing value
-    '<property of MyClass>'
-    >>> MyClass.value  # uses cached value
-    '<property of MyClass>'
-    >>> MySubclass.value
-    Computing value
-    '<property of MySubclass>'
-
-    This decorator is similar to `functools.cached_property`, but it adds a
-    property to the class, not to individual instances.
-    """
-
-    def __init__(self, func):
-        self._func = func
-        self._cache = {}
-
-    def __get__(self, obj, objtype):
-        if objtype not in self._cache:
-            self._cache[objtype] = self._func(objtype)
-        return self._cache[objtype]
-
-    def __set__(self, obj, value):
-        raise AttributeError("property %s is read-only" % self._func.__name__)
-
-    def __delete__(self, obj):
-        raise AttributeError("property %s is read-only" % self._func.__name__)
-
-
-# modification based on https://github.com/python/mypy/issues/2563
-PropReturn = TypeVar("PropReturn")
-
-
-def cached_classproperty(func: Callable[..., PropReturn]) -> PropReturn:
-    return _CachedClassProperty(func)  # type: ignore
-
-
-cached_classproperty.__doc__ = _CachedClassProperty.__doc__
-
-
 def actual_sizeof(obj):
     """similar to sys.getsizeof, but recurses into dicts and lists and other objects"""
     seen = set()
diff --git a/tests/test_actor_pool.py b/tests/test_actor_pool.py
deleted file mode 100644
index 08686eb30..000000000
--- a/tests/test_actor_pool.py
+++ /dev/null
@@ -1,167 +0,0 @@
-import asyncio
-import time
-
-import pytest
-import ray
-
-from levanter.utils.actor_pool import AutoScalingActorPool, PoolWorkerBase
-from levanter.utils.py_utils import logical_cpu_core_count
-
-
-@ray.remote
-class TestActor(PoolWorkerBase):
-    def __init__(self):
-        self.node_id = ray.get_runtime_context().get_node_id()
-
-    def get_node_id(self):
-        return self.node_id
-
-    def double(self, v):
-        return 2 * v
-
-
-@ray.remote
-class BlockerActor(PoolWorkerBase):
-    def __init__(self):
-        self.node_id = ray.get_runtime_context().get_node_id()
-        self.unblocked = False
-        self.unblock_event = asyncio.Event()
-
-    def get_node_id(self):
-        return self.node_id
-
-    async def block(self):
-        if not self.unblocked:
-            await self.unblock_event.wait()
-
-    async def unblock(self):
-        self.unblocked = True
-        self.unblock_event.set()
-
-
-@ray.remote
-class BlockingTestActor(PoolWorkerBase):
-    def __init__(self, blocker):
-        self.node_id = ray.get_runtime_context().get_node_id()
-        self.blocker = blocker
-
-    def get_node_id(self):
-        return self.node_id
-
-    def double(self, v, bypass_blocker=False):
-        if not bypass_blocker:
-            ray.get(self.blocker.block.remote())
-        return 2 * v
-
-
-# Helper function to create a TestActor
-def create_test_actor():
-    return TestActor.remote()
-
-
-def create_test_actor_blocker(blocker_handle):
-    return BlockingTestActor.remote(blocker_handle)
-
-
-def setup_module(module):
-    ray.init(
-        "local", num_cpus=max(2 * logical_cpu_core_count(), 8), ignore_reinit_error=True
-    )  # 2x cpu count is faster on my m1
-
-
-def teardown_module(module):
-    ray.shutdown()
-
-
-@pytest.mark.asyncio
-async def test_basic_submit():
-    pool = AutoScalingActorPool(create_test_actor, min_size=1, max_size=4)
-    results = [pool.submit(lambda a, v: a.double.remote(v), i) for i in range(4)]
-    results = [await r for r in results]
-
-    assert results == [0, 2, 4, 6]
-
-
-@pytest.mark.asyncio
-async def test_basic_submit_no_idle():
-    pool = AutoScalingActorPool(create_test_actor, min_size=0, max_size=4)
-    results = [pool.submit(lambda a, v: a.double.remote(v), i) for i in range(4)]
-    results = [await r for r in results]
-
-    assert results == [0, 2, 4, 6]
-
-
-@pytest.mark.asyncio
-async def test_basic_functionality():
-    pool = AutoScalingActorPool(create_test_actor, min_size=1, max_size=4)
-    results = list(await pool.map(lambda a, v: a.double.remote(v), [1, 2, 3, 4]))
-    assert results == [2, 4, 6, 8]
-
-
-@pytest.mark.asyncio
-async def test_scaling_up():
-    blocker = BlockerActor.remote()
-    pool = AutoScalingActorPool(lambda: create_test_actor_blocker(blocker), min_size=1, max_size=4)
-    f1 = pool.submit(lambda a, v: a.double.remote(v), 1)
-    f2 = pool.submit(lambda a, v: a.double.remote(v), 2)
-    f3 = pool.submit(lambda a, v: a.double.remote(v, True), 3)
-    f4 = pool.submit(lambda a, v: a.double.remote(v, True), 4)
-
-    shield_f2 = asyncio.shield(f2)
-    with pytest.raises(asyncio.TimeoutError):
-        await asyncio.wait_for(shield_f2, timeout=0.1)
-
-    assert (await asyncio.gather(f3, f4)) == [6, 8]
-
-    await blocker.unblock.remote()
-    # assert (await asyncio.gather(f1, f2)) == [2, 4]
-    assert (await f1) == 2
-    assert (await f2) == 4
-
-
-@pytest.mark.asyncio
-async def test_scaling_down():
-    pool = AutoScalingActorPool(create_test_actor, min_size=1, max_size=4)
-    await pool.submit(lambda a, v: a.double.remote(v), 1)
-    await pool.submit(lambda a, v: a.double.remote(v), 2)
-    await pool.submit(lambda a, v: a.double.remote(v), 3)
-    await pool.submit(lambda a, v: a.double.remote(v), 4)
-    results = await asyncio.gather(
-        pool.submit(lambda a, v: a.double.remote(v), 1),
-        pool.submit(lambda a, v: a.double.remote(v), 2),
-        pool.submit(lambda a, v: a.double.remote(v), 3),
-        pool.submit(lambda a, v: a.double.remote(v), 4),
-    )
-    assert results == [2, 4, 6, 8]
-    assert len(pool._idle_actors) == 1
-    assert len(pool._busy_actors) == 0
-
-
-@pytest.mark.asyncio
-async def test_push_pop_idle():
-    pool = AutoScalingActorPool(create_test_actor, min_size=1, max_size=4)
-    await pool.submit(lambda a, v: a.double.remote(v), 1)
-    actor = pool.pop_idle()
-    assert actor is not None
-    pool.push(actor)
-    assert len(pool._idle_actors) == 1
-
-
-@pytest.mark.asyncio
-async def test_submit_with_no_idle_actors():
-    blocker = BlockerActor.remote()
-    pool = AutoScalingActorPool(lambda: create_test_actor_blocker(blocker), min_size=1, max_size=4)
-    futs = [pool.submit(lambda a, v: a.double.remote(v), i) for i in range(4)]
-    f5 = pool.submit(lambda a, v: a.double.remote(v), 5)
-    await _sleep_until(lambda: pool.num_pending_tasks == 1, timeout=10)
-    await blocker.unblock.remote()
-    await asyncio.gather(*futs)
-    assert (await f5) == 10
-
-
-async def _sleep_until(condition, timeout=5, message="Condition not met within timeout"):
-    start = time.time()
-    while not condition():
-        if time.time() - start > timeout:
-            pytest.fail(message)
-        await asyncio.sleep(0.1)
diff --git a/tests/test_longformer.py b/tests/test_longformer.py
deleted file mode 100644
index c964499a0..000000000
--- a/tests/test_longformer.py
+++ /dev/null
@@ -1,102 +0,0 @@
-import jax
-import jax.numpy as jnp
-import numpy as np
-from chex import assert_trees_all_close
-
-import haliax as hax
-from haliax import Axis
-from haliax.nn.attention import causal_mask
-
-from levanter.models.longformer import causal_sliding_window_attention
-
-
-def test_causal_sliding_window_attention_simple():
-    # test that we can't attend to something outside of the range
-    D = 2
-    for L, W in [(10, 5), (15, 5)]:
-        Pos = Axis("Pos", L)
-        Window = Axis("Window", W)
-        Head = Axis("Head", D)
-
-        keys = np.zeros((L, D), dtype=np.float32)
-        keys[0, 0] = 100.0  # really want to attend to this
-        values = np.zeros((L, D), dtype=np.float32)
-        values[0, 1] = 300.0  # check if we did attend
-
-        query = np.ones((L, D), dtype=np.float32)
-
-        query = hax.named(query, (Pos, Head))
-        keys = hax.named(keys, (Pos, Head))
-        values = hax.named(values, (Pos, Head))
-
-        result = causal_sliding_window_attention(Pos, Window, Head, query, keys, values)
-        # we should be able to attend to the previous W positions for each position (including current), so 6-10 can't attend
-        # to 0-4 and can't get the 100.0 key
-        result = result.rearrange((Pos, Head)).array
-        assert_trees_all_close(result[0:W, 1], 300)
-        assert_trees_all_close(result[W:, 1], 0)
-
-
-def test_sliding_window_attention_fancier():
-    D = 4
-    for L, W in [(2, 1), (2, 2), (4, 2), (10, 5), (15, 5), (16, 2), (15, 3), (10, 10)]:
-        Pos = Axis("Pos", L)
-        Window = Axis("Window", W)
-        Head = Axis("Head", D)
-
-        q_key, k_key, v_key = jax.random.split(jax.random.PRNGKey(0), 3)
-
-        query = hax.random.uniform(q_key, (Pos, Head))
-        keys = hax.random.uniform(k_key, (Pos, Head))
-        values = hax.random.uniform(v_key, (Pos, Head))
-
-        result = causal_sliding_window_attention(Pos, Window, Head, query, keys, values)
-        result = result.rearrange((Pos, Head)).array
-
-        KPos = Axis("KPos", Pos.size)
-        keys = keys.rename({Pos: KPos})
-        values = values.rename({Pos: KPos})
-
-        diff = hax.arange(Pos).broadcast_axis(KPos) - hax.arange(KPos).broadcast_axis(Pos)
-        mask = causal_mask(Pos, KPos) & (diff < Window.size) & (diff >= 0)
-
-        # check that the result is the same as non-blocked attention with the right mask
-        expected = hax.nn.attention.dot_product_attention(KPos, Head, query, keys, values, mask=mask)
-
-        expected = expected.rearrange((Pos, Head)).array
-
-        assert_trees_all_close(result, expected, atol=1e-3, rtol=1e-3)
-
-
-def test_longformer_alibi_bias_pos_invariance():
-    D = 1
-    W = 32
-    H = 1
-
-    L = 4096
-
-    Head = Axis("Head", H)
-    Pos = Axis("Pos", L)
-    Window = Axis("Window", W)
-    Hidden = Axis("Hidden", D)
-
-    # this cycles [31, ..., 0, 31, ..., 0, ...]
-    cycle = np.flip(np.arange(W, dtype=np.float32))
-    v = np.tile(cycle, L // W).reshape((L, H, D))
-    v = hax.named(v, (Pos, Head, Hidden))
-
-    q = hax.ones((Pos, Head, Hidden), dtype=jnp.bfloat16) * 0.001
-    k = hax.ones((Pos, Head, Hidden), dtype=jnp.bfloat16) * 0.001
-
-    # bias gets geometrically larger as we go further in the sequence
-    # this is especially true if there are a lot of heads
-    big_head = hax.Axis("Head", 16)
-    # NB: this test doesn't work if you use bfloat16 for biases
-    bias = hax.nn.attention.alibi_attention_bias(big_head, Pos, dtype=jnp.float32).slice(big_head, Head, 0)
-
-    attn = causal_sliding_window_attention(Pos, Window, Hidden, q, k, v, bias=bias, attention_dtype=jnp.bfloat16)
-    attn = attn.rearrange((Pos, Head, Hidden)).array.reshape(L)
-
-    # final value for each cycle should be the same
-    finals = attn[W - 1 :: W]
-    assert np.isclose(finals, finals[0], rtol=2e-4).all(), f"finals: {finals}"
diff --git a/tests/test_prefetch_actor.py b/tests/test_prefetch_actor.py
deleted file mode 100644
index e48546fc1..000000000
--- a/tests/test_prefetch_actor.py
+++ /dev/null
@@ -1,137 +0,0 @@
-import time
-from typing import Iterator
-
-import pytest
-import ray
-
-from levanter.store._prefetch_actor import RayPrefetchQueue
-
-
-def _sleep_until(condition, timeout=5, message="Condition not met within timeout"):
-    start = time.time()
-    while not condition():
-        if time.time() - start > timeout:
-            pytest.fail(message)
-        time.sleep(0.1)
-
-
-@pytest.fixture(scope="module", autouse=True)
-def ray_init_and_shutdown():
-    ray.init("local", ignore_reinit_error=True)
-    yield
-    ray.shutdown()
-
-
-@pytest.mark.ray
-def test_initialization_and_basic_functionality():
-    def simple_producer():
-        for i in range(10):
-            yield i
-
-    actor = RayPrefetchQueue(simple_producer)
-    results = [actor.get_next() for _ in range(10)]
-    assert results == list(range(10))
-
-
-@pytest.mark.ray
-def test_queue_size_limit():
-    def simple_producer() -> Iterator[ray.ObjectRef]:
-        for i in range(100):
-            yield i
-
-    actor = RayPrefetchQueue(simple_producer, max_queue_size=10)
-    # Allow some time for the queue to fill up
-    _sleep_until(lambda: actor.queue_size() == 10)
-
-    # get a few items to make some space
-    [actor.get_next() for _ in range(5)]
-    _sleep_until(lambda: actor.queue_size() == 10, message="Queue size did not reach 10")
-
-
-@pytest.mark.ray
-def test_stop_functionality():
-    def simple_producer():
-        for i in range(10000):
-            yield i
-
-    actor = RayPrefetchQueue(simple_producer)
-    actor.stop()
-
-    _sleep_until(lambda: actor.is_stopped(), message="Actor did not stop")
-
-
-@pytest.mark.ray
-def test_exception_handling():
-    def faulty_producer():
-        for i in range(5):
-            yield i
-        raise ValueError("Test exception")
-
-    actor = RayPrefetchQueue(faulty_producer)
-    results = []
-    try:
-        for _ in range(10):
-            results.append(actor.get_next())
-    except ValueError as e:
-        assert "Test exception" in str(e)  # Ray puts a lot of crap in the exception message
-    assert results == list(range(5))
-
-
-@pytest.mark.ray
-def test_empty_producer():
-    def empty_producer() -> Iterator[ray.ObjectRef]:
-        if False:
-            yield
-
-    actor = RayPrefetchQueue(empty_producer)
-    with pytest.raises(StopIteration):
-        actor.get_next()
-
-
-@pytest.mark.ray
-def test_multiple_consumers():
-    def simple_producer() -> Iterator[ray.ObjectRef]:
-        for i in range(20):
-            yield i
-
-    actor = RayPrefetchQueue(simple_producer)
-    results = [actor.get_next() for _ in range(10)]
-    results += [actor.get_next() for _ in range(10)]
-    assert results == list(range(20))
-
-
-@pytest.mark.ray
-def test_producer_completion():
-    def simple_producer():
-        for i in range(10):
-            yield i
-
-    actor = RayPrefetchQueue(simple_producer)
-    results = []
-    try:
-        while True:
-            results.append(actor.get_next())
-    except StopIteration:
-        pass
-    assert results == list(range(10))
-
-
-@pytest.mark.ray
-def test_drain_queue():
-    def simple_producer():
-        for i in range(10):
-            yield i
-
-    actor = RayPrefetchQueue(simple_producer)
-
-    all_results = []
-
-    for tot in range(0, 5):
-        out = actor.drain_available(tot)
-        assert len(out) <= tot
-        all_results.extend(out)
-
-    while len(all_results) < 10:
-        all_results.append(actor.get_next())
-
-    assert all_results == list(range(10))