Skip to content

Commit

Permalink
clarify cache folder
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeronymous committed Mar 3, 2024
1 parent f7e6fff commit 0609d7c
Showing 1 changed file with 17 additions and 6 deletions.
23 changes: 17 additions & 6 deletions whisper_timestamped/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2392,7 +2392,7 @@ def load_model(
backend : str, optional
Backend to use. Either "transformers" or "openai-whisper".
download_root : str, optional
Root folder to download the model to. If None, use the default download root.
Root folder to download the model to. If None, use the default download root (typically: ~/.cache)
in_memory : bool, optional
Whether to preload the model weights into host memory.
"""
Expand All @@ -2405,11 +2405,12 @@ def load_model(
name = f"openai/whisper-{name}"
# TODO: use download_root
# TODO: does in_memory makes sense?
cache_dir=os.path.join(download_root, "huggingface", "hub") if download_root else None,
try:
generation_config = transformers.GenerationConfig.from_pretrained(name)
generation_config = transformers.GenerationConfig.from_pretrained(name, cache_dir=cache_dir)
except OSError:
generation_config = transformers.GenerationConfig.from_pretrained("openai/whisper-tiny")
processor = transformers.WhisperProcessor.from_pretrained(name)
generation_config = transformers.GenerationConfig.from_pretrained("openai/whisper-tiny", cache_dir=cache_dir)
processor = transformers.WhisperProcessor.from_pretrained(name, cache_dir=cache_dir)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
precision = torch.float32
Expand All @@ -2421,6 +2422,7 @@ def load_model(
# torch_dtype=torch.bfloat16,
# attn_implementation="flash_attention_2",
# attn_implementation="sdpa",
cache_dir=cache_dir,
)
# model = model.to_bettertransformer()

Expand All @@ -2433,7 +2435,12 @@ def load_model(
extension = os.path.splitext(name)[-1] if os.path.isfile(name) else None

if name in whisper.available_models() or extension == ".pt":
return whisper.load_model(name, device=device, download_root=download_root, in_memory=in_memory)
return whisper.load_model(
name,
device=device,
download_root=os.path.join(download_root, "whisper") if download_root else None,
in_memory=in_memory
)

# Otherwise, assume transformers
if extension in [".ckpt", ".bin"]:
Expand All @@ -2446,7 +2453,11 @@ def load_model(
raise ImportError(f"If you are trying to download a HuggingFace model with {name}, please install first the transformers library")
from transformers.utils import cached_file

kwargs = dict(cache_dir=download_root, use_auth_token=None, revision=None)
kwargs = dict(
cache_dir=os.path.join(download_root, "huggingface", "hub") if download_root else None,
use_auth_token=None,
revision=None,
)
try:
model_path = cached_file(name, "pytorch_model.bin", **kwargs)
except OSError as err:
Expand Down

0 comments on commit 0609d7c

Please sign in to comment.