Skip to content

Commit

Permalink
added softquentization to BWE models
Browse files Browse the repository at this point in the history
  • Loading branch information
janpbuethe committed Feb 6, 2025
1 parent 99f871e commit 313930b
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 11 deletions.
2 changes: 1 addition & 1 deletion dnn/torch/osce/adv_train_bwe_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def criterion(x, y, x_up):

print(f"generator: {count_parameters(model.cpu()) / 1e6:5.3f} M parameters")
if hasattr(model, 'flop_count'):
print(f"generator: {model.flop_count(48000) / 1e6:5.3f} MFLOPS")
print(f"generator: {model.flop_count(16000) / 1e6:5.3f} MFLOPS")
print(f"discriminator: {count_parameters(disc.cpu()) / 1e6:5.3f} M parameters")


Expand Down
26 changes: 17 additions & 9 deletions dnn/torch/osce/models/bbwe_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from utils.layers.silk_upsampler import SilkUpsampler
from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
from utils.layers.td_shaper import TDShaper

from dnntools.quantization.softquant import soft_quant

DUMP=False

Expand All @@ -30,7 +30,8 @@ def __init__(self,
feature_dim=84,
num_channels=256,
upsamp_factor=2,
lookahead=False):
lookahead=False,
softquant=False):

super().__init__()

Expand All @@ -46,6 +47,11 @@ def __init__(self,

self.tconv = nn.ConvTranspose1d(num_channels, num_channels, upsamp_factor, upsamp_factor)

if softquant:
self.conv2 = soft_quant(self.conv2)
self.gru = soft_quant(self.gru, names=['weight_hh_l0', 'weight_ih_l0'])
self.tconv = soft_quant(self.tconv)

def flop_count(self, rate=100):
count = 0
for conv in self.conv1, self.conv2, self.tconv:
Expand Down Expand Up @@ -125,6 +131,8 @@ def __init__(self,
func_extension=True,
shaper='TDShaper',
bias=False,
softquant=False,
lookahead=False,
):

super().__init__()
Expand Down Expand Up @@ -152,14 +160,14 @@ def __init__(self,
self.upsampler = SilkUpsampler()

# feature net
self.feature_net = FloatFeatureNet(feature_dim=feature_dim, num_channels=cond_dim)
self.feature_net = FloatFeatureNet(feature_dim=feature_dim, num_channels=cond_dim, softquant=softquant, lookahead=lookahead)

# non-linear transforms

if self.shape_extension:
if self.shaper == 'TDShaper':
self.tdshape1 = TDShaper(cond_dim, frame_size=self.frame_size32, avg_pool_k=avg_pool_k32, interpolate_k=interpolate_k32, bias=bias)
self.tdshape2 = TDShaper(cond_dim, frame_size=self.frame_size48, avg_pool_k=avg_pool_k48, interpolate_k=interpolate_k48, bias=bias)
self.tdshape1 = TDShaper(cond_dim, frame_size=self.frame_size32, avg_pool_k=avg_pool_k32, interpolate_k=interpolate_k32, bias=bias, softquant=softquant)
self.tdshape2 = TDShaper(cond_dim, frame_size=self.frame_size48, avg_pool_k=avg_pool_k48, interpolate_k=interpolate_k48, bias=bias, softquant=softquant)
elif self.shaper == 'Folder':
self.tdshape1 = Folder(8, frame_size=self.frame_size32)
self.tdshape2 = Folder(12, frame_size=self.frame_size48)
Expand All @@ -178,9 +186,9 @@ def __init__(self,
if self.func_extension: latent_channels += 1

# spectral shaping
self.af1 = LimitedAdaptiveConv1d(1, latent_channels, self.kernel_size16, cond_dim, frame_size=self.frame_size16, overlap_size=self.frame_size16//2, use_bias=False, padding=[self.kernel_size16 - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=2)
self.af2 = LimitedAdaptiveConv1d(latent_channels, latent_channels, self.kernel_size32, cond_dim, frame_size=self.frame_size32, overlap_size=self.frame_size32//2, use_bias=False, padding=[self.kernel_size32 - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=2)
self.af3 = LimitedAdaptiveConv1d(latent_channels, 1, self.kernel_size48, cond_dim, frame_size=self.frame_size48, overlap_size=self.frame_size48//2, use_bias=False, padding=[self.kernel_size48 - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=2)
self.af1 = LimitedAdaptiveConv1d(1, latent_channels, self.kernel_size16, cond_dim, frame_size=self.frame_size16, overlap_size=self.frame_size16//2, use_bias=False, padding=[self.kernel_size16 - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=2, softquant=softquant)
self.af2 = LimitedAdaptiveConv1d(latent_channels, latent_channels, self.kernel_size32, cond_dim, frame_size=self.frame_size32, overlap_size=self.frame_size32//2, use_bias=False, padding=[self.kernel_size32 - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=2, softquant=softquant)
self.af3 = LimitedAdaptiveConv1d(latent_channels, 1, self.kernel_size48, cond_dim, frame_size=self.frame_size48, overlap_size=self.frame_size48//2, use_bias=False, padding=[self.kernel_size48 - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=2, softquant=softquant)


def flop_count(self, rate=16000, verbose=False):
Expand Down Expand Up @@ -246,4 +254,4 @@ def forward(self, x, features, debug=False):
# 2nd mixing
y48_out = self.af3(y48_out, cf)

return y48_out
return y48_out
2 changes: 1 addition & 1 deletion dnn/torch/osce/train_bwe_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def criterion(x, y, x_up):

print(f"{count_parameters(model.cpu()) / 1e6:5.3f} M parameters")
if hasattr(model, 'flop_count'):
print(f"{model.flop_count(48000) / 1e6:5.3f} MFLOPS")
print(f"{model.flop_count(16000) / 1e6:5.3f} MFLOPS")


best_loss = 1e9
Expand Down

0 comments on commit 313930b

Please sign in to comment.