Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add basic optional support for global style token module #100

Merged
merged 6 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,35 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

For fs2/gst/attn.py:

Copyright (c) 2019, Shigeki Karita. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

For fs2/gst/model.py (sourced from ESPNet2):

Copyright (c) 2020, Nagoya University (Tomoki Hayashi). All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
2 changes: 1 addition & 1 deletion fs2/attn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def forward(
attn = (queries_enc[:, :, :, None] - keys_enc[:, :, None]) ** 2
# compute log likelihood from a gaussian
attn = -0.0005 * attn.sum(1, keepdim=True)
if attn_prior is not None:
if torch.is_tensor(attn_prior):
attn = self.log_softmax(attn) + torch.log(attn_prior[:, None] + 1e-8)

attn_logprob = attn.clone()
Expand Down
43 changes: 42 additions & 1 deletion fs2/cli/synthesize.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def prepare_data(
model: Any,
text_representation: DatasetTextRepresentation,
duration_control: float,
style_reference: Path | None,
) -> list[dict[str, Any]]:
""""""
from everyvoice.utils import slugify
Expand Down Expand Up @@ -154,9 +155,36 @@ def prepare_data(
multi=model.config.model.multispeaker,
)

# We only allow a single style reference right now, so it's fine to load it once here.
if style_reference:
from everyvoice.utils.heavy import get_spectral_transform

spectral_transform = get_spectral_transform(
model.config.preprocessing.audio.spec_type,
model.config.preprocessing.audio.n_fft,
model.config.preprocessing.audio.fft_window_size,
model.config.preprocessing.audio.fft_hop_size,
f_min=model.config.preprocessing.audio.f_min,
f_max=model.config.preprocessing.audio.f_max,
sample_rate=model.config.preprocessing.audio.output_sampling_rate,
n_mels=model.config.preprocessing.audio.n_mels,
)
import torchaudio

style_reference_audio, style_reference_sr = torchaudio.load(style_reference)
if style_reference_sr != model.config.preprocessing.audio.input_sampling_rate:
style_reference_audio = torchaudio.functional.resample(
style_reference_audio,
style_reference_sr,
model.config.preprocessing.audio.input_sampling_rate,
)
style_reference_spec = spectral_transform(style_reference_audio)
# Add duration_control
for item in data:
item["duration_control"] = duration_control
# Add style reference
if style_reference:
item["mel_style_reference"] = style_reference_spec

return data

Expand All @@ -175,6 +203,7 @@ def get_global_step(model_path: Path) -> int:
def synthesize_helper(
model,
texts: list[str],
style_reference: Optional[Path],
language: Optional[str],
speaker: Optional[str],
duration_control: Optional[float],
Expand Down Expand Up @@ -227,8 +256,8 @@ def synthesize_helper(
filelist=filelist,
model=model,
text_representation=text_representation,
style_reference=style_reference,
)

from pytorch_lightning import Trainer

from ..prediction_writing_callback import get_synthesis_output_callbacks
Expand Down Expand Up @@ -270,6 +299,7 @@ def synthesize_helper(
model.lang2id,
model.speaker2id,
teacher_forcing=teacher_forcing,
style_reference=style_reference is not None,
),
return_predictions=True,
),
Expand Down Expand Up @@ -314,6 +344,16 @@ def synthesize( # noqa: C901
"-D",
help="Control the speaking rate of the synthesis. Set a value to multily the durations by, lower numbers produce quicker speaking rates, larger numbers produce slower speaking rates. Default is 1.0",
),
style_reference: Optional[Path] = typer.Option(
None,
"--style-reference",
"-S",
exists=True,
file_okay=True,
dir_okay=False,
help="The path to an audio file containing a style reference. Your text-to-spec must have been trained with the global style token module to use this feature.",
shell_complete=complete_path,
),
speaker: Optional[str] = typer.Option(
None,
"--speaker",
Expand Down Expand Up @@ -472,6 +512,7 @@ def synthesize( # noqa: C901
synthesize_helper(
model=model,
texts=texts,
style_reference=style_reference,
language=language,
speaker=speaker,
duration_control=duration_control,
Expand Down
6 changes: 5 additions & 1 deletion fs2/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)

# FastSpeech2Config's latest version number
LATEST_VERSION: str = "1.0"
LATEST_VERSION: str = "1.1"


class ConformerConfig(ConfigModel):
Expand Down Expand Up @@ -138,6 +138,10 @@ class FastSpeech2ModelConfig(ConfigModel):
True,
description="Whether to jointly learn alignments using monotonic alignment search module (See Badlani et. al. 2021: https://arxiv.org/abs/2108.10447). If set to False, you will have to provide text/audio alignments separately before training a text-to-spec (feature prediction) model.",
)
use_global_style_token_module: bool = Field(
False,
description="Whether to use the Global Style Token (GST) module from Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End Speech Synthesis (https://arxiv.org/abs/1803.09017)",
)
max_length: int = Field(
1000, description="The maximum length (i.e. number of symbols) for text inputs."
)
Expand Down
22 changes: 20 additions & 2 deletions fs2/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
speaker2id: LookupTable,
teacher_forcing=False,
inference=False,
style_reference=False,
):
self.dataset = dataset
self.config = config
Expand All @@ -43,6 +44,7 @@ def __init__(
self.preprocessed_dir = Path(self.config.preprocessing.save_dir)
self.sampling_rate = self.config.preprocessing.audio.input_sampling_rate
self.teacher_forcing = teacher_forcing
self.style_reference = style_reference
self.inference = inference
self.lang2id = lang2id
self.speaker2id = speaker2id
Expand All @@ -57,6 +59,7 @@ def __getitem__(self, index):
"""
Returns dict with keys: {
"mel"
"mel_style_reference"
"duration"
"duration_control"
"pfs"
Expand Down Expand Up @@ -104,6 +107,12 @@ def __getitem__(self, index):
) # [mel_bins, frames] -> [frames, mel_bins]
else:
mel = None

if self.style_reference:
mel_style_reference = item["mel_style_reference"].squeeze(0).transpose(0, 1)
else:
mel_style_reference = None

if (
self.teacher_forcing or not self.inference
) and self.config.model.learn_alignment:
Expand Down Expand Up @@ -176,9 +185,9 @@ def __getitem__(self, index):
else:
energy = None
pitch = None

return {
"mel": mel,
"mel_style_reference": mel_style_reference,
"duration": duration,
"duration_control": duration_control,
"pfs": pfs,
Expand Down Expand Up @@ -208,11 +217,13 @@ def __init__(
inference=False,
teacher_forcing=False,
inference_output_dir=Path("synthesis_output"),
style_reference=False,
):
super().__init__(config=config, inference_output_dir=inference_output_dir)
self.inference = inference
self.prepared = False
self.teacher_forcing = teacher_forcing
self.style_reference = style_reference
self.collate_fn = partial(
self.collate_method, learn_alignment=config.model.learn_alignment
)
Expand Down Expand Up @@ -278,6 +289,7 @@ def prepare_data(self):
self.speaker2id,
inference=self.inference,
teacher_forcing=self.teacher_forcing,
style_reference=self.style_reference,
)
torch.save(self.predict_dataset, self.predict_path)
elif not self.prepared:
Expand Down Expand Up @@ -327,8 +339,14 @@ def __init__(
lang2id: LookupTable,
speaker2id: LookupTable,
teacher_forcing: bool = False,
style_reference=False,
):
super().__init__(config=config, inference=True, teacher_forcing=teacher_forcing)
super().__init__(
config=config,
inference=True,
teacher_forcing=teacher_forcing,
style_reference=style_reference,
)
self.inference = True
self.data = data
self.collate_fn = partial(
Expand Down
Empty file added fs2/gst/__init__.py
Empty file.
Loading
Loading