Skip to content

Commit

Permalink
feat: add readalong-html synthesize output format
Browse files Browse the repository at this point in the history
And rename readalong -> readalong-xml for clarity, as suggested by @roedoejet
  • Loading branch information
joanise committed Dec 17, 2024
1 parent 9549082 commit ff1b0d0
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 34 deletions.
24 changes: 19 additions & 5 deletions fs2/cli/synthesize.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,10 +342,21 @@ def synthesize( # noqa: C901
"--output-type",
help="""Which format(s) to synthesize to.
Multiple formats can be provided by repeating `--output-type`.
'**wav**' is the default and will synthesize to a playable audio file;
'**spec**' will generate predicted Mel spectrograms. Tensors are Mel band-oriented (K, T) where K is equal to the number of Mel bands and T is equal to the number of frames.
'**wav**' is the default and will synthesize to a playable audio file. Requires --vocoder-path.
'**spec**' will generate predicted Mel spectrograms. Tensors are Mel band-oriented (K, T) where K is equal to the number of Mel bands and T is equal to the number of frames.
'**textgrid**' will generate a Praat TextGrid with alignment labels. This can be helpful for evaluation.
'**readalong**' will generate a ReadAlong from the given text and synthesized audio (see https://github.com/ReadAlongs).
'**readalong-xml**' will generate a ReadAlong from the given text and synthesized audio in XML .readalong format (see https://github.com/ReadAlongs).
'**readalong-html**' will generate a single file Offline HTML ReadAlong that can be further edited in the ReadAlong Studio Editor, and opened by itself. Also implies '--output-type wav'. Requires --vocoder-path.
""",
),
teacher_forcing_directory: Path = typer.Option(
Expand Down Expand Up @@ -394,9 +405,12 @@ def synthesize( # noqa: C901
sys.exit(1)

# output to .wav will require a valid spec-to-wav model
if SynthesizeOutputFormats.wav in output_type and not vocoder_path:
if (
SynthesizeOutputFormats.wav in output_type
or SynthesizeOutputFormats.readalong_html in output_type
) and not vocoder_path:
print(
"Missing --vocoder-path option, which is required when the output type is 'wav'.",
"Missing --vocoder-path option, which is required when the output type includes 'wav' or 'offline-ras'.",
file=sys.stderr,
)
sys.exit(1)
Expand Down
127 changes: 109 additions & 18 deletions fs2/prediction_writing_callback.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from pathlib import Path
from typing import Any, Optional, Sequence

Expand All @@ -9,7 +11,11 @@
from loguru import logger
from pympi import TextGrid
from pytorch_lightning.callbacks import Callback
from readalongs.api import Token, convert_to_readalong
from readalongs.api import (
Token,
convert_prealigned_text_to_offline_html,
convert_prealigned_text_to_readalong,
)

from .config import FastSpeech2Config
from .type_definitions import SynthesizeOutputFormats
Expand Down Expand Up @@ -49,7 +55,7 @@ def get_synthesis_output_callbacks(
output_key=output_key,
)
)
if SynthesizeOutputFormats.readalong in output_type:
if SynthesizeOutputFormats.readalong_xml in output_type:
callbacks.append(
PredictionWritingReadAlongCallback(
config=config,
Expand All @@ -58,7 +64,10 @@ def get_synthesis_output_callbacks(
output_key=output_key,
)
)
if SynthesizeOutputFormats.wav in output_type:
if (
SynthesizeOutputFormats.wav in output_type
or SynthesizeOutputFormats.readalong_html in output_type
):
if (
vocoder_model is None
or vocoder_config is None
Expand All @@ -79,6 +88,18 @@ def get_synthesis_output_callbacks(
vocoder_global_step=vocoder_global_step,
)
)
if SynthesizeOutputFormats.readalong_html in output_type:
wav_callback = callbacks[-1]
assert isinstance(wav_callback, PredictionWritingWavCallback)
callbacks.append(
PredictionWritingOfflineRASCallback(
config=config,
global_step=global_step,
output_dir=output_dir,
output_key=output_key,
wav_callback=wav_callback,
)
)

return callbacks

Expand Down Expand Up @@ -196,17 +217,12 @@ def save_aligned_text_to_file(
max_seconds: float,
phones: list[tuple[float, float, str]],
words: list[tuple[float, float, str]],
basename: str,
speaker: str,
language: str,
filename: Path,
): # pragma: no cover
"""
Subclasses must implement this function to save the aligned text to file
in the desired format.
See for example PredictionWritingTextGridCallback.save_aligned_text_to_file
and PredictionWritingReadAlongCallback.save_aligned_text_to_file which save
the results to TextGrid and ReadAlong formats, respectively.
"""
"""Subclasses must implement this function to save the aligned text to file
in the desired format."""
raise NotImplementedError

def frames_to_seconds(self, frames: int) -> float:
Expand Down Expand Up @@ -281,11 +297,9 @@ def on_predict_batch_end( # pyright: ignore [reportIncompatibleMethodOverride]
last_word_end = current_word_end
current_word_duration = 0

# get the filename
filename = self.get_filename(basename, speaker, language)
# Save the output (the subclass has to implement this)
self.save_aligned_text_to_file(
xmax_seconds, phones, words, language, filename
xmax_seconds, phones, words, basename, speaker, language
)


Expand All @@ -309,7 +323,15 @@ def __init__(
save_dir=output_dir / "textgrids",
)

def save_aligned_text_to_file(self, max_seconds, phones, words, language, filename):
def save_aligned_text_to_file(
self,
max_seconds: float,
phones: list[tuple[float, float, str]],
words: list[tuple[float, float, str]],
basename: str,
speaker: str,
language: str,
):
"""Save the aligned text as a TextGrid with phones and words layers"""
new_tg = TextGrid(xmax=max_seconds)
phone_tier = new_tg.add_tier("phones")
Expand All @@ -324,6 +346,7 @@ def save_aligned_text_to_file(self, max_seconds, phones, words, language, filena
word_tier.add_interval(*interval)
word_annotation_tier.add_interval(interval[0], interval[1], "")

filename = self.get_filename(basename, speaker, language)
new_tg.to_file(filename)


Expand All @@ -350,7 +373,15 @@ def __init__(
self.output_key = output_key
logger.info(f"Saving pytorch output to {self.save_dir}")

def save_aligned_text_to_file(self, max_seconds, phones, words, language, filename):
def save_aligned_text_to_file(
self,
max_seconds: float,
phones: list[tuple[float, float, str]],
words: list[tuple[float, float, str]],
basename: str,
speaker: str,
language: str,
):
"""Save the aligned text as a .readalong file"""

ras_tokens: list[Token] = []
Expand All @@ -359,11 +390,71 @@ def save_aligned_text_to_file(self, max_seconds, phones, words, language, filena
ras_tokens.append(Token(text=" ", is_word=False))
ras_tokens.append(Token(text=label, time=start, dur=end - start))

readalong = convert_to_readalong([ras_tokens], [language])
readalong = convert_prealigned_text_to_readalong([ras_tokens], [language])
filename = self.get_filename(basename, speaker, language)
with open(filename, "w", encoding="utf8") as f:
f.write(readalong)


class PredictionWritingOfflineRASCallback(PredictionWritingAlignedTextCallback):
"""
This callback runs inference on a provided text-to-spec model and saves the
resulting readalong of the predicted durations to disk as a single file
Offline HTML. This can be loaded in the ReadAlongs Studio-Web Editor for
further modification.
"""

def __init__(
self,
config: FastSpeech2Config,
global_step: int,
output_dir: Path,
output_key: str,
wav_callback: PredictionWritingWavCallback,
):
super().__init__(
config=config,
global_step=global_step,
output_key=output_key,
file_extension=f"{config.preprocessing.audio.input_sampling_rate}-{config.preprocessing.audio.spec_type}.html",
save_dir=output_dir / "readalongs",
)
self.text_processor = TextProcessor(config.text)
self.output_key = output_key
self.wav_callback = wav_callback
logger.info(f"Saving pytorch output to {self.save_dir}")

def save_aligned_text_to_file(
self,
max_seconds: float,
phones: list[tuple[float, float, str]],
words: list[tuple[float, float, str]],
basename: str,
speaker: str,
language: str,
):
"""Save the aligned text as an Offline HTML readalong file"""

ras_tokens: list[Token] = []
for start, end, label in words:
if ras_tokens:
ras_tokens.append(Token(text=" ", is_word=False))
ras_tokens.append(Token(text=label, time=start, dur=end - start))

wav_file_name = self.wav_callback.get_filename(
basename, speaker, language, include_global_step=True
)
readalong_html, _readalong_xml = convert_prealigned_text_to_offline_html(
[ras_tokens],
wav_file_name,
[language],
title="ReadAlong generated using EveryVoice",
)
filename = self.get_filename(basename, speaker, language)
with open(filename, "w", encoding="utf8") as f:
f.write(readalong_html)


class PredictionWritingWavCallback(PredictionWritingCallbackBase):
"""
Given text-to-spec, this callback does spec-to-wav and writes wav files.
Expand Down
76 changes: 66 additions & 10 deletions fs2/tests/test_writing_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def test_writing_readalong(self):
tmp_dir = Path(tmp_dir)
with silence_c_stderr():
writer = get_synthesis_output_callbacks(
[SynthesizeOutputFormats.readalong],
[SynthesizeOutputFormats.readalong_xml],
config=FastSpeech2Config(contact=self.contact),
global_step=77,
output_dir=tmp_dir,
Expand Down Expand Up @@ -247,6 +247,70 @@ def test_writing_readalong(self):
self.assertIn('<w time="0.0" dur=', readalong)


class TestWritingOfflineRAS(WritingTestBase):
"""
Testing the callback that writes Offline HTML readalong files.
"""

def test_writing_offline_ras(self):
with TemporaryDirectory() as tmp_dir:
tmp_dir = Path(tmp_dir)
vocoder, vocoder_path = get_dummy_vocoder(tmp_dir)
with silence_c_stderr():
writers = get_synthesis_output_callbacks(
[SynthesizeOutputFormats.readalong_html],
config=FastSpeech2Config(
contact=self.contact,
training=FastSpeech2TrainingConfig(vocoder_path=vocoder_path),
),
global_step=77,
output_dir=tmp_dir,
output_key=self.output_key,
device=torch.device("cpu"),
vocoder_model=vocoder,
vocoder_config=vocoder.config,
vocoder_global_step=10,
)
for writer in writers:
writer.on_predict_batch_end(
_trainer=None,
_pl_module=None,
outputs=self.outputs,
batch=self.batch,
_batch_idx=0,
_dataloader_idx=0,
)
output_dir = writer.save_dir

# print(output_dir, *output_dir.glob("**/*")) # For debugging
output_files = (
output_dir / "short--spk1--lngA--22050-mel-librosa.html",
output_dir
/ "This utterance is way too long--spk2--lngB--22050-mel-librosa.html",
)
for output_file in output_files:
with self.subTest(output_file=output_file):
self.assertTrue(output_file.exists())
with open(output_file, "r", encoding="utf8") as f:
readalong = f.read()
# print(readalong)
self.assertIn("<read-along ", readalong)
self.assertIn("<span slot", readalong)


def get_dummy_vocoder(tmp_dir: Path) -> tuple[HiFiGAN, Path]:
contact_info = ContactInformation(
contact_name="Test Runner", contact_email="[email protected]"
)
vocoder = HiFiGAN(HiFiGANConfig(contact=contact_info))
with silence_c_stderr():
trainer = Trainer(default_root_dir=str(tmp_dir), barebones=True)
trainer.strategy.connect(vocoder)
vocoder_path = tmp_dir / "vocoder"
trainer.save_checkpoint(vocoder_path)
return vocoder, vocoder_path


class TestWritingWav(WritingTestBase):
"""
Testing the callback that writes wav files.
Expand All @@ -260,15 +324,7 @@ def test_filenames_not_truncated(self):
"""
with TemporaryDirectory() as tmp_dir:
tmp_dir = Path(tmp_dir)
contact_info = ContactInformation(
contact_name="Test Runner", contact_email="[email protected]"
)
vocoder = HiFiGAN(HiFiGANConfig(contact=contact_info))
with silence_c_stderr():
trainer = Trainer(default_root_dir=str(tmp_dir), barebones=True)
trainer.strategy.connect(vocoder)
vocoder_path = Path(tmp_dir) / "vocoder"
trainer.save_checkpoint(vocoder_path)
vocoder, vocoder_path = get_dummy_vocoder(tmp_dir)

with silence_c_stderr():
writer = get_synthesis_output_callbacks(
Expand Down
3 changes: 2 additions & 1 deletion fs2/type_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ class SynthesizeOutputFormats(str, Enum):
wav = "wav"
spec = "spec"
textgrid = "textgrid"
readalong = "readalong"
readalong_xml = "readalong-xml"
readalong_html = "readalong-html"

0 comments on commit ff1b0d0

Please sign in to comment.