Skip to content

Commit

Permalink
Merge pull request #651 from exo-explore/parallelise_model_loadin
Browse files Browse the repository at this point in the history
parallelise model loading
  • Loading branch information
AlexCheema authored Jan 29, 2025
2 parents 75091e2 + 4887be5 commit f6ed830
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 22 deletions.
30 changes: 17 additions & 13 deletions exo/inference/mlx/sharded_inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(self, shard_downloader: ShardDownloader):
self._mlx_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix="mlx")
self._tokenizer_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix="tokenizer")
self.session = {}
self._shard_lock = asyncio.Lock()

async def _eval_mlx(self, *args):
await asyncio.get_running_loop().run_in_executor(self._mlx_thread, mx.eval, *args)
Expand Down Expand Up @@ -157,19 +158,22 @@ def train_step(inp, tar, lng):
return score, first_layer

async def ensure_shard(self, shard: Shard):
if self.shard == shard:
return
model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
if self.shard != shard:
model_shard = await asyncio.get_running_loop().run_in_executor(self._mlx_thread, lambda: load_model_shard(model_path, shard, lazy=False))
if hasattr(model_shard, "tokenizer"):
self.tokenizer = model_shard.tokenizer
else:
self.tokenizer = await resolve_tokenizer(model_path)
self.shard = shard
self.model = model_shard
self.caches = OrderedDict()
self.session = {}
async with self._shard_lock:
if self.shard == shard: return
model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
if self.shard != shard:
model_shard = await asyncio.get_running_loop().run_in_executor(
self._mlx_thread,
lambda: load_model_shard(model_path, shard, lazy=False)
)
if hasattr(model_shard, "tokenizer"):
self.tokenizer = model_shard.tokenizer
else:
self.tokenizer = await resolve_tokenizer(model_path)
self.shard = shard
self.model = model_shard
self.caches = OrderedDict()
self.session = {}

async def cleanup(self):
self._mlx_thread.shutdown(wait=True)
12 changes: 3 additions & 9 deletions exo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,33 +193,27 @@ def update_prompt_viz(request_id, opaque_status: str):
traceback.print_exc()
node.on_opaque_status.register("update_prompt_viz").on_next(update_prompt_viz)

def preemptively_start_download(request_id: str, opaque_status: str):
def preemptively_load_shard(request_id: str, opaque_status: str):
try:
status = json.loads(opaque_status)
if status.get("type") != "node_status" or status.get("status") != "start_process_prompt": return
current_shard = node.get_current_shard(Shard.from_dict(status.get("shard")))
if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}")
asyncio.create_task(shard_downloader.ensure_shard(current_shard, node.inference_engine.__class__.__name__))
asyncio.create_task(node.inference_engine.ensure_shard(current_shard))
except Exception as e:
if DEBUG >= 2:
print(f"Failed to preemptively start download: {e}")
traceback.print_exc()


node.on_opaque_status.register("start_download").on_next(preemptively_start_download)
node.on_opaque_status.register("preemptively_load_shard").on_next(preemptively_load_shard)

last_broadcast_time = 0


def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
global last_broadcast_time
current_time = time.time()
if event.status == "not_started": return
if event.status == "complete" or current_time - last_broadcast_time >= 0.1:
last_broadcast_time = current_time
asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))


shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)

async def run_model_cli(node: Node, model_name: str, prompt: str):
Expand Down

0 comments on commit f6ed830

Please sign in to comment.