From 17b83e2f4edc06f33a7883b20f98266c1831a3d4 Mon Sep 17 00:00:00 2001 From: xinsen <1324154699@qq.com> Date: Tue, 6 Aug 2024 16:17:33 +0800 Subject: [PATCH] patch stream bug --- TTS/tts/layers/xtts/stream_generator.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/TTS/tts/layers/xtts/stream_generator.py b/TTS/tts/layers/xtts/stream_generator.py index e12f8995cf..67e9ae597d 100644 --- a/TTS/tts/layers/xtts/stream_generator.py +++ b/TTS/tts/layers/xtts/stream_generator.py @@ -183,10 +183,14 @@ def generate( requires_attention_mask = "encoder_outputs" not in model_kwargs if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask: + pad_token_tensor = torch.tensor([generation_config.pad_token_id], + device=inputs_tensor.device) if generation_config.pad_token_id is not None else None + eos_token_tensor = torch.tensor([generation_config.eos_token_id], + device=inputs_tensor.device) if generation_config.eos_token_id is not None else None model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( inputs_tensor, - generation_config.pad_token_id, - generation_config.eos_token_id, + pad_token_tensor, + eos_token_tensor, ) # decoder-only models should use left-padding for generation @@ -409,7 +413,8 @@ def generate( ) elif is_sample_gen_stream_mode: # 11. prepare logits warper - logits_warper = self._get_logits_warper(generation_config) + # logits_warper = self._get_logits_warper(generation_config) + logits_warper = self._get_logits_warper(generation_config, device=inputs_tensor.device) # 12. expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation(