diff --git a/fs2/cli/synthesize.py b/fs2/cli/synthesize.py index 0eb34d6..274be64 100644 --- a/fs2/cli/synthesize.py +++ b/fs2/cli/synthesize.py @@ -342,9 +342,10 @@ def synthesize( # noqa: C901 "--output-type", help="""Which format(s) to synthesize to. Multiple formats can be provided by repeating `--output-type`. - **wav** is the default and will synthesize to a playable audio file; - **spec** will generate predicted Mel spectrograms. Tensors are time-oriented (T, K) where T is equal to the number of frames and K is equal to the number of Mel bands. - **textgrid** will generate a Praat TextGrid with alignment labels. This can be helpful for evaluation. + '**wav**' is the default and will synthesize to a playable audio file; + '**spec**' will generate predicted Mel spectrograms. Tensors are time-oriented (T, K) where T is equal to the number of frames and K is equal to the number of Mel bands. + '**textgrid**' will generate a Praat TextGrid with alignment labels. This can be helpful for evaluation. + '**readalong**' will generate a ReadAlong from the given text and synthesized audio (see https://github.com/ReadAlongs). """, ), teacher_forcing_directory: Path = typer.Option( diff --git a/fs2/prediction_writing_callback.py b/fs2/prediction_writing_callback.py index 9ba4f16..8a8799f 100644 --- a/fs2/prediction_writing_callback.py +++ b/fs2/prediction_writing_callback.py @@ -9,6 +9,7 @@ from loguru import logger from pympi import TextGrid from pytorch_lightning.callbacks import Callback +from readalongs.api import Token, convert_to_readalong from .config import FastSpeech2Config from .type_definitions import SynthesizeOutputFormats @@ -24,7 +25,7 @@ def get_synthesis_output_callbacks( vocoder_model: Optional[HiFiGAN] = None, vocoder_config: Optional[HiFiGANConfig] = None, vocoder_global_step: Optional[int] = None, -): +) -> list[Callback]: """ Given a list of desired output file formats, return the proper callbacks that will generate those files. @@ -48,6 +49,15 @@ def get_synthesis_output_callbacks( output_key=output_key, ) ) + if SynthesizeOutputFormats.readalong in output_type: + callbacks.append( + PredictionWritingReadAlongCallback( + config=config, + global_step=global_step, + output_dir=output_dir, + output_key=output_key, + ) + ) if SynthesizeOutputFormats.wav in output_type: if ( vocoder_model is None @@ -76,11 +86,13 @@ def get_synthesis_output_callbacks( class PredictionWritingCallbackBase(Callback): def __init__( self, + config: FastSpeech2Config, file_extension: str, global_step: int, save_dir: Path, ) -> None: super().__init__() + self.config = config self.file_extension = file_extension self.global_step = f"ckpt={global_step}" self.save_dir = save_dir @@ -88,21 +100,21 @@ def __init__( self.save_dir.mkdir(parents=True, exist_ok=True) - def _get_filename(self, basename: str, speaker: str, language: str) -> Path: + def get_filename( + self, + basename: str, + speaker: str, + language: str, + include_global_step: bool = False, + ) -> Path: # We don't truncate or alter the filename here because the basename is # already truncated/cleaned in cli/synthesize.py - path = self.save_dir / self.sep.join( - [ - basename, - speaker, - language, - self.global_step, - self.file_extension, - ] - ) - path.parent.mkdir( - parents=True, exist_ok=True - ) # synthesizing spec allows nested outputs + name_parts = [basename, speaker, language, self.file_extension] + if include_global_step: + name_parts.insert(-1, self.global_step) + path = self.save_dir / self.sep.join(name_parts) + # synthesizing spec allows nested outputs so we may need to make subdirs + path.parent.mkdir(parents=True, exist_ok=True) return path @@ -119,33 +131,15 @@ def __init__( output_key: str, ): super().__init__( + config=config, global_step=global_step, file_extension=f"spec-pred-{config.preprocessing.audio.input_sampling_rate}-{config.preprocessing.audio.spec_type}.pt", save_dir=output_dir / "synthesized_spec", ) self.output_key = output_key - self.config = config logger.info(f"Saving pytorch output to {self.save_dir}") - def _get_filename(self, basename: str, speaker: str, language: str) -> Path: - # We don't truncate or alter the filename here because the basename is - # already truncated/cleaned in cli/synthesize.py - # the spec should not have the global step printed because it is used to fine-tune - # and the dataloader does not expect a global step in the filename - path = self.save_dir / self.sep.join( - [ - basename, - speaker, - language, - self.file_extension, - ] - ) - path.parent.mkdir( - parents=True, exist_ok=True - ) # synthesizing spec allows nested outputs - return path - def on_predict_batch_end( # pyright: ignore [reportIncompatibleMethodOverride] self, _trainer, @@ -166,53 +160,52 @@ def on_predict_batch_end( # pyright: ignore [reportIncompatibleMethodOverride] ): torch.save( data[:unmasked_len].cpu(), - self._get_filename( - basename=basename, - speaker=speaker, - language=language, - ), + self.get_filename(basename, speaker, language), ) -class PredictionWritingTextGridCallback(PredictionWritingCallbackBase): +class PredictionWritingAlignedTextCallback(PredictionWritingCallbackBase): """ - This callback runs inference on a provided text-to-spec model and saves the resulting textgrid of the predicted durations to disk. This can be used for evaluation. + This callback runs inference on a provided text-to-spec model and saves the + resulting time-aligned text to file. The output format depends on the subclass's + implementation of save_aligned_text_to_file. """ def __init__( self, config: FastSpeech2Config, global_step: int, - output_dir: Path, output_key: str, + file_extension: str, + save_dir: Path, ): super().__init__( + config=config, global_step=global_step, - file_extension=f"{config.preprocessing.audio.input_sampling_rate}-{config.preprocessing.audio.spec_type}.TextGrid", - save_dir=output_dir / "textgrids", + file_extension=file_extension, + save_dir=save_dir, ) self.text_processor = TextProcessor(config.text) self.output_key = output_key - self.config = config logger.info(f"Saving pytorch output to {self.save_dir}") - def _get_filename(self, basename: str, speaker: str, language: str) -> Path: - # We don't truncate or alter the filename here because the basename is - # already truncated/cleaned in cli/synthesize.py - # the textgrid should not have the global step printed because it is used to fine-tune - # and the dataloader does not expect a global step in the filename - path = self.save_dir / self.sep.join( - [ - basename, - speaker, - language, - self.file_extension, - ] - ) - path.parent.mkdir( - parents=True, exist_ok=True - ) # synthesizing spec allows nested outputs - return path + def save_aligned_text_to_file( + self, + max_seconds: float, + phones: list[tuple[float, float, str]], + words: list[tuple[float, float, str]], + language: str, + filename: Path, + ): # pragma: no cover + """ + Subclasses must implement this function to save the aligned text to file + in the desired format. + + See for example PredictionWritingTextGridCallback.save_aligned_text_to_file + and PredictionWritingReadAlongCallback.save_aligned_text_to_file which save + the results to TextGrid and ReadAlong formats, respectively. + """ + raise NotImplementedError def frames_to_seconds(self, frames: int) -> float: return ( @@ -253,8 +246,6 @@ def on_predict_batch_end( # pyright: ignore [reportIncompatibleMethodOverride] ), f"can't synthesize {raw_text} because the number of predicted duration steps ({len(duration_frames)}) doesn't equal the number of input text labels ({len(text_labels)})" # get the duration of the audio: (sum_of_frames * hop_size) / sample_rate xmax_seconds = self.frames_to_seconds(sum(duration_frames)) - # create new textgrid - new_tg = TextGrid(xmax=xmax_seconds) # create the tiers words: list[tuple[float, float, str]] = [] phones: list[tuple[float, float, str]] = [] @@ -262,10 +253,6 @@ def on_predict_batch_end( # pyright: ignore [reportIncompatibleMethodOverride] current_word_duration = 0.0 last_phone_end = 0.0 last_word_end = 0.0 - phone_tier = new_tg.add_tier("phones") - phone_annotation_tier = new_tg.add_tier("phone annotations") - word_tier = new_tg.add_tier("words") - word_annotation_tier = new_tg.add_tier("word annotations") # skip padding text_labels_no_padding = [tl for tl in text_labels if tl != "\x80"] duration_frames_no_padding = duration_frames[: len(text_labels_no_padding)] @@ -277,8 +264,6 @@ def on_predict_batch_end( # pyright: ignore [reportIncompatibleMethodOverride] current_phone_end = last_phone_end + phone_duration interval = (last_phone_end, current_phone_end, label) phones.append(interval) - phone_annotation_tier.add_interval(interval[0], interval[1], "") - phone_tier.add_interval(*interval) last_phone_end = current_phone_end # accumulate phone to word label current_word_duration += phone_duration @@ -291,18 +276,90 @@ def on_predict_batch_end( # pyright: ignore [reportIncompatibleMethodOverride] raw_text_words[len(words)], ) words.append(interval) - word_tier.add_interval(*interval) - word_annotation_tier.add_interval(interval[0], interval[1], "") last_word_end = current_word_end current_word_duration = 0 + # get the filename - filename = self._get_filename( - basename=basename, - speaker=speaker, - language=language, + filename = self.get_filename(basename, speaker, language) + # Save the output (the subclass has to implement this) + self.save_aligned_text_to_file( + xmax_seconds, phones, words, language, filename ) - # write the file - new_tg.to_file(filename) + + +class PredictionWritingTextGridCallback(PredictionWritingAlignedTextCallback): + """ + This callback runs inference on a provided text-to-spec model and saves the resulting textgrid of the predicted durations to disk. This can be used for evaluation. + """ + + def __init__( + self, + config: FastSpeech2Config, + global_step: int, + output_dir: Path, + output_key: str, + ): + super().__init__( + config=config, + global_step=global_step, + output_key=output_key, + file_extension=f"{config.preprocessing.audio.input_sampling_rate}-{config.preprocessing.audio.spec_type}.TextGrid", + save_dir=output_dir / "textgrids", + ) + + def save_aligned_text_to_file(self, max_seconds, phones, words, language, filename): + """Save the aligned text as a TextGrid with phones and words layers""" + new_tg = TextGrid(xmax=max_seconds) + phone_tier = new_tg.add_tier("phones") + phone_annotation_tier = new_tg.add_tier("phone annotations") + for interval in phones: + phone_annotation_tier.add_interval(interval[0], interval[1], "") + phone_tier.add_interval(*interval) + + word_tier = new_tg.add_tier("words") + word_annotation_tier = new_tg.add_tier("word annotations") + for interval in words: + word_tier.add_interval(*interval) + word_annotation_tier.add_interval(interval[0], interval[1], "") + + new_tg.to_file(filename) + + +class PredictionWritingReadAlongCallback(PredictionWritingAlignedTextCallback): + """ + This callback runs inference on a provided text-to-spec model and saves the resulting readalong of the predicted durations to disk. Combined with the .wav output, this can be loaded in the ReadAlongs Web-Component for viewing. + """ + + def __init__( + self, + config: FastSpeech2Config, + global_step: int, + output_dir: Path, + output_key: str, + ): + super().__init__( + config=config, + global_step=global_step, + output_key=output_key, + file_extension=f"{config.preprocessing.audio.input_sampling_rate}-{config.preprocessing.audio.spec_type}.readalong", + save_dir=output_dir / "readalongs", + ) + self.text_processor = TextProcessor(config.text) + self.output_key = output_key + logger.info(f"Saving pytorch output to {self.save_dir}") + + def save_aligned_text_to_file(self, max_seconds, phones, words, language, filename): + """Save the aligned text as a .readalong file""" + + ras_tokens: list[Token] = [] + for start, end, label in words: + if ras_tokens: + ras_tokens.append(Token(text=" ", is_word=False)) + ras_tokens.append(Token(text=label, time=start, dur=end - start)) + + readalong = convert_to_readalong([ras_tokens], [language]) + with open(filename, "w", encoding="utf8") as f: + f.write(readalong) class PredictionWritingWavCallback(PredictionWritingCallbackBase): @@ -322,6 +379,7 @@ def __init__( vocoder_global_step: int, ): super().__init__( + config=config, file_extension="pred.wav", global_step=global_step, save_dir=output_dir / "wav", @@ -329,7 +387,6 @@ def __init__( self.output_key = output_key self.device = device - self.config = config self.vocoder_model = vocoder_model self.vocoder_config = vocoder_config sampling_rate_change = ( @@ -403,10 +460,8 @@ def on_predict_batch_end( # pyright: ignore [reportIncompatibleMethodOverride] outputs["tgt_lens"], ): write( - self._get_filename( - basename=basename, - speaker=speaker, - language=language, + self.get_filename( + basename, speaker, language, include_global_step=True ), sr, # the vocoder output includes padding so we have to remove that diff --git a/fs2/tests/test_writing_callbacks.py b/fs2/tests/test_writing_callbacks.py index b67bbd4..ea70bed 100644 --- a/fs2/tests/test_writing_callbacks.py +++ b/fs2/tests/test_writing_callbacks.py @@ -11,11 +11,8 @@ from pytorch_lightning import Trainer from ..config import FastSpeech2Config, FastSpeech2TrainingConfig -from ..prediction_writing_callback import ( - PredictionWritingSpecCallback, - PredictionWritingTextGridCallback, - PredictionWritingWavCallback, -) +from ..prediction_writing_callback import get_synthesis_output_callbacks +from ..type_definitions import SynthesizeOutputFormats from ..utils import BASENAME_MAX_LENGTH, truncate_basename @@ -120,12 +117,14 @@ def test_filenames_not_truncated(self): with TemporaryDirectory() as tmp_dir: tmp_dir = Path(tmp_dir) with silence_c_stderr(): - writer = PredictionWritingSpecCallback( + writer = get_synthesis_output_callbacks( + [SynthesizeOutputFormats.spec], config=FastSpeech2Config(contact=self.contact), global_step=77, output_dir=tmp_dir, output_key=self.output_key, - ) + device=torch.device("cpu"), + )[0] writer.on_predict_batch_end( _trainer=None, _pl_module=None, @@ -163,12 +162,14 @@ def test_filenames_not_truncated(self): with TemporaryDirectory() as tmp_dir: tmp_dir = Path(tmp_dir) with silence_c_stderr(): - writer = PredictionWritingTextGridCallback( + writer = get_synthesis_output_callbacks( + [SynthesizeOutputFormats.textgrid], config=FastSpeech2Config(contact=self.contact), global_step=77, output_dir=tmp_dir, output_key=self.output_key, - ) + device=torch.device("cpu"), + )[0] writer.on_predict_batch_end( _trainer=None, _pl_module=None, @@ -178,7 +179,7 @@ def test_filenames_not_truncated(self): _dataloader_idx=0, ) output_dir = writer.save_dir - # print(output_dir, *output_dir.glob("**")) # For debugging + # print(output_dir, *output_dir.glob("**/*")) # For debugging self.assertTrue(output_dir.exists()) self.assertTrue( (output_dir / "short--spk1--lngA--22050-mel-librosa.TextGrid").exists() @@ -203,6 +204,49 @@ def test_filenames_not_truncated(self): self.assertEqual(tiers[2].intervals[0][2], "W̱SÁNEĆ") +class TestWritingReadAlong(WritingTestBase): + """ + Testing the callback that writes .readalong files. + """ + + def test_writing_readalong(self): + with TemporaryDirectory() as tmp_dir: + tmp_dir = Path(tmp_dir) + with silence_c_stderr(): + writer = get_synthesis_output_callbacks( + [SynthesizeOutputFormats.readalong], + config=FastSpeech2Config(contact=self.contact), + global_step=77, + output_dir=tmp_dir, + output_key=self.output_key, + device=torch.device("cpu"), + )[0] + writer.on_predict_batch_end( + _trainer=None, + _pl_module=None, + outputs=self.outputs, + batch=self.batch, + _batch_idx=0, + _dataloader_idx=0, + ) + output_dir = writer.save_dir + + # print(output_dir, *output_dir.glob("**/*")) # For debugging + output_files = ( + output_dir / "short--spk1--lngA--22050-mel-librosa.readalong", + output_dir + / "This utterance is way too long--spk2--lngB--22050-mel-librosa.readalong", + ) + for output_file in output_files: + with self.subTest(output_file=output_file): + self.assertTrue(output_file.exists()) + with open(output_file, "r", encoding="utf8") as f: + readalong = f.read() + # print(readalong) + self.assertIn("