Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat]: dotenv and multithreading #94

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .env-example
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
MS_TTS_KEY=<your_subscription_key> # for Azure
MS_TTS_REGION=<your_region> # for Azure
OPENAI_API_KEY=<your_openai_api_key> # for OpenAI
8 changes: 7 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -32,4 +32,10 @@ private_examples/
# custom
scripts/
*.onnx
*.onnx.json
*.onnx.json

# Models
piper_models/*

# Env files
.env
183 changes: 114 additions & 69 deletions audiobook_generator/tts_providers/azure_tts_provider.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
import concurrent
import concurrent.futures
import html
import io
import logging
import math
import multiprocessing
import os
import threading
from datetime import datetime, timedelta
from time import sleep
from typing import Optional

import requests

from audiobook_generator.core.audio_tags import AudioTags
from audiobook_generator.config.general_config import GeneralConfig
from audiobook_generator.core.utils import split_text, set_audio_tags
from audiobook_generator.core.audio_tags import AudioTags
from audiobook_generator.core.utils import set_audio_tags, split_text
from audiobook_generator.tts_providers.base_tts_provider import BaseTTSProvider

logger = logging.getLogger(__name__)
@@ -20,7 +26,7 @@
class AzureTTSProvider(BaseTTSProvider):
def __init__(self, config: GeneralConfig):
# TTS provider specific config
config.voice_name = config.voice_name or "en-US-GuyNeural"
config.voice_name = config.voice_name or "en-US-DavisMultilingualNeural"
config.output_format = config.output_format or "audio-24khz-48kbitrate-mono-mp3"

# 16$ per 1 million characters
@@ -29,6 +35,7 @@ def __init__(self, config: GeneralConfig):
# access token and expiry time
self.access_token = None
self.token_expiry_time = datetime.utcnow()
self.token_lock = threading.RLock()
super().__init__(config)

subscription_key = os.environ.get("MS_TTS_KEY")
@@ -47,24 +54,31 @@ def __init__(self, config: GeneralConfig):

def __str__(self) -> str:
return (
super().__str__()
+ f", voice_name={self.config.voice_name}, language={self.config.language}, break_duration={self.config.break_duration}, output_format={self.config.output_format}"
super().__str__()
+ f", voice_name={self.config.voice_name}, language={self.config.language}, break_duration={self.config.break_duration}, output_format={self.config.output_format}"
)

def is_access_token_expired(self) -> bool:
return self.access_token is None or datetime.utcnow() >= self.token_expiry_time
with self.token_lock:
return (
self.access_token is None or datetime.utcnow() >= self.token_expiry_time
)

def auto_renew_access_token(self) -> str:
if self.access_token is None or self.is_access_token_expired():
logger.info(
f"azure tts access_token doesn't exist or is expired, getting new one"
)
self.access_token = self.get_access_token()
self.token_expiry_time = datetime.utcnow() + timedelta(minutes=9, seconds=1)
with self.token_lock:
if self.access_token is None or self.is_access_token_expired():
logger.info(
f"azure tts access_token doesn't exist or is expired, getting new one"
)
self.access_token = self.get_access_token()
self.token_expiry_time = datetime.utcnow() + timedelta(
minutes=9, seconds=1
)
return self.access_token

def get_access_token(self) -> str:
for retry in range(MAX_RETRIES):
response = None
try:
logger.info("Getting new access token")
response = requests.post(self.TOKEN_URL, headers=self.TOKEN_HEADERS)
@@ -77,74 +91,103 @@ def get_access_token(self) -> str:
f"Network error while getting access token (attempt {retry + 1}/{MAX_RETRIES}): {e}"
)
if retry < MAX_RETRIES - 1:
sleep(2 ** retry)
sleep(2**retry)
else:
raise e
finally:
if response is not None:
response.close()
raise Exception("Failed to get access token")

def process_chunk(
self, chunk: str, audio_tags: AudioTags, i: int, total_chunks: int
) -> Optional[tuple[int, io.BytesIO]]:
logger.debug(
f"Processing chunk {i} of {total_chunks}, length={len(chunk)}, text=[{chunk}]"
)
escaped_text = html.escape(chunk)
logger.debug(f"Escaped text: [{escaped_text}]")
# replace MAGIC_BREAK_STRING with a break tag for section/paragraph break
escaped_text = escaped_text.replace(
self.get_break_string().strip(),
f" <break time='{self.config.break_duration}ms' /> ",
) # strip in case leading bank is missing
logger.info(
f"Processing chapter-{audio_tags.idx} <{audio_tags.title}>, chunk {i} of {total_chunks}"
)
ssml = f"<speak version='1.0' xmlns='http://www.w3.org/2001/10/synthesis' xml:lang='{self.config.language}'><voice name='{self.config.voice_name}'>{escaped_text}</voice></speak>"
logger.debug(f"SSML: [{ssml}]")

for retry in range(MAX_RETRIES):
access_token = self.auto_renew_access_token()
headers = {
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/ssml+xml",
"X-Microsoft-OutputFormat": self.config.output_format,
"User-Agent": "Python",
}
response = None # Initialize response
try:
logger.info(
"Sending request to Azure TTS, data length: " + str(len(ssml))
)
response = requests.post(
self.TTS_URL, headers=headers, data=ssml.encode("utf-8")
)
response.raise_for_status()
logger.info(
"Got response from Azure TTS, response length: "
+ str(len(response.content))
)
return (i, io.BytesIO(response.content))
except requests.exceptions.RequestException as e:
logger.warning(
f"Error while converting text to speech (attempt {retry + 1}/{MAX_RETRIES}): {e}"
)
if retry < MAX_RETRIES - 1:
sleep(2**retry)
else:
raise e
finally:
if response is not None:
response.close()

def text_to_speech(
self,
text: str,
output_file: str,
audio_tags: AudioTags,
self,
text: str,
output_file: str,
audio_tags: AudioTags,
):
# Adjust this value based on your testing
max_chars = 1800 if self.config.language.startswith("zh") else 3000

text_chunks = split_text(text, max_chars, self.config.language)

audio_segments = []

for i, chunk in enumerate(text_chunks, 1):
logger.debug(
f"Processing chunk {i} of {len(text_chunks)}, length={len(chunk)}, text=[{chunk}]"
)
escaped_text = html.escape(chunk)
logger.debug(f"Escaped text: [{escaped_text}]")
# replace MAGIC_BREAK_STRING with a break tag for section/paragraph break
escaped_text = escaped_text.replace(
self.get_break_string().strip(),
f" <break time='{self.config.break_duration}ms' /> ",
) # strip in case leading bank is missing
logger.info(
f"Processing chapter-{audio_tags.idx} <{audio_tags.title}>, chunk {i} of {len(text_chunks)}"
)
ssml = f"<speak version='1.0' xmlns='http://www.w3.org/2001/10/synthesis' xml:lang='{self.config.language}'><voice name='{self.config.voice_name}'>{escaped_text}</voice></speak>"
logger.debug(f"SSML: [{ssml}]")

for retry in range(MAX_RETRIES):
self.auto_renew_access_token()
headers = {
"Authorization": f"Bearer {self.access_token}",
"Content-Type": "application/ssml+xml",
"X-Microsoft-OutputFormat": self.config.output_format,
"User-Agent": "Python",
}
total_chunks = len(text_chunks)

audio_segments: list[tuple[int, io.BytesIO]] = []
with concurrent.futures.ThreadPoolExecutor(
max_workers=4 # multiprocessing.cpu_count()
) as executor:
futures = {
executor.submit(
self.process_chunk, chunk, audio_tags, i, total_chunks
): i
for i, chunk in enumerate(text_chunks, 1)
}
for i, chunk in enumerate(text_chunks, 1):
logger.debug(
f"Processing chunk {i} of {len(text_chunks)}, length={len(chunk)}, text=[{chunk}]"
)
for future in concurrent.futures.as_completed(futures):
try:
logger.info(
"Sending request to Azure TTS, data length: " + str(len(ssml))
)
response = requests.post(
self.TTS_URL, headers=headers, data=ssml.encode("utf-8")
)
response.raise_for_status() # Will raise HTTPError for 4XX or 5XX status
logger.info(
"Got response from Azure TTS, response length: "
+ str(len(response.content))
)
audio_segments.append(io.BytesIO(response.content))
break
except requests.exceptions.RequestException as e:
logger.warning(
f"Error while converting text to speech (attempt {retry + 1}): {e}"
)
if retry < MAX_RETRIES - 1:
sleep(2 ** retry)
else:
raise e

result = future.result()
if result:
audio_segments.append(result)
except Exception as e:
logger.error(f"Error processing chunk: {e}")
with open(output_file, "wb") as outfile:
for segment in audio_segments:

for _, segment in sorted(audio_segments, key=lambda x: x[0]):
segment.seek(0)
outfile.write(segment.read())

@@ -171,7 +214,9 @@ def get_output_file_extension(self):
elif self.config.output_format.endswith("mp3"):
return "mp3"
else:
raise NotImplementedError(f"Unknown file extension for output format: {self.config.output_format}")
raise NotImplementedError(
f"Unknown file extension for output format: {self.config.output_format}"
)

def validate_config(self):
# TODO: Need to dig into Azure properties, im not familiar with them, but look at OpenAI as ref example
9 changes: 8 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
import argparse
from ast import parse
import logging
import os

from audiobook_generator.config.general_config import GeneralConfig
from audiobook_generator.core.audiobook_generator import AudiobookGenerator
from audiobook_generator.tts_providers.base_tts_provider import (
get_supported_tts_providers,
)

from dotenv import load_dotenv

load_dotenv()


def handle_args():

parser = argparse.ArgumentParser(description="Convert text book to audiobook")
parser.add_argument("input_file", help="Path to the EPUB file")
parser.add_argument("output_folder", help="Path to the output folder")
@@ -179,7 +186,7 @@ def setup_logging(log_level):

def main():
config = handle_args()

config.tts
setup_logging(config.log)

AudiobookGenerator(config).run()
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -5,4 +5,5 @@ openai==1.35.7
requests==2.32.3
socksio==1.0.0
edge-tts==6.1.12
pydub==0.25.1
pydub==0.25.1
python-dotenv==1.0.1