Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

temporary fix to get TTS working again #391

Merged
merged 1 commit into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/airunner/aihandler/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from airunner.aihandler.enums import EngineRequestCode, EngineResponseCode
from airunner.aihandler.logger import Logger
from airunner.mediator_mixin import MediatorMixin
from airunner.workers.tts_generator_worker import TTSGeneratorWorker
from airunner.workers.tts_vocalizer_worker import TTSVocalizerWorker
from airunner.workers.worker import Worker
from airunner.aihandler.llm import LLMController
from airunner.aihandler.logger import Logger
Expand Down Expand Up @@ -64,7 +66,6 @@ class Engine(QObject, MediatorMixin, SettingsMixin):

# Model controllers
llm_controller = None
tts_controller = None
stt_controller = None
ocr_controller = None

Expand Down Expand Up @@ -143,6 +144,14 @@ def __init__(self, **kwargs):

self.request_worker = self.create_worker(EngineRequestWorker)
self.response_worker = self.create_worker(EngineResponseWorker)

self.generator_worker = self.create_worker(TTSGeneratorWorker)
self.vocalizer_worker = self.create_worker(TTSVocalizerWorker)
self.register("tts_request", self)

@pyqtSlot(dict)
def on_tts_request(self, data: dict):
self.generator_worker.add_to_queue(data)

def on_llm_controller_response_signal(self, message):
self.do_response(message)
Expand Down
206 changes: 18 additions & 188 deletions src/airunner/aihandler/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,172 +12,10 @@

from airunner.aihandler.logger import Logger
from airunner.mediator_mixin import MediatorMixin
from airunner.workers.worker import Worker
from airunner.windows.main.settings_mixin import SettingsMixin


class VocalizerWorker(Worker):
"""
Speech (in the form of numpy arrays generated with the TTS class) is added to the
vocalizer's queue. The vocalizer plays the speech using sounddevice.
"""
reader_mode_active = False
logger = Logger(prefix="VocalizerWorker")

def __init__(self, *args, **kwargs):
super().__init__()
self.queue = Queue()
self.stream = sd.OutputStream(samplerate=24000, channels=1)
self.stream.start()
self.data = []
self.started = False
self.register("TTSGeneratorWorker_add_to_stream_signal", self)

def on_TTSGeneratorWorker_add_to_stream_signal(self, response):
self.queue.put(response)

def handle_message(self, item):
item = self.queue.get(timeout=1)
if self.started or not self.reader_mode_active:
self.stream.write(item)
else:
self.data.append(item)

if not self.started and len(self.data) >= 6 and self.reader_mode_active:
for item in self.data:
self.stream.write(item)
self.started = True
self.data = []

def handle_speech(self, generated_speech):
self.logger.info("Adding speech to stream...")
try:
self.queue.put(generated_speech)
except Exception as e:
self.logger.error(f"Error while adding speech to stream: {e}")


class TTSGeneratorWorker(Worker):
"""
Takes input text from any source and generates speech from it using the TTS class.
"""
def __init__(self, prefix="TTSGeneratorWorker"):
super().__init__(prefix)
self.tts = TTS()
self.play_queue = []
self.play_queue_started = False
self.tts_settings = None

def handle_message(self, data):
tts_settings = data["tts_settings"]
self.tts_settings = tts_settings
message = data["message"]
is_end_of_message = data["is_end_of_message"]
play_queue_buffer_length = tts_settings["play_queue_buffer_length"]
play_queue.append(data)
if is_end_of_message or len(play_queue) == play_queue_buffer_length or play_queue_started:
for item in play_queue:
self.generate(item)
play_queue_started = True
play_queue = []
if is_end_of_message or len(self.play_queue) == play_queue_buffer_length or self.play_queue_started:
self.play_queue_started = True
self.generate_message()
for item in play_queue:
self.generate(message)
self.play_queue_started = True
self.play_queue = []


def generate(self, text):
self.logger.info("Generating TTS...")
text = text.replace("\n", " ").strip()

if self.tts_settings["use_bark"]:
response = self.generate_with_bark(text)
else:
response = self.generate_with_t5(text)

self.emit("TTSGeneratorWorker_add_to_stream_signal", response)

def move_inputs_to_device(self, inputs):
use_cuda = self.memory_settings["use_cuda"]
if use_cuda:
self.logger.info("Moving inputs to CUDA")
try:
inputs = {k: v.cuda() for k, v in inputs.items()}
except AttributeError:
pass
return inputs

def generate_with_bark(self, text):
self.logger.info("Generating TTS...")
text = text.replace("\n", " ").strip()

self.logger.info("Processing inputs...")
inputs = self.parent.processor(text, voice_preset=self.tts_settings["voice"]).to(self.parent.device)
inputs = self.move_inputs_to_device(inputs)

self.logger.info("Generating speech...")
start = time.time()
params = dict(
**inputs,
fine_temperature=self.tts_settings["fine_temperature"],
coarse_temperature=self.tts_settings["coarse_temperature"],
semantic_temperature=self.tts_settings["semantic_temperature"],
)
speech = self.parent.model.generate(**params)
self.logger.info("Generated speech in " + str(time.time() - start) + " seconds")

