diff --git a/README.md b/README.md index 66762eb..7729ad5 100644 --- a/README.md +++ b/README.md @@ -150,8 +150,9 @@ trainer.train() - [x] complete perceiver then cross attention conditioning on ddpm side - [x] add classifier free guidance, even if not in paper - [x] complete duration / pitch prediction during training - thanks to Manmay +- [x] make sure pyworld way of computing pitch can also work -- [ ] make sure pyworld way of computing pitch can also work +- [ ] consult phd student in TTS field about pyworld usage - [ ] also offer direct summation conditioning using spear-tts text-to-semantic module, if available - [ ] add self-conditioning on ddpm side - [ ] take care of automatic slicing of audio for prompt, being aware of minimal audio segment as allowed by the codec model diff --git a/naturalspeech2_pytorch/naturalspeech2_pytorch.py b/naturalspeech2_pytorch/naturalspeech2_pytorch.py index e298a76..aaa351b 100644 --- a/naturalspeech2_pytorch/naturalspeech2_pytorch.py +++ b/naturalspeech2_pytorch/naturalspeech2_pytorch.py @@ -56,6 +56,9 @@ def default(val, d): return val return d() if callable(d) else d +def divisible_by(num, den): + return (num % den) == 0 + def identity(t, *args, **kwargs): return t @@ -94,7 +97,7 @@ def generate_mask_from_lengths(lengths): class LearnedSinusoidalPosEmb(nn.Module): def __init__(self, dim): super().__init__() - assert (dim % 2) == 0 + assert divisible_by(dim, 2) half_dim = dim // 2 self.weights = nn.Parameter(torch.randn(half_dim)) @@ -115,19 +118,37 @@ def compute_pitch_pytorch(wav, sample_rate): #as mentioned in paper using pyworld -def compute_pitch(spec, sample_rate, hop_length, pitch_fmax=640.0): - # align F0 length to the spectrogram length - if len(spec) % hop_length == 0: - spec = np.pad(spec, (0, hop_length // 2), mode="reflect") +def compute_pitch_pyworld(wav, sample_rate, hop_length, pitch_fmax=640.0): + is_tensor_input = torch.is_tensor(wav) - f0, t = pw.dio( - spec.astype(np.double), - fs=sample_rate, - f0_ceil=pitch_fmax, - frame_period=1000 * hop_length / sample_rate, - ) - f0 = pw.stonemask(spec.astype(np.double), f0, t, sample_rate) - return f0 + if is_tensor_input: + device = wav.device + wav = wav.contiguous().cpu().numpy() + + if divisible_by(len(wav), hop_length): + wav = np.pad(wav, (0, hop_length // 2), mode="reflect") + + wav = wav.astype(np.double) + + outs = [] + + for sample in wav: + f0, t = pw.dio( + sample, + fs = sample_rate, + f0_ceil = pitch_fmax, + frame_period = 1000 * hop_length / sample_rate, + ) + + f0 = pw.stonemask(sample, f0, t, sample_rate) + outs.append(f0) + + outs = np.stack(outs) + + if is_tensor_input: + outs = torch.from_numpy(outs).to(device) + + return outs def f0_to_coarse(f0, f0_bin = 256, f0_max = 1100.0, f0_min = 50.0): f0_mel_max = 1127 * torch.log(1 + torch.tensor(f0_max) / 700) @@ -1115,6 +1136,8 @@ def __init__( num_phoneme_tokens: int = 150, pitch_emb_dim: int = 256, pitch_emb_pp_hidden_dim: int= 512, + calc_pitch_with_pyworld = True, # pyworld or kaldi from torchaudio + mel_hop_length = 160, audio_to_mel_kwargs: dict = dict(), scale = 1., # this will be set to < 1. for better convergence when training on higher resolution images duration_loss_weight = 1., @@ -1145,11 +1168,16 @@ def __init__( if exists(self.target_sample_hz): audio_to_mel_kwargs.update(sampling_rate = self.target_sample_hz) + self.mel_hop_length = mel_hop_length + self.audio_to_mel = AudioToMel( n_mels = aligner_dim_in, + hop_length = mel_hop_length, **audio_to_mel_kwargs ) + self.calc_pitch_with_pyworld = calc_pitch_with_pyworld + self.phoneme_enc = PhonemeEncoder(tokenizer=tokenizer, num_tokens=num_phoneme_tokens) self.prompt_enc = SpeechPromptEncoder(dim_codebook=dim_codebook) self.duration_pitch = DurationPitchPredictor(dim=duration_pitch_dim) @@ -1456,13 +1484,21 @@ def forward( prompt_enc = self.prompt_enc(prompt) phoneme_enc = self.phoneme_enc(text) - # process pitch + # process pitch with kaldi if not exists(pitch): assert exists(audio) and audio.ndim == 2 assert exists(self.target_sample_hz) - pitch = compute_pitch_pytorch(audio, self.target_sample_hz) + if self.calc_pitch_with_pyworld: + pitch = compute_pitch_pyworld( + audio, + sample_rate = self.target_sample_hz, + hop_length = self.mel_hop_length + ) + else: + pitch = compute_pitch_pytorch(audio, self.target_sample_hz) + pitch = rearrange(pitch, 'b n -> b 1 n') # process mel @@ -1470,7 +1506,9 @@ def forward( if not exists(mel): assert exists(audio) and audio.ndim == 2 mel = self.audio_to_mel(audio) - mel = mel[..., :pitch.shape[-1]] + + if exists(pitch): + mel = mel[..., :pitch.shape[-1]] mel_max_length = mel.shape[-1] @@ -1803,7 +1841,7 @@ def train(self): if accelerator.is_main_process: self.ema.update() - if self.step % self.save_and_sample_every == 0: + if divisible_by(self.step, self.save_and_sample_every): milestone = self.step // self.save_and_sample_every models = [(self.unwrapped_model, str(self.step))] diff --git a/naturalspeech2_pytorch/version.py b/naturalspeech2_pytorch/version.py index df9144c..10939f0 100644 --- a/naturalspeech2_pytorch/version.py +++ b/naturalspeech2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.1.1' +__version__ = '0.1.2'