Skip to content

Commit

Permalink
fix: allow teacherforcing with gst module
Browse files Browse the repository at this point in the history
  • Loading branch information
roedoejet committed Jan 14, 2025
1 parent 4663ed0 commit 455ecff
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions fs2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,12 @@ def forward(self, batch, control=InferenceControl(), inference=False):

# Add Global Style Token Embedding
if self.config.model.use_global_style_token_module:
if torch.is_tensor(batch["mel_style_reference"]):
# Used in training and also for synthesis with a reference audio
if inference and torch.is_tensor(batch["mel_style_reference"]):
style_embs = self.gst(batch["mel_style_reference"])
else:
elif inference and not teacher_forcing:
style_embs = self.gst.condition_on_gst_tokens(batch["text"].size(0))
else:
style_embs = self.gst(batch["mel"])
x = x + style_embs.unsqueeze(1)

# Speaker Embedding
Expand Down

0 comments on commit 455ecff

Please sign in to comment.