diff --git a/fs2/model.py b/fs2/model.py index 0011f15..aa661b5 100644 --- a/fs2/model.py +++ b/fs2/model.py @@ -193,11 +193,12 @@ def forward(self, batch, control=InferenceControl(), inference=False): # Add Global Style Token Embedding if self.config.model.use_global_style_token_module: - if torch.is_tensor(batch["mel_style_reference"]): - # Used in training and also for synthesis with a reference audio + if inference and torch.is_tensor(batch["mel_style_reference"]): style_embs = self.gst(batch["mel_style_reference"]) - else: + elif inference and not teacher_forcing: style_embs = self.gst.condition_on_gst_tokens(batch["text"].size(0)) + else: + style_embs = self.gst(batch["mel"]) x = x + style_embs.unsqueeze(1) # Speaker Embedding