Skip to content

Commit

Permalink
Raise error if disk is full before downloading weights
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed Jan 8, 2025
1 parent 91f3752 commit 9f563d3
Showing 1 changed file with 35 additions and 3 deletions.
38 changes: 35 additions & 3 deletions litgpt/scripts/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from contextlib import contextmanager
import importlib.util
from pathlib import Path
import shutil
from typing import List, Optional, Tuple

import torch
Expand Down Expand Up @@ -62,7 +63,38 @@ def download_from_hub(

download_files = ["tokenizer*", "generation_config.json", "config.json"]
if not tokenizer_only:
bins, safetensors = find_weight_files(repo_id, access_token)
bins, safetensors, info = find_weight_files(repo_id, access_token)

total_weight_size_bytes = 0
if bins:
total_weight_size_bytes = sum(
(file.size or 0)
for file in info.siblings
if file.rfilename.endswith(".bin") or file.rfilename.endswith(".bin.index.json")
)
elif safetensors:
total_weight_size_bytes = sum(
(file.size or 0)
for file in info.siblings
if file.rfilename.endswith(".safetensors")
)
else:
raise ValueError(f"Couldn't find weight files for {repo_id}")

weight_size_gb = total_weight_size_bytes / (1024**3)
free_space_bytes = shutil.disk_usage(str(checkpoint_dir)).free
free_space_gb = free_space_bytes / (1024**3)

if weight_size_gb > free_space_gb:
if os.getenv("LIGHTNING_CLUSTER_ID") is not None:
studio_text = " Please switch to a larger Studio with more disk space."
else:
studio_text = ""
raise RuntimeError(
f"Not enough disk space to download {repo_id} weights. "
f"Needed: ~{weight_size_gb:.2f} GB, free: ~{free_space_gb:.2f} GB.{studio_text}"
)

if bins:
# covers `.bin` files and `.bin.index.json`
download_files.append("*.bin*")
Expand Down Expand Up @@ -104,11 +136,11 @@ def find_weight_files(repo_id: str, access_token: Optional[str]) -> Tuple[List[s
from huggingface_hub.utils import filter_repo_objects

with gated_repo_catcher(repo_id, access_token):
info = repo_info(repo_id, token=access_token)
info = repo_info(repo_id, token=access_token, files_metadata=True)
filenames = [f.rfilename for f in info.siblings]
bins = list(filter_repo_objects(items=filenames, allow_patterns=["*model*.bin*"]))
safetensors = list(filter_repo_objects(items=filenames, allow_patterns=["*.safetensors*"]))
return bins, safetensors
return bins, safetensors, info


@contextmanager
Expand Down

0 comments on commit 9f563d3

Please sign in to comment.