diff --git a/naturalspeech2_pytorch/naturalspeech2_pytorch.py b/naturalspeech2_pytorch/naturalspeech2_pytorch.py index 385dfae..a66564c 100644 --- a/naturalspeech2_pytorch/naturalspeech2_pytorch.py +++ b/naturalspeech2_pytorch/naturalspeech2_pytorch.py @@ -67,6 +67,15 @@ def has_int_squareroot(num): # tensor helpers +def pad_or_curtail_to_length(t, length): + if t.shape[-1] == length: + return t + + if t.shape[-1] > length: + return t[..., :length] + + return F.pad(t, (0, length - t.shape[-1])) + def prob_mask_like(shape, prob, device): if prob == 1: return torch.ones(shape, device = device, dtype = torch.bool) @@ -834,6 +843,7 @@ def __init__( ) # prompt condition + self.cond_drop_prob = cond_drop_prob # for classifier free guidance self.condition_on_prompt = condition_on_prompt self.to_prompt_cond = None @@ -861,6 +871,15 @@ def __init__( use_flash_attn = use_flash_attn ) + # aligned conditioning from aligner + duration module + + self.null_cond = None + self.cond_to_model_dim = None + + if self.condition_on_prompt: + self.cond_to_model_dim = nn.Conv1d(dim_prompt, dim, 1) + self.null_cond = nn.Parameter(torch.zeros(dim, 1)) + # conditioning includes time and optionally prompt dim_cond_mult = dim_cond_mult * (2 if condition_on_prompt else 1) @@ -913,23 +932,27 @@ def forward( times, prompt = None, prompt_mask = None, - cond= None, + cond = None, cond_drop_prob = None ): b = x.shape[0] cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob) - drop_mask = prob_mask_like((b,), cond_drop_prob, self.device) + # prepare prompt condition + # prob should remove going forward t = self.to_time_cond(times) c = None if exists(self.to_prompt_cond): assert exists(prompt) + + prompt_cond_drop_mask = prob_mask_like((b,), cond_drop_prob, self.device) + prompt_cond = self.to_prompt_cond(prompt) prompt_cond = torch.where( - rearrange(drop_mask, 'b -> b 1'), + rearrange(prompt_cond_drop_mask, 'b -> b 1'), self.null_prompt_cond, prompt_cond, ) @@ -939,12 +962,37 @@ def forward( resampled_prompt_tokens = self.perceiver_resampler(prompt, mask = prompt_mask) c = torch.where( - rearrange(drop_mask, 'b -> b 1 1'), + rearrange(prompt_cond_drop_mask, 'b -> b 1 1'), self.null_prompt_tokens, resampled_prompt_tokens ) + # rearrange to channel first + x = rearrange(x, 'b n d -> b d n') + + # sum aligned condition to input sequence + + if exists(self.cond_to_model_dim): + assert exists(cond) + cond = self.cond_to_model_dim(cond) + + cond_drop_mask = prob_mask_like((b,), cond_drop_prob, self.device) + + cond = torch.where( + rearrange(cond_drop_mask, 'b -> b 1 1'), + self.null_cond, + cond + ) + + # for now, conform the condition to the length of the latent features + + cond = pad_or_curtail_to_length(cond, x.shape[-1]) + + x = x + cond + + # main wavenet body + x = self.wavenet(x, t) x = rearrange(x, 'b d n -> b n d') @@ -1527,6 +1575,7 @@ def forward( duration_pred, pitch_pred = self.duration_pitch(phoneme_enc, prompt_enc) pitch = average_over_durations(pitch, aln_hard) + cond = self.expand_encodings(rearrange(phoneme_enc, 'b n d -> b d n'), rearrange(aln_mask, 'b n c -> b 1 n c'), pitch) # pitch and duration loss @@ -1536,6 +1585,7 @@ def forward( pitch = rearrange(pitch, 'b 1 d -> b d') pitch_loss = F.l1_loss(pitch, pitch_pred) align_loss = self.aligner_loss(aln_log , text_lens, mel_lens) + # weigh the losses aux_loss = (duration_loss * self.duration_loss_weight) \ diff --git a/naturalspeech2_pytorch/version.py b/naturalspeech2_pytorch/version.py index 2fb2513..124e462 100644 --- a/naturalspeech2_pytorch/version.py +++ b/naturalspeech2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.1.6' +__version__ = '0.1.7'