diff --git a/everyvoice/cli.py b/everyvoice/cli.py index 90e71bb9..681bdaa9 100644 --- a/everyvoice/cli.py +++ b/everyvoice/cli.py @@ -1,4 +1,3 @@ -import enum import json import platform import subprocess @@ -37,6 +36,9 @@ from everyvoice.model.feature_prediction.FastSpeech2_lightning.fs2.cli.train import ( train as train_fs2, ) +from everyvoice.model.feature_prediction.FastSpeech2_lightning.fs2.type_definitions import ( + SynthesizeOutputFormats, +) from everyvoice.model.vocoder.HiFiGAN_iSTFT_lightning.hfgl.cli import ( HFG_EXPORT_LONG_HELP, HFG_EXPORT_SHORT_HELP, @@ -558,7 +560,7 @@ def check_data( )(inspect_checkpoint) -TestSuites = enum.Enum("TestSuites", {name: name for name in SUITE_NAMES}) # type: ignore +TestSuites = Enum("TestSuites", {name: name for name in SUITE_NAMES}) # type: ignore @app.command(hidden=True) @@ -571,6 +573,12 @@ def test(suite: TestSuites = typer.Argument("dev")): SCHEMAS_TO_OUTPUT: dict[str, Any] = {} # dict[str, type[BaseModel]] +AllowedDemoOutputFormats = Enum( # type: ignore + "AllowedDemoOutputFormats", + [("all", "all")] + [(i.name, i.value) for i in SynthesizeOutputFormats], +) + + @app.command() def demo( text_to_spec_model: Path = typer.Argument( @@ -608,13 +616,19 @@ def demo( ["all"], "--language", "-l", - help="Specify languages to be included in the demo. Example: everyvoice demo --language eng --language fin", + help="Specify languages to be included in the demo. Must be supported by your model. Example: everyvoice demo TEXT_TO_SPEC_MODEL SPEC_TO_WAV_MODEL --language eng --language fin", ), speakers: List[str] = typer.Option( ["all"], "--speaker", "-s", - help="Specify speakers to be included in the demo. Example: everyvoice demo --speaker speaker_1 --speaker Sue", + help="Specify speakers to be included in the demo. Must be supported by your model. Example: everyvoice demo TEXT_TO_SPEC_MODEL SPEC_TO_WAV_MODEL --speaker speaker_1 --speaker Sue", + ), + outputs: list[AllowedDemoOutputFormats] = typer.Option( + ["all"], + "--output-format", + "-O", + help="Specify output formats to be included in the demo. Example: everyvoice demo TEXT_TO_SPEC_MODEL SPEC_TO_WAV_MODEL --output-format wav --output-format readalong-html", ), output_dir: Path = typer.Option( "synthesis_output", @@ -625,9 +639,13 @@ def demo( help="The directory where your synthesized audio should be written", shell_complete=complete_path, ), - accelerator: str = typer.Option("auto", "--accelerator", "-a"), + accelerator: str = typer.Option( + "auto", + "--accelerator", + "-a", + help="Specify the Pytorch Lightning accelerator to use", + ), ): - if allowlist and denylist: raise ValueError( "You provided a value for both the allowlist and the denylist but you can only provide one." @@ -652,6 +670,7 @@ def demo( spec_to_wav_model_path=spec_to_wav_model, languages=languages, speakers=speakers, + outputs=outputs, output_dir=output_dir, accelerator=accelerator, allowlist=allowlist_data, diff --git a/everyvoice/demo/app.py b/everyvoice/demo/app.py index 04ddf0cf..5bce0227 100644 --- a/everyvoice/demo/app.py +++ b/everyvoice/demo/app.py @@ -7,7 +7,6 @@ import gradio as gr import torch -from gradio.processing_utils import convert_to_16_bit_wav from loguru import logger from everyvoice.config.type_definitions import TargetTrainingTextRepresentationLevel @@ -17,12 +16,16 @@ from everyvoice.model.feature_prediction.FastSpeech2_lightning.fs2.model import ( FastSpeech2, ) -from everyvoice.model.feature_prediction.FastSpeech2_lightning.fs2.prediction_writing_callback import ( - PredictionWritingWavCallback, +from everyvoice.model.feature_prediction.FastSpeech2_lightning.fs2.type_definitions import ( + SynthesizeOutputFormats, +) +from everyvoice.model.feature_prediction.FastSpeech2_lightning.fs2.utils import ( + truncate_basename, ) from everyvoice.model.vocoder.HiFiGAN_iSTFT_lightning.hfgl.utils import ( load_hifigan_from_checkpoint, ) +from everyvoice.utils import slugify from everyvoice.utils.heavy import get_device_from_accelerator os.environ["no_proxy"] = "localhost,127.0.0.1,::1" @@ -33,6 +36,7 @@ def synthesize_audio( duration_control, language, speaker, + output_format, text_to_spec_model, vocoder_model, vocoder_config, @@ -47,6 +51,7 @@ def synthesize_audio( "Text for synthesis was not provided. Please type the text you want to be synthesized into the textfield." ) norm_text = normalize_text(text) + basename = truncate_basename(slugify(text)) if allowlist and norm_text not in allowlist: raise gr.Error( f"Oops, the word {text} is not allowed to be synthesized by this model. Please contact the model owner." @@ -62,7 +67,9 @@ def synthesize_audio( raise gr.Error("Language is not selected. Please select a language.") if speaker is None: raise gr.Error("Speaker is not selected. Please select a speaker.") - config, device, predictions = synthesize_helper( + if output_format is None: + raise gr.Error("Speaker is not selected. Please select an output format.") + config, device, predictions, callbacks = synthesize_helper( model=text_to_spec_model, vocoder_model=vocoder_model, vocoder_config=vocoder_config, @@ -71,9 +78,9 @@ def synthesize_audio( accelerator=accelerator, devices="1", device=device, - global_step=1, - vocoder_global_step=1, # dummy value since the vocoder step is not used - output_type=[], + global_step=text_to_spec_model.config.training.max_steps, + vocoder_global_step=vocoder_model.config.training.max_steps, + output_type=(output_format, SynthesizeOutputFormats.wav), text_representation=TargetTrainingTextRepresentationLevel.characters, output_dir=output_dir, speaker=speaker, @@ -83,24 +90,16 @@ def synthesize_audio( batch_size=1, num_workers=1, ) - output_key = ( - "postnet_output" if text_to_spec_model.config.model.use_postnet else "output" - ) - wav_writer = PredictionWritingWavCallback( - output_dir=output_dir, - config=config, - output_key=output_key, - device=device, - global_step=1, - vocoder_global_step=1, # dummy value since the vocoder step is not used - vocoder_model=vocoder_model, - vocoder_config=vocoder_config, - ) - # move to device because lightning accumulates predictions on cpu - predictions[0][output_key] = predictions[0][output_key].to(device) - wav, sr = wav_writer.synthesize_audio(predictions[0]) - return sr, convert_to_16_bit_wav(wav.numpy()) + wav_writer = callbacks[SynthesizeOutputFormats.wav] + wav_output = wav_writer.get_filename(basename, speaker, language) + + file_output = None + if output_format != SynthesizeOutputFormats.wav: + file_writer = callbacks[output_format] + file_output = file_writer.get_filename(basename, speaker, language) + + return wav_output, file_output def require_ffmpeg(): @@ -158,15 +157,38 @@ def normalize_text(text: str) -> str: def create_demo_app( - text_to_spec_model_path, - spec_to_wav_model_path, - languages, - speakers, - output_dir, - accelerator, + text_to_spec_model_path: os.PathLike, + spec_to_wav_model_path: os.PathLike, + languages: list[str], + speakers: list[str], + outputs: list, # list[str | AllowedDemoOutputFormats] + output_dir: os.PathLike, + accelerator: str, allowlist: list[str] = [], denylist: list[str] = [], ) -> gr.Blocks: + # Early argument validation where possible + possible_outputs = [x.value for x in SynthesizeOutputFormats] + + # this used to be `if outputs == ["all"]:` but my Enum() constructor for + # AllowedDemoOutputFormats breaks that, unfortunately, and enum.StrEnum + # doesn't appear until Python 3.11 so I can't use it. + if len(outputs) == 1 and getattr(outputs[0], "value", outputs[0]) == "all": + output_list = possible_outputs + else: + if not outputs: + raise ValueError( + f"Empty outputs list. Please specify ['all'] or one or more of {possible_outputs}" + ) + output_list = [] + for output in outputs: + value = getattr(output, "value", output) # Enum->value as str / str->str + if value not in possible_outputs: + raise ValueError( + f"Unknown output format '{value}'. Valid outputs values are ['all'] or one or more of {possible_outputs}" + ) + output_list.append(value) + require_ffmpeg() device = get_device_from_accelerator(accelerator) vocoder_ckpt = torch.load(spec_to_wav_model_path, map_location=device) @@ -215,6 +237,7 @@ def create_demo_app( print( f"Attention: The model have not been trained for speech synthesis with '{speaker}' speaker. The '{speaker}' speaker option will not be available for selection." ) + if lang_list == []: raise ValueError( f"Language option has been activated, but valid languages have not been provided. The model has been trained in {model_languages} languages. Please select either 'all' or at least some of them." @@ -227,6 +250,8 @@ def create_demo_app( interactive_lang = len(lang_list) > 1 default_speak = speak_list[0] interactive_speak = len(speak_list) > 1 + default_output = output_list[0] + interactive_output = len(output_list) > 1 with gr.Blocks() as demo: gr.Markdown( """ @@ -255,12 +280,25 @@ def create_demo_app( interactive=interactive_speak, label="Speaker", ) + with gr.Row(): + output_format = gr.Dropdown( + choices=output_list, + value=default_output, + interactive=interactive_output, + label="Output Format", + ) btn = gr.Button("Synthesize") with gr.Column(): - out_audio = gr.Audio(format="mp3") + out_audio = gr.Audio(format="wav") + if output_list == [SynthesizeOutputFormats.wav]: + # When the only output option is wav, don't show the File Output box + outputs = [out_audio] + else: + out_file = gr.File(label="File Output") + outputs = [out_audio, out_file] btn.click( synthesize_audio_preset, - inputs=[inp_text, inp_slider, inp_lang, inp_speak], - outputs=[out_audio], + inputs=[inp_text, inp_slider, inp_lang, inp_speak, output_format], + outputs=outputs, ) return demo diff --git a/everyvoice/model/feature_prediction/FastSpeech2_lightning b/everyvoice/model/feature_prediction/FastSpeech2_lightning index 87213a92..fab09c51 160000 --- a/everyvoice/model/feature_prediction/FastSpeech2_lightning +++ b/everyvoice/model/feature_prediction/FastSpeech2_lightning @@ -1 +1 @@ -Subproject commit 87213a92bf3b186a948e90cc4e736b8e0994936e +Subproject commit fab09c5184342dc8e3120f1bc60a4e2e79e7fe33 diff --git a/everyvoice/tests/test_cli.py b/everyvoice/tests/test_cli.py index e74fa4b1..174bd2b7 100644 --- a/everyvoice/tests/test_cli.py +++ b/everyvoice/tests/test_cli.py @@ -1,3 +1,4 @@ +import enum import json import os import subprocess @@ -21,6 +22,7 @@ from everyvoice.base_cli.helpers import save_configuration_to_log_dir from everyvoice.cli import SCHEMAS_TO_OUTPUT, app from everyvoice.config.shared_types import ContactInformation +from everyvoice.demo.app import create_demo_app from everyvoice.model.feature_prediction.FastSpeech2_lightning.fs2.config import ( FastSpeech2Config, ) @@ -56,6 +58,7 @@ def setUp(self) -> None: "preprocess", "inspect-checkpoint", "evaluate", + "demo", ] def test_version(self): @@ -323,6 +326,48 @@ def test_expensive_imports_are_tucked_away(self): self.assertNotIn(b"shared_types", result.stderr, msg.format("shared_types.py")) self.assertNotIn(b"pydantic", result.stderr, msg.format("pydantic")) + def test_demo_with_bad_args(self): + result = self.runner.invoke(app, ["demo"]) + self.assertNotEqual(result.exit_code, 0) + self.assertIn("Missing argument", result.output) + + result = self.runner.invoke( + app, ["demo", os.devnull, os.devnull, "--output-format", "not-a-format"] + ) + self.assertNotEqual(result.exit_code, 0) + self.assertIn("Invalid value", result.output) + + def test_create_demo_app_with_errors(self): + # outputs is the first thing to get checked, because it's can be done as + # a quick check before loading any models. + with self.assertRaises(ValueError) as cm: + create_demo_app( + text_to_spec_model_path=None, + spec_to_wav_model_path=None, + languages=[], + speakers=[], + outputs=[], + output_dir=None, + accelerator=None, + ) + self.assertIn("Empty outputs list", str(cm.exception)) + + class WrongEnum(str, enum.Enum): + foo = "foo" + + for outputs in (["wav", WrongEnum.foo], ["textgrid", "foo"]): + with self.assertRaises(ValueError) as cm: + create_demo_app( + text_to_spec_model_path=None, + spec_to_wav_model_path=None, + languages=[], + speakers=[], + outputs=outputs, + output_dir=None, + accelerator=None, + ) + self.assertIn("Unknown output format 'foo'", str(cm.exception)) + class TestBaseCLIHelper(TestCase): def test_save_configuration_to_log_dir(self):