Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

XTTS: add inference_stream_text (slightly friendlier for text-streaming) #3724

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 132 additions & 54 deletions TTS/tts/models/xtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ def __init__(self, config: Coqpit):
self.decoder_checkpoint = self.args.decoder_checkpoint # TODO: check if this is even needed
self.models_dir = config.model_dir
self.gpt_batch_size = self.args.gpt_batch_size
self._stream_text_holder = []
self._stream_generator = None

self.tokenizer = VoiceBpeTokenizer()
self.gpt = None
Expand Down Expand Up @@ -632,64 +634,140 @@ def inference_stream(
length_scale = 1.0 / max(speed, 0.05)
gpt_cond_latent = gpt_cond_latent.to(self.device)
speaker_embedding = speaker_embedding.to(self.device)
if enable_text_splitting:
text = split_sentence(text, language, self.tokenizer.char_limits[language])
else:
text = [text]
text_streaming = (text is None)

for sent in text:
sent = sent.strip().lower()
text_tokens = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0).to(self.device)
while True:
if text_streaming:
yield None
if len(self._stream_text_holder) == 0:
return
text, enable_text_splitting = self._stream_text_holder

assert (
text_tokens.shape[-1] < self.args.gpt_max_text_tokens
), " ❗ XTTS can only generate text with a maximum of 400 tokens."
if enable_text_splitting:
text = split_sentence(text, language, self.tokenizer.char_limits[language])
else:
text = [text]

fake_inputs = self.gpt.compute_embeddings(
gpt_cond_latent.to(self.device),
text_tokens,
)
gpt_generator = self.gpt.get_generator(
fake_inputs=fake_inputs,
top_k=top_k,
top_p=top_p,
temperature=temperature,
do_sample=do_sample,
num_beams=1,
num_return_sequences=1,
length_penalty=float(length_penalty),
repetition_penalty=float(repetition_penalty),
output_attentions=False,
output_hidden_states=True,
**hf_generate_kwargs,
)
for sent in text:
sent = sent.strip().lower()
text_tokens = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0).to(self.device)

assert (
text_tokens.shape[-1] < self.args.gpt_max_text_tokens
), " ❗ XTTS can only generate text with a maximum of 400 tokens."

fake_inputs = self.gpt.compute_embeddings(
gpt_cond_latent.to(self.device),
text_tokens,
)
gpt_generator = self.gpt.get_generator(
fake_inputs=fake_inputs,
top_k=top_k,
top_p=top_p,
temperature=temperature,
do_sample=do_sample,
num_beams=1,
num_return_sequences=1,
length_penalty=float(length_penalty),
repetition_penalty=float(repetition_penalty),
output_attentions=False,
output_hidden_states=True,
**hf_generate_kwargs,
)

last_tokens = []
all_latents = []
wav_gen_prev = None
wav_overlap = None
is_end = False

while not is_end:
try:
x, latent = next(gpt_generator)
last_tokens += [x]
all_latents += [latent]
except StopIteration:
is_end = True

if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size):
gpt_latents = torch.cat(all_latents, dim=0)[None, :]
if length_scale != 1.0:
gpt_latents = F.interpolate(
gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear"
).transpose(1, 2)
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len
)
last_tokens = []
yield wav_chunk

if not text_streaming:
return

def inference_stream_text(
self,
language,
gpt_cond_latent,
speaker_embedding,
# Streaming
stream_chunk_size=20,
overlap_wav_len=1024,
# GPT inference
temperature=0.75,
length_penalty=1.0,
repetition_penalty=10.0,
top_k=50,
top_p=0.85,
do_sample=True,
speed=1.0,
**hf_generate_kwargs,
):
if self._stream_generator is not None:
raise Exception('Inference text-streaming already in progress. '
'Did you forget to call inference_finalize_text?')

# Arguments `text` and `enable_text_splitting` given through holder
self._stream_text_holder = [None, None]
self._stream_generator = self.inference_stream(
None,
language,
gpt_cond_latent,
speaker_embedding,
stream_chunk_size=stream_chunk_size,
overlap_wav_len=overlap_wav_len,
temperature=temperature,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
top_k=top_k,
top_p=top_p,
do_sample=do_sample,
speed=speed,
**hf_generate_kwargs,
)

last_tokens = []
all_latents = []
wav_gen_prev = None
wav_overlap = None
is_end = False

while not is_end:
try:
x, latent = next(gpt_generator)
last_tokens += [x]
all_latents += [latent]
except StopIteration:
is_end = True

if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size):
gpt_latents = torch.cat(all_latents, dim=0)[None, :]
if length_scale != 1.0:
gpt_latents = F.interpolate(
gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear"
).transpose(1, 2)
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len
)
last_tokens = []
yield wav_chunk
# Start the generator and return it
_ = next(self._stream_generator)
return self._stream_generator

def inference_add_text(self, text: str, enable_text_splitting=False):
if self._stream_generator is None:
raise Exception('Inference text-streaming not started. '
'Please call inference_stream_text first')
self._stream_text_holder[0] = text
self._stream_text_holder[1] = enable_text_splitting

def inference_finalize_text(self):
if self._stream_generator is None:
raise Exception('Inference text-streaming was not started '
'(start with inference_stream_text)')
# Finalize and reset the generator
self._stream_text_holder.clear()
try:
_ = next(self._stream_generator)
except StopIteration:
pass
self._stream_generator = None

def forward(self):
raise NotImplementedError(
Expand Down
42 changes: 38 additions & 4 deletions docs/source/models/xtts.md
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ torchaudio.save("xtts.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000)
```


##### Streaming manually
##### Streaming inference

Here the goal is to stream the audio as it is being generated. This is useful for real-time applications.
Streaming inference is typically slower than regular inference, but it allows to get a first chunk of audio faster.
Expand Down Expand Up @@ -253,16 +253,50 @@ chunks = model.inference_stream(
speaker_embedding
)

wav_chuncks = []
wav_chunks = []
for i, chunk in enumerate(chunks):
if i == 0:
print(f"Time to first chunck: {time.time() - t0}")
print(f"Received chunk {i} of audio length {chunk.shape[-1]}")
wav_chuncks.append(chunk)
wav = torch.cat(wav_chuncks, dim=0)
wav_chunks.append(chunk)
wav = torch.cat(wav_chunks, dim=0)
torchaudio.save("xtts_streaming.wav", wav.squeeze().unsqueeze(0).cpu(), 24000)
```

If you also need to do text-streaming you can use `inference_stream_text`, like so:

```python
# ...same setup as before

def text_streaming_generator():
yield "It took me quite a long time to develop a voice and now that I have it I am not going to be silent."
yield "Having discovered not just one, but many voices, I will champion each."

print("Inference with text streaming...")

text_gen = text_streaming_generator()
inf_gen = model.inference_stream_text(
"en",
gpt_cond_latent,
speaker_embedding
)

wav_chunks = []
for text in text_gen:
# Add text progressively
model.inference_add_text(text, enable_text_splitting=True)
for chunk in enumerate(inf_gen):
if chunk is None:
break # all chunks generated for the current text
print(f"Received chunk {len(wav_chunks)} of audio length {chunk.shape[-1]}")
wav_chunks.append(chunk)

# Call finalize to discard the inference generator
model.inference_finalize_text()

wav = torch.cat(wav_chunks, dim=0)
torchaudio.save("xtts_streaming_text.wav", wav.squeeze().unsqueeze(0).cpu(), 24000)
```

### Training

Expand Down
Loading