Skip to content

Commit

Permalink
feat: synthesize_helper returns callbacks so caller can get filenames
Browse files Browse the repository at this point in the history
  • Loading branch information
joanise committed Jan 8, 2025
1 parent 520222e commit fab09c5
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 17 deletions.
29 changes: 15 additions & 14 deletions fs2/cli/synthesize.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,24 +233,23 @@ def synthesize_helper(

from ..prediction_writing_callback import get_synthesis_output_callbacks

callbacks = get_synthesis_output_callbacks(
output_type=output_type,
output_dir=output_dir,
config=model.config,
output_key=model.output_key,
device=device,
global_step=global_step,
vocoder_model=vocoder_model,
vocoder_config=vocoder_config,
vocoder_global_step=vocoder_global_step,
)
trainer = Trainer(
logger=False, # We don't need to log things to tensorboard during inference
accelerator=accelerator,
devices=devices,
max_epochs=model.config.training.max_epochs,
callbacks=list(
get_synthesis_output_callbacks(
output_type=output_type,
output_dir=output_dir,
config=model.config,
output_key=model.output_key,
device=device,
global_step=global_step,
vocoder_model=vocoder_model,
vocoder_config=vocoder_config,
vocoder_global_step=vocoder_global_step,
).values()
),
callbacks=list(callbacks.values()),
)
if teacher_forcing_directory is not None:
teacher_forcing = True
Expand All @@ -274,6 +273,7 @@ def synthesize_helper(
),
return_predictions=True,
),
callbacks,
)


Expand Down Expand Up @@ -468,7 +468,8 @@ def synthesize( # noqa: C901
vocoder_model = None
vocoder_config = None
vocoder_global_step = None
return synthesize_helper(

synthesize_helper(
model=model,
texts=texts,
language=language,
Expand Down
4 changes: 1 addition & 3 deletions fs2/prediction_writing_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def __init__(
)
self.text_processor = TextProcessor(config.text)
self.output_key = output_key
logger.info(f"Saving pytorch output to {self.save_dir}")
logger.info(f"Saving text output to {self.save_dir}")

def save_aligned_text_to_file(
self,
Expand Down Expand Up @@ -379,7 +379,6 @@ def __init__(
)
self.text_processor = TextProcessor(config.text)
self.output_key = output_key
logger.info(f"Saving pytorch output to {self.save_dir}")

def save_aligned_text_to_file(
self,
Expand Down Expand Up @@ -430,7 +429,6 @@ def __init__(
self.text_processor = TextProcessor(config.text)
self.output_key = output_key
self.wav_callback = wav_callback
logger.info(f"Saving pytorch output to {self.save_dir}")

def save_aligned_text_to_file(
self,
Expand Down

0 comments on commit fab09c5

Please sign in to comment.