diff --git a/exo/main.py b/exo/main.py index 5b20808b8..e16daea94 100644 --- a/exo/main.py +++ b/exo/main.py @@ -206,14 +206,16 @@ def preemptively_load_shard(request_id: str, opaque_status: str): traceback.print_exc() node.on_opaque_status.register("preemptively_load_shard").on_next(preemptively_load_shard) -last_broadcast_time = 0 +last_events: dict[str, tuple[float, RepoProgressEvent]] = {} def throttled_broadcast(shard: Shard, event: RepoProgressEvent): - global last_broadcast_time + global last_events current_time = time.time() if event.status == "not_started": return - if event.status == "complete" or current_time - last_broadcast_time >= 0.1: - last_broadcast_time = current_time - asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()}))) + last_event = last_events.get(shard.model_id) + if last_event and last_event[1].status == "complete" and event.status == "complete": return + if last_event and last_event[0] == event.status and current_time - last_event[0] < 0.2: return + last_events[shard.model_id] = (current_time, event) + asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()}))) shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast) async def run_model_cli(node: Node, model_name: str, prompt: str):