From 313930b654a475e019db9791ffada265302d1a2f Mon Sep 17 00:00:00 2001 From: Jan Buethe Date: Wed, 5 Feb 2025 16:48:06 -0800 Subject: [PATCH] added softquentization to BWE models --- dnn/torch/osce/adv_train_bwe_model.py | 2 +- dnn/torch/osce/models/bbwe_net.py | 26 +++++++++++++++++--------- dnn/torch/osce/train_bwe_model.py | 2 +- 3 files changed, 19 insertions(+), 11 deletions(-) diff --git a/dnn/torch/osce/adv_train_bwe_model.py b/dnn/torch/osce/adv_train_bwe_model.py index 0d1681e2d..e1a55778a 100644 --- a/dnn/torch/osce/adv_train_bwe_model.py +++ b/dnn/torch/osce/adv_train_bwe_model.py @@ -315,7 +315,7 @@ def criterion(x, y, x_up): print(f"generator: {count_parameters(model.cpu()) / 1e6:5.3f} M parameters") if hasattr(model, 'flop_count'): - print(f"generator: {model.flop_count(48000) / 1e6:5.3f} MFLOPS") + print(f"generator: {model.flop_count(16000) / 1e6:5.3f} MFLOPS") print(f"discriminator: {count_parameters(disc.cpu()) / 1e6:5.3f} M parameters") diff --git a/dnn/torch/osce/models/bbwe_net.py b/dnn/torch/osce/models/bbwe_net.py index fb9b1f349..c89e84e77 100644 --- a/dnn/torch/osce/models/bbwe_net.py +++ b/dnn/torch/osce/models/bbwe_net.py @@ -6,7 +6,7 @@ from utils.layers.silk_upsampler import SilkUpsampler from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d from utils.layers.td_shaper import TDShaper - +from dnntools.quantization.softquant import soft_quant DUMP=False @@ -30,7 +30,8 @@ def __init__(self, feature_dim=84, num_channels=256, upsamp_factor=2, - lookahead=False): + lookahead=False, + softquant=False): super().__init__() @@ -46,6 +47,11 @@ def __init__(self, self.tconv = nn.ConvTranspose1d(num_channels, num_channels, upsamp_factor, upsamp_factor) + if softquant: + self.conv2 = soft_quant(self.conv2) + self.gru = soft_quant(self.gru, names=['weight_hh_l0', 'weight_ih_l0']) + self.tconv = soft_quant(self.tconv) + def flop_count(self, rate=100): count = 0 for conv in self.conv1, self.conv2, self.tconv: @@ -125,6 +131,8 @@ def __init__(self, func_extension=True, shaper='TDShaper', bias=False, + softquant=False, + lookahead=False, ): super().__init__() @@ -152,14 +160,14 @@ def __init__(self, self.upsampler = SilkUpsampler() # feature net - self.feature_net = FloatFeatureNet(feature_dim=feature_dim, num_channels=cond_dim) + self.feature_net = FloatFeatureNet(feature_dim=feature_dim, num_channels=cond_dim, softquant=softquant, lookahead=lookahead) # non-linear transforms if self.shape_extension: if self.shaper == 'TDShaper': - self.tdshape1 = TDShaper(cond_dim, frame_size=self.frame_size32, avg_pool_k=avg_pool_k32, interpolate_k=interpolate_k32, bias=bias) - self.tdshape2 = TDShaper(cond_dim, frame_size=self.frame_size48, avg_pool_k=avg_pool_k48, interpolate_k=interpolate_k48, bias=bias) + self.tdshape1 = TDShaper(cond_dim, frame_size=self.frame_size32, avg_pool_k=avg_pool_k32, interpolate_k=interpolate_k32, bias=bias, softquant=softquant) + self.tdshape2 = TDShaper(cond_dim, frame_size=self.frame_size48, avg_pool_k=avg_pool_k48, interpolate_k=interpolate_k48, bias=bias, softquant=softquant) elif self.shaper == 'Folder': self.tdshape1 = Folder(8, frame_size=self.frame_size32) self.tdshape2 = Folder(12, frame_size=self.frame_size48) @@ -178,9 +186,9 @@ def __init__(self, if self.func_extension: latent_channels += 1 # spectral shaping - self.af1 = LimitedAdaptiveConv1d(1, latent_channels, self.kernel_size16, cond_dim, frame_size=self.frame_size16, overlap_size=self.frame_size16//2, use_bias=False, padding=[self.kernel_size16 - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=2) - self.af2 = LimitedAdaptiveConv1d(latent_channels, latent_channels, self.kernel_size32, cond_dim, frame_size=self.frame_size32, overlap_size=self.frame_size32//2, use_bias=False, padding=[self.kernel_size32 - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=2) - self.af3 = LimitedAdaptiveConv1d(latent_channels, 1, self.kernel_size48, cond_dim, frame_size=self.frame_size48, overlap_size=self.frame_size48//2, use_bias=False, padding=[self.kernel_size48 - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=2) + self.af1 = LimitedAdaptiveConv1d(1, latent_channels, self.kernel_size16, cond_dim, frame_size=self.frame_size16, overlap_size=self.frame_size16//2, use_bias=False, padding=[self.kernel_size16 - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=2, softquant=softquant) + self.af2 = LimitedAdaptiveConv1d(latent_channels, latent_channels, self.kernel_size32, cond_dim, frame_size=self.frame_size32, overlap_size=self.frame_size32//2, use_bias=False, padding=[self.kernel_size32 - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=2, softquant=softquant) + self.af3 = LimitedAdaptiveConv1d(latent_channels, 1, self.kernel_size48, cond_dim, frame_size=self.frame_size48, overlap_size=self.frame_size48//2, use_bias=False, padding=[self.kernel_size48 - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=2, softquant=softquant) def flop_count(self, rate=16000, verbose=False): @@ -246,4 +254,4 @@ def forward(self, x, features, debug=False): # 2nd mixing y48_out = self.af3(y48_out, cf) - return y48_out \ No newline at end of file + return y48_out diff --git a/dnn/torch/osce/train_bwe_model.py b/dnn/torch/osce/train_bwe_model.py index 16bc0226d..268fa6fd8 100644 --- a/dnn/torch/osce/train_bwe_model.py +++ b/dnn/torch/osce/train_bwe_model.py @@ -263,7 +263,7 @@ def criterion(x, y, x_up): print(f"{count_parameters(model.cpu()) / 1e6:5.3f} M parameters") if hasattr(model, 'flop_count'): - print(f"{model.flop_count(48000) / 1e6:5.3f} MFLOPS") + print(f"{model.flop_count(16000) / 1e6:5.3f} MFLOPS") best_loss = 1e9