Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve HF download speed #1899

Merged
merged 3 commits into from
Jan 8, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions litgpt/scripts/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,17 @@ def download_from_hub(
else:
raise ValueError(f"Couldn't find weight files for {repo_id}")

# Get and set env variable to improve download speed
user_env_value = os.environ.get("HF_HUB_ENABLE_HF_TRANSFER")

if user_env_value is None:
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
print("Setting HF_HUB_ENABLE_HF_TRANSFER=1 by default")

import huggingface_hub._snapshot_download as download
import huggingface_hub.constants as constants

previous = constants.HF_HUB_ENABLE_HF_TRANSFER
if _HF_TRANSFER_AVAILABLE and not previous:
print("Setting HF_HUB_ENABLE_HF_TRANSFER=1")
constants.HF_HUB_ENABLE_HF_TRANSFER = True
download.HF_HUB_ENABLE_HF_TRANSFER = True
previous_flag = constants.HF_HUB_ENABLE_HF_TRANSFER # this may be redundant
rasbt marked this conversation as resolved.
Show resolved Hide resolved

directory = checkpoint_dir / repo_id
with gated_repo_catcher(repo_id, access_token):
Expand All @@ -88,8 +91,8 @@ def download_from_hub(
token=access_token,
)

constants.HF_HUB_ENABLE_HF_TRANSFER = previous
download.HF_HUB_ENABLE_HF_TRANSFER = previous
constants.HF_HUB_ENABLE_HF_TRANSFER = previous_flag
download.HF_HUB_ENABLE_HF_TRANSFER = previous_flag

if convert_checkpoint and not tokenizer_only:
print("Converting checkpoint files to LitGPT format.")
Expand Down
Loading