Skip to content

Commit

Permalink
fix eta/speed for resuming an existing download, using the session do…
Browse files Browse the repository at this point in the history
…wnloaded bytes
  • Loading branch information
AlexCheema committed Jan 27, 2025
1 parent 90e0e27 commit 7c64908
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions exo/download/new_shard_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,12 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_progress: Dict[str, RepoFileProgressEvent], all_start_time: float) -> RepoProgressEvent:
all_total_bytes = sum([p.total for p in file_progress.values()])
all_downloaded_bytes = sum([p.downloaded for p in file_progress.values()])
all_downloaded_bytes_this_session = sum([p.downloaded_this_session for p in file_progress.values()])
elapsed_time = time.time() - all_start_time
all_speed = all_downloaded_bytes / elapsed_time if elapsed_time > 0 else 0
all_speed = all_downloaded_bytes_this_session / elapsed_time if elapsed_time > 0 else 0
all_eta = timedelta(seconds=(all_total_bytes - all_downloaded_bytes) / all_speed) if all_speed > 0 else timedelta(seconds=0)
status = "not_started" if all_downloaded_bytes == 0 else "complete" if all_downloaded_bytes == all_total_bytes else "in_progress"
return RepoProgressEvent(shard, repo_id, revision, len([p for p in file_progress.values() if p.downloaded == p.total]), len(file_progress), all_downloaded_bytes, all_downloaded_bytes, all_total_bytes, all_speed, all_eta, file_progress, status)
return RepoProgressEvent(shard, repo_id, revision, len([p for p in file_progress.values() if p.downloaded == p.total]), len(file_progress), all_downloaded_bytes, all_downloaded_bytes_this_session, all_total_bytes, all_speed, all_eta, file_progress, status)

async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]:
target_dir = await ensure_exo_tmp()/repo_id.replace("/", "--")
Expand Down Expand Up @@ -143,9 +144,10 @@ async def download_shard(shard: Shard, inference_engine_classname: str, on_progr
file_progress: Dict[str, RepoFileProgressEvent] = {}
def on_progress_wrapper(file: dict, curr_bytes: int, total_bytes: int):
start_time = file_progress[file["path"]].start_time if file["path"] in file_progress else time.time()
speed = curr_bytes / (time.time() - start_time)
downloaded_this_session = file_progress[file["path"]].downloaded_this_session + (curr_bytes - file_progress[file["path"]].downloaded) if file["path"] in file_progress else curr_bytes
speed = downloaded_this_session / (time.time() - start_time)
eta = timedelta(seconds=(total_bytes - curr_bytes) / speed)
file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], curr_bytes, curr_bytes, total_bytes, speed, eta, "in_progress", start_time)
file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], curr_bytes, downloaded_this_session, total_bytes, speed, eta, "in_progress", start_time)
on_progress.trigger_all(shard, calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time))
if DEBUG >= 6: print(f"Downloading {file['path']} {curr_bytes}/{total_bytes} {speed} {eta}")
for file in filtered_file_list:
Expand Down

0 comments on commit 7c64908

Please sign in to comment.