diff --git a/.github/workflows/style-check.yml b/.github/workflows/style-check.yml new file mode 100644 index 0000000..2430e8d --- /dev/null +++ b/.github/workflows/style-check.yml @@ -0,0 +1,13 @@ +name: Lint and format check + +on: workflow_dispatch + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: psf/black@stable + with: + options: "--check --verbose" + src: "." diff --git a/commons.py b/commons.py index fc38491..5003f2a 100644 --- a/commons.py +++ b/commons.py @@ -6,166 +6,168 @@ def init_weights(m, mean=0.0, std=0.01): - classname = m.__class__.__name__ - if classname.find("Conv") != -1: - m.weight.data.normal_(mean, std) + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) def get_padding(kernel_size, dilation=1): - return int((kernel_size*dilation - dilation)/2) + return int((kernel_size * dilation - dilation) / 2) def convert_pad_shape(pad_shape): - l = pad_shape[::-1] - pad_shape = [item for sublist in l for item in sublist] - return pad_shape + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape def intersperse(lst, item): - result = [item] * (len(lst) * 2 + 1) - result[1::2] = lst - return result + result = [item] * (len(lst) * 2 + 1) + result[1::2] = lst + return result def kl_divergence(m_p, logs_p, m_q, logs_q): - """KL(P||Q)""" - kl = (logs_q - logs_p) - 0.5 - kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q) - return kl + """KL(P||Q)""" + kl = (logs_q - logs_p) - 0.5 + kl += ( + 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) + ) + return kl def rand_gumbel(shape): - """Sample from the Gumbel distribution, protect from overflows.""" - uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 - return -torch.log(-torch.log(uniform_samples)) + """Sample from the Gumbel distribution, protect from overflows.""" + uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 + return -torch.log(-torch.log(uniform_samples)) def rand_gumbel_like(x): - g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) - return g + g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) + return g def slice_segments(x, ids_str, segment_size=4): - ret = torch.zeros_like(x[:, :, :segment_size]) - for i in range(x.size(0)): - idx_str = ids_str[i] - idx_end = idx_str + segment_size - ret[i] = x[i, :, idx_str:idx_end] - return ret + ret = torch.zeros_like(x[:, :, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + ret[i] = x[i, :, idx_str:idx_end] + return ret def rand_slice_segments(x, x_lengths=None, segment_size=4): - b, d, t = x.size() - if x_lengths is None: - x_lengths = t - ids_str_max = x_lengths - segment_size + 1 - ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) - ret = slice_segments(x, ids_str, segment_size) - return ret, ids_str + b, d, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = x_lengths - segment_size + 1 + ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) + ret = slice_segments(x, ids_str, segment_size) + return ret, ids_str def rand_spec_segments(x, x_lengths=None, segment_size=4): - b, d, t = x.size() - if x_lengths is None: - x_lengths = t - ids_str_max = x_lengths - segment_size - ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) - ret = slice_segments(x, ids_str, segment_size) - return ret, ids_str - - -def get_timing_signal_1d( - length, channels, min_timescale=1.0, max_timescale=1.0e4): - position = torch.arange(length, dtype=torch.float) - num_timescales = channels // 2 - log_timescale_increment = ( - math.log(float(max_timescale) / float(min_timescale)) / - (num_timescales - 1)) - inv_timescales = min_timescale * torch.exp( - torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment) - scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) - signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) - signal = F.pad(signal, [0, 0, 0, channels % 2]) - signal = signal.view(1, channels, length) - return signal + b, d, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = x_lengths - segment_size + ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) + ret = slice_segments(x, ids_str, segment_size) + return ret, ids_str + + +def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): + position = torch.arange(length, dtype=torch.float) + num_timescales = channels // 2 + log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / ( + num_timescales - 1 + ) + inv_timescales = min_timescale * torch.exp( + torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment + ) + scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) + signal = F.pad(signal, [0, 0, 0, channels % 2]) + signal = signal.view(1, channels, length) + return signal def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): - b, channels, length = x.size() - signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) - return x + signal.to(dtype=x.dtype, device=x.device) + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return x + signal.to(dtype=x.dtype, device=x.device) def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): - b, channels, length = x.size() - signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) - return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) def subsequent_mask(length): - mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) - return mask + mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) + return mask @torch.jit.script def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): - n_channels_int = n_channels[0] - in_act = input_a + input_b - t_act = torch.tanh(in_act[:, :n_channels_int, :]) - s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) - acts = t_act * s_act - return acts + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts def convert_pad_shape(pad_shape): - l = pad_shape[::-1] - pad_shape = [item for sublist in l for item in sublist] - return pad_shape + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape def shift_1d(x): - x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] - return x + x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] + return x def sequence_mask(length, max_length=None): - if max_length is None: - max_length = length.max() - x = torch.arange(max_length, dtype=length.dtype, device=length.device) - return x.unsqueeze(0) < length.unsqueeze(1) + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) def generate_path(duration, mask): - """ - duration: [b, 1, t_x] - mask: [b, 1, t_y, t_x] - """ - device = duration.device - - b, _, t_y, t_x = mask.shape - cum_duration = torch.cumsum(duration, -1) - - cum_duration_flat = cum_duration.view(b * t_x) - path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) - path = path.view(b, t_x, t_y) - path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] - path = path.unsqueeze(1).transpose(2,3) * mask - return path + """ + duration: [b, 1, t_x] + mask: [b, 1, t_y, t_x] + """ + device = duration.device + + b, _, t_y, t_x = mask.shape + cum_duration = torch.cumsum(duration, -1) + + cum_duration_flat = cum_duration.view(b * t_x) + path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) + path = path.view(b, t_x, t_y) + path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] + path = path.unsqueeze(1).transpose(2, 3) * mask + return path def clip_grad_value_(parameters, clip_value, norm_type=2): - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - parameters = list(filter(lambda p: p.grad is not None, parameters)) - norm_type = float(norm_type) - if clip_value is not None: - clip_value = float(clip_value) - - total_norm = 0 - for p in parameters: - param_norm = p.grad.data.norm(norm_type) - total_norm += param_norm.item() ** norm_type + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + norm_type = float(norm_type) if clip_value is not None: - p.grad.data.clamp_(min=-clip_value, max=clip_value) - total_norm = total_norm ** (1. / norm_type) - return total_norm + clip_value = float(clip_value) + + total_norm = 0 + for p in parameters: + param_norm = p.grad.data.norm(norm_type) + total_norm += param_norm.item() ** norm_type + if clip_value is not None: + p.grad.data.clamp_(min=-clip_value, max=clip_value) + total_norm = total_norm ** (1.0 / norm_type) + return total_norm diff --git a/convert.py b/convert.py index bcbb624..4c604aa 100644 --- a/convert.py +++ b/convert.py @@ -12,18 +12,30 @@ from wavlm import WavLM, WavLMConfig from speaker_encoder.voice_encoder import SpeakerEncoder import logging -logging.getLogger('numba').setLevel(logging.WARNING) + +logging.getLogger("numba").setLevel(logging.WARNING) if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--hpfile", type=str, default="configs/freevc.json", help="path to json config file") - parser.add_argument("--ptfile", type=str, default="checkpoints/freevc.pth", help="path to pth file") - parser.add_argument("--txtpath", type=str, default="convert.txt", help="path to txt file") - parser.add_argument("--outdir", type=str, default="output/freevc", help="path to output dir") + parser.add_argument( + "--hpfile", + type=str, + default="configs/freevc.json", + help="path to json config file", + ) + parser.add_argument( + "--ptfile", type=str, default="checkpoints/freevc.pth", help="path to pth file" + ) + parser.add_argument( + "--txtpath", type=str, default="convert.txt", help="path to txt file" + ) + parser.add_argument( + "--outdir", type=str, default="output/freevc", help="path to output dir" + ) parser.add_argument("--use_timestamp", default=False, action="store_true") args = parser.parse_args() - + os.makedirs(args.outdir, exist_ok=True) hps = utils.get_hparams_from_file(args.hpfile) @@ -31,17 +43,18 @@ net_g = SynthesizerTrn( hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, - **hps.model).cuda() + **hps.model, + ).cuda() _ = net_g.eval() print("Loading checkpoint...") _ = utils.load_checkpoint(args.ptfile, net_g, None, True) print("Loading WavLM for content...") cmodel = utils.get_cmodel(0) - + if hps.model.use_spk: print("Loading speaker encoder...") - smodel = SpeakerEncoder('speaker_encoder/ckpt/pretrained_bak_5805000.pt') + smodel = SpeakerEncoder("speaker_encoder/ckpt/pretrained_bak_5805000.pt") print("Processing text...") titles, srcs, tgts = [], [], [] @@ -65,20 +78,20 @@ else: wav_tgt = torch.from_numpy(wav_tgt).unsqueeze(0).cuda() mel_tgt = mel_spectrogram_torch( - wav_tgt, + wav_tgt, hps.data.filter_length, hps.data.n_mel_channels, hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, hps.data.mel_fmin, - hps.data.mel_fmax + hps.data.mel_fmax, ) # src wav_src, _ = librosa.load(src, sr=hps.data.sampling_rate) wav_src = torch.from_numpy(wav_src).unsqueeze(0).cuda() c = utils.get_content(cmodel, wav_src) - + if hps.model.use_spk: audio = net_g.infer(c, g=g_tgt) else: @@ -86,7 +99,14 @@ audio = audio[0][0].data.cpu().float().numpy() if args.use_timestamp: timestamp = time.strftime("%m-%d_%H-%M", time.localtime()) - write(os.path.join(args.outdir, "{}.wav".format(timestamp+"_"+title)), hps.data.sampling_rate, audio) + write( + os.path.join(args.outdir, "{}.wav".format(timestamp + "_" + title)), + hps.data.sampling_rate, + audio, + ) else: - write(os.path.join(args.outdir, f"{title}.wav"), hps.data.sampling_rate, audio) - + write( + os.path.join(args.outdir, f"{title}.wav"), + hps.data.sampling_rate, + audio, + ) diff --git a/data_utils.py b/data_utils.py index 3114db6..2c82925 100644 --- a/data_utils.py +++ b/data_utils.py @@ -5,27 +5,31 @@ import torch import torch.utils.data -import commons +import commons from mel_processing import spectrogram_torch, spec_to_mel_torch from utils import load_wav_to_torch, load_filepaths_and_text, transform -#import h5py + +# import h5py """Multi speaker version""" + + class TextAudioSpeakerLoader(torch.utils.data.Dataset): """ - 1) loads audio, speaker_id, text pairs - 2) normalizes text and converts them to sequences of integers - 3) computes spectrograms from audio files. + 1) loads audio, speaker_id, text pairs + 2) normalizes text and converts them to sequences of integers + 3) computes spectrograms from audio files. """ + def __init__(self, audiopaths, hparams): self.audiopaths = load_filepaths_and_text(audiopaths) self.max_wav_value = hparams.data.max_wav_value self.sampling_rate = hparams.data.sampling_rate - self.filter_length = hparams.data.filter_length - self.hop_length = hparams.data.hop_length - self.win_length = hparams.data.win_length - self.sampling_rate = hparams.data.sampling_rate + self.filter_length = hparams.data.filter_length + self.hop_length = hparams.data.hop_length + self.win_length = hparams.data.win_length + self.sampling_rate = hparams.data.sampling_rate self.use_sr = hparams.train.use_sr self.use_spk = hparams.model.use_spk self.spec_len = hparams.train.max_speclen @@ -50,47 +54,55 @@ def _filter(self): def get_audio(self, filename): audio, sampling_rate = load_wav_to_torch(filename) if sampling_rate != self.sampling_rate: - raise ValueError("{} SR doesn't match target {} SR".format( - sampling_rate, self.sampling_rate)) + raise ValueError( + "{} SR doesn't match target {} SR".format( + sampling_rate, self.sampling_rate + ) + ) audio_norm = audio / self.max_wav_value audio_norm = audio_norm.unsqueeze(0) spec_filename = filename.replace(".wav", ".spec.pt") if os.path.exists(spec_filename): spec = torch.load(spec_filename) else: - spec = spectrogram_torch(audio_norm, self.filter_length, - self.sampling_rate, self.hop_length, self.win_length, - center=False) + spec = spectrogram_torch( + audio_norm, + self.filter_length, + self.sampling_rate, + self.hop_length, + self.win_length, + center=False, + ) spec = torch.squeeze(spec, 0) torch.save(spec, spec_filename) - + if self.use_spk: spk_filename = filename.replace(".wav", ".npy") spk_filename = spk_filename.replace("DUMMY", "dataset/spk") spk = torch.from_numpy(np.load(spk_filename)) - + if not self.use_sr: c_filename = filename.replace(".wav", ".pt") c_filename = c_filename.replace("DUMMY", "dataset/wavlm") c = torch.load(c_filename).squeeze(0) else: - i = random.randint(68,92) - ''' + i = random.randint(68, 92) + """ basename = os.path.basename(filename)[:-4] spkname = basename[:4] #print(basename, spkname) with h5py.File(f"dataset/rs/wavlm/{spkname}/{i}.hdf5","r") as f: c = torch.from_numpy(f[basename][()]).squeeze(0) #print(c) - ''' + """ c_filename = filename.replace(".wav", f"_{i}.pt") c_filename = c_filename.replace("DUMMY", "dataset/sr/wavlm") c = torch.load(c_filename).squeeze(0) - + # 2023.01.10 update: code below can deteriorate model performance # I added these code during cleaning up, thinking that it can offer better performance than my provided checkpoints, but actually it does the opposite. # What an act of 'adding legs to a snake'! - ''' + """ lmin = min(c.size(-1), spec.size(-1)) spec, c = spec[:, :lmin], c[:, :lmin] audio_norm = audio_norm[:, :lmin*self.hop_length] @@ -104,8 +116,8 @@ def get_audio(self, filename): spec = spec[:, start:end] c = c[:, start:end] audio_norm = audio_norm[:, start*self.hop_length:end*self.hop_length] - ''' - + """ + if self.use_spk: return c, spec, audio_norm, spk else: @@ -118,9 +130,9 @@ def __len__(self): return len(self.audiopaths) -class TextAudioSpeakerCollate(): - """ Zero-pads model inputs and targets - """ +class TextAudioSpeakerCollate: + """Zero-pads model inputs and targets""" + def __init__(self, hps): self.hps = hps self.use_sr = hps.train.use_sr @@ -134,8 +146,8 @@ def __call__(self, batch): """ # Right zero-pad all one-hot text sequences to max input length _, ids_sorted_decreasing = torch.sort( - torch.LongTensor([x[0].size(1) for x in batch]), - dim=0, descending=True) + torch.LongTensor([x[0].size(1) for x in batch]), dim=0, descending=True + ) max_spec_len = max([x[1].size(1) for x in batch]) max_wav_len = max([x[2].size(1) for x in batch]) @@ -146,46 +158,54 @@ def __call__(self, batch): spks = torch.FloatTensor(len(batch), batch[0][3].size(0)) else: spks = None - + c_padded = torch.FloatTensor(len(batch), batch[0][0].size(0), max_spec_len) spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len) wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len) c_padded.zero_() spec_padded.zero_() wav_padded.zero_() - + for i in range(len(ids_sorted_decreasing)): row = batch[ids_sorted_decreasing[i]] - + c = row[0] - c_padded[i, :, :c.size(1)] = c + c_padded[i, :, : c.size(1)] = c spec = row[1] - spec_padded[i, :, :spec.size(1)] = spec + spec_padded[i, :, : spec.size(1)] = spec spec_lengths[i] = spec.size(1) wav = row[2] - wav_padded[i, :, :wav.size(1)] = wav + wav_padded[i, :, : wav.size(1)] = wav wav_lengths[i] = wav.size(1) - + if self.use_spk: spks[i] = row[3] - - spec_seglen = spec_lengths[-1] if spec_lengths[-1] < self.hps.train.max_speclen + 1 else self.hps.train.max_speclen + 1 - wav_seglen = spec_seglen * self.hps.data.hop_length - - spec_padded, ids_slice = commons.rand_spec_segments(spec_padded, spec_lengths, spec_seglen) - wav_padded = commons.slice_segments(wav_padded, ids_slice * self.hps.data.hop_length, wav_seglen) - - c_padded = commons.slice_segments(c_padded, ids_slice, spec_seglen)[:,:,:-1] - - spec_padded = spec_padded[:,:,:-1] - wav_padded = wav_padded[:,:,:-self.hps.data.hop_length] + + spec_seglen = ( + spec_lengths[-1] + if spec_lengths[-1] < self.hps.train.max_speclen + 1 + else self.hps.train.max_speclen + 1 + ) + wav_seglen = spec_seglen * self.hps.data.hop_length + + spec_padded, ids_slice = commons.rand_spec_segments( + spec_padded, spec_lengths, spec_seglen + ) + wav_padded = commons.slice_segments( + wav_padded, ids_slice * self.hps.data.hop_length, wav_seglen + ) + + c_padded = commons.slice_segments(c_padded, ids_slice, spec_seglen)[:, :, :-1] + + spec_padded = spec_padded[:, :, :-1] + wav_padded = wav_padded[:, :, : -self.hps.data.hop_length] if self.use_spk: - return c_padded, spec_padded, wav_padded, spks + return c_padded, spec_padded, wav_padded, spks else: - return c_padded, spec_padded, wav_padded + return c_padded, spec_padded, wav_padded class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler): @@ -193,20 +213,29 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler): Maintain similar input lengths in a batch. Length groups are specified by boundaries. Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}. - + It removes samples which are not included in the boundaries. Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded. """ - def __init__(self, dataset, batch_size, boundaries, num_replicas=None, rank=None, shuffle=True): + + def __init__( + self, + dataset, + batch_size, + boundaries, + num_replicas=None, + rank=None, + shuffle=True, + ): super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) self.lengths = dataset.lengths self.batch_size = batch_size self.boundaries = boundaries - + self.buckets, self.num_samples_per_bucket = self._create_buckets() self.total_size = sum(self.num_samples_per_bucket) self.num_samples = self.total_size // self.num_replicas - + def _create_buckets(self): buckets = [[] for _ in range(len(self.boundaries) - 1)] for i in range(len(self.lengths)): @@ -214,74 +243,85 @@ def _create_buckets(self): idx_bucket = self._bisect(length) if idx_bucket != -1: buckets[idx_bucket].append(i) - + for i in range(len(buckets) - 1, 0, -1): if len(buckets[i]) == 0: buckets.pop(i) - self.boundaries.pop(i+1) - + self.boundaries.pop(i + 1) + num_samples_per_bucket = [] for i in range(len(buckets)): len_bucket = len(buckets[i]) total_batch_size = self.num_replicas * self.batch_size - rem = (total_batch_size - (len_bucket % total_batch_size)) % total_batch_size + rem = ( + total_batch_size - (len_bucket % total_batch_size) + ) % total_batch_size num_samples_per_bucket.append(len_bucket + rem) return buckets, num_samples_per_bucket - + def __iter__(self): - # deterministically shuffle based on epoch - g = torch.Generator() - g.manual_seed(self.epoch) - - indices = [] - if self.shuffle: - for bucket in self.buckets: - indices.append(torch.randperm(len(bucket), generator=g).tolist()) - else: - for bucket in self.buckets: - indices.append(list(range(len(bucket)))) - - batches = [] - for i in range(len(self.buckets)): - bucket = self.buckets[i] - len_bucket = len(bucket) - ids_bucket = indices[i] - num_samples_bucket = self.num_samples_per_bucket[i] - - # add extra samples to make it evenly divisible - rem = num_samples_bucket - len_bucket - ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[:(rem % len_bucket)] - - # subsample - ids_bucket = ids_bucket[self.rank::self.num_replicas] - - # batching - for j in range(len(ids_bucket) // self.batch_size): - batch = [bucket[idx] for idx in ids_bucket[j*self.batch_size:(j+1)*self.batch_size]] - batches.append(batch) - - if self.shuffle: - batch_ids = torch.randperm(len(batches), generator=g).tolist() - batches = [batches[i] for i in batch_ids] - self.batches = batches - - assert len(self.batches) * self.batch_size == self.num_samples - return iter(self.batches) - + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + + indices = [] + if self.shuffle: + for bucket in self.buckets: + indices.append(torch.randperm(len(bucket), generator=g).tolist()) + else: + for bucket in self.buckets: + indices.append(list(range(len(bucket)))) + + batches = [] + for i in range(len(self.buckets)): + bucket = self.buckets[i] + len_bucket = len(bucket) + ids_bucket = indices[i] + num_samples_bucket = self.num_samples_per_bucket[i] + + # add extra samples to make it evenly divisible + rem = num_samples_bucket - len_bucket + ids_bucket = ( + ids_bucket + + ids_bucket * (rem // len_bucket) + + ids_bucket[: (rem % len_bucket)] + ) + + # subsample + ids_bucket = ids_bucket[self.rank :: self.num_replicas] + + # batching + for j in range(len(ids_bucket) // self.batch_size): + batch = [ + bucket[idx] + for idx in ids_bucket[ + j * self.batch_size : (j + 1) * self.batch_size + ] + ] + batches.append(batch) + + if self.shuffle: + batch_ids = torch.randperm(len(batches), generator=g).tolist() + batches = [batches[i] for i in batch_ids] + self.batches = batches + + assert len(self.batches) * self.batch_size == self.num_samples + return iter(self.batches) + def _bisect(self, x, lo=0, hi=None): - if hi is None: - hi = len(self.boundaries) - 1 - - if hi > lo: - mid = (hi + lo) // 2 - if self.boundaries[mid] < x and x <= self.boundaries[mid+1]: - return mid - elif x <= self.boundaries[mid]: - return self._bisect(x, lo, mid) - else: - return self._bisect(x, mid + 1, hi) - else: - return -1 + if hi is None: + hi = len(self.boundaries) - 1 + + if hi > lo: + mid = (hi + lo) // 2 + if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]: + return mid + elif x <= self.boundaries[mid]: + return self._bisect(x, lo, mid) + else: + return self._bisect(x, mid + 1, hi) + else: + return -1 def __len__(self): return self.num_samples // self.batch_size diff --git a/downsample.py b/downsample.py index 8882e8f..f632727 100644 --- a/downsample.py +++ b/downsample.py @@ -11,7 +11,7 @@ def process(wav_name): # speaker 's5', 'p280', 'p315' are excluded, speaker = wav_name[:4] wav_path = os.path.join(args.in_dir, speaker, wav_name) - if os.path.exists(wav_path) and '_mic2.flac' in wav_path: + if os.path.exists(wav_path) and "_mic2.flac" in wav_path: os.makedirs(os.path.join(args.out_dir1, speaker), exist_ok=True) os.makedirs(os.path.join(args.out_dir2, speaker), exist_ok=True) wav, sr = librosa.load(wav_path) @@ -25,14 +25,10 @@ def process(wav_name): save_path1 = os.path.join(args.out_dir1, speaker, save_name) save_path2 = os.path.join(args.out_dir2, speaker, save_name) wavfile.write( - save_path1, - args.sr1, - (wav1 * np.iinfo(np.int16).max).astype(np.int16) + save_path1, args.sr1, (wav1 * np.iinfo(np.int16).max).astype(np.int16) ) wavfile.write( - save_path2, - args.sr2, - (wav2 * np.iinfo(np.int16).max).astype(np.int16) + save_path2, args.sr2, (wav2 * np.iinfo(np.int16).max).astype(np.int16) ) @@ -40,16 +36,24 @@ def process(wav_name): parser = argparse.ArgumentParser() parser.add_argument("--sr1", type=int, default=16000, help="sampling rate") parser.add_argument("--sr2", type=int, default=22050, help="sampling rate") - parser.add_argument("--in_dir", type=str, default="/home/Datasets/lijingyi/data/vctk/wav48_silence_trimmed/", help="path to source dir") - parser.add_argument("--out_dir1", type=str, default="./dataset/vctk-16k", help="path to target dir") - parser.add_argument("--out_dir2", type=str, default="./dataset/vctk-22k", help="path to target dir") + parser.add_argument( + "--in_dir", + type=str, + default="/home/Datasets/lijingyi/data/vctk/wav48_silence_trimmed/", + help="path to source dir", + ) + parser.add_argument( + "--out_dir1", type=str, default="./dataset/vctk-16k", help="path to target dir" + ) + parser.add_argument( + "--out_dir2", type=str, default="./dataset/vctk-22k", help="path to target dir" + ) args = parser.parse_args() - pool = Pool(processes=cpu_count()-2) + pool = Pool(processes=cpu_count() - 2) for speaker in os.listdir(args.in_dir): spk_dir = os.path.join(args.in_dir, speaker) if os.path.isdir(spk_dir): for _ in tqdm(pool.imap_unordered(process, os.listdir(spk_dir))): pass - diff --git a/hifigan/__init__.py b/hifigan/__init__.py index f0d12ff..22627e2 100644 --- a/hifigan/__init__.py +++ b/hifigan/__init__.py @@ -4,4 +4,4 @@ class AttrDict(dict): def __init__(self, *args, **kwargs): super(AttrDict, self).__init__(*args, **kwargs) - self.__dict__ = self \ No newline at end of file + self.__dict__ = self diff --git a/hifigan/models.py b/hifigan/models.py index 985a8ec..708898f 100644 --- a/hifigan/models.py +++ b/hifigan/models.py @@ -125,7 +125,7 @@ def __init__(self, h): self.ups.append( weight_norm( ConvTranspose1d( - h.upsample_initial_channel // (2 ** i), + h.upsample_initial_channel // (2**i), h.upsample_initial_channel // (2 ** (i + 1)), k, u, @@ -171,4 +171,4 @@ def remove_weight_norm(self): for l in self.resblocks: l.remove_weight_norm() remove_weight_norm(self.conv_pre) - remove_weight_norm(self.conv_post) \ No newline at end of file + remove_weight_norm(self.conv_post) diff --git a/losses.py b/losses.py index 41f9be6..2d34837 100644 --- a/losses.py +++ b/losses.py @@ -1,61 +1,61 @@ -import torch +import torch from torch.nn import functional as F import commons def feature_loss(fmap_r, fmap_g): - loss = 0 - for dr, dg in zip(fmap_r, fmap_g): - for rl, gl in zip(dr, dg): - rl = rl.float().detach() - gl = gl.float() - loss += torch.mean(torch.abs(rl - gl)) + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + rl = rl.float().detach() + gl = gl.float() + loss += torch.mean(torch.abs(rl - gl)) - return loss * 2 + return loss * 2 def discriminator_loss(disc_real_outputs, disc_generated_outputs): - loss = 0 - r_losses = [] - g_losses = [] - for dr, dg in zip(disc_real_outputs, disc_generated_outputs): - dr = dr.float() - dg = dg.float() - r_loss = torch.mean((1-dr)**2) - g_loss = torch.mean(dg**2) - loss += (r_loss + g_loss) - r_losses.append(r_loss.item()) - g_losses.append(g_loss.item()) + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + dr = dr.float() + dg = dg.float() + r_loss = torch.mean((1 - dr) ** 2) + g_loss = torch.mean(dg**2) + loss += r_loss + g_loss + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) - return loss, r_losses, g_losses + return loss, r_losses, g_losses def generator_loss(disc_outputs): - loss = 0 - gen_losses = [] - for dg in disc_outputs: - dg = dg.float() - l = torch.mean((1-dg)**2) - gen_losses.append(l) - loss += l + loss = 0 + gen_losses = [] + for dg in disc_outputs: + dg = dg.float() + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l - return loss, gen_losses + return loss, gen_losses def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): - """ - z_p, logs_q: [b, h, t_t] - m_p, logs_p: [b, h, t_t] - """ - z_p = z_p.float() - logs_q = logs_q.float() - m_p = m_p.float() - logs_p = logs_p.float() - z_mask = z_mask.float() - #print(logs_p) - kl = logs_p - logs_q - 0.5 - kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2. * logs_p) - kl = torch.sum(kl * z_mask) - l = kl / torch.sum(z_mask) - return l + """ + z_p, logs_q: [b, h, t_t] + m_p, logs_p: [b, h, t_t] + """ + z_p = z_p.float() + logs_q = logs_q.float() + m_p = m_p.float() + logs_p = logs_p.float() + z_mask = z_mask.float() + # print(logs_p) + kl = logs_p - logs_q - 0.5 + kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p) + kl = torch.sum(kl * z_mask) + l = kl / torch.sum(z_mask) + return l diff --git a/mel_processing.py b/mel_processing.py index 99c5b35..b829bd0 100644 --- a/mel_processing.py +++ b/mel_processing.py @@ -49,22 +49,38 @@ def spectral_de_normalize_torch(magnitudes): def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): - if torch.min(y) < -1.: - print('min value is ', torch.min(y)) - if torch.max(y) > 1.: - print('max value is ', torch.max(y)) + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) global hann_window - dtype_device = str(y.dtype) + '_' + str(y.device) - wnsize_dtype_device = str(win_size) + '_' + dtype_device + dtype_device = str(y.dtype) + "_" + str(y.device) + wnsize_dtype_device = str(win_size) + "_" + dtype_device if wnsize_dtype_device not in hann_window: - hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) - - y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') + hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to( + dtype=y.dtype, device=y.device + ) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), + mode="reflect", + ) y = y.squeeze(1) - spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], - center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) + spec = torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=False, + ) spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) return spec @@ -72,37 +88,63 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False) def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): global mel_basis - dtype_device = str(spec.dtype) + '_' + str(spec.device) - fmax_dtype_device = str(fmax) + '_' + dtype_device + dtype_device = str(spec.dtype) + "_" + str(spec.device) + fmax_dtype_device = str(fmax) + "_" + dtype_device if fmax_dtype_device not in mel_basis: - mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) - mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) + mel = librosa_mel_fn( + sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax + ) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to( + dtype=spec.dtype, device=spec.device + ) spec = torch.matmul(mel_basis[fmax_dtype_device], spec) spec = spectral_normalize_torch(spec) return spec -def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): - if torch.min(y) < -1.: - print('min value is ', torch.min(y)) - if torch.max(y) > 1.: - print('max value is ', torch.max(y)) +def mel_spectrogram_torch( + y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False +): + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) global mel_basis, hann_window - dtype_device = str(y.dtype) + '_' + str(y.device) - fmax_dtype_device = str(fmax) + '_' + dtype_device - wnsize_dtype_device = str(win_size) + '_' + dtype_device + dtype_device = str(y.dtype) + "_" + str(y.device) + fmax_dtype_device = str(fmax) + "_" + dtype_device + wnsize_dtype_device = str(win_size) + "_" + dtype_device if fmax_dtype_device not in mel_basis: - mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) - mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) + mel = librosa_mel_fn( + sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax + ) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to( + dtype=y.dtype, device=y.device + ) if wnsize_dtype_device not in hann_window: - hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) - - y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') + hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to( + dtype=y.dtype, device=y.device + ) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), + mode="reflect", + ) y = y.squeeze(1) - spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], - center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) + spec = torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=False, + ) spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) diff --git a/models.py b/models.py index 46b8aac..1f8e4ca 100644 --- a/models.py +++ b/models.py @@ -13,88 +13,132 @@ class ResidualCouplingBlock(nn.Module): - def __init__(self, - channels, - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - n_flows=4, - gin_channels=0): - super().__init__() - self.channels = channels - self.hidden_channels = hidden_channels - self.kernel_size = kernel_size - self.dilation_rate = dilation_rate - self.n_layers = n_layers - self.n_flows = n_flows - self.gin_channels = gin_channels - - self.flows = nn.ModuleList() - for i in range(n_flows): - self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True)) - self.flows.append(modules.Flip()) - - def forward(self, x, x_mask, g=None, reverse=False): - if not reverse: - for flow in self.flows: - x, _ = flow(x, x_mask, g=g, reverse=reverse) - else: - for flow in reversed(self.flows): - x = flow(x, x_mask, g=g, reverse=reverse) - return x + def __init__( + self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + n_flows=4, + gin_channels=0, + ): + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.flows = nn.ModuleList() + for i in range(n_flows): + self.flows.append( + modules.ResidualCouplingLayer( + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=gin_channels, + mean_only=True, + ) + ) + self.flows.append(modules.Flip()) + + def forward(self, x, x_mask, g=None, reverse=False): + if not reverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, reverse=reverse) + else: + for flow in reversed(self.flows): + x = flow(x, x_mask, g=g, reverse=reverse) + return x class Encoder(nn.Module): - def __init__(self, - in_channels, - out_channels, - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - gin_channels=0): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.hidden_channels = hidden_channels - self.kernel_size = kernel_size - self.dilation_rate = dilation_rate - self.n_layers = n_layers - self.gin_channels = gin_channels - - self.pre = nn.Conv1d(in_channels, hidden_channels, 1) - self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels) - self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - - def forward(self, x, x_lengths, g=None): - x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) - x = self.pre(x) * x_mask - x = self.enc(x, x_mask, g=g) - stats = self.proj(x) * x_mask - m, logs = torch.split(stats, self.out_channels, dim=1) - z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask - return z, m, logs, x_mask + def __init__( + self, + in_channels, + out_channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=0, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.enc = modules.WN( + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=gin_channels, + ) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths, g=None): + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( + x.dtype + ) + x = self.pre(x) * x_mask + x = self.enc(x, x_mask, g=g) + stats = self.proj(x) * x_mask + m, logs = torch.split(stats, self.out_channels, dim=1) + z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask + return z, m, logs, x_mask class Generator(torch.nn.Module): - def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0): + def __init__( + self, + initial_channel, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=0, + ): super(Generator, self).__init__() self.num_kernels = len(resblock_kernel_sizes) self.num_upsamples = len(upsample_rates) - self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3) - resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2 + self.conv_pre = Conv1d( + initial_channel, upsample_initial_channel, 7, 1, padding=3 + ) + resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2 self.ups = nn.ModuleList() for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): - self.ups.append(weight_norm( - ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)), - k, u, padding=(k-u)//2))) + self.ups.append( + weight_norm( + ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) self.resblocks = nn.ModuleList() for i in range(len(self.ups)): - ch = upsample_initial_channel//(2**(i+1)) - for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate( + zip(resblock_kernel_sizes, resblock_dilation_sizes) + ): self.resblocks.append(resblock(ch, k, d)) self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) @@ -106,7 +150,7 @@ def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_di def forward(self, x, g=None): x = self.conv_pre(x) if g is not None: - x = x + self.cond(g) + x = x + self.cond(g) for i in range(self.num_upsamples): x = F.leaky_relu(x, modules.LRELU_SLOPE) @@ -114,9 +158,9 @@ def forward(self, x, g=None): xs = None for j in range(self.num_kernels): if xs is None: - xs = self.resblocks[i*self.num_kernels+j](x) + xs = self.resblocks[i * self.num_kernels + j](x) else: - xs += self.resblocks[i*self.num_kernels+j](x) + xs += self.resblocks[i * self.num_kernels + j](x) x = xs / self.num_kernels x = F.leaky_relu(x) x = self.conv_post(x) @@ -125,7 +169,7 @@ def forward(self, x, g=None): return x def remove_weight_norm(self): - print('Removing weight norm...') + print("Removing weight norm...") for l in self.ups: remove_weight_norm(l) for l in self.resblocks: @@ -138,13 +182,55 @@ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): self.period = period self.use_spectral_norm = use_spectral_norm norm_f = weight_norm if use_spectral_norm == False else spectral_norm - self.convs = nn.ModuleList([ - norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), - norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), - norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), - norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), - norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))), - ]) + self.convs = nn.ModuleList( + [ + norm_f( + Conv2d( + 1, + 32, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 32, + 128, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 128, + 512, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 512, + 1024, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 1024, + 1024, + (kernel_size, 1), + 1, + padding=(get_padding(kernel_size, 1), 0), + ) + ), + ] + ) self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) def forward(self, x): @@ -152,7 +238,7 @@ def forward(self, x): # 1d to 2d b, c, t = x.shape - if t % self.period != 0: # pad first + if t % self.period != 0: # pad first n_pad = self.period - (t % self.period) x = F.pad(x, (0, n_pad), "reflect") t = t + n_pad @@ -173,14 +259,16 @@ class DiscriminatorS(torch.nn.Module): def __init__(self, use_spectral_norm=False): super(DiscriminatorS, self).__init__() norm_f = weight_norm if use_spectral_norm == False else spectral_norm - self.convs = nn.ModuleList([ - norm_f(Conv1d(1, 16, 15, 1, padding=7)), - norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), - norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), - norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), - norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), - norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), - ]) + self.convs = nn.ModuleList( + [ + norm_f(Conv1d(1, 16, 15, 1, padding=7)), + norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), + norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ] + ) self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) def forward(self, x): @@ -200,10 +288,12 @@ def forward(self, x): class MultiPeriodDiscriminator(torch.nn.Module): def __init__(self, use_spectral_norm=False): super(MultiPeriodDiscriminator, self).__init__() - periods = [2,3,5,7,11] + periods = [2, 3, 5, 7, 11] discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] - discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods] + discs = discs + [ + DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods + ] self.discriminators = nn.ModuleList(discs) def forward(self, y, y_hat): @@ -220,12 +310,20 @@ def forward(self, y, y_hat): fmap_gs.append(fmap_g) return y_d_rs, y_d_gs, fmap_rs, fmap_gs - - + + class SpeakerEncoder(torch.nn.Module): - def __init__(self, mel_n_channels=80, model_num_layers=3, model_hidden_size=256, model_embedding_size=256): + def __init__( + self, + mel_n_channels=80, + model_num_layers=3, + model_hidden_size=256, + model_embedding_size=256, + ): super(SpeakerEncoder, self).__init__() - self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True) + self.lstm = nn.LSTM( + mel_n_channels, model_hidden_size, model_num_layers, batch_first=True + ) self.linear = nn.Linear(model_hidden_size, model_embedding_size) self.relu = nn.ReLU() @@ -234,118 +332,145 @@ def forward(self, mels): _, (hidden, _) = self.lstm(mels) embeds_raw = self.relu(self.linear(hidden[-1])) return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True) - + def compute_partial_slices(self, total_frames, partial_frames, partial_hop): mel_slices = [] - for i in range(0, total_frames-partial_frames, partial_hop): - mel_range = torch.arange(i, i+partial_frames) + for i in range(0, total_frames - partial_frames, partial_hop): + mel_range = torch.arange(i, i + partial_frames) mel_slices.append(mel_range) - + return mel_slices - + def embed_utterance(self, mel, partial_frames=128, partial_hop=64): mel_len = mel.size(1) - last_mel = mel[:,-partial_frames:] - + last_mel = mel[:, -partial_frames:] + if mel_len > partial_frames: - mel_slices = self.compute_partial_slices(mel_len, partial_frames, partial_hop) - mels = list(mel[:,s] for s in mel_slices) + mel_slices = self.compute_partial_slices( + mel_len, partial_frames, partial_hop + ) + mels = list(mel[:, s] for s in mel_slices) mels.append(last_mel) mels = torch.stack(tuple(mels), 0).squeeze(1) - + with torch.no_grad(): partial_embeds = self(mels) embed = torch.mean(partial_embeds, axis=0).unsqueeze(0) - #embed = embed / torch.linalg.norm(embed, 2) + # embed = embed / torch.linalg.norm(embed, 2) else: with torch.no_grad(): embed = self(last_mel) - + return embed class SynthesizerTrn(nn.Module): - """ - Synthesizer for Training - """ - - def __init__(self, - spec_channels, - segment_size, - inter_channels, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout, - resblock, - resblock_kernel_sizes, - resblock_dilation_sizes, - upsample_rates, - upsample_initial_channel, - upsample_kernel_sizes, - gin_channels, - ssl_dim, - use_spk, - **kwargs): - - super().__init__() - self.spec_channels = spec_channels - self.inter_channels = inter_channels - self.hidden_channels = hidden_channels - self.filter_channels = filter_channels - self.n_heads = n_heads - self.n_layers = n_layers - self.kernel_size = kernel_size - self.p_dropout = p_dropout - self.resblock = resblock - self.resblock_kernel_sizes = resblock_kernel_sizes - self.resblock_dilation_sizes = resblock_dilation_sizes - self.upsample_rates = upsample_rates - self.upsample_initial_channel = upsample_initial_channel - self.upsample_kernel_sizes = upsample_kernel_sizes - self.segment_size = segment_size - self.gin_channels = gin_channels - self.ssl_dim = ssl_dim - self.use_spk = use_spk - - self.enc_p = Encoder(ssl_dim, inter_channels, hidden_channels, 5, 1, 16) - self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels) - self.enc_q = Encoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels) - self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) - - if not self.use_spk: - self.enc_spk = SpeakerEncoder(model_hidden_size=gin_channels, model_embedding_size=gin_channels) - - def forward(self, c, spec, g=None, mel=None, c_lengths=None, spec_lengths=None): - if c_lengths == None: - c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device) - if spec_lengths == None: - spec_lengths = (torch.ones(spec.size(0)) * spec.size(-1)).to(spec.device) - - if not self.use_spk: - g = self.enc_spk(mel.transpose(1,2)) - g = g.unsqueeze(-1) - - _, m_p, logs_p, _ = self.enc_p(c, c_lengths) - z, m_q, logs_q, spec_mask = self.enc_q(spec, spec_lengths, g=g) - z_p = self.flow(z, spec_mask, g=g) - - z_slice, ids_slice = commons.rand_slice_segments(z, spec_lengths, self.segment_size) - o = self.dec(z_slice, g=g) - - return o, ids_slice, spec_mask, (z, z_p, m_p, logs_p, m_q, logs_q) - - def infer(self, c, g=None, mel=None, c_lengths=None): - if c_lengths == None: - c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device) - if not self.use_spk: - g = self.enc_spk.embed_utterance(mel.transpose(1,2)) - g = g.unsqueeze(-1) - - z_p, m_p, logs_p, c_mask = self.enc_p(c, c_lengths) - z = self.flow(z_p, c_mask, g=g, reverse=True) - o = self.dec(z * c_mask, g=g) - - return o + """ + Synthesizer for Training + """ + + def __init__( + self, + spec_channels, + segment_size, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels, + ssl_dim, + use_spk, + **kwargs + ): + + super().__init__() + self.spec_channels = spec_channels + self.inter_channels = inter_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.resblock = resblock + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.upsample_rates = upsample_rates + self.upsample_initial_channel = upsample_initial_channel + self.upsample_kernel_sizes = upsample_kernel_sizes + self.segment_size = segment_size + self.gin_channels = gin_channels + self.ssl_dim = ssl_dim + self.use_spk = use_spk + + self.enc_p = Encoder(ssl_dim, inter_channels, hidden_channels, 5, 1, 16) + self.dec = Generator( + inter_channels, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=gin_channels, + ) + self.enc_q = Encoder( + spec_channels, + inter_channels, + hidden_channels, + 5, + 1, + 16, + gin_channels=gin_channels, + ) + self.flow = ResidualCouplingBlock( + inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels + ) + + if not self.use_spk: + self.enc_spk = SpeakerEncoder( + model_hidden_size=gin_channels, model_embedding_size=gin_channels + ) + + def forward(self, c, spec, g=None, mel=None, c_lengths=None, spec_lengths=None): + if c_lengths == None: + c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device) + if spec_lengths == None: + spec_lengths = (torch.ones(spec.size(0)) * spec.size(-1)).to(spec.device) + + if not self.use_spk: + g = self.enc_spk(mel.transpose(1, 2)) + g = g.unsqueeze(-1) + + _, m_p, logs_p, _ = self.enc_p(c, c_lengths) + z, m_q, logs_q, spec_mask = self.enc_q(spec, spec_lengths, g=g) + z_p = self.flow(z, spec_mask, g=g) + + z_slice, ids_slice = commons.rand_slice_segments( + z, spec_lengths, self.segment_size + ) + o = self.dec(z_slice, g=g) + + return o, ids_slice, spec_mask, (z, z_p, m_p, logs_p, m_q, logs_q) + + def infer(self, c, g=None, mel=None, c_lengths=None): + if c_lengths == None: + c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device) + if not self.use_spk: + g = self.enc_spk.embed_utterance(mel.transpose(1, 2)) + g = g.unsqueeze(-1) + + z_p, m_p, logs_p, c_mask = self.enc_p(c, c_lengths) + z = self.flow(z_p, c_mask, g=g, reverse=True) + o = self.dec(z * c_mask, g=g) + + return o diff --git a/modules.py b/modules.py index 52ee14e..2641d84 100644 --- a/modules.py +++ b/modules.py @@ -17,193 +17,282 @@ class LayerNorm(nn.Module): - def __init__(self, channels, eps=1e-5): - super().__init__() - self.channels = channels - self.eps = eps + def __init__(self, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps - self.gamma = nn.Parameter(torch.ones(channels)) - self.beta = nn.Parameter(torch.zeros(channels)) + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) - def forward(self, x): - x = x.transpose(1, -1) - x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) - return x.transpose(1, -1) - class ConvReluNorm(nn.Module): - def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): - super().__init__() - self.in_channels = in_channels - self.hidden_channels = hidden_channels - self.out_channels = out_channels - self.kernel_size = kernel_size - self.n_layers = n_layers - self.p_dropout = p_dropout - assert n_layers > 1, "Number of layers should be larger than 0." - - self.conv_layers = nn.ModuleList() - self.norm_layers = nn.ModuleList() - self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2)) - self.norm_layers.append(LayerNorm(hidden_channels)) - self.relu_drop = nn.Sequential( - nn.ReLU(), - nn.Dropout(p_dropout)) - for _ in range(n_layers-1): - self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2)) - self.norm_layers.append(LayerNorm(hidden_channels)) - self.proj = nn.Conv1d(hidden_channels, out_channels, 1) - self.proj.weight.data.zero_() - self.proj.bias.data.zero_() - - def forward(self, x, x_mask): - x_org = x - for i in range(self.n_layers): - x = self.conv_layers[i](x * x_mask) - x = self.norm_layers[i](x) - x = self.relu_drop(x) - x = x_org + self.proj(x) - return x * x_mask + def __init__( + self, + in_channels, + hidden_channels, + out_channels, + kernel_size, + n_layers, + p_dropout, + ): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + assert n_layers > 1, "Number of layers should be larger than 0." + + self.conv_layers = nn.ModuleList() + self.norm_layers = nn.ModuleList() + self.conv_layers.append( + nn.Conv1d( + in_channels, hidden_channels, kernel_size, padding=kernel_size // 2 + ) + ) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout)) + for _ in range(n_layers - 1): + self.conv_layers.append( + nn.Conv1d( + hidden_channels, + hidden_channels, + kernel_size, + padding=kernel_size // 2, + ) + ) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask): + x_org = x + for i in range(self.n_layers): + x = self.conv_layers[i](x * x_mask) + x = self.norm_layers[i](x) + x = self.relu_drop(x) + x = x_org + self.proj(x) + return x * x_mask class DDSConv(nn.Module): - """ - Dialted and Depth-Separable Convolution - """ - def __init__(self, channels, kernel_size, n_layers, p_dropout=0.): - super().__init__() - self.channels = channels - self.kernel_size = kernel_size - self.n_layers = n_layers - self.p_dropout = p_dropout - - self.drop = nn.Dropout(p_dropout) - self.convs_sep = nn.ModuleList() - self.convs_1x1 = nn.ModuleList() - self.norms_1 = nn.ModuleList() - self.norms_2 = nn.ModuleList() - for i in range(n_layers): - dilation = kernel_size ** i - padding = (kernel_size * dilation - dilation) // 2 - self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size, - groups=channels, dilation=dilation, padding=padding - )) - self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) - self.norms_1.append(LayerNorm(channels)) - self.norms_2.append(LayerNorm(channels)) - - def forward(self, x, x_mask, g=None): - if g is not None: - x = x + g - for i in range(self.n_layers): - y = self.convs_sep[i](x * x_mask) - y = self.norms_1[i](y) - y = F.gelu(y) - y = self.convs_1x1[i](y) - y = self.norms_2[i](y) - y = F.gelu(y) - y = self.drop(y) - x = x + y - return x * x_mask + """ + Dialted and Depth-Separable Convolution + """ + + def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0): + super().__init__() + self.channels = channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + + self.drop = nn.Dropout(p_dropout) + self.convs_sep = nn.ModuleList() + self.convs_1x1 = nn.ModuleList() + self.norms_1 = nn.ModuleList() + self.norms_2 = nn.ModuleList() + for i in range(n_layers): + dilation = kernel_size**i + padding = (kernel_size * dilation - dilation) // 2 + self.convs_sep.append( + nn.Conv1d( + channels, + channels, + kernel_size, + groups=channels, + dilation=dilation, + padding=padding, + ) + ) + self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) + self.norms_1.append(LayerNorm(channels)) + self.norms_2.append(LayerNorm(channels)) + + def forward(self, x, x_mask, g=None): + if g is not None: + x = x + g + for i in range(self.n_layers): + y = self.convs_sep[i](x * x_mask) + y = self.norms_1[i](y) + y = F.gelu(y) + y = self.convs_1x1[i](y) + y = self.norms_2[i](y) + y = F.gelu(y) + y = self.drop(y) + x = x + y + return x * x_mask class WN(torch.nn.Module): - def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0): - super(WN, self).__init__() - assert(kernel_size % 2 == 1) - self.hidden_channels =hidden_channels - self.kernel_size = kernel_size, - self.dilation_rate = dilation_rate - self.n_layers = n_layers - self.gin_channels = gin_channels - self.p_dropout = p_dropout - - self.in_layers = torch.nn.ModuleList() - self.res_skip_layers = torch.nn.ModuleList() - self.drop = nn.Dropout(p_dropout) - - if gin_channels != 0: - cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1) - self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') - - for i in range(n_layers): - dilation = dilation_rate ** i - padding = int((kernel_size * dilation - dilation) / 2) - in_layer = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size, - dilation=dilation, padding=padding) - in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') - self.in_layers.append(in_layer) - - # last one is not necessary - if i < n_layers - 1: - res_skip_channels = 2 * hidden_channels - else: - res_skip_channels = hidden_channels - - res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) - res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') - self.res_skip_layers.append(res_skip_layer) - - def forward(self, x, x_mask, g=None, **kwargs): - output = torch.zeros_like(x) - n_channels_tensor = torch.IntTensor([self.hidden_channels]) - - if g is not None: - g = self.cond_layer(g) - - for i in range(self.n_layers): - x_in = self.in_layers[i](x) - if g is not None: - cond_offset = i * 2 * self.hidden_channels - g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:] - else: - g_l = torch.zeros_like(x_in) - - acts = commons.fused_add_tanh_sigmoid_multiply( - x_in, - g_l, - n_channels_tensor) - acts = self.drop(acts) - - res_skip_acts = self.res_skip_layers[i](acts) - if i < self.n_layers - 1: - res_acts = res_skip_acts[:,:self.hidden_channels,:] - x = (x + res_acts) * x_mask - output = output + res_skip_acts[:,self.hidden_channels:,:] - else: - output = output + res_skip_acts - return output * x_mask - - def remove_weight_norm(self): - if self.gin_channels != 0: - torch.nn.utils.remove_weight_norm(self.cond_layer) - for l in self.in_layers: - torch.nn.utils.remove_weight_norm(l) - for l in self.res_skip_layers: - torch.nn.utils.remove_weight_norm(l) + def __init__( + self, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=0, + p_dropout=0, + ): + super(WN, self).__init__() + assert kernel_size % 2 == 1 + self.hidden_channels = hidden_channels + self.kernel_size = (kernel_size,) + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + self.p_dropout = p_dropout + + self.in_layers = torch.nn.ModuleList() + self.res_skip_layers = torch.nn.ModuleList() + self.drop = nn.Dropout(p_dropout) + + if gin_channels != 0: + cond_layer = torch.nn.Conv1d( + gin_channels, 2 * hidden_channels * n_layers, 1 + ) + self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight") + + for i in range(n_layers): + dilation = dilation_rate**i + padding = int((kernel_size * dilation - dilation) / 2) + in_layer = torch.nn.Conv1d( + hidden_channels, + 2 * hidden_channels, + kernel_size, + dilation=dilation, + padding=padding, + ) + in_layer = torch.nn.utils.weight_norm(in_layer, name="weight") + self.in_layers.append(in_layer) + + # last one is not necessary + if i < n_layers - 1: + res_skip_channels = 2 * hidden_channels + else: + res_skip_channels = hidden_channels + + res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) + res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight") + self.res_skip_layers.append(res_skip_layer) + + def forward(self, x, x_mask, g=None, **kwargs): + output = torch.zeros_like(x) + n_channels_tensor = torch.IntTensor([self.hidden_channels]) + + if g is not None: + g = self.cond_layer(g) + + for i in range(self.n_layers): + x_in = self.in_layers[i](x) + if g is not None: + cond_offset = i * 2 * self.hidden_channels + g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] + else: + g_l = torch.zeros_like(x_in) + + acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) + acts = self.drop(acts) + + res_skip_acts = self.res_skip_layers[i](acts) + if i < self.n_layers - 1: + res_acts = res_skip_acts[:, : self.hidden_channels, :] + x = (x + res_acts) * x_mask + output = output + res_skip_acts[:, self.hidden_channels :, :] + else: + output = output + res_skip_acts + return output * x_mask + + def remove_weight_norm(self): + if self.gin_channels != 0: + torch.nn.utils.remove_weight_norm(self.cond_layer) + for l in self.in_layers: + torch.nn.utils.remove_weight_norm(l) + for l in self.res_skip_layers: + torch.nn.utils.remove_weight_norm(l) class ResBlock1(torch.nn.Module): def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): super(ResBlock1, self).__init__() - self.convs1 = nn.ModuleList([ - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], - padding=get_padding(kernel_size, dilation[0]))), - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], - padding=get_padding(kernel_size, dilation[1]))), - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], - padding=get_padding(kernel_size, dilation[2]))) - ]) + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) self.convs1.apply(init_weights) - self.convs2 = nn.ModuleList([ - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, - padding=get_padding(kernel_size, 1))), - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, - padding=get_padding(kernel_size, 1))), - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, - padding=get_padding(kernel_size, 1))) - ]) + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) self.convs2.apply(init_weights) def forward(self, x, x_mask=None): @@ -231,12 +320,30 @@ def remove_weight_norm(self): class ResBlock2(torch.nn.Module): def __init__(self, channels, kernel_size=3, dilation=(1, 3)): super(ResBlock2, self).__init__() - self.convs = nn.ModuleList([ - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], - padding=get_padding(kernel_size, dilation[0]))), - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], - padding=get_padding(kernel_size, dilation[1]))) - ]) + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) self.convs.apply(init_weights) def forward(self, x, x_mask=None): @@ -256,87 +363,96 @@ def remove_weight_norm(self): class Log(nn.Module): - def forward(self, x, x_mask, reverse=False, **kwargs): - if not reverse: - y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask - logdet = torch.sum(-y, [1, 2]) - return y, logdet - else: - x = torch.exp(x) * x_mask - return x - + def forward(self, x, x_mask, reverse=False, **kwargs): + if not reverse: + y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask + logdet = torch.sum(-y, [1, 2]) + return y, logdet + else: + x = torch.exp(x) * x_mask + return x + class Flip(nn.Module): - def forward(self, x, *args, reverse=False, **kwargs): - x = torch.flip(x, [1]) - if not reverse: - logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) - return x, logdet - else: - return x + def forward(self, x, *args, reverse=False, **kwargs): + x = torch.flip(x, [1]) + if not reverse: + logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) + return x, logdet + else: + return x class ElementwiseAffine(nn.Module): - def __init__(self, channels): - super().__init__() - self.channels = channels - self.m = nn.Parameter(torch.zeros(channels,1)) - self.logs = nn.Parameter(torch.zeros(channels,1)) - - def forward(self, x, x_mask, reverse=False, **kwargs): - if not reverse: - y = self.m + torch.exp(self.logs) * x - y = y * x_mask - logdet = torch.sum(self.logs * x_mask, [1,2]) - return y, logdet - else: - x = (x - self.m) * torch.exp(-self.logs) * x_mask - return x + def __init__(self, channels): + super().__init__() + self.channels = channels + self.m = nn.Parameter(torch.zeros(channels, 1)) + self.logs = nn.Parameter(torch.zeros(channels, 1)) + + def forward(self, x, x_mask, reverse=False, **kwargs): + if not reverse: + y = self.m + torch.exp(self.logs) * x + y = y * x_mask + logdet = torch.sum(self.logs * x_mask, [1, 2]) + return y, logdet + else: + x = (x - self.m) * torch.exp(-self.logs) * x_mask + return x class ResidualCouplingLayer(nn.Module): - def __init__(self, - channels, - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - p_dropout=0, - gin_channels=0, - mean_only=False): - assert channels % 2 == 0, "channels should be divisible by 2" - super().__init__() - self.channels = channels - self.hidden_channels = hidden_channels - self.kernel_size = kernel_size - self.dilation_rate = dilation_rate - self.n_layers = n_layers - self.half_channels = channels // 2 - self.mean_only = mean_only - - self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) - self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels) - self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) - self.post.weight.data.zero_() - self.post.bias.data.zero_() - - def forward(self, x, x_mask, g=None, reverse=False): - x0, x1 = torch.split(x, [self.half_channels]*2, 1) - h = self.pre(x0) * x_mask - h = self.enc(h, x_mask, g=g) - stats = self.post(h) * x_mask - if not self.mean_only: - m, logs = torch.split(stats, [self.half_channels]*2, 1) - else: - m = stats - logs = torch.zeros_like(m) - - if not reverse: - x1 = m + x1 * torch.exp(logs) * x_mask - x = torch.cat([x0, x1], 1) - logdet = torch.sum(logs, [1,2]) - return x, logdet - else: - x1 = (x1 - m) * torch.exp(-logs) * x_mask - x = torch.cat([x0, x1], 1) - return x + def __init__( + self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + p_dropout=0, + gin_channels=0, + mean_only=False, + ): + assert channels % 2 == 0, "channels should be divisible by 2" + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.half_channels = channels // 2 + self.mean_only = mean_only + + self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) + self.enc = WN( + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + p_dropout=p_dropout, + gin_channels=gin_channels, + ) + self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) + self.post.weight.data.zero_() + self.post.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) * x_mask + h = self.enc(h, x_mask, g=g) + stats = self.post(h) * x_mask + if not self.mean_only: + m, logs = torch.split(stats, [self.half_channels] * 2, 1) + else: + m = stats + logs = torch.zeros_like(m) + + if not reverse: + x1 = m + x1 * torch.exp(logs) * x_mask + x = torch.cat([x0, x1], 1) + logdet = torch.sum(logs, [1, 2]) + return x, logdet + else: + x1 = (x1 - m) * torch.exp(-logs) * x_mask + x = torch.cat([x0, x1], 1) + return x diff --git a/preprocess_flist.py b/preprocess_flist.py index 2c67dbf..c30e0e6 100644 --- a/preprocess_flist.py +++ b/preprocess_flist.py @@ -6,46 +6,62 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--train_list", type=str, default="./filelists/train.txt", help="path to train list") - parser.add_argument("--val_list", type=str, default="./filelists/val.txt", help="path to val list") - parser.add_argument("--test_list", type=str, default="./filelists/test.txt", help="path to test list") - parser.add_argument("--source_dir", type=str, default="./dataset/vctk-16k", help="path to source dir") + parser.add_argument( + "--train_list", + type=str, + default="./filelists/train.txt", + help="path to train list", + ) + parser.add_argument( + "--val_list", type=str, default="./filelists/val.txt", help="path to val list" + ) + parser.add_argument( + "--test_list", + type=str, + default="./filelists/test.txt", + help="path to test list", + ) + parser.add_argument( + "--source_dir", + type=str, + default="./dataset/vctk-16k", + help="path to source dir", + ) args = parser.parse_args() - + train = [] val = [] test = [] idx = 0 - + for speaker in tqdm(os.listdir(args.source_dir)): wavs = os.listdir(os.path.join(args.source_dir, speaker)) shuffle(wavs) train += wavs[2:-10] val += wavs[:2] test += wavs[-10:] - + shuffle(train) shuffle(val) shuffle(test) - + print("Writing", args.train_list) with open(args.train_list, "w") as f: for fname in tqdm(train): speaker = fname[:4] wavpath = os.path.join("DUMMY", speaker, fname) f.write(wavpath + "\n") - + print("Writing", args.val_list) with open(args.val_list, "w") as f: for fname in tqdm(val): speaker = fname[:4] wavpath = os.path.join("DUMMY", speaker, fname) f.write(wavpath + "\n") - + print("Writing", args.test_list) with open(args.test_list, "w") as f: for fname in tqdm(test): speaker = fname[:4] wavpath = os.path.join("DUMMY", speaker, fname) f.write(wavpath + "\n") - \ No newline at end of file diff --git a/preprocess_spk.py b/preprocess_spk.py index 67d2c6a..74f912f 100644 --- a/preprocess_spk.py +++ b/preprocess_spk.py @@ -6,22 +6,26 @@ from os.path import join, basename, split from tqdm import tqdm from multiprocessing import cpu_count -from concurrent.futures import ProcessPoolExecutor +from concurrent.futures import ProcessPoolExecutor from functools import partial -import glob +import glob import argparse def build_from_path(in_dir, out_dir, weights_fpath, num_workers=1): executor = ProcessPoolExecutor(max_workers=num_workers) futures = [] - wavfile_paths = glob.glob(os.path.join(in_dir, '*.wav')) - wavfile_paths= sorted(wavfile_paths) + wavfile_paths = glob.glob(os.path.join(in_dir, "*.wav")) + wavfile_paths = sorted(wavfile_paths) for wav_path in wavfile_paths: - futures.append(executor.submit( - partial(_compute_spkEmbed, out_dir, wav_path, weights_fpath))) + futures.append( + executor.submit( + partial(_compute_spkEmbed, out_dir, wav_path, weights_fpath) + ) + ) return [future.result() for future in tqdm(futures)] + def _compute_spkEmbed(out_dir, wav_path, weights_fpath): utt_id = os.path.basename(wav_path).rstrip(".wav") fpath = Path(wav_path) @@ -33,51 +37,54 @@ def _compute_spkEmbed(out_dir, wav_path, weights_fpath): np.save(fname_save, embed, allow_pickle=False) return os.path.basename(fname_save) + def preprocess(in_dir, out_dir_root, spk, weights_fpath, num_workers): out_dir = os.path.join(out_dir_root, spk) os.makedirs(out_dir, exist_ok=True) metadata = build_from_path(in_dir, out_dir, weights_fpath, num_workers) + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--in_dir', type=str, - default='dataset/vctk-16k/') - parser.add_argument('--num_workers', type=int, default=12) - parser.add_argument('--out_dir_root', type=str, - default='dataset') - parser.add_argument('--spk_encoder_ckpt', type=str, \ - default='speaker_encoder/ckpt/pretrained_bak_5805000.pt') + parser.add_argument("--in_dir", type=str, default="dataset/vctk-16k/") + parser.add_argument("--num_workers", type=int, default=12) + parser.add_argument("--out_dir_root", type=str, default="dataset") + parser.add_argument( + "--spk_encoder_ckpt", + type=str, + default="speaker_encoder/ckpt/pretrained_bak_5805000.pt", + ) args = parser.parse_args() - - #split_list = ['train-clean-100', 'train-clean-360'] + + # split_list = ['train-clean-100', 'train-clean-360'] sub_folder_list = os.listdir(args.in_dir) sub_folder_list.sort() - + args.num_workers = args.num_workers if args.num_workers is not None else cpu_count() print("Number of workers: ", args.num_workers) - ckpt_step = os.path.basename(args.spk_encoder_ckpt).split('.')[0].split('_')[-1] + ckpt_step = os.path.basename(args.spk_encoder_ckpt).split(".")[0].split("_")[-1] spk_embed_out_dir = os.path.join(args.out_dir_root, "spk") print("[INFO] spk_embed_out_dir: ", spk_embed_out_dir) os.makedirs(spk_embed_out_dir, exist_ok=True) - #for data_split in split_list: - # sub_folder_list = os.listdir(args.in_dir, data_split) + # for data_split in split_list: + # sub_folder_list = os.listdir(args.in_dir, data_split) for spk in sub_folder_list: print("Preprocessing {} ...".format(spk)) in_dir = os.path.join(args.in_dir, spk) - if not os.path.isdir(in_dir): + if not os.path.isdir(in_dir): continue - #out_dir = os.path.join(args.out_dir, spk) - preprocess(in_dir, spk_embed_out_dir, spk, args.spk_encoder_ckpt, args.num_workers) - ''' + # out_dir = os.path.join(args.out_dir, spk) + preprocess( + in_dir, spk_embed_out_dir, spk, args.spk_encoder_ckpt, args.num_workers + ) + """ for data_split in split_list: in_dir = os.path.join(args.in_dir, data_split) preprocess(in_dir, spk_embed_out_dir, args.spk_encoder_ckpt, args.num_workers) - ''' + """ print("DONE!") sys.exit(0) - - diff --git a/preprocess_sr.py b/preprocess_sr.py index b2903e6..186b025 100644 --- a/preprocess_sr.py +++ b/preprocess_sr.py @@ -10,14 +10,16 @@ import utils from mel_processing import mel_spectrogram_torch from wavlm import WavLM, WavLMConfig -#import h5py + +# import h5py import logging -logging.getLogger('numba').setLevel(logging.WARNING) + +logging.getLogger("numba").setLevel(logging.WARNING) def process(filename): basename = os.path.basename(filename) - speaker = filename.split("/")[-2]#basename[:4] + speaker = filename.split("/")[-2] # basename[:4] wav_dir = os.path.join(args.wav_dir, speaker) ssl_dir = os.path.join(args.ssl_dir, speaker) os.makedirs(wav_dir, exist_ok=True) @@ -25,22 +27,22 @@ def process(filename): wav, _ = librosa.load(filename, sr=hps.sampling_rate) wav = torch.from_numpy(wav).unsqueeze(0).cuda() mel = mel_spectrogram_torch( - wav, - hps.n_fft, - hps.num_mels, - hps.sampling_rate, - hps.hop_size, - hps.win_size, - hps.fmin, - hps.fmax + wav, + hps.n_fft, + hps.num_mels, + hps.sampling_rate, + hps.hop_size, + hps.win_size, + hps.fmin, + hps.fmax, ) - ''' + """ f = {} for i in range(args.min, args.max+1): fpath = os.path.join(ssl_dir, f"{i}.hdf5") f[i] = h5py.File(fpath, "a") - ''' - for i in range(args.min, args.max+1): + """ + for i in range(args.min, args.max + 1): mel_rs = utils.transform(mel, i) wav_rs = vocoder(mel_rs)[0][0].detach().cpu().numpy() _wav_rs = librosa.resample(wav_rs, orig_sr=hps.sampling_rate, target_sr=args.sr) @@ -48,19 +50,14 @@ def process(filename): c = utils.get_content(cmodel, wav_rs) ssl_path = os.path.join(ssl_dir, basename.replace(".wav", f"_{i}.pt")) torch.save(c.cpu(), ssl_path) - #print(wav_rs.size(), c.size()) + # print(wav_rs.size(), c.size()) wav_path = os.path.join(wav_dir, basename.replace(".wav", f"_{i}.wav")) - wavfile.write( - wav_path, - args.sr, - _wav_rs - ) - ''' + wavfile.write(wav_path, args.sr, _wav_rs) + """ f[i][basename[:-4]] = c.cpu() for i in range(args.min, args.max+1): f[i].close() - ''' - + """ if __name__ == "__main__": @@ -68,17 +65,25 @@ def process(filename): parser.add_argument("--sr", type=int, default=16000, help="sampling rate") parser.add_argument("--min", type=int, default=68, help="min") parser.add_argument("--max", type=int, default=92, help="max") - parser.add_argument("--config", type=str, default="hifigan/config.json", help="path to config file") - parser.add_argument("--in_dir", type=str, default="dataset/vctk-22k", help="path to input dir") - parser.add_argument("--wav_dir", type=str, default="dataset/sr/wav", help="path to output wav dir") - parser.add_argument("--ssl_dir", type=str, default="dataset/sr/wavlm", help="path to output ssl dir") + parser.add_argument( + "--config", type=str, default="hifigan/config.json", help="path to config file" + ) + parser.add_argument( + "--in_dir", type=str, default="dataset/vctk-22k", help="path to input dir" + ) + parser.add_argument( + "--wav_dir", type=str, default="dataset/sr/wav", help="path to output wav dir" + ) + parser.add_argument( + "--ssl_dir", type=str, default="dataset/sr/wavlm", help="path to output ssl dir" + ) args = parser.parse_args() print("Loading WavLM for content...") - checkpoint = torch.load('wavlm/WavLM-Large.pt') - cfg = WavLMConfig(checkpoint['cfg']) + checkpoint = torch.load("wavlm/WavLM-Large.pt") + cfg = WavLMConfig(checkpoint["cfg"]) cmodel = WavLM(cfg).cuda() - cmodel.load_state_dict(checkpoint['model']) + cmodel.load_state_dict(checkpoint["model"]) cmodel.eval() print("Loaded WavLM.") @@ -86,15 +91,14 @@ def process(filename): vocoder = utils.get_vocoder(0) vocoder.eval() print("Loaded vocoder.") - + config_path = args.config with open(config_path, "r") as f: data = f.read() config = json.loads(data) hps = utils.HParams(**config) - filenames = glob(f'{args.in_dir}/*/*.wav', recursive=True)#[:10] - + filenames = glob(f"{args.in_dir}/*/*.wav", recursive=True) # [:10] + for filename in tqdm(filenames): process(filename) - diff --git a/preprocess_ssl.py b/preprocess_ssl.py index 6e94040..7450722 100644 --- a/preprocess_ssl.py +++ b/preprocess_ssl.py @@ -24,22 +24,25 @@ def process(filename): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--sr", type=int, default=16000, help="sampling rate") - parser.add_argument("--in_dir", type=str, default="dataset/vctk-16k", help="path to input dir") - parser.add_argument("--out_dir", type=str, default="dataset/wavlm", help="path to output dir") + parser.add_argument( + "--in_dir", type=str, default="dataset/vctk-16k", help="path to input dir" + ) + parser.add_argument( + "--out_dir", type=str, default="dataset/wavlm", help="path to output dir" + ) args = parser.parse_args() - + os.makedirs(args.out_dir, exist_ok=True) print("Loading WavLM for content...") - checkpoint = torch.load('wavlm/WavLM-Large.pt') - cfg = WavLMConfig(checkpoint['cfg']) + checkpoint = torch.load("wavlm/WavLM-Large.pt") + cfg = WavLMConfig(checkpoint["cfg"]) cmodel = WavLM(cfg).cuda() - cmodel.load_state_dict(checkpoint['model']) + cmodel.load_state_dict(checkpoint["model"]) cmodel.eval() print("Loaded WavLM.") - - filenames = glob(f'{args.in_dir}/*/*.wav', recursive=True) - + + filenames = glob(f"{args.in_dir}/*/*.wav", recursive=True) + for filename in tqdm(filenames): process(filename) - \ No newline at end of file diff --git a/speaker_encoder/audio.py b/speaker_encoder/audio.py index 2fcb77a..9c637a2 100644 --- a/speaker_encoder/audio.py +++ b/speaker_encoder/audio.py @@ -7,20 +7,21 @@ import librosa import struct -int16_max = (2 ** 15) - 1 +int16_max = (2**15) - 1 -def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray], - source_sr: Optional[int] = None): +def preprocess_wav( + fpath_or_wav: Union[str, Path, np.ndarray], source_sr: Optional[int] = None +): """ - Applies the preprocessing operations used in training the Speaker Encoder to a waveform + Applies the preprocessing operations used in training the Speaker Encoder to a waveform either on disk or in memory. The waveform will be resampled to match the data hyperparameters. - :param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not + :param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not just .wav), either the waveform as a numpy array of floats. - :param source_sr: if passing an audio waveform, the sampling rate of the waveform before - preprocessing. After preprocessing, the waveform's sampling rate will match the data - hyperparameters. If passing a filepath, the sampling rate will be automatically detected and + :param source_sr: if passing an audio waveform, the sampling rate of the waveform before + preprocessing. After preprocessing, the waveform's sampling rate will match the data + hyperparameters. If passing a filepath, the sampling rate will be automatically detected and this argument will be ignored. """ # Load the wav from disk if needed @@ -28,15 +29,15 @@ def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray], wav, source_sr = librosa.load(fpath_or_wav, sr=None) else: wav = fpath_or_wav - + # Resample the wav if needed if source_sr is not None and source_sr != sampling_rate: wav = librosa.resample(wav, source_sr, sampling_rate) - - # Apply the preprocessing: normalize volume and shorten long silences + + # Apply the preprocessing: normalize volume and shorten long silences wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True) wav = trim_long_silences(wav) - + return wav @@ -50,58 +51,65 @@ def wav_to_mel_spectrogram(wav): sr=sampling_rate, n_fft=int(sampling_rate * mel_window_length / 1000), hop_length=int(sampling_rate * mel_window_step / 1000), - n_mels=mel_n_channels + n_mels=mel_n_channels, ) return frames.astype(np.float32).T def trim_long_silences(wav): """ - Ensures that segments without voice in the waveform remain no longer than a + Ensures that segments without voice in the waveform remain no longer than a threshold determined by the VAD parameters in params.py. - :param wav: the raw waveform as a numpy array of floats + :param wav: the raw waveform as a numpy array of floats :return: the same waveform with silences trimmed away (length <= original wav length) """ # Compute the voice detection window size samples_per_window = (vad_window_length * sampling_rate) // 1000 - + # Trim the end of the audio to have a multiple of the window size - wav = wav[:len(wav) - (len(wav) % samples_per_window)] - + wav = wav[: len(wav) - (len(wav) % samples_per_window)] + # Convert the float waveform to 16-bit mono PCM - pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16)) - + pcm_wave = struct.pack( + "%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16) + ) + # Perform voice activation detection voice_flags = [] vad = webrtcvad.Vad(mode=3) for window_start in range(0, len(wav), samples_per_window): window_end = window_start + samples_per_window - voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2], - sample_rate=sampling_rate)) + voice_flags.append( + vad.is_speech( + pcm_wave[window_start * 2 : window_end * 2], sample_rate=sampling_rate + ) + ) voice_flags = np.array(voice_flags) - + # Smooth the voice detection with a moving average def moving_average(array, width): - array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2))) + array_padded = np.concatenate( + (np.zeros((width - 1) // 2), array, np.zeros(width // 2)) + ) ret = np.cumsum(array_padded, dtype=float) ret[width:] = ret[width:] - ret[:-width] - return ret[width - 1:] / width - + return ret[width - 1 :] / width + audio_mask = moving_average(voice_flags, vad_moving_average_width) audio_mask = np.round(audio_mask).astype(np.bool) - + # Dilate the voiced regions audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1)) audio_mask = np.repeat(audio_mask, samples_per_window) - + return wav[audio_mask == True] def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False): if increase_only and decrease_only: raise ValueError("Both increase only and decrease only are set") - dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav ** 2)) + dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav**2)) if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only): return wav return wav * (10 ** (dBFS_change / 20)) diff --git a/speaker_encoder/compute_embed.py b/speaker_encoder/compute_embed.py index 2fee33d..cb7d649 100644 --- a/speaker_encoder/compute_embed.py +++ b/speaker_encoder/compute_embed.py @@ -2,6 +2,7 @@ from multiprocessing.pool import Pool from functools import partial from pathlib import Path + # from utils import logmmse # from tqdm import tqdm # import numpy as np @@ -18,23 +19,25 @@ def embed_utterance(fpaths, encoder_model_fpath): wav = encoder.preprocess_wav(wav) embed = encoder.embed_utterance(wav) np.save(embed_fpath, embed, allow_pickle=False) - - -def create_embeddings(outdir_root: Path, wav_dir: Path, encoder_model_fpath: Path, n_processes: int): + + +def create_embeddings( + outdir_root: Path, wav_dir: Path, encoder_model_fpath: Path, n_processes: int +): wav_dir = outdir_root.joinpath("audio") metadata_fpath = synthesizer_root.joinpath("train.txt") assert wav_dir.exists() and metadata_fpath.exists() embed_dir = synthesizer_root.joinpath("embeds") embed_dir.mkdir(exist_ok=True) - + # Gather the input wave filepath and the target output embed filepath with metadata_fpath.open("r") as metadata_file: metadata = [line.split("|") for line in metadata_file] fpaths = [(wav_dir.joinpath(m[0]), embed_dir.joinpath(m[2])) for m in metadata] - + # TODO: improve on the multiprocessing, it's terrible. Disk I/O is the bottleneck here. # Embed the utterances in separate threads func = partial(embed_utterance, encoder_model_fpath=encoder_model_fpath) job = Pool(n_processes).imap(func, fpaths) - list(tqdm(job, "Embedding", len(fpaths), unit="utterances")) \ No newline at end of file + list(tqdm(job, "Embedding", len(fpaths), unit="utterances")) diff --git a/speaker_encoder/config.py b/speaker_encoder/config.py index 1c21312..bde2ffb 100644 --- a/speaker_encoder/config.py +++ b/speaker_encoder/config.py @@ -1,40 +1,22 @@ librispeech_datasets = { "train": { "clean": ["LibriSpeech/train-clean-100", "LibriSpeech/train-clean-360"], - "other": ["LibriSpeech/train-other-500"] - }, - "test": { - "clean": ["LibriSpeech/test-clean"], - "other": ["LibriSpeech/test-other"] - }, - "dev": { - "clean": ["LibriSpeech/dev-clean"], - "other": ["LibriSpeech/dev-other"] + "other": ["LibriSpeech/train-other-500"], }, + "test": {"clean": ["LibriSpeech/test-clean"], "other": ["LibriSpeech/test-other"]}, + "dev": {"clean": ["LibriSpeech/dev-clean"], "other": ["LibriSpeech/dev-other"]}, } libritts_datasets = { "train": { "clean": ["LibriTTS/train-clean-100", "LibriTTS/train-clean-360"], - "other": ["LibriTTS/train-other-500"] - }, - "test": { - "clean": ["LibriTTS/test-clean"], - "other": ["LibriTTS/test-other"] - }, - "dev": { - "clean": ["LibriTTS/dev-clean"], - "other": ["LibriTTS/dev-other"] + "other": ["LibriTTS/train-other-500"], }, + "test": {"clean": ["LibriTTS/test-clean"], "other": ["LibriTTS/test-other"]}, + "dev": {"clean": ["LibriTTS/dev-clean"], "other": ["LibriTTS/dev-other"]}, } voxceleb_datasets = { - "voxceleb1" : { - "train": ["VoxCeleb1/wav"], - "test": ["VoxCeleb1/test_wav"] - }, - "voxceleb2" : { - "train": ["VoxCeleb2/dev/aac"], - "test": ["VoxCeleb2/test_wav"] - } + "voxceleb1": {"train": ["VoxCeleb1/wav"], "test": ["VoxCeleb1/test_wav"]}, + "voxceleb2": {"train": ["VoxCeleb2/dev/aac"], "test": ["VoxCeleb2/test_wav"]}, } other_datasets = [ diff --git a/speaker_encoder/data_objects/__init__.py b/speaker_encoder/data_objects/__init__.py index 030317a..2f981b8 100644 --- a/speaker_encoder/data_objects/__init__.py +++ b/speaker_encoder/data_objects/__init__.py @@ -1,2 +1,6 @@ -from speaker_encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset -from speaker_encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataLoader +from speaker_encoder.data_objects.speaker_verification_dataset import ( + SpeakerVerificationDataset, +) +from speaker_encoder.data_objects.speaker_verification_dataset import ( + SpeakerVerificationDataLoader, +) diff --git a/speaker_encoder/data_objects/random_cycler.py b/speaker_encoder/data_objects/random_cycler.py index c405db6..b968ebd 100644 --- a/speaker_encoder/data_objects/random_cycler.py +++ b/speaker_encoder/data_objects/random_cycler.py @@ -1,23 +1,24 @@ import random + class RandomCycler: """ - Creates an internal copy of a sequence and allows access to its items in a constrained random - order. For a source sequence of n items and one or several consecutive queries of a total + Creates an internal copy of a sequence and allows access to its items in a constrained random + order. For a source sequence of n items and one or several consecutive queries of a total of m items, the following guarantees hold (one implies the other): - Each item will be returned between m // n and ((m - 1) // n) + 1 times. - Between two appearances of the same item, there may be at most 2 * (n - 1) other items. """ - + def __init__(self, source): if len(source) == 0: raise Exception("Can't create RandomCycler from an empty collection") self.all_items = list(source) self.next_items = [] - + def sample(self, count: int): shuffle = lambda l: random.sample(l, len(l)) - + out = [] while count > 0: if count >= len(self.all_items): @@ -31,7 +32,6 @@ def sample(self, count: int): if len(self.next_items) == 0: self.next_items = shuffle(list(self.all_items)) return out - + def __next__(self): return self.sample(1)[0] - diff --git a/speaker_encoder/data_objects/speaker.py b/speaker_encoder/data_objects/speaker.py index 0737984..bc75b0c 100644 --- a/speaker_encoder/data_objects/speaker.py +++ b/speaker_encoder/data_objects/speaker.py @@ -2,6 +2,7 @@ from speaker_encoder.data_objects.utterance import Utterance from pathlib import Path + # Contains the set of utterances of a single speaker class Speaker: def __init__(self, root: Path): @@ -9,25 +10,27 @@ def __init__(self, root: Path): self.name = root.name self.utterances = None self.utterance_cycler = None - + def _load_utterances(self): with self.root.joinpath("_sources.txt").open("r") as sources_file: sources = [l.split(",") for l in sources_file] sources = {frames_fname: wave_fpath for frames_fname, wave_fpath in sources} - self.utterances = [Utterance(self.root.joinpath(f), w) for f, w in sources.items()] + self.utterances = [ + Utterance(self.root.joinpath(f), w) for f, w in sources.items() + ] self.utterance_cycler = RandomCycler(self.utterances) - + def random_partial(self, count, n_frames): """ - Samples a batch of unique partial utterances from the disk in a way that all + Samples a batch of unique partial utterances from the disk in a way that all utterances come up at least once every two cycles and in a random order every time. - - :param count: The number of partial utterances to sample from the set of utterances from - that speaker. Utterances are guaranteed not to be repeated if is not larger than + + :param count: The number of partial utterances to sample from the set of utterances from + that speaker. Utterances are guaranteed not to be repeated if is not larger than the number of utterances available. :param n_frames: The number of frames in the partial utterance. - :return: A list of tuples (utterance, frames, range) where utterance is an Utterance, - frames are the frames of the partial utterances and range is the range of the partial + :return: A list of tuples (utterance, frames, range) where utterance is an Utterance, + frames are the frames of the partial utterances and range is the range of the partial utterance with regard to the complete utterance. """ if self.utterances is None: diff --git a/speaker_encoder/data_objects/speaker_batch.py b/speaker_encoder/data_objects/speaker_batch.py index 4485605..3cc5a52 100644 --- a/speaker_encoder/data_objects/speaker_batch.py +++ b/speaker_encoder/data_objects/speaker_batch.py @@ -2,11 +2,18 @@ from typing import List from speaker_encoder.data_objects.speaker import Speaker + class SpeakerBatch: - def __init__(self, speakers: List[Speaker], utterances_per_speaker: int, n_frames: int): + def __init__( + self, speakers: List[Speaker], utterances_per_speaker: int, n_frames: int + ): self.speakers = speakers - self.partials = {s: s.random_partial(utterances_per_speaker, n_frames) for s in speakers} - + self.partials = { + s: s.random_partial(utterances_per_speaker, n_frames) for s in speakers + } + # Array of shape (n_speakers * n_utterances, n_frames, mel_n), e.g. for 3 speakers with # 4 utterances each of 160 frames of 40 mel coefficients: (12, 160, 40) - self.data = np.array([frames for s in speakers for _, frames, _ in self.partials[s]]) + self.data = np.array( + [frames for s in speakers for _, frames, _ in self.partials[s]] + ) diff --git a/speaker_encoder/data_objects/speaker_verification_dataset.py b/speaker_encoder/data_objects/speaker_verification_dataset.py index cecd8ed..1a24f2c 100644 --- a/speaker_encoder/data_objects/speaker_verification_dataset.py +++ b/speaker_encoder/data_objects/speaker_verification_dataset.py @@ -7,50 +7,61 @@ # TODO: improve with a pool of speakers for data efficiency + class SpeakerVerificationDataset(Dataset): def __init__(self, datasets_root: Path): self.root = datasets_root speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()] if len(speaker_dirs) == 0: - raise Exception("No speakers found. Make sure you are pointing to the directory " - "containing all preprocessed speaker directories.") + raise Exception( + "No speakers found. Make sure you are pointing to the directory " + "containing all preprocessed speaker directories." + ) self.speakers = [Speaker(speaker_dir) for speaker_dir in speaker_dirs] self.speaker_cycler = RandomCycler(self.speakers) def __len__(self): return int(1e10) - + def __getitem__(self, index): return next(self.speaker_cycler) - + def get_logs(self): log_string = "" for log_fpath in self.root.glob("*.txt"): with log_fpath.open("r") as log_file: log_string += "".join(log_file.readlines()) return log_string - - + + class SpeakerVerificationDataLoader(DataLoader): - def __init__(self, dataset, speakers_per_batch, utterances_per_speaker, sampler=None, - batch_sampler=None, num_workers=0, pin_memory=False, timeout=0, - worker_init_fn=None): + def __init__( + self, + dataset, + speakers_per_batch, + utterances_per_speaker, + sampler=None, + batch_sampler=None, + num_workers=0, + pin_memory=False, + timeout=0, + worker_init_fn=None, + ): self.utterances_per_speaker = utterances_per_speaker super().__init__( - dataset=dataset, - batch_size=speakers_per_batch, - shuffle=False, - sampler=sampler, - batch_sampler=batch_sampler, + dataset=dataset, + batch_size=speakers_per_batch, + shuffle=False, + sampler=sampler, + batch_sampler=batch_sampler, num_workers=num_workers, - collate_fn=self.collate, - pin_memory=pin_memory, - drop_last=False, - timeout=timeout, - worker_init_fn=worker_init_fn + collate_fn=self.collate, + pin_memory=pin_memory, + drop_last=False, + timeout=timeout, + worker_init_fn=worker_init_fn, ) def collate(self, speakers): - return SpeakerBatch(speakers, self.utterances_per_speaker, partials_n_frames) - \ No newline at end of file + return SpeakerBatch(speakers, self.utterances_per_speaker, partials_n_frames) diff --git a/speaker_encoder/data_objects/utterance.py b/speaker_encoder/data_objects/utterance.py index 0768c34..5b65eaa 100644 --- a/speaker_encoder/data_objects/utterance.py +++ b/speaker_encoder/data_objects/utterance.py @@ -5,16 +5,16 @@ class Utterance: def __init__(self, frames_fpath, wave_fpath): self.frames_fpath = frames_fpath self.wave_fpath = wave_fpath - + def get_frames(self): return np.load(self.frames_fpath) def random_partial(self, n_frames): """ Crops the frames into a partial utterance of n_frames - + :param n_frames: The number of frames of the partial utterance - :return: the partial utterance frames and a tuple indicating the start and end of the + :return: the partial utterance frames and a tuple indicating the start and end of the partial utterance in the complete utterance. """ frames = self.get_frames() @@ -23,4 +23,4 @@ def random_partial(self, n_frames): else: start = np.random.randint(0, frames.shape[0] - n_frames) end = start + n_frames - return frames[start:end], (start, end) \ No newline at end of file + return frames[start:end], (start, end) diff --git a/speaker_encoder/hparams.py b/speaker_encoder/hparams.py index 9a8c164..2c536ae 100644 --- a/speaker_encoder/hparams.py +++ b/speaker_encoder/hparams.py @@ -1,13 +1,13 @@ ## Mel-filterbank mel_window_length = 25 # In milliseconds -mel_window_step = 10 # In milliseconds +mel_window_step = 10 # In milliseconds mel_n_channels = 40 ## Audio sampling_rate = 16000 # Number of spectrogram frames in a partial utterance -partials_n_frames = 160 # 1600 ms +partials_n_frames = 160 # 1600 ms ## Voice Activation Detection @@ -15,7 +15,7 @@ # This sets the granularity of the VAD. Should not need to be changed. vad_window_length = 30 # In milliseconds # Number of frames to average together when performing the moving average smoothing. -# The larger this value, the larger the VAD variations must be to not get smoothed out. +# The larger this value, the larger the VAD variations must be to not get smoothed out. vad_moving_average_width = 8 # Maximum number of consecutive silent frames a segment can have. vad_max_silence_length = 6 @@ -28,4 +28,4 @@ ## Model parameters model_hidden_size = 256 model_embedding_size = 256 -model_num_layers = 3 \ No newline at end of file +model_num_layers = 3 diff --git a/speaker_encoder/inference.py b/speaker_encoder/inference.py index 15e6bf1..deac3f8 100644 --- a/speaker_encoder/inference.py +++ b/speaker_encoder/inference.py @@ -1,6 +1,8 @@ from speaker_encoder.params_data import * from speaker_encoder.model import SpeakerEncoder -from speaker_encoder.audio import preprocess_wav # We want to expose this function from here +from speaker_encoder.audio import ( + preprocess_wav, +) # We want to expose this function from here from matplotlib import cm from speaker_encoder import audio from pathlib import Path @@ -8,18 +10,18 @@ import numpy as np import torch -_model = None # type: SpeakerEncoder -_device = None # type: torch.device +_model = None # type: SpeakerEncoder +_device = None # type: torch.device def load_model(weights_fpath: Path, device=None): """ - Loads the model in memory. If this function is not explicitely called, it will be run on the + Loads the model in memory. If this function is not explicitely called, it will be run on the first call to embed_frames() with the default weights file. - + :param weights_fpath: the path to saved model weights. - :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). The - model will be loaded and will run on this device. Outputs will however always be on the cpu. + :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). The + model will be loaded and will run on this device. Outputs will however always be on the cpu. If None, will default to your GPU if it"s available, otherwise your CPU. """ # TODO: I think the slow loading of the encoder might have something to do with the device it @@ -33,9 +35,12 @@ def load_model(weights_fpath: Path, device=None): checkpoint = torch.load(weights_fpath) _model.load_state_dict(checkpoint["model_state"]) _model.eval() - print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"])) - - + print( + 'Loaded encoder "%s" trained to step %d' + % (weights_fpath.name, checkpoint["step"]) + ) + + def is_loaded(): return _model is not None @@ -43,48 +48,52 @@ def is_loaded(): def embed_frames_batch(frames_batch): """ Computes embeddings for a batch of mel spectrogram. - - :param frames_batch: a batch mel of spectrogram as a numpy array of float32 of shape + + :param frames_batch: a batch mel of spectrogram as a numpy array of float32 of shape (batch_size, n_frames, n_channels) :return: the embeddings as a numpy array of float32 of shape (batch_size, model_embedding_size) """ if _model is None: raise Exception("Model was not loaded. Call load_model() before inference.") - + frames = torch.from_numpy(frames_batch).to(_device) embed = _model.forward(frames).detach().cpu().numpy() return embed -def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames, - min_pad_coverage=0.75, overlap=0.5): +def compute_partial_slices( + n_samples, + partial_utterance_n_frames=partials_n_frames, + min_pad_coverage=0.75, + overlap=0.5, +): """ - Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain - partial utterances of each. Both the waveform and the mel - spectrogram slices are returned, so as to make each partial utterance waveform correspond to - its spectrogram. This function assumes that the mel spectrogram parameters used are those + Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain + partial utterances of each. Both the waveform and the mel + spectrogram slices are returned, so as to make each partial utterance waveform correspond to + its spectrogram. This function assumes that the mel spectrogram parameters used are those defined in params_data.py. - - The returned ranges may be indexing further than the length of the waveform. It is + + The returned ranges may be indexing further than the length of the waveform. It is recommended that you pad the waveform with zeros up to wave_slices[-1].stop. - + :param n_samples: the number of samples in the waveform - :param partial_utterance_n_frames: the number of mel spectrogram frames in each partial + :param partial_utterance_n_frames: the number of mel spectrogram frames in each partial utterance - :param min_pad_coverage: when reaching the last partial utterance, it may or may not have - enough frames. If at least of are present, - then the last partial utterance will be considered, as if we padded the audio. Otherwise, - it will be discarded, as if we trimmed the audio. If there aren't enough frames for 1 partial + :param min_pad_coverage: when reaching the last partial utterance, it may or may not have + enough frames. If at least of are present, + then the last partial utterance will be considered, as if we padded the audio. Otherwise, + it will be discarded, as if we trimmed the audio. If there aren't enough frames for 1 partial utterance, this parameter is ignored so that the function always returns at least 1 slice. - :param overlap: by how much the partial utterance should overlap. If set to 0, the partial - utterances are entirely disjoint. - :return: the waveform slices and mel spectrogram slices as lists of array slices. Index - respectively the waveform and the mel spectrogram with these slices to obtain the partial + :param overlap: by how much the partial utterance should overlap. If set to 0, the partial + utterances are entirely disjoint. + :return: the waveform slices and mel spectrogram slices as lists of array slices. Index + respectively the waveform and the mel spectrogram with these slices to obtain the partial utterances. """ assert 0 <= overlap < 1 assert 0 < min_pad_coverage <= 1 - + samples_per_frame = int((sampling_rate * mel_window_step / 1000)) n_frames = int(np.ceil((n_samples + 1) / samples_per_frame)) frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1) @@ -97,34 +106,36 @@ def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_fram wav_range = mel_range * samples_per_frame mel_slices.append(slice(*mel_range)) wav_slices.append(slice(*wav_range)) - + # Evaluate whether extra padding is warranted or not last_wav_range = wav_slices[-1] - coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start) + coverage = (n_samples - last_wav_range.start) / ( + last_wav_range.stop - last_wav_range.start + ) if coverage < min_pad_coverage and len(mel_slices) > 1: mel_slices = mel_slices[:-1] wav_slices = wav_slices[:-1] - + return wav_slices, mel_slices def embed_utterance(wav, using_partials=True, return_partials=False, **kwargs): """ Computes an embedding for a single utterance. - + # TODO: handle multiple wavs to benefit from batching on GPU :param wav: a preprocessed (see audio.py) utterance waveform as a numpy array of float32 - :param using_partials: if True, then the utterance is split in partial utterances of - frames and the utterance embedding is computed from their - normalized average. If False, the utterance is instead computed from feeding the entire + :param using_partials: if True, then the utterance is split in partial utterances of + frames and the utterance embedding is computed from their + normalized average. If False, the utterance is instead computed from feeding the entire spectogram to the network. - :param return_partials: if True, the partial embeddings will also be returned along with the + :param return_partials: if True, the partial embeddings will also be returned along with the wav slices that correspond to the partial embeddings. :param kwargs: additional arguments to compute_partial_splits() - :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If - is True, the partial utterances as a numpy array of float32 of shape - (n_partials, model_embedding_size) and the wav partials as a list of slices will also be - returned. If is simultaneously set to False, both these values will be None + :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If + is True, the partial utterances as a numpy array of float32 of shape + (n_partials, model_embedding_size) and the wav partials as a list of slices will also be + returned. If is simultaneously set to False, both these values will be None instead. """ # Process the entire utterance if not using partials @@ -134,22 +145,22 @@ def embed_utterance(wav, using_partials=True, return_partials=False, **kwargs): if return_partials: return embed, None, None return embed - + # Compute where to split the utterance into partials and pad if necessary wave_slices, mel_slices = compute_partial_slices(len(wav), **kwargs) max_wave_length = wave_slices[-1].stop if max_wave_length >= len(wav): wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant") - + # Split the utterance into partials frames = audio.wav_to_mel_spectrogram(wav) frames_batch = np.array([frames[s] for s in mel_slices]) partial_embeds = embed_frames_batch(frames_batch) - + # Compute the utterance embedding from the partial embeddings raw_embed = np.mean(partial_embeds, axis=0) embed = raw_embed / np.linalg.norm(raw_embed, 2) - + if return_partials: return embed, partial_embeds, wave_slices return embed @@ -159,19 +170,21 @@ def embed_speaker(wavs, **kwargs): raise NotImplemented() -def plot_embedding_as_heatmap(embed, ax=None, title="", shape=None, color_range=(0, 0.30)): +def plot_embedding_as_heatmap( + embed, ax=None, title="", shape=None, color_range=(0, 0.30) +): if ax is None: ax = plt.gca() - + if shape is None: height = int(np.sqrt(len(embed))) shape = (height, -1) embed = embed.reshape(shape) - + cmap = cm.get_cmap() mappable = ax.imshow(embed, cmap=cmap) cbar = plt.colorbar(mappable, ax=ax, fraction=0.046, pad=0.04) cbar.set_clim(*color_range) - + ax.set_xticks([]), ax.set_yticks([]) ax.set_title(title) diff --git a/speaker_encoder/model.py b/speaker_encoder/model.py index c022b66..6fc462f 100644 --- a/speaker_encoder/model.py +++ b/speaker_encoder/model.py @@ -13,84 +13,92 @@ class SpeakerEncoder(nn.Module): def __init__(self, device, loss_device): super().__init__() self.loss_device = loss_device - + # Network defition - self.lstm = nn.LSTM(input_size=mel_n_channels, # 40 - hidden_size=model_hidden_size, # 256 - num_layers=model_num_layers, # 3 - batch_first=True).to(device) - self.linear = nn.Linear(in_features=model_hidden_size, - out_features=model_embedding_size).to(device) + self.lstm = nn.LSTM( + input_size=mel_n_channels, # 40 + hidden_size=model_hidden_size, # 256 + num_layers=model_num_layers, # 3 + batch_first=True, + ).to(device) + self.linear = nn.Linear( + in_features=model_hidden_size, out_features=model_embedding_size + ).to(device) self.relu = torch.nn.ReLU().to(device) - + # Cosine similarity scaling (with fixed initial parameter values) - self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device) - self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device) + self.similarity_weight = nn.Parameter(torch.tensor([10.0])).to(loss_device) + self.similarity_bias = nn.Parameter(torch.tensor([-5.0])).to(loss_device) # Loss self.loss_fn = nn.CrossEntropyLoss().to(loss_device) - + def do_gradient_ops(self): # Gradient scale self.similarity_weight.grad *= 0.01 self.similarity_bias.grad *= 0.01 - + # Gradient clipping clip_grad_norm_(self.parameters(), 3, norm_type=2) - + def forward(self, utterances, hidden_init=None): """ Computes the embeddings of a batch of utterance spectrograms. - - :param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape - (batch_size, n_frames, n_channels) - :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers, + + :param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape + (batch_size, n_frames, n_channels) + :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers, batch_size, hidden_size). Will default to a tensor of zeros if None. :return: the embeddings as a tensor of shape (batch_size, embedding_size) """ # Pass the input through the LSTM layers and retrieve all outputs, the final hidden state # and the final cell state. out, (hidden, cell) = self.lstm(utterances, hidden_init) - + # We take only the hidden state of the last layer embeds_raw = self.relu(self.linear(hidden[-1])) - + # L2-normalize it embeds = embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True) - + return embeds - + def similarity_matrix(self, embeds): """ Computes the similarity matrix according the section 2.1 of GE2E. - :param embeds: the embeddings as a tensor of shape (speakers_per_batch, + :param embeds: the embeddings as a tensor of shape (speakers_per_batch, utterances_per_speaker, embedding_size) :return: the similarity matrix as a tensor of shape (speakers_per_batch, utterances_per_speaker, speakers_per_batch) """ speakers_per_batch, utterances_per_speaker = embeds.shape[:2] - + # Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation centroids_incl = torch.mean(embeds, dim=1, keepdim=True) - centroids_incl = centroids_incl.clone() / torch.norm(centroids_incl, dim=2, keepdim=True) + centroids_incl = centroids_incl.clone() / torch.norm( + centroids_incl, dim=2, keepdim=True + ) # Exclusive centroids (1 per utterance) - centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds) - centroids_excl /= (utterances_per_speaker - 1) - centroids_excl = centroids_excl.clone() / torch.norm(centroids_excl, dim=2, keepdim=True) + centroids_excl = torch.sum(embeds, dim=1, keepdim=True) - embeds + centroids_excl /= utterances_per_speaker - 1 + centroids_excl = centroids_excl.clone() / torch.norm( + centroids_excl, dim=2, keepdim=True + ) # Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot # product of these vectors (which is just an element-wise multiplication reduced by a sum). # We vectorize the computation for efficiency. - sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker, - speakers_per_batch).to(self.loss_device) + sim_matrix = torch.zeros( + speakers_per_batch, utterances_per_speaker, speakers_per_batch + ).to(self.loss_device) mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int) for j in range(speakers_per_batch): mask = np.where(mask_matrix[j])[0] sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2) sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1) - + ## Even more vectorized version (slower maybe because of transpose) # sim_matrix2 = torch.zeros(speakers_per_batch, speakers_per_batch, utterances_per_speaker # ).to(self.loss_device) @@ -100,28 +108,29 @@ def similarity_matrix(self, embeds): # mask = np.where(eye) # sim_matrix2[mask] = (embeds * centroids_excl).sum(dim=2) # sim_matrix2 = sim_matrix2.transpose(1, 2) - + sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias return sim_matrix - + def loss(self, embeds): """ Computes the softmax loss according the section 2.1 of GE2E. - - :param embeds: the embeddings as a tensor of shape (speakers_per_batch, + + :param embeds: the embeddings as a tensor of shape (speakers_per_batch, utterances_per_speaker, embedding_size) :return: the loss and the EER for this batch of embeddings. """ speakers_per_batch, utterances_per_speaker = embeds.shape[:2] - + # Loss sim_matrix = self.similarity_matrix(embeds) - sim_matrix = sim_matrix.reshape((speakers_per_batch * utterances_per_speaker, - speakers_per_batch)) + sim_matrix = sim_matrix.reshape( + (speakers_per_batch * utterances_per_speaker, speakers_per_batch) + ) ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker) target = torch.from_numpy(ground_truth).long().to(self.loss_device) loss = self.loss_fn(sim_matrix, target) - + # EER (not backpropagated) with torch.no_grad(): inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0] @@ -129,7 +138,7 @@ def loss(self, embeds): preds = sim_matrix.detach().cpu().numpy() # Snippet from https://yangcha.github.io/EER-ROC/ - fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten()) - eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.) - - return loss, eer \ No newline at end of file + fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten()) + eer = brentq(lambda x: 1.0 - x - interp1d(fpr, tpr)(x), 0.0, 1.0) + + return loss, eer diff --git a/speaker_encoder/params_data.py b/speaker_encoder/params_data.py index bdb1716..0619c73 100644 --- a/speaker_encoder/params_data.py +++ b/speaker_encoder/params_data.py @@ -1,16 +1,15 @@ - ## Mel-filterbank mel_window_length = 25 # In milliseconds -mel_window_step = 10 # In milliseconds +mel_window_step = 10 # In milliseconds mel_n_channels = 40 ## Audio sampling_rate = 16000 # Number of spectrogram frames in a partial utterance -partials_n_frames = 160 # 1600 ms +partials_n_frames = 160 # 1600 ms # Number of spectrogram frames at inference -inference_n_frames = 80 # 800 ms +inference_n_frames = 80 # 800 ms ## Voice Activation Detection @@ -18,7 +17,7 @@ # This sets the granularity of the VAD. Should not need to be changed. vad_window_length = 30 # In milliseconds # Number of frames to average together when performing the moving average smoothing. -# The larger this value, the larger the VAD variations must be to not get smoothed out. +# The larger this value, the larger the VAD variations must be to not get smoothed out. vad_moving_average_width = 8 # Maximum number of consecutive silent frames a segment can have. vad_max_silence_length = 6 @@ -26,4 +25,3 @@ ## Audio volume normalization audio_norm_target_dBFS = -30 - diff --git a/speaker_encoder/params_model.py b/speaker_encoder/params_model.py index 3e35647..29026cc 100644 --- a/speaker_encoder/params_model.py +++ b/speaker_encoder/params_model.py @@ -1,4 +1,3 @@ - ## Model parameters model_hidden_size = 256 model_embedding_size = 256 diff --git a/speaker_encoder/preprocess.py b/speaker_encoder/preprocess.py index fe5ab25..a6c2c98 100644 --- a/speaker_encoder/preprocess.py +++ b/speaker_encoder/preprocess.py @@ -12,68 +12,75 @@ class DatasetLog: """ Registers metadata about the dataset in a text file. """ + def __init__(self, root, name): self.text_file = open(Path(root, "Log_%s.txt" % name.replace("/", "_")), "w") self.sample_data = dict() - + start_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M")) self.write_line("Creating dataset %s on %s" % (name, start_time)) self.write_line("-----") self._log_params() - + def _log_params(self): from speaker_encoder import params_data + self.write_line("Parameter values:") for param_name in (p for p in dir(params_data) if not p.startswith("__")): value = getattr(params_data, param_name) self.write_line("\t%s: %s" % (param_name, value)) self.write_line("-----") - + def write_line(self, line): self.text_file.write("%s\n" % line) - + def add_sample(self, **kwargs): for param_name, value in kwargs.items(): if not param_name in self.sample_data: self.sample_data[param_name] = [] self.sample_data[param_name].append(value) - + def finalize(self): self.write_line("Statistics:") for param_name, values in self.sample_data.items(): self.write_line("\t%s:" % param_name) self.write_line("\t\tmin %.3f, max %.3f" % (np.min(values), np.max(values))) - self.write_line("\t\tmean %.3f, median %.3f" % (np.mean(values), np.median(values))) + self.write_line( + "\t\tmean %.3f, median %.3f" % (np.mean(values), np.median(values)) + ) self.write_line("-----") end_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M")) self.write_line("Finished on %s" % end_time) self.text_file.close() - - -def _init_preprocess_dataset(dataset_name, datasets_root, out_dir) -> (Path, DatasetLog): + + +def _init_preprocess_dataset( + dataset_name, datasets_root, out_dir +) -> (Path, DatasetLog): dataset_root = datasets_root.joinpath(dataset_name) if not dataset_root.exists(): - print("Couldn\'t find %s, skipping this dataset." % dataset_root) + print("Couldn't find %s, skipping this dataset." % dataset_root) return None, None return dataset_root, DatasetLog(out_dir, dataset_name) -def _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, extension, - skip_existing, logger): +def _preprocess_speaker_dirs( + speaker_dirs, dataset_name, datasets_root, out_dir, extension, skip_existing, logger +): print("%s: Preprocessing data for %d speakers." % (dataset_name, len(speaker_dirs))) - + # Function to preprocess utterances for one speaker def preprocess_speaker(speaker_dir: Path): # Give a name to the speaker that includes its dataset speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts) - - # Create an output directory with that name, as well as a txt file containing a + + # Create an output directory with that name, as well as a txt file containing a # reference to each source file. speaker_out_dir = out_dir.joinpath(speaker_name) speaker_out_dir.mkdir(exist_ok=True) sources_fpath = speaker_out_dir.joinpath("_sources.txt") - - # There's a possibility that the preprocessing was interrupted earlier, check if + + # There's a possibility that the preprocessing was interrupted earlier, check if # there already is a sources file. if sources_fpath.exists(): try: @@ -83,7 +90,7 @@ def preprocess_speaker(speaker_dir: Path): existing_fnames = {} else: existing_fnames = {} - + # Gather all audio files for that speaker recursively sources_file = sources_fpath.open("a" if skip_existing else "w") for in_fpath in speaker_dir.glob("**/*.%s" % extension): @@ -92,97 +99,112 @@ def preprocess_speaker(speaker_dir: Path): out_fname = out_fname.replace(".%s" % extension, ".npy") if skip_existing and out_fname in existing_fnames: continue - + # Load and preprocess the waveform wav = audio.preprocess_wav(in_fpath) if len(wav) == 0: continue - + # Create the mel spectrogram, discard those that are too short frames = audio.wav_to_mel_spectrogram(wav) if len(frames) < partials_n_frames: continue - + out_fpath = speaker_out_dir.joinpath(out_fname) np.save(out_fpath, frames) logger.add_sample(duration=len(wav) / sampling_rate) sources_file.write("%s,%s\n" % (out_fname, in_fpath)) - + sources_file.close() - + # Process the utterances for each speaker with ThreadPool(8) as pool: - list(tqdm(pool.imap(preprocess_speaker, speaker_dirs), dataset_name, len(speaker_dirs), - unit="speakers")) + list( + tqdm( + pool.imap(preprocess_speaker, speaker_dirs), + dataset_name, + len(speaker_dirs), + unit="speakers", + ) + ) logger.finalize() print("Done preprocessing %s.\n" % dataset_name) # Function to preprocess utterances for one speaker -def __preprocess_speaker(speaker_dir: Path, datasets_root: Path, out_dir: Path, extension: str, skip_existing: bool): - # Give a name to the speaker that includes its dataset - speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts) - - # Create an output directory with that name, as well as a txt file containing a - # reference to each source file. - speaker_out_dir = out_dir.joinpath(speaker_name) - speaker_out_dir.mkdir(exist_ok=True) - sources_fpath = speaker_out_dir.joinpath("_sources.txt") - - # There's a possibility that the preprocessing was interrupted earlier, check if - # there already is a sources file. - # if sources_fpath.exists(): - # try: - # with sources_fpath.open("r") as sources_file: - # existing_fnames = {line.split(",")[0] for line in sources_file} - # except: - # existing_fnames = {} - # else: - # existing_fnames = {} - existing_fnames = {} - # Gather all audio files for that speaker recursively - sources_file = sources_fpath.open("a" if skip_existing else "w") +def __preprocess_speaker( + speaker_dir: Path, + datasets_root: Path, + out_dir: Path, + extension: str, + skip_existing: bool, +): + # Give a name to the speaker that includes its dataset + speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts) - for in_fpath in speaker_dir.glob("**/*.%s" % extension): - # Check if the target output file already exists - out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts) - out_fname = out_fname.replace(".%s" % extension, ".npy") - if skip_existing and out_fname in existing_fnames: - continue - - # Load and preprocess the waveform - wav = audio.preprocess_wav(in_fpath) - if len(wav) == 0: - continue - - # Create the mel spectrogram, discard those that are too short - frames = audio.wav_to_mel_spectrogram(wav) - if len(frames) < partials_n_frames: - continue - - out_fpath = speaker_out_dir.joinpath(out_fname) - np.save(out_fpath, frames) - # logger.add_sample(duration=len(wav) / sampling_rate) - sources_file.write("%s,%s\n" % (out_fname, in_fpath)) - - sources_file.close() - return len(wav) + # Create an output directory with that name, as well as a txt file containing a + # reference to each source file. + speaker_out_dir = out_dir.joinpath(speaker_name) + speaker_out_dir.mkdir(exist_ok=True) + sources_fpath = speaker_out_dir.joinpath("_sources.txt") + + # There's a possibility that the preprocessing was interrupted earlier, check if + # there already is a sources file. + # if sources_fpath.exists(): + # try: + # with sources_fpath.open("r") as sources_file: + # existing_fnames = {line.split(",")[0] for line in sources_file} + # except: + # existing_fnames = {} + # else: + # existing_fnames = {} + existing_fnames = {} + # Gather all audio files for that speaker recursively + sources_file = sources_fpath.open("a" if skip_existing else "w") -def _preprocess_speaker_dirs_vox2(speaker_dirs, dataset_name, datasets_root, out_dir, extension, - skip_existing, logger): + for in_fpath in speaker_dir.glob("**/*.%s" % extension): + # Check if the target output file already exists + out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts) + out_fname = out_fname.replace(".%s" % extension, ".npy") + if skip_existing and out_fname in existing_fnames: + continue + + # Load and preprocess the waveform + wav = audio.preprocess_wav(in_fpath) + if len(wav) == 0: + continue + + # Create the mel spectrogram, discard those that are too short + frames = audio.wav_to_mel_spectrogram(wav) + if len(frames) < partials_n_frames: + continue + + out_fpath = speaker_out_dir.joinpath(out_fname) + np.save(out_fpath, frames) + # logger.add_sample(duration=len(wav) / sampling_rate) + sources_file.write("%s,%s\n" % (out_fname, in_fpath)) + + sources_file.close() + return len(wav) + + +def _preprocess_speaker_dirs_vox2( + speaker_dirs, dataset_name, datasets_root, out_dir, extension, skip_existing, logger +): # from multiprocessing import Pool, cpu_count from pathos.multiprocessing import ProcessingPool as Pool + # Function to preprocess utterances for one speaker def __preprocess_speaker(speaker_dir: Path): # Give a name to the speaker that includes its dataset speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts) - - # Create an output directory with that name, as well as a txt file containing a + + # Create an output directory with that name, as well as a txt file containing a # reference to each source file. speaker_out_dir = out_dir.joinpath(speaker_name) speaker_out_dir.mkdir(exist_ok=True) sources_fpath = speaker_out_dir.joinpath("_sources.txt") - + existing_fnames = {} # Gather all audio files for that speaker recursively sources_file = sources_fpath.open("a" if skip_existing else "w") @@ -193,17 +215,17 @@ def __preprocess_speaker(speaker_dir: Path): out_fname = out_fname.replace(".%s" % extension, ".npy") if skip_existing and out_fname in existing_fnames: continue - + # Load and preprocess the waveform wav = audio.preprocess_wav(in_fpath) if len(wav) == 0: continue - + # Create the mel spectrogram, discard those that are too short frames = audio.wav_to_mel_spectrogram(wav) if len(frames) < partials_n_frames: continue - + out_fpath = speaker_out_dir.joinpath(out_fname) np.save(out_fpath, frames) # logger.add_sample(duration=len(wav) / sampling_rate) @@ -221,7 +243,7 @@ def __preprocess_speaker(speaker_dir: Path): for i, wav_lens in enumerate(pool.map(__preprocess_speaker, speaker_dirs), 1): for wav_len in wav_lens: logger.add_sample(duration=wav_len / sampling_rate) - print(f'{i}/{len(speaker_dirs)} \r') + print(f"{i}/{len(speaker_dirs)} \r") logger.finalize() print("Done preprocessing %s.\n" % dataset_name) @@ -230,56 +252,78 @@ def __preprocess_speaker(speaker_dir: Path): def preprocess_librispeech(datasets_root: Path, out_dir: Path, skip_existing=False): for dataset_name in librispeech_datasets["train"]["other"]: # Initialize the preprocessing - dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir) + dataset_root, logger = _init_preprocess_dataset( + dataset_name, datasets_root, out_dir + ) if not dataset_root: - return - + return + # Preprocess all speakers speaker_dirs = list(dataset_root.glob("*")) - _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "flac", - skip_existing, logger) + _preprocess_speaker_dirs( + speaker_dirs, + dataset_name, + datasets_root, + out_dir, + "flac", + skip_existing, + logger, + ) def preprocess_voxceleb1(datasets_root: Path, out_dir: Path, skip_existing=False): # Initialize the preprocessing dataset_name = "VoxCeleb1" - dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir) + dataset_root, logger = _init_preprocess_dataset( + dataset_name, datasets_root, out_dir + ) if not dataset_root: return # Get the contents of the meta file with dataset_root.joinpath("vox1_meta.csv").open("r") as metafile: metadata = [line.split("\t") for line in metafile][1:] - + # Select the ID and the nationality, filter out non-anglophone speakers nationalities = {line[0]: line[3] for line in metadata} - # keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items() if + # keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items() if # nationality.lower() in anglophone_nationalites] - keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items()] - print("VoxCeleb1: using samples from %d (presumed anglophone) speakers out of %d." % - (len(keep_speaker_ids), len(nationalities))) - + keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items()] + print( + "VoxCeleb1: using samples from %d (presumed anglophone) speakers out of %d." + % (len(keep_speaker_ids), len(nationalities)) + ) + # Get the speaker directories for anglophone speakers only speaker_dirs = dataset_root.joinpath("wav").glob("*") - speaker_dirs = [speaker_dir for speaker_dir in speaker_dirs if - speaker_dir.name in keep_speaker_ids] - print("VoxCeleb1: found %d anglophone speakers on the disk, %d missing (this is normal)." % - (len(speaker_dirs), len(keep_speaker_ids) - len(speaker_dirs))) + speaker_dirs = [ + speaker_dir + for speaker_dir in speaker_dirs + if speaker_dir.name in keep_speaker_ids + ] + print( + "VoxCeleb1: found %d anglophone speakers on the disk, %d missing (this is normal)." + % (len(speaker_dirs), len(keep_speaker_ids) - len(speaker_dirs)) + ) # Preprocess all speakers - _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "wav", - skip_existing, logger) + _preprocess_speaker_dirs( + speaker_dirs, dataset_name, datasets_root, out_dir, "wav", skip_existing, logger + ) def preprocess_voxceleb2(datasets_root: Path, out_dir: Path, skip_existing=False): # Initialize the preprocessing dataset_name = "VoxCeleb2" - dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir) + dataset_root, logger = _init_preprocess_dataset( + dataset_name, datasets_root, out_dir + ) if not dataset_root: return - + # Get the speaker directories # Preprocess all speakers speaker_dirs = list(dataset_root.joinpath("dev", "aac").glob("*")) - _preprocess_speaker_dirs_vox2(speaker_dirs, dataset_name, datasets_root, out_dir, "m4a", - skip_existing, logger) + _preprocess_speaker_dirs_vox2( + speaker_dirs, dataset_name, datasets_root, out_dir, "m4a", skip_existing, logger + ) diff --git a/speaker_encoder/train.py b/speaker_encoder/train.py index 282e4f5..e1dc745 100644 --- a/speaker_encoder/train.py +++ b/speaker_encoder/train.py @@ -1,42 +1,56 @@ from speaker_encoder.visualizations import Visualizations -from speaker_encoder.data_objects import SpeakerVerificationDataLoader, SpeakerVerificationDataset +from speaker_encoder.data_objects import ( + SpeakerVerificationDataLoader, + SpeakerVerificationDataset, +) from speaker_encoder.params_model import * from speaker_encoder.model import SpeakerEncoder from utils.profiler import Profiler from pathlib import Path import torch + def sync(device: torch.device): # FIXME - return + return # For correct profiling (cuda operations are async) if device.type == "cuda": torch.cuda.synchronize(device) -def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int, - backup_every: int, vis_every: int, force_restart: bool, visdom_server: str, - no_visdom: bool): + +def train( + run_id: str, + clean_data_root: Path, + models_dir: Path, + umap_every: int, + save_every: int, + backup_every: int, + vis_every: int, + force_restart: bool, + visdom_server: str, + no_visdom: bool, +): # Create a dataset and a dataloader dataset = SpeakerVerificationDataset(clean_data_root) loader = SpeakerVerificationDataLoader( dataset, - speakers_per_batch, # 64 - utterances_per_speaker, # 10 + speakers_per_batch, # 64 + utterances_per_speaker, # 10 num_workers=8, ) - - # Setup the device on which to run the forward pass and the loss. These can be different, + + # Setup the device on which to run the forward pass and the loss. These can be different, # because the forward pass is faster on the GPU whereas the loss is often (depending on your # hyperparameters) faster on the CPU. device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # FIXME: currently, the gradient is None if loss_device is cuda loss_device = torch.device("cpu") - + # Create the model and the optimizer model = SpeakerEncoder(device, loss_device) optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init) init_step = 1 - + # Configure file path for the model state_fpath = models_dir.joinpath(run_id + ".pt") backup_dir = models_dir.joinpath(run_id + "_backups") @@ -44,30 +58,34 @@ def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, # Load any existing model if not force_restart: if state_fpath.exists(): - print("Found existing model \"%s\", loading it and resuming training." % run_id) + print( + 'Found existing model "%s", loading it and resuming training.' % run_id + ) checkpoint = torch.load(state_fpath) init_step = checkpoint["step"] model.load_state_dict(checkpoint["model_state"]) optimizer.load_state_dict(checkpoint["optimizer_state"]) optimizer.param_groups[0]["lr"] = learning_rate_init else: - print("No model \"%s\" found, starting training from scratch." % run_id) + print('No model "%s" found, starting training from scratch.' % run_id) else: print("Starting the training from scratch.") model.train() - + # Initialize the visualization environment vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom) vis.log_dataset(dataset) vis.log_params() - device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU") + device_name = str( + torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU" + ) vis.log_implementation({"Device": device_name}) - + # Training loop profiler = Profiler(summarize_every=10, disabled=False) for step, speaker_batch in enumerate(loader, init_step): profiler.tick("Blocking, waiting for batch (threaded)") - + # Forward pass inputs = torch.from_numpy(speaker_batch.data).to(device) sync(device) @@ -75,7 +93,9 @@ def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, embeds = model(inputs) sync(device) profiler.tick("Forward pass") - embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device) + embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to( + loss_device + ) loss, eer = model.loss(embeds_loss) sync(loss_device) profiler.tick("Loss") @@ -87,11 +107,11 @@ def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, model.do_gradient_ops() optimizer.step() profiler.tick("Parameter update") - + # Update visualizations # learning_rate = optimizer.param_groups[0]["lr"] vis.update(loss.item(), eer, step) - + # Draw projections and save them to the backup folder if umap_every != 0 and step % umap_every == 0: print("Drawing and saving projections (step %d)" % step) @@ -104,22 +124,27 @@ def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, # Overwrite the latest version of the model if save_every != 0 and step % save_every == 0: print("Saving the model (step %d)" % step) - torch.save({ - "step": step + 1, - "model_state": model.state_dict(), - "optimizer_state": optimizer.state_dict(), - }, state_fpath) - + torch.save( + { + "step": step + 1, + "model_state": model.state_dict(), + "optimizer_state": optimizer.state_dict(), + }, + state_fpath, + ) + # Make a backup if backup_every != 0 and step % backup_every == 0: print("Making a backup (step %d)" % step) backup_dir.mkdir(exist_ok=True) backup_fpath = backup_dir.joinpath("%s_bak_%06d.pt" % (run_id, step)) - torch.save({ - "step": step + 1, - "model_state": model.state_dict(), - "optimizer_state": optimizer.state_dict(), - }, backup_fpath) - + torch.save( + { + "step": step + 1, + "model_state": model.state_dict(), + "optimizer_state": optimizer.state_dict(), + }, + backup_fpath, + ) + profiler.tick("Extras (visualizations, saving)") - \ No newline at end of file diff --git a/speaker_encoder/visualizations.py b/speaker_encoder/visualizations.py index ec00fc6..fea0efb 100644 --- a/speaker_encoder/visualizations.py +++ b/speaker_encoder/visualizations.py @@ -1,31 +1,42 @@ -from speaker_encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset +from speaker_encoder.data_objects.speaker_verification_dataset import ( + SpeakerVerificationDataset, +) from datetime import datetime from time import perf_counter as timer import matplotlib.pyplot as plt import numpy as np + # import webbrowser import visdom import umap -colormap = np.array([ - [76, 255, 0], - [0, 127, 70], - [255, 0, 0], - [255, 217, 38], - [0, 135, 255], - [165, 0, 165], - [255, 167, 255], - [0, 255, 255], - [255, 96, 38], - [142, 76, 0], - [33, 0, 127], - [0, 0, 0], - [183, 183, 183], -], dtype=np.float) / 255 +colormap = ( + np.array( + [ + [76, 255, 0], + [0, 127, 70], + [255, 0, 0], + [255, 217, 38], + [0, 135, 255], + [165, 0, 165], + [255, 167, 255], + [0, 255, 255], + [255, 96, 38], + [142, 76, 0], + [33, 0, 127], + [0, 0, 0], + [183, 183, 183], + ], + dtype=np.float, + ) + / 255 +) class Visualizations: - def __init__(self, env_name=None, update_every=10, server="http://localhost", disabled=False): + def __init__( + self, env_name=None, update_every=10, server="http://localhost", disabled=False + ): # Tracking data self.last_update_timestamp = timer() self.update_every = update_every @@ -33,27 +44,29 @@ def __init__(self, env_name=None, update_every=10, server="http://localhost", di self.losses = [] self.eers = [] print("Updating the visualizations every %d steps." % update_every) - + # If visdom is disabled TODO: use a better paradigm for that - self.disabled = disabled + self.disabled = disabled if self.disabled: - return - + return + # Set the environment name now = str(datetime.now().strftime("%d-%m %Hh%M")) if env_name is None: self.env_name = now else: self.env_name = "%s (%s)" % (env_name, now) - + # Connect to visdom and open the corresponding window in the browser try: self.vis = visdom.Visdom(server, env=self.env_name, raise_exceptions=True) except ConnectionError: - raise Exception("No visdom server detected. Run the command \"visdom\" in your CLI to " - "start it.") + raise Exception( + 'No visdom server detected. Run the command "visdom" in your CLI to ' + "start it." + ) # webbrowser.open("http://localhost:8097/env/" + self.env_name) - + # Create the windows self.loss_win = None self.eer_win = None @@ -61,12 +74,13 @@ def __init__(self, env_name=None, update_every=10, server="http://localhost", di self.implementation_win = None self.projection_win = None self.implementation_string = "" - + def log_params(self): if self.disabled: - return + return from speaker_encoder import params_data from speaker_encoder import params_model + param_string = "Model parameters:
" for param_name in (p for p in dir(params_model) if not p.startswith("__")): value = getattr(params_model, param_name) @@ -76,27 +90,26 @@ def log_params(self): value = getattr(params_data, param_name) param_string += "\t%s: %s
" % (param_name, value) self.vis.text(param_string, opts={"title": "Parameters"}) - + def log_dataset(self, dataset: SpeakerVerificationDataset): if self.disabled: - return + return dataset_string = "" dataset_string += "Speakers: %s\n" % len(dataset.speakers) dataset_string += "\n" + dataset.get_logs() dataset_string = dataset_string.replace("\n", "
") self.vis.text(dataset_string, opts={"title": "Dataset"}) - + def log_implementation(self, params): if self.disabled: - return + return implementation_string = "" for param, value in params.items(): implementation_string += "%s: %s\n" % (param, value) implementation_string = implementation_string.replace("\n", "
") self.implementation_string = implementation_string self.implementation_win = self.vis.text( - implementation_string, - opts={"title": "Training implementation"} + implementation_string, opts={"title": "Training implementation"} ) def update(self, loss, eer, step): @@ -107,14 +120,18 @@ def update(self, loss, eer, step): self.losses.append(loss) self.eers.append(eer) print(".", end="") - + # Update the plots every steps if step % self.update_every != 0: return - time_string = "Step time: mean: %5dms std: %5dms" % \ - (int(np.mean(self.step_times)), int(np.std(self.step_times))) - print("\nStep %6d Loss: %.4f EER: %.4f %s" % - (step, np.mean(self.losses), np.mean(self.eers), time_string)) + time_string = "Step time: mean: %5dms std: %5dms" % ( + int(np.mean(self.step_times)), + int(np.std(self.step_times)), + ) + print( + "\nStep %6d Loss: %.4f EER: %.4f %s" + % (step, np.mean(self.losses), np.mean(self.eers), time_string) + ) if not self.disabled: self.loss_win = self.vis.line( [np.mean(self.losses)], @@ -126,7 +143,7 @@ def update(self, loss, eer, step): xlabel="Step", ylabel="Loss", title="Loss", - ) + ), ) self.eer_win = self.vis.line( [np.mean(self.eers)], @@ -137,12 +154,12 @@ def update(self, loss, eer, step): legend=["Avg. EER"], xlabel="Step", ylabel="EER", - title="Equal error rate" - ) + title="Equal error rate", + ), ) if self.implementation_win is not None: self.vis.text( - self.implementation_string + ("%s" % time_string), + self.implementation_string + ("%s" % time_string), win=self.implementation_win, opts={"title": "Training implementation"}, ) @@ -151,16 +168,17 @@ def update(self, loss, eer, step): self.losses.clear() self.eers.clear() self.step_times.clear() - - def draw_projections(self, embeds, utterances_per_speaker, step, out_fpath=None, - max_speakers=10): + + def draw_projections( + self, embeds, utterances_per_speaker, step, out_fpath=None, max_speakers=10 + ): max_speakers = min(max_speakers, len(colormap)) - embeds = embeds[:max_speakers * utterances_per_speaker] - + embeds = embeds[: max_speakers * utterances_per_speaker] + n_speakers = len(embeds) // utterances_per_speaker ground_truth = np.repeat(np.arange(n_speakers), utterances_per_speaker) colors = [colormap[i] for i in ground_truth] - + reducer = umap.UMAP() projected = reducer.fit_transform(embeds) plt.scatter(projected[:, 0], projected[:, 1], c=colors) @@ -171,8 +189,7 @@ def draw_projections(self, embeds, utterances_per_speaker, step, out_fpath=None, if out_fpath is not None: plt.savefig(out_fpath) plt.clf() - + def save(self): if not self.disabled: self.vis.save([self.env_name]) - \ No newline at end of file diff --git a/speaker_encoder/voice_encoder.py b/speaker_encoder/voice_encoder.py index 88cdee2..93f70f7 100644 --- a/speaker_encoder/voice_encoder.py +++ b/speaker_encoder/voice_encoder.py @@ -9,30 +9,34 @@ class SpeakerEncoder(nn.Module): - def __init__(self, weights_fpath, device: Union[str, torch.device]=None, verbose=True): + def __init__( + self, weights_fpath, device: Union[str, torch.device] = None, verbose=True + ): """ - :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). - If None, defaults to cuda if it is available on your machine, otherwise the model will + :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). + If None, defaults to cuda if it is available on your machine, otherwise the model will run on cpu. Outputs are always returned on the cpu, as numpy arrays. """ super().__init__() - + # Define the network - self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True) + self.lstm = nn.LSTM( + mel_n_channels, model_hidden_size, model_num_layers, batch_first=True + ) self.linear = nn.Linear(model_hidden_size, model_embedding_size) self.relu = nn.ReLU() - + # Get the target device if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") elif isinstance(device, str): device = torch.device(device) self.device = device - + # Load the pretrained model'speaker weights # weights_fpath = Path(__file__).resolve().parent.joinpath("pretrained.pt") # if not weights_fpath.exists(): - # raise Exception("Couldn't find the voice encoder pretrained model at %s." % + # raise Exception("Couldn't find the voice encoder pretrained model at %s." % # weights_fpath) start = timer() @@ -40,60 +44,65 @@ def __init__(self, weights_fpath, device: Union[str, torch.device]=None, verbose self.load_state_dict(checkpoint["model_state"], strict=False) self.to(device) - + if verbose: - print("Loaded the voice encoder model on %s in %.2f seconds." % - (device.type, timer() - start)) + print( + "Loaded the voice encoder model on %s in %.2f seconds." + % (device.type, timer() - start) + ) def forward(self, mels: torch.FloatTensor): """ Computes the embeddings of a batch of utterance spectrograms. - :param mels: a batch of mel spectrograms of same duration as a float32 tensor of shape - (batch_size, n_frames, n_channels) - :return: the embeddings as a float 32 tensor of shape (batch_size, embedding_size). + :param mels: a batch of mel spectrograms of same duration as a float32 tensor of shape + (batch_size, n_frames, n_channels) + :return: the embeddings as a float 32 tensor of shape (batch_size, embedding_size). Embeddings are positive and L2-normed, thus they lay in the range [0, 1]. """ - # Pass the input through the LSTM layers and retrieve the final hidden state of the last + # Pass the input through the LSTM layers and retrieve the final hidden state of the last # layer. Apply a cutoff to 0 for negative values and L2 normalize the embeddings. _, (hidden, _) = self.lstm(mels) embeds_raw = self.relu(self.linear(hidden[-1])) return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True) - + @staticmethod def compute_partial_slices(n_samples: int, rate, min_coverage): """ - Computes where to split an utterance waveform and its corresponding mel spectrogram to - obtain partial utterances of each. Both the waveform and the - mel spectrogram slices are returned, so as to make each partial utterance waveform + Computes where to split an utterance waveform and its corresponding mel spectrogram to + obtain partial utterances of each. Both the waveform and the + mel spectrogram slices are returned, so as to make each partial utterance waveform correspond to its spectrogram. - - The returned ranges may be indexing further than the length of the waveform. It is + + The returned ranges may be indexing further than the length of the waveform. It is recommended that you pad the waveform with zeros up to wav_slices[-1].stop. - + :param n_samples: the number of samples in the waveform - :param rate: how many partial utterances should occur per second. Partial utterances must - cover the span of the entire utterance, thus the rate should not be lower than the inverse - of the duration of a partial utterance. By default, partial utterances are 1.6s long and + :param rate: how many partial utterances should occur per second. Partial utterances must + cover the span of the entire utterance, thus the rate should not be lower than the inverse + of the duration of a partial utterance. By default, partial utterances are 1.6s long and the minimum rate is thus 0.625. - :param min_coverage: when reaching the last partial utterance, it may or may not have - enough frames. If at least of are present, - then the last partial utterance will be considered by zero-padding the audio. Otherwise, - it will be discarded. If there aren't enough frames for one partial utterance, + :param min_coverage: when reaching the last partial utterance, it may or may not have + enough frames. If at least of are present, + then the last partial utterance will be considered by zero-padding the audio. Otherwise, + it will be discarded. If there aren't enough frames for one partial utterance, this parameter is ignored so that the function always returns at least one slice. - :return: the waveform slices and mel spectrogram slices as lists of array slices. Index - respectively the waveform and the mel spectrogram with these slices to obtain the partial + :return: the waveform slices and mel spectrogram slices as lists of array slices. Index + respectively the waveform and the mel spectrogram with these slices to obtain the partial utterances. """ assert 0 < min_coverage <= 1 - + # Compute how many frames separate two partial utterances samples_per_frame = int((sampling_rate * mel_window_step / 1000)) n_frames = int(np.ceil((n_samples + 1) / samples_per_frame)) frame_step = int(np.round((sampling_rate / rate) / samples_per_frame)) assert 0 < frame_step, "The rate is too high" - assert frame_step <= partials_n_frames, "The rate is too low, it should be %f at least" % \ - (sampling_rate / (samples_per_frame * partials_n_frames)) - + assert ( + frame_step <= partials_n_frames + ), "The rate is too low, it should be %f at least" % ( + sampling_rate / (samples_per_frame * partials_n_frames) + ) + # Compute the slices wav_slices, mel_slices = [], [] steps = max(1, n_frames - partials_n_frames + frame_step + 1) @@ -102,72 +111,83 @@ def compute_partial_slices(n_samples: int, rate, min_coverage): wav_range = mel_range * samples_per_frame mel_slices.append(slice(*mel_range)) wav_slices.append(slice(*wav_range)) - + # Evaluate whether extra padding is warranted or not last_wav_range = wav_slices[-1] - coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start) + coverage = (n_samples - last_wav_range.start) / ( + last_wav_range.stop - last_wav_range.start + ) if coverage < min_coverage and len(mel_slices) > 1: mel_slices = mel_slices[:-1] wav_slices = wav_slices[:-1] - + return wav_slices, mel_slices - - def embed_utterance(self, wav: np.ndarray, return_partials=False, rate=1.3, min_coverage=0.75): + + def embed_utterance( + self, wav: np.ndarray, return_partials=False, rate=1.3, min_coverage=0.75 + ): """ - Computes an embedding for a single utterance. The utterance is divided in partial - utterances and an embedding is computed for each. The complete utterance embedding is the + Computes an embedding for a single utterance. The utterance is divided in partial + utterances and an embedding is computed for each. The complete utterance embedding is the L2-normed average embedding of the partial utterances. - + TODO: independent batched version of this function - + :param wav: a preprocessed utterance waveform as a numpy array of float32 - :param return_partials: if True, the partial embeddings will also be returned along with + :param return_partials: if True, the partial embeddings will also be returned along with the wav slices corresponding to each partial utterance. - :param rate: how many partial utterances should occur per second. Partial utterances must - cover the span of the entire utterance, thus the rate should not be lower than the inverse - of the duration of a partial utterance. By default, partial utterances are 1.6s long and + :param rate: how many partial utterances should occur per second. Partial utterances must + cover the span of the entire utterance, thus the rate should not be lower than the inverse + of the duration of a partial utterance. By default, partial utterances are 1.6s long and the minimum rate is thus 0.625. - :param min_coverage: when reaching the last partial utterance, it may or may not have - enough frames. If at least of are present, - then the last partial utterance will be considered by zero-padding the audio. Otherwise, - it will be discarded. If there aren't enough frames for one partial utterance, + :param min_coverage: when reaching the last partial utterance, it may or may not have + enough frames. If at least of are present, + then the last partial utterance will be considered by zero-padding the audio. Otherwise, + it will be discarded. If there aren't enough frames for one partial utterance, this parameter is ignored so that the function always returns at least one slice. - :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If - is True, the partial utterances as a numpy array of float32 of shape - (n_partials, model_embedding_size) and the wav partials as a list of slices will also be + :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If + is True, the partial utterances as a numpy array of float32 of shape + (n_partials, model_embedding_size) and the wav partials as a list of slices will also be returned. """ - # Compute where to split the utterance into partials and pad the waveform with zeros if - # the partial utterances cover a larger range. - wav_slices, mel_slices = self.compute_partial_slices(len(wav), rate, min_coverage) + # Compute where to split the utterance into partials and pad the waveform with zeros if + # the partial utterances cover a larger range. + wav_slices, mel_slices = self.compute_partial_slices( + len(wav), rate, min_coverage + ) max_wave_length = wav_slices[-1].stop if max_wave_length >= len(wav): wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant") - + # Split the utterance into partials and forward them through the model mel = audio.wav_to_mel_spectrogram(wav) mels = np.array([mel[s] for s in mel_slices]) with torch.no_grad(): mels = torch.from_numpy(mels).to(self.device) partial_embeds = self(mels).cpu().numpy() - + # Compute the utterance embedding from the partial embeddings raw_embed = np.mean(partial_embeds, axis=0) embed = raw_embed / np.linalg.norm(raw_embed, 2) - + if return_partials: return embed, partial_embeds, wav_slices return embed - + def embed_speaker(self, wavs: List[np.ndarray], **kwargs): """ - Compute the embedding of a collection of wavs (presumably from the same speaker) by + Compute the embedding of a collection of wavs (presumably from the same speaker) by averaging their embedding and L2-normalizing it. - + :param wavs: list of wavs a numpy arrays of float32. :param kwargs: extra arguments to embed_utterance() :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). """ - raw_embed = np.mean([self.embed_utterance(wav, return_partials=False, **kwargs) \ - for wav in wavs], axis=0) - return raw_embed / np.linalg.norm(raw_embed, 2) \ No newline at end of file + raw_embed = np.mean( + [ + self.embed_utterance(wav, return_partials=False, **kwargs) + for wav in wavs + ], + axis=0, + ) + return raw_embed / np.linalg.norm(raw_embed, 2) diff --git a/tips-for-synthesizing-24KHz-wavs-from-16kHz-wavs/convert_24.py b/tips-for-synthesizing-24KHz-wavs-from-16kHz-wavs/convert_24.py index 3077c38..e39141d 100644 --- a/tips-for-synthesizing-24KHz-wavs-from-16kHz-wavs/convert_24.py +++ b/tips-for-synthesizing-24KHz-wavs-from-16kHz-wavs/convert_24.py @@ -11,18 +11,30 @@ from mel_processing import mel_spectrogram_torch from speaker_encoder.voice_encoder import SpeakerEncoder import logging -logging.getLogger('numba').setLevel(logging.WARNING) + +logging.getLogger("numba").setLevel(logging.WARNING) if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--hpfile", type=str, default="configs/freevc.json", help="path to json config file") - parser.add_argument("--ptfile", type=str, default="checkpoints/freevc.pth", help="path to pth file") - parser.add_argument("--txtpath", type=str, default="convert.txt", help="path to txt file") - parser.add_argument("--outdir", type=str, default="output/freevc", help="path to output dir") + parser.add_argument( + "--hpfile", + type=str, + default="configs/freevc.json", + help="path to json config file", + ) + parser.add_argument( + "--ptfile", type=str, default="checkpoints/freevc.pth", help="path to pth file" + ) + parser.add_argument( + "--txtpath", type=str, default="convert.txt", help="path to txt file" + ) + parser.add_argument( + "--outdir", type=str, default="output/freevc", help="path to output dir" + ) parser.add_argument("--use_timestamp", default=False, action="store_true") args = parser.parse_args() - + os.makedirs(args.outdir, exist_ok=True) hps = utils.get_hparams_from_file(args.hpfile) @@ -30,17 +42,18 @@ net_g = SynthesizerTrn( hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, - **hps.model).cuda() + **hps.model, + ).cuda() _ = net_g.eval() print("Loading checkpoint...") _ = utils.load_checkpoint(args.ptfile, net_g, None) print("Loading WavLM for content...") cmodel = utils.get_cmodel(0) - + if hps.model.use_spk: print("Loading speaker encoder...") - smodel = SpeakerEncoder('speaker_encoder/ckpt/pretrained_bak_5805000.pt') + smodel = SpeakerEncoder("speaker_encoder/ckpt/pretrained_bak_5805000.pt") print("Processing text...") titles, srcs, tgts = [], [], [] @@ -64,20 +77,20 @@ else: wav_tgt = torch.from_numpy(wav_tgt).unsqueeze(0).cuda() mel_tgt = mel_spectrogram_torch( - wav_tgt, + wav_tgt, hps.data.filter_length, hps.data.n_mel_channels, hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, hps.data.mel_fmin, - hps.data.mel_fmax + hps.data.mel_fmax, ) # src wav_src, _ = librosa.load(src, sr=hps.data.sampling_rate) wav_src = torch.from_numpy(wav_src).unsqueeze(0).cuda() c = utils.get_content(cmodel, wav_src) - + if hps.model.use_spk: audio = net_g.infer(c, g=g_tgt) else: @@ -85,7 +98,10 @@ audio = audio[0][0].data.cpu().float().numpy() if args.use_timestamp: timestamp = time.strftime("%m-%d_%H-%M", time.localtime()) - write(os.path.join(args.outdir, "{}.wav".format(timestamp+"_"+title)), 24000, audio) + write( + os.path.join(args.outdir, "{}.wav".format(timestamp + "_" + title)), + 24000, + audio, + ) else: write(os.path.join(args.outdir, f"{title}.wav"), 24000, audio) - diff --git a/tips-for-synthesizing-24KHz-wavs-from-16kHz-wavs/data_utils_24.py b/tips-for-synthesizing-24KHz-wavs-from-16kHz-wavs/data_utils_24.py index 62c9523..ae45820 100644 --- a/tips-for-synthesizing-24KHz-wavs-from-16kHz-wavs/data_utils_24.py +++ b/tips-for-synthesizing-24KHz-wavs-from-16kHz-wavs/data_utils_24.py @@ -5,26 +5,30 @@ import torch import torch.utils.data -import commons +import commons from mel_processing import spectrogram_torch, spec_to_mel_torch from utils import load_wav_to_torch, load_filepaths_and_text, transform -#import h5py + +# import h5py """Multi speaker version""" + + class TextAudioSpeakerLoader(torch.utils.data.Dataset): """ - 1) loads audio, speaker_id, text pairs - 2) normalizes text and converts them to sequences of integers - 3) computes spectrograms from audio files. + 1) loads audio, speaker_id, text pairs + 2) normalizes text and converts them to sequences of integers + 3) computes spectrograms from audio files. """ + def __init__(self, audiopaths, hparams): self.audiopaths = load_filepaths_and_text(audiopaths) self.max_wav_value = hparams.data.max_wav_value self.sampling_rate = hparams.data.sampling_rate - self.filter_length = hparams.data.filter_length - self.hop_length = hparams.data.hop_length - self.win_length = hparams.data.win_length + self.filter_length = hparams.data.filter_length + self.hop_length = hparams.data.hop_length + self.win_length = hparams.data.win_length self.use_sr = hparams.train.use_sr self.use_spk = hparams.model.use_spk self.spec_len = hparams.train.max_speclen @@ -47,46 +51,56 @@ def _filter(self): self.lengths = lengths def get_audio(self, filename): - audio, sampling_rate = load_wav_to_torch(filename.replace("DUMMY", "dataset/vctk-24k")) + audio, sampling_rate = load_wav_to_torch( + filename.replace("DUMMY", "dataset/vctk-24k") + ) if sampling_rate != 24000: - raise ValueError("{} SR doesn't match target {} SR".format( - sampling_rate, self.sampling_rate)) + raise ValueError( + "{} SR doesn't match target {} SR".format( + sampling_rate, self.sampling_rate + ) + ) audio_norm = audio / self.max_wav_value audio_norm = audio_norm.unsqueeze(0) spec_filename = filename.replace(".wav", ".spec.pt") if os.path.exists(spec_filename): spec = torch.load(spec_filename) else: - spec = spectrogram_torch(audio_norm, self.filter_length, - self.sampling_rate, self.hop_length, self.win_length, - center=False) + spec = spectrogram_torch( + audio_norm, + self.filter_length, + self.sampling_rate, + self.hop_length, + self.win_length, + center=False, + ) spec = torch.squeeze(spec, 0) torch.save(spec, spec_filename) - + if self.use_spk: spk_filename = filename.replace(".wav", ".npy") spk_filename = spk_filename.replace("DUMMY", "dataset/spk") spk = torch.from_numpy(np.load(spk_filename)) - + if not self.use_sr: c_filename = filename.replace(".wav", ".pt") c_filename = c_filename.replace("DUMMY", "dataset/wavlm") c = torch.load(c_filename).squeeze(0) else: - i = random.randint(68,92) - ''' + i = random.randint(68, 92) + """ basename = os.path.basename(filename)[:-4] spkname = basename[:4] #print(basename, spkname) with h5py.File(f"dataset/rs/wavlm/{spkname}/{i}.hdf5","r") as f: c = torch.from_numpy(f[basename][()]).squeeze(0) #print(c) - ''' + """ c_filename = filename.replace(".wav", f"_{i}.pt") c_filename = c_filename.replace("DUMMY", "dataset/sr/wavlm") c = torch.load(c_filename).squeeze(0) - - ''' + + """ lmin = min(c.size(-1), spec.size(-1)) spec, c = spec[:, :lmin], c[:, :lmin] audio_norm = audio_norm[:, :lmin*480] @@ -100,8 +114,8 @@ def get_audio(self, filename): spec = spec[:, start:end] c = c[:, start:end] audio_norm = audio_norm[:, start*480:end*480] - ''' - + """ + if self.use_spk: return c, spec, audio_norm, spk else: @@ -114,9 +128,9 @@ def __len__(self): return len(self.audiopaths) -class TextAudioSpeakerCollate(): - """ Zero-pads model inputs and targets - """ +class TextAudioSpeakerCollate: + """Zero-pads model inputs and targets""" + def __init__(self, hps): self.hps = hps self.use_sr = hps.train.use_sr @@ -130,8 +144,8 @@ def __call__(self, batch): """ # Right zero-pad all one-hot text sequences to max input length _, ids_sorted_decreasing = torch.sort( - torch.LongTensor([x[0].size(1) for x in batch]), - dim=0, descending=True) + torch.LongTensor([x[0].size(1) for x in batch]), dim=0, descending=True + ) max_spec_len = max([x[1].size(1) for x in batch]) max_wav_len = max([x[2].size(1) for x in batch]) @@ -142,67 +156,82 @@ def __call__(self, batch): spks = torch.FloatTensor(len(batch), batch[0][3].size(0)) else: spks = None - + c_padded = torch.FloatTensor(len(batch), batch[0][0].size(0), max_spec_len) spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len) wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len) c_padded.zero_() spec_padded.zero_() wav_padded.zero_() - + for i in range(len(ids_sorted_decreasing)): row = batch[ids_sorted_decreasing[i]] - + c = row[0] - c_padded[i, :, :c.size(1)] = c + c_padded[i, :, : c.size(1)] = c spec = row[1] - spec_padded[i, :, :spec.size(1)] = spec + spec_padded[i, :, : spec.size(1)] = spec spec_lengths[i] = spec.size(1) wav = row[2] - wav_padded[i, :, :wav.size(1)] = wav + wav_padded[i, :, : wav.size(1)] = wav wav_lengths[i] = wav.size(1) - + if self.use_spk: spks[i] = row[3] - - spec_seglen = spec_lengths[-1] if spec_lengths[-1] < self.hps.train.max_speclen + 1 else self.hps.train.max_speclen + 1 + + spec_seglen = ( + spec_lengths[-1] + if spec_lengths[-1] < self.hps.train.max_speclen + 1 + else self.hps.train.max_speclen + 1 + ) wav_seglen = spec_seglen * 480 - spec_padded, ids_slice = commons.rand_spec_segments(spec_padded, spec_lengths, spec_seglen) + spec_padded, ids_slice = commons.rand_spec_segments( + spec_padded, spec_lengths, spec_seglen + ) wav_padded = commons.slice_segments(wav_padded, ids_slice * 480, wav_seglen) - - c_padded = commons.slice_segments(c_padded, ids_slice, spec_seglen)[:,:,:-1] - - spec_padded = spec_padded[:,:,:-1] - wav_padded = wav_padded[:,:,:-480] + + c_padded = commons.slice_segments(c_padded, ids_slice, spec_seglen)[:, :, :-1] + + spec_padded = spec_padded[:, :, :-1] + wav_padded = wav_padded[:, :, :-480] if self.use_spk: - return c_padded, spec_padded, wav_padded, spks + return c_padded, spec_padded, wav_padded, spks else: - return c_padded, spec_padded, wav_padded - + return c_padded, spec_padded, wav_padded + class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler): """ Maintain similar input lengths in a batch. Length groups are specified by boundaries. Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}. - + It removes samples which are not included in the boundaries. Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded. """ - def __init__(self, dataset, batch_size, boundaries, num_replicas=None, rank=None, shuffle=True): + + def __init__( + self, + dataset, + batch_size, + boundaries, + num_replicas=None, + rank=None, + shuffle=True, + ): super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) self.lengths = dataset.lengths self.batch_size = batch_size self.boundaries = boundaries - + self.buckets, self.num_samples_per_bucket = self._create_buckets() self.total_size = sum(self.num_samples_per_bucket) self.num_samples = self.total_size // self.num_replicas - + def _create_buckets(self): buckets = [[] for _ in range(len(self.boundaries) - 1)] for i in range(len(self.lengths)): @@ -210,74 +239,85 @@ def _create_buckets(self): idx_bucket = self._bisect(length) if idx_bucket != -1: buckets[idx_bucket].append(i) - + for i in range(len(buckets) - 1, 0, -1): if len(buckets[i]) == 0: buckets.pop(i) - self.boundaries.pop(i+1) - + self.boundaries.pop(i + 1) + num_samples_per_bucket = [] for i in range(len(buckets)): len_bucket = len(buckets[i]) total_batch_size = self.num_replicas * self.batch_size - rem = (total_batch_size - (len_bucket % total_batch_size)) % total_batch_size + rem = ( + total_batch_size - (len_bucket % total_batch_size) + ) % total_batch_size num_samples_per_bucket.append(len_bucket + rem) return buckets, num_samples_per_bucket - + def __iter__(self): - # deterministically shuffle based on epoch - g = torch.Generator() - g.manual_seed(self.epoch) - - indices = [] - if self.shuffle: - for bucket in self.buckets: - indices.append(torch.randperm(len(bucket), generator=g).tolist()) - else: - for bucket in self.buckets: - indices.append(list(range(len(bucket)))) - - batches = [] - for i in range(len(self.buckets)): - bucket = self.buckets[i] - len_bucket = len(bucket) - ids_bucket = indices[i] - num_samples_bucket = self.num_samples_per_bucket[i] - - # add extra samples to make it evenly divisible - rem = num_samples_bucket - len_bucket - ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[:(rem % len_bucket)] - - # subsample - ids_bucket = ids_bucket[self.rank::self.num_replicas] - - # batching - for j in range(len(ids_bucket) // self.batch_size): - batch = [bucket[idx] for idx in ids_bucket[j*self.batch_size:(j+1)*self.batch_size]] - batches.append(batch) - - if self.shuffle: - batch_ids = torch.randperm(len(batches), generator=g).tolist() - batches = [batches[i] for i in batch_ids] - self.batches = batches - - assert len(self.batches) * self.batch_size == self.num_samples - return iter(self.batches) - + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + + indices = [] + if self.shuffle: + for bucket in self.buckets: + indices.append(torch.randperm(len(bucket), generator=g).tolist()) + else: + for bucket in self.buckets: + indices.append(list(range(len(bucket)))) + + batches = [] + for i in range(len(self.buckets)): + bucket = self.buckets[i] + len_bucket = len(bucket) + ids_bucket = indices[i] + num_samples_bucket = self.num_samples_per_bucket[i] + + # add extra samples to make it evenly divisible + rem = num_samples_bucket - len_bucket + ids_bucket = ( + ids_bucket + + ids_bucket * (rem // len_bucket) + + ids_bucket[: (rem % len_bucket)] + ) + + # subsample + ids_bucket = ids_bucket[self.rank :: self.num_replicas] + + # batching + for j in range(len(ids_bucket) // self.batch_size): + batch = [ + bucket[idx] + for idx in ids_bucket[ + j * self.batch_size : (j + 1) * self.batch_size + ] + ] + batches.append(batch) + + if self.shuffle: + batch_ids = torch.randperm(len(batches), generator=g).tolist() + batches = [batches[i] for i in batch_ids] + self.batches = batches + + assert len(self.batches) * self.batch_size == self.num_samples + return iter(self.batches) + def _bisect(self, x, lo=0, hi=None): - if hi is None: - hi = len(self.boundaries) - 1 - - if hi > lo: - mid = (hi + lo) // 2 - if self.boundaries[mid] < x and x <= self.boundaries[mid+1]: - return mid - elif x <= self.boundaries[mid]: - return self._bisect(x, lo, mid) - else: - return self._bisect(x, mid + 1, hi) - else: - return -1 + if hi is None: + hi = len(self.boundaries) - 1 + + if hi > lo: + mid = (hi + lo) // 2 + if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]: + return mid + elif x <= self.boundaries[mid]: + return self._bisect(x, lo, mid) + else: + return self._bisect(x, mid + 1, hi) + else: + return -1 def __len__(self): return self.num_samples // self.batch_size diff --git a/tips-for-synthesizing-24KHz-wavs-from-16kHz-wavs/downsample_24k.py b/tips-for-synthesizing-24KHz-wavs-from-16kHz-wavs/downsample_24k.py index 19a7bbb..9ced899 100644 --- a/tips-for-synthesizing-24KHz-wavs-from-16kHz-wavs/downsample_24k.py +++ b/tips-for-synthesizing-24KHz-wavs-from-16kHz-wavs/downsample_24k.py @@ -11,37 +11,41 @@ def process(wav_name): # speaker 's5', 'p280', 'p315' are excluded, speaker = wav_name[:4] wav_path = os.path.join(args.in_dir, speaker, wav_name) - if os.path.exists(wav_path) and '_mic2.flac' in wav_path: + if os.path.exists(wav_path) and "_mic2.flac" in wav_path: os.makedirs(os.path.join(args.out_dir1, speaker), exist_ok=True) wav, sr = librosa.load(wav_path) wav, index = librosa.effects.trim(wav, top_db=20) peak = np.abs(wav).max() if peak > 1.0: wav = 0.98 * wav / peak - #wav1 = librosa.resample(wav, orig_sr=sr, target_sr=args.sr1) + # wav1 = librosa.resample(wav, orig_sr=sr, target_sr=args.sr1) wav1, sr = librosa.load(wav_path, sr=args.sr1) - wav1 = wav1[int(index[0]*args.sr1/22050): int(index[1]*args.sr1/22050)] + wav1 = wav1[int(index[0] * args.sr1 / 22050) : int(index[1] * args.sr1 / 22050)] save_name = wav_name.replace("_mic2.flac", ".wav") save_path1 = os.path.join(args.out_dir1, speaker, save_name) wavfile.write( - save_path1, - args.sr1, - (wav1 * np.iinfo(np.int16).max).astype(np.int16) + save_path1, args.sr1, (wav1 * np.iinfo(np.int16).max).astype(np.int16) ) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--sr1", type=int, default=24000, help="sampling rate") - parser.add_argument("--in_dir", type=str, default="/home/Datasets/lijingyi/data/vctk/wav48_silence_trimmed/", help="path to source dir") - parser.add_argument("--out_dir1", type=str, default="./dataset/vctk-24k", help="path to target dir") + parser.add_argument( + "--in_dir", + type=str, + default="/home/Datasets/lijingyi/data/vctk/wav48_silence_trimmed/", + help="path to source dir", + ) + parser.add_argument( + "--out_dir1", type=str, default="./dataset/vctk-24k", help="path to target dir" + ) args = parser.parse_args() - pool = Pool(processes=cpu_count()-2) + pool = Pool(processes=cpu_count() - 2) for speaker in os.listdir(args.in_dir): spk_dir = os.path.join(args.in_dir, speaker) if os.path.isdir(spk_dir): for _ in tqdm(pool.imap_unordered(process, os.listdir(spk_dir))): pass - diff --git a/tips-for-synthesizing-24KHz-wavs-from-16kHz-wavs/train_24.py b/tips-for-synthesizing-24KHz-wavs-from-16kHz-wavs/train_24.py index d49aed9..cb826a7 100644 --- a/tips-for-synthesizing-24KHz-wavs-from-16kHz-wavs/train_24.py +++ b/tips-for-synthesizing-24KHz-wavs-from-16kHz-wavs/train_24.py @@ -16,277 +16,366 @@ import commons import utils from data_utils_24 import ( - TextAudioSpeakerLoader, - TextAudioSpeakerCollate, - DistributedBucketSampler + TextAudioSpeakerLoader, + TextAudioSpeakerCollate, + DistributedBucketSampler, ) from models import ( - SynthesizerTrn, - MultiPeriodDiscriminator, -) -from losses import ( - generator_loss, - discriminator_loss, - feature_loss, - kl_loss + SynthesizerTrn, + MultiPeriodDiscriminator, ) +from losses import generator_loss, discriminator_loss, feature_loss, kl_loss from mel_processing import mel_spectrogram_torch, spec_to_mel_torch torch.backends.cudnn.benchmark = True global_step = 0 -#os.environ['TORCH_DISTRIBUTED_DEBUG'] = 'INFO' +# os.environ['TORCH_DISTRIBUTED_DEBUG'] = 'INFO' def main(): - """Assume Single Node Multi GPUs Training Only""" - assert torch.cuda.is_available(), "CPU training is not allowed." - hps = utils.get_hparams() + """Assume Single Node Multi GPUs Training Only""" + assert torch.cuda.is_available(), "CPU training is not allowed." + hps = utils.get_hparams() - n_gpus = torch.cuda.device_count() - os.environ['MASTER_ADDR'] = 'localhost' - os.environ['MASTER_PORT'] = hps.train.port + n_gpus = torch.cuda.device_count() + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = hps.train.port - mp.spawn(run, nprocs=n_gpus, args=(n_gpus, hps,)) + mp.spawn( + run, + nprocs=n_gpus, + args=( + n_gpus, + hps, + ), + ) def run(rank, n_gpus, hps): - global global_step - if rank == 0: - logger = utils.get_logger(hps.model_dir) - logger.info(hps) - utils.check_git_hash(hps.model_dir) - writer = SummaryWriter(log_dir=hps.model_dir) - writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval")) - - dist.init_process_group(backend='nccl', init_method='env://', world_size=n_gpus, rank=rank) - torch.manual_seed(hps.train.seed) - torch.cuda.set_device(rank) - - train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps) - train_sampler = DistributedBucketSampler( - train_dataset, - hps.train.batch_size, - [32,300,400,500,600,700,800,900,1000], - num_replicas=n_gpus, - rank=rank, - shuffle=True) - collate_fn = TextAudioSpeakerCollate(hps) - train_loader = DataLoader(train_dataset, num_workers=8, shuffle=False, pin_memory=True, - collate_fn=collate_fn, batch_sampler=train_sampler) - if rank == 0: - eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps) - eval_loader = DataLoader(eval_dataset, num_workers=8, shuffle=True, - batch_size=hps.train.batch_size, pin_memory=False, - drop_last=False, collate_fn=collate_fn) - - net_g = SynthesizerTrn( - hps.data.filter_length // 2 + 1, - hps.train.segment_size // 480, - **hps.model).cuda(rank) - net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) - optim_g = torch.optim.AdamW( - net_g.parameters(), - hps.train.learning_rate, - betas=hps.train.betas, - eps=hps.train.eps) - optim_d = torch.optim.AdamW( - net_d.parameters(), - hps.train.learning_rate, - betas=hps.train.betas, - eps=hps.train.eps) - net_g = DDP(net_g, device_ids=[rank])#, find_unused_parameters=True) - net_d = DDP(net_d, device_ids=[rank]) - - try: - _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g) - _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, optim_d) - global_step = (epoch_str - 1) * len(train_loader) - except: - epoch_str = 1 - global_step = 0 - - scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str-2) - scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str-2) - - scaler = GradScaler(enabled=hps.train.fp16_run) - - for epoch in range(epoch_str, hps.train.epochs + 1): - if rank==0: - train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, [train_loader, eval_loader], logger, [writer, writer_eval]) - else: - train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, [train_loader, None], None, None) - scheduler_g.step() - scheduler_d.step() - - -def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers): - - net_g, net_d = nets - optim_g, optim_d = optims - scheduler_g, scheduler_d = schedulers - train_loader, eval_loader = loaders - if writers is not None: - writer, writer_eval = writers - - train_loader.batch_sampler.set_epoch(epoch) - global global_step - - net_g.train() - net_d.train() - for batch_idx, items in enumerate(train_loader): - if hps.model.use_spk: - c, spec, y, spk = items - g = spk.cuda(rank, non_blocking=True) - else: - c, spec, y = items - g = None - spec, y = spec.cuda(rank, non_blocking=True), y.cuda(rank, non_blocking=True) - c = c.cuda(rank, non_blocking=True) - - with autocast(enabled=hps.train.fp16_run): - y_hat, ids_slice, z_mask,\ - (z, z_p, m_p, logs_p, m_q, logs_q) = net_g(c, spec, g=g) - - #print(ids_slice) - - mel = mel_spectrogram_torch( - y.squeeze(1), - 960, - hps.data.n_mel_channels, - 24000, - 240, - 960, - hps.data.mel_fmin, - hps.data.mel_fmax - ) - y_mel = commons.slice_segments(mel, ids_slice * 2, hps.train.segment_size // 240) - y_hat_mel = mel_spectrogram_torch( - y_hat.squeeze(1), - 960, - hps.data.n_mel_channels, - 24000, - 240, - 960, - hps.data.mel_fmin, - hps.data.mel_fmax - ) - y = commons.slice_segments(y, ids_slice * 480, hps.train.segment_size) # slice - - # Discriminator - y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach()) - with autocast(enabled=False): - loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g) - loss_disc_all = loss_disc - optim_d.zero_grad() - scaler.scale(loss_disc_all).backward() - scaler.unscale_(optim_d) - grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None) - scaler.step(optim_d) - - with autocast(enabled=hps.train.fp16_run): - # Generator - y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat) - with autocast(enabled=False): - loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel - loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl - loss_fm = feature_loss(fmap_r, fmap_g) - loss_gen, losses_gen = generator_loss(y_d_hat_g) - loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl - optim_g.zero_grad() - scaler.scale(loss_gen_all).backward() - scaler.unscale_(optim_g) - grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None) - scaler.step(optim_g) - scaler.update() - - if rank==0: - if global_step % hps.train.log_interval == 0: - lr = optim_g.param_groups[0]['lr'] - losses = [loss_disc, loss_gen, loss_fm, loss_mel, loss_kl] - logger.info('Train Epoch: {} [{:.0f}%]'.format( - epoch, - 100. * batch_idx / len(train_loader))) - logger.info([x.item() for x in losses] + [global_step, lr]) - - scalar_dict = {"loss/g/total": loss_gen_all, "loss/d/total": loss_disc_all, "learning_rate": lr, "grad_norm_d": grad_norm_d, "grad_norm_g": grad_norm_g} - scalar_dict.update({"loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/kl": loss_kl}) - - scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}) - scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)}) - scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}) - image_dict = { - "slice/mel_org": utils.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()), - "slice/mel_gen": utils.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()), - "all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()), - } - utils.summarize( - writer=writer, - global_step=global_step, - images=image_dict, - scalars=scalar_dict) - - if global_step % hps.train.eval_interval == 0: - evaluate(hps, net_g, eval_loader, writer_eval) - utils.save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "G_{}.pth".format(global_step))) - utils.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "D_{}.pth".format(global_step))) - global_step += 1 - - if rank == 0: - logger.info('====> Epoch: {}'.format(epoch)) - - + global global_step + if rank == 0: + logger = utils.get_logger(hps.model_dir) + logger.info(hps) + utils.check_git_hash(hps.model_dir) + writer = SummaryWriter(log_dir=hps.model_dir) + writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval")) + + dist.init_process_group( + backend="nccl", init_method="env://", world_size=n_gpus, rank=rank + ) + torch.manual_seed(hps.train.seed) + torch.cuda.set_device(rank) + + train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps) + train_sampler = DistributedBucketSampler( + train_dataset, + hps.train.batch_size, + [32, 300, 400, 500, 600, 700, 800, 900, 1000], + num_replicas=n_gpus, + rank=rank, + shuffle=True, + ) + collate_fn = TextAudioSpeakerCollate(hps) + train_loader = DataLoader( + train_dataset, + num_workers=8, + shuffle=False, + pin_memory=True, + collate_fn=collate_fn, + batch_sampler=train_sampler, + ) + if rank == 0: + eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps) + eval_loader = DataLoader( + eval_dataset, + num_workers=8, + shuffle=True, + batch_size=hps.train.batch_size, + pin_memory=False, + drop_last=False, + collate_fn=collate_fn, + ) + + net_g = SynthesizerTrn( + hps.data.filter_length // 2 + 1, hps.train.segment_size // 480, **hps.model + ).cuda(rank) + net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) + optim_g = torch.optim.AdamW( + net_g.parameters(), + hps.train.learning_rate, + betas=hps.train.betas, + eps=hps.train.eps, + ) + optim_d = torch.optim.AdamW( + net_d.parameters(), + hps.train.learning_rate, + betas=hps.train.betas, + eps=hps.train.eps, + ) + net_g = DDP(net_g, device_ids=[rank]) # , find_unused_parameters=True) + net_d = DDP(net_d, device_ids=[rank]) + + try: + _, _, _, epoch_str = utils.load_checkpoint( + utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g + ) + _, _, _, epoch_str = utils.load_checkpoint( + utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, optim_d + ) + global_step = (epoch_str - 1) * len(train_loader) + except: + epoch_str = 1 + global_step = 0 + + scheduler_g = torch.optim.lr_scheduler.ExponentialLR( + optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2 + ) + scheduler_d = torch.optim.lr_scheduler.ExponentialLR( + optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2 + ) + + scaler = GradScaler(enabled=hps.train.fp16_run) + + for epoch in range(epoch_str, hps.train.epochs + 1): + if rank == 0: + train_and_evaluate( + rank, + epoch, + hps, + [net_g, net_d], + [optim_g, optim_d], + [scheduler_g, scheduler_d], + scaler, + [train_loader, eval_loader], + logger, + [writer, writer_eval], + ) + else: + train_and_evaluate( + rank, + epoch, + hps, + [net_g, net_d], + [optim_g, optim_d], + [scheduler_g, scheduler_d], + scaler, + [train_loader, None], + None, + None, + ) + scheduler_g.step() + scheduler_d.step() + + +def train_and_evaluate( + rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers +): + + net_g, net_d = nets + optim_g, optim_d = optims + scheduler_g, scheduler_d = schedulers + train_loader, eval_loader = loaders + if writers is not None: + writer, writer_eval = writers + + train_loader.batch_sampler.set_epoch(epoch) + global global_step + + net_g.train() + net_d.train() + for batch_idx, items in enumerate(train_loader): + if hps.model.use_spk: + c, spec, y, spk = items + g = spk.cuda(rank, non_blocking=True) + else: + c, spec, y = items + g = None + spec, y = spec.cuda(rank, non_blocking=True), y.cuda(rank, non_blocking=True) + c = c.cuda(rank, non_blocking=True) + + with autocast(enabled=hps.train.fp16_run): + y_hat, ids_slice, z_mask, (z, z_p, m_p, logs_p, m_q, logs_q) = net_g( + c, spec, g=g + ) + + # print(ids_slice) + + mel = mel_spectrogram_torch( + y.squeeze(1), + 960, + hps.data.n_mel_channels, + 24000, + 240, + 960, + hps.data.mel_fmin, + hps.data.mel_fmax, + ) + y_mel = commons.slice_segments( + mel, ids_slice * 2, hps.train.segment_size // 240 + ) + y_hat_mel = mel_spectrogram_torch( + y_hat.squeeze(1), + 960, + hps.data.n_mel_channels, + 24000, + 240, + 960, + hps.data.mel_fmin, + hps.data.mel_fmax, + ) + y = commons.slice_segments( + y, ids_slice * 480, hps.train.segment_size + ) # slice + + # Discriminator + y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach()) + with autocast(enabled=False): + loss_disc, losses_disc_r, losses_disc_g = discriminator_loss( + y_d_hat_r, y_d_hat_g + ) + loss_disc_all = loss_disc + optim_d.zero_grad() + scaler.scale(loss_disc_all).backward() + scaler.unscale_(optim_d) + grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None) + scaler.step(optim_d) + + with autocast(enabled=hps.train.fp16_run): + # Generator + y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat) + with autocast(enabled=False): + loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel + loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl + loss_fm = feature_loss(fmap_r, fmap_g) + loss_gen, losses_gen = generator_loss(y_d_hat_g) + loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl + optim_g.zero_grad() + scaler.scale(loss_gen_all).backward() + scaler.unscale_(optim_g) + grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None) + scaler.step(optim_g) + scaler.update() + + if rank == 0: + if global_step % hps.train.log_interval == 0: + lr = optim_g.param_groups[0]["lr"] + losses = [loss_disc, loss_gen, loss_fm, loss_mel, loss_kl] + logger.info( + "Train Epoch: {} [{:.0f}%]".format( + epoch, 100.0 * batch_idx / len(train_loader) + ) + ) + logger.info([x.item() for x in losses] + [global_step, lr]) + + scalar_dict = { + "loss/g/total": loss_gen_all, + "loss/d/total": loss_disc_all, + "learning_rate": lr, + "grad_norm_d": grad_norm_d, + "grad_norm_g": grad_norm_g, + } + scalar_dict.update( + {"loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/kl": loss_kl} + ) + + scalar_dict.update( + {"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)} + ) + scalar_dict.update( + {"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)} + ) + scalar_dict.update( + {"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)} + ) + image_dict = { + "slice/mel_org": utils.plot_spectrogram_to_numpy( + y_mel[0].data.cpu().numpy() + ), + "slice/mel_gen": utils.plot_spectrogram_to_numpy( + y_hat_mel[0].data.cpu().numpy() + ), + "all/mel": utils.plot_spectrogram_to_numpy( + mel[0].data.cpu().numpy() + ), + } + utils.summarize( + writer=writer, + global_step=global_step, + images=image_dict, + scalars=scalar_dict, + ) + + if global_step % hps.train.eval_interval == 0: + evaluate(hps, net_g, eval_loader, writer_eval) + utils.save_checkpoint( + net_g, + optim_g, + hps.train.learning_rate, + epoch, + os.path.join(hps.model_dir, "G_{}.pth".format(global_step)), + ) + utils.save_checkpoint( + net_d, + optim_d, + hps.train.learning_rate, + epoch, + os.path.join(hps.model_dir, "D_{}.pth".format(global_step)), + ) + global_step += 1 + + if rank == 0: + logger.info("====> Epoch: {}".format(epoch)) + + def evaluate(hps, generator, eval_loader, writer_eval): generator.eval() with torch.no_grad(): - for batch_idx, items in enumerate(eval_loader): - if hps.model.use_spk: - c, spec, y, spk = items - g = spk[:1].cuda(0) - else: - c, spec, y = items - g = None - spec, y = spec[:1].cuda(0), y[:1].cuda(0) - c = c[:1].cuda(0) - break - mel = mel_spectrogram_torch( - y.squeeze(1), - 960, - hps.data.n_mel_channels, - 24000, - 240, - 960, - hps.data.mel_fmin, - hps.data.mel_fmax - ) - y_hat = generator.module.infer(c, g=g) - - y_hat_mel = mel_spectrogram_torch( - y_hat.squeeze(1).float(), - 960, - hps.data.n_mel_channels, - 24000, - 240, - 960, - hps.data.mel_fmin, - hps.data.mel_fmax - ) + for batch_idx, items in enumerate(eval_loader): + if hps.model.use_spk: + c, spec, y, spk = items + g = spk[:1].cuda(0) + else: + c, spec, y = items + g = None + spec, y = spec[:1].cuda(0), y[:1].cuda(0) + c = c[:1].cuda(0) + break + mel = mel_spectrogram_torch( + y.squeeze(1), + 960, + hps.data.n_mel_channels, + 24000, + 240, + 960, + hps.data.mel_fmin, + hps.data.mel_fmax, + ) + y_hat = generator.module.infer(c, g=g) + + y_hat_mel = mel_spectrogram_torch( + y_hat.squeeze(1).float(), + 960, + hps.data.n_mel_channels, + 24000, + 240, + 960, + hps.data.mel_fmin, + hps.data.mel_fmax, + ) image_dict = { - "gen/mel": utils.plot_spectrogram_to_numpy(y_hat_mel[0].cpu().numpy()), - "gt/mel": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy()) - } - audio_dict = { - "gen/audio": y_hat[0], - "gt/audio": y[0] + "gen/mel": utils.plot_spectrogram_to_numpy(y_hat_mel[0].cpu().numpy()), + "gt/mel": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy()), } + audio_dict = {"gen/audio": y_hat[0], "gt/audio": y[0]} utils.summarize( - writer=writer_eval, - global_step=global_step, - images=image_dict, - audios=audio_dict, - audio_sampling_rate=24000 + writer=writer_eval, + global_step=global_step, + images=image_dict, + audios=audio_dict, + audio_sampling_rate=24000, ) generator.train() - + if __name__ == "__main__": - main() + main() diff --git a/train.py b/train.py index fd4516d..8639b6d 100644 --- a/train.py +++ b/train.py @@ -16,269 +16,362 @@ import commons import utils from data_utils import ( - TextAudioSpeakerLoader, - TextAudioSpeakerCollate, - DistributedBucketSampler + TextAudioSpeakerLoader, + TextAudioSpeakerCollate, + DistributedBucketSampler, ) from models import ( - SynthesizerTrn, - MultiPeriodDiscriminator, -) -from losses import ( - generator_loss, - discriminator_loss, - feature_loss, - kl_loss + SynthesizerTrn, + MultiPeriodDiscriminator, ) +from losses import generator_loss, discriminator_loss, feature_loss, kl_loss from mel_processing import mel_spectrogram_torch, spec_to_mel_torch torch.backends.cudnn.benchmark = True global_step = 0 -#os.environ['TORCH_DISTRIBUTED_DEBUG'] = 'INFO' +# os.environ['TORCH_DISTRIBUTED_DEBUG'] = 'INFO' def main(): - """Assume Single Node Multi GPUs Training Only""" - assert torch.cuda.is_available(), "CPU training is not allowed." - hps = utils.get_hparams() + """Assume Single Node Multi GPUs Training Only""" + assert torch.cuda.is_available(), "CPU training is not allowed." + hps = utils.get_hparams() - n_gpus = torch.cuda.device_count() - os.environ['MASTER_ADDR'] = 'localhost' - os.environ['MASTER_PORT'] = hps.train.port + n_gpus = torch.cuda.device_count() + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = hps.train.port - mp.spawn(run, nprocs=n_gpus, args=(n_gpus, hps,)) + mp.spawn( + run, + nprocs=n_gpus, + args=( + n_gpus, + hps, + ), + ) def run(rank, n_gpus, hps): - global global_step - if rank == 0: - logger = utils.get_logger(hps.model_dir) - logger.info(hps) - utils.check_git_hash(hps.model_dir) - writer = SummaryWriter(log_dir=hps.model_dir) - writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval")) - - dist.init_process_group(backend='nccl', init_method='env://', world_size=n_gpus, rank=rank) - torch.manual_seed(hps.train.seed) - torch.cuda.set_device(rank) - - train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps) - train_sampler = DistributedBucketSampler( - train_dataset, - hps.train.batch_size, - [32,300,400,500,600,700,800,900,1000], - num_replicas=n_gpus, - rank=rank, - shuffle=True) - collate_fn = TextAudioSpeakerCollate(hps) - train_loader = DataLoader(train_dataset, num_workers=8, shuffle=False, pin_memory=True, - collate_fn=collate_fn, batch_sampler=train_sampler) - if rank == 0: - eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps) - eval_loader = DataLoader(eval_dataset, num_workers=8, shuffle=True, - batch_size=hps.train.batch_size, pin_memory=False, - drop_last=False, collate_fn=collate_fn) - - net_g = SynthesizerTrn( - hps.data.filter_length // 2 + 1, - hps.train.segment_size // hps.data.hop_length, - **hps.model).cuda(rank) - net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) - optim_g = torch.optim.AdamW( - net_g.parameters(), - hps.train.learning_rate, - betas=hps.train.betas, - eps=hps.train.eps) - optim_d = torch.optim.AdamW( - net_d.parameters(), - hps.train.learning_rate, - betas=hps.train.betas, - eps=hps.train.eps) - net_g = DDP(net_g, device_ids=[rank])#, find_unused_parameters=True) - net_d = DDP(net_d, device_ids=[rank]) - - try: - _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g) - _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, optim_d) - global_step = (epoch_str - 1) * len(train_loader) - except: - epoch_str = 1 - global_step = 0 - - scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str-2) - scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str-2) - - scaler = GradScaler(enabled=hps.train.fp16_run) - - for epoch in range(epoch_str, hps.train.epochs + 1): - if rank==0: - train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, [train_loader, eval_loader], logger, [writer, writer_eval]) - else: - train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, [train_loader, None], None, None) - scheduler_g.step() - scheduler_d.step() - - -def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers): - - net_g, net_d = nets - optim_g, optim_d = optims - scheduler_g, scheduler_d = schedulers - train_loader, eval_loader = loaders - if writers is not None: - writer, writer_eval = writers - - train_loader.batch_sampler.set_epoch(epoch) - global global_step - - net_g.train() - net_d.train() - for batch_idx, items in enumerate(train_loader): - if hps.model.use_spk: - c, spec, y, spk = items - g = spk.cuda(rank, non_blocking=True) - else: - c, spec, y = items - g = None - spec, y = spec.cuda(rank, non_blocking=True), y.cuda(rank, non_blocking=True) - c = c.cuda(rank, non_blocking=True) - mel = spec_to_mel_torch( - spec, - hps.data.filter_length, - hps.data.n_mel_channels, - hps.data.sampling_rate, - hps.data.mel_fmin, - hps.data.mel_fmax) - - with autocast(enabled=hps.train.fp16_run): - y_hat, ids_slice, z_mask,\ - (z, z_p, m_p, logs_p, m_q, logs_q) = net_g(c, spec, g=g, mel=mel) - - y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length) - y_hat_mel = mel_spectrogram_torch( - y_hat.squeeze(1), - hps.data.filter_length, - hps.data.n_mel_channels, - hps.data.sampling_rate, - hps.data.hop_length, - hps.data.win_length, - hps.data.mel_fmin, - hps.data.mel_fmax - ) - y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice - - # Discriminator - y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach()) - with autocast(enabled=False): - loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g) - loss_disc_all = loss_disc - optim_d.zero_grad() - scaler.scale(loss_disc_all).backward() - scaler.unscale_(optim_d) - grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None) - scaler.step(optim_d) - - with autocast(enabled=hps.train.fp16_run): - # Generator - y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat) - with autocast(enabled=False): - loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel - loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl - loss_fm = feature_loss(fmap_r, fmap_g) - loss_gen, losses_gen = generator_loss(y_d_hat_g) - loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl - optim_g.zero_grad() - scaler.scale(loss_gen_all).backward() - scaler.unscale_(optim_g) - grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None) - scaler.step(optim_g) - scaler.update() - - if rank==0: - if global_step % hps.train.log_interval == 0: - lr = optim_g.param_groups[0]['lr'] - losses = [loss_disc, loss_gen, loss_fm, loss_mel, loss_kl] - logger.info('Train Epoch: {} [{:.0f}%]'.format( - epoch, - 100. * batch_idx / len(train_loader))) - logger.info([x.item() for x in losses] + [global_step, lr]) - - scalar_dict = {"loss/g/total": loss_gen_all, "loss/d/total": loss_disc_all, "learning_rate": lr, "grad_norm_d": grad_norm_d, "grad_norm_g": grad_norm_g} - scalar_dict.update({"loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/kl": loss_kl}) - - scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}) - scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)}) - scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}) - image_dict = { - "slice/mel_org": utils.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()), - "slice/mel_gen": utils.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()), - "all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()), - } - utils.summarize( - writer=writer, - global_step=global_step, - images=image_dict, - scalars=scalar_dict) - - if global_step % hps.train.eval_interval == 0: - evaluate(hps, net_g, eval_loader, writer_eval) - utils.save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "G_{}.pth".format(global_step))) - utils.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "D_{}.pth".format(global_step))) - global_step += 1 - - if rank == 0: - logger.info('====> Epoch: {}'.format(epoch)) - - + global global_step + if rank == 0: + logger = utils.get_logger(hps.model_dir) + logger.info(hps) + utils.check_git_hash(hps.model_dir) + writer = SummaryWriter(log_dir=hps.model_dir) + writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval")) + + dist.init_process_group( + backend="nccl", init_method="env://", world_size=n_gpus, rank=rank + ) + torch.manual_seed(hps.train.seed) + torch.cuda.set_device(rank) + + train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps) + train_sampler = DistributedBucketSampler( + train_dataset, + hps.train.batch_size, + [32, 300, 400, 500, 600, 700, 800, 900, 1000], + num_replicas=n_gpus, + rank=rank, + shuffle=True, + ) + collate_fn = TextAudioSpeakerCollate(hps) + train_loader = DataLoader( + train_dataset, + num_workers=8, + shuffle=False, + pin_memory=True, + collate_fn=collate_fn, + batch_sampler=train_sampler, + ) + if rank == 0: + eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps) + eval_loader = DataLoader( + eval_dataset, + num_workers=8, + shuffle=True, + batch_size=hps.train.batch_size, + pin_memory=False, + drop_last=False, + collate_fn=collate_fn, + ) + + net_g = SynthesizerTrn( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + **hps.model + ).cuda(rank) + net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) + optim_g = torch.optim.AdamW( + net_g.parameters(), + hps.train.learning_rate, + betas=hps.train.betas, + eps=hps.train.eps, + ) + optim_d = torch.optim.AdamW( + net_d.parameters(), + hps.train.learning_rate, + betas=hps.train.betas, + eps=hps.train.eps, + ) + net_g = DDP(net_g, device_ids=[rank]) # , find_unused_parameters=True) + net_d = DDP(net_d, device_ids=[rank]) + + try: + _, _, _, epoch_str = utils.load_checkpoint( + utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g + ) + _, _, _, epoch_str = utils.load_checkpoint( + utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, optim_d + ) + global_step = (epoch_str - 1) * len(train_loader) + except: + epoch_str = 1 + global_step = 0 + + scheduler_g = torch.optim.lr_scheduler.ExponentialLR( + optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2 + ) + scheduler_d = torch.optim.lr_scheduler.ExponentialLR( + optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2 + ) + + scaler = GradScaler(enabled=hps.train.fp16_run) + + for epoch in range(epoch_str, hps.train.epochs + 1): + if rank == 0: + train_and_evaluate( + rank, + epoch, + hps, + [net_g, net_d], + [optim_g, optim_d], + [scheduler_g, scheduler_d], + scaler, + [train_loader, eval_loader], + logger, + [writer, writer_eval], + ) + else: + train_and_evaluate( + rank, + epoch, + hps, + [net_g, net_d], + [optim_g, optim_d], + [scheduler_g, scheduler_d], + scaler, + [train_loader, None], + None, + None, + ) + scheduler_g.step() + scheduler_d.step() + + +def train_and_evaluate( + rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers +): + + net_g, net_d = nets + optim_g, optim_d = optims + scheduler_g, scheduler_d = schedulers + train_loader, eval_loader = loaders + if writers is not None: + writer, writer_eval = writers + + train_loader.batch_sampler.set_epoch(epoch) + global global_step + + net_g.train() + net_d.train() + for batch_idx, items in enumerate(train_loader): + if hps.model.use_spk: + c, spec, y, spk = items + g = spk.cuda(rank, non_blocking=True) + else: + c, spec, y = items + g = None + spec, y = spec.cuda(rank, non_blocking=True), y.cuda(rank, non_blocking=True) + c = c.cuda(rank, non_blocking=True) + mel = spec_to_mel_torch( + spec, + hps.data.filter_length, + hps.data.n_mel_channels, + hps.data.sampling_rate, + hps.data.mel_fmin, + hps.data.mel_fmax, + ) + + with autocast(enabled=hps.train.fp16_run): + y_hat, ids_slice, z_mask, (z, z_p, m_p, logs_p, m_q, logs_q) = net_g( + c, spec, g=g, mel=mel + ) + + y_mel = commons.slice_segments( + mel, ids_slice, hps.train.segment_size // hps.data.hop_length + ) + y_hat_mel = mel_spectrogram_torch( + y_hat.squeeze(1), + hps.data.filter_length, + hps.data.n_mel_channels, + hps.data.sampling_rate, + hps.data.hop_length, + hps.data.win_length, + hps.data.mel_fmin, + hps.data.mel_fmax, + ) + y = commons.slice_segments( + y, ids_slice * hps.data.hop_length, hps.train.segment_size + ) # slice + + # Discriminator + y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach()) + with autocast(enabled=False): + loss_disc, losses_disc_r, losses_disc_g = discriminator_loss( + y_d_hat_r, y_d_hat_g + ) + loss_disc_all = loss_disc + optim_d.zero_grad() + scaler.scale(loss_disc_all).backward() + scaler.unscale_(optim_d) + grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None) + scaler.step(optim_d) + + with autocast(enabled=hps.train.fp16_run): + # Generator + y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat) + with autocast(enabled=False): + loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel + loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl + loss_fm = feature_loss(fmap_r, fmap_g) + loss_gen, losses_gen = generator_loss(y_d_hat_g) + loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl + optim_g.zero_grad() + scaler.scale(loss_gen_all).backward() + scaler.unscale_(optim_g) + grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None) + scaler.step(optim_g) + scaler.update() + + if rank == 0: + if global_step % hps.train.log_interval == 0: + lr = optim_g.param_groups[0]["lr"] + losses = [loss_disc, loss_gen, loss_fm, loss_mel, loss_kl] + logger.info( + "Train Epoch: {} [{:.0f}%]".format( + epoch, 100.0 * batch_idx / len(train_loader) + ) + ) + logger.info([x.item() for x in losses] + [global_step, lr]) + + scalar_dict = { + "loss/g/total": loss_gen_all, + "loss/d/total": loss_disc_all, + "learning_rate": lr, + "grad_norm_d": grad_norm_d, + "grad_norm_g": grad_norm_g, + } + scalar_dict.update( + {"loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/kl": loss_kl} + ) + + scalar_dict.update( + {"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)} + ) + scalar_dict.update( + {"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)} + ) + scalar_dict.update( + {"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)} + ) + image_dict = { + "slice/mel_org": utils.plot_spectrogram_to_numpy( + y_mel[0].data.cpu().numpy() + ), + "slice/mel_gen": utils.plot_spectrogram_to_numpy( + y_hat_mel[0].data.cpu().numpy() + ), + "all/mel": utils.plot_spectrogram_to_numpy( + mel[0].data.cpu().numpy() + ), + } + utils.summarize( + writer=writer, + global_step=global_step, + images=image_dict, + scalars=scalar_dict, + ) + + if global_step % hps.train.eval_interval == 0: + evaluate(hps, net_g, eval_loader, writer_eval) + utils.save_checkpoint( + net_g, + optim_g, + hps.train.learning_rate, + epoch, + os.path.join(hps.model_dir, "G_{}.pth".format(global_step)), + ) + utils.save_checkpoint( + net_d, + optim_d, + hps.train.learning_rate, + epoch, + os.path.join(hps.model_dir, "D_{}.pth".format(global_step)), + ) + global_step += 1 + + if rank == 0: + logger.info("====> Epoch: {}".format(epoch)) + + def evaluate(hps, generator, eval_loader, writer_eval): generator.eval() with torch.no_grad(): - for batch_idx, items in enumerate(eval_loader): - if hps.model.use_spk: - c, spec, y, spk = items - g = spk[:1].cuda(0) - else: - c, spec, y = items - g = None - spec, y = spec[:1].cuda(0), y[:1].cuda(0) - c = c[:1].cuda(0) - break - mel = spec_to_mel_torch( - spec, - hps.data.filter_length, - hps.data.n_mel_channels, - hps.data.sampling_rate, - hps.data.mel_fmin, - hps.data.mel_fmax) - y_hat = generator.module.infer(c, g=g, mel=mel) - - y_hat_mel = mel_spectrogram_torch( - y_hat.squeeze(1).float(), - hps.data.filter_length, - hps.data.n_mel_channels, - hps.data.sampling_rate, - hps.data.hop_length, - hps.data.win_length, - hps.data.mel_fmin, - hps.data.mel_fmax - ) + for batch_idx, items in enumerate(eval_loader): + if hps.model.use_spk: + c, spec, y, spk = items + g = spk[:1].cuda(0) + else: + c, spec, y = items + g = None + spec, y = spec[:1].cuda(0), y[:1].cuda(0) + c = c[:1].cuda(0) + break + mel = spec_to_mel_torch( + spec, + hps.data.filter_length, + hps.data.n_mel_channels, + hps.data.sampling_rate, + hps.data.mel_fmin, + hps.data.mel_fmax, + ) + y_hat = generator.module.infer(c, g=g, mel=mel) + + y_hat_mel = mel_spectrogram_torch( + y_hat.squeeze(1).float(), + hps.data.filter_length, + hps.data.n_mel_channels, + hps.data.sampling_rate, + hps.data.hop_length, + hps.data.win_length, + hps.data.mel_fmin, + hps.data.mel_fmax, + ) image_dict = { - "gen/mel": utils.plot_spectrogram_to_numpy(y_hat_mel[0].cpu().numpy()), - "gt/mel": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy()) - } - audio_dict = { - "gen/audio": y_hat[0], - "gt/audio": y[0] + "gen/mel": utils.plot_spectrogram_to_numpy(y_hat_mel[0].cpu().numpy()), + "gt/mel": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy()), } + audio_dict = {"gen/audio": y_hat[0], "gt/audio": y[0]} utils.summarize( - writer=writer_eval, - global_step=global_step, - images=image_dict, - audios=audio_dict, - audio_sampling_rate=hps.data.sampling_rate + writer=writer_eval, + global_step=global_step, + images=image_dict, + audios=audio_dict, + audio_sampling_rate=hps.data.sampling_rate, ) generator.train() - + if __name__ == "__main__": - main() + main() diff --git a/utils.py b/utils.py index 3c66418..f8a80bd 100644 --- a/utils.py +++ b/utils.py @@ -21,14 +21,14 @@ def get_cmodel(rank): - checkpoint = torch.load('wavlm/WavLM-Large.pt') - cfg = WavLMConfig(checkpoint['cfg']) + checkpoint = torch.load("wavlm/WavLM-Large.pt") + cfg = WavLMConfig(checkpoint["cfg"]) cmodel = WavLM(cfg).cuda(rank) - cmodel.load_state_dict(checkpoint['model']) + cmodel.load_state_dict(checkpoint["model"]) cmodel.eval() return cmodel - - + + def get_content(cmodel, y): with torch.no_grad(): c = cmodel.extract_features(y.squeeze(1))[0] @@ -47,265 +47,295 @@ def get_vocoder(rank): vocoder.remove_weight_norm() vocoder.cuda(rank) return vocoder - - -def transform(mel, height): # 68-92 - #r = np.random.random() - #rate = r * 0.3 + 0.85 # 0.85-1.15 - #height = int(mel.size(-2) * rate) + + +def transform(mel, height): # 68-92 + # r = np.random.random() + # rate = r * 0.3 + 0.85 # 0.85-1.15 + # height = int(mel.size(-2) * rate) tgt = torchvision.transforms.functional.resize(mel, (height, mel.size(-1))) if height >= mel.size(-2): - return tgt[:, :mel.size(-2), :] + return tgt[:, : mel.size(-2), :] else: - silence = tgt[:,-1:,:].repeat(1,mel.size(-2)-height,1) + silence = tgt[:, -1:, :].repeat(1, mel.size(-2) - height, 1) silence += torch.randn_like(silence) / 10 return torch.cat((tgt, silence), 1) - - -def stretch(mel, width): # 0.5-2 + + +def stretch(mel, width): # 0.5-2 return torchvision.transforms.functional.resize(mel, (mel.size(-2), width)) def load_checkpoint(checkpoint_path, model, optimizer=None, strict=False): - assert os.path.isfile(checkpoint_path) - checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') - iteration = checkpoint_dict['iteration'] - learning_rate = checkpoint_dict['learning_rate'] - if optimizer is not None: - optimizer.load_state_dict(checkpoint_dict['optimizer']) - saved_state_dict = checkpoint_dict['model'] - if hasattr(model, 'module'): - state_dict = model.module.state_dict() - else: - state_dict = model.state_dict() - if strict: - assert state_dict.keys() == saved_state_dict.keys(), "Mismatched model config and checkpoint." - new_state_dict= {} - for k, v in state_dict.items(): - try: - new_state_dict[k] = saved_state_dict[k] - except: - logger.info("%s is not in the checkpoint" % k) - new_state_dict[k] = v - if hasattr(model, 'module'): - model.module.load_state_dict(new_state_dict) - else: - model.load_state_dict(new_state_dict) - logger.info("Loaded checkpoint '{}' (iteration {})" .format( - checkpoint_path, iteration)) - return model, optimizer, learning_rate, iteration + assert os.path.isfile(checkpoint_path) + checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") + iteration = checkpoint_dict["iteration"] + learning_rate = checkpoint_dict["learning_rate"] + if optimizer is not None: + optimizer.load_state_dict(checkpoint_dict["optimizer"]) + saved_state_dict = checkpoint_dict["model"] + if hasattr(model, "module"): + state_dict = model.module.state_dict() + else: + state_dict = model.state_dict() + if strict: + assert ( + state_dict.keys() == saved_state_dict.keys() + ), "Mismatched model config and checkpoint." + new_state_dict = {} + for k, v in state_dict.items(): + try: + new_state_dict[k] = saved_state_dict[k] + except: + logger.info("%s is not in the checkpoint" % k) + new_state_dict[k] = v + if hasattr(model, "module"): + model.module.load_state_dict(new_state_dict) + else: + model.load_state_dict(new_state_dict) + logger.info( + "Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration) + ) + return model, optimizer, learning_rate, iteration def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): - logger.info("Saving model and optimizer state at iteration {} to {}".format( - iteration, checkpoint_path)) - if hasattr(model, 'module'): - state_dict = model.module.state_dict() - else: - state_dict = model.state_dict() - torch.save({'model': state_dict, - 'iteration': iteration, - 'optimizer': optimizer.state_dict(), - 'learning_rate': learning_rate}, checkpoint_path) - - -def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050): - for k, v in scalars.items(): - writer.add_scalar(k, v, global_step) - for k, v in histograms.items(): - writer.add_histogram(k, v, global_step) - for k, v in images.items(): - writer.add_image(k, v, global_step, dataformats='HWC') - for k, v in audios.items(): - writer.add_audio(k, v, global_step, audio_sampling_rate) + logger.info( + "Saving model and optimizer state at iteration {} to {}".format( + iteration, checkpoint_path + ) + ) + if hasattr(model, "module"): + state_dict = model.module.state_dict() + else: + state_dict = model.state_dict() + torch.save( + { + "model": state_dict, + "iteration": iteration, + "optimizer": optimizer.state_dict(), + "learning_rate": learning_rate, + }, + checkpoint_path, + ) + + +def summarize( + writer, + global_step, + scalars={}, + histograms={}, + images={}, + audios={}, + audio_sampling_rate=22050, +): + for k, v in scalars.items(): + writer.add_scalar(k, v, global_step) + for k, v in histograms.items(): + writer.add_histogram(k, v, global_step) + for k, v in images.items(): + writer.add_image(k, v, global_step, dataformats="HWC") + for k, v in audios.items(): + writer.add_audio(k, v, global_step, audio_sampling_rate) def latest_checkpoint_path(dir_path, regex="G_*.pth"): - f_list = glob.glob(os.path.join(dir_path, regex)) - f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) - x = f_list[-1] - print(x) - return x + f_list = glob.glob(os.path.join(dir_path, regex)) + f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) + x = f_list[-1] + print(x) + return x def plot_spectrogram_to_numpy(spectrogram): - global MATPLOTLIB_FLAG - if not MATPLOTLIB_FLAG: - import matplotlib - matplotlib.use("Agg") - MATPLOTLIB_FLAG = True - mpl_logger = logging.getLogger('matplotlib') - mpl_logger.setLevel(logging.WARNING) - import matplotlib.pylab as plt - import numpy as np - - fig, ax = plt.subplots(figsize=(10,2)) - im = ax.imshow(spectrogram, aspect="auto", origin="lower", - interpolation='none') - plt.colorbar(im, ax=ax) - plt.xlabel("Frames") - plt.ylabel("Channels") - plt.tight_layout() - - fig.canvas.draw() - data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') - data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) - plt.close() - return data + global MATPLOTLIB_FLAG + if not MATPLOTLIB_FLAG: + import matplotlib + + matplotlib.use("Agg") + MATPLOTLIB_FLAG = True + mpl_logger = logging.getLogger("matplotlib") + mpl_logger.setLevel(logging.WARNING) + import matplotlib.pylab as plt + import numpy as np + + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") + plt.colorbar(im, ax=ax) + plt.xlabel("Frames") + plt.ylabel("Channels") + plt.tight_layout() + + fig.canvas.draw() + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + plt.close() + return data def plot_alignment_to_numpy(alignment, info=None): - global MATPLOTLIB_FLAG - if not MATPLOTLIB_FLAG: - import matplotlib - matplotlib.use("Agg") - MATPLOTLIB_FLAG = True - mpl_logger = logging.getLogger('matplotlib') - mpl_logger.setLevel(logging.WARNING) - import matplotlib.pylab as plt - import numpy as np - - fig, ax = plt.subplots(figsize=(6, 4)) - im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower', - interpolation='none') - fig.colorbar(im, ax=ax) - xlabel = 'Decoder timestep' - if info is not None: - xlabel += '\n\n' + info - plt.xlabel(xlabel) - plt.ylabel('Encoder timestep') - plt.tight_layout() - - fig.canvas.draw() - data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') - data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) - plt.close() - return data + global MATPLOTLIB_FLAG + if not MATPLOTLIB_FLAG: + import matplotlib + + matplotlib.use("Agg") + MATPLOTLIB_FLAG = True + mpl_logger = logging.getLogger("matplotlib") + mpl_logger.setLevel(logging.WARNING) + import matplotlib.pylab as plt + import numpy as np + + fig, ax = plt.subplots(figsize=(6, 4)) + im = ax.imshow( + alignment.transpose(), aspect="auto", origin="lower", interpolation="none" + ) + fig.colorbar(im, ax=ax) + xlabel = "Decoder timestep" + if info is not None: + xlabel += "\n\n" + info + plt.xlabel(xlabel) + plt.ylabel("Encoder timestep") + plt.tight_layout() + + fig.canvas.draw() + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + plt.close() + return data def load_wav_to_torch(full_path): - sampling_rate, data = read(full_path) - return torch.FloatTensor(data.astype(np.float32)), sampling_rate + sampling_rate, data = read(full_path) + return torch.FloatTensor(data.astype(np.float32)), sampling_rate def load_filepaths_and_text(filename, split="|"): - with open(filename, encoding='utf-8') as f: - filepaths_and_text = [line.strip().split(split) for line in f] - return filepaths_and_text + with open(filename, encoding="utf-8") as f: + filepaths_and_text = [line.strip().split(split) for line in f] + return filepaths_and_text def get_hparams(init=True): - parser = argparse.ArgumentParser() - parser.add_argument('-c', '--config', type=str, default="./configs/base.json", - help='JSON file for configuration') - parser.add_argument('-m', '--model', type=str, required=True, - help='Model name') - - args = parser.parse_args() - model_dir = os.path.join("./logs", args.model) - - if not os.path.exists(model_dir): - os.makedirs(model_dir) - - config_path = args.config - config_save_path = os.path.join(model_dir, "config.json") - if init: - with open(config_path, "r") as f: - data = f.read() - with open(config_save_path, "w") as f: - f.write(data) - else: - with open(config_save_path, "r") as f: - data = f.read() - config = json.loads(data) - - hparams = HParams(**config) - hparams.model_dir = model_dir - return hparams + parser = argparse.ArgumentParser() + parser.add_argument( + "-c", + "--config", + type=str, + default="./configs/base.json", + help="JSON file for configuration", + ) + parser.add_argument("-m", "--model", type=str, required=True, help="Model name") + + args = parser.parse_args() + model_dir = os.path.join("./logs", args.model) + + if not os.path.exists(model_dir): + os.makedirs(model_dir) + + config_path = args.config + config_save_path = os.path.join(model_dir, "config.json") + if init: + with open(config_path, "r") as f: + data = f.read() + with open(config_save_path, "w") as f: + f.write(data) + else: + with open(config_save_path, "r") as f: + data = f.read() + config = json.loads(data) + + hparams = HParams(**config) + hparams.model_dir = model_dir + return hparams def get_hparams_from_dir(model_dir): - config_save_path = os.path.join(model_dir, "config.json") - with open(config_save_path, "r") as f: - data = f.read() - config = json.loads(data) + config_save_path = os.path.join(model_dir, "config.json") + with open(config_save_path, "r") as f: + data = f.read() + config = json.loads(data) - hparams =HParams(**config) - hparams.model_dir = model_dir - return hparams + hparams = HParams(**config) + hparams.model_dir = model_dir + return hparams def get_hparams_from_file(config_path): - with open(config_path, "r") as f: - data = f.read() - config = json.loads(data) + with open(config_path, "r") as f: + data = f.read() + config = json.loads(data) - hparams =HParams(**config) - return hparams + hparams = HParams(**config) + return hparams def check_git_hash(model_dir): - source_dir = os.path.dirname(os.path.realpath(__file__)) - if not os.path.exists(os.path.join(source_dir, ".git")): - logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format( - source_dir - )) - return + source_dir = os.path.dirname(os.path.realpath(__file__)) + if not os.path.exists(os.path.join(source_dir, ".git")): + logger.warn( + "{} is not a git repository, therefore hash value comparison will be ignored.".format( + source_dir + ) + ) + return + + cur_hash = subprocess.getoutput("git rev-parse HEAD") + + path = os.path.join(model_dir, "githash") + if os.path.exists(path): + saved_hash = open(path).read() + if saved_hash != cur_hash: + logger.warn( + "git hash values are different. {}(saved) != {}(current)".format( + saved_hash[:8], cur_hash[:8] + ) + ) + else: + open(path, "w").write(cur_hash) - cur_hash = subprocess.getoutput("git rev-parse HEAD") - path = os.path.join(model_dir, "githash") - if os.path.exists(path): - saved_hash = open(path).read() - if saved_hash != cur_hash: - logger.warn("git hash values are different. {}(saved) != {}(current)".format( - saved_hash[:8], cur_hash[:8])) - else: - open(path, "w").write(cur_hash) +def get_logger(model_dir, filename="train.log"): + global logger + logger = logging.getLogger(os.path.basename(model_dir)) + logger.setLevel(logging.DEBUG) + formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") + if not os.path.exists(model_dir): + os.makedirs(model_dir) + h = logging.FileHandler(os.path.join(model_dir, filename)) + h.setLevel(logging.DEBUG) + h.setFormatter(formatter) + logger.addHandler(h) + return logger -def get_logger(model_dir, filename="train.log"): - global logger - logger = logging.getLogger(os.path.basename(model_dir)) - logger.setLevel(logging.DEBUG) - - formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") - if not os.path.exists(model_dir): - os.makedirs(model_dir) - h = logging.FileHandler(os.path.join(model_dir, filename)) - h.setLevel(logging.DEBUG) - h.setFormatter(formatter) - logger.addHandler(h) - return logger - - -class HParams(): - def __init__(self, **kwargs): - for k, v in kwargs.items(): - if type(v) == dict: - v = HParams(**v) - self[k] = v - - def keys(self): - return self.__dict__.keys() - - def items(self): - return self.__dict__.items() - - def values(self): - return self.__dict__.values() - - def __len__(self): - return len(self.__dict__) - - def __getitem__(self, key): - return getattr(self, key) - - def __setitem__(self, key, value): - return setattr(self, key, value) - - def __contains__(self, key): - return key in self.__dict__ - - def __repr__(self): - return self.__dict__.__repr__() + +class HParams: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + if type(v) == dict: + v = HParams(**v) + self[k] = v + + def keys(self): + return self.__dict__.keys() + + def items(self): + return self.__dict__.items() + + def values(self): + return self.__dict__.values() + + def __len__(self): + return len(self.__dict__) + + def __getitem__(self, key): + return getattr(self, key) + + def __setitem__(self, key, value): + return setattr(self, key, value) + + def __contains__(self, key): + return key in self.__dict__ + + def __repr__(self): + return self.__dict__.__repr__() diff --git a/wavlm/WavLM.py b/wavlm/WavLM.py index 09af97b..0905ab3 100644 --- a/wavlm/WavLM.py +++ b/wavlm/WavLM.py @@ -155,60 +155,96 @@ def arrange(s, e, length, keep_length): if len(mask_idc) > min_len: mask_idc = np.random.choice(mask_idc, min_len, replace=False) mask[i, mask_idc] = True - + return mask class WavLMConfig: def __init__(self, cfg=None): - self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True) - self.encoder_layers: int = 12 # num encoder layers in the transformer + self.extractor_mode: str = ( + "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True) + ) + self.encoder_layers: int = 12 # num encoder layers in the transformer - self.encoder_embed_dim: int = 768 # encoder embedding dimension - self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN - self.encoder_attention_heads: int = 12 # num encoder attention heads - self.activation_fn: str = "gelu" # activation function to use + self.encoder_embed_dim: int = 768 # encoder embedding dimension + self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN + self.encoder_attention_heads: int = 12 # num encoder attention heads + self.activation_fn: str = "gelu" # activation function to use - self.layer_norm_first: bool = False # apply layernorm first in the transformer - self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...] - self.conv_bias: bool = False # include bias in conv encoder - self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this + self.layer_norm_first: bool = False # apply layernorm first in the transformer + self.conv_feature_layers: str = ( + "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...] + ) + self.conv_bias: bool = False # include bias in conv encoder + self.feature_grad_mult: float = ( + 1.0 # multiply feature extractor var grads by this + ) - self.normalize: bool = False # normalize input to have 0 mean and unit variance during training + self.normalize: bool = ( + False # normalize input to have 0 mean and unit variance during training + ) # dropouts - self.dropout: float = 0.1 # dropout probability for the transformer - self.attention_dropout: float = 0.1 # dropout probability for attention weights - self.activation_dropout: float = 0.0 # dropout probability after activation in FFN - self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer - self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr) - self.dropout_features: float = 0.0 # dropout to apply to the features (after feat extr) + self.dropout: float = 0.1 # dropout probability for the transformer + self.attention_dropout: float = 0.1 # dropout probability for attention weights + self.activation_dropout: float = ( + 0.0 # dropout probability after activation in FFN + ) + self.encoder_layerdrop: float = ( + 0.0 # probability of dropping a tarnsformer layer + ) + self.dropout_input: float = ( + 0.0 # dropout to apply to the input (after feat extr) + ) + self.dropout_features: float = ( + 0.0 # dropout to apply to the features (after feat extr) + ) # masking - self.mask_length: int = 10 # mask length - self.mask_prob: float = 0.65 # probability of replacing a token with mask - self.mask_selection: str = "static" # how to choose mask length - self.mask_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh - self.no_mask_overlap: bool = False # whether to allow masks to overlap - self.mask_min_space: int = 1 # min space between spans (if no overlap is enabled) + self.mask_length: int = 10 # mask length + self.mask_prob: float = 0.65 # probability of replacing a token with mask + self.mask_selection: str = "static" # how to choose mask length + self.mask_other: float = ( + 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh + ) + self.no_mask_overlap: bool = False # whether to allow masks to overlap + self.mask_min_space: int = ( + 1 # min space between spans (if no overlap is enabled) + ) # channel masking - self.mask_channel_length: int = 10 # length of the mask for features (channels) - self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0 - self.mask_channel_selection: str = "static" # how to choose mask length for channel masking - self.mask_channel_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices - self.no_mask_channel_overlap: bool = False # whether to allow channel masks to overlap - self.mask_channel_min_space: int = 1 # min space between spans (if no overlap is enabled) + self.mask_channel_length: int = 10 # length of the mask for features (channels) + self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0 + self.mask_channel_selection: str = ( + "static" # how to choose mask length for channel masking + ) + self.mask_channel_other: float = ( + 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices + ) + self.no_mask_channel_overlap: bool = ( + False # whether to allow channel masks to overlap + ) + self.mask_channel_min_space: int = ( + 1 # min space between spans (if no overlap is enabled) + ) # positional embeddings - self.conv_pos: int = 128 # number of filters for convolutional positional embeddings - self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding + self.conv_pos: int = ( + 128 # number of filters for convolutional positional embeddings + ) + self.conv_pos_groups: int = ( + 16 # number of groups for convolutional positional embedding + ) # relative position embedding - self.relative_position_embedding: bool = False # apply relative position embedding - self.num_buckets: int = 320 # number of buckets for relative position embedding - self.max_distance: int = 1280 # maximum distance for relative position embedding - self.gru_rel_pos: bool = False # apply gated relative position embedding + self.relative_position_embedding: bool = ( + False # apply relative position embedding + ) + self.num_buckets: int = 320 # number of buckets for relative position embedding + self.max_distance: int = ( + 1280 # maximum distance for relative position embedding + ) + self.gru_rel_pos: bool = False # apply gated relative position embedding if cfg is not None: self.update(cfg) @@ -305,19 +341,19 @@ def apply_mask(self, x, padding_mask): .expand(-1, T, -1) ) x[mask_channel_indices] = 0 - + return x, mask_indices def forward_padding_mask( - self, features: torch.Tensor, padding_mask: torch.Tensor, + self, + features: torch.Tensor, + padding_mask: torch.Tensor, ) -> torch.Tensor: extra = padding_mask.size(1) % features.size(1) if extra > 0: padding_mask = padding_mask[:, :-extra] - padding_mask = padding_mask.view( - padding_mask.size(0), features.size(1), -1 - ) - #padding_mask = padding_mask.all(-1) + padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1) + # padding_mask = padding_mask.all(-1) padding_mask = padding_mask.any(-1) return padding_mask @@ -343,19 +379,17 @@ def extract_features( if padding_mask is not None: padding_mask = self.forward_padding_mask(features, padding_mask) - + if self.post_extract_proj is not None: features = self.post_extract_proj(features) features = self.dropout_input(features) if mask: - x, mask_indices = self.apply_mask( - features, padding_mask - ) + x, mask_indices = self.apply_mask(features, padding_mask) else: x = features - + # feature: (B, T, D), float # target: (B, T), long # x: (B, T, D), float @@ -364,10 +398,15 @@ def extract_features( x, layer_results = self.encoder( x, padding_mask=padding_mask, - layer=None if output_layer is None else output_layer - 1 + layer=None if output_layer is None else output_layer - 1, ) - - res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results} + + res = { + "x": x, + "padding_mask": padding_mask, + "features": features, + "layer_results": layer_results, + } feature = res["features"] if ret_conv else res["x"] if ret_layer_results: @@ -377,25 +416,25 @@ def extract_features( class ConvFeatureExtractionModel(nn.Module): def __init__( - self, - conv_layers: List[Tuple[int, int, int]], - dropout: float = 0.0, - mode: str = "default", - conv_bias: bool = False, - conv_type: str = "default" + self, + conv_layers: List[Tuple[int, int, int]], + dropout: float = 0.0, + mode: str = "default", + conv_bias: bool = False, + conv_type: str = "default", ): super().__init__() assert mode in {"default", "layer_norm"} def block( - n_in, - n_out, - k, - stride, - is_layer_norm=False, - is_group_norm=False, - conv_bias=False, + n_in, + n_out, + k, + stride, + is_layer_norm=False, + is_group_norm=False, + conv_bias=False, ): def make_conv(): conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) @@ -403,8 +442,8 @@ def make_conv(): return conv assert ( - is_layer_norm and is_group_norm - ) == False, "layer norm and group norm are exclusive" + is_layer_norm and is_group_norm + ) == False, "layer norm and group norm are exclusive" if is_layer_norm: return nn.Sequential( @@ -454,9 +493,7 @@ def make_conv(): assert len(cl) == 3 (dim, k, stride) = cl - self.conv_layers.append( - torch.nn.Conv2d(in_d, dim, k, stride) - ) + self.conv_layers.append(torch.nn.Conv2d(in_d, dim, k, stride)) self.conv_layers.append(torch.nn.ReLU()) in_d = dim elif self.conv_type == "custom": @@ -469,9 +506,7 @@ def make_conv(): self.conv_layers.append( torch.nn.Conv2d(in_d, dim, k, stride, padding=1) ) - self.conv_layers.append( - torch.nn.LayerNorm([dim, idim]) - ) + self.conv_layers.append(torch.nn.LayerNorm([dim, idim])) self.conv_layers.append(torch.nn.ReLU()) in_d = dim if (i + 1) % 2 == 0: @@ -546,7 +581,9 @@ def __init__(self, args): activation_dropout=args.activation_dropout, activation_fn=args.activation_fn, layer_norm_first=args.layer_norm_first, - has_relative_attention_bias=(self.relative_position_embedding and i == 0), + has_relative_attention_bias=( + self.relative_position_embedding and i == 0 + ), num_buckets=self.num_buckets, max_distance=self.max_distance, gru_rel_pos=args.gru_rel_pos, @@ -563,17 +600,19 @@ def __init__(self, args): def forward(self, x, padding_mask=None, streaming_mask=None, layer=None): x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer) - + if self.layer_norm_first and layer is None: x = self.layer_norm(x) return x, layer_results - def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer=None): + def extract_features( + self, x, padding_mask=None, streaming_mask=None, tgt_layer=None + ): if padding_mask is not None: x[padding_mask] = 0 - + x_conv = self.pos_conv(x.transpose(1, 2)) x_conv = x_conv.transpose(1, 2) x += x_conv @@ -595,8 +634,13 @@ def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer= for i, layer in enumerate(self.layers): dropout_probability = np.random.random() if not self.training or (dropout_probability > self.layerdrop): - x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, - self_attn_mask=streaming_mask, pos_bias=pos_bias) + x, z, pos_bias = layer( + x, + self_attn_padding_mask=padding_mask, + need_weights=False, + self_attn_mask=streaming_mask, + pos_bias=pos_bias, + ) if tgt_layer is not None: layer_results.append((x, z)) if i == tgt_layer: @@ -619,20 +663,20 @@ class TransformerSentenceEncoderLayer(nn.Module): """ def __init__( - self, - embedding_dim: float = 768, - ffn_embedding_dim: float = 3072, - num_attention_heads: float = 8, - dropout: float = 0.1, - attention_dropout: float = 0.1, - activation_dropout: float = 0.1, - activation_fn: str = "relu", - layer_norm_first: bool = False, - has_relative_attention_bias: bool = False, - num_buckets: int = 0, - max_distance: int = 0, - rescale_init: bool = False, - gru_rel_pos: bool = False, + self, + embedding_dim: float = 768, + ffn_embedding_dim: float = 3072, + num_attention_heads: float = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + activation_fn: str = "relu", + layer_norm_first: bool = False, + has_relative_attention_bias: bool = False, + num_buckets: int = 0, + max_distance: int = 0, + rescale_init: bool = False, + gru_rel_pos: bool = False, ) -> None: super().__init__() @@ -675,12 +719,12 @@ def __init__( self.final_layer_norm = LayerNorm(self.embedding_dim) def forward( - self, - x: torch.Tensor, - self_attn_mask: torch.Tensor = None, - self_attn_padding_mask: torch.Tensor = None, - need_weights: bool = False, - pos_bias=None + self, + x: torch.Tensor, + self_attn_mask: torch.Tensor = None, + self_attn_padding_mask: torch.Tensor = None, + need_weights: bool = False, + pos_bias=None, ): """ LayerNorm is applied either before or after the self-attention/ffn @@ -697,7 +741,7 @@ def forward( key_padding_mask=self_attn_padding_mask, need_weights=False, attn_mask=self_attn_mask, - position_bias=pos_bias + position_bias=pos_bias, ) x = self.dropout1(x) x = residual + x @@ -720,7 +764,7 @@ def forward( key_padding_mask=self_attn_padding_mask, need_weights=need_weights, attn_mask=self_attn_mask, - position_bias=pos_bias + position_bias=pos_bias, ) x = self.dropout1(x) diff --git a/wavlm/__init__.py b/wavlm/__init__.py index 03f8908..2d9ca9a 100644 --- a/wavlm/__init__.py +++ b/wavlm/__init__.py @@ -1 +1 @@ -from wavlm.WavLM import WavLM, WavLMConfig \ No newline at end of file +from wavlm.WavLM import WavLM, WavLMConfig diff --git a/wavlm/modules.py b/wavlm/modules.py index cd360aa..70ad1b6 100644 --- a/wavlm/modules.py +++ b/wavlm/modules.py @@ -84,8 +84,7 @@ def forward(self, x): class Swish(nn.Module): - """Swish function - """ + """Swish function""" def __init__(self): """Construct an MultiHeadedAttention object.""" @@ -122,9 +121,14 @@ def forward(self, x): x = self.linear(x) if self.glu_type == "bilinear": - x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2]) + x = ( + x[:, :, 0 : self.output_dim] + * x[:, :, self.output_dim : self.output_dim * 2] + ) else: - x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2])) + x = x[:, :, 0 : self.output_dim] * self.glu_act( + x[:, :, self.output_dim : self.output_dim * 2] + ) return x @@ -149,9 +153,7 @@ def get_activation_fn(activation: str): elif activation == "gelu": return gelu elif activation == "gelu_fast": - warnings.warn( - "--activation-fn=gelu_fast has been renamed to gelu_accurate" - ) + warnings.warn("--activation-fn=gelu_fast has been renamed to gelu_accurate") return gelu_accurate elif activation == "gelu_accurate": return gelu_accurate @@ -182,9 +184,7 @@ def init_bert_params(module): def normal_(data): # with FSDP, module params will be on CUDA, so we cast them back to CPU # so that the RNG is consistent with and without FSDP - data.copy_( - data.cpu().normal_(mean=0.0, std=0.02).to(data.device) - ) + data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device)) if isinstance(module, nn.Linear): normal_(module.weight.data) @@ -307,24 +307,24 @@ class MultiheadAttention(nn.Module): """ def __init__( - self, - embed_dim, - num_heads, - kdim=None, - vdim=None, - dropout=0.0, - bias=True, - add_bias_kv=False, - add_zero_attn=False, - self_attention=False, - encoder_decoder_attention=False, - q_noise=0.0, - qn_block_size=8, - has_relative_attention_bias=False, - num_buckets=32, - max_distance=128, - gru_rel_pos=False, - rescale_init=False, + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + q_noise=0.0, + qn_block_size=8, + has_relative_attention_bias=False, + num_buckets=32, + max_distance=128, + gru_rel_pos=False, + rescale_init=False, ): super().__init__() self.embed_dim = embed_dim @@ -345,9 +345,9 @@ def __init__( self.q_head_dim = self.head_dim self.k_head_dim = self.head_dim assert ( - self.head_dim * num_heads == self.embed_dim + self.head_dim * num_heads == self.embed_dim ), "embed_dim must be divisible by num_heads" - self.scaling = self.head_dim ** -0.5 + self.scaling = self.head_dim**-0.5 self.self_attention = self_attention self.encoder_decoder_attention = encoder_decoder_attention @@ -424,21 +424,26 @@ def _relative_positions_bucket(self, relative_positions, bidirectional=True): relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets relative_positions = torch.abs(relative_positions) else: - relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions)) + relative_positions = -torch.min( + relative_positions, torch.zeros_like(relative_positions) + ) max_exact = num_buckets // 2 is_small = relative_positions < max_exact relative_postion_if_large = max_exact + ( - torch.log(relative_positions.float() / max_exact) - / math.log(max_distance / max_exact) - * (num_buckets - max_exact) + torch.log(relative_positions.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) ).to(torch.long) relative_postion_if_large = torch.min( - relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) + relative_postion_if_large, + torch.full_like(relative_postion_if_large, num_buckets - 1), ) - relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large) + relative_buckets += torch.where( + is_small, relative_positions, relative_postion_if_large + ) return relative_buckets def compute_bias(self, query_length, key_length): @@ -446,27 +451,28 @@ def compute_bias(self, query_length, key_length): memory_position = torch.arange(key_length, dtype=torch.long)[None, :] relative_position = memory_position - context_position relative_position_bucket = self._relative_positions_bucket( - relative_position, - bidirectional=True + relative_position, bidirectional=True + ) + relative_position_bucket = relative_position_bucket.to( + self.relative_attention_bias.weight.device ) - relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device) values = self.relative_attention_bias(relative_position_bucket) values = values.permute([2, 0, 1]) return values def forward( - self, - query, - key: Optional[Tensor], - value: Optional[Tensor], - key_padding_mask: Optional[Tensor] = None, - incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, - need_weights: bool = True, - static_kv: bool = False, - attn_mask: Optional[Tensor] = None, - before_softmax: bool = False, - need_head_weights: bool = False, - position_bias: Optional[Tensor] = None + self, + query, + key: Optional[Tensor], + value: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + need_weights: bool = True, + static_kv: bool = False, + attn_mask: Optional[Tensor] = None, + before_softmax: bool = False, + need_head_weights: bool = False, + position_bias: Optional[Tensor] = None, ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: """Input shape: Time x Batch x Channel @@ -503,16 +509,20 @@ def forward( if self.has_relative_attention_bias and position_bias is None: position_bias = self.compute_bias(tgt_len, src_len) - position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len) + position_bias = ( + position_bias.unsqueeze(0) + .repeat(bsz, 1, 1, 1) + .view(bsz * self.num_heads, tgt_len, src_len) + ) if ( - not is_tpu # don't use PyTorch version on TPUs - and incremental_state is None - and not static_kv - # A workaround for quantization to work. Otherwise JIT compilation - # treats bias in linear module as method. - and not torch.jit.is_scripting() - and self.q_head_dim == self.head_dim + not is_tpu # don't use PyTorch version on TPUs + and incremental_state is None + and not static_kv + # A workaround for quantization to work. Otherwise JIT compilation + # treats bias in linear module as method. + and not torch.jit.is_scripting() + and self.q_head_dim == self.head_dim ): assert key is not None and value is not None assert attn_mask is None @@ -527,10 +537,15 @@ def forward( query_layer = query_layer.permute(0, 2, 1, 3) _B, _H, _L, __ = query_layer.size() - gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view( - _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1) + gate_a, gate_b = torch.sigmoid( + self.grep_linear(query_layer) + .view(_B, _H, _L, 2, 4) + .sum(-1, keepdim=False) + ).chunk(2, dim=-1) gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 - attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias + attn_mask_rel_pos = ( + gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias + ) attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len)) k_proj_bias = self.k_proj.bias @@ -614,20 +629,20 @@ def forward( q = ( q.contiguous() - .view(tgt_len, bsz * self.num_heads, self.q_head_dim) - .transpose(0, 1) + .view(tgt_len, bsz * self.num_heads, self.q_head_dim) + .transpose(0, 1) ) if k is not None: k = ( k.contiguous() - .view(-1, bsz * self.num_heads, self.k_head_dim) - .transpose(0, 1) + .view(-1, bsz * self.num_heads, self.k_head_dim) + .transpose(0, 1) ) if v is not None: v = ( v.contiguous() - .view(-1, bsz * self.num_heads, self.head_dim) - .transpose(0, 1) + .view(-1, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) ) if saved_state is not None: @@ -731,18 +746,21 @@ def forward( if self.gru_rel_pos == 1: query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) _B, _H, _L, __ = query_layer.size() - gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view( - _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1) + gate_a, gate_b = torch.sigmoid( + self.grep_linear(query_layer) + .view(_B, _H, _L, 2, 4) + .sum(-1, keepdim=False) + ).chunk(2, dim=-1) gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 - position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias + position_bias = ( + gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias + ) position_bias = position_bias.view(attn_weights.size()) attn_weights = attn_weights + position_bias - attn_weights_float = F.softmax( - attn_weights, dim=-1 - ) + attn_weights_float = F.softmax(attn_weights, dim=-1) attn_weights = attn_weights_float.type_as(attn_weights) attn_probs = self.dropout_module(attn_weights) @@ -764,11 +782,11 @@ def forward( @staticmethod def _append_prev_key_padding_mask( - key_padding_mask: Optional[Tensor], - prev_key_padding_mask: Optional[Tensor], - batch_size: int, - src_len: int, - static_kv: bool, + key_padding_mask: Optional[Tensor], + prev_key_padding_mask: Optional[Tensor], + batch_size: int, + src_len: int, + static_kv: bool, ) -> Optional[Tensor]: # saved key padding masks have shape (bsz, seq_len) if prev_key_padding_mask is not None and static_kv: @@ -807,7 +825,7 @@ def _append_prev_key_padding_mask( return new_key_padding_mask def _get_input_buffer( - self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] + self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] ) -> Dict[str, Optional[Tensor]]: result = self.get_incremental_state(incremental_state, "attn_state") if result is not None: @@ -817,11 +835,11 @@ def _get_input_buffer( return empty_result def _set_input_buffer( - self, - incremental_state: Dict[str, Dict[str, Optional[Tensor]]], - buffer: Dict[str, Optional[Tensor]], + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + buffer: Dict[str, Optional[Tensor]], ): return self.set_incremental_state(incremental_state, "attn_state", buffer) def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): - return attn_weights \ No newline at end of file + return attn_weights