Skip to content

Commit

Permalink
Merge from release.
Browse files Browse the repository at this point in the history
  • Loading branch information
KPatr1ck committed Mar 14, 2022
2 parents 2063340 + 2c609e0 commit a25574b
Show file tree
Hide file tree
Showing 357 changed files with 43,210 additions and 427 deletions.
98 changes: 98 additions & 0 deletions modules/audio/keyword_spotting/kwmlp_speech_commands/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# kwmlp_speech_commands

|模型名称|kwmlp_speech_commands|
| :--- | :---: |
|类别|语音-语言识别|
|网络|Keyword-MLP|
|数据集|Google Speech Commands V2|
|是否支持Fine-tuning||
|模型大小|1.6MB|
|最新更新日期|2022-01-04|
|数据指标|ACC 97.56%|

## 一、模型基本信息

### 模型介绍

kwmlp_speech_commands采用了 [Keyword-MLP](https://arxiv.org/pdf/2110.07749v1.pdf) 的轻量级模型结构,并在 [Google Speech Commands V2](https://arxiv.org/abs/1804.03209) 数据集上进行了预训练,在其测试集的测试结果为 ACC 97.56%。

<p align="center">
<img src="https://d3i71xaburhd42.cloudfront.net/fa690a97f76ba119ca08fb02fa524a546c47f031/2-Figure1-1.png" hspace='10' height="550"/> <br />
</p>


更多详情请参考
- [Speech Commands: A Dataset for Limited-Vocabulary Speech Recognition](https://arxiv.org/abs/1804.03209)
- [ATTENTION-FREE KEYWORD SPOTTING](https://arxiv.org/pdf/2110.07749v1.pdf)
- [Keyword-MLP](https://github.com/AI-Research-BD/Keyword-MLP)


## 二、安装

- ### 1、环境依赖

- paddlepaddle >= 2.2.0

- paddlehub >= 2.2.0 | [如何安装PaddleHub](../../../../docs/docs_ch/get_start/installation.rst)

- ### 2、安装

- ```shell
$ hub install kwmlp_speech_commands
```
- 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md)
| [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md)


## 三、模型API预测

- ### 1、预测代码示例

```python
import paddlehub as hub
model = hub.Module(
name='kwmlp_speech_commands',
version='1.0.0')
# 通过下列链接可下载示例音频
# https://paddlehub.bj.bcebos.com/paddlehub_dev/go.wav
# Keyword spotting
score, label = model.keyword_recognize('no.wav')
print(score, label)
# [0.89498246] no
score, label = model.keyword_recognize('go.wav')
print(score, label)
# [0.8997176] go
score, label = model.keyword_recognize('one.wav')
print(score, label)
# [0.88598305] one
```

- ### 2、API
- ```python
def keyword_recognize(
wav: os.PathLike,
)
```
- 检测音频中包含的关键词。

- **参数**

- `wav`:输入的包含关键词的音频文件,格式为`*.wav`

- **返回**

- 输出结果的得分和对应的关键词标签。


## 四、更新历史

* 1.0.0

初始发布

```shell
$ hub install kwmlp_speech_commands
```
13 changes: 13 additions & 0 deletions modules/audio/keyword_spotting/kwmlp_speech_commands/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
59 changes: 59 additions & 0 deletions modules/audio/keyword_spotting/kwmlp_speech_commands/feature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math

import numpy as np
import paddle
import paddleaudio


def create_dct(n_mfcc: int, n_mels: int, norm: str = 'ortho'):
n = paddle.arange(float(n_mels))
k = paddle.arange(float(n_mfcc)).unsqueeze(1)
dct = paddle.cos(math.pi / float(n_mels) * (n + 0.5) * k) # size (n_mfcc, n_mels)
if norm is None:
dct *= 2.0
else:
assert norm == "ortho"
dct[0] *= 1.0 / math.sqrt(2.0)
dct *= math.sqrt(2.0 / float(n_mels))
return dct.t()


def compute_mfcc(
x: paddle.Tensor,
sr: int = 16000,
n_mels: int = 40,
n_fft: int = 480,
win_length: int = 480,
hop_length: int = 160,
f_min: float = 0.0,
f_max: float = None,
center: bool = False,
top_db: float = 80.0,
norm: str = 'ortho',
):
fbank = paddleaudio.features.spectrum.MelSpectrogram(
sr=sr,
n_mels=n_mels,
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
f_min=0.0,
f_max=f_max,
center=center)(x) # waveforms batch ~ (B, T)
log_fbank = paddleaudio.features.spectrum.power_to_db(fbank, top_db=top_db)
dct_matrix = create_dct(n_mfcc=n_mels, n_mels=n_mels, norm=norm)
mfcc = paddle.matmul(log_fbank.transpose((0, 2, 1)), dct_matrix).transpose((0, 2, 1)) # (B, n_mels, L)
return mfcc
143 changes: 143 additions & 0 deletions modules/audio/keyword_spotting/kwmlp_speech_commands/kwmlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F


class Residual(nn.Layer):
def __init__(self, fn):
super().__init__()
self.fn = fn

def forward(self, x):
return self.fn(x) + x


class PreNorm(nn.Layer):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)

def forward(self, x, **kwargs):
x = self.norm(x)
return self.fn(x, **kwargs)


class PostNorm(nn.Layer):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn

def forward(self, x, **kwargs):
return self.norm(self.fn(x, **kwargs))


class SpatialGatingUnit(nn.Layer):
def __init__(self, dim, dim_seq, act=nn.Identity(), init_eps=1e-3):
super().__init__()
dim_out = dim // 2

self.norm = nn.LayerNorm(dim_out)
self.proj = nn.Conv1D(dim_seq, dim_seq, 1)

self.act = act

init_eps /= dim_seq

def forward(self, x):
res, gate = x.split(2, axis=-1)
gate = self.norm(gate)

weight, bias = self.proj.weight, self.proj.bias
gate = F.conv1d(gate, weight, bias)

return self.act(gate) * res


class gMLPBlock(nn.Layer):
def __init__(self, *, dim, dim_ff, seq_len, act=nn.Identity()):
super().__init__()
self.proj_in = nn.Sequential(nn.Linear(dim, dim_ff), nn.GELU())

self.sgu = SpatialGatingUnit(dim_ff, seq_len, act)
self.proj_out = nn.Linear(dim_ff // 2, dim)

def forward(self, x):
x = self.proj_in(x)
x = self.sgu(x)
x = self.proj_out(x)
return x


class Rearrange(nn.Layer):
def __init__(self):
super().__init__()

def forward(self, x):
x = x.transpose([0, 1, 3, 2]).squeeze(1)
return x


class Reduce(nn.Layer):
def __init__(self, axis=1):
super().__init__()
self.axis = axis

def forward(self, x):
x = x.mean(axis=self.axis, keepdim=False)
return x


class KW_MLP(nn.Layer):
"""Keyword-MLP."""

def __init__(self,
input_res=[40, 98],
patch_res=[40, 1],
num_classes=35,
dim=64,
depth=12,
ff_mult=4,
channels=1,
prob_survival=0.9,
pre_norm=False,
**kwargs):
super().__init__()
image_height, image_width = input_res
patch_height, patch_width = patch_res
assert (image_height % patch_height) == 0 and (
image_width % patch_width) == 0, 'image height and width must be divisible by patch size'
num_patches = (image_height // patch_height) * (image_width // patch_width)

P_Norm = PreNorm if pre_norm else PostNorm

dim_ff = dim * ff_mult

self.to_patch_embed = nn.Sequential(Rearrange(), nn.Linear(channels * patch_height * patch_width, dim))

self.prob_survival = prob_survival

self.layers = nn.LayerList(
[Residual(P_Norm(dim, gMLPBlock(dim=dim, dim_ff=dim_ff, seq_len=num_patches))) for i in range(depth)])

self.to_logits = nn.Sequential(nn.LayerNorm(dim), Reduce(axis=1), nn.Linear(dim, num_classes))

def forward(self, x):
x = self.to_patch_embed(x)
layers = self.layers
x = nn.Sequential(*layers)(x)
return self.to_logits(x)
Loading

0 comments on commit a25574b

Please sign in to comment.