Skip to content

Commit

Permalink
throttle repo progress events and only send them out if something cha…
Browse files Browse the repository at this point in the history
…nged
  • Loading branch information
AlexCheema committed Jan 29, 2025
1 parent 96f1aec commit 3675804
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions exo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,14 +206,16 @@ def preemptively_load_shard(request_id: str, opaque_status: str):
traceback.print_exc()
node.on_opaque_status.register("preemptively_load_shard").on_next(preemptively_load_shard)

last_broadcast_time = 0
last_events: dict[str, tuple[float, RepoProgressEvent]] = {}
def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
global last_broadcast_time
global last_events
current_time = time.time()
if event.status == "not_started": return
if event.status == "complete" or current_time - last_broadcast_time >= 0.1:
last_broadcast_time = current_time
asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
last_event = last_events.get(shard.model_id)
if last_event and last_event[1].status == "complete" and event.status == "complete": return
if last_event and last_event[0] == event.status and current_time - last_event[0] < 0.2: return
last_events[shard.model_id] = (current_time, event)
asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)

async def run_model_cli(node: Node, model_name: str, prompt: str):
Expand Down

0 comments on commit 3675804

Please sign in to comment.