forked from lucataco/cog-whisperspeech
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
66 lines (57 loc) · 2.83 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# Prediction interface for Cog ⚙️
# https://github.com/replicate/cog/blob/main/docs/python.md
import os
import time
import subprocess
WEIGHTS_FOLDER = "/src/models/"
os.environ['HF_HOME'] = WEIGHTS_FOLDER
os.environ['HF_HUB_CACHE'] = WEIGHTS_FOLDER
os.environ['TORCH_HOME'] = WEIGHTS_FOLDER
os.environ['PYANNOTE_CACHE'] = WEIGHTS_FOLDER
from cog import BasePredictor, Input, Path
import torch
from whisperspeech.pipeline import Pipeline
from speechbrain.pretrained import EncoderClassifier
MODELS_URL = "https://weights.replicate.delivery/default/whisper-speech/models.tar"
MODELS_PATH = WEIGHTS_FOLDER
def download_weights(url, dest, extract=True):
start = time.time()
print("downloading url: ", url)
print("downloading to: ", dest)
args = ["pget"]
if extract:
args.append("-x")
subprocess.check_call(args + [url, dest], close_fds=False)
print("downloading took: ", time.time() - start)
class Predictor(BasePredictor):
def setup(self) -> None:
"""Load the model into memory to make running multiple predictions efficient"""
# download model weights to cache
# if we already have the model, this doesn't do anything
if not os.path.exists(Path(MODELS_PATH) / "speechbrain"):
download_weights(MODELS_URL, MODELS_PATH)
self.pipe = Pipeline(s2a_ref='collabora/whisperspeech:s2a-q4-small-en+pl.model')
# source: https://github.com/collabora/WhisperSpeech/blob/a4f9c2de1a7e2e0b77f2acb08374de347414e2fa/whisperspeech/pipeline.py#L68-L72
self.pipe.encoder = EncoderClassifier.from_hparams("speechbrain/spkrec-ecapa-voxceleb",
savedir=Path(WEIGHTS_FOLDER) / "speechbrain",
run_opts={"device": "cuda"})
def predict(
self,
prompt: str = Input(description="Text to synthesize", default="This is the first demo of Whisper Speech, a fully open source text-to-speech model trained by Collabora and Lion on the Juwels supercomputer."),
language: str = Input(
description="Language to synthesize", default="en",
choices=["en", "pl"]
),
speaker: str = Input(
description="URL of an audio file for zero-shot voice cloning. Supported audio formats are OGG, WAV and MP3. (example: https://upload.wikimedia.org/wikipedia/commons/7/75/Winston_Churchill_-_Be_Ye_Men_of_Valour.ogg)",
default="",
),
) -> Path:
"""Run a single prediction on the model"""
output_path = "/tmp/output.wav"
# Check if voice cloning is used
if(speaker == ""):
self.pipe.generate_to_file(output_path, prompt, lang=language)
else:
self.pipe.generate_to_file(output_path, prompt, lang=language, speaker=speaker)
return Path(output_path)