diff --git a/exo/download/new_shard_download.py b/exo/download/new_shard_download.py index be5a2d22a..e74102847 100644 --- a/exo/download/new_shard_download.py +++ b/exo/download/new_shard_download.py @@ -105,7 +105,7 @@ def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_prog elapsed_time = time.time() - all_start_time all_speed = all_downloaded_bytes_this_session / elapsed_time if elapsed_time > 0 else 0 all_eta = timedelta(seconds=(all_total_bytes - all_downloaded_bytes) / all_speed) if all_speed > 0 else timedelta(seconds=0) - status = "not_started" if all_downloaded_bytes == 0 else "complete" if all_downloaded_bytes == all_total_bytes else "in_progress" + status = "complete" if all(p.status == "complete" for p in file_progress.values()) else "in_progress" if any(p.status == "in_progress" for p in file_progress.values()) else "not_started" return RepoProgressEvent(shard, repo_id, revision, len([p for p in file_progress.values() if p.downloaded == p.total]), len(file_progress), all_downloaded_bytes, all_downloaded_bytes_this_session, all_total_bytes, all_speed, all_eta, file_progress, status) async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]: @@ -147,12 +147,12 @@ def on_progress_wrapper(file: dict, curr_bytes: int, total_bytes: int): downloaded_this_session = file_progress[file["path"]].downloaded_this_session + (curr_bytes - file_progress[file["path"]].downloaded) if file["path"] in file_progress else curr_bytes speed = downloaded_this_session / (time.time() - start_time) eta = timedelta(seconds=(total_bytes - curr_bytes) / speed) - file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], curr_bytes, downloaded_this_session, total_bytes, speed, eta, "in_progress", start_time) + file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], curr_bytes, downloaded_this_session, total_bytes, speed, eta, "complete" if curr_bytes == total_bytes else "in_progress", start_time) on_progress.trigger_all(shard, calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time)) if DEBUG >= 6: print(f"Downloading {file['path']} {curr_bytes}/{total_bytes} {speed} {eta}") for file in filtered_file_list: downloaded_bytes = (await aios.stat(target_dir/file["path"])).st_size if await aios.path.exists(target_dir/file["path"]) else 0 - file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], downloaded_bytes, 0, file["size"], 0, timedelta(0), "not_started" if downloaded_bytes == 0 else "complete" if downloaded_bytes == file["size"] else "in_progress", time.time()) + file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], downloaded_bytes, 0, file["size"], 0, timedelta(0), "complete" if downloaded_bytes == file["size"] else "not_started", time.time()) semaphore = asyncio.Semaphore(max_parallel_downloads) async def download_with_semaphore(file): diff --git a/exo/inference/tinygrad/inference.py b/exo/inference/tinygrad/inference.py index 6543f0b80..8e336dce1 100644 --- a/exo/inference/tinygrad/inference.py +++ b/exo/inference/tinygrad/inference.py @@ -61,12 +61,13 @@ def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=No return model +_executor = ThreadPoolExecutor(max_workers=1) # singleton so tinygrad always runs on the same thread class TinygradDynamicShardInferenceEngine(InferenceEngine): def __init__(self, shard_downloader: ShardDownloader): self.shard = None self.shard_downloader = shard_downloader - self.executor = ThreadPoolExecutor(max_workers=1) self.states = OrderedDict() + self.executor = _executor def poll_state(self, x, request_id: str, max_states=2): if request_id not in self.states: @@ -79,8 +80,8 @@ def poll_state(self, x, request_id: str, max_states=2): return {"start_pos": state.start, "cache": state.cache} async def sample(self, x: np.ndarray, temp=TEMPERATURE, top_p: float = 0.0) -> np.ndarray: - logits = x[:, -1, :] def sample_wrapper(): + logits = x[:, -1, :] return sample_logits(Tensor(logits).flatten(), temp, 0, 0.8, top_p, 0.0).realize().numpy().astype(int) return await asyncio.get_running_loop().run_in_executor(self.executor, sample_wrapper) @@ -112,9 +113,9 @@ def wrap_infer(): state = self.poll_state(h, request_id) out = self.model.forward(h, **state) self.states[request_id].start += x.shape[1] - return out.realize() + return out.numpy() output_data = await asyncio.get_running_loop().run_in_executor(self.executor, wrap_infer) - return output_data.numpy(), inference_state + return output_data, inference_state async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss=length_masked_ce_loss): def step(x, y, l): diff --git a/exo/main.py b/exo/main.py index 5b20808b8..e16daea94 100644 --- a/exo/main.py +++ b/exo/main.py @@ -206,14 +206,16 @@ def preemptively_load_shard(request_id: str, opaque_status: str): traceback.print_exc() node.on_opaque_status.register("preemptively_load_shard").on_next(preemptively_load_shard) -last_broadcast_time = 0 +last_events: dict[str, tuple[float, RepoProgressEvent]] = {} def throttled_broadcast(shard: Shard, event: RepoProgressEvent): - global last_broadcast_time + global last_events current_time = time.time() if event.status == "not_started": return - if event.status == "complete" or current_time - last_broadcast_time >= 0.1: - last_broadcast_time = current_time - asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()}))) + last_event = last_events.get(shard.model_id) + if last_event and last_event[1].status == "complete" and event.status == "complete": return + if last_event and last_event[0] == event.status and current_time - last_event[0] < 0.2: return + last_events[shard.model_id] = (current_time, event) + asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()}))) shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast) async def run_model_cli(node: Node, model_name: str, prompt: str): diff --git a/exo/viz/topology_viz.py b/exo/viz/topology_viz.py index 12b161cfa..82064b423 100644 --- a/exo/viz/topology_viz.py +++ b/exo/viz/topology_viz.py @@ -89,16 +89,16 @@ def _generate_prompt_output_layout(self) -> Panel: # Calculate available height for content panel_height = 15 # Fixed panel height available_lines = panel_height - 2 # Subtract 2 for panel borders - lines_per_entry = available_lines // len(requests) if requests else 0 + lines_per_request = available_lines // len(requests) if requests else 0 for (prompt, output) in reversed(requests): prompt_icon, output_icon = "💬️", "🤖" - # Calculate max lines for prompt and output - max_prompt_lines = max(3, lines_per_entry // 2) # Ensure at least 3 lines for prompt - max_output_lines = lines_per_entry - max_prompt_lines - 1 # Remaining space minus spacing + # Equal space allocation for prompt and output + max_prompt_lines = lines_per_request // 2 + max_output_lines = lines_per_request - max_prompt_lines - 1 # -1 for spacing - # Process prompt with more generous line allocation + # Process prompt prompt_lines = [] for line in prompt.split('\n'): words = line.split() @@ -118,53 +118,55 @@ def _generate_prompt_output_layout(self) -> Panel: if current_line: prompt_lines.append(' '.join(current_line)) - # Show more prompt content and append ellipses to last line if needed + # Truncate prompt if needed if len(prompt_lines) > max_prompt_lines: prompt_lines = prompt_lines[:max_prompt_lines] - # Append ellipses to last line if there's room, otherwise truncate last line - last_line = prompt_lines[-1] - if len(last_line) + 4 <= max_width: # +4 for " ..." - prompt_lines[-1] = last_line + " ..." - else: - prompt_lines[-1] = last_line[:max_width-4] + " ..." + if prompt_lines: + last_line = prompt_lines[-1] + if len(last_line) + 4 <= max_width: + prompt_lines[-1] = last_line + " ..." + else: + prompt_lines[-1] = last_line[:max_width-4] + " ..." prompt_text = Text(f"{prompt_icon} ", style="bold bright_blue") prompt_text.append('\n'.join(prompt_lines), style="white") + content.append(prompt_text) - # Process output - same word-aware wrapping - output_lines = [] - for line in output.split('\n'): - words = line.split() - current_line = [] - current_length = 0 - - for word in words: - if current_length + len(word) + 1 <= max_width: - current_line.append(word) - current_length += len(word) + 1 - else: - if current_line: - output_lines.append(' '.join(current_line)) - current_line = [word] - current_length = len(word) - - if current_line: - output_lines.append(' '.join(current_line)) - - if len(output_lines) > max_output_lines: - output_lines = output_lines[:max_output_lines] - last_line = output_lines[-1] if output_lines else None - if last_line: - if len(last_line) + 4 <= max_width: - output_lines[-1] = last_line + " ..." - else: - output_lines[-1] = last_line[:max_width-4] + " ..." - - output_text = Text(f"\n{output_icon} ", style="bold bright_magenta") - output_text.append('\n'.join(output_lines), style="white") + # Process output with similar word wrapping + if output: # Only process output if it exists + output_lines = [] + for line in output.split('\n'): + words = line.split() + current_line = [] + current_length = 0 + + for word in words: + if current_length + len(word) + 1 <= max_width: + current_line.append(word) + current_length += len(word) + 1 + else: + if current_line: + output_lines.append(' '.join(current_line)) + current_line = [word] + current_length = len(word) + + if current_line: + output_lines.append(' '.join(current_line)) + + # Truncate output if needed + if len(output_lines) > max_output_lines: + output_lines = output_lines[:max_output_lines] + if output_lines: + last_line = output_lines[-1] + if len(last_line) + 4 <= max_width: + output_lines[-1] = last_line + " ..." + else: + output_lines[-1] = last_line[:max_width-4] + " ..." + + output_text = Text(f"{output_icon} ", style="bold bright_magenta") + output_text.append('\n'.join(output_lines), style="white") + content.append(output_text) - content.append(prompt_text) - content.append(output_text) content.append(Text()) # Empty line between entries return Panel( diff --git a/setup.py b/setup.py index 54e787761..a158d4430 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ "transformers==4.46.3", "uuid==1.30", "uvloop==0.21.0", - "tinygrad @ git+https://github.com/tinygrad/tinygrad.git@3b26e51fcebfc6576f4e0f99693e6f1406d61d79", + "tinygrad @ git+https://github.com/tinygrad/tinygrad.git@ec120ce6b9ce8e4ff4b5692566a683ef240e8bc8", ] extras_require = {