Skip to content

Commit

Permalink
refactor: factor out the alignment token creation function
Browse files Browse the repository at this point in the history
  • Loading branch information
roedoejet committed Dec 20, 2024
1 parent 42f729b commit 9d7ff68
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 48 deletions.
1 change: 1 addition & 0 deletions fs2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ def forward(self, batch, control=InferenceControl(), inference=False):
"energy_target": variance_adaptor_out["energy_target"],
"pitch_prediction": variance_adaptor_out["pitch_prediction"],
"pitch_target": variance_adaptor_out["pitch_target"],
"text_input": text_inputs,
}

def check_and_upgrade_checkpoint(self, checkpoint):
Expand Down
118 changes: 70 additions & 48 deletions fs2/prediction_writing_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Optional, Sequence

import numpy as np
import numpy.typing as npt
import torch
from everyvoice.model.vocoder.HiFiGAN_iSTFT_lightning.hfgl.config import HiFiGANConfig
from everyvoice.model.vocoder.HiFiGAN_iSTFT_lightning.hfgl.model import HiFiGAN
Expand All @@ -21,6 +22,72 @@
from .type_definitions import SynthesizeOutputFormats


def frames_to_seconds(frames: int, fft_hop_size: int, sampling_rate: int) -> float:
return (frames * fft_hop_size) / sampling_rate


def get_tokens_from_duration_and_labels(
duration_predictions: torch.Tensor,
text: npt.NDArray[np.float32],
raw_text: str,
text_processor: TextProcessor,
config: FastSpeech2Config,
):
# Get all durations in frames
duration_frames = (
torch.clamp(torch.round(torch.exp(duration_predictions) - 1), min=0)
.int()
.tolist()
)
# Get all input labels
tokens: list[int] = text.tolist()
text_labels = text_processor.decode_tokens(tokens, join_character=None)
assert len(duration_frames) == len(
text_labels
), f"can't synthesize {raw_text} because the number of predicted duration steps ({len(duration_frames)}) doesn't equal the number of input text labels ({len(text_labels)})"
# get the duration of the audio: (sum_of_frames * hop_size) / sample_rate
xmax_seconds = frames_to_seconds(
sum(duration_frames),
config.preprocessing.audio.fft_hop_size,
config.preprocessing.audio.output_sampling_rate,
)
# create the tiers
words: list[tuple[float, float, str]] = []
phones: list[tuple[float, float, str]] = []
raw_text_words = raw_text.split()
current_word_duration = 0.0
last_phone_end = 0.0
last_word_end = 0.0
# skip padding
text_labels_no_padding = [tl for tl in text_labels if tl != "\x80"]
duration_frames_no_padding = duration_frames[: len(text_labels_no_padding)]
for label, duration in zip(text_labels_no_padding, duration_frames_no_padding):
# add phone label
phone_duration = frames_to_seconds(
duration,
config.preprocessing.audio.fft_hop_size,
config.preprocessing.audio.output_sampling_rate,
)
current_phone_end = last_phone_end + phone_duration
interval = (last_phone_end, current_phone_end, label)
phones.append(interval)
last_phone_end = current_phone_end
# accumulate phone to word label
current_word_duration += phone_duration
# if label is space or the last phone, add the word and recount
if label == " " or len(phones) == len(text_labels_no_padding):
current_word_end = last_word_end + current_word_duration
interval = (
last_word_end,
current_word_end,
raw_text_words[len(words)],
)
words.append(interval)
last_word_end = current_word_end
current_word_duration = 0
return xmax_seconds, phones, words


def get_synthesis_output_callbacks(
output_type: Sequence[SynthesizeOutputFormats],
output_dir: Path,
Expand Down Expand Up @@ -226,11 +293,6 @@ def save_aligned_text_to_file(
in the desired format."""
raise NotImplementedError

def frames_to_seconds(self, frames: int) -> float:
return (
frames * self.config.preprocessing.audio.fft_hop_size
) / self.config.preprocessing.audio.output_sampling_rate

def on_predict_batch_end( # pyright: ignore [reportIncompatibleMethodOverride]
self,
_trainer,
Expand All @@ -253,50 +315,10 @@ def on_predict_batch_end( # pyright: ignore [reportIncompatibleMethodOverride]
batch["text"], # type: ignore
outputs["duration_prediction"],
):
# Get all durations in frames
duration_frames = (
torch.clamp(torch.round(torch.exp(duration) - 1), min=0).int().tolist()
# Get the phone/word alignment tokens
xmax_seconds, phones, words = get_tokens_from_duration_and_labels(
duration, text, raw_text, self.text_processor, self.config
)
# Get all input labels
tokens: list[int] = text.tolist()
text_labels = self.text_processor.decode_tokens(tokens, join_character=None)
assert len(duration_frames) == len(
text_labels
), f"can't synthesize {raw_text} because the number of predicted duration steps ({len(duration_frames)}) doesn't equal the number of input text labels ({len(text_labels)})"
# get the duration of the audio: (sum_of_frames * hop_size) / sample_rate
xmax_seconds = self.frames_to_seconds(sum(duration_frames))
# create the tiers
words: list[tuple[float, float, str]] = []
phones: list[tuple[float, float, str]] = []
raw_text_words = raw_text.split()
current_word_duration = 0.0
last_phone_end = 0.0
last_word_end = 0.0
# skip padding
text_labels_no_padding = [tl for tl in text_labels if tl != "\x80"]
duration_frames_no_padding = duration_frames[: len(text_labels_no_padding)]
for label, duration in zip(
text_labels_no_padding, duration_frames_no_padding
):
# add phone label
phone_duration = self.frames_to_seconds(duration)
current_phone_end = last_phone_end + phone_duration
interval = (last_phone_end, current_phone_end, label)
phones.append(interval)
last_phone_end = current_phone_end
# accumulate phone to word label
current_word_duration += phone_duration
# if label is space or the last phone, add the word and recount
if label == " " or len(phones) == len(text_labels_no_padding):
current_word_end = last_word_end + current_word_duration
interval = (
last_word_end,
current_word_end,
raw_text_words[len(words)],
)
words.append(interval)
last_word_end = current_word_end
current_word_duration = 0

# Save the output (the subclass has to implement this)
self.save_aligned_text_to_file(
Expand Down

0 comments on commit 9d7ff68

Please sign in to comment.