diff --git a/AcademiCodec/env.py b/AcademiCodec/env.py deleted file mode 100644 index 2bdbc95..0000000 --- a/AcademiCodec/env.py +++ /dev/null @@ -1,15 +0,0 @@ -import os -import shutil - - -class AttrDict(dict): - def __init__(self, *args, **kwargs): - super(AttrDict, self).__init__(*args, **kwargs) - self.__dict__ = self - - -def build_env(config, config_name, path): - t_path = os.path.join(path, config_name) - if config != t_path: - os.makedirs(path, exist_ok=True) - shutil.copyfile(config, os.path.join(path, config_name)) diff --git a/AcademiCodec/meldataset.py b/AcademiCodec/meldataset.py deleted file mode 100644 index 2cf2a26..0000000 --- a/AcademiCodec/meldataset.py +++ /dev/null @@ -1,171 +0,0 @@ -# code based on https://github.com/b04901014/MQTTS -import math -import os -import random -import torch -import torch.utils.data -import numpy as np -from librosa.util import normalize -from scipy.io.wavfile import read -from librosa.filters import mel as librosa_mel_fn - -MAX_WAV_VALUE = 32768.0 - - -def load_wav(full_path): - sampling_rate, data = read(full_path) - return data, sampling_rate - - -def dynamic_range_compression(x, C=1, clip_val=1e-5): - return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) - - -def dynamic_range_decompression(x, C=1): - return np.exp(x) / C - - -def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): - return torch.log(torch.clamp(x, min=clip_val) * C) - - -def dynamic_range_decompression_torch(x, C=1): - return torch.exp(x) / C - - -def spectral_normalize_torch(magnitudes): - output = dynamic_range_compression_torch(magnitudes) - return output - - -def spectral_de_normalize_torch(magnitudes): - output = dynamic_range_decompression_torch(magnitudes) - return output - - -mel_basis = {} -hann_window = {} - - -def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): - if torch.min(y) < -1.: - print('min value is ', torch.min(y)) - if torch.max(y) > 1.: - print('max value is ', torch.max(y)) - - global mel_basis, hann_window - if fmax not in mel_basis: - mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) - mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device) - hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) - - y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') - y = y.squeeze(1) - - spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)], - center=center, pad_mode='reflect', normalized=False, onesided=True) - - spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9)) - - spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec) - spec = spectral_normalize_torch(spec) - - return spec - - -def get_dataset_filelist(a): - with open(a.input_training_file, 'r') as f: - training_files = [l.strip() for l in f] - with open(a.input_validation_file, 'r') as f: - validation_files = [l.strip() for l in f] - return training_files, validation_files - - -class MelDataset(torch.utils.data.Dataset): - def __init__(self, training_files, segment_size, n_fft, num_mels, - hop_size, win_size, sampling_rate, fmin, fmax, split=True, shuffle=True, n_cache_reuse=1, - device=None, fmax_loss=None, fine_tuning=False, base_mels_path=None): - self.audio_files = training_files - random.seed(1234) - if shuffle: - random.shuffle(self.audio_files) - self.segment_size = segment_size - self.sampling_rate = sampling_rate - self.split = split - self.n_fft = n_fft - self.num_mels = num_mels - self.hop_size = hop_size - self.win_size = win_size - self.fmin = fmin - self.fmax = fmax - self.fmax_loss = fmax_loss - self.cached_wav = None - self.n_cache_reuse = n_cache_reuse - self._cache_ref_count = 0 - self.device = device - self.fine_tuning = fine_tuning - self.base_mels_path = base_mels_path - - def __getitem__(self, index): - filename = self.audio_files[index] - if self._cache_ref_count == 0: - try: - audio, sampling_rate = load_wav(filename) - audio = audio / MAX_WAV_VALUE - if not self.fine_tuning: - audio = normalize(audio) * 0.95 - except: - print (f"Error on audio: {filename}") - audio = np.random.normal(size=(160000,)) * 0.05 - sampling_rate = self.sampling_rate - self.cached_wav = audio - if sampling_rate != self.sampling_rate: - raise ValueError("{} SR doesn't match target {} SR".format( - sampling_rate, self.sampling_rate)) - self._cache_ref_count = self.n_cache_reuse - else: - audio = self.cached_wav - self._cache_ref_count -= 1 - - audio = torch.FloatTensor(audio) - audio = audio.unsqueeze(0) - - if not self.fine_tuning: - if self.split: - if audio.size(1) >= self.segment_size: - max_audio_start = audio.size(1) - self.segment_size - audio_start = random.randint(0, max_audio_start) - audio = audio[:, audio_start:audio_start+self.segment_size] - else: - audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant') - - mel = mel_spectrogram(audio, self.n_fft, self.num_mels, - self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax, - center=False) - else: - mel = np.load( - os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + '.npy')) - mel = torch.from_numpy(mel) - - if len(mel.shape) < 3: - mel = mel.unsqueeze(0) - - if self.split: - frames_per_seg = math.ceil(self.segment_size / self.hop_size) - - if audio.size(1) >= self.segment_size: - mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1) - mel = mel[:, :, mel_start:mel_start + frames_per_seg] - audio = audio[:, mel_start * self.hop_size:(mel_start + frames_per_seg) * self.hop_size] - else: - mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), 'constant') - audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant') - - mel_loss = mel_spectrogram(audio, self.n_fft, self.num_mels, - self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax_loss, - center=False) - - return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze()) - - def __len__(self): - return len(self.audio_files) diff --git a/AcademiCodec/models.py b/AcademiCodec/models.py deleted file mode 100644 index 3dd54f6..0000000 --- a/AcademiCodec/models.py +++ /dev/null @@ -1,458 +0,0 @@ -import torch -import torch.nn.functional as F -import torch.nn as nn -from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d -from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm -try: - from utils import init_weights, get_padding -except: - from .utils import init_weights, get_padding - -LRELU_SLOPE = 0.1 - - -class ResBlock1(torch.nn.Module): - def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): - super(ResBlock1, self).__init__() - self.h = h - self.convs1 = nn.ModuleList([ - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], - padding=get_padding(kernel_size, dilation[0]))), - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], - padding=get_padding(kernel_size, dilation[1]))), - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], - padding=get_padding(kernel_size, dilation[2]))) - ]) - self.convs1.apply(init_weights) - - self.convs2 = nn.ModuleList([ - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, - padding=get_padding(kernel_size, 1))), - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, - padding=get_padding(kernel_size, 1))), - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, - padding=get_padding(kernel_size, 1))) - ]) - self.convs2.apply(init_weights) - - def forward(self, x): - for c1, c2 in zip(self.convs1, self.convs2): - xt = F.leaky_relu(x, LRELU_SLOPE) - xt = c1(xt) - xt = F.leaky_relu(xt, LRELU_SLOPE) - xt = c2(xt) - x = xt + x - return x - - def remove_weight_norm(self): - for l in self.convs1: - remove_weight_norm(l) - for l in self.convs2: - remove_weight_norm(l) - - -class ResBlock2(torch.nn.Module): - def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): - super(ResBlock2, self).__init__() - self.h = h - self.convs = nn.ModuleList([ - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], - padding=get_padding(kernel_size, dilation[0]))), - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], - padding=get_padding(kernel_size, dilation[1]))) - ]) - self.convs.apply(init_weights) - - def forward(self, x): - for c in self.convs: - xt = F.leaky_relu(x, LRELU_SLOPE) - xt = c(xt) - x = xt + x - return x - - def remove_weight_norm(self): - for l in self.convs: - remove_weight_norm(l) - - -class Generator(torch.nn.Module): - def __init__(self, h): - super(Generator, self).__init__() - self.h = h - self.num_kernels = len(h.resblock_kernel_sizes) - self.num_upsamples = len(h.upsample_rates) - self.conv_pre = weight_norm(Conv1d(512, h.upsample_initial_channel, 7, 1, padding=3)) - resblock = ResBlock1 if h.resblock == '1' else ResBlock2 - - self.ups = nn.ModuleList() - for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): - self.ups.append(weight_norm( - ConvTranspose1d( - h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)), - k, u, - # padding=(u//2 + u%2), - padding=(k - u )//2, - # output_padding=u%2 - ) - )) - - self.resblocks = nn.ModuleList() - for i in range(len(self.ups)): - ch = h.upsample_initial_channel//(2**(i+1)) - for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): - self.resblocks.append(resblock(h, ch, k, d)) - - self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) - self.ups.apply(init_weights) - self.conv_post.apply(init_weights) - - def forward(self, x): - x = self.conv_pre(x) - for i in range(self.num_upsamples): - x = F.leaky_relu(x, LRELU_SLOPE) - x = self.ups[i](x) - xs = None - for j in range(self.num_kernels): - if xs is None: - xs = self.resblocks[i*self.num_kernels+j](x) - else: - xs += self.resblocks[i*self.num_kernels+j](x) - x = xs / self.num_kernels - x = F.leaky_relu(x, LRELU_SLOPE) - x = self.conv_post(x) - x = torch.tanh(x) - - return x - - def remove_weight_norm(self): - print('Removing weight norm...') - for l in self.ups: - remove_weight_norm(l) - for l in self.resblocks: - l.remove_weight_norm() - remove_weight_norm(self.conv_pre) - remove_weight_norm(self.conv_post) - - -class DiscriminatorP(torch.nn.Module): - def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): - super(DiscriminatorP, self).__init__() - self.period = period - norm_f = weight_norm if use_spectral_norm == False else spectral_norm - self.convs = nn.ModuleList([ - norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), - norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), - norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), - norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), - norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), - ]) - self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) - - def forward(self, x): - fmap = [] - - # 1d to 2d - b, c, t = x.shape - if t % self.period != 0: # pad first - n_pad = self.period - (t % self.period) - x = F.pad(x, (0, n_pad), "reflect") - t = t + n_pad - x = x.view(b, c, t // self.period, self.period) - - for l in self.convs: - x = l(x) - x = F.leaky_relu(x, LRELU_SLOPE) - fmap.append(x) - x = self.conv_post(x) - fmap.append(x) - x = torch.flatten(x, 1, -1) - - return x, fmap - - -class MultiPeriodDiscriminator(torch.nn.Module): - def __init__(self): - super(MultiPeriodDiscriminator, self).__init__() - self.discriminators = nn.ModuleList([ - DiscriminatorP(2), - DiscriminatorP(3), - DiscriminatorP(5), - DiscriminatorP(7), - DiscriminatorP(11), - ]) - - def forward(self, y, y_hat): - y_d_rs = [] - y_d_gs = [] - fmap_rs = [] - fmap_gs = [] - for i, d in enumerate(self.discriminators): - y_d_r, fmap_r = d(y) - y_d_g, fmap_g = d(y_hat) - y_d_rs.append(y_d_r) - fmap_rs.append(fmap_r) - y_d_gs.append(y_d_g) - fmap_gs.append(fmap_g) - - return y_d_rs, y_d_gs, fmap_rs, fmap_gs - - -class DiscriminatorS(torch.nn.Module): - def __init__(self, use_spectral_norm=False): - super(DiscriminatorS, self).__init__() - norm_f = weight_norm if use_spectral_norm == False else spectral_norm - self.convs = nn.ModuleList([ - norm_f(Conv1d(1, 128, 15, 1, padding=7)), - norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), - norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), - norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), - norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), - norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), - norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), - ]) - self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) - - def forward(self, x): - fmap = [] - for l in self.convs: - x = l(x) - x = F.leaky_relu(x, LRELU_SLOPE) - fmap.append(x) - x = self.conv_post(x) - fmap.append(x) - x = torch.flatten(x, 1, -1) - - return x, fmap - - -class MultiScaleDiscriminator(torch.nn.Module): - def __init__(self): - super(MultiScaleDiscriminator, self).__init__() - self.discriminators = nn.ModuleList([ - DiscriminatorS(use_spectral_norm=True), - DiscriminatorS(), - DiscriminatorS(), - ]) - self.meanpools = nn.ModuleList([ - AvgPool1d(4, 2, padding=2), - AvgPool1d(4, 2, padding=2) - ]) - - def forward(self, y, y_hat): - y_d_rs = [] - y_d_gs = [] - fmap_rs = [] - fmap_gs = [] - for i, d in enumerate(self.discriminators): - if i != 0: - y = self.meanpools[i-1](y) - y_hat = self.meanpools[i-1](y_hat) - y_d_r, fmap_r = d(y) - y_d_g, fmap_g = d(y_hat) - y_d_rs.append(y_d_r) - fmap_rs.append(fmap_r) - y_d_gs.append(y_d_g) - fmap_gs.append(fmap_g) - - return y_d_rs, y_d_gs, fmap_rs, fmap_gs - - -def feature_loss(fmap_r, fmap_g): - loss = 0 - for dr, dg in zip(fmap_r, fmap_g): - for rl, gl in zip(dr, dg): - loss += torch.mean(torch.abs(rl - gl)) - - return loss*2 - - -def discriminator_loss(disc_real_outputs, disc_generated_outputs): - loss = 0 - r_losses = [] - g_losses = [] - for dr, dg in zip(disc_real_outputs, disc_generated_outputs): - r_loss = torch.mean((1-dr)**2) - g_loss = torch.mean(dg**2) - loss += (r_loss + g_loss) - r_losses.append(r_loss.item()) - g_losses.append(g_loss.item()) - - return loss, r_losses, g_losses - - -def generator_loss(disc_outputs): - loss = 0 - gen_losses = [] - for dg in disc_outputs: - l = torch.mean((1-dg)**2) - gen_losses.append(l) - loss += l - - return loss, gen_losses - - -class Encoder(torch.nn.Module): - def __init__(self, h): - super(Encoder, self).__init__() - self.h = h - self.num_kernels = len(h.resblock_kernel_sizes) - self.num_upsamples = len(h.upsample_rates) - self.conv_pre = weight_norm(Conv1d(1, 32, 7, 1, padding=3)) - self.normalize = nn.ModuleList() - resblock = ResBlock1 if h.resblock == '1' else ResBlock2 - - self.ups = nn.ModuleList() - for i, (u, k) in enumerate(list(reversed(list(zip(h.upsample_rates, h.upsample_kernel_sizes))))): - self.ups.append(weight_norm( - Conv1d(32*(2**i), 32*(2**(i+1)), - k, u, - padding=((k-u)//2) - # padding=(u//2 + u%2) - ))) - self.resblocks = nn.ModuleList() - for i in range(len(self.ups)): - ch = 32*(2**(i+1)) - for j, (k, d) in enumerate( - zip( - list(reversed(h.resblock_kernel_sizes)), - list(reversed(h.resblock_dilation_sizes)) - ) - ): - self.resblocks.append(resblock(h, ch, k, d)) - self.normalize.append(torch.nn.GroupNorm(ch // 16, ch, eps=1e-6, affine=True)) - self.conv_post = Conv1d(512, 512, 3, 1, padding=1) - self.ups.apply(init_weights) - self.conv_post.apply(init_weights) - - def forward(self, x): - x = self.conv_pre(x) - for i in range(self.num_upsamples): - x = F.leaky_relu(x, LRELU_SLOPE) - x = self.ups[i](x) - xs = None - for j in range(self.num_kernels): - if xs is None: - xs = self.resblocks[i*self.num_kernels+j](x) - xs = self.normalize[i*self.num_kernels+j](xs) - else: - xs += self.resblocks[i*self.num_kernels+j](x) - xs = self.normalize[i*self.num_kernels+j](xs) - x = xs / self.num_kernels - x = F.leaky_relu(x) - x = self.conv_post(x) - return x - - def remove_weight_norm(self): - print('Removing weight norm...') - for l in self.ups: - remove_weight_norm(l) - for l in self.resblocks: - l.remove_weight_norm() - remove_weight_norm(self.conv_pre) - - -class Quantizer_module(torch.nn.Module): - def __init__(self, n_e, e_dim): - super(Quantizer_module, self).__init__() - self.embedding = nn.Embedding(n_e, e_dim) - self.embedding.weight.data.uniform_(-1.0 / n_e, 1.0 / n_e) - - def forward(self, x): - # compute Euclidean distance - d = torch.sum(x ** 2, 1, keepdim=True) + torch.sum(self.embedding.weight ** 2, 1) \ - - 2 * torch.matmul(x, self.embedding.weight.T) - min_indicies = torch.argmin(d, 1) - z_q = self.embedding(min_indicies) - return z_q, min_indicies - - -class Quantizer(torch.nn.Module): - def __init__(self, h): - super(Quantizer, self).__init__() - assert 512 % h.n_code_groups == 0 - self.quantizer_modules = nn.ModuleList([ - Quantizer_module(h.n_codes, 512 // h.n_code_groups) for _ in range(h.n_code_groups) - ]) - self.quantizer_modules2 = nn.ModuleList([ - Quantizer_module(h.n_codes, 512 // h.n_code_groups) for _ in range(h.n_code_groups) - ]) - self.h = h - self.codebook_loss_lambda = self.h.codebook_loss_lambda # e.g., 1 - self.commitment_loss_lambda = self.h.commitment_loss_lambda # e.g., 0.25 - self.residul_layer = 2 - self.n_code_groups = h.n_code_groups - - def for_one_step(self, xin, idx): - xin = xin.transpose(1, 2) - x = xin.reshape(-1, 512) - x = torch.split(x, 512 // self.h.n_code_groups, dim=-1) - min_indicies = [] - z_q = [] - if idx == 0: - for _x, m in zip(x, self.quantizer_modules): - _z_q, _min_indicies = m(_x) - z_q.append(_z_q) - min_indicies.append(_min_indicies) #B * T, - z_q = torch.cat(z_q, -1).reshape(xin.shape) - # loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean((z_q - xin.detach()) ** 2) - loss = self.codebook_loss_lambda * torch.mean((z_q - xin.detach()) ** 2) \ - + self.commitment_loss_lambda * torch.mean((z_q.detach() - xin) ** 2) - z_q = xin + (z_q - xin).detach() - z_q = z_q.transpose(1, 2) - return z_q, loss, min_indicies - else: - for _x, m in zip(x, self.quantizer_modules2): - _z_q, _min_indicies = m(_x) - z_q.append(_z_q) - min_indicies.append(_min_indicies) #B * T, - z_q = torch.cat(z_q, -1).reshape(xin.shape) - # loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean((z_q - xin.detach()) ** 2) - loss = self.codebook_loss_lambda * torch.mean((z_q - xin.detach()) ** 2) \ - + self.commitment_loss_lambda * torch.mean((z_q.detach() - xin) ** 2) - z_q = xin + (z_q - xin).detach() - z_q = z_q.transpose(1, 2) - return z_q, loss, min_indicies - - def forward(self, xin): - #B, C, T - quantized_out = 0.0 - residual = xin - all_losses = [] - all_indices = [] - for i in range(self.residul_layer): - quantized, loss, indices = self.for_one_step(residual, i) # - residual = residual - quantized - quantized_out = quantized_out + quantized - all_indices.extend(indices) # - all_losses.append(loss) - all_losses = torch.stack(all_losses) - loss = torch.mean(all_losses) - return quantized_out, loss, all_indices - - def embed(self, x): - #idx: N, T, 4 - #print('x ', x.shape) - quantized_out = torch.tensor(0.0, device=x.device) - x = torch.split(x, 1, 2) # split, 将最后一个维度分开, 每个属于一个index group - #print('x.shape ', len(x),x[0].shape) - for i in range(self.residul_layer): - ret = [] - if i == 0: - for j in range(self.n_code_groups): - q = x[j] - embed = self.quantizer_modules[j] - q = embed.embedding(q.squeeze(-1)) - ret.append(q) - ret = torch.cat(ret, -1) - #print(ret.shape) - quantized_out = quantized_out + ret - else: - for j in range(self.n_code_groups): - q = x[j+self.n_code_groups] - embed = self.quantizer_modules2[j] - q = embed.embedding(q.squeeze(-1)) - ret.append(q) - ret = torch.cat(ret, -1) - quantized_out = quantized_out + ret - return quantized_out.transpose(1, 2) #N, C, T - diff --git a/AcademiCodec/msstftd.py b/AcademiCodec/msstftd.py deleted file mode 100644 index 3ea98f6..0000000 --- a/AcademiCodec/msstftd.py +++ /dev/null @@ -1,156 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -"""MS-STFT discriminator, provided here for reference.""" - -import typing as tp - -import torchaudio -import torch -from torch import nn -from einops import rearrange - -from modules import NormConv2d - - -FeatureMapType = tp.List[torch.Tensor] -LogitsType = torch.Tensor -DiscriminatorOutput = tp.Tuple[tp.List[LogitsType], tp.List[FeatureMapType]] - - -def get_2d_padding(kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[int, int] = (1, 1)): - return (((kernel_size[0] - 1) * dilation[0]) // 2, ((kernel_size[1] - 1) * dilation[1]) // 2) - - -class DiscriminatorSTFT(nn.Module): - """STFT sub-discriminator. - Args: - filters (int): Number of filters in convolutions - in_channels (int): Number of input channels. Default: 1 - out_channels (int): Number of output channels. Default: 1 - n_fft (int): Size of FFT for each scale. Default: 1024 - hop_length (int): Length of hop between STFT windows for each scale. Default: 256 - kernel_size (tuple of int): Inner Conv2d kernel sizes. Default: ``(3, 9)`` - stride (tuple of int): Inner Conv2d strides. Default: ``(1, 2)`` - dilations (list of int): Inner Conv2d dilation on the time dimension. Default: ``[1, 2, 4]`` - win_length (int): Window size for each scale. Default: 1024 - normalized (bool): Whether to normalize by magnitude after stft. Default: True - norm (str): Normalization method. Default: `'weight_norm'` - activation (str): Activation function. Default: `'LeakyReLU'` - activation_params (dict): Parameters to provide to the activation function. - growth (int): Growth factor for the filters. Default: 1 - """ - def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1, - n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024, max_filters: int = 1024, - filters_scale: int = 1, kernel_size: tp.Tuple[int, int] = (3, 9), dilations: tp.List = [1, 2, 4], - stride: tp.Tuple[int, int] = (1, 2), normalized: bool = True, norm: str = 'weight_norm', - activation: str = 'LeakyReLU', activation_params: dict = {'negative_slope': 0.2}): - super().__init__() - assert len(kernel_size) == 2 - assert len(stride) == 2 - self.filters = filters - self.in_channels = in_channels - self.out_channels = out_channels - self.n_fft = n_fft - self.hop_length = hop_length - self.win_length = win_length - self.normalized = normalized - self.activation = getattr(torch.nn, activation)(**activation_params) - self.spec_transform = torchaudio.transforms.Spectrogram( - n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window_fn=torch.hann_window, - normalized=self.normalized, center=False, pad_mode=None, power=None) - spec_channels = 2 * self.in_channels - self.convs = nn.ModuleList() - self.convs.append( - NormConv2d(spec_channels, self.filters, kernel_size=kernel_size, padding=get_2d_padding(kernel_size)) - ) - in_chs = min(filters_scale * self.filters, max_filters) - for i, dilation in enumerate(dilations): - out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters) - self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride, - dilation=(dilation, 1), padding=get_2d_padding(kernel_size, (dilation, 1)), - norm=norm)) - in_chs = out_chs - out_chs = min((filters_scale ** (len(dilations) + 1)) * self.filters, max_filters) - self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_size[0], kernel_size[0]), - padding=get_2d_padding((kernel_size[0], kernel_size[0])), - norm=norm)) - self.conv_post = NormConv2d(out_chs, self.out_channels, - kernel_size=(kernel_size[0], kernel_size[0]), - padding=get_2d_padding((kernel_size[0], kernel_size[0])), - norm=norm) - - def forward(self, x: torch.Tensor): - fmap = [] - # print('x ', x.shape) - z = self.spec_transform(x) # [B, 2, Freq, Frames, 2] - # print('z ', z.shape) - z = torch.cat([z.real, z.imag], dim=1) - # print('cat_z ', z.shape) - z = rearrange(z, 'b c w t -> b c t w') - for i, layer in enumerate(self.convs): - z = layer(z) - z = self.activation(z) - # print('z i', i, z.shape) - fmap.append(z) - z = self.conv_post(z) - # print('logit ', z.shape) - return z, fmap - - -class MultiScaleSTFTDiscriminator(nn.Module): - """Multi-Scale STFT (MS-STFT) discriminator. - Args: - filters (int): Number of filters in convolutions - in_channels (int): Number of input channels. Default: 1 - out_channels (int): Number of output channels. Default: 1 - n_ffts (Sequence[int]): Size of FFT for each scale - hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale - win_lengths (Sequence[int]): Window size for each scale - **kwargs: additional args for STFTDiscriminator - """ - def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1, - n_ffts: tp.List[int] = [1024, 2048, 512, 256, 128], hop_lengths: tp.List[int] = [256, 512, 128, 64, 32], - win_lengths: tp.List[int] = [1024, 2048, 512, 256, 128], **kwargs): - super().__init__() - assert len(n_ffts) == len(hop_lengths) == len(win_lengths) - self.discriminators = nn.ModuleList([ - DiscriminatorSTFT(filters, in_channels=in_channels, out_channels=out_channels, - n_fft=n_ffts[i], win_length=win_lengths[i], hop_length=hop_lengths[i], **kwargs) - for i in range(len(n_ffts)) - ]) - self.num_discriminators = len(self.discriminators) - - def forward(self, x: torch.Tensor) -> DiscriminatorOutput: - logits = [] - fmaps = [] - for disc in self.discriminators: - logit, fmap = disc(x) - logits.append(logit) - fmaps.append(fmap) - return logits, fmaps - - -def test(): - disc = MultiScaleSTFTDiscriminator(filters=32) - y = torch.randn(1, 1, 24000) - y_hat = torch.randn(1, 1, 24000) - - y_disc_r, fmap_r = disc(y) - #print('y_disc_r ', len(y_disc_r)) - # print('fmap_r ', len(fmap_r)) - y_disc_gen, fmap_gen = disc(y_hat) - # print('y_disc_gen ', y_disc_gen.shape) - # print('fmap_gen ', len(fmap_gen)) - assert len(y_disc_r) == len(y_disc_gen) == len(fmap_r) == len(fmap_gen) == disc.num_discriminators - - assert all([len(fm) == 5 for fm in fmap_r + fmap_gen]) - assert all([list(f.shape)[:2] == [1, 32] for fm in fmap_r + fmap_gen for f in fm]) - assert all([len(logits.shape) == 4 for logits in y_disc_r + y_disc_gen]) - - -if __name__ == '__main__': - test() diff --git a/AcademiCodec/param_config.json b/AcademiCodec/param_config.json deleted file mode 100644 index 0c7cf0f..0000000 --- a/AcademiCodec/param_config.json +++ /dev/null @@ -1,42 +0,0 @@ -{ - "resblock": "1", - "num_gpus": 8, - "batch_size": 32, - "learning_rate": 0.0002, - "adam_b1": 0.5, - "adam_b2": 0.9, - "lr_decay": 0.98, - "seed": 1234, - - "upsample_rates": [8,5,3,2], - "upsample_kernel_sizes": [16,11,7,4], - "upsample_initial_channel": 512, - "resblock_kernel_sizes": [3,7,11], - "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], - - "segment_size": 12000, - "num_mels": 80, - "num_freq": 1025, - "n_fft": 1024, - "hop_size": 240, - "win_size": 1024, - - "sampling_rate": 24000, - - "n_code_groups": 2, - "n_codes": 1024, - "codebook_loss_lambda": 1.0, - "commitment_loss_lambda": 0.25, - - "fmin": 0, - "fmax": 8000, - "fmax_for_loss": null, - - "num_workers": 12, - - "dist_config": { - "dist_backend": "nccl", - "dist_url": "tcp://localhost:54321", - "world_size": 1 - } -} diff --git a/AcademiCodec/quantization/__init__.py b/AcademiCodec/quantization/__init__.py deleted file mode 100644 index bfabe52..0000000 --- a/AcademiCodec/quantization/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -# flake8: noqa -from .vq import QuantizedResult, ResidualVectorQuantizer diff --git a/AcademiCodec/quantization/ac.py b/AcademiCodec/quantization/ac.py deleted file mode 100644 index f0f3e5d..0000000 --- a/AcademiCodec/quantization/ac.py +++ /dev/null @@ -1,292 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -"""Arithmetic coder.""" - -import io -import math -import random -import typing as tp -import torch - -from ..binary import BitPacker, BitUnpacker - - -def build_stable_quantized_cdf(pdf: torch.Tensor, total_range_bits: int, - roundoff: float = 1e-8, min_range: int = 2, - check: bool = True) -> torch.Tensor: - """Turn the given PDF into a quantized CDF that splits - [0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional - to the PDF. - - Args: - pdf (torch.Tensor): probability distribution, shape should be `[N]`. - total_range_bits (int): see `ArithmeticCoder`, the typical range we expect - during the coding process is `[0, 2 ** total_range_bits - 1]`. - roundoff (float): will round the pdf up to that level to remove difference coming - from e.g. evaluating the Language Model on different architectures. - min_range (int): minimum range width. Should always be at least 2 for numerical - stability. Use this to avoid pathological behavior is a value - that is expected to be rare actually happens in real life. - check (bool): if True, checks that nothing bad happened, can be deactivated for speed. - """ - pdf = pdf.detach() - if roundoff: - pdf = (pdf / roundoff).floor() * roundoff - # interpolate with uniform distribution to achieve desired minimum probability. - total_range = 2 ** total_range_bits - cardinality = len(pdf) - alpha = min_range * cardinality / total_range - assert alpha <= 1, "you must reduce min_range" - ranges = (((1 - alpha) * total_range) * pdf).floor().long() - ranges += min_range - quantized_cdf = torch.cumsum(ranges, dim=-1) - if min_range < 2: - raise ValueError("min_range must be at least 2.") - if check: - assert quantized_cdf[-1] <= 2 ** total_range_bits, quantized_cdf[-1] - if ((quantized_cdf[1:] - quantized_cdf[:-1]) < min_range).any() or quantized_cdf[0] < min_range: - raise ValueError("You must increase your total_range_bits.") - return quantized_cdf - - -class ArithmeticCoder: - """ArithmeticCoder, - Let us take a distribution `p` over `N` symbols, and assume we have a stream - of random variables `s_t` sampled from `p`. Let us assume that we have a budget - of `B` bits that we can afford to write on device. There are `2**B` possible numbers, - corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single - sequence `(s_t)` by doing the following: - - 1) Initialize the current range to` [0 ** 2 B - 1]`. - 2) For each time step t, split the current range into contiguous chunks, - one for each possible outcome, with size roughly proportional to `p`. - For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks - would be `{[0, 2], [3, 3]}`. - 3) Select the chunk corresponding to `s_t`, and replace the current range with this. - 4) When done encoding all the values, just select any value remaining in the range. - - You will notice that this procedure can fail: for instance if at any point in time - the range is smaller than `N`, then we can no longer assign a non-empty chunk to each - possible outcome. Intuitively, the more likely a value is, the less the range width - will reduce, and the longer we can go on encoding values. This makes sense: for any efficient - coding scheme, likely outcomes would take less bits, and more of them can be coded - with a fixed budget. - - In practice, we do not know `B` ahead of time, but we have a way to inject new bits - when the current range decreases below a given limit (given by `total_range_bits`), without - having to redo all the computations. If we encode mostly likely values, we will seldom - need to inject new bits, but a single rare value can deplete our stock of entropy! - - In this explanation, we assumed that the distribution `p` was constant. In fact, the present - code works for any sequence `(p_t)` possibly different for each timestep. - We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller - the KL between the true distribution and `p_t`, the most efficient the coding will be. - - Args: - fo (IO[bytes]): file-like object to which the bytes will be written to. - total_range_bits (int): the range `M` described above is `2 ** total_range_bits. - Any time the current range width fall under this limit, new bits will - be injected to rescale the initial range. - """ - - def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24): - assert total_range_bits <= 30 - self.total_range_bits = total_range_bits - self.packer = BitPacker(bits=1, fo=fo) # we push single bits at a time. - self.low: int = 0 - self.high: int = 0 - self.max_bit: int = -1 - self._dbg: tp.List[tp.Any] = [] - self._dbg2: tp.List[tp.Any] = [] - - @property - def delta(self) -> int: - """Return the current range width.""" - return self.high - self.low + 1 - - def _flush_common_prefix(self): - # If self.low and self.high start with the sames bits, - # those won't change anymore as we always just increase the range - # by powers of 2, and we can flush them out to the bit stream. - assert self.high >= self.low, (self.low, self.high) - assert self.high < 2 ** (self.max_bit + 1) - while self.max_bit >= 0: - b1 = self.low >> self.max_bit - b2 = self.high >> self.max_bit - if b1 == b2: - self.low -= (b1 << self.max_bit) - self.high -= (b1 << self.max_bit) - assert self.high >= self.low, (self.high, self.low, self.max_bit) - assert self.low >= 0 - self.max_bit -= 1 - self.packer.push(b1) - else: - break - - def push(self, symbol: int, quantized_cdf: torch.Tensor): - """Push the given symbol on the stream, flushing out bits - if possible. - - Args: - symbol (int): symbol to encode with the AC. - quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf` - to build this from your pdf estimate. - """ - while self.delta < 2 ** self.total_range_bits: - self.low *= 2 - self.high = self.high * 2 + 1 - self.max_bit += 1 - - range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item() - range_high = quantized_cdf[symbol].item() - 1 - effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits)))) - effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits)))) - assert self.low <= self.high - self.high = self.low + effective_high - self.low = self.low + effective_low - assert self.low <= self.high, (effective_low, effective_high, range_low, range_high) - self._dbg.append((self.low, self.high)) - self._dbg2.append((self.low, self.high)) - outs = self._flush_common_prefix() - assert self.low <= self.high - assert self.max_bit >= -1 - assert self.max_bit <= 61, self.max_bit - return outs - - def flush(self): - """Flush the remaining information to the stream. - """ - while self.max_bit >= 0: - b1 = (self.low >> self.max_bit) & 1 - self.packer.push(b1) - self.max_bit -= 1 - self.packer.flush() - - -class ArithmeticDecoder: - """ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation. - - Note that this must be called with **exactly** the same parameters and sequence - of quantized cdf as the arithmetic encoder or the wrong values will be decoded. - - If the AC encoder current range is [L, H], with `L` and `H` having the some common - prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream. - For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside - `[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained - for a specific sequence of symbols and a binary-search allows us to decode those symbols. - At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols, - and we will need to read new bits from the stream and repeat the process. - - """ - def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24): - self.total_range_bits = total_range_bits - self.low: int = 0 - self.high: int = 0 - self.current: int = 0 - self.max_bit: int = -1 - self.unpacker = BitUnpacker(bits=1, fo=fo) # we pull single bits at a time. - # Following is for debugging - self._dbg: tp.List[tp.Any] = [] - self._dbg2: tp.List[tp.Any] = [] - self._last: tp.Any = None - - @property - def delta(self) -> int: - return self.high - self.low + 1 - - def _flush_common_prefix(self): - # Given the current range [L, H], if both have a common prefix, - # we know we can remove it from our representation to avoid handling large numbers. - while self.max_bit >= 0: - b1 = self.low >> self.max_bit - b2 = self.high >> self.max_bit - if b1 == b2: - self.low -= (b1 << self.max_bit) - self.high -= (b1 << self.max_bit) - self.current -= (b1 << self.max_bit) - assert self.high >= self.low - assert self.low >= 0 - self.max_bit -= 1 - else: - break - - def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]: - """Pull a symbol, reading as many bits from the stream as required. - This returns `None` when the stream has been exhausted. - - Args: - quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf` - to build this from your pdf estimate. This must be **exatly** - the same cdf as the one used at encoding time. - """ - while self.delta < 2 ** self.total_range_bits: - bit = self.unpacker.pull() - if bit is None: - return None - self.low *= 2 - self.high = self.high * 2 + 1 - self.current = self.current * 2 + bit - self.max_bit += 1 - - def bin_search(low_idx: int, high_idx: int): - # Binary search is not just for coding interviews :) - if high_idx < low_idx: - raise RuntimeError("Binary search failed") - mid = (low_idx + high_idx) // 2 - range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0 - range_high = quantized_cdf[mid].item() - 1 - effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits)))) - effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits)))) - low = effective_low + self.low - high = effective_high + self.low - if self.current >= low: - if self.current <= high: - return (mid, low, high, self.current) - else: - return bin_search(mid + 1, high_idx) - else: - return bin_search(low_idx, mid - 1) - - self._last = (self.low, self.high, self.current, self.max_bit) - sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1) - self._dbg.append((self.low, self.high, self.current)) - self._flush_common_prefix() - self._dbg2.append((self.low, self.high, self.current)) - - return sym - - -def test(): - torch.manual_seed(1234) - random.seed(1234) - for _ in range(4): - pdfs = [] - cardinality = random.randrange(4000) - steps = random.randrange(100, 500) - fo = io.BytesIO() - encoder = ArithmeticCoder(fo) - symbols = [] - for step in range(steps): - pdf = torch.softmax(torch.randn(cardinality), dim=0) - pdfs.append(pdf) - q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits) - symbol = torch.multinomial(pdf, 1).item() - symbols.append(symbol) - encoder.push(symbol, q_cdf) - encoder.flush() - - fo.seek(0) - decoder = ArithmeticDecoder(fo) - for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)): - q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits) - decoded_symbol = decoder.pull(q_cdf) - assert decoded_symbol == symbol, idx - assert decoder.pull(torch.zeros(1)) is None - - -if __name__ == "__main__": - test() diff --git a/AcademiCodec/quantization/core_vq.py b/AcademiCodec/quantization/core_vq.py deleted file mode 100644 index 48f1635..0000000 --- a/AcademiCodec/quantization/core_vq.py +++ /dev/null @@ -1,362 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# -# This implementation is inspired from -# https://github.com/lucidrains/vector-quantize-pytorch -# which is released under MIT License. Hereafter, the original license: -# MIT License -# -# Copyright (c) 2020 Phil Wang -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -"""Core vector quantization implementation.""" -import typing as tp - -from einops import rearrange, repeat -import torch -from torch import nn -import torch.nn.functional as F - -from .distrib import broadcast_tensors, rank - - -def default(val: tp.Any, d: tp.Any) -> tp.Any: - return val if val is not None else d - - -def ema_inplace(moving_avg, new, decay: float): - moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) - - -def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5): - return (x + epsilon) / (x.sum() + n_categories * epsilon) - - -def uniform_init(*shape: int): - t = torch.empty(shape) - nn.init.kaiming_uniform_(t) - return t - - -def sample_vectors(samples, num: int): - num_samples, device = samples.shape[0], samples.device - - if num_samples >= num: - indices = torch.randperm(num_samples, device=device)[:num] - else: - indices = torch.randint(0, num_samples, (num,), device=device) - - return samples[indices] - - -def kmeans(samples, num_clusters: int, num_iters: int = 10): - dim, dtype = samples.shape[-1], samples.dtype - - means = sample_vectors(samples, num_clusters) - - for _ in range(num_iters): - diffs = rearrange(samples, "n d -> n () d") - rearrange( - means, "c d -> () c d" - ) - dists = -(diffs ** 2).sum(dim=-1) - - buckets = dists.max(dim=-1).indices - bins = torch.bincount(buckets, minlength=num_clusters) - zero_mask = bins == 0 - bins_min_clamped = bins.masked_fill(zero_mask, 1) - - new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) - new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples) - new_means = new_means / bins_min_clamped[..., None] - - means = torch.where(zero_mask[..., None], means, new_means) - - return means, bins - - -class EuclideanCodebook(nn.Module): - """Codebook with Euclidean distance. - Args: - dim (int): Dimension. - codebook_size (int): Codebook size. - kmeans_init (bool): Whether to use k-means to initialize the codebooks. - If set to true, run the k-means algorithm on the first training batch and use - the learned centroids as initialization. - kmeans_iters (int): Number of iterations used for k-means algorithm at initialization. - decay (float): Decay for exponential moving average over the codebooks. - epsilon (float): Epsilon value for numerical stability. - threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes - that have an exponential moving average cluster size less than the specified threshold with - randomly selected vector from the current batch. - """ - def __init__( - self, - dim: int, - codebook_size: int, - kmeans_init: int = False, - kmeans_iters: int = 10, - decay: float = 0.99, - epsilon: float = 1e-5, - threshold_ema_dead_code: int = 2, - ): - super().__init__() - self.decay = decay - init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros - embed = init_fn(codebook_size, dim) - - self.codebook_size = codebook_size - - self.kmeans_iters = kmeans_iters - self.epsilon = epsilon - self.threshold_ema_dead_code = threshold_ema_dead_code - - self.register_buffer("inited", torch.Tensor([not kmeans_init])) - self.register_buffer("cluster_size", torch.zeros(codebook_size)) - self.register_buffer("embed", embed) - self.register_buffer("embed_avg", embed.clone()) - - @torch.jit.ignore - def init_embed_(self, data): - if self.inited: - return - - embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) - self.embed.data.copy_(embed) - self.embed_avg.data.copy_(embed.clone()) - self.cluster_size.data.copy_(cluster_size) - self.inited.data.copy_(torch.Tensor([True])) - # Make sure all buffers across workers are in sync after initialization - #broadcast_tensors(self.buffers()) - - def replace_(self, samples, mask): - modified_codebook = torch.where( - mask[..., None], sample_vectors(samples, self.codebook_size), self.embed - ) - self.embed.data.copy_(modified_codebook) - - def expire_codes_(self, batch_samples): - if self.threshold_ema_dead_code == 0: - return - - expired_codes = self.cluster_size < self.threshold_ema_dead_code - if not torch.any(expired_codes): - return - - batch_samples = rearrange(batch_samples, "... d -> (...) d") - self.replace_(batch_samples, mask=expired_codes) - #broadcast_tensors(self.buffers()) - - def preprocess(self, x): - x = rearrange(x, "... d -> (...) d") - return x - - def quantize(self, x): - embed = self.embed.t() - dist = -( - x.pow(2).sum(1, keepdim=True) - - 2 * x @ embed - + embed.pow(2).sum(0, keepdim=True) - ) - embed_ind = dist.max(dim=-1).indices - return embed_ind - - def postprocess_emb(self, embed_ind, shape): - return embed_ind.view(*shape[:-1]) - - def dequantize(self, embed_ind): - quantize = F.embedding(embed_ind, self.embed) - return quantize - - def encode(self, x): - shape = x.shape - # pre-process - x = self.preprocess(x) - # quantize - embed_ind = self.quantize(x) - # post-process - embed_ind = self.postprocess_emb(embed_ind, shape) - return embed_ind - - def decode(self, embed_ind): - quantize = self.dequantize(embed_ind) - return quantize - - def forward(self, x): - shape, dtype = x.shape, x.dtype - x = self.preprocess(x) - - self.init_embed_(x) - - embed_ind = self.quantize(x) - embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) - embed_ind = self.postprocess_emb(embed_ind, shape) - quantize = self.dequantize(embed_ind) - - if self.training: - # We do the expiry of code at that point as buffers are in sync - # and all the workers will take the same decision. - self.expire_codes_(x) - ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) - embed_sum = x.t() @ embed_onehot - ema_inplace(self.embed_avg, embed_sum.t(), self.decay) - cluster_size = ( - laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) - * self.cluster_size.sum() - ) - embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) - self.embed.data.copy_(embed_normalized) - - return quantize, embed_ind - - -class VectorQuantization(nn.Module): - """Vector quantization implementation. - Currently supports only euclidean distance. - Args: - dim (int): Dimension - codebook_size (int): Codebook size - codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim. - decay (float): Decay for exponential moving average over the codebooks. - epsilon (float): Epsilon value for numerical stability. - kmeans_init (bool): Whether to use kmeans to initialize the codebooks. - kmeans_iters (int): Number of iterations used for kmeans initialization. - threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes - that have an exponential moving average cluster size less than the specified threshold with - randomly selected vector from the current batch. - commitment_weight (float): Weight for commitment loss. - """ - def __init__( - self, - dim: int, - codebook_size: int, - codebook_dim: tp.Optional[int] = None, - decay: float = 0.99, - epsilon: float = 1e-5, - kmeans_init: bool = True, - kmeans_iters: int = 50, - threshold_ema_dead_code: int = 2, - commitment_weight: float = 1., - ): - super().__init__() - _codebook_dim: int = default(codebook_dim, dim) - - requires_projection = _codebook_dim != dim - self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()) - self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()) - - self.epsilon = epsilon - self.commitment_weight = commitment_weight - - self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size, - kmeans_init=kmeans_init, kmeans_iters=kmeans_iters, - decay=decay, epsilon=epsilon, - threshold_ema_dead_code=threshold_ema_dead_code) - self.codebook_size = codebook_size - - @property - def codebook(self): - return self._codebook.embed - - def encode(self, x): - x = rearrange(x, "b d n -> b n d") - x = self.project_in(x) - embed_in = self._codebook.encode(x) - return embed_in - - def decode(self, embed_ind): - quantize = self._codebook.decode(embed_ind) - quantize = self.project_out(quantize) - quantize = rearrange(quantize, "b n d -> b d n") - return quantize - - def forward(self, x): - device = x.device - x = rearrange(x, "b d n -> b n d") - x = self.project_in(x) - - quantize, embed_ind = self._codebook(x) - - if self.training: - quantize = x + (quantize - x).detach() - - loss = torch.tensor([0.0], device=device, requires_grad=self.training) - - if self.training: - if self.commitment_weight > 0: - commit_loss = F.mse_loss(quantize.detach(), x) - loss = loss + commit_loss * self.commitment_weight - - quantize = self.project_out(quantize) - quantize = rearrange(quantize, "b n d -> b d n") - return quantize, embed_ind, loss - - -class ResidualVectorQuantization(nn.Module): - """Residual vector quantization implementation. - Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf - """ - def __init__(self, *, num_quantizers, **kwargs): - super().__init__() - self.layers = nn.ModuleList( - [VectorQuantization(**kwargs) for _ in range(num_quantizers)] - ) - - def forward(self, x, n_q: tp.Optional[int] = None): - quantized_out = 0.0 - residual = x - - all_losses = [] - all_indices = [] - - n_q = n_q or len(self.layers) - - for layer in self.layers[:n_q]: - quantized, indices, loss = layer(residual) - residual = residual - quantized - quantized_out = quantized_out + quantized - - all_indices.append(indices) - all_losses.append(loss) - - out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) - return quantized_out, out_indices, out_losses - - def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor: - residual = x - all_indices = [] - n_q = n_q or len(self.layers) - for layer in self.layers[:n_q]: - indices = layer.encode(residual) - quantized = layer.decode(indices) - residual = residual - quantized - all_indices.append(indices) - out_indices = torch.stack(all_indices) - return out_indices - - def decode(self, q_indices: torch.Tensor) -> torch.Tensor: - quantized_out = torch.tensor(0.0, device=q_indices.device) - for i, indices in enumerate(q_indices): - layer = self.layers[i] - quantized = layer.decode(indices) - quantized_out = quantized_out + quantized - return quantized_out diff --git a/AcademiCodec/quantization/distrib.py b/AcademiCodec/quantization/distrib.py deleted file mode 100644 index 4edc88a..0000000 --- a/AcademiCodec/quantization/distrib.py +++ /dev/null @@ -1,126 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -"""Torch distributed utilities.""" - -import typing as tp - -import torch - - -def rank(): - if torch.distributed.is_initialized(): - return torch.distributed.get_rank() - else: - return 0 - - -def world_size(): - if torch.distributed.is_initialized(): - return torch.distributed.get_world_size() - else: - return 1 - - -def is_distributed(): - return world_size() > 1 - - -def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM): - if is_distributed(): - return torch.distributed.all_reduce(tensor, op) - - -def _is_complex_or_float(tensor): - return torch.is_floating_point(tensor) or torch.is_complex(tensor) - - -def _check_number_of_params(params: tp.List[torch.Tensor]): - # utility function to check that the number of params in all workers is the same, - # and thus avoid a deadlock with distributed all reduce. - if not is_distributed() or not params: - return - #print('params[0].device ', params[0].device) - tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long) - all_reduce(tensor) - if tensor.item() != len(params) * world_size(): - # If not all the workers have the same number, for at least one of them, - # this inequality will be verified. - raise RuntimeError(f"Mismatch in number of params: ours is {len(params)}, " - "at least one worker has a different one.") - - -def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0): - """Broadcast the tensors from the given parameters to all workers. - This can be used to ensure that all workers have the same model to start with. - """ - if not is_distributed(): - return - tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)] - _check_number_of_params(tensors) - handles = [] - for tensor in tensors: - # src = int(rank()) # added code - handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True) - handles.append(handle) - for handle in handles: - handle.wait() - - -def sync_buffer(buffers, average=True): - """ - Sync grad for buffers. If average is False, broadcast instead of averaging. - """ - if not is_distributed(): - return - handles = [] - for buffer in buffers: - if torch.is_floating_point(buffer.data): - if average: - handle = torch.distributed.all_reduce( - buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True) - else: - handle = torch.distributed.broadcast( - buffer.data, src=0, async_op=True) - handles.append((buffer, handle)) - for buffer, handle in handles: - handle.wait() - if average: - buffer.data /= world_size - - -def sync_grad(params): - """ - Simpler alternative to DistributedDataParallel, that doesn't rely - on any black magic. For simple models it can also be as fast. - Just call this on your model parameters after the call to backward! - """ - if not is_distributed(): - return - handles = [] - for p in params: - if p.grad is not None: - handle = torch.distributed.all_reduce( - p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True) - handles.append((p, handle)) - for p, handle in handles: - handle.wait() - p.grad.data /= world_size() - - -def average_metrics(metrics: tp.Dict[str, float], count=1.): - """Average a dictionary of metrics across all workers, using the optional - `count` as unormalized weight. - """ - if not is_distributed(): - return metrics - keys, values = zip(*metrics.items()) - device = 'cuda' if torch.cuda.is_available() else 'cpu' - tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32) - tensor *= count - all_reduce(tensor) - averaged = (tensor[:-1] / tensor[-1]).cpu().tolist() - return dict(zip(keys, averaged)) diff --git a/AcademiCodec/quantization/vq.py b/AcademiCodec/quantization/vq.py deleted file mode 100644 index d603e39..0000000 --- a/AcademiCodec/quantization/vq.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -"""Residual vector quantizer implementation.""" - -from dataclasses import dataclass, field -import math -import typing as tp - -import torch -from torch import nn - -from .core_vq import ResidualVectorQuantization - - -@dataclass -class QuantizedResult: - quantized: torch.Tensor - codes: torch.Tensor - bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item. - penalty: tp.Optional[torch.Tensor] = None - metrics: dict = field(default_factory=dict) - - -class ResidualVectorQuantizer(nn.Module): - """Residual Vector Quantizer. - Args: - dimension (int): Dimension of the codebooks. - n_q (int): Number of residual vector quantizers used. - bins (int): Codebook size. - decay (float): Decay for exponential moving average over the codebooks. - kmeans_init (bool): Whether to use kmeans to initialize the codebooks. - kmeans_iters (int): Number of iterations used for kmeans initialization. - threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes - that have an exponential moving average cluster size less than the specified threshold with - randomly selected vector from the current batch. - """ - def __init__( - self, - dimension: int = 256, - n_q: int = 8, - bins: int = 1024, - decay: float = 0.99, - kmeans_init: bool = True, - kmeans_iters: int = 50, - threshold_ema_dead_code: int = 2, - ): - super().__init__() - self.n_q = n_q - self.dimension = dimension - self.bins = bins - self.decay = decay - self.kmeans_init = kmeans_init - self.kmeans_iters = kmeans_iters - self.threshold_ema_dead_code = threshold_ema_dead_code - self.vq = ResidualVectorQuantization( - dim=self.dimension, - codebook_size=self.bins, - num_quantizers=self.n_q, - decay=self.decay, - kmeans_init=self.kmeans_init, - kmeans_iters=self.kmeans_iters, - threshold_ema_dead_code=self.threshold_ema_dead_code, - ) - - def forward(self, x: torch.Tensor, sample_rate: int, bandwidth: tp.Optional[float] = None) -> QuantizedResult: - """Residual vector quantization on the given input tensor. - Args: - x (torch.Tensor): Input tensor. - sample_rate (int): Sample rate of the input tensor. - bandwidth (float): Target bandwidth. - Returns: - QuantizedResult: - The quantized (or approximately quantized) representation with - the associated bandwidth and any penalty term for the loss. - """ - bw_per_q = self.get_bandwidth_per_quantizer(sample_rate) - n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth) - quantized, codes, commit_loss = self.vq(x, n_q=n_q) - bw = torch.tensor(n_q * bw_per_q).to(x) - return quantized, codes, bw, torch.mean(commit_loss) - #return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss)) - - def get_num_quantizers_for_bandwidth(self, sample_rate: int, bandwidth: tp.Optional[float] = None) -> int: - """Return n_q based on specified target bandwidth. - """ - bw_per_q = self.get_bandwidth_per_quantizer(sample_rate) - n_q = self.n_q - if bandwidth and bandwidth > 0.: - n_q = int(max(1, math.floor(bandwidth / bw_per_q))) - return n_q - - def get_bandwidth_per_quantizer(self, sample_rate: int): - """Return bandwidth per quantizer for a given input sample rate. - """ - return math.log2(self.bins) * sample_rate / 1000 - - def encode(self, x: torch.Tensor, sample_rate: int, bandwidth: tp.Optional[float] = None) -> torch.Tensor: - """Encode a given input tensor with the specified sample rate at the given bandwidth. - The RVQ encode method sets the appropriate number of quantizer to use - and returns indices for each quantizer. - """ - n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth) - codes = self.vq.encode(x, n_q=n_q) - return codes - - def decode(self, codes: torch.Tensor) -> torch.Tensor: - """Decode the given codes to the quantized representation. - """ - quantized = self.vq.decode(codes) - return quantized diff --git a/AcademiCodec/start.sh b/AcademiCodec/start.sh deleted file mode 100644 index 810519d..0000000 --- a/AcademiCodec/start.sh +++ /dev/null @@ -1,44 +0,0 @@ - -#sleep 666666666666666666666666666666666666666666666666666666666666666 - -set -e - -proj_dir="path_of_academiCodec" - -export PYTHONPATH=${proj_dir}:$PYTHONPATH -log_root="logs" - -input_training_file="train.lst" # .lst save the wav path. -input_validation_file="valid_256.lst" - -#mode=$1 # debug or train -#mode=debug -mode=train - -if [ "${mode}" == "debug" ]; then - ## debug - echo "Debug" - log_root=${log_root}_debug - export CUDA_VISIBLE_DEVICES=0 - python ${proj_dir}/train.py \ - --config ${proj_dir}/configs/param_config.json \ - --checkpoint_path ${log_root} \ - --input_training_file ${input_training_file} \ - --input_validation_file ${input_validation_file} \ - --checkpoint_interval 100 \ - --summary_interval 10 \ - --validation_interval 100 \ - -elif [ "$mode" == "train" ]; then - ## train - echo "Train model..." - export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 - python ${proj_dir}/train.py \ - --config ${proj_dir}/configs/param_config.json \ - --checkpoint_path ${log_root} \ - --input_training_file ${input_training_file} \ - --input_validation_file ${input_validation_file} \ - --checkpoint_interval 5000 \ - --summary_interval 100 \ - --validation_interval 5000 -fi diff --git a/AcademiCodec/test.sh b/AcademiCodec/test.sh deleted file mode 100644 index ce3b1b1..0000000 --- a/AcademiCodec/test.sh +++ /dev/null @@ -1,18 +0,0 @@ - -proj_dir="path_of_proj" -log_root="logs" -ckpt="$(ls -dt "${log_root}"/g_* | head -1 || true)" -echo checkpoint path: ${ckpt} - -# the path of test wave -wav_dir="test_wavs" - -outputdir=${log_root}/copysyn_$(date '+%Y-%m-%d-%H-%M-%S') -mkdir -p ${outputdir} - -CUDA_VISIBLE_DEVICES=0 python ./vqvae_copy_syn.py \ - --model_path ${ckpt} \ - --config_path ${log_root}/config.json \ - --input_wavdir ${wav_dir} \ - --outputdir ${outputdir} \ - --num_gens 10000 diff --git a/AcademiCodec/train.py b/AcademiCodec/train.py deleted file mode 100644 index 3aa4d63..0000000 --- a/AcademiCodec/train.py +++ /dev/null @@ -1,340 +0,0 @@ -import warnings -warnings.simplefilter(action='ignore', category=FutureWarning) -import itertools -import os -import time -import argparse -import json -import torch -import torch.nn.functional as F -from torch.utils.tensorboard import SummaryWriter -from torch.utils.data import DistributedSampler, DataLoader -import torch.multiprocessing as mp -from torch.distributed import init_process_group -from torch.nn.parallel import DistributedDataParallel -from env import AttrDict, build_env -from meldataset import MelDataset, mel_spectrogram, get_dataset_filelist -from msstftd import MultiScaleSTFTDiscriminator -from models import Generator, MultiPeriodDiscriminator, MultiScaleDiscriminator, feature_loss, generator_loss,\ - discriminator_loss, Encoder, Quantizer -try: - from utils import plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint -except: - from .utils import plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint - -torch.backends.cudnn.benchmark = True - -def reconstruction_loss(x, G_x, device, eps=1e-7): - L = 100*F.mse_loss(x, G_x) # wav L1 loss - for i in range(6,11): - s = 2**i - melspec = MelSpectrogram(sample_rate=24000, n_fft=s, hop_length=s//4, n_mels=64, wkwargs={"device": device}).to(device) - # 64, 16, 64 - # 128, 32, 128 - # 256, 64, 256 - # 512, 128, 512 - # 1024, 256, 1024 - S_x = melspec(x) - S_G_x = melspec(G_x) - loss = ((S_x-S_G_x).abs().mean() + (((torch.log(S_x.abs()+eps)-torch.log(S_G_x.abs()+eps))**2).mean(dim=-2)**0.5).mean())/(i) - L += loss - #print('i ,loss ', i, loss) - #assert 1==2 - return L - -def train(rank, a, h): - if h.num_gpus > 1: - init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'], - world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank) - - torch.cuda.manual_seed(h.seed) - device = torch.device('cuda:{:d}'.format(rank)) - - encoder = Encoder(h).to(device) - generator = Generator(h).to(device) - quantizer = Quantizer(h).to(device) - mpd = MultiPeriodDiscriminator().to(device) - msd = MultiScaleDiscriminator().to(device) - mstftd = MultiScaleSTFTDiscriminator(32).to(device) - if rank == 0: - print(encoder) - print(quantizer) - print(generator) - os.makedirs(a.checkpoint_path, exist_ok=True) - print("checkpoints directory : ", a.checkpoint_path) - - if os.path.isdir(a.checkpoint_path): - cp_g = scan_checkpoint(a.checkpoint_path, 'g_') - cp_do = scan_checkpoint(a.checkpoint_path, 'do_') - - steps = 0 - if cp_g is None or cp_do is None: - state_dict_do = None - last_epoch = -1 - else: - state_dict_g = load_checkpoint(cp_g, device) - state_dict_do = load_checkpoint(cp_do, device) - generator.load_state_dict(state_dict_g['generator']) - encoder.load_state_dict(state_dict_g['encoder']) - quantizer.load_state_dict(state_dict_g['quantizer']) - mpd.load_state_dict(state_dict_do['mpd']) - msd.load_state_dict(state_dict_do['msd']) - mstftd.load_state_dict(state_dict_do['mstftd']) - steps = state_dict_do['steps'] + 1 - last_epoch = state_dict_do['epoch'] - - if h.num_gpus > 1: - generator = DistributedDataParallel(generator, device_ids=[rank]).to(device) - encoder = DistributedDataParallel(encoder, device_ids=[rank]).to(device) - quantizer = DistributedDataParallel(quantizer, device_ids=[rank]).to(device) - mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device) - msd = DistributedDataParallel(msd, device_ids=[rank]).to(device) - mstftd = DistributedDataParallel(mstftd, device_ids=[rank]).to(device) - - optim_g = torch.optim.Adam(itertools.chain(generator.parameters(), encoder.parameters(), quantizer.parameters()), - h.learning_rate, betas=[h.adam_b1, h.adam_b2]) - optim_d = torch.optim.Adam(itertools.chain(msd.parameters(), mpd.parameters(), mstftd.parameters()), - h.learning_rate, betas=[h.adam_b1, h.adam_b2]) - if state_dict_do is not None: - optim_g.load_state_dict(state_dict_do['optim_g']) - optim_d.load_state_dict(state_dict_do['optim_d']) - - scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch) - scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch) - - training_filelist, validation_filelist = get_dataset_filelist(a) - - trainset = MelDataset(training_filelist, h.segment_size, h.n_fft, h.num_mels, - h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, n_cache_reuse=0, - shuffle=False if h.num_gpus > 1 else True, fmax_loss=h.fmax_for_loss, device=device, - fine_tuning=a.fine_tuning, base_mels_path=a.input_mels_dir) - - train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None - - train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False, - sampler=train_sampler, - batch_size=h.batch_size, - pin_memory=True, - drop_last=True) - - if rank == 0: - validset = MelDataset(validation_filelist, h.segment_size, h.n_fft, h.num_mels, - h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, False, False, n_cache_reuse=0, - fmax_loss=h.fmax_for_loss, device=device, fine_tuning=a.fine_tuning, - base_mels_path=a.input_mels_dir) - validation_loader = DataLoader(validset, num_workers=1, shuffle=False, - sampler=None, - batch_size=1, - pin_memory=True, - drop_last=True) - sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs')) - plot_gt_once = False - generator.train() - encoder.train() - quantizer.train() - mpd.train() - msd.train() - for epoch in range(max(0, last_epoch), a.training_epochs): - if rank == 0: - start = time.time() - print("Epoch: {}".format(epoch+1)) - if h.num_gpus > 1: - train_sampler.set_epoch(epoch) - for i, batch in enumerate(train_loader): - if rank == 0: - start_b = time.time() - x, y, _, y_mel = batch - x = torch.autograd.Variable(x.to(device, non_blocking=True)) - y = torch.autograd.Variable(y.to(device, non_blocking=True)) - y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True)) - y = y.unsqueeze(1) - - c = encoder(y) - # print("c.shape: ", c.shape) - q, loss_q, c = quantizer(c) - # print("q.shape: ", q.shape) - y_g_hat = generator(q) - y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, - h.fmin, h.fmax_for_loss) # 1024, 80, 24000, 240,1024 - y_r_mel_1 = mel_spectrogram(y.squeeze(1), 512, h.num_mels, h.sampling_rate, 120, 512, - h.fmin, h.fmax_for_loss) - y_g_mel_1 = mel_spectrogram(y_g_hat.squeeze(1), 512, h.num_mels, h.sampling_rate, 120, 512, - h.fmin, h.fmax_for_loss) - y_r_mel_2 = mel_spectrogram(y.squeeze(1), 256, h.num_mels, h.sampling_rate, 60, 256, - h.fmin, h.fmax_for_loss) - y_g_mel_2 = mel_spectrogram(y_g_hat.squeeze(1), 256, h.num_mels, h.sampling_rate, 60, 256, - h.fmin, h.fmax_for_loss) - y_r_mel_3 = mel_spectrogram(y.squeeze(1), 128, h.num_mels, h.sampling_rate, 30, 128, - h.fmin, h.fmax_for_loss) - y_g_mel_3 = mel_spectrogram(y_g_hat.squeeze(1), 128, h.num_mels, h.sampling_rate, 30, 128, - h.fmin, h.fmax_for_loss) - # print("x.shape: ", x.shape) - # print("y.shape: ", y.shape) - # print("y_g_hat.shape: ", y_g_hat.shape) - optim_d.zero_grad() - - # MPD - y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach()) - loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g) - - # MSD - y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach()) - loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g) - - y_disc_r, fmap_r = mstftd(y) - y_disc_gen, fmap_gen = mstftd(y_g_hat.detach()) - loss_disc_stft, losses_disc_stft_r, losses_disc_stft_g = discriminator_loss(y_disc_r, y_disc_gen) - loss_disc_all = loss_disc_s + loss_disc_f + loss_disc_stft - - loss_disc_all.backward() - optim_d.step() - - # Generator - optim_g.zero_grad() - - # L1 Mel-Spectrogram Loss - loss_mel1 = F.l1_loss(y_r_mel_1, y_g_mel_1) - loss_mel2 = F.l1_loss(y_r_mel_2, y_g_mel_2) - loss_mel3 = F.l1_loss(y_r_mel_3, y_g_mel_3) - #print('loss_mel1, loss_mel2 ', loss_mel1, loss_mel2) - loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45 + loss_mel1 + loss_mel2 - # print('loss_mel ', loss_mel) - # assert 1==2 - y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat) - y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat) - y_stftd_hat_r, fmap_stftd_r = mstftd(y) - y_stftd_hat_g, fmap_stftd_g = mstftd(y_g_hat) - loss_fm_f = feature_loss(fmap_f_r, fmap_f_g) - loss_fm_s = feature_loss(fmap_s_r, fmap_s_g) - loss_fm_stft = feature_loss(fmap_stftd_r, fmap_stftd_g) - loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g) - loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g) - loss_gen_stft, losses_gen_stft = generator_loss(y_stftd_hat_g) - loss_gen_all = loss_gen_s + loss_gen_f + loss_gen_stft + loss_fm_s + loss_fm_f + loss_fm_stft + loss_mel + loss_q * 10 - loss_gen_all.backward() - optim_g.step() - if rank == 0: - # STDOUT logging - if steps % a.stdout_interval == 0: - with torch.no_grad(): - mel_error = F.l1_loss(y_mel, y_g_hat_mel).item() - print('Steps : {:d}, Gen Loss Total : {:4.3f}, Loss Q : {:4.3f}, Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}'. - format(steps, loss_gen_all, loss_q, mel_error, time.time() - start_b)) - # checkpointing - if steps % a.checkpoint_interval == 0 and steps != 0: - checkpoint_path = "{}/g_{:08d}".format(a.checkpoint_path, steps) - save_checkpoint(checkpoint_path, - {'generator': (generator.module if h.num_gpus > 1 else generator).state_dict(), - 'encoder': (encoder.module if h.num_gpus > 1 else encoder).state_dict(), - 'quantizer': (quantizer.module if h.num_gpus > 1 else quantizer).state_dict() - }, num_ckpt_keep=a.num_ckpt_keep) - checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, steps) - save_checkpoint(checkpoint_path, - {'mpd': (mpd.module if h.num_gpus > 1 - else mpd).state_dict(), - 'msd': (msd.module if h.num_gpus > 1 - else msd).state_dict(), - 'mstftd': (mstftd.module if h.num_gpus > 1 - else msd).state_dict(), - 'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps, - 'epoch': epoch}, num_ckpt_keep=a.num_ckpt_keep) - # Tensorboard summary logging - if steps % a.summary_interval == 0: - sw.add_scalar("training/gen_loss_total", loss_gen_all, steps) - sw.add_scalar("training/mel_spec_error", mel_error, steps) - - # Validation - if steps % a.validation_interval == 0 and steps != 0: - generator.eval() - encoder.eval() - quantizer.eval() - torch.cuda.empty_cache() - val_err_tot = 0 - with torch.no_grad(): - for j, batch in enumerate(validation_loader): - x, y, _, y_mel = batch - c = encoder(y.to(device).unsqueeze(1)) - q, loss_q, c = quantizer(c) - y_g_hat = generator(q) - y_mel = torch.autograd.Variable(y_mel.to(device)) - y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, - h.hop_size, h.win_size, - h.fmin, h.fmax_for_loss) - i_size = min(y_mel.size(2), y_g_hat_mel.size(2)) - val_err_tot += F.l1_loss(y_mel[:, :, :i_size], y_g_hat_mel[:, :, :i_size]).item() - - if j <= 8: - # if steps == 0: - if not plot_gt_once: - sw.add_audio('gt/y_{}'.format(j), y[0], steps, h.sampling_rate) - sw.add_figure('gt/y_spec_{}'.format(j), plot_spectrogram(x[0]), steps) - - sw.add_audio('generated/y_hat_{}'.format(j), y_g_hat[0], steps, h.sampling_rate) - y_hat_spec = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, - h.sampling_rate, h.hop_size, h.win_size, - h.fmin, h.fmax) - sw.add_figure('generated/y_hat_spec_{}'.format(j), - plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()), steps) - - val_err = val_err_tot / (j+1) - sw.add_scalar("validation/mel_spec_error", val_err, steps) - if not plot_gt_once: - plot_gt_once = True - - generator.train() - - steps += 1 - - scheduler_g.step() - scheduler_d.step() - - if rank == 0: - print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start))) - - -def main(): - print('Initializing Training Process..') - - parser = argparse.ArgumentParser() - - # parser.add_argument('--group_name', default=None) - # parser.add_argument('--input_wavs_dir', default='../datasets/audios') - parser.add_argument('--input_mels_dir', default=None) - parser.add_argument('--input_training_file', required=True) - parser.add_argument('--input_validation_file', required=True) - parser.add_argument('--checkpoint_path', default='checkpoints') - parser.add_argument('--config', default='') - parser.add_argument('--training_epochs', default=2000, type=int) - parser.add_argument('--stdout_interval', default=5, type=int) - parser.add_argument('--checkpoint_interval', default=5000, type=int) - parser.add_argument('--summary_interval', default=100, type=int) - parser.add_argument('--validation_interval', default=5000, type=int) - parser.add_argument('--num_ckpt_keep', default=5, type=int) - parser.add_argument('--fine_tuning', default=False, type=bool) - - a = parser.parse_args() - - with open(a.config) as f: - data = f.read() - - json_config = json.loads(data) - h = AttrDict(json_config) - build_env(a.config, 'config.json', a.checkpoint_path) - - torch.manual_seed(h.seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(h.seed) - h.num_gpus = torch.cuda.device_count() - h.batch_size = int(h.batch_size / h.num_gpus) - print('Batch size per GPU :', h.batch_size) - else: - pass - - if h.num_gpus > 1: - mp.spawn(train, nprocs=h.num_gpus, args=(a, h,)) - else: - train(0, a, h) - - -if __name__ == '__main__': - main() diff --git a/AcademiCodec/utils.py b/AcademiCodec/utils.py deleted file mode 100644 index b56f00b..0000000 --- a/AcademiCodec/utils.py +++ /dev/null @@ -1,64 +0,0 @@ -import glob -import os -import matplotlib -import torch -from torch.nn.utils import weight_norm -matplotlib.use("Agg") -import matplotlib.pylab as plt -import re -import pathlib - - -def plot_spectrogram(spectrogram): - fig, ax = plt.subplots(figsize=(10, 2)) - im = ax.imshow(spectrogram, aspect="auto", origin="lower", - interpolation='none') - plt.colorbar(im, ax=ax) - - fig.canvas.draw() - plt.close() - - return fig - - -def init_weights(m, mean=0.0, std=0.01): - classname = m.__class__.__name__ - if classname.find("Conv") != -1: - m.weight.data.normal_(mean, std) - - -def apply_weight_norm(m): - classname = m.__class__.__name__ - if classname.find("Conv") != -1: - weight_norm(m) - - -def get_padding(kernel_size, dilation=1): - return int((kernel_size*dilation - dilation)/2) - - -def load_checkpoint(filepath, device): - assert os.path.isfile(filepath) - print("Loading '{}'".format(filepath)) - checkpoint_dict = torch.load(filepath, map_location=device) - print("Complete.") - return checkpoint_dict - - -def save_checkpoint(filepath, obj, num_ckpt_keep=5): - name = re.match(r'(do|g)_\d+', pathlib.Path(filepath).name).group(1) - ckpts = sorted(pathlib.Path(filepath).parent.glob(f'{name}_*')) - if len(ckpts) > num_ckpt_keep: - [os.remove(c) for c in ckpts[:-num_ckpt_keep]] - print("Saving checkpoint to {}".format(filepath)) - torch.save(obj, filepath) - print("Complete.") - - -def scan_checkpoint(cp_dir, prefix): - pattern = os.path.join(cp_dir, prefix + '????????') - cp_list = glob.glob(pattern) - if len(cp_list) == 0: - return None - return sorted(cp_list)[-1] - diff --git a/AcademiCodec/vqvae.py b/AcademiCodec/vqvae.py deleted file mode 100644 index 28f5bd4..0000000 --- a/AcademiCodec/vqvae.py +++ /dev/null @@ -1,38 +0,0 @@ -import torch -import torch.nn as nn -import json - -from models import Generator, Quantizer, Encoder -from env import AttrDict - - -class VQVAE(nn.Module): - def __init__(self, config_path, ckpt_path, with_encoder=False): - super(VQVAE, self).__init__() - ckpt = torch.load(ckpt_path) - with open(config_path) as f: - data = f.read() - json_config = json.loads(data) - self.h = AttrDict(json_config) - self.quantizer = Quantizer(self.h) - self.generator = Generator(self.h) - self.generator.load_state_dict(ckpt['generator']) - self.quantizer.load_state_dict(ckpt['quantizer']) - if with_encoder: - self.encoder = Encoder(self.h) - self.encoder.load_state_dict(ckpt['encoder']) - - def forward(self, x): - # x is the codebook - # print('x ', x.shape) - # assert 1==2 - return self.generator(self.quantizer.embed(x)) # - - def encode(self, x): - batch_size = x.size(0) - c = self.encoder(x.unsqueeze(1)) - q, loss_q, c = self.quantizer(c) - c = [code.reshape(batch_size, -1) for code in c] - # print(torch.stack(c,-1).shape) - # assert 1==2 - return torch.stack(c, -1) #N, T, 4 diff --git a/AcademiCodec/vqvae_copy_syn.py b/AcademiCodec/vqvae_copy_syn.py deleted file mode 100644 index af6dc69..0000000 --- a/AcademiCodec/vqvae_copy_syn.py +++ /dev/null @@ -1,48 +0,0 @@ -import argparse -import soundfile as sf -import os -from pathlib import Path -import json -import glob -from tqdm import tqdm -from vqvae_tester import VqvaeTester - -parser = argparse.ArgumentParser() - -#Path -parser.add_argument('--outputdir', type=str, required=True) -parser.add_argument('--model_path', type=str, required=True) -parser.add_argument('--input_wavdir', type=str, required=True) -parser.add_argument('--config_path', type=str, required=True) -parser.add_argument('--num_gens', type=int, default=1024) - -#Data -parser.add_argument('--sample_rate', type=int, default=24000) - -args = parser.parse_args() - -with open(args.config_path, 'r') as f: - argdict = json.load(f) - assert argdict['sampling_rate'] == args.sample_rate, \ - f"Sampling rate not consistent, stated {args.sample_rate}, but the model is trained on {argdict['sample_rate']}" - argdict.update(args.__dict__) - args.__dict__ = argdict - - -if __name__ == '__main__': - Path(args.outputdir).mkdir(parents=True, exist_ok=True) - print("Init model and load weights") - model = VqvaeTester(args) - model.cuda() - model.vqvae.generator.remove_weight_norm() - model.vqvae.encoder.remove_weight_norm() - model.eval() - print("Model ready") - - wav_paths = glob.glob(f"{args.input_wavdir}/*.wav")[:args.num_gens] - print(f"Globbed {len(wav_paths)} wav files.") - - for wav_path in tqdm(wav_paths): - fid, wav = model(wav_path) - wav = wav.squeeze().cpu().numpy() - sf.write(os.path.join(args.outputdir, f'{fid}.wav'), wav, args.sample_rate) diff --git a/AcademiCodec/vqvae_tester.py b/AcademiCodec/vqvae_tester.py deleted file mode 100644 index 889028a..0000000 --- a/AcademiCodec/vqvae_tester.py +++ /dev/null @@ -1,47 +0,0 @@ -import os - -import torch -import torch.nn as nn -import torch.nn.functional as F -import numpy as np -import soundfile as sf -import librosa -from librosa.util import normalize -from tqdm import tqdm - -from vqvae import VQVAE - - -class VqvaeTester(nn.Module): - - def __init__(self, hp): - super().__init__() - self.hp = hp - self.vqvae = VQVAE(hp.config_path, hp.model_path, with_encoder=True) - self.sample_rate = self.hp.sample_rate - - @torch.no_grad() - def forward(self, wav_path): - wav, sr = sf.read(wav_path) - if sr != self.sample_rate: - wav = librosa.resample(wav, orig_sr=sr, target_sr=self.hp.sample_rate) - fid = os.path.basename(wav_path)[:-4] - wav = normalize(wav) * 0.95 - wav = torch.FloatTensor(wav).unsqueeze(0) - wav = wav.to(torch.device('cuda')) - vq_codes = self.vqvae.encode(wav) # - syn = self.vqvae(vq_codes) - return fid, syn - - @torch.no_grad() - def vq(self, wav_path): - wav, sr = sf.read(wav_path) - if sr != self.sample_rate: - wav = librosa.resample(wav, orig_sr=sr, target_sr=self.hp.sample_rate) - fid = os.path.basename(wav_path)[:-4] - wav = normalize(wav) * 0.95 - wav = torch.FloatTensor(wav).unsqueeze(0) - wav = wav.to(torch.device('cuda')) - vq_codes = self.vqvae.encode(wav) - - return fid, vq_codes