diff --git a/whisper_timestamped/transcribe.py b/whisper_timestamped/transcribe.py index 3b7fd61..2c14dee 100644 --- a/whisper_timestamped/transcribe.py +++ b/whisper_timestamped/transcribe.py @@ -2392,7 +2392,7 @@ def load_model( backend : str, optional Backend to use. Either "transformers" or "openai-whisper". download_root : str, optional - Root folder to download the model to. If None, use the default download root. + Root folder to download the model to. If None, use the default download root (typically: ~/.cache) in_memory : bool, optional Whether to preload the model weights into host memory. """ @@ -2405,11 +2405,12 @@ def load_model( name = f"openai/whisper-{name}" # TODO: use download_root # TODO: does in_memory makes sense? + cache_dir=os.path.join(download_root, "huggingface", "hub") if download_root else None, try: - generation_config = transformers.GenerationConfig.from_pretrained(name) + generation_config = transformers.GenerationConfig.from_pretrained(name, cache_dir=cache_dir) except OSError: - generation_config = transformers.GenerationConfig.from_pretrained("openai/whisper-tiny") - processor = transformers.WhisperProcessor.from_pretrained(name) + generation_config = transformers.GenerationConfig.from_pretrained("openai/whisper-tiny", cache_dir=cache_dir) + processor = transformers.WhisperProcessor.from_pretrained(name, cache_dir=cache_dir) if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" precision = torch.float32 @@ -2421,6 +2422,7 @@ def load_model( # torch_dtype=torch.bfloat16, # attn_implementation="flash_attention_2", # attn_implementation="sdpa", + cache_dir=cache_dir, ) # model = model.to_bettertransformer() @@ -2433,7 +2435,12 @@ def load_model( extension = os.path.splitext(name)[-1] if os.path.isfile(name) else None if name in whisper.available_models() or extension == ".pt": - return whisper.load_model(name, device=device, download_root=download_root, in_memory=in_memory) + return whisper.load_model( + name, + device=device, + download_root=os.path.join(download_root, "whisper") if download_root else None, + in_memory=in_memory + ) # Otherwise, assume transformers if extension in [".ckpt", ".bin"]: @@ -2446,7 +2453,11 @@ def load_model( raise ImportError(f"If you are trying to download a HuggingFace model with {name}, please install first the transformers library") from transformers.utils import cached_file - kwargs = dict(cache_dir=download_root, use_auth_token=None, revision=None) + kwargs = dict( + cache_dir=os.path.join(download_root, "huggingface", "hub") if download_root else None, + use_auth_token=None, + revision=None, + ) try: model_path = cached_file(name, "pytorch_model.bin", **kwargs) except OSError as err: