Skip to content

Commit

Permalink
refactor: get synthesis callbacks now returns a dict
Browse files Browse the repository at this point in the history
  • Loading branch information
joanise committed Jan 8, 2025
1 parent 38064eb commit 520222e
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 56 deletions.
22 changes: 12 additions & 10 deletions fs2/cli/synthesize.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,16 +238,18 @@ def synthesize_helper(
accelerator=accelerator,
devices=devices,
max_epochs=model.config.training.max_epochs,
callbacks=get_synthesis_output_callbacks(
output_type=output_type,
output_dir=output_dir,
config=model.config,
output_key=model.output_key,
device=device,
global_step=global_step,
vocoder_model=vocoder_model,
vocoder_config=vocoder_config,
vocoder_global_step=vocoder_global_step,
callbacks=list(
get_synthesis_output_callbacks(
output_type=output_type,
output_dir=output_dir,
config=model.config,
output_key=model.output_key,
device=device,
global_step=global_step,
vocoder_model=vocoder_model,
vocoder_config=vocoder_config,
vocoder_global_step=vocoder_global_step,
).values()
),
)
if teacher_forcing_directory is not None:
Expand Down
68 changes: 31 additions & 37 deletions fs2/prediction_writing_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,39 +32,12 @@ def get_synthesis_output_callbacks(
vocoder_model: Optional[HiFiGAN] = None,
vocoder_config: Optional[HiFiGANConfig] = None,
vocoder_global_step: Optional[int] = None,
) -> list[Callback]:
) -> dict[SynthesizeOutputFormats, Callback]:
"""
Given a list of desired output file formats, return the proper callbacks
that will generate those files.
"""
callbacks: list[Callback] = []
if SynthesizeOutputFormats.spec in output_type:
callbacks.append(
PredictionWritingSpecCallback(
config=config,
global_step=global_step,
output_dir=output_dir,
output_key=output_key,
)
)
if SynthesizeOutputFormats.textgrid in output_type:
callbacks.append(
PredictionWritingTextGridCallback(
config=config,
global_step=global_step,
output_dir=output_dir,
output_key=output_key,
)
)
if SynthesizeOutputFormats.readalong_xml in output_type:
callbacks.append(
PredictionWritingReadAlongCallback(
config=config,
global_step=global_step,
output_dir=output_dir,
output_key=output_key,
)
)
callbacks: dict[SynthesizeOutputFormats, Callback] = {}
if (
SynthesizeOutputFormats.wav in output_type
or SynthesizeOutputFormats.readalong_html in output_type
Expand All @@ -77,22 +50,43 @@ def get_synthesis_output_callbacks(
raise ValueError(
"We cannot synthesize waveforms without a vocoder. Please ensure that a vocoder is specified."
)
callbacks.append(
PredictionWritingWavCallback(
callbacks[SynthesizeOutputFormats.wav] = PredictionWritingWavCallback(
config=config,
device=device,
global_step=global_step,
output_dir=output_dir,
output_key=output_key,
vocoder_model=vocoder_model,
vocoder_config=vocoder_config,
vocoder_global_step=vocoder_global_step,
)
if SynthesizeOutputFormats.spec in output_type:
callbacks[SynthesizeOutputFormats.spec] = PredictionWritingSpecCallback(
config=config,
global_step=global_step,
output_dir=output_dir,
output_key=output_key,
)
if SynthesizeOutputFormats.textgrid in output_type:
callbacks[SynthesizeOutputFormats.textgrid] = PredictionWritingTextGridCallback(
config=config,
global_step=global_step,
output_dir=output_dir,
output_key=output_key,
)
if SynthesizeOutputFormats.readalong_xml in output_type:
callbacks[SynthesizeOutputFormats.readalong_xml] = (
PredictionWritingReadAlongCallback(
config=config,
device=device,
global_step=global_step,
output_dir=output_dir,
output_key=output_key,
vocoder_model=vocoder_model,
vocoder_config=vocoder_config,
vocoder_global_step=vocoder_global_step,
)
)
if SynthesizeOutputFormats.readalong_html in output_type:
wav_callback = callbacks[-1]
wav_callback = callbacks[SynthesizeOutputFormats.wav]
assert isinstance(wav_callback, PredictionWritingWavCallback)
callbacks.append(
callbacks[SynthesizeOutputFormats.readalong_html] = (
PredictionWritingOfflineRASCallback(
config=config,
global_step=global_step,
Expand Down
22 changes: 13 additions & 9 deletions fs2/tests/test_writing_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,15 @@ def test_filenames_not_truncated(self):
with TemporaryDirectory() as tmp_dir:
tmp_dir = Path(tmp_dir)
with silence_c_stderr():
writer = get_synthesis_output_callbacks(
writers = get_synthesis_output_callbacks(
[SynthesizeOutputFormats.spec],
config=FastSpeech2Config(contact=self.contact),
global_step=77,
output_dir=tmp_dir,
output_key=self.output_key,
device=torch.device("cpu"),
)[0]
)
writer = next(iter(writers.values()))
writer.on_predict_batch_end(
_trainer=None,
_pl_module=None,
Expand Down Expand Up @@ -162,14 +163,15 @@ def test_filenames_not_truncated(self):
with TemporaryDirectory() as tmp_dir:
tmp_dir = Path(tmp_dir)
with silence_c_stderr():
writer = get_synthesis_output_callbacks(
writers = get_synthesis_output_callbacks(
[SynthesizeOutputFormats.textgrid],
config=FastSpeech2Config(contact=self.contact),
global_step=77,
output_dir=tmp_dir,
output_key=self.output_key,
device=torch.device("cpu"),
)[0]
)
writer = next(iter(writers.values()))
writer.on_predict_batch_end(
_trainer=None,
_pl_module=None,
Expand Down Expand Up @@ -213,14 +215,15 @@ def test_writing_readalong(self):
with TemporaryDirectory() as tmp_dir:
tmp_dir = Path(tmp_dir)
with silence_c_stderr():
writer = get_synthesis_output_callbacks(
writers = get_synthesis_output_callbacks(
[SynthesizeOutputFormats.readalong_xml],
config=FastSpeech2Config(contact=self.contact),
global_step=77,
output_dir=tmp_dir,
output_key=self.output_key,
device=torch.device("cpu"),
)[0]
)
writer = next(iter(writers.values()))
writer.on_predict_batch_end(
_trainer=None,
_pl_module=None,
Expand Down Expand Up @@ -271,7 +274,7 @@ def test_writing_offline_ras(self):
vocoder_config=vocoder.config,
vocoder_global_step=10,
)
for writer in writers:
for writer in writers.values():
writer.on_predict_batch_end(
_trainer=None,
_pl_module=None,
Expand Down Expand Up @@ -327,7 +330,7 @@ def test_filenames_not_truncated(self):
vocoder, vocoder_path = get_dummy_vocoder(tmp_dir)

with silence_c_stderr():
writer = get_synthesis_output_callbacks(
writers = get_synthesis_output_callbacks(
[SynthesizeOutputFormats.wav],
config=FastSpeech2Config(
contact=self.contact,
Expand All @@ -340,7 +343,8 @@ def test_filenames_not_truncated(self):
vocoder_model=vocoder,
vocoder_config=vocoder.config,
vocoder_global_step=10,
)[0]
)
writer = next(iter(writers.values()))
writer.on_predict_batch_end(
_trainer=None,
_pl_module=None,
Expand Down

0 comments on commit 520222e

Please sign in to comment.