Skip to content

Commit

Permalink
add silero vad model for sense voice asr
Browse files Browse the repository at this point in the history
  • Loading branch information
lovemefan committed Dec 8, 2024
1 parent 40d0d46 commit dab6c6e
Show file tree
Hide file tree
Showing 11 changed files with 621 additions and 60 deletions.
27 changes: 26 additions & 1 deletion scripts/convert-pt-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
)

self.model_checkpoint = "model.pt"
self.vad_model_checkpoint = 'silero_vad.pt'
self.hparams = Model.load_hparams(self.dir_model)
self.gguf_writer = gguf.GGUFWriter(
fname_out,
Expand Down Expand Up @@ -105,9 +106,24 @@ def set_vocab(self):
raise NotImplementedError

def get_tensors(self) -> Iterator[tuple[str, Tensor]]:

print(f"gguf: loading model part '{self.model_checkpoint}'")
print(f"gguf: loading vad model part '{self.vad_model_checkpoint}'")
ctx: ContextManager[Any]

ctx = contextlib.nullcontext(
torch.load(
str(self.dir_model / self.vad_model_checkpoint),
map_location="cpu",
mmap=True,
weights_only=True,
)
)

with ctx as model_part:
for name, data in model_part.items():
yield name, data

ctx = contextlib.nullcontext(
torch.load(
str(self.dir_model / self.model_checkpoint),
Expand Down Expand Up @@ -254,7 +270,7 @@ def write_one_tensor(self, data_torch, name):
if data_torch.dtype not in (torch.float16, torch.float32):
data_torch = data_torch.to(torch.float32)

_data = data_torch.squeeze().numpy()
_data = data_torch.numpy()
# use max to avoid n_dim of single tensor become 0
if len(_data.shape) != 0:
data = _data
Expand All @@ -281,6 +297,15 @@ def write_one_tensor(self, data_torch, name):
):
data = data.astype(np.float16)

if self.ftype == 0 and name in [
'_model.stft.forward_basis_buffer.weight',
'_model.encoder.0.reparam_conv.weight',
'_model.encoder.1.reparam_conv.weight',
'_model.encoder.2.reparam_conv.weight',
'_model.encoder.3.reparam_conv.weight',
'_model.decoder.decoder.2.weight'
]:
data = data.astype(np.float16)

print(
f"|{name}| n_dims = {n_dims}| {old_dtype} | {data.dtype} | {data.size}|"
Expand Down
2 changes: 2 additions & 0 deletions sense-voice/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ set(SOURCE_FILES
sense-voice-decoder.cc
sense-voice.h
sense-voice.cc
silero-vad.h
silero-vad.cc
)

add_library(sense-voice-core STATIC ${SOURCE_FILES})
Expand Down
22 changes: 22 additions & 0 deletions sense-voice/csrc/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,25 @@ struct sense_voice_full_params sense_voice_full_default_params(enum sense_voice_

return result;
}

bool ggml_graph_compute_helper(
ggml_backend_sched_t sched,
struct ggml_cgraph * graph,
int n_threads) {

for (int i = 0; i < ggml_backend_sched_get_n_backends(sched); ++i) {
ggml_backend_t backend = ggml_backend_sched_get_backend(sched, i);
ggml_backend_dev_t dev = ggml_backend_get_device(backend);
ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;

auto * fn_set_n_threads = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
if (fn_set_n_threads) {
fn_set_n_threads(backend, n_threads);
}
}


bool t = ggml_backend_sched_graph_compute(sched, graph) == GGML_STATUS_SUCCESS;
ggml_backend_sched_reset(sched);
return t;
}
31 changes: 30 additions & 1 deletion sense-voice/csrc/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include "sense-voice-frontend.h"


#ifdef __GNUC__
#define SENSEVOICE_DEPRECATED(func, hint) func __attribute__((deprecated(hint)))
#elif defined(_MSC_VER)
Expand Down Expand Up @@ -202,6 +203,18 @@ struct sense_voice {

};

struct silero_vad;

struct silero_vad_model {
struct silero_vad *model;
// context
struct ggml_context *ctx;

// the model backend data is read-only and can be shared between processors
ggml_backend_buffer_t buffer = nullptr;
};


static const std::map<std::string, std::pair<int, std::string>> g_lang = {
{ "auto", { 0, "auto", } },
{ "zh", { 3, "chinese", } },
Expand All @@ -226,7 +239,7 @@ struct sense_voice_hparams {
int n_decoder_attention_heads = 4;
int n_decoder_layers = 14;
int fsmn_kernel_size = 11;

int n_vad_encoder_layers = 4;
int n_predictor_dim = 512;
float predictor_tail_threshold = 0.45;

Expand Down Expand Up @@ -362,9 +375,21 @@ struct sense_voice_state {

std::vector<ggml_backend_t> backends;

sense_voice_sched sched_vad;
sense_voice_sched sched_vad_sate;
sense_voice_sched sched_encode;
sense_voice_sched sched_decode;

// hidden state in vad lstm
ggml_context * vad_ctx = nullptr;
struct ggml_tensor * vad_lstm_hidden_state;
struct ggml_tensor * vad_lstm_context;
ggml_backend_buffer_t vad_lstm_hidden_state_buffer = nullptr;
ggml_backend_buffer_t vad_lstm_context_buffer = nullptr;

ggml_cgraph *sense_voice_encoder_graph;
ggml_cgraph *sense_voice_decoder_graph;

// result of the encoder
struct ggml_tensor *encoder_out = nullptr;

Expand Down Expand Up @@ -443,6 +468,7 @@ struct sense_voice_full_params {
struct sense_voice_model {
std::string model_type;
sense_voice_hparams hparams;
silero_vad *vad_model;
sense_voice *model;
// context
struct ggml_context *ctx;
Expand Down Expand Up @@ -472,6 +498,7 @@ struct sense_voice_context {
ggml_type itype =
ggml_type::GGML_TYPE_F16; // intermediate type (FP32 or FP16)

silero_vad_model vad_model;
sense_voice_model model;
sense_voice_vocab vocab;

Expand All @@ -483,5 +510,7 @@ struct sense_voice_context {
};

struct sense_voice_full_params sense_voice_full_default_params(enum sense_voice_decoding_strategy strategy);
bool ggml_graph_compute_helper(ggml_backend_sched_t sched, struct ggml_cgraph * graph, int n_threads);


#endif//SENSEVOICE_CPP_COMMON_H
Loading

0 comments on commit dab6c6e

Please sign in to comment.