response = speech[0].cpu().float().numpy()
return response

def generate_with_t5(self, text):
self.logger.info("Generating TTS...")
text = text.replace("\n", " ").strip()

self.logger.info("Processing inputs...")

inputs = self.parent.processor(text=text, return_tensors="pt")
inputs = self.move_inputs_to_device(inputs)

self.logger.info("Generating speech...")
start = time.time()
params = dict(
**inputs,
speaker_embeddings=self.parent.speaker_embeddings,
vocoder=self.parent.vocoder,
max_length=100,
)
speech = self.parent.model.generate(**params)
self.logger.info("Generated speech in " + str(time.time() - start) + " seconds")
response = speech.cpu().float().numpy()
return response


class TTSController(QObject, MediatorMixin):
"""
Handles TTS requests from the main thread and passes them to the generator worker.
Also handles speech from the generator worker and passes it to the vocalizer worker.
Responses from the vocalizer worker are passed back to the main thread.
"""
def __init__(self, *args, **kwargs):
self.engine = kwargs.pop("engine")
super().__init__(*args, **kwargs)
MediatorMixin.__init__(self)

self.generator_worker = self.create_worker(GeneratorWorker)
self.vocalizer_worker = self.create_worker(VocalizerWorker)

self.register("GeneratorWorker")

self.register("tts_request", self)

@pyqtSlot(dict)
def on_tts_request(self, data: dict):
self.generator_worker.add_to_queue(data)


class TTS(QObject, MediatorMixin):
class TTS(QObject, MediatorMixin, SettingsMixin):
"""
Generates speech from given text.
Responsible for managing the model, processor, vocoder, and speaker embeddings.
Expand Down Expand Up @@ -241,68 +79,61 @@ def device(self):
def torch_dtype(self):
return torch.float16 if self.use_cuda else torch.float32

@property
def settings(self):
return self._settings

@settings.setter
def settings(self, value):
self._settings = value

@property
def word_chunks(self):
return self.settings["word_chunks"]
return self.tts_settings["word_chunks"]

@property
def use_bark(self):
return self.settings["use_bark"]
return self.tts_settings["use_bark"]

@property
def cuda_index(self):
return self.settings["cuda_index"]
return self.tts_settings["cuda_index"]

@property
def voice_preset(self):
return self.settings["voice"]
return self.tts_settings["voice"]

@property
def use_cuda(self):
return self.settings["use_cuda"] and torch.cuda.is_available()
return self.tts_settings["use_cuda"] and torch.cuda.is_available()

@property
def fine_temperature(self):
return self.settings["fine_temperature"] / 100
return self.tts_settings["fine_temperature"] / 100

@property
def coarse_temperature(self):
return self.settings["coarse_temperature"] / 100
return self.tts_settings["coarse_temperature"] / 100

@property
def semantic_temperature(self):
return self.settings["semantic_temperature"] / 100
return self.tts_settings["semantic_temperature"] / 100

@property
def enable_cpu_offload(self):
return self.settings["enable_cpu_offload"]
return self.tts_settings["enable_cpu_offload"]

@property
def play_queue_buffer_length(self):
return self.settings["play_queue_buffer_length"]
return self.tts_settings["play_queue_buffer_length"]

@property
def use_word_chunks(self):
return self.settings["use_word_chunks"]
return self.tts_settings["use_word_chunks"]

@property
def use_sentence_chunks(self):
return self.settings["use_sentence_chunks"]
return self.tts_settings["use_sentence_chunks"]

@property
def sentence_chunks(self):
return self.settings["sentence_chunks"]
return self.tts_settings["sentence_chunks"]

def __init__(self, *args, **kwargs):
super().__init__()
SettingsMixin.__init__(self)
MediatorMixin.__init__(self)
self.logger.info("Loading")
self.corpus = []
Expand All @@ -313,7 +144,7 @@ def __init__(self, *args, **kwargs):
self.sentences = []

@pyqtSlot(np.ndarray)
def on_add_to_stream(self, generated_speech: np.ndarray):
def on_add_to_stream_signal(self, generated_speech: np.ndarray):
"""
This function is called from the generator worker when speech has been generated.
It adds the generated speech to the vocalizer's queue.
Expand Down Expand Up @@ -352,7 +183,7 @@ def move_to_device(self):
self.speaker_embeddings = self.speaker_embeddings.to(self.device)

def initialize(self):
target_model = "bark" if self.use_bark else "t5"
target_model = "bark" if self.tts_settings["use_bark"] else "t5"
if target_model != self.current_model:
self.unload()

Expand Down Expand Up @@ -476,7 +307,6 @@ def process_sentences(self):
self.sentences.append(sentence)

def add_text(self, data: dict, is_end_of_message: bool):
self.settings = data["tts_settings"]
self.initialize()
self.message += data["message"]
#if is_end_of_message:
Expand Down
Loading