diff --git a/litgpt/scripts/download.py b/litgpt/scripts/download.py index 7ab609b30f..9771a37956 100644 --- a/litgpt/scripts/download.py +++ b/litgpt/scripts/download.py @@ -3,6 +3,7 @@ import os from concurrent.futures import ProcessPoolExecutor from contextlib import contextmanager +import importlib.util from pathlib import Path from typing import List, Optional, Tuple @@ -56,6 +57,8 @@ def download_from_hub( return from huggingface_hub import snapshot_download + if importlib.util.find_spec("hf_transfer") is None: + print("It is recommended to install hf_transfer for faster checkpoint download speeds: `pip install hf_transfer`") download_files = ["tokenizer*", "generation_config.json", "config.json"] if not tokenizer_only: