diff --git a/configs/config_naivev2.yaml b/configs/config_naivev2.yaml index dc113be..ade595d 100644 --- a/configs/config_naivev2.yaml +++ b/configs/config_naivev2.yaml @@ -36,7 +36,16 @@ model: conv_dropout: 0.0 atten_dropout: 0.1 use_weight_norm: false -device: cuda +device: 'cuda' +ddp: + use_ddp: false # if true, ddp_device will cover device and gpu id + port: '13348' + ddp_cache_gpu: false + ddp_device: + - 'cuda:1' + - 'cuda:2' + - 'cuda:3' + - 'cuda:4' vocoder: type: 'nsf-hifigan' ckpt: 'pretrain/nsf_hifigan/model' diff --git a/configs/config_naivev2diff.yaml b/configs/config_naivev2diff.yaml index 704816c..dea0e3b 100644 --- a/configs/config_naivev2diff.yaml +++ b/configs/config_naivev2diff.yaml @@ -41,8 +41,20 @@ model: conv_model_type: 'mode1' conv_dropout: 0.0 atten_dropout: 0.1 + conv_model_activation: 'SiLU' + GLU_type: 'GLU' + channel_norm: false mask_cond_ratio: 'NOTUSE' # input 'NOTUSE' if not use -device: cuda +device: 'cuda' +ddp: + use_ddp: false # if true, ddp_device will cover device and gpu id + port: '13348' + ddp_cache_gpu: false + ddp_device: + - 'cuda:1' + - 'cuda:2' + - 'cuda:3' + - 'cuda:4' vocoder: type: 'nsf-hifigan' ckpt: 'pretrain/nsf_hifigan/model' diff --git a/configs/config_naivev2diff_comb.yaml b/configs/config_naivev2diff_comb.yaml index a798e10..d9d88e3 100644 --- a/configs/config_naivev2diff_comb.yaml +++ b/configs/config_naivev2diff_comb.yaml @@ -42,6 +42,9 @@ model: conv_model_type: 'mode1' # dont change conv_dropout: 0.0 atten_dropout: 0.1 + conv_model_activation: 'SiLU' + GLU_type: 'GLU' + channel_norm: false mask_cond_ratio: 'NOTUSE' # input 'NOTUSE' if not use naive_fn: type: 'LYNXNet' # LYNXNet is thr other name of ConformerNaiveEncoder(NaiveNet) @@ -60,7 +63,16 @@ model: use_weight_norm: false naive_fn_grad_not_by_diffusion: false # dont change if dont understand; more info:diffusion/unit2mel.py naive_out_mel_cond_diff: false # mel condition diffusion is a test function, maybe can make the model learn faster but less quality and pitch range. -device: cuda +device: 'cuda' +ddp: + use_ddp: false # if true, ddp_device will cover device and gpu id + port: '13348' + ddp_cache_gpu: false + ddp_device: + - 'cuda:1' + - 'cuda:2' + - 'cuda:3' + - 'cuda:4' vocoder: type: 'nsf-hifigan' ckpt: 'pretrain/nsf_hifigan/model' diff --git a/configs/config_naivev2diff_shallow.yaml b/configs/config_naivev2diff_shallow.yaml index d81fdf0..dcfbd62 100644 --- a/configs/config_naivev2diff_shallow.yaml +++ b/configs/config_naivev2diff_shallow.yaml @@ -42,8 +42,20 @@ model: conv_model_type: 'mode1' conv_dropout: 0.0 atten_dropout: 0.1 + conv_model_activation: 'SiLU' + GLU_type: 'GLU' + channel_norm: false mask_cond_ratio: 'NOTUSE' # input 'NOTUSE' if not use -device: cuda +device: 'cuda' +ddp: + use_ddp: false # if true, ddp_device will cover device and gpu id + port: '13348' + ddp_cache_gpu: false + ddp_device: + - 'cuda:1' + - 'cuda:2' + - 'cuda:3' + - 'cuda:4' vocoder: type: 'nsf-hifigan' ckpt: 'pretrain/nsf_hifigan/model' diff --git a/configs/config_naivev2diff_vae.yaml b/configs/config_naivev2diff_vae.yaml index 4a698dd..f9ca573 100644 --- a/configs/config_naivev2diff_vae.yaml +++ b/configs/config_naivev2diff_vae.yaml @@ -41,8 +41,20 @@ model: conv_model_type: 'mode1' conv_dropout: 0.0 atten_dropout: 0.1 + conv_model_activation: 'SiLU' + GLU_type: 'GLU' + channel_norm: false mask_cond_ratio: 'NOTUSE' # input 'NOTUSE' if not use -device: cuda +device: 'cuda' +ddp: + use_ddp: false # if true, ddp_device will cover device and gpu id + port: '13348' + ddp_cache_gpu: false + ddp_device: + - 'cuda:1' + - 'cuda:2' + - 'cuda:3' + - 'cuda:4' vocoder: type: 'hifivaegan' ckpt: 'pretrain/hifivaegan/G_224800.pth' diff --git a/configs/config_naivev2reflow.yaml b/configs/config_naivev2reflow.yaml index a4032a4..0856478 100644 --- a/configs/config_naivev2reflow.yaml +++ b/configs/config_naivev2reflow.yaml @@ -43,9 +43,24 @@ model: atten_dropout: 0.1 conv_model_activation: 'SiLU' GLU_type: 'GLU' + channel_norm: false mask_cond_ratio: 'NOTUSE' # input 'NOTUSE' if not use loss_type: 'l2' # 'l1', 'l2' or 'l2_lognorm' -device: cuda + consistency: false + consistency_only: true + consistency_delta_t: 0.1 + consistency_lambda_f: 1.0 + consistency_lambda_v: 0.01 +device: 'cuda' +ddp: + use_ddp: false # if true, ddp_device will cover device and gpu id + port: '13348' + ddp_cache_gpu: false + ddp_device: + - 'cuda:1' + - 'cuda:2' + - 'cuda:3' + - 'cuda:4' vocoder: type: 'nsf-hifigan' ckpt: 'pretrain/nsf_hifigan/model' diff --git a/configs/config_naivev2diff_comb_compile.yaml b/configs/config_naivev2reflow_1step_test.yaml similarity index 62% rename from configs/config_naivev2diff_comb_compile.yaml rename to configs/config_naivev2reflow_1step_test.yaml index 52af02d..fa7d858 100644 --- a/configs/config_naivev2diff_comb_compile.yaml +++ b/configs/config_naivev2reflow_1step_test.yaml @@ -18,27 +18,17 @@ data: extensions: # List of extension included in the data collection - wav model: - torch_compile_args: - use_copile: true - fullgraph: false - dynamic: 'none' # 'none',false or true - backend: 'inductor' # 'cudagraphs', 'inductor', 'onnxrt', 'openxla', 'openxla_eval', 'tvm' - mode: 'reduce-overhead' # 'default','reduce-overhead','max-autotune' or 'max-autotune-no-cudagraphs' - use_options: false # if use options, should be true - options: - k: 'v' - k_step_max: 100 - type: 'DiffusionV2' + t_start: 0.0 # do not change + type: 'ReFlow1Step' n_hidden: 256 use_pitch_aug: true n_spk: 2 # max number of different speakers z_rate: 0 # dont change mean_only: true - max_beta: 0.02 spec_min: -12 spec_max: 2 - denoise_fn: - type: 'NaiveV2Diff' + velocity_fn: + type: 'LYNXNetDiff' cn_layers: 6 cn_chans: 512 use_mlp: false # is use MLP in cond_emb and output_proj @@ -48,36 +38,32 @@ model: conv_only: true # use Transformer block with conv block, if false wavenet_like: false # dont change if dont understand; more info:diffusion/naive_v2/naive_v2_diff.py use_norm: false # pre-norm for every layers - conv_model_type: 'mode1' # dont change + conv_model_type: 'mode1' conv_dropout: 0.0 atten_dropout: 0.1 + conv_model_activation: 'SiLU' + GLU_type: 'GLU' + channel_norm: false mask_cond_ratio: 'NOTUSE' # input 'NOTUSE' if not use - naive_fn: - type: 'LYNXNet' # LYNXNet is thr other name of ConformerNaiveEncoder(NaiveNet) - n_layers: 3 - n_chans: 256 - simple_stack: false # use simple stack for unit emb - out_put_norm: true # norm and weight_norm in last layer - expansion_factor: 2 - kernel_size: 31 - conv_model_type: 'mode1' # dont change - num_heads: 8 - use_norm: false # pre-norm for every layers - conv_only: true # use Transformer block with conv block, if false - conv_dropout: 0.0 - atten_dropout: 0.1 - use_weight_norm: false - naive_fn_grad_not_by_diffusion: false # dont change if dont understand; more info:diffusion/unit2mel.py - naive_out_mel_cond_diff: false # mel condition diffusion is a test function, maybe can make the model learn faster but less quality and pitch range. -device: cuda + loss_type: 'l2' # 'l1', 'l2' or 'l2_lognorm' +device: 'cuda' +ddp: + use_ddp: false # if true, ddp_device will cover device and gpu id + port: '13348' + ddp_cache_gpu: false + ddp_device: + - 'cuda:1' + - 'cuda:2' + - 'cuda:3' + - 'cuda:4' vocoder: type: 'nsf-hifigan' ckpt: 'pretrain/nsf_hifigan/model' infer: - speedup: 1 - method: 'pndm' # 'ddim', 'pndm', 'dpm-solver' or 'unipc' + infer_step: 10 + method: 'euler' # 'euler', 'rk4' env: - expdir: exp/naivev2diffcombocompile + expdir: exp/naivev2reflow gpu_id: 0 train: ema_decay: 0.999 # <1 @@ -96,4 +82,4 @@ train: decay_step: 100000 gamma: 0.5 weight_decay: 0 - save_opt: false \ No newline at end of file + save_opt: false diff --git a/configs/config_naivev2reflow_combo.yaml b/configs/config_naivev2reflow_combo.yaml index 641c09f..8fc1b8e 100644 --- a/configs/config_naivev2reflow_combo.yaml +++ b/configs/config_naivev2reflow_combo.yaml @@ -41,6 +41,9 @@ model: conv_model_type: 'mode1' conv_dropout: 0.0 atten_dropout: 0.1 + conv_model_activation: 'SiLU' + GLU_type: 'GLU' + channel_norm: false mask_cond_ratio: 'NOTUSE' # input 'NOTUSE' if not use loss_type: 'l2' # 'l1', 'l2' or 'l2_lognorm' naive_fn: @@ -60,7 +63,21 @@ model: use_weight_norm: false naive_fn_grad_not_by_reflow: false # dont change if dont understand; more info:diffusion/unit2mel.py naive_out_mel_cond_reflow: false # mel condition diffusion is a test function, maybe can make the model learn faster but less quality and pitch range. -device: cuda + consistency: false + consistency_only: true + consistency_delta_t: 0.1 + consistency_lambda_f: 1.0 + consistency_lambda_v: 0.01 +device: 'cuda' +ddp: + use_ddp: false # if true, ddp_device will cover device and gpu id + port: '13348' + ddp_cache_gpu: false + ddp_device: + - 'cuda:1' + - 'cuda:2' + - 'cuda:3' + - 'cuda:4' vocoder: type: 'nsf-hifigan' ckpt: 'pretrain/nsf_hifigan/model' diff --git a/configs/config_naivev2reflow_shallow.yaml b/configs/config_naivev2reflow_shallow.yaml index d7a3001..f75e86d 100644 --- a/configs/config_naivev2reflow_shallow.yaml +++ b/configs/config_naivev2reflow_shallow.yaml @@ -41,9 +41,26 @@ model: conv_model_type: 'mode1' conv_dropout: 0.0 atten_dropout: 0.1 + conv_model_activation: 'SiLU' + GLU_type: 'GLU' + channel_norm: false mask_cond_ratio: 'NOTUSE' # input 'NOTUSE' if not use loss_type: 'l2' # 'l1', 'l2' or 'l2_lognorm' -device: cuda + consistency: false + consistency_only: true + consistency_delta_t: 0.1 + consistency_lambda_f: 1.0 + consistency_lambda_v: 0.01 +device: 'cuda' +ddp: + use_ddp: false # if true, ddp_device will cover device and gpu id + port: '13348' + ddp_cache_gpu: false + ddp_device: + - 'cuda:1' + - 'cuda:2' + - 'cuda:3' + - 'cuda:4' vocoder: type: 'nsf-hifigan' ckpt: 'pretrain/nsf_hifigan/model' diff --git a/configs/config_v2.yaml b/configs/config_v2.yaml index cc645b9..4f371aa 100644 --- a/configs/config_v2.yaml +++ b/configs/config_v2.yaml @@ -37,7 +37,16 @@ model: wn_tf_rf: false # only wn_tf_use is true and here is true will use RoFormer wn_tf_n_layers: 2 wn_tf_n_head: 4 -device: cuda +device: 'cuda' +ddp: + use_ddp: false # if true, ddp_device will cover device and gpu id + port: '13348' + ddp_cache_gpu: false + ddp_device: + - 'cuda:1' + - 'cuda:2' + - 'cuda:3' + - 'cuda:4' vocoder: type: 'nsf-hifigan' ckpt: 'pretrain/nsf_hifigan/model' diff --git a/configs/config_v2_comb.yaml b/configs/config_v2_comb.yaml index 4553e75..59c4e06 100644 --- a/configs/config_v2_comb.yaml +++ b/configs/config_v2_comb.yaml @@ -56,7 +56,16 @@ model: use_weight_norm: false naive_fn_grad_not_by_diffusion: false # dont change if dont understand; more info:diffusion/unit2mel.py naive_out_mel_cond_diff: false # mel condition diffusion is a test function, maybe can make the model learn faster but less quality and pitch range. -device: cuda +device: 'cuda' +ddp: + use_ddp: false # if true, ddp_device will cover device and gpu id + port: '13348' + ddp_cache_gpu: false + ddp_device: + - 'cuda:1' + - 'cuda:2' + - 'cuda:3' + - 'cuda:4' vocoder: type: 'nsf-hifigan' ckpt: 'pretrain/nsf_hifigan/model' diff --git a/configs/config_v2_reflow.yaml b/configs/config_v2_reflow.yaml index 6a3af72..a88fa98 100644 --- a/configs/config_v2_reflow.yaml +++ b/configs/config_v2_reflow.yaml @@ -39,7 +39,21 @@ model: wn_tf_n_layers: 2 wn_tf_n_head: 4 loss_type: 'l2_lognorm' # 'l1', 'l2' or 'l2_lognorm' -device: cuda + consistency: false + consistency_only: true + consistency_delta_t: 0.1 + consistency_lambda_f: 1.0 + consistency_lambda_v: 0.01 +device: 'cuda' +ddp: + use_ddp: false # if true, ddp_device will cover device and gpu id + port: '13348' + ddp_cache_gpu: false + ddp_device: + - 'cuda:1' + - 'cuda:2' + - 'cuda:3' + - 'cuda:4' vocoder: type: 'nsf-hifigan' ckpt: 'pretrain/nsf_hifigan/model' diff --git a/configs/config_v2_shallow.yaml b/configs/config_v2_shallow.yaml index d80eee2..02f8ab9 100644 --- a/configs/config_v2_shallow.yaml +++ b/configs/config_v2_shallow.yaml @@ -38,7 +38,16 @@ model: wn_tf_rf: false # only wn_tf_use is true and here is true will use RoFormer wn_tf_n_layers: 2 wn_tf_n_head: 4 -device: cuda +device: 'cuda' +ddp: + use_ddp: false # if true, ddp_device will cover device and gpu id + port: '13348' + ddp_cache_gpu: false + ddp_device: + - 'cuda:1' + - 'cuda:2' + - 'cuda:3' + - 'cuda:4' vocoder: type: 'nsf-hifigan' ckpt: 'pretrain/nsf_hifigan/model' diff --git a/configs/config_v2_vae.yaml b/configs/config_v2_vae.yaml index 3f4dfa6..2e1b1a1 100644 --- a/configs/config_v2_vae.yaml +++ b/configs/config_v2_vae.yaml @@ -37,7 +37,16 @@ model: wn_tf_rf: false # only wn_tf_use is true and here is true will use RoFormer wn_tf_n_layers: 2 wn_tf_n_head: 4 -device: cuda +device: 'cuda' +ddp: + use_ddp: false # if true, ddp_device will cover device and gpu id + port: '13348' + ddp_cache_gpu: false + ddp_device: + - 'cuda:1' + - 'cuda:2' + - 'cuda:3' + - 'cuda:4' vocoder: type: 'hifivaegan' ckpt: 'pretrain/hifivaegan/G_224800.pth' diff --git a/diffusion/data_loaders.py b/diffusion/data_loaders.py index 42ad719..74d2010 100644 --- a/diffusion/data_loaders.py +++ b/diffusion/data_loaders.py @@ -49,11 +49,15 @@ def traverse_dir( return file_list -def get_data_loaders(args, whole_audio=False): +def get_data_loaders(args, whole_audio=False, ddp=False, rank=0, ddp_cache_gpu=False, ddp_device_list=None): if args.data.volume_noise == 0: volume_noise = None else: volume_noise = args.data.volume_noise + _ddp_device = 'cpu' + if ddp: + if ddp_cache_gpu: + _ddp_device = ddp_device_list[rank] data_train = AudioDataset( args.data.train_path, waveform_sec=args.data.duration, @@ -63,42 +67,62 @@ def get_data_loaders(args, whole_audio=False): whole_audio=whole_audio, extensions=args.data.extensions, n_spk=args.model.n_spk, - device=args.train.cache_device, + device=args.train.cache_device if not ddp else _ddp_device, fp16=args.train.cache_fp16, use_aug=True, use_spk_encoder=args.model.use_speaker_encoder, spk_encoder_mode=args.data.speaker_encoder_mode, - volume_noise=volume_noise + volume_noise=volume_noise, + tqdm_rank=rank, + load_nothing_data=args.train.load_nothing_data ) - loader_train = torch.utils.data.DataLoader( - data_train, - batch_size=args.train.batch_size if not whole_audio else 1, - shuffle=True, - num_workers=args.train.num_workers if args.train.cache_device == 'cpu' else 0, - persistent_workers=(args.train.num_workers > 0) if args.train.cache_device == 'cpu' else False, - pin_memory=True if args.train.cache_device == 'cpu' else False - ) - data_valid = AudioDataset( - args.data.valid_path, - waveform_sec=args.data.duration, - hop_size=args.data.block_size, - sample_rate=args.data.sampling_rate, - load_all_data=args.train.cache_all_data, - whole_audio=True, - extensions=args.data.extensions, - n_spk=args.model.n_spk, - use_spk_encoder=args.model.use_speaker_encoder, - spk_encoder_mode=args.data.speaker_encoder_mode, - volume_noise=volume_noise - ) - loader_valid = torch.utils.data.DataLoader( - data_valid, - batch_size=1, - shuffle=False, - num_workers=0, - pin_memory=True - ) - return loader_train, loader_valid + if not ddp: + samper_train = None + loader_train = torch.utils.data.DataLoader( + data_train, + batch_size=args.train.batch_size if not whole_audio else 1, + shuffle=True, + num_workers=args.train.num_workers if args.train.cache_device == 'cpu' else 0, + persistent_workers=(args.train.num_workers > 0) if args.train.cache_device == 'cpu' else False, + pin_memory=True if args.train.cache_device == 'cpu' else False + ) + else: + samper_train = torch.utils.data.distributed.DistributedSampler(data_train) + loader_train = torch.utils.data.DataLoader( + data_train, + batch_size=args.train.batch_size if not whole_audio else 1, + shuffle=False, + sampler=samper_train, + num_workers=args.train.num_workers if (_ddp_device == 'cpu') else 0, + persistent_workers=(args.train.num_workers > 0) if (_ddp_device == 'cpu') else False, + pin_memory=True if (_ddp_device == 'cpu') else False, + drop_last=True + ) + if ddp and (rank != 0): + loader_valid = None + else: + data_valid = AudioDataset( + args.data.valid_path, + waveform_sec=args.data.duration, + hop_size=args.data.block_size, + sample_rate=args.data.sampling_rate, + load_all_data=args.train.cache_all_data, + whole_audio=True, + extensions=args.data.extensions, + n_spk=args.model.n_spk, + use_spk_encoder=args.model.use_speaker_encoder, + spk_encoder_mode=args.data.speaker_encoder_mode, + volume_noise=volume_noise, + tqdm_rank=rank + ) + loader_valid = torch.utils.data.DataLoader( + data_valid, + batch_size=1, + shuffle=False, + num_workers=0, + pin_memory=True + ) + return loader_train, loader_valid, samper_train class AudioDataset(Dataset): @@ -117,10 +141,17 @@ def __init__( use_aug=False, use_spk_encoder=False, spk_encoder_mode='each_spk', - volume_noise=None + volume_noise=None, + tqdm_rank=0, + load_nothing_data=False ): super().__init__() + if load_all_data: + load_nothing_data = False + self.load_nothing_data = load_nothing_data + self.device = device + self.volume_noise = volume_noise self.waveform_sec = waveform_sec self.sample_rate = sample_rate self.hop_size = hop_size @@ -142,28 +173,29 @@ def __init__( print('Load all the data from :', path_root) else: print('Load the f0, volume data from :', path_root) - for name_ext in tqdm(self.paths, total=len(self.paths)): + for name_ext in tqdm(self.paths, total=len(self.paths), position=tqdm_rank): name = os.path.splitext(name_ext)[0] path_audio = os.path.join(self.path_root, 'audio', name_ext) duration = librosa.get_duration(filename=path_audio, sr=self.sample_rate) - path_f0 = os.path.join(self.path_root, 'f0', name_ext) + '.npy' - f0 = np.load(path_f0) - f0 = torch.from_numpy(f0).float().unsqueeze(-1).to(device) - - path_volume = os.path.join(self.path_root, 'volume', name_ext) + '.npy' - volume = np.load(path_volume) - volume = torch.from_numpy(volume).float().unsqueeze(-1).to(device) - if volume_noise is not None: - _noise = volume_noise * torch.rand(volume.shape,).to(device) - volume = volume + _noise * torch.sign(volume) - - path_augvol = os.path.join(self.path_root, 'aug_vol', name_ext) + '.npy' - aug_vol = np.load(path_augvol) - aug_vol = torch.from_numpy(aug_vol).float().unsqueeze(-1).to(device) - if volume_noise is not None: - _noise = volume_noise * torch.rand(aug_vol.shape,).to(device) - aug_vol = aug_vol + _noise * torch.sign(aug_vol) + if not load_nothing_data: + path_f0 = os.path.join(self.path_root, 'f0', name_ext) + '.npy' + f0 = np.load(path_f0) + f0 = torch.from_numpy(f0).float().unsqueeze(-1).to(device) + + path_volume = os.path.join(self.path_root, 'volume', name_ext) + '.npy' + volume = np.load(path_volume) + volume = torch.from_numpy(volume).float().unsqueeze(-1).to(device) + if volume_noise is not None: + _noise = volume_noise * torch.rand(volume.shape,).to(device) + volume = volume + _noise * torch.sign(volume) + + path_augvol = os.path.join(self.path_root, 'aug_vol', name_ext) + '.npy' + aug_vol = np.load(path_augvol) + aug_vol = torch.from_numpy(aug_vol).float().unsqueeze(-1).to(device) + if volume_noise is not None: + _noise = volume_noise * torch.rand(aug_vol.shape,).to(device) + aug_vol = aug_vol + _noise * torch.sign(aug_vol) if n_spk is not None and n_spk > 1: dirname_split = re.split(r"_|\-", os.path.dirname(name_ext), 2)[0] @@ -228,14 +260,21 @@ def __init__( 'spk_emb': spk_emb } else: - self.data_buffer[name_ext] = { - 'duration': duration, - 'f0': f0, - 'volume': volume, - 'aug_vol': aug_vol, - 'spk_id': spk_id, - 't_spk_id': t_spk_id - } + if not load_nothing_data: + self.data_buffer[name_ext] = { + 'duration': duration, + 'f0': f0, + 'volume': volume, + 'aug_vol': aug_vol, + 'spk_id': spk_id, + 't_spk_id': t_spk_id + } + else: + self.data_buffer[name_ext] = { + 'duration': duration, + 'spk_id': spk_id, + 't_spk_id': t_spk_id + } def __getitem__(self, file_idx): name_ext = self.paths[file_idx] @@ -322,15 +361,29 @@ def get_data(self, name_ext, data_buffer): spk_emb = torch.rand(1, 1) # load f0 - f0 = data_buffer.get('f0') + if self.load_nothing_data: + path_f0 = os.path.join(self.path_root, 'f0', name_ext) + '.npy' + f0 = np.load(path_f0) + f0 = torch.from_numpy(f0).float().unsqueeze(-1).to(self.device) + else: + f0 = data_buffer.get('f0') aug_shift = 0 if aug_flag: aug_shift = self.pitch_aug_dict[name_ext] f0_frames = 2 ** (aug_shift / 12) * f0[start_frame: start_frame + units_frame_len] # load volume - vol_key = 'aug_vol' if aug_flag else 'volume' - volume = data_buffer.get(vol_key) + if self.load_nothing_data: + vol_key = 'aug_vol' if aug_flag else 'volume' + path_volume = os.path.join(self.path_root, vol_key, name_ext) + '.npy' + volume = np.load(path_volume) + volume = torch.from_numpy(volume).float().unsqueeze(-1).to(self.device) + if self.volume_noise is not None: + _noise = self.volume_noise * torch.rand(volume.shape, ).to(self.device) + volume = volume + _noise * torch.sign(volume) + else: + vol_key = 'aug_vol' if aug_flag else 'volume' + volume = data_buffer.get(vol_key) volume_frames = volume[start_frame: start_frame + units_frame_len] # load spk_id diff --git a/diffusion/naive/pcmer.py b/diffusion/naive/pcmer.py index f0eb32a..d88bc3e 100644 --- a/diffusion/naive/pcmer.py +++ b/diffusion/naive/pcmer.py @@ -5,7 +5,6 @@ from functools import partial from einops import rearrange, repeat -from local_attention import LocalAttention import torch.nn.functional as F #import fast_transformers.causal_product.causal_product_cuda @@ -315,6 +314,7 @@ def forward(self, q, k, v): class SelfAttention(nn.Module): def __init__(self, dim, causal = False, heads = 8, dim_head = 64, local_heads = 0, local_window_size = 256, nb_features = None, feature_redraw_interval = 1000, generalized_attention = False, kernel_fn = nn.ReLU(), qr_uniform_q = False, dropout = 0., no_projection = False): super().__init__() + from local_attention import LocalAttention assert dim % heads == 0, 'dimension must be divisible by number of heads' dim_head = default(dim_head, dim // heads) inner_dim = dim_head * heads diff --git a/diffusion/naive/pcmer_onnx.py b/diffusion/naive/pcmer_onnx.py index 7176a95..3fdd6ca 100644 --- a/diffusion/naive/pcmer_onnx.py +++ b/diffusion/naive/pcmer_onnx.py @@ -5,7 +5,6 @@ from functools import partial from einops import rearrange, repeat -from local_attention import LocalAttention import torch.nn.functional as F #import fast_transformers.causal_product.causal_product_cuda @@ -315,6 +314,7 @@ def forward(self, q, k, v): class SelfAttention(nn.Module): def __init__(self, dim, causal = False, heads = 8, dim_head = 64, local_heads = 0, local_window_size = 256, nb_features = None, feature_redraw_interval = 1000, generalized_attention = False, kernel_fn = nn.ReLU(), qr_uniform_q = False, dropout = 0., no_projection = False): super().__init__() + from local_attention import LocalAttention assert dim % heads == 0, 'dimension must be divisible by number of heads' dim_head = default(dim_head, dim // heads) inner_dim = dim_head * heads diff --git a/diffusion/naive/pcmer_siren_full.py b/diffusion/naive/pcmer_siren_full.py index 15b4aa8..d564f4f 100644 --- a/diffusion/naive/pcmer_siren_full.py +++ b/diffusion/naive/pcmer_siren_full.py @@ -5,7 +5,7 @@ from functools import partial from einops import rearrange, repeat from siren import Sine -from local_attention import LocalAttention + import torch.nn.functional as F #import fast_transformers.causal_product.causal_product_cuda @@ -305,6 +305,7 @@ def forward(self, q, k, v): class SelfAttention(nn.Module): # nn.ReLu() def __init__(self, dim, causal = False, heads = 8, dim_head = 64, local_heads = 0, local_window_size = 256, nb_features = None, feature_redraw_interval = 1000, generalized_attention = False, kernel_fn = Sine(w0=1), qr_uniform_q = False, dropout = 0., no_projection = False): super().__init__() + from local_attention import LocalAttention assert dim % heads == 0, 'dimension must be divisible by number of heads' dim_head = default(dim_head, dim // heads) inner_dim = dim_head * heads diff --git a/diffusion/naive/pcmer_siren_full_onnx.py b/diffusion/naive/pcmer_siren_full_onnx.py index 6d4fa85..3a0e899 100644 --- a/diffusion/naive/pcmer_siren_full_onnx.py +++ b/diffusion/naive/pcmer_siren_full_onnx.py @@ -5,7 +5,7 @@ from functools import partial from einops import rearrange, repeat from siren import Sine -from local_attention import LocalAttention + import torch.nn.functional as F #import fast_transformers.causal_product.causal_product_cuda @@ -305,6 +305,7 @@ def forward(self, q, k, v): class SelfAttention(nn.Module): # nn.ReLu() def __init__(self, dim, causal = False, heads = 8, dim_head = 64, local_heads = 0, local_window_size = 256, nb_features = None, feature_redraw_interval = 1000, generalized_attention = False, kernel_fn = Sine(w0=1), qr_uniform_q = False, dropout = 0., no_projection = False): super().__init__() + from local_attention import LocalAttention assert dim % heads == 0, 'dimension must be divisible by number of heads' dim_head = default(dim_head, dim // heads) inner_dim = dim_head * heads diff --git a/diffusion/naive_v2/model_conformer_naive.py b/diffusion/naive_v2/model_conformer_naive.py index a0a19b5..89b9baf 100644 --- a/diffusion/naive_v2/model_conformer_naive.py +++ b/diffusion/naive_v2/model_conformer_naive.py @@ -74,7 +74,8 @@ def __init__(self, atten_dropout: float = 0.1, conv_model_type='mode1', conv_model_activation='SiLU', - GLU_type='GLU' + GLU_type='GLU', + fix_free_norm=False ): super().__init__() self.num_layers = num_layers @@ -83,6 +84,7 @@ def __init__(self, self.use_norm = use_norm self.residual_dropout = 0.1 # 废弃代码,仅做兼容性保留 self.attention_dropout = 0.1 # 废弃代码,仅做兼容性保留 + self.fix_free_norm = fix_free_norm self.encoder_layers = nn.ModuleList( [ @@ -97,7 +99,8 @@ def __init__(self, atten_dropout=atten_dropout, conv_model_type=conv_model_type, conv_model_activation=conv_model_activation, - GLU_type=GLU_type + GLU_type=GLU_type, + fix_free_norm=fix_free_norm ) for _ in range(num_layers) ] @@ -143,7 +146,8 @@ def __init__(self, atten_dropout: float = 0.1, conv_model_type='mode1', conv_model_activation='SiLU', - GLU_type='GLU' + GLU_type='GLU', + fix_free_norm=False ): super().__init__() @@ -157,8 +161,13 @@ def __init__(self, activation=conv_model_activation, GLU_type=GLU_type ) - - self.norm = nn.LayerNorm(dim_model) + if conv_only: + if not fix_free_norm: + self.norm = nn.LayerNorm(dim_model) + else: + self.norm = None + else: + self.norm = nn.LayerNorm(dim_model) self.dropout = nn.Dropout(0.1) # 废弃代码,仅做兼容性保留 @@ -200,7 +209,7 @@ def __init__( use_norm=False, conv_model_type='mode1', activation='SiLU', - GLU_type='GLU', + GLU_type='GLU' ): super().__init__() diff --git a/diffusion/naive_v2/naive_v2.py b/diffusion/naive_v2/naive_v2.py index bf54644..2a52f9a 100644 --- a/diffusion/naive_v2/naive_v2.py +++ b/diffusion/naive_v2/naive_v2.py @@ -40,6 +40,7 @@ def __init__( self.atten_dropout = net_fn.atten_dropout if (net_fn.atten_dropout is not None) else 0.1 self.conv_model_activation = net_fn.conv_model_activation if (net_fn.conv_model_activation is not None)\ else 'SiLU' + self.fix_free_norm = net_fn.fix_free_norm if (net_fn.fix_free_norm is not None) else False self.decoder = ConformerNaiveEncoder( num_layers=self.n_layers, @@ -52,7 +53,8 @@ def __init__( conv_dropout=self.conv_dropout, atten_dropout=self.atten_dropout, conv_model_type=self.conv_model_type, - conv_model_activation=self.conv_model_activation + conv_model_activation=self.conv_model_activation, + fix_free_norm=self.fix_free_norm, ) else: raise ValueError(f'net_fn.type={net_fn.type} is not supported') diff --git a/diffusion/naive_v2/naive_v2_diff.py b/diffusion/naive_v2/naive_v2_diff.py index 78b52bf..3dca64e 100644 --- a/diffusion/naive_v2/naive_v2_diff.py +++ b/diffusion/naive_v2/naive_v2_diff.py @@ -46,7 +46,8 @@ def __init__(self, conv_model_type='mode1', no_t_emb=False, conv_model_activation='SiLU', - GLU_type='GLU' + GLU_type='GLU', + fix_free_norm=False ): super().__init__() @@ -58,9 +59,8 @@ def __init__(self, use_norm=use_norm, conv_model_type=conv_model_type, activation=conv_model_activation, - GLU_type=GLU_type + GLU_type=GLU_type, ) - self.norm = nn.LayerNorm(dim_model) # 请务必注意这是个可学习的层,但模型并没有使用到这一层,在很多地方backward都会出现严重问题,但是在该项目中不会,但直接删除该层会导致加载过往的权重失败,待解决 self.dropout = nn.Dropout(0.1) # 废弃代码,仅做兼容性保留 if wavenet_like: @@ -84,8 +84,13 @@ def __init__(self, dropout=atten_dropout, activation='gelu' ) + self.norm = nn.LayerNorm(dim_model) else: self.attn = None + if fix_free_norm: + self.norm = None + else: + self.norm = nn.LayerNorm(dim_model) def forward(self, x, condition=None, diffusion_step=None) -> torch.Tensor: res_x = x.transpose(1, 2) @@ -129,13 +134,20 @@ def __init__( atten_dropout=0.1, no_t_emb=False, conv_model_activation='SiLU', - GLU_type='GLU' + GLU_type='GLU', + fix_free_norm=False, + channel_norm=False ): super(NaiveV2Diff, self).__init__() self.no_t_emb = no_t_emb if (no_t_emb is not None) else False self.wavenet_like = wavenet_like self.mask_cond_ratio = None + if channel_norm: + self.channel_norm = nn.LayerNorm(dim) + else: + self.channel_norm = None + self.input_projection = nn.Conv1d(mel_channels, dim, 1) if self.no_t_emb: self.diffusion_embedding = None @@ -173,7 +185,8 @@ def __init__( conv_model_type=conv_model_type, no_t_emb=self.no_t_emb, conv_model_activation=conv_model_activation, - GLU_type=GLU_type + GLU_type=GLU_type, + fix_free_norm=fix_free_norm ) for i in range(num_layers) ] @@ -250,6 +263,11 @@ def forward(self, spec, diffusion_step, cond): # forward x = layer(x, condition, diffusion_step) + if self.channel_norm is not None: + x = x.transpose(-1, -2) + x = self.channel_norm(x) + x = x.transpose(-1, -2) + # MLP and GLU x = self.output_projection(x) # [B, 128, T] diff --git a/diffusion/reflow/reflow.py b/diffusion/reflow/reflow.py index dafbe03..9b75290 100644 --- a/diffusion/reflow/reflow.py +++ b/diffusion/reflow/reflow.py @@ -14,13 +14,24 @@ def __init__(self, out_dims=128, spec_min=-12, spec_max=2, - loss_type='l2'): + loss_type='l2', + consistency=False, + consistency_only=True, + consistency_delta_t=0.1, + consistency_lambda_f=1.0, + consistency_lambda_v=1.0, + ): super().__init__() self.velocity_fn = velocity_fn self.out_dims = out_dims self.spec_min = spec_min self.spec_max = spec_max self.loss_type = loss_type + self.consistency = consistency + self.consistency_only = consistency_only + self.consistency_delta_t = consistency_delta_t + self.consistency_lambda_f = consistency_lambda_f + self.consistency_lambda_v = consistency_lambda_v def reflow_loss(self, x_1, t, cond, loss_type=None): x_0 = torch.randn_like(x_1) @@ -44,6 +55,30 @@ def reflow_loss(self, x_1, t, cond, loss_type=None): return loss + # 一致性损失 + def reflow_consistency_loss(self, x_1, t_a, t_b, cond, loss_type=None): + x_0 = torch.randn_like(x_1) + x_t_a = x_0 + t_a[:, None, None, None] * (x_1 - x_0) + x_t_b = x_0 + t_b[:, None, None, None] * (x_1 - x_0) + v_pred_a = self.velocity_fn(x_t_a, 1000 * t_a, cond) + v_pred_b = self.velocity_fn(x_t_b, 1000 * t_b, cond).detach() + f_pred_a = x_t_a + (1 - t_a[:, None, None, None]) * v_pred_a + f_pred_b = x_t_b + (1 - t_b[:, None, None, None]) * v_pred_b + if loss_type is None: + loss_type = self.loss_type + else: + loss_type = loss_type + + if loss_type == 'l2': + loss_f = F.mse_loss(f_pred_a, f_pred_b) / self.consistency_delta_t ** 2 + loss_v = F.mse_loss(v_pred_a, v_pred_b) + loss = self.consistency_lambda_f * loss_f + self.consistency_lambda_v * loss_v + else: + raise NotImplementedError() + + return loss + + def sample_euler(self, x, t, dt, cond): x += self.velocity_fn(x, 1000 * t, cond) * dt t += dt @@ -117,7 +152,7 @@ def sample_rk5_fp64(self, x, t, dt, cond): t += dt return x, t - def sample_heun(self, x, t, dt, cond=None): + def sample_heun(self, x, t, dt, cond): # Predict k_1 = self.velocity_fn(x, 1000 * t, cond=cond) x_pred = x + k_1 * dt @@ -128,7 +163,7 @@ def sample_heun(self, x, t, dt, cond=None): t += dt return x, t - def sample_PECECE(self, x, t, dt, cond=None): + def sample_PECECE(self, x, t, dt, cond): # Predict1 k_1 = self.velocity_fn(x, 1000 * t, cond=cond) x_pred1 = x + k_1 * dt @@ -144,6 +179,15 @@ def sample_PECECE(self, x, t, dt, cond=None): x += (k_3 + k_4) / 2 * dt t += dt return x, t + + def sample_rf_solver(self, x, t, dt, cond): + v_t = self.velocity_fn(x, 1000 * t, cond=cond) + x_half = x + v_t * dt / 2 + v_half = self.velocity_fn(x_half, 1000 * (t + dt / 2), cond=cond) + v_prime = (v_half - v_t) / (dt / 2) + x += v_t * dt + v_prime / 2 * (dt ** 2) + t += dt + return x, t def forward(self, condition, @@ -162,11 +206,25 @@ def forward(self, if t_start > 1.0: t_start = 1.0 if not infer: - x_1 = self.norm_spec(gt_spec) - x_1 = x_1.transpose(1, 2)[:, None, :, :] # [B, 1, M, T] - t = t_start + (1.0 - t_start) * torch.rand(b, device=device) - t = torch.clip(t, 1e-7, 1 - 1e-7) - return self.reflow_loss(x_1, t, cond=cond) + if self.consistency: + x_1 = self.norm_spec(gt_spec) + x_1 = x_1.transpose(1, 2)[:, None, :, :] # [B, 1, M, T] + t = t_start + (1.0 - t_start) * torch.rand(b, device=device) + dt = self.consistency_delta_t * torch.randn(b, device=device).abs() + t_a = torch.clip(t - 0.5 * dt, t_start, 1) + t_b = torch.clip(t + 0.5 * dt, t_start, 1) + consistency_loss = self.reflow_consistency_loss(x_1, t_a, t_b, cond=cond) + if self.consistency_only: + return consistency_loss + else: + reflow_loss = self.reflow_loss(x_1, t_a, cond=cond) + return reflow_loss, consistency_loss + else: + x_1 = self.norm_spec(gt_spec) + x_1 = x_1.transpose(1, 2)[:, None, :, :] # [B, 1, M, T] + t = t_start + (1.0 - t_start) * torch.rand(b, device=device) + t = torch.clip(t, 1e-7, 1 - 1e-7) + return self.reflow_loss(x_1, t, cond=cond) else: shape = (cond.shape[0], 1, self.out_dims, cond.shape[2]) # [B, 1, M, T] @@ -261,6 +319,14 @@ def forward(self, else: for i in range(infer_step): x, t = self.sample_PECECE(x, t, dt, cond) + + elif method == 'RF-Solver': + if use_tqdm: + for i in tqdm(range(infer_step), desc='sample time step', total=infer_step): + x, t = self.sample_rf_solver(x, t, dt, cond) + else: + for i in range(infer_step): + x, t = self.sample_rf_solver(x, t, dt, cond) else: raise NotImplementedError(method) diff --git a/diffusion/reflow/reflow_1step.py b/diffusion/reflow/reflow_1step.py index 138fde7..b1b23b6 100644 --- a/diffusion/reflow/reflow_1step.py +++ b/diffusion/reflow/reflow_1step.py @@ -14,37 +14,43 @@ def __init__(self, out_dims=128, spec_min=-12, spec_max=2, - loss_type='l2'): + loss_type='l2', + consistency=True + ): super().__init__() self.velocity_fn = velocity_fn self.out_dims = out_dims self.spec_min = spec_min self.spec_max = spec_max self.loss_type = loss_type - if loss_type != "l2": - raise ValueError("loss_type must be l2 when use ReFlow1Step") + self.consistency = consistency - def reflow_loss(self, x_1, t, cond, loss_type=None): + def reflow_loss(self, x_1, t, t0, cond, loss_type=None): x_0 = torch.randn_like(x_1) x_t = x_0 + t[:, None, None, None] * (x_1 - x_0) v_pred = self.velocity_fn(x_t, 1000 * t, cond) + v_pred_0 = self.velocity_fn(x_0, 1000 * t0, cond) if loss_type is None: loss_type = self.loss_type else: loss_type = loss_type + consistency_loss = 0 if loss_type == 'l1': - loss = (x_1 - x_0 - v_pred).abs().mean() + xt_loss = (x_1 - x_0 - v_pred).abs().mean() + x0_loss = (x_0 - v_pred_0).abs().mean() + if self.consistency: + consistency_loss = (v_pred.detach() - v_pred_0).abs().mean() elif loss_type == 'l2': - loss = F.mse_loss(x_1 - x_0, v_pred) - elif loss_type == 'l2_lognorm': - weights = 0.398942 / t / (1 - t) * torch.exp(-0.5 * torch.log(t / ( 1 - t)) ** 2) - loss = torch.mean(weights[:, None, None, None] * F.mse_loss(x_1 - x_0, v_pred, reduction='none')) + xt_loss = F.mse_loss(x_1 - x_0, v_pred) + x0_loss = F.mse_loss(x_0, v_pred_0) + if self.consistency: + consistency_loss = F.mse_loss(v_pred.detach(), v_pred_0) else: raise NotImplementedError() - return loss + return xt_loss, x0_loss, consistency_loss def sample_euler(self, x, t, dt, cond): x += self.velocity_fn(x, 1000 * t, cond) * dt @@ -146,22 +152,17 @@ def sample_PECECE(self, x, t, dt, cond=None): x += (k_3 + k_4) / 2 * dt t += dt return x, t - + def forward(self, condition, gt_spec=None, infer=True, - infer_step=1, + infer_step=10, method='euler', t_start=0.0, use_tqdm=True): cond = condition.transpose(1, 2) # [B, H, T] b, device = condition.shape[0], condition.device - if infer_step != 1: - raise ValueError("infer step must be 1 when use ReFlow1Step") - if method != "euler": - raise ValueError("euler must be euler when use ReFlow1Step") - if t_start is None: t_start = 0.0 if t_start < 0.0: @@ -171,9 +172,11 @@ def forward(self, if not infer: x_1 = self.norm_spec(gt_spec) x_1 = x_1.transpose(1, 2)[:, None, :, :] # [B, 1, M, T] - t = torch.full((b,), t_start, device=device) - # t = torch.clip(t, 1e-7, 1 - 1e-7) - return self.reflow_loss(x_1, t, cond=cond) + t0 = t_start * torch.ones(b, device=device) + t = t_start + (1.0 - t_start) * torch.rand(b, device=device) + xt_loss, x0_loss, consistency_loss = self.reflow_loss(x_1, t, t0, cond) + xt_loss = 10 * xt_loss + return xt_loss, x0_loss, consistency_loss else: shape = (cond.shape[0], 1, self.out_dims, cond.shape[2]) # [B, 1, M, T] diff --git a/diffusion/solver.py b/diffusion/solver.py index c715d80..180792a 100644 --- a/diffusion/solver.py +++ b/diffusion/solver.py @@ -6,8 +6,9 @@ from logger.saver import Saver from logger import utils from torch import autocast -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from nsf_hifigan.nvSTFT import STFT +import random WAV_TO_MEL = None @@ -51,9 +52,15 @@ def calculate_mel_psnr(gt_mel, pred_mel): return psnr -def test(args, model, vocoder, loader_test, saver): +def test(args, model, vocoder, loader_test, saver, device): print(' [*] testing...') model.eval() + if args.vocoder.type == 'hifivaegan': + use_vae = True + elif args.vocoder.type == 'hifivaegan2': + use_vae = True + else: + use_vae = False # losses test_loss = 0. @@ -82,7 +89,7 @@ def test(args, model, vocoder, loader_test, saver): # unpack data for k in data.keys(): if not k.startswith('name'): - data[k] = data[k].to(args.device) + data[k] = data[k].to(device) print('>>', data['name'][0]) # forward @@ -98,7 +105,9 @@ def test(args, model, vocoder, loader_test, saver): infer_step=args.infer.infer_step, method=args.infer.method, t_start=args.model.t_start, - spk_emb=data['spk_emb']) + spk_emb=data['spk_emb'], + use_vae=use_vae + ) else: mel = model( data['units'], @@ -110,7 +119,9 @@ def test(args, model, vocoder, loader_test, saver): infer_speedup=args.infer.speedup, k_step=args.model.k_step_max, method=args.infer.method, - spk_emb=data['spk_emb']) + spk_emb=data['spk_emb'], + use_vae=use_vae + ) signal = vocoder.infer(mel, data['f0']) ed_time = time.time() @@ -133,7 +144,7 @@ def test(args, model, vocoder, loader_test, saver): infer=False, t_start=args.model.t_start, spk_emb=data['spk_emb'], - use_vae=(args.vocoder.type == 'hifivaegan') + use_vae=use_vae ) else: loss_dict = model( @@ -145,7 +156,7 @@ def test(args, model, vocoder, loader_test, saver): infer=False, k_step=args.model.k_step_max, spk_emb=data['spk_emb'], - use_vae=(args.vocoder.type == 'hifivaegan') + use_vae=use_vae ) _loss = 0 if not isinstance(loss_dict, dict): @@ -170,6 +181,8 @@ def test(args, model, vocoder, loader_test, saver): # log mel if args.vocoder.type == 'hifivaegan': log_from_signal = True + elif args.vocoder.type == 'hifivaegan2': + log_from_signal = True else: log_from_signal = False @@ -196,9 +209,10 @@ def test(args, model, vocoder, loader_test, saver): gt_mel = gt_mel.transpose(-1, -2) # 如果形状不同,裁剪使得形状相同 if pre_mel.shape[1] != gt_mel.shape[1]: - gt_mel = gt_mel[:, :pre_mel.shape[1], :] + gt_mel = gt_mel[:, :min(pre_mel.shape[1],gt_mel.shape[1]), :] saver.log_spec(data['name'][0], gt_mel, pre_mel) # 计算指标 + spec_range = 14 # for mel mel_val_mse_all += torch.nn.functional.mse_loss(pre_mel, gt_mel).detach().cpu().numpy() gt_mel_norm = torch.clip(gt_mel, spec_min, spec_max) gt_mel_norm = gt_mel_norm / spec_range + spec_min @@ -254,7 +268,8 @@ def test(args, model, vocoder, loader_test, saver): return test_loss_dict, test_loss -def train(args, initial_global_step, model, optimizer, scheduler, vocoder, loader_train, loader_test): +def train(rank, args, initial_global_step, model, optimizer, scheduler, vocoder, loader_train, loader_test, device, + ddp=False, samper_train=None): # saver saver = Saver(args, initial_global_step=initial_global_step) @@ -262,15 +277,19 @@ def train(args, initial_global_step, model, optimizer, scheduler, vocoder, loade params_count = utils.get_network_paras_amount({'model': model}) saver.log_info('--- model size ---') saver.log_info(params_count) + last_decay_step = args.train.last_decay_step if args.vocoder.type == 'hifivaegan': use_vae = True + elif args.vocoder.type == 'hifivaegan2': + use_vae = True else: use_vae = False # set up EMA if args.train.use_ema: - ema_model = torch.optim.swa_utils.AveragedModel(model, multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(args.train.ema_decay)) - saver.log_info('ModelEmaV2 is enable') + raise NotImplementedError(' [x] EMA is not supported now.') + #ema_model = torch.optim.swa_utils.AveragedModel(model, multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(args.train.ema_decay)) + #saver.log_info('ModelEmaV2 is enable') # run num_batches = len(loader_train) @@ -287,14 +306,34 @@ def train(args, initial_global_step, model, optimizer, scheduler, vocoder, loade else: raise ValueError(' [x] Unknown amp_dtype: ' + args.train.amp_dtype) for epoch in range(start_epoch, args.train.epochs): + if ddp: + samper_train.set_epoch(epoch) for batch_idx, data in enumerate(loader_train): saver.global_step_increment() optimizer.zero_grad() # unpack data + duration_random_range = args.data.duration_random_range + if duration_random_range is None: + duration_random_range = 0.0 + if float(duration_random_range) > 0.0: + random_dur = random.random() * duration_random_range + sli_frame = random_dur // args.data.sampling_rate + n_frame = int(data['units'].shape[1]) + sli_frame = n_frame - sli_frame + sli_frame = random.randint(0, n_frame - sli_frame) + else: + sli_frame = None for k in data.keys(): if not k.startswith('name'): - data[k] = data[k].to(args.device) + data[k] = data[k].to(device) + if sli_frame is not None: + if k in ['units','f0','volume','mel']: + _shape_len = len(data[k].shape) + if _shape_len == 3: + data[k] = data[k][:, sli_frame:, :] + elif _shape_len == 4: + data[k] = data[k][:, sli_frame:, :, :] # forward if (args.model.type == 'ReFlow') or (args.model.type == 'ReFlow1Step'): @@ -305,7 +344,7 @@ def train(args, initial_global_step, model, optimizer, scheduler, vocoder, loade t_start=args.model.t_start, spk_emb=data['spk_emb'], use_vae=use_vae) else: - with autocast(device_type=args.device, dtype=dtype): + with autocast(device_type=device, dtype=dtype): loss_dict = model(data['units'], data['f0'], data['volume'], data['spk_id'], aug_shift=data['aug_shift'], gt_spec=data['mel'], infer=False, t_start=args.model.t_start, @@ -318,7 +357,7 @@ def train(args, initial_global_step, model, optimizer, scheduler, vocoder, loade k_step=args.model.k_step_max, spk_emb=data['spk_emb'], use_vae=use_vae) else: - with autocast(device_type=args.device, dtype=dtype): + with autocast(device_type=device, dtype=dtype): loss_dict = model(data['units'], data['f0'], data['volume'], data['spk_id'], aug_shift=data['aug_shift'], gt_spec=data['mel'], infer=False, k_step=args.model.k_step_max, @@ -357,74 +396,87 @@ def train(args, initial_global_step, model, optimizer, scheduler, vocoder, loade scaler.update() if args.train.use_ema: - ema_model.update_parameters(model) - - scheduler.step() + raise NotImplementedError(' [x] EMA is not supported now.') + #ema_model.update_parameters(model) + + if last_decay_step is not None: + # 如果在last_decay_step步数之后,则不再更新学习率 + if saver.global_step <= last_decay_step: + scheduler.step() + else: + scheduler.step() # log loss - if saver.global_step % args.train.interval_log == 0: - current_lr = optimizer.param_groups[0]['lr'] - saver.log_info( - 'epoch: {} | {:3d}/{:3d} | {} | batch/s: {:.2f} | lr: {:.6} | loss: {:.3f} | time: {} | step: {}'.format( - epoch, - batch_idx, - num_batches, - args.env.expdir, - args.train.interval_log / saver.get_interval_time(), - current_lr, - loss.item(), - saver.get_total_time(), - saver.global_step + if rank == 0: + if saver.global_step % args.train.interval_log == 0: + current_lr = optimizer.param_groups[0]['lr'] + saver.log_info( + 'epoch: {} | {:3d}/{:3d} | {} | batch/s: {:.2f} | lr: {:.6} | loss: {:.3f} | time: {} | step: {}'.format( + epoch, + batch_idx, + num_batches, + args.env.expdir, + args.train.interval_log / saver.get_interval_time(), + current_lr, + loss.item(), + saver.get_total_time(), + saver.global_step + ) ) - ) - - saver.log_value({ - 'train/loss': loss.item() - }) - for k in loss_float_dict.keys(): saver.log_value({ - 'train/' + k: loss_float_dict[k] + 'train/loss': loss.item() }) - saver.log_value({ - 'train/lr': current_lr - }) + for k in loss_float_dict.keys(): + saver.log_value({ + 'train/' + k: loss_float_dict[k] + }) - # validation - if saver.global_step % args.train.interval_val == 0: - optimizer_save = optimizer if args.train.save_opt else None - - # save latest - if args.train.use_ema: - saver.save_model(ema_model.module, optimizer_save, postfix=f'{saver.global_step}') - else: - saver.save_model(model, optimizer_save, postfix=f'{saver.global_step}') - - last_val_step = saver.global_step - args.train.interval_val - if last_val_step % args.train.interval_force_save != 0: - saver.delete_model(postfix=f'{last_val_step}') + saver.log_value({ + 'train/lr': current_lr + }) - # run testing set - if args.train.use_ema: - test_loss_dict, test_loss = test(args, ema_model, vocoder, loader_test, saver) - else: - test_loss_dict, test_loss = test(args, model, vocoder, loader_test, saver) + # validation + if rank == 0: + if saver.global_step % args.train.interval_val == 0: + optimizer_save = optimizer if args.train.save_opt else None + + # save latest + if args.train.use_ema: + raise NotImplementedError(' [x] EMA is not supported now.') + #saver.save_model(ema_model.module, optimizer_save, postfix=f'{saver.global_step}') + else: + if ddp: + saver.save_model(model.module, optimizer_save, postfix=f'{saver.global_step}') + else: + saver.save_model(model, optimizer_save, postfix=f'{saver.global_step}') + + last_val_step = saver.global_step - args.train.interval_val + if last_val_step % args.train.interval_force_save != 0: + saver.delete_model(postfix=f'{last_val_step}') + + # run testing set + if args.train.use_ema: + raise NotImplementedError(' [x] EMA is not supported now.') + #test_loss_dict, test_loss = test(args, ema_model, vocoder, loader_test, saver) + else: + test_loss_dict, test_loss = test(args, model, vocoder, loader_test, saver, device) - # log loss - saver.log_info( - ' --- --- \nloss: {:.3f}. '.format( - test_loss, + # log loss + saver.log_info( + ' --- --- \nloss: {:.3f}. '.format( + test_loss, + ) ) - ) - - saver.log_value({ - 'validation/loss': test_loss - }) - for k in test_loss_dict.keys(): saver.log_value({ - 'validation/' + k: test_loss_dict[k] + 'validation/loss': test_loss }) - model.train() + for k in test_loss_dict.keys(): + saver.log_value({ + 'validation/' + k: test_loss_dict[k] + }) + + model.train() diff --git a/diffusion/unit2mel.py b/diffusion/unit2mel.py index 71fa136..3e21c30 100644 --- a/diffusion/unit2mel.py +++ b/diffusion/unit2mel.py @@ -100,6 +100,8 @@ def get_network_from_dot(netdot, out_dims, cond_dims): no_t_emb = netdot.no_t_emb if (netdot.no_t_emb is not None) else False conv_model_activation = netdot.conv_model_activation if (netdot.conv_model_activation is not None) else 'SiLU' GLU_type = netdot.GLU_type if (netdot.GLU_type is not None) else 'GLU' + fix_free_norm = netdot.fix_free_norm if (netdot.fix_free_norm is not None) else False + channel_norm = netdot.channel_norm if (netdot.channel_norm is not None) else False # init convnext denoiser denoiser = NaiveV2Diff( mel_channels=out_dims, @@ -118,7 +120,9 @@ def get_network_from_dot(netdot, out_dims, cond_dims): atten_dropout=atten_dropout, no_t_emb=no_t_emb, conv_model_activation=conv_model_activation, - GLU_type=GLU_type + GLU_type=GLU_type, + fix_free_norm=fix_free_norm, + channel_norm=channel_norm ) else: @@ -127,7 +131,7 @@ def get_network_from_dot(netdot, out_dims, cond_dims): return denoiser -def get_z(stack_tensor, mean_only=False): +def get_z(stack_tensor, mean_only=False, clip_min=None, clip_max=None): # stack_tensor: [B x N x D x 2] # sample z, or mean only m = stack_tensor.transpose(-1, 0)[:1].transpose(-1, 0).squeeze(-1) @@ -136,6 +140,10 @@ def get_z(stack_tensor, mean_only=False): z = m # mean only else: z = m + torch.randn_like(m) * torch.exp(logs) # sample z + if (clip_min is not None) or (clip_max is not None): + assert clip_min is not None + assert clip_max is not None + z = z.clamp(min=clip_min, max=clip_max) return z # [B x N x D] @@ -158,7 +166,7 @@ def load_model_vocoder( model = load_svc_model(args=args, vocoder_dimension=vocoder.dimension) print(' [Loading] ' + model_path) - ckpt = torch.load(model_path, map_location=torch.device(device)) + ckpt = torch.load(model_path, map_location=torch.device(device), weights_only=True) model.to(device) model.load_state_dict(ckpt['model']) model.eval() @@ -166,7 +174,7 @@ def load_model_vocoder( def load_model_vocoder_from_combo(combo_model_path, device='cpu', loaded_vocoder=None): - read_dict = torch.load(combo_model_path, map_location=torch.device(device)) + read_dict = torch.load(combo_model_path, map_location=torch.device(device), weights_only=True) # 检查是否有键名“_version_” if '_version_' in read_dict.keys(): raise ValueError(" [X] 这是新版本的模型, 请在新仓库中使用") @@ -268,7 +276,13 @@ def load_svc_model(args, vocoder_dimension): naive_fn=args.model.naive_fn, naive_fn_grad_not_by_reflow=args.model.naive_fn_grad_not_by_reflow, naive_out_mel_cond_reflow=args.model.naive_out_mel_cond_reflow, - loss_type=args.model.loss_type,) + loss_type=args.model.loss_type, + consistency=args.model.consistency, + consistency_only=args.model.consistency_only, + consistency_delta_t=args.model.consistency_delta_t, + consistency_lambda_f=args.model.consistency_lambda_f, + consistency_lambda_v=args.model.consistency_lambda_v, + ) elif args.model.type == 'ReFlow1Step': model = Unit2MelV2ReFlow1Step( @@ -289,7 +303,9 @@ def load_svc_model(args, vocoder_dimension): naive_fn=args.model.naive_fn, naive_fn_grad_not_by_reflow=args.model.naive_fn_grad_not_by_reflow, naive_out_mel_cond_reflow=args.model.naive_out_mel_cond_reflow, - loss_type=args.model.loss_type, ) + loss_type=args.model.loss_type, + consistency=args.model.consistency, + ) elif args.model.type == 'Naive': model = Unit2MelNaive( @@ -429,6 +445,8 @@ def __init__( mask_cond_ratio = float(mask_cond_ratio) if (str(mask_cond_ratio) != 'NOTUSE') else -99 if mask_cond_ratio > 0: self.mask_cond_ratio = mask_cond_ratio + # 未实现错误 + raise NotImplementedError(" [X] mask_cond_ratio is not implemented.") else: self.mask_cond_ratio = None else: @@ -526,7 +544,7 @@ def forward(self, units, f0, volume, spk_id=None, spk_mix_dict=None, aug_shift=N # sample z or mean only if use_vae and (gt_spec is not None): - gt_spec = get_z(gt_spec, mean_only=self.mean_only) + gt_spec = get_z(gt_spec, mean_only=self.mean_only, clip_min=self.spec_min, clip_max=self.spec_max) if (self.z_rate is not None) and (self.z_rate != 0): gt_spec = gt_spec * self.z_rate # scale z @@ -615,7 +633,13 @@ def spawn_decoder(self, velocity_fn, out_dims): out_dims=out_dims, spec_min=self.spec_min, spec_max=self.spec_max, - loss_type=self.loss_type) + loss_type=self.loss_type, + consistency=self.consistency, + consistency_only=self.consistency_only, + consistency_delta_t=self.consistency_delta_t, + consistency_lambda_f=self.consistency_lambda_f, + consistency_lambda_v=self.consistency_lambda_v, + ) return decoder def __init__( @@ -637,9 +661,19 @@ def __init__( naive_fn=None, naive_fn_grad_not_by_reflow=False, naive_out_mel_cond_reflow=True, - loss_type='l2' + loss_type='l2', + consistency=False, + consistency_only=True, + consistency_delta_t=0.1, + consistency_lambda_f=1.0, + consistency_lambda_v=1.0, ): self.loss_type = loss_type if (loss_type is not None) else 'l2' + self.consistency = consistency if (consistency is not None) else False + self.consistency_only = consistency_only if (consistency_only is not None) else True + self.consistency_delta_t = consistency_delta_t if (consistency_delta_t is not None) else 0.1 + self.consistency_lambda_f = consistency_lambda_f if (consistency_lambda_f is not None) else 1.0 + self.consistency_lambda_v = consistency_lambda_v if (consistency_lambda_v is not None) else 1.0 super().__init__( input_channel, n_spk, @@ -678,7 +712,7 @@ def forward(self, units, f0, volume, spk_id=None, spk_mix_dict=None, aug_shift=N # sample z or mean only if use_vae and (gt_spec is not None): - gt_spec = get_z(gt_spec, mean_only=self.mean_only) + gt_spec = get_z(gt_spec, mean_only=self.mean_only, clip_min=self.spec_min, clip_max=self.spec_max) if (self.z_rate is not None) and (self.z_rate != 0): gt_spec = gt_spec * self.z_rate # scale z @@ -696,8 +730,9 @@ def forward(self, units, f0, volume, spk_id=None, spk_mix_dict=None, aug_shift=N if infer: x = self.decoder(x, gt_spec=gt_spec, infer=True, infer_step=infer_step, method=method, t_start=t_start, use_tqdm=use_tqdm) + _step_loss_dict = None else: - x = self.decoder(x, gt_spec=gt_spec, t_start=t_start, infer=False) + _step_loss_dict = self.step_train(x, gt_spec, t_start) # mask cond end self.mask_cond_train_end() @@ -707,12 +742,43 @@ def forward(self, units, f0, volume, spk_id=None, spk_mix_dict=None, aug_shift=N x = x / self.z_rate # scale z if not infer: - if self.combo_trained_model: + return self.make_loss_dict(_step_loss_dict, naive_loss) + else: + return x + + def make_loss_dict(self, _step_loss_dict, naive_loss): + x = _step_loss_dict['x'] + co_loss = _step_loss_dict['co_loss'] + if self.combo_trained_model: + if self.consistency: + if self.consistency_only: + return {'reflow_consistency_loss': x, 'naive_loss': naive_loss} + else: + return {'consistency_loss': co_loss, + 'reflow_loss': x, + 'naive_loss': naive_loss} + else: return {'reflow_loss': x, 'naive_loss': naive_loss} + else: + if self.consistency: + if self.consistency_only: + return {'reflow_consistency_loss': (x + naive_loss)} + else: + return {'consistency_loss': co_loss, + 'reflow_loss': (x + naive_loss)} else: return {'reflow_loss': (x + naive_loss)} - return x + def step_train(self, x, gt_spec, t_start): + co_loss = None + if self.consistency: + if self.consistency_only: + x = self.decoder(x, gt_spec=gt_spec, t_start=t_start, infer=False) + else: + x, co_loss = self.decoder(x, gt_spec=gt_spec, t_start=t_start, infer=False) + else: + x = self.decoder(x, gt_spec=gt_spec, t_start=t_start, infer=False) + return {'x':x, 'co_loss':co_loss} class Unit2MelV2ReFlow1Step(Unit2MelV2ReFlow): @@ -722,8 +788,72 @@ def spawn_decoder(self, velocity_fn, out_dims): out_dims=out_dims, spec_min=self.spec_min, spec_max=self.spec_max, - loss_type=self.loss_type) + loss_type=self.loss_type, + consistency=self.x0_xt_consistency + ) return decoder + def __init__( + self, + input_channel, + n_spk, + use_pitch_aug=False, + out_dims=128, + n_hidden=256, + use_speaker_encoder=False, + speaker_encoder_out_channels=256, + z_rate=None, + mean_only=False, + max_beta=0.02, # 暂时废弃,但是极有可能未来会有用吧,所以先不删除, 可以为None + spec_min=-12, + spec_max=2, + velocity_fn=None, + mask_cond_ratio=None, + naive_fn=None, + naive_fn_grad_not_by_reflow=False, + naive_out_mel_cond_reflow=True, + loss_type='l2', + consistency=True, + ): + self.x0_xt_consistency = consistency if (consistency is not None) else True + super().__init__( + input_channel, + n_spk, + use_pitch_aug=use_pitch_aug, + out_dims=out_dims, + n_hidden=n_hidden, + use_speaker_encoder=use_speaker_encoder, + speaker_encoder_out_channels=speaker_encoder_out_channels, + z_rate=z_rate, + mean_only=mean_only, + max_beta=max_beta, + spec_min=spec_min, + spec_max=spec_max, + velocity_fn=velocity_fn, + mask_cond_ratio=mask_cond_ratio, + naive_fn=naive_fn, + naive_fn_grad_not_by_reflow=naive_fn_grad_not_by_reflow, + naive_out_mel_cond_reflow=naive_out_mel_cond_reflow, + loss_type=loss_type, + consistency=False, + ) + + def step_train(self, x, gt_spec, t_start): + xt_loss, x0_loss, consistency_loss = self.decoder(x, gt_spec=gt_spec, t_start=t_start, infer=False) + return {'x':xt_loss, 'x0_loss':x0_loss, 'consistency_loss':consistency_loss} + + def make_loss_dict(self, _step_loss_dict, naive_loss): + x = _step_loss_dict['x'] + x0_loss = _step_loss_dict['x0_loss'] + consistency_loss = _step_loss_dict['consistency_loss'] + if self.combo_trained_model: + return {'reflow_loss': x, + 'reflow_loss_x0':x0_loss, + 'consistency_loss': consistency_loss, + 'naive_loss': naive_loss} + else: + return {'reflow_loss': (x + naive_loss), + 'reflow_loss_x0': x0_loss, + 'consistency_loss': consistency_loss} class Unit2Mel(nn.Module): diff --git a/diffusion/vocoder.py b/diffusion/vocoder.py index 0a26a36..01b9dab 100644 --- a/diffusion/vocoder.py +++ b/diffusion/vocoder.py @@ -10,6 +10,8 @@ from encoder.evagan import EVAGANBase as EVABase from encoder.evagan import EVAGANBig as EVABig from encoder.dct.dct import DCT, IDCT +from encoder.wavs.wavs import WAVS, IWAVS +from encoder.hifi_vaegan2.modules.models import Encoder, Generator def load_vocoder_for_save(vocoder_type, model_path, device='cpu'): @@ -19,6 +21,8 @@ def load_vocoder_for_save(vocoder_type, model_path, device='cpu'): vocoder = NsfHifiGANLog10(model_path, device=device) elif vocoder_type == 'hifivaegan': vocoder = HiFiVAEGAN(model_path, device=device) + elif vocoder_type == 'hifivaegan2': + vocoder = HiFiVAEGAN2(model_path, device=device) elif vocoder_type == 'fireflygan-base': vocoder = FireFlyGANBase(model_path, device=device) elif vocoder_type == 'evagan-base': @@ -29,6 +33,8 @@ def load_vocoder_for_save(vocoder_type, model_path, device='cpu'): vocoder = DCT512(model_path, device=device) elif vocoder_type == 'dct512log': vocoder = DCT512(model_path, device=device, l_norm=True) + elif vocoder_type == 'wavs512': + vocoder = WAVS512(model_path, device=device) else: raise ValueError(f" [x] Unknown vocoder: {vocoder_type}") out_dict = vocoder.load_model_for_combo(model_path=model_path) @@ -53,6 +59,8 @@ def __init__(self, vocoder_type, vocoder_ckpt, device=None): self.vocoder = NsfHifiGANLog10(vocoder_ckpt, device=device) elif vocoder_type == 'hifivaegan': self.vocoder = HiFiVAEGAN(vocoder_ckpt, device=device) + elif vocoder_type == 'hifivaegan2': + self.vocoder = HiFiVAEGAN2(vocoder_ckpt, device=device) elif vocoder_type == 'fireflygan-base': self.vocoder = FireFlyGANBase(vocoder_ckpt, device=device) elif vocoder_type == 'evagan-base': @@ -63,6 +71,8 @@ def __init__(self, vocoder_type, vocoder_ckpt, device=None): self.vocoder = DCT512(vocoder_ckpt, device=device) elif vocoder_type == 'dct512log': self.vocoder = DCT512(vocoder_ckpt, device=device, l_norm=True) + elif vocoder_type == 'wavs512': + self.vocoder = WAVS512(vocoder_ckpt, device=device) else: raise ValueError(f" [x] Unknown vocoder: {vocoder_type}") @@ -93,6 +103,50 @@ def infer(self, mel, f0): return audio +class WAVS512(torch.nn.Module): + def __init__(self, model_path, device=None): + super().__init__() + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.device = device + self.h_sampling_rate = 44100 + self.h_num_mels = 512 + self.h_hop_size = 512 + self.wavs = WAVS(self.h_hop_size) + self.iwavs = IWAVS(self.h_hop_size) + + def sample_rate(self): + return self.h_sampling_rate + + def hop_size(self): + return self.h_hop_size + + def dimension(self): + return self.h_num_mels + + def extract(self, audio, keyshift=0): + assert keyshift == 0 + with torch.no_grad(): + audio = audio.to(self.device) + mel = self.wavs(audio) # B, n_frames, bins + return mel + + def forward(self, mel, f0): # mel: B, n_frames, bins; f0: B, n_frames + assert mel.shape[-1] == 512 + with torch.no_grad(): + audio = self.iwavs(mel) + return audio.unsqueeze(1) # B, 1, T + + def load_model_for_combo(self, model_path=None, device='cpu'): + config = {"sampling_rate": self.sampling_rate, "num_mels": self.num_mels, "hop_size": self.hop_size} + model = NothingFlag() + out_dict = { + "config": config, + "model": model + } + return out_dict + + class DCT512(torch.nn.Module): def __init__(self, model_path, device=None, l_norm=False): super().__init__() @@ -238,6 +292,7 @@ def dimension(self): return self.model.inter_channels def extract(self, audio, keyshift=0, only_z=False): + assert keyshift == 0 if audio.shape[-1] % self.model.hop_size == 0: audio = torch.cat((audio, torch.zeros_like(audio[:, :1])), dim=-1) if keyshift != 0: @@ -260,10 +315,85 @@ def load_model_for_combo(self, model_path=None, device='cpu'): model_path = self.model_path assert self.config_path is not None config_path = os.path.join(os.path.split(model_path)[0], 'config.json') - with open(config_path, "r") as f: + with open(config_path, "r", encoding='utf-8') as f: + data = f.read() + config = json.loads(data) + model_state_dict = torch.load(model_path, map_location=torch.device(device), weights_only=True) + out_dict = { + "config": config, + "model": model_state_dict + } + return out_dict + + +class HiFiVAEGAN2(torch.nn.Module): + def __init__(self, model_path, device=None): + super().__init__() + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.device = device + + # 如果model_path是字典,说明传入的是config + model + if type(model_path) == dict: + self.config_path = None + self.model_path = None + self.config = model_path['config'] + self.enc = Encoder(self.config['hps']).to(device) + self.dec = Generator(self.config['hps']).to(device) + self.enc.load_state_dict(model_path['model']['encoder']) + self.dec.load_state_dict(model_path['model']['decoder']) + self.enc.eval() + self.dec.eval() + else: + self.model_path = model_path + self.config_path = os.path.join(os.path.split(model_path)[0], 'config.json') + with open(self.config_path, "r", encoding='utf-8') as f: + self.config = json.load(f) + self.enc = Encoder(self.config['hps']).to(device) + self.dec = Generator(self.config['hps']).to(device) + model_state_dict = torch.load(model_path, map_location=torch.device(device), weights_only=True) + self.enc.load_state_dict(model_state_dict['encoder']) + self.dec.load_state_dict(model_state_dict['decoder']) + self.enc.eval() + self.dec.eval() + + def sample_rate(self): + return self.config["hps"]["sampling_rate"] + + def hop_size(self): + return self.config["hop_size"] + + def dimension(self): + return self.config["hps"]["inter_channels"] + + def extract(self, audio, keyshift=0, only_z=False): + assert keyshift == 0 + if audio.shape[-1] % self.config["hop_size"] == 0: + audio = torch.cat((audio, torch.zeros_like(audio[:, :1])), dim=-1) + if keyshift != 0: + raise ValueError("HiFiVAEGAN could not use keyshift!") + with torch.no_grad(): + z, m, logs = self.enc(audio) + if only_z: + return z.transpose(1, 2) + mel = torch.stack((m.transpose(-1, -2), logs.transpose(-1, -2)), dim=-1) + return mel + + def forward(self, mel, f0): + with torch.no_grad(): + z = mel.transpose(1, 2) + audio = self.dec(z) + return audio + + def load_model_for_combo(self, model_path=None, device='cpu'): + if model_path is None: + model_path = self.model_path + assert self.config_path is not None + config_path = os.path.join(os.path.split(model_path)[0], 'config.json') + with open(config_path, "r", encoding='utf-8') as f: data = f.read() config = json.loads(data) - model_state_dict = torch.load(model_path, map_location=torch.device(device)) + model_state_dict = torch.load(model_path, map_location=torch.device(device), weights_only=True) out_dict = { "config": config, "model": model_state_dict @@ -340,7 +470,7 @@ def load_model_for_combo(self, model_path=None, device='cpu'): config_path = os.path.join(os.path.split(model_path)[0], 'config.yaml') with open(config_path, "r") as config: config = yaml.safe_load(config) - model = torch.load(model_path, map_location=torch.device(device)) + model = torch.load(model_path, map_location=torch.device(device), weights_only=True) out_dict = { "config": config, "model": model @@ -417,7 +547,7 @@ def load_model_for_combo(self, model_path=None, device='cpu'): config_path = os.path.join(os.path.split(model_path)[0], 'config.yaml') with open(config_path, "r") as config: config = yaml.safe_load(config) - model = torch.load(model_path, map_location=torch.device(device)) + model = torch.load(model_path, map_location=torch.device(device), weights_only=True) out_dict = { "config": config, "model": model @@ -494,7 +624,7 @@ def load_model_for_combo(self, model_path=None, device='cpu'): config_path = os.path.join(os.path.split(model_path)[0], 'config.yaml') with open(config_path, "r") as config: config = yaml.safe_load(config) - model = torch.load(model_path, map_location=torch.device(device)) + model = torch.load(model_path, map_location=torch.device(device), weights_only=True) out_dict = { "config": config, "model": model diff --git a/encoder/hifi_vaegan/model.py b/encoder/hifi_vaegan/model.py index fda59e7..9520940 100644 --- a/encoder/hifi_vaegan/model.py +++ b/encoder/hifi_vaegan/model.py @@ -4,15 +4,15 @@ from torch import nn from torch.nn import Conv1d, ConvTranspose1d from torch.nn import functional as F -from torch.nn.utils import remove_weight_norm _OLD_WEIGHT_NORM = False try: from torch.nn.utils.parametrizations import weight_norm + from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm except ImportError: from torch.nn.utils import weight_norm + from torch.nn.utils import remove_weight_norm _OLD_WEIGHT_NORM = True import os -from vector_quantize_pytorch import VectorQuantize LRELU_SLOPE = 0.1 @@ -69,6 +69,14 @@ def remove_weight_norm(self): a_lay.remove_weight_norm() remove_weight_norm(self.conv_pre) remove_weight_norm(self.conv_post) + else: + for a_lay in self.ups: + remove_weight_norm(a_lay, 'weight') + for a_lay in self.resblocks: + a_lay.remove_weight_norm() + remove_weight_norm(self.conv_pre, 'weight') + remove_weight_norm(self.conv_post, 'weight') + def forward(self, x): x = x[:, None, :] @@ -148,6 +156,11 @@ def remove_weight_norm(self): remove_weight_norm(l) for l in self.convs2: remove_weight_norm(l) + else: + for l in self.convs1: + remove_weight_norm(l, 'weight') + for l in self.convs2: + remove_weight_norm(l, 'weight') class ResBlock2(torch.nn.Module): @@ -174,11 +187,15 @@ def remove_weight_norm(self): if _OLD_WEIGHT_NORM: for l in self.convs: remove_weight_norm(l) + else: + for l in self.convs: + remove_weight_norm(l, 'weight') class Generator(torch.nn.Module): def __init__(self, h): super(Generator, self).__init__() + from vector_quantize_pytorch import VectorQuantize self.h = h self.num_kernels = len(h["resblock_kernel_sizes"]) @@ -243,6 +260,13 @@ def remove_weight_norm(self): a_lay.remove_weight_norm() remove_weight_norm(self.conv_pre) remove_weight_norm(self.conv_post) + else: + for a_lay in self.ups: + remove_weight_norm(a_lay, 'weight') + for a_lay in self.resblocks: + a_lay.remove_weight_norm() + remove_weight_norm(self.conv_pre, 'weight') + remove_weight_norm(self.conv_post, 'weight') def feature_loss(fmap_r, fmap_g): @@ -435,7 +459,7 @@ def decode(self, z): @torch.no_grad() def load(self, model_type): assert os.path.isfile(self.model_path) - model_dict = torch.load(self.model_path, map_location='cpu')["model"] + model_dict = torch.load(self.model_path, map_location='cpu', weights_only=True)["model"] load_dict = {} for k, v in model_dict.items(): if k[:len(model_type)] == model_type: diff --git a/encoder/hifi_vaegan2/__init__.py b/encoder/hifi_vaegan2/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/encoder/hifi_vaegan2/modules/ConformerConv.py b/encoder/hifi_vaegan2/modules/ConformerConv.py new file mode 100644 index 0000000..7da8f68 --- /dev/null +++ b/encoder/hifi_vaegan2/modules/ConformerConv.py @@ -0,0 +1,92 @@ +from torch import nn +import torch + +class ConformerConvolutionModule(nn.Module): + """Convolution block used in the conformer block""" + + def __init__(self, + hidden_size, + conv_depthwise_kernel_size + ): + super().__init__() + if (conv_depthwise_kernel_size - 1) % 2 == 1: + raise ValueError("`config.conv_depthwise_kernel_size` should be a odd number for 'SAME' padding") + self.layer_norm = nn.LayerNorm(hidden_size, eps=1e-6) + self.pointwise_conv1 = nn.Conv1d( + hidden_size, + 2 * hidden_size, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + self.glu = nn.GLU(dim=1) + self.depthwise_conv = nn.Conv1d( + hidden_size, + hidden_size, + conv_depthwise_kernel_size, + stride=1, + padding=0, + groups=hidden_size, + bias=False, + ) + + self.depthwise_layer_norm = nn.LayerNorm(hidden_size, eps=1e-6) + self.activation = nn.SiLU() + self.pointwise_conv2 = nn.Conv1d( + hidden_size, + hidden_size, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + self.dropout = nn.Dropout(0.1) + + def forward(self, hidden_states): + hidden_states_input = hidden_states + hidden_states = hidden_states.transpose(1, 2) + hidden_states = self.layer_norm(hidden_states) + + # exchange the temporal dimension and the feature dimension + hidden_states = hidden_states.transpose(1, 2) + + # GLU mechanism + # => (batch, 2*channel, dim) + hidden_states = self.pointwise_conv1(hidden_states) + # => (batch, channel, dim) + hidden_states = self.glu(hidden_states) + + # Pad the sequence entirely on the left because of causal convolution. + hidden_states = torch.nn.functional.pad(hidden_states, (self.depthwise_conv.kernel_size[0] - 1, 0)) + + # 1D Depthwise Conv + hidden_states = self.depthwise_conv(hidden_states) + + hidden_states = self.depthwise_layer_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + hidden_states = self.activation(hidden_states) + + hidden_states = self.pointwise_conv2(hidden_states) + hidden_states = self.dropout(hidden_states) + # hidden_states = hidden_states.transpose(1, 2) + return hidden_states + hidden_states_input * 0.5 + +class ConformerConvLayer(nn.Module): + def __init__(self, + hidden_size, + conv_depthwise_kernel_size, + layer_num + ): + super().__init__() + self.lzyers = nn.ModuleList([ + ConformerConvolutionModule( + hidden_size, + conv_depthwise_kernel_size + ) for _ in range(layer_num) + ]) + def forward(self, hidden_states): + for layer in self.lzyers: + hidden_states = layer(hidden_states) + return hidden_states + \ No newline at end of file diff --git a/encoder/hifi_vaegan2/modules/MRD.py b/encoder/hifi_vaegan2/modules/MRD.py new file mode 100644 index 0000000..9f8eb1f --- /dev/null +++ b/encoder/hifi_vaegan2/modules/MRD.py @@ -0,0 +1,105 @@ +import torch +import torch.nn as nn +from einops import rearrange +from torch.nn.utils import weight_norm +import torchaudio +BANDS = [(0.0, 0.1), (0.1, 0.2), (0.2, 0.3), (0.3, 0.4), (0.4, 0.5), (0.5, 0.6), (0.6, 0.7), (0.7, 0.8), (0.8, 0.9), (0.9, 1.0)] +# WEIGHT = [1., 1., 2., 2.5, 3.] + +def WNConv2d(*args, **kwargs): + act = kwargs.pop("act", True) + conv = weight_norm(nn.Conv2d(*args, **kwargs)) + if not act: + return conv + return nn.Sequential(conv, nn.SELU(0.1)) + +class MRD(nn.Module): + def __init__( + self, + window_length: int, + hop_factor: float = 0.25, + sample_rate: int = 44100, + bands: list = BANDS, + channels: int = 1 + ): + """Complex multi-band spectrogram discriminator. + Parameters + ---------- + window_length : int + Window length of STFT. + hop_factor : float, optional + Hop factor of the STFT, defaults to ``0.25 * window_length``. + sample_rate : int, optional + Sampling rate of audio in Hz, by default 44100 + bands : list, optional + Bands to run discriminator over. + """ + super().__init__() + from audiotools import STFTParams + self.window_length = window_length + self.hop_factor = hop_factor + self.sample_rate = sample_rate + self.stft_params = STFTParams( + window_length=window_length, + hop_length=int(window_length * hop_factor), + match_stride=True, + ) + + self.channels = channels + + n_fft = window_length // 2 + 1 + bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] + self.bands = bands + + ch = 32 + convs = lambda: nn.ModuleList( + [ + WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)), + ] + ) + self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) + self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False) + + self.spec_transform = torchaudio.transforms.Spectrogram( + n_fft=window_length, hop_length=int(window_length * hop_factor), win_length=window_length, window_fn=torch.hann_window, + normalized=True, center=False, pad_mode=None, power=None, return_complex=True) + + def spectrogram(self, x): + from audiotools import AudioSignal + x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params) + x = torch.view_as_real(x.stft()) + # x = torch.view_as_real(self.spec_transform(x)) + # print(x.shape) + x = rearrange(x, "b ch f t c -> (b ch) c t f", ch=self.channels) + # Split into bands + x_bands = [x[..., b[0] : b[1]] for b in self.bands] + return x_bands + + def forward(self, x): + x_bands = self.spectrogram(x) + fmap = [] + + x = [] + for band, stack in zip(x_bands, self.band_convs): + for layer in stack: + band = layer(band) + fmap.append(band) + x.append(band) + + x = torch.cat(x, dim=-1) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + +if __name__ == "__main__": + x = torch.randn(2, 1, 44100 * 10) + mrd = MRD(1024) + fmap, logist = mrd(x) + print([f.shape for f in fmap]) + print(logist.shape) \ No newline at end of file diff --git a/encoder/hifi_vaegan2/modules/__init__.py b/encoder/hifi_vaegan2/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/encoder/hifi_vaegan2/modules/commons.py b/encoder/hifi_vaegan2/modules/commons.py new file mode 100644 index 0000000..120968a --- /dev/null +++ b/encoder/hifi_vaegan2/modules/commons.py @@ -0,0 +1,56 @@ +import math +import torch +from torch.nn import functional as F + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if "Depthwise_Separable" in classname: + m.depth_conv.weight.data.normal_(mean, std) + m.point_conv.weight.data.normal_(mean, std) + elif classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + +def get_padding(kernel_size, dilation=1): + return int((kernel_size*dilation - dilation)/2) + +def intersperse(lst, item): + result = [item] * (len(lst) * 2 + 1) + result[1::2] = lst + return result + +def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): + position = torch.arange(length, dtype=torch.float) + num_timescales = channels // 2 + log_timescale_increment = (math.log(float(max_timescale) / float(min_timescale)) / (num_timescales - 1)) + inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment) + scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) + signal = F.pad(signal, [0, 0, 0, channels % 2]) + signal = signal.view(1, channels, length) + return signal + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, n_channels): + n_channels_int = n_channels[0] + in_act = input_a + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + +def clip_grad_value_(parameters, clip_value, norm_type=2): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + norm_type = float(norm_type) + if clip_value is not None: + clip_value = float(clip_value) + + total_norm = 0 + for p in parameters: + param_norm = p.grad.data.norm(norm_type) + total_norm += param_norm.item() ** norm_type + if clip_value is not None: + p.grad.data.clamp_(min=-clip_value, max=clip_value) + total_norm = total_norm ** (1. / norm_type) + return total_norm diff --git a/encoder/hifi_vaegan2/modules/losses.py b/encoder/hifi_vaegan2/modules/losses.py new file mode 100644 index 0000000..91e2946 --- /dev/null +++ b/encoder/hifi_vaegan2/modules/losses.py @@ -0,0 +1,92 @@ +import torch +import torchaudio +from torch import nn +from torch.nn import functional as F + +def feature_loss(fmap_r, fmap_g): + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + rl = rl.float().detach() + gl = gl.float() + loss += torch.mean(torch.abs(rl - gl)) + + return loss * 2 + +def discriminator_loss(disc_real_outputs, disc_generated_outputs): + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + dr = dr.float() + dg = dg.float() + r_loss = torch.mean((1-dr)**2) + g_loss = torch.mean(dg**2) + loss += (r_loss + g_loss) + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + +def generator_loss(disc_outputs): + loss = 0 + gen_losses = [] + for dg in disc_outputs: + dg = dg.float() + l = torch.mean((1-dg)**2) + gen_losses.append(l) + loss += l + + return loss, gen_losses + +def kl_loss(logs, m): + kl = 0.5 * (m**2 + torch.exp(logs) - logs - 1).sum(dim=1) + kl = torch.mean(kl) + return kl + +class SSSLoss(nn.Module): + """ + Single-scale Spectral Loss. + """ + + def __init__(self, n_fft=111, alpha=1.0, overlap=0, eps=1e-7): + super().__init__() + self.n_fft = n_fft + self.alpha = alpha + self.eps = eps + self.hop_length = int(n_fft * (1 - overlap)) # 25% of the length + self.spec = torchaudio.transforms.Spectrogram(n_fft=self.n_fft, hop_length=self.hop_length, power=1, normalized=True, center=False) + + def forward(self, x_true, x_pred): + S_true = self.spec(x_true) + self.eps + S_pred = self.spec(x_pred) + self.eps + + converge_term = torch.mean(torch.linalg.norm(S_true - S_pred, dim = (1, 2)) / torch.linalg.norm(S_true + S_pred, dim = (1, 2))) + + log_term = F.l1_loss(S_true.log(), S_pred.log()) + + loss = converge_term + self.alpha * log_term + return loss + + +class RSSLoss(nn.Module): + ''' + Random-scale Spectral Loss. + ''' + + def __init__(self, fft_min, fft_max, n_scale, alpha=1.0, overlap=0, eps=1e-7, device='cuda'): + super().__init__() + self.fft_min = fft_min + self.fft_max = fft_max + self.n_scale = n_scale + self.lossdict = {} + for n_fft in range(fft_min, fft_max): + self.lossdict[n_fft] = SSSLoss(n_fft, alpha, overlap, eps).to(device) + + def forward(self, x_pred, x_true): + value = 0. + n_ffts = torch.randint(self.fft_min, self.fft_max, (self.n_scale,)) + for n_fft in n_ffts: + loss_func = self.lossdict[int(n_fft)] + value += loss_func(x_true, x_pred) + return value / self.n_scale \ No newline at end of file diff --git a/encoder/hifi_vaegan2/modules/mel_processing.py b/encoder/hifi_vaegan2/modules/mel_processing.py new file mode 100644 index 0000000..d81a4c6 --- /dev/null +++ b/encoder/hifi_vaegan2/modules/mel_processing.py @@ -0,0 +1,83 @@ +import torch +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn + +MAX_WAV_VALUE = 32768.0 + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + """ + PARAMS + ------ + C: compression factor + """ + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + """ + PARAMS + ------ + C: compression factor used to compress + """ + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +def spectral_de_normalize_torch(magnitudes): + output = dynamic_range_decompression_torch(magnitudes) + return output + + +mel_basis = {} +hann_window = {} + + +def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): + if torch.min(y) < -1.: + print('min value is ', torch.min(y)) + if torch.max(y) > 1.: + print('max value is ', torch.max(y)) + + global hann_window + dtype_device = str(y.dtype) + '_' + str(y.device) + wnsize_dtype_device = str(win_size) + '_' + dtype_device + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') + y = y.squeeze(1) + + y_dtype = y.dtype + if y.dtype == torch.bfloat16: + y = y.to(torch.float32) + + spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], + center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True) + spec = torch.view_as_real(spec).to(y_dtype) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + return spec + + +def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): + global mel_basis + dtype_device = str(spec.dtype) + '_' + str(spec.device) + fmax_dtype_device = str(fmax) + '_' + dtype_device + if fmax_dtype_device not in mel_basis: + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) + spec = torch.matmul(mel_basis[fmax_dtype_device], spec) + spec = spectral_normalize_torch(spec) + return spec + + +def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): + spec = spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center) + spec = spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax) + + return spec diff --git a/encoder/hifi_vaegan2/modules/models.py b/encoder/hifi_vaegan2/modules/models.py new file mode 100644 index 0000000..ea39f9b --- /dev/null +++ b/encoder/hifi_vaegan2/modules/models.py @@ -0,0 +1,419 @@ +import numpy as np +import torch +from torch import nn +from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d +from torch.nn import functional as F +from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm +from torch.nn.utils.parametrizations import spectral_norm, weight_norm +#from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm + +from .modules import LRELU_SLOPE as M_LRELU_SLOPE +from .commons import get_padding, init_weights +from .msstftd import MultiScaleSTFTDiscriminator +from .ConformerConv import ConformerConvLayer +from .MRD import MRD +LRELU_SLOPE = 0.1 + +class Encoder(nn.Module): + def __init__(self, h, + ): + super().__init__() + + self.h = h + # h["inter_channels"] + self.num_kernels = len(h["resblock_kernel_sizes"]) + self.out_channels = h["inter_channels"] + self.num_downsamples = len(h["upsample_rates"]) + self.conv_pre = weight_norm(Conv1d(1, h["upsample_initial_channel"]// (2 ** len(h["upsample_rates"])), 7, 1, padding=3)) + resblock = ResBlock2 + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(reversed(h["upsample_rates"]), reversed(h["upsample_kernel_sizes"]))): + self.ups.append(weight_norm( + Conv1d(h["upsample_initial_channel"] // (2 ** (len(h["upsample_rates"]) - i)), h["upsample_initial_channel"] // (2 ** (len(h["upsample_rates"]) - i - 1)), + k, u, padding= (k - u + 1) // 2))) + self.resblocks = nn.ModuleList() + for i in range(len(self.ups), 0, -1): + ch = h["upsample_initial_channel"] // (2 ** (i - 1)) + for j, (k, d) in enumerate(zip(h["resblock_kernel_sizes"], h["resblock_dilation_sizes"])): + self.resblocks.append(resblock(h, ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, 2 * h["inter_channels"], 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + self.upp = np.prod(h["upsample_rates"]) + + def forward(self, x): + x = x[:,None,:] + x = self.conv_pre(x) + for i in range(self.num_downsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + m, logs = torch.split(x, self.out_channels, dim=1) + z = m + torch.randn_like(m) * torch.exp(logs) + return z, m, logs + + def remove_weight_norm(self): + for l in self.ups: + remove_weight_norm(l, 'weight') + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre, 'weight') + remove_weight_norm(self.conv_post, 'weight') + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), + ]) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv1d(1, 16, 15, 1, padding=7)), + norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), + norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ]) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, M_LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(MultiPeriodDiscriminator, self).__init__() + periods = [2, 3, 5, 7, 11, 13, 17] + fft_sizes = [2048, 1024, 512] + + # discs = [MultiScaleSTFTDiscriminator(filters=32) ,] + discs = [MRD(f) for f in fft_sizes] + discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods] + self.discriminators = nn.ModuleList(discs) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + if isinstance(d, MultiScaleSTFTDiscriminator): + y_d_rs.extend(y_d_r) + y_d_gs.extend(y_d_g) + fmap_rs.extend(fmap_r) + fmap_gs.extend(fmap_g) + else: + y_d_rs.append(y_d_r) + y_d_gs.append(y_d_g) + fmap_rs.append(fmap_r) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + +class ResBlock1(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.h = h + self.convs1 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]))) + ]) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))) + ]) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l, 'weight') + for l in self.convs2: + remove_weight_norm(l, 'weight') + + +class ResBlock2(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__() + self.h = h + self.convs = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))) + ]) + self.convs.apply(init_weights) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l, 'weight') + +class Generator(torch.nn.Module): + def __init__(self, h): + super(Generator, self).__init__() + self.h = h + self.num_kernels = len(h["resblock_kernel_sizes"]) + self.num_upsamples = len(h["upsample_rates"]) + self.conv_pre = weight_norm(Conv1d(h["inter_channels"], h["upsample_initial_channel"], 7, 1, padding=3)) + resblock = ResBlock1 if h["resblock"] == '1' else ResBlock2 + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(h["upsample_rates"], h["upsample_kernel_sizes"])): + self.ups.append(weight_norm( + ConvTranspose1d(h["upsample_initial_channel"] // (2 ** i), h["upsample_initial_channel"] // (2 ** (i + 1)), + k, u, padding=(k - u + 1) // 2))) + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h["upsample_initial_channel"] // (2 ** (i + 1)) + for j, (k, d) in enumerate(zip(h["resblock_kernel_sizes"], h["resblock_dilation_sizes"])): + self.resblocks.append(resblock(h, ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + self.upp = np.prod(h["upsample_rates"]) + self.conv_layers = ConformerConvLayer(h["inter_channels"], 31, 4) + def forward(self, x): + x = self.conv_layers(x) + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + for l in self.ups: + remove_weight_norm(l, 'weight') + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre, 'weight') + remove_weight_norm(self.conv_post, 'weight') + +class MultiScaleDiscriminator(torch.nn.Module): + def __init__(self): + super(MultiScaleDiscriminator, self).__init__() + self.discriminators = nn.ModuleList([ + DiscriminatorS(use_spectral_norm=True), + DiscriminatorS(), + DiscriminatorS(), + ]) + self.meanpools = nn.ModuleList([ + AvgPool1d(4, 2, padding=2), + AvgPool1d(4, 2, padding=2) + ]) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + if i != 0: + y = self.meanpools[i - 1](y) + y_hat = self.meanpools[i - 1](y_hat) + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + +def feature_loss(fmap_r, fmap_g): + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl - gl)) + + return loss * 2 + +def discriminator_loss(disc_real_outputs, disc_generated_outputs): + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean((1 - dr) ** 2) + g_loss = torch.mean(dg ** 2) + loss += (r_loss + g_loss) + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + +def generator_loss(disc_outputs): + loss = 0 + gen_losses = [] + for dg in disc_outputs: + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l + + return loss, gen_losses + +class TrainModel(nn.Module): + """ + Synthesizer for Training + """ + + def __init__(self, + hop_size, + windows_size, + inter_channels, + hidden_channels, + kernel_size, + p_dropout, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + sampling_rate=44100, + use_vq = False, + codebook_size = 4096, + **kwargs): + + super().__init__() + self.inter_channels = inter_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.resblock = resblock + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.upsample_rates = upsample_rates + self.upsample_initial_channel = upsample_initial_channel + self.upsample_kernel_sizes = upsample_kernel_sizes + self.hop_size = hop_size + self.windows_size = windows_size + + hps = { + "sampling_rate": sampling_rate, + "inter_channels": inter_channels, + "resblock": resblock, + "resblock_kernel_sizes": resblock_kernel_sizes, + "resblock_dilation_sizes": resblock_dilation_sizes, + "upsample_rates": upsample_rates, + "upsample_initial_channel": upsample_initial_channel, + "upsample_kernel_sizes": upsample_kernel_sizes + } + + self.dec = Generator(h=hps) + + self.enc_q = Encoder(h=hps) + + if use_vq: + from vector_quantize_pytorch import VectorQuantize + self.quantizer = VectorQuantize( + dim = inter_channels, + codebook_size = codebook_size, + decay = 0.8, + commitment_weight = 1.) + else: + self.quantizer = None + + def forward(self, wav): + z, m, logs = self.enc_q(wav) + if self.quantizer is not None and self.training: + z_, indices, commit_loss = self.quantizer(z.transpose(1,2)) + else: + commit_loss = 0 + wav = self.dec(z) + + return z, wav, (m, logs), commit_loss + + def remove_weight_norm(self): + self.dec.remove_weight_norm() + self.enc_q.remove_weight_norm() + +if __name__ == "__main__": + disc = MultiPeriodDiscriminator() \ No newline at end of file diff --git a/encoder/hifi_vaegan2/modules/modules.py b/encoder/hifi_vaegan2/modules/modules.py new file mode 100644 index 0000000..225b197 --- /dev/null +++ b/encoder/hifi_vaegan2/modules/modules.py @@ -0,0 +1,140 @@ +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.utils import remove_weight_norm, weight_norm + +LRELU_SLOPE = 0.1 + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, n_channels): + n_channels_int = n_channels[0] + t_act = torch.tanh(input_a[:, :n_channels_int, :]) + s_act = torch.sigmoid(input_a[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) + +class ConvReluNorm(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + assert n_layers > 1, "Number of layers should be larger than 0." + + self.conv_layers = nn.ModuleList() + self.norm_layers = nn.ModuleList() + self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.relu_drop = nn.Sequential( + nn.ReLU(), + nn.Dropout(p_dropout)) + for _ in range(n_layers-1): + self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask): + x_org = x + for i in range(self.n_layers): + x = self.conv_layers[i](x * x_mask) + x = self.norm_layers[i](x) + x = self.relu_drop(x) + x = x_org + self.proj(x) + return x * x_mask + +class WN(torch.nn.Module): + def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=0): + super(WN, self).__init__() + assert(kernel_size % 2 == 1) + self.hidden_channels =hidden_channels + self.kernel_size = kernel_size, + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.p_dropout = p_dropout + + self.in_layers = torch.nn.ModuleList() + self.res_skip_layers = torch.nn.ModuleList() + self.drop = nn.Dropout(p_dropout) + + for i in range(n_layers): + dilation = dilation_rate ** i + padding = int((kernel_size * dilation - dilation) / 2) + in_layer = nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size, + dilation=dilation, padding=padding) + in_layer = weight_norm(in_layer, name='weight') + self.in_layers.append(in_layer) + + # last one is not necessary + if i < n_layers - 1: + res_skip_channels = 2 * hidden_channels + else: + res_skip_channels = hidden_channels + + res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) + res_skip_layer = weight_norm(res_skip_layer, name='weight') + self.res_skip_layers.append(res_skip_layer) + + def forward(self, x, x_mask, **kwargs): + output = torch.zeros_like(x) + n_channels_tensor = torch.IntTensor([self.hidden_channels]) + + for i in range(self.n_layers): + x_in = self.in_layers[i](x) + + acts = fused_add_tanh_sigmoid_multiply( + x_in, + n_channels_tensor) + acts = self.drop(acts) + + res_skip_acts = self.res_skip_layers[i](acts) + if i < self.n_layers - 1: + res_acts = res_skip_acts[:,:self.hidden_channels,:] + x = (x + res_acts) * x_mask + output = output + res_skip_acts[:,self.hidden_channels:,:] + else: + output = output + res_skip_acts + return output * x_mask + + def remove_weight_norm(self): + for l in self.in_layers: + remove_weight_norm(l) + for l in self.res_skip_layers: + remove_weight_norm(l) + +class Log(nn.Module): + def forward(self, x, x_mask, reverse=False, **kwargs): + if not reverse: + y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask + logdet = torch.sum(-y, [1, 2]) + return y, logdet + else: + x = torch.exp(x) * x_mask + return x + +class Flip(nn.Module): + def forward(self, x, *args, reverse=False, **kwargs): + x = torch.flip(x, [1]) + if not reverse: + logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) + return x, logdet + else: + return x diff --git a/encoder/hifi_vaegan2/modules/msstftd.py b/encoder/hifi_vaegan2/modules/msstftd.py new file mode 100644 index 0000000..6649dd8 --- /dev/null +++ b/encoder/hifi_vaegan2/modules/msstftd.py @@ -0,0 +1,205 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""MS-STFT discriminator, provided here for reference.""" + +import typing as tp +import einops +import torchaudio +import torch +from torch import nn +from einops import rearrange +from torch.nn.utils import spectral_norm, weight_norm + +FeatureMapType = tp.List[torch.Tensor] +LogitsType = torch.Tensor +DiscriminatorOutput = tp.Tuple[tp.List[LogitsType], tp.List[FeatureMapType]] + +CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm', + 'time_layer_norm', 'layer_norm', 'time_group_norm']) + +class ConvLayerNorm(nn.LayerNorm): + """ + Convolution-friendly LayerNorm that moves channels to last dimensions + before running the normalization and moves them back to original position right after. + """ + def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs): + super().__init__(normalized_shape, **kwargs) + + def forward(self, x): + x = einops.rearrange(x, 'b ... t -> b t ...') + x = super().forward(x) + x = einops.rearrange(x, 'b t ... -> b ... t') + return + +def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module: + assert norm in CONV_NORMALIZATIONS + if norm == 'weight_norm': + return weight_norm(module) + elif norm == 'spectral_norm': + return spectral_norm(module) + else: + # We already check was in CONV_NORMALIZATION, so any other choice + # doesn't need reparametrization. + return module + + +def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module: + """Return the proper normalization module. If causal is True, this will ensure the returned + module is causal, or return an error if the normalization doesn't support causal evaluation. + """ + assert norm in CONV_NORMALIZATIONS + if norm == 'layer_norm': + assert isinstance(module, nn.modules.conv._ConvNd) + return ConvLayerNorm(module.out_channels, **norm_kwargs) + elif norm == 'time_group_norm': + if causal: + raise ValueError("GroupNorm doesn't support causal evaluation.") + assert isinstance(module, nn.modules.conv._ConvNd) + return nn.GroupNorm(1, module.out_channels, **norm_kwargs) + else: + return nn.Identity() + + +def get_2d_padding(kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[int, int] = (1, 1)): + return (((kernel_size[0] - 1) * dilation[0]) // 2, ((kernel_size[1] - 1) * dilation[1]) // 2) + +class NormConv2d(nn.Module): + """Wrapper around Conv2d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + def __init__(self, *args, norm: str = 'none', + norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + super().__init__() + self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm) + self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + return x + +class DiscriminatorSTFT(nn.Module): + """STFT sub-discriminator. + Args: + filters (int): Number of filters in convolutions + in_channels (int): Number of input channels. Default: 1 + out_channels (int): Number of output channels. Default: 1 + n_fft (int): Size of FFT for each scale. Default: 1024 + hop_length (int): Length of hop between STFT windows for each scale. Default: 256 + kernel_size (tuple of int): Inner Conv2d kernel sizes. Default: ``(3, 9)`` + stride (tuple of int): Inner Conv2d strides. Default: ``(1, 2)`` + dilations (list of int): Inner Conv2d dilation on the time dimension. Default: ``[1, 2, 4]`` + win_length (int): Window size for each scale. Default: 1024 + normalized (bool): Whether to normalize by magnitude after stft. Default: True + norm (str): Normalization method. Default: `'weight_norm'` + activation (str): Activation function. Default: `'LeakyReLU'` + activation_params (dict): Parameters to provide to the activation function. + growth (int): Growth factor for the filters. Default: 1 + """ + def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1, + n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024, max_filters: int = 1024, + filters_scale: int = 1, kernel_size: tp.Tuple[int, int] = (3, 9), dilations: tp.List = [1, 2, 4], + stride: tp.Tuple[int, int] = (1, 2), normalized: bool = True, norm: str = 'weight_norm', + activation: str = 'LeakyReLU', activation_params: dict = {'negative_slope': 0.2}): + super().__init__() + assert len(kernel_size) == 2 + assert len(stride) == 2 + self.filters = filters + self.in_channels = in_channels + self.out_channels = out_channels + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.normalized = normalized + self.activation = getattr(torch.nn, activation)(**activation_params) + self.spec_transform = torchaudio.transforms.Spectrogram( + n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window_fn=torch.hann_window, + normalized=self.normalized, center=False, pad_mode=None, power=None) + spec_channels = 2 * self.in_channels + self.convs = nn.ModuleList() + self.convs.append( + NormConv2d(spec_channels, self.filters, kernel_size=kernel_size, padding=get_2d_padding(kernel_size)) + ) + in_chs = min(filters_scale * self.filters, max_filters) + for i, dilation in enumerate(dilations): + out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters) + self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride, + dilation=(dilation, 1), padding=get_2d_padding(kernel_size, (dilation, 1)), + norm=norm)) + in_chs = out_chs + out_chs = min((filters_scale ** (len(dilations) + 1)) * self.filters, max_filters) + self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_size[0], kernel_size[0]), + padding=get_2d_padding((kernel_size[0], kernel_size[0])), + norm=norm)) + self.conv_post = NormConv2d(out_chs, self.out_channels, + kernel_size=(kernel_size[0], kernel_size[0]), + padding=get_2d_padding((kernel_size[0], kernel_size[0])), + norm=norm) + + def forward(self, x: torch.Tensor): + fmap = [] + z = self.spec_transform(x) # [B, 2, Freq, Frames, 2] + z = torch.cat([z.real, z.imag], dim=1) + z = rearrange(z, 'b c w t -> b c t w') + for i, layer in enumerate(self.convs): + z = layer(z) + z = self.activation(z) + fmap.append(z) + z = self.conv_post(z) + return z, fmap + + +class MultiScaleSTFTDiscriminator(nn.Module): + """Multi-Scale STFT (MS-STFT) discriminator. + Args: + filters (int): Number of filters in convolutions + in_channels (int): Number of input channels. Default: 1 + out_channels (int): Number of output channels. Default: 1 + n_ffts (Sequence[int]): Size of FFT for each scale + hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale + win_lengths (Sequence[int]): Window size for each scale + **kwargs: additional args for STFTDiscriminator + """ + def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1, + n_ffts: tp.List[int] = [1024, 2048, 512], hop_lengths: tp.List[int] = [256, 512, 128], + win_lengths: tp.List[int] = [1024, 2048, 512], **kwargs): + super().__init__() + assert len(n_ffts) == len(hop_lengths) == len(win_lengths) + self.discriminators = nn.ModuleList([ + DiscriminatorSTFT(filters, in_channels=in_channels, out_channels=out_channels, + n_fft=n_ffts[i], win_length=win_lengths[i], hop_length=hop_lengths[i], **kwargs) + for i in range(len(n_ffts)) + ]) + self.num_discriminators = len(self.discriminators) + + def forward(self, x: torch.Tensor) -> DiscriminatorOutput: + logits = [] + fmaps = [] + for disc in self.discriminators: + logit, fmap = disc(x) + logits.append(logit) + fmaps.append(fmap) + return logits, fmaps + + +def test(): + disc = MultiScaleSTFTDiscriminator(filters=32) + y = torch.randn(1, 1, 24000) + y_hat = torch.randn(1, 1, 24000) + + y_disc_r, fmap_r = disc(y) + y_disc_gen, fmap_gen = disc(y_hat) + assert len(y_disc_r) == len(y_disc_gen) == len(fmap_r) == len(fmap_gen) == disc.num_discriminators + + assert all([len(fm) == 5 for fm in fmap_r + fmap_gen]) + assert all([list(f.shape)[:2] == [1, 32] for fm in fmap_r + fmap_gen for f in fm]) + assert all([len(logits.shape) == 4 for logits in y_disc_r + y_disc_gen]) + + +if __name__ == '__main__': + test() \ No newline at end of file diff --git a/encoder/hifi_vaegan2/readme.txt b/encoder/hifi_vaegan2/readme.txt new file mode 100644 index 0000000..8111347 --- /dev/null +++ b/encoder/hifi_vaegan2/readme.txt @@ -0,0 +1 @@ +from https://github.com/OpenNSP/Hifi-vaegan \ No newline at end of file diff --git a/encoder/rmvpe/inference.py b/encoder/rmvpe/inference.py index e48f29c..a795040 100644 --- a/encoder/rmvpe/inference.py +++ b/encoder/rmvpe/inference.py @@ -11,7 +11,7 @@ class RMVPE: def __init__(self, model_path, hop_length=160): self.resample_kernel = {} model = E2E0(4, 1, (2, 2)) - ckpt = torch.load(model_path) + ckpt = torch.load(model_path, weights_only=True) model.load_state_dict(ckpt['model'], strict=False) model.eval() self.model = model diff --git a/encoder/wavs/__init__.py b/encoder/wavs/__init__.py new file mode 100644 index 0000000..5eef152 --- /dev/null +++ b/encoder/wavs/__init__.py @@ -0,0 +1 @@ +from .wavs import WAVS, IWAVS \ No newline at end of file diff --git a/encoder/wavs/wavs.py b/encoder/wavs/wavs.py new file mode 100644 index 0000000..21e1a9b --- /dev/null +++ b/encoder/wavs/wavs.py @@ -0,0 +1,31 @@ +import torch + + +class WAVS(torch.nn.Module): + def __init__(self, hop_length): + super(WAVS, self).__init__() + self.hop_length = hop_length + + def forward(self, x): + # x: (batch_size, x_len) + x_len = x.shape[-1] + # pad x to make it a multiple of hop_length + pad_len = (self.hop_length - x_len % self.hop_length) + # pad zeros on the right for all batches + x = torch.nn.functional.pad(x, (0, pad_len)) + # unfold x + x = x.unfold(-1, self.hop_length, self.hop_length) # (batch_size, time, hop_length) + return x + + +class IWAVS(torch.nn.Module): + def __init__(self, hop_length): + super(IWAVS, self).__init__() + self.hop_length = hop_length + + def forward(self, x): + assert x.shape[-1] == self.hop_length + # x: (batch_size, time, hop_length) + # fold x + x = x.flatten(-2, -1) + return x # (batch_size, x_len) diff --git a/logger/utils.py b/logger/utils.py index d27d39b..9d117b7 100644 --- a/logger/utils.py +++ b/logger/utils.py @@ -73,7 +73,7 @@ def load_config(path_config): def to_json(path_params, path_json): - params = torch.load(path_params, map_location=torch.device('cpu')) + params = torch.load(path_params, map_location=torch.device('cpu'), weights_only=True) raw_state_dict = {} for k, v in params.items(): val = v.flatten().numpy().tolist() @@ -99,7 +99,8 @@ def load_model( optimizer, name='model', postfix='', - device='cpu'): + device='cpu', + model_only=False): if postfix == '': postfix = '_' + postfix path = os.path.join(expdir, name + postfix) @@ -113,9 +114,34 @@ def load_model( else: path_pt = path + 'best.pt' print(' [*] restoring model from', path_pt) - ckpt = torch.load(path_pt, map_location=torch.device(device)) + ckpt = torch.load(path_pt, map_location=torch.device(device), weights_only=True) global_step = ckpt['global_step'] model.load_state_dict(ckpt['model'], strict=False) + if not model_only: + if ckpt.get('optimizer') != None: + optimizer.load_state_dict(ckpt['optimizer']) + return global_step, model, optimizer + + +def load_optimizer( + expdir, + optimizer, + name='model', + postfix='', + device='cpu'): + if postfix == '': + postfix = '_' + postfix + path = os.path.join(expdir, name + postfix) + path_pt = traverse_dir(expdir, ['pt'], is_ext=False) + if len(path_pt) > 0: + steps = [s[len(path):] for s in path_pt] + maxstep = max([int(s) if s.isdigit() else 0 for s in steps]) + if maxstep >= 0: + path_pt = path + str(maxstep) + '.pt' + else: + path_pt = path + 'best.pt' + print(' [*] restoring model from', path_pt) + ckpt = torch.load(path_pt, map_location=torch.device(device), weights_only=True) if ckpt.get('optimizer') != None: optimizer.load_state_dict(ckpt['optimizer']) - return global_step, model, optimizer + return optimizer diff --git a/nsf_hifigan/models.py b/nsf_hifigan/models.py index 8342e01..8355481 100644 --- a/nsf_hifigan/models.py +++ b/nsf_hifigan/models.py @@ -6,13 +6,14 @@ import torch.nn.functional as F import torch.nn as nn from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d -from torch.nn.utils import remove_weight_norm, spectral_norm from .utils import init_weights, get_padding _OLD_WEIGHT_NORM = False try: - from torch.nn.utils.parametrizations import weight_norm + from torch.nn.utils.parametrizations import weight_norm, spectral_norm + from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm except ImportError: from torch.nn.utils import weight_norm + from torch.nn.utils import remove_weight_norm, spectral_norm _OLD_WEIGHT_NORM = True LRELU_SLOPE = 0.1 @@ -25,7 +26,7 @@ def load_model(model_path, device='cuda', load_for_combo=False): with open(config_file) as f: data = f.read() json_config = json.loads(data) - model_state_dict = torch.load(model_path, map_location=device) + model_state_dict = torch.load(model_path, map_location=device, weights_only=True) return json_config, model_state_dict # 如果model_path是字典,说明传入的是config + model @@ -36,7 +37,7 @@ def load_model(model_path, device='cuda', load_for_combo=False): else: h = load_config(model_path) generator = Generator(h).to(device) - cp_dict = torch.load(model_path, map_location=device) + cp_dict = torch.load(model_path, map_location=device, weights_only=True) generator.load_state_dict(cp_dict['generator']) generator.eval() @@ -95,6 +96,11 @@ def remove_weight_norm(self): remove_weight_norm(l) for l in self.convs2: remove_weight_norm(l) + else: + for l in self.convs1: + remove_weight_norm(l, 'weight') + for l in self.convs2: + remove_weight_norm(l, 'weight') class ResBlock2(torch.nn.Module): @@ -121,6 +127,9 @@ def remove_weight_norm(self): if _OLD_WEIGHT_NORM: for l in self.convs: remove_weight_norm(l) + else: + for l in self.convs: + remove_weight_norm(l, 'weight') class SineGen(torch.nn.Module): @@ -298,6 +307,13 @@ def remove_weight_norm(self): l.remove_weight_norm() remove_weight_norm(self.conv_pre) remove_weight_norm(self.conv_post) + else: + for l in self.ups: + remove_weight_norm(l, 'weight') + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre, 'weight') + remove_weight_norm(self.conv_post, 'weight') class DiscriminatorP(torch.nn.Module): diff --git a/preprocess.py b/preprocess.py index bf971f7..1ae0549 100644 --- a/preprocess.py +++ b/preprocess.py @@ -197,8 +197,12 @@ def process(file): use_pitch_aug = True if args.vocoder.type == 'hifivaegan': use_pitch_aug = False + if args.vocoder.type == 'hifivaegan2': + use_pitch_aug = False if str(args.vocoder.type)[:3] == 'dct': use_pitch_aug = False + if str(args.vocoder.type)[:4] == 'wavs': + use_pitch_aug = False # initialize units encoder if args.data.encoder == 'cnhubertsoftfish': diff --git a/requirements.txt b/requirements.txt index 006c7b5..4dd444b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,3 +26,4 @@ sounddevice gradio faiss-cpu vector-quantize-pytorch +torch-dct diff --git a/tools/combine_tools.py b/tools/combine_tools.py index 9732732..1e7128b 100644 --- a/tools/combine_tools.py +++ b/tools/combine_tools.py @@ -62,9 +62,9 @@ def __init__(self, diff_model_path, naive_model_path=None, vocoder_type=None, vo self.device = device # load ckpt - self.diff_model = torch.load(diff_model_path, map_location=torch.device(device)) + self.diff_model = torch.load(diff_model_path, map_location=torch.device(device), weights_only=False) if not self.is_combo_trained_model: - self.naive_model = torch.load(naive_model_path, map_location=torch.device(device)) + self.naive_model = torch.load(naive_model_path, map_location=torch.device(device), weights_only=False) print(" [INFO] Loaded model and config check out.") # vocoder diff --git a/tools/get_z_range.py b/tools/get_z_range.py index 49667be..114a10f 100644 --- a/tools/get_z_range.py +++ b/tools/get_z_range.py @@ -93,10 +93,10 @@ def parse_args(args=None, namespace=None): args = load_config(cmd.config) print(f' [INFO] args: {args}') print(f' [INFO] config: {cmd.config}') - print(f' [INFO] args.data.{train_or_val}_data_path: {args.data.train_path}') # get all file path data_path = args.data.train_path if train_or_val == 'train' else args.data.valid_path + print(f' [INFO] args.data.{train_or_val}_data_path: {data_path}') path_srcdir = os.path.join(data_path, 'mel') filelist = traverse_dir( path_srcdir, diff --git a/tools/get_z_range_png.py b/tools/get_z_range_png.py new file mode 100644 index 0000000..66c4eca --- /dev/null +++ b/tools/get_z_range_png.py @@ -0,0 +1,173 @@ +import os +import numpy as np +import tqdm +import yaml +import torch +import argparse +import tqdm + + +class DotDict(dict): + def __getattr__(*args): + val = dict.get(*args) + return DotDict(val) if type(val) is dict else val + + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ + + +def load_config(path_config): + with open(path_config, "r") as config: + args = yaml.safe_load(config) + args = DotDict(args) + # print(args) + return args + + +def traverse_dir( + root_dir, + extensions, + amount=None, + str_include=None, + str_exclude=None, + is_pure=False, + is_sort=False, + is_ext=True): + file_list = [] + cnt = 0 + for root, _, files in os.walk(root_dir): + for file in files: + if any([file.endswith(f".{ext}") for ext in extensions]): + # path + mix_path = os.path.join(root, file) + pure_path = mix_path[len(root_dir) + 1:] if is_pure else mix_path + + # amount + if (amount is not None) and (cnt == amount): + if is_sort: + file_list.sort() + return file_list + + # check string + if (str_include is not None) and (str_include not in pure_path): + continue + if (str_exclude is not None) and (str_exclude in pure_path): + continue + + if not is_ext: + ext = pure_path.split('.')[-1] + pure_path = pure_path[:-(len(ext) + 1)] + file_list.append(pure_path) + cnt += 1 + if is_sort: + file_list.sort() + return file_list + + +def parse_args(args=None, namespace=None): + """Parse command-line arguments.""" + parser = argparse.ArgumentParser() + parser.add_argument( + "-c", + "--config", + type=str, + required=True, + help="path to the config file") + parser.add_argument( + "-m", + "--model", + type=str, + default=None, + required=False, + help="val or train; default is train, it means get z range from train data path") + parser.add_argument( + "-min", + "--min", + type=float, + default=-10.0, + required=False, + help="z_min") + parser.add_argument( + "-max", + "--max", + type=float, + default=10.0, + required=False, + help="z_max") + return parser.parse_args(args=args, namespace=namespace) + + +if __name__ == '__main__': + # parse commands + cmd = parse_args() + train_or_val = cmd.model + if train_or_val is None: + train_or_val = 'train' + print(f' [INFO] train_or_val: {train_or_val}') + + # load config + args = load_config(cmd.config) + print(f' [INFO] args: {args}') + print(f' [INFO] config: {cmd.config}') + + # get all file path + data_path = args.data.train_path if train_or_val == 'train' else args.data.valid_path + print(f' [INFO] args.data.{train_or_val}_data_path: {data_path}') + path_srcdir = os.path.join(data_path, 'mel') + filelist = traverse_dir( + path_srcdir, + extensions=['npy'], + is_pure=True, + is_sort=True, + is_ext=True) + + # 定义0.001一级,范围z的直方图的表,tensor + z_min = cmd.min + z_max = cmd.max + bins = int((z_max - z_min) * 1000 + 1) + hist = torch.zeros(bins) + # 遍历所有文件 + for file in tqdm.tqdm(filelist): + path_specfile = os.path.join(path_srcdir, file) + # load spec + spec = np.load(path_specfile, allow_pickle=True) + spec = torch.from_numpy(spec).float() + m = spec.transpose(-1, 0)[:1].transpose(-1, 0).squeeze(-1) + logs = spec.transpose(-1, 0)[1:].transpose(-1, 0).squeeze(-1) + z = m + torch.randn_like(m) * torch.exp(logs) + # 计算直方图 + # clip将z限制在-10到10之间, 超出部分视为-10或10 + z_c = z.clamp(z_min, z_max) + hist += torch.histc(z_c, bins=bins, min=z_min, max=z_max) + + # 计算直方图的累积分布函数 + # 从左到右累积 + cdf = torch.cumsum(hist, dim=0) + # total count + cdf_total = cdf[-1] + # 找到0.001和0.999的位置 + z_find_min = z_min + z_find_max = z_max + for i in range(bins): + if cdf[i] > (0.001 * cdf_total): + z_find_min = i / 1000 - 10 + break + for i in range(bins): + if cdf[i] > (0.999 * cdf_total): + z_find_max = i / 1000 - 10 + break + print(f' [INFO] z_min(0.001): {z_find_min}, z_max(0.009): {z_find_max}') + # 刨去两端极值的数据占比 + _sum = (cdf[-2] - cdf[0]) + _sum = _sum / cdf_total + print(f' [INFO] sum(min > max): {_sum}') + import matplotlib.pyplot as plt + # 画图 + plt.figure() + plt.plot(torch.arange(z_min, z_max + 0.001, 0.001), hist) + plt.xlabel('z') + plt.ylabel('count') + plt.title('z range') + plt.grid() + plt.savefig(os.path.join(data_path, 'z_range.png')) + plt.close() diff --git a/tools/infer_tools.py b/tools/infer_tools.py index fe305e5..dfe83ba 100644 --- a/tools/infer_tools.py +++ b/tools/infer_tools.py @@ -309,6 +309,8 @@ def __call__(self, units, f0, volume, spk_id=1, spk_mix_dict=None, aug_shift=0, if self.args.vocoder.type == 'hifivaegan': use_vae = True + elif self.args.vocoder.type == 'hifivaegan2': + use_vae = True else: use_vae = False diff --git a/tools/tools.py b/tools/tools.py index e9ac091..9b97b65 100644 --- a/tools/tools.py +++ b/tools/tools.py @@ -98,7 +98,7 @@ def __init__(self, config_path, ckpt_path, device='cuda'): self.config['model']["num_lstm_layers"], ) with fsspec.open(ckpt_path, "rb") as f: - state = torch.load(f, map_location=device) + state = torch.load(f, map_location=device, weights_only=True) self.model.load_state_dict(state["model"]) self.model = self.model.to(device) self.model.eval() @@ -531,7 +531,7 @@ def __init__(self, path, h_sample_rate=16000, h_hop_size=320): print(' [Encoder Model] HuBERT Soft') self.hubert = HubertSoft() print(' [Loading] ' + path) - checkpoint = torch.load(path) + checkpoint = torch.load(path, map_location='cpu', weights_only=True) consume_prefix_in_state_dict_if_present(checkpoint, "module.") self.hubert.load_state_dict(checkpoint) self.hubert.eval() @@ -688,7 +688,7 @@ def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device='cpu', gate self.proj = torch.nn.Sequential(torch.nn.Dropout(0.1), torch.nn.Linear(768, 256)) # self.label_embedding = nn.Embedding(128, 256) - state_dict = torch.load(path, map_location=device) + state_dict = torch.load(path, map_location=device, weights_only=True) self.load_state_dict(state_dict) @torch.no_grad() diff --git a/train.py b/train.py index 0b572ff..2f2f86d 100644 --- a/train.py +++ b/train.py @@ -7,6 +7,7 @@ from diffusion.solver import train from diffusion.unit2mel import Unit2Mel, Unit2MelNaive, load_svc_model from diffusion.vocoder import Vocoder +import time def parse_args(args=None, namespace=None): @@ -25,49 +26,145 @@ def parse_args(args=None, namespace=None): required=False, default=None, help="print model") + parser.add_argument( + "-pre", + "--pretrain", + type=str, + required=False, + default=None, + help="path to the pretraining model") return parser.parse_args(args=args, namespace=namespace) -if __name__ == '__main__': - # parse commands - cmd = parse_args() - - # load config - args = utils.load_config(cmd.config) - print(' > config:', cmd.config) - print(' > exp:', args.env.expdir) - +def train_run(rank, config_path, print_model, pretrain, ddp=False, ddp_device_list=None): + args = utils.load_config(config_path) # load vocoder - vocoder = Vocoder(args.vocoder.type, args.vocoder.ckpt, device=args.device) - + if not ddp: + vocoder = Vocoder(args.vocoder.type, args.vocoder.ckpt, device=f"{args.device}:{args.env.gpu_id}") + else: + vocoder = Vocoder(args.vocoder.type, args.vocoder.ckpt, device=ddp_device_list[rank]) + # load model model = load_svc_model(args=args, vocoder_dimension=vocoder.dimension) - - # load parameters - optimizer = torch.optim.AdamW(model.parameters()) - initial_global_step, model, optimizer = utils.load_model(args.env.expdir, model, optimizer, device=args.device) - for param_group in optimizer.param_groups: - param_group['initial_lr'] = args.train.lr - param_group['lr'] = args.train.lr * args.train.gamma ** max((initial_global_step - 2) // args.train.decay_step, 0) - param_group['weight_decay'] = args.train.weight_decay - scheduler = lr_scheduler.StepLR(optimizer, step_size=args.train.decay_step, gamma=args.train.gamma, last_epoch=initial_global_step-2) - + + # load parameters not ddp + if not ddp: + optimizer = torch.optim.AdamW(model.parameters()) + initial_global_step, model, optimizer = utils.load_model(args.env.expdir, model, optimizer, + device=f"{args.device}:{args.env.gpu_id}") + if pretrain is not None: # 加载预训练模型 + if initial_global_step == 0: + _ckpt = torch.load(pretrain, map_location=torch.device(f"{args.device}:{args.env.gpu_id}"), + weights_only=True) + model.load_state_dict(_ckpt['model'], strict=False) + if _ckpt.get('optimizer') != None: + optimizer = torch.optim.AdamW(model.parameters()) + optimizer.load_state_dict(_ckpt['optimizer']) + else: + optimizer = None + initial_global_step, model, optimizer = utils.load_model(args.env.expdir, model, optimizer, + device=ddp_device_list[rank], model_only=True) + if pretrain is not None: # 加载预训练模型 + if initial_global_step == 0: + _ckpt = torch.load(pretrain, map_location=torch.device(ddp_device_list[rank]),weights_only=True) + model.load_state_dict(_ckpt['model'], strict=False) + # device - if args.device == 'cuda': - torch.cuda.set_device(args.env.gpu_id) - model.to(args.device) + if ddp: + # init + if rank != 0: + time.sleep(5) + torch.distributed.init_process_group( + backend='gloo' if os.name == 'nt' else 'nccl', + init_method='env://', world_size=len(ddp_device_list), + rank=rank + ) + # device + device = ddp_device_list[rank] + torch.cuda.set_device(torch.device(device)) + model = model.to(device) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device]) + model = model.to(device) + optimizer = torch.optim.AdamW(model.parameters()) + if pretrain is not None: # 加载预训练模型 + if initial_global_step == 0: + _ckpt = torch.load(pretrain, map_location=device, weights_only=True) + if _ckpt.get('optimizer') != None: + optimizer.load_state_dict(_ckpt['optimizer']) + else: + optimizer = utils.load_optimizer(args.env.expdir, optimizer, device=device) + else: + optimizer = utils.load_optimizer(args.env.expdir, optimizer, device=device) + else: + optimizer = utils.load_optimizer(args.env.expdir, optimizer, device=device) + else: + device = args.device + if args.device == 'cuda': + torch.cuda.set_device(args.env.gpu_id) + model = model.to(device) # 打印模型结构 - if (str(cmd.print) == 'True') or (str(cmd.print) == 'true'): + if (str(print_model) == 'True') or (str(print_model) == 'true'): print(model) - + for state in optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): - state[k] = v.to(args.device) - - # datas - loader_train, loader_valid = get_data_loaders(args, whole_audio=False) - - # run - train(args, initial_global_step, model, optimizer, scheduler, vocoder, loader_train, loader_valid) - + state[k] = v.to(device) + + # init scheduler + if args.train.last_decay_step is not None: + _lds = args.train.last_decay_step + else: + _lds = 99999999999 + for param_group in optimizer.param_groups: + param_group['initial_lr'] = args.train.lr + param_group['lr'] = args.train.lr * args.train.gamma ** max( + min(_lds,(initial_global_step - 2)) // args.train.decay_step, 0) + param_group['weight_decay'] = args.train.weight_decay + _last_epoch = initial_global_step + if args.train.last_decay_step is not None: + if initial_global_step > args.train.last_decay_step: + _last_epoch = args.train.last_decay_step + scheduler = lr_scheduler.StepLR(optimizer, step_size=args.train.decay_step, gamma=args.train.gamma, + last_epoch=_last_epoch - 2) + + # datas run + if ddp: + loader_train, loader_valid, samper_train = get_data_loaders(args, whole_audio=False, ddp=True, rank=rank, + ddp_cache_gpu=args.ddp.ddp_cache_gpu, + ddp_device_list=ddp_device_list) + if rank != 0: + loader_valid = None + train(rank, args, initial_global_step, model, optimizer, scheduler, vocoder, loader_train, loader_valid, device, + ddp=ddp, samper_train=samper_train) + else: + loader_train, loader_valid, samper_train = get_data_loaders(args, whole_audio=False, ddp=False, rank=0) + train(0, args, initial_global_step, model, optimizer, scheduler, vocoder, loader_train, loader_valid, device, + ddp=ddp, samper_train=samper_train) + + +if __name__ == '__main__': + # parse commands + cmd = parse_args() + + # load config + args = utils.load_config(cmd.config) + print(' > config:', cmd.config) + print(' > exp:', args.env.expdir) + + if args.ddp.use_ddp: + # device + device_list = args.ddp.ddp_device + device_ids = [] + for device in device_list: + _device_ids = device.split(':')[-1] + device_ids.append(int(_device_ids)) + # init gloo or nccl + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = args.ddp.port + # run + torch.multiprocessing.set_start_method('spawn') + torch.multiprocessing.spawn(train_run, args=(cmd.config, cmd.print, cmd.pretrain, True, device_list), nprocs=len(device_ids)) + + else: + train_run(0, cmd.config, cmd.print, cmd.pretrain, ddp=False)