From da7e18eb1e382f4a1b5ecee0bbdc5a3ac73faf35 Mon Sep 17 00:00:00 2001 From: dialogflowchatbot Date: Sun, 8 Sep 2024 12:36:54 +0800 Subject: [PATCH] Added TTS and ASR module --- Cargo.toml | 10 +- src/ai/asr.rs | 9 + src/ai/audio.rs | 31 +++ src/ai/bs1770.rs | 506 ++++++++++++++++++++++++++++++++++++++++++ src/ai/embedding.rs | 10 +- src/ai/huggingface.rs | 61 ++++- src/ai/llama.rs | 14 +- src/ai/mod.rs | 4 + src/ai/tts.rs | 9 + src/man/settings.rs | 60 ++++- 10 files changed, 697 insertions(+), 17 deletions(-) create mode 100644 src/ai/asr.rs create mode 100644 src/ai/audio.rs create mode 100644 src/ai/bs1770.rs create mode 100644 src/ai/tts.rs diff --git a/Cargo.toml b/Cargo.toml index a3eb683..3f5c695 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "dialogflow" -version = "1.16.1" +version = "1.17.0" edition = "2021" homepage = "https://dialogflowchatbot.github.io/" authors = ["dialogflowchatbot "] @@ -18,10 +18,12 @@ anyhow = "1.0" axum = {version = "0.7", features = ["query", "tokio", "macros"]} bigdecimal = "0.4" # candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.5.1" } -candle = { version = "0.6", package = "candle-core", default-features = false } -candle-nn = "0.6" +# candle = { version = "0.6", package = "candle-core", default-features = false } +candle = { git = "https://github.com/huggingface/candle.git", package = "candle-core", default-features = false } +candle-nn = { git = "https://github.com/huggingface/candle.git" } # candle-onnx = "0.6" -candle-transformers = { version = "0.6" } +candle-transformers = { git = "https://github.com/huggingface/candle.git" } +# candle-transformers = { version = "0.6" } # candle-transformers = { version = "0.6", features = ["flash-attn"] } # crossbeam-channel = "0.5" frand = "0.10" diff --git a/src/ai/asr.rs b/src/ai/asr.rs new file mode 100644 index 0000000..40690b2 --- /dev/null +++ b/src/ai/asr.rs @@ -0,0 +1,9 @@ +use serde::{Deserialize, Serialize}; + +use super::huggingface::{load_bert_model_files, HuggingFaceModel, HuggingFaceModelInfo}; + +#[derive(Clone, Deserialize, Serialize)] +#[serde(tag = "id", content = "model")] +pub(crate) enum AsrProvider { + HuggingFace(HuggingFaceModel), +} diff --git a/src/ai/audio.rs b/src/ai/audio.rs new file mode 100644 index 0000000..56475f9 --- /dev/null +++ b/src/ai/audio.rs @@ -0,0 +1,31 @@ +use candle::Tensor; + +use crate::result::Result; + +pub fn normalize_loudness( + wav: &Tensor, + sample_rate: u32, + loudness_compressor: bool, +) -> Result { + let energy = wav.sqr()?.mean_all()?.sqrt()?.to_vec0::()?; + if energy < 2e-3 { + return Ok(wav.clone()); + } + let wav_array = wav.to_vec1::()?; + let mut meter = super::bs1770::ChannelLoudnessMeter::new(sample_rate); + meter.push(wav_array.into_iter()); + let power = meter.as_100ms_windows(); + let loudness = match super::bs1770::gated_mean(power) { + None => return Ok(wav.clone()), + Some(gp) => gp.loudness_lkfs() as f64, + }; + let delta_loudness = -14. - loudness; + let gain = 10f64.powf(delta_loudness / 20.); + let wav = (wav * gain)?; + if loudness_compressor { + let r = wav.tanh()?; + Ok(r) + } else { + Ok(wav) + } +} diff --git a/src/ai/bs1770.rs b/src/ai/bs1770.rs new file mode 100644 index 0000000..fbda6df --- /dev/null +++ b/src/ai/bs1770.rs @@ -0,0 +1,506 @@ +// Copied from https://github.com/ruuda/bs1770/blob/master/src/lib.rs +// BS1770 -- Loudness analysis library conforming to ITU-R BS.1770 +// Copyright 2020 Ruud van Asseldonk + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// A copy of the License has been included in the root of the repository. + +//! Loudness analysis conforming to [ITU-R BS.1770-4][bs17704]. +//! +//! This library offers the building blocks to perform BS.1770 loudness +//! measurements, but you need to put the pieces together yourself. +//! +//! [bs17704]: https://www.itu.int/rec/R-REC-BS.1770-4-201510-I/en +//! +//! # Stereo integrated loudness example +//! +//! ```ignore +//! # fn load_stereo_audio() -> [Vec; 2] { +//! # [vec![0; 48_000], vec![0; 48_000]] +//! # } +//! # +//! let sample_rate_hz = 44_100; +//! let bits_per_sample = 16; +//! let channel_samples: [Vec; 2] = load_stereo_audio(); +//! +//! // When converting integer samples to float, note that the maximum amplitude +//! // is `1 << (bits_per_sample - 1)`, one bit is the sign bit. +//! let normalizer = 1.0 / (1_u64 << (bits_per_sample - 1)) as f32; +//! +//! let channel_power: Vec<_> = channel_samples.iter().map(|samples| { +//! let mut meter = bs1770::ChannelLoudnessMeter::new(sample_rate_hz); +//! meter.push(samples.iter().map(|&s| s as f32 * normalizer)); +//! meter.into_100ms_windows() +//! }).collect(); +//! +//! let stereo_power = bs1770::reduce_stereo( +//! channel_power[0].as_ref(), +//! channel_power[1].as_ref(), +//! ); +//! +//! let gated_power = bs1770::gated_mean( +//! stereo_power.as_ref() +//! ).unwrap_or(bs1770::Power(0.0)); +//! println!("Integrated loudness: {:.1} LUFS", gated_power.loudness_lkfs()); +//! ``` + +use std::f32; + +/// Coefficients for a 2nd-degree infinite impulse response filter. +/// +/// Coefficient a0 is implicitly 1.0. +#[derive(Clone)] +struct Filter { + a1: f32, + a2: f32, + b0: f32, + b1: f32, + b2: f32, + + // The past two input and output samples. + x1: f32, + x2: f32, + y1: f32, + y2: f32, +} + +impl Filter { + /// Stage 1 of th BS.1770-4 pre-filter. + pub fn high_shelf(sample_rate_hz: f32) -> Filter { + // Coefficients taken from https://github.com/csteinmetz1/pyloudnorm/blob/ + // 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/meter.py#L135-L136. + let gain_db = 3.999_843_8; + let q = 0.707_175_25; + let center_hz = 1_681.974_5; + + // Formula taken from https://github.com/csteinmetz1/pyloudnorm/blob/ + // 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/iirfilter.py#L134-L143. + let k = (f32::consts::PI * center_hz / sample_rate_hz).tan(); + let vh = 10.0_f32.powf(gain_db / 20.0); + let vb = vh.powf(0.499_666_78); + let a0 = 1.0 + k / q + k * k; + Filter { + b0: (vh + vb * k / q + k * k) / a0, + b1: 2.0 * (k * k - vh) / a0, + b2: (vh - vb * k / q + k * k) / a0, + a1: 2.0 * (k * k - 1.0) / a0, + a2: (1.0 - k / q + k * k) / a0, + + x1: 0.0, + x2: 0.0, + y1: 0.0, + y2: 0.0, + } + } + + /// Stage 2 of th BS.1770-4 pre-filter. + pub fn high_pass(sample_rate_hz: f32) -> Filter { + // Coefficients taken from https://github.com/csteinmetz1/pyloudnorm/blob/ + // 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/meter.py#L135-L136. + let q = 0.500_327_05; + let center_hz = 38.135_47; + + // Formula taken from https://github.com/csteinmetz1/pyloudnorm/blob/ + // 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/iirfilter.py#L145-L151 + let k = (f32::consts::PI * center_hz / sample_rate_hz).tan(); + Filter { + a1: 2.0 * (k * k - 1.0) / (1.0 + k / q + k * k), + a2: (1.0 - k / q + k * k) / (1.0 + k / q + k * k), + b0: 1.0, + b1: -2.0, + b2: 1.0, + + x1: 0.0, + x2: 0.0, + y1: 0.0, + y2: 0.0, + } + } + + /// Feed the next input sample, get the next output sample. + #[inline(always)] + pub fn apply(&mut self, x0: f32) -> f32 { + let y0 = 0.0 + self.b0 * x0 + self.b1 * self.x1 + self.b2 * self.x2 + - self.a1 * self.y1 + - self.a2 * self.y2; + + self.x2 = self.x1; + self.x1 = x0; + self.y2 = self.y1; + self.y1 = y0; + + y0 + } +} + +/// Compensated sum, for summing many values of different orders of magnitude +/// accurately. +#[derive(Copy, Clone, PartialEq)] +struct Sum { + sum: f32, + residue: f32, +} + +impl Sum { + #[inline(always)] + fn zero() -> Sum { + Sum { + sum: 0.0, + residue: 0.0, + } + } + + #[inline(always)] + fn add(&mut self, x: f32) { + let sum = self.sum + (self.residue + x); + self.residue = (self.residue + x) - (sum - self.sum); + self.sum = sum; + } +} + +/// The mean of the squares of the K-weighted samples in a window of time. +/// +/// K-weighted power is equivalent to K-weighted loudness, the only difference +/// is one of scale: power is quadratic in sample amplitudes, whereas loudness +/// units are logarithmic. `loudness_lkfs` and `from_lkfs` convert between power, +/// and K-weighted Loudness Units relative to nominal Full Scale (LKFS). +/// +/// The term “LKFS” (Loudness Units, K-Weighted, relative to nominal Full Scale) +/// is used in BS.1770-4 to emphasize K-weighting, but the term is otherwise +/// interchangeable with the more widespread term “LUFS” (Loudness Units, +/// relative to Full Scale). Loudness units are related to decibels in the +/// following sense: boosting a signal that has a loudness of +/// -LK LUFS by LK dB (by +/// multiplying the amplitude by 10LK/20) will +/// bring the loudness to 0 LUFS. +/// +/// K-weighting refers to a high-shelf and high-pass filter that model the +/// effect that humans perceive a certain amount of power in low frequencies to +/// be less loud than the same amount of power in higher frequencies. In this +/// library the `Power` type is used exclusively to refer to power after applying K-weighting. +/// +/// The nominal “full scale” is the range [-1.0, 1.0]. Because the power is the +/// mean square of the samples, if no input samples exceeded the full scale, the +/// power will be in the range [0.0, 1.0]. However, the power delivered by +/// multiple channels, which is a weighted sum over individual channel powers, +/// can exceed this range, because the weighted sum is not normalized. +#[derive(Copy, Clone, PartialEq, PartialOrd)] +pub struct Power(pub f32); + +impl Power { + /// Convert Loudness Units relative to Full Scale into a squared sample amplitude. + /// + /// This is the inverse of `loudness_lkfs`. + pub fn from_lkfs(lkfs: f32) -> Power { + // The inverse of the formula below. + Power(10.0_f32.powf((lkfs + 0.691) * 0.1)) + } + + /// Return the loudness of this window in Loudness Units, K-weighted, relative to Full Scale. + /// + /// This is the inverse of `from_lkfs`. + pub fn loudness_lkfs(&self) -> f32 { + // Equation 2 (p.5) of BS.1770-4. + -0.691 + 10.0 * self.0.log10() + } +} + +/// A `T` value for non-overlapping windows of audio, 100ms in length. +/// +/// The `ChannelLoudnessMeter` applies K-weighting and then produces the power +/// for non-overlapping windows of 100ms duration. +/// +/// These non-overlapping 100ms windows can later be combined into overlapping +/// windows of 400ms, spaced 100ms apart, to compute instantaneous loudness or +/// to perform a gated measurement, or they can be combined into even larger +/// windows for a momentary loudness measurement. +#[derive(Copy, Clone, Debug)] +pub struct Windows100ms { + pub inner: T, +} + +impl Windows100ms { + /// Wrap a new empty vector. + pub fn new() -> Windows100ms> { + Windows100ms { inner: Vec::new() } + } + + /// Apply `as_ref` to the inner value. + pub fn as_ref(&self) -> Windows100ms<&[Power]> + where + T: AsRef<[Power]>, + { + Windows100ms { + inner: self.inner.as_ref(), + } + } + + /// Apply `as_mut` to the inner value. + pub fn as_mut(&mut self) -> Windows100ms<&mut [Power]> + where + T: AsMut<[Power]>, + { + Windows100ms { + inner: self.inner.as_mut(), + } + } + + #[allow(clippy::len_without_is_empty)] + /// Apply `len` to the inner value. + pub fn len(&self) -> usize + where + T: AsRef<[Power]>, + { + self.inner.as_ref().len() + } +} + +/// Measures K-weighted power of non-overlapping 100ms windows of a single channel of audio. +/// +/// # Output +/// +/// The output of the meter is an intermediate result in the form of power for +/// 100ms non-overlapping windows. The windows need to be processed further to +/// get one of the instantaneous, momentary, and integrated loudness +/// measurements defined in BS.1770. +/// +/// The windows can also be inspected directly; the data is meaningful +/// on its own (the K-weighted power delivered in that window of time), but it +/// is not something that BS.1770 defines a term for. +/// +/// # Multichannel audio +/// +/// To perform a loudness measurement of multichannel audio, construct a +/// `ChannelLoudnessMeter` per channel, and later combine the measured power +/// with e.g. `reduce_stereo`. +/// +/// # Instantaneous loudness +/// +/// The instantaneous loudness is the power over a 400ms window, so you can +/// average four 100ms windows. No special functionality is implemented to help +/// with that at this time. ([Pull requests would be accepted.][contribute]) +/// +/// # Momentary loudness +/// +/// The momentary loudness is the power over a 3-second window, so you can +/// average thirty 100ms windows. No special functionality is implemented to +/// help with that at this time. ([Pull requests would be accepted.][contribute]) +/// +/// # Integrated loudness +/// +/// Use `gated_mean` to perform an integrated loudness measurement: +/// +/// ```ignore +/// # use std::iter; +/// # use bs1770::{ChannelLoudnessMeter, gated_mean}; +/// # let sample_rate_hz = 44_100; +/// # let samples_per_100ms = sample_rate_hz / 10; +/// # let mut meter = ChannelLoudnessMeter::new(sample_rate_hz); +/// # meter.push((0..44_100).map(|i| (i as f32 * 0.01).sin())); +/// let integrated_loudness_lkfs = gated_mean(meter.as_100ms_windows()) +/// .unwrap_or(bs1770::Power(0.0)) +/// .loudness_lkfs(); +/// ``` +/// +/// [contribute]: https://github.com/ruuda/bs1770/blob/master/CONTRIBUTING.md +#[derive(Clone)] +pub struct ChannelLoudnessMeter { + /// The number of samples that fit in 100ms of audio. + samples_per_100ms: u32, + + /// Stage 1 filter (head effects, high shelf). + filter_stage1: Filter, + + /// Stage 2 filter (high-pass). + filter_stage2: Filter, + + /// Sum of the squares over non-overlapping windows of 100ms. + windows: Windows100ms>, + + /// The number of samples in the current unfinished window. + count: u32, + + /// The sum of the squares of the samples in the current unfinished window. + square_sum: Sum, +} + +impl ChannelLoudnessMeter { + /// Construct a new loudness meter for the given sample rate. + pub fn new(sample_rate_hz: u32) -> ChannelLoudnessMeter { + ChannelLoudnessMeter { + samples_per_100ms: sample_rate_hz / 10, + filter_stage1: Filter::high_shelf(sample_rate_hz as f32), + filter_stage2: Filter::high_pass(sample_rate_hz as f32), + windows: Windows100ms::new(), + count: 0, + square_sum: Sum::zero(), + } + } + + /// Feed input samples for loudness analysis. + /// + /// # Full scale + /// + /// Full scale for the input samples is the interval [-1.0, 1.0]. If your + /// input consists of signed integer samples, you can convert as follows: + /// + /// ```ignore + /// # let mut meter = bs1770::ChannelLoudnessMeter::new(44_100); + /// # let bits_per_sample = 16_usize; + /// # let samples = &[0_i16]; + /// // Note that the maximum amplitude is `1 << (bits_per_sample - 1)`, + /// // one bit is the sign bit. + /// let normalizer = 1.0 / (1_u64 << (bits_per_sample - 1)) as f32; + /// meter.push(samples.iter().map(|&s| s as f32 * normalizer)); + /// ``` + /// + /// # Repeated calls + /// + /// You can call `push` multiple times to feed multiple batches of samples. + /// This is equivalent to feeding a single chained iterator. The leftover of + /// samples that did not fill a full 100ms window is not discarded: + /// + /// ```ignore + /// # use std::iter; + /// # use bs1770::ChannelLoudnessMeter; + /// let sample_rate_hz = 44_100; + /// let samples_per_100ms = sample_rate_hz / 10; + /// let mut meter = ChannelLoudnessMeter::new(sample_rate_hz); + /// + /// meter.push(iter::repeat(0.0).take(samples_per_100ms as usize - 1)); + /// assert_eq!(meter.as_100ms_windows().len(), 0); + /// + /// meter.push(iter::once(0.0)); + /// assert_eq!(meter.as_100ms_windows().len(), 1); + /// ``` + pub fn push>(&mut self, samples: I) { + let normalizer = 1.0 / self.samples_per_100ms as f32; + + // LLVM, if you could go ahead and inline those apply calls, and then + // unroll and vectorize the loop, that'd be terrific. + for x in samples { + let y = self.filter_stage1.apply(x); + let z = self.filter_stage2.apply(y); + + self.square_sum.add(z * z); + self.count += 1; + + // TODO: Should this branch be marked cold? + if self.count == self.samples_per_100ms { + let mean_squares = Power(self.square_sum.sum * normalizer); + self.windows.inner.push(mean_squares); + // We intentionally do not reset the residue. That way, leftover + // energy from this window is not lost, so for the file overall, + // the sum remains more accurate. + self.square_sum.sum = 0.0; + self.count = 0; + } + } + } + + /// Return a reference to the 100ms windows analyzed so far. + pub fn as_100ms_windows(&self) -> Windows100ms<&[Power]> { + self.windows.as_ref() + } + + /// Return all 100ms windows analyzed so far. + pub fn into_100ms_windows(self) -> Windows100ms> { + self.windows + } +} + +/// Combine power for multiple channels by taking a weighted sum. +/// +/// Note that BS.1770-4 defines power for a multi-channel signal as a weighted +/// sum over channels which is not normalized. This means that a stereo signal +/// is inherently louder than a mono signal. For a mono signal played back on +/// stereo speakers, you should therefore still apply `reduce_stereo`, passing +/// in the same signal for both channels. +pub fn reduce_stereo( + left: Windows100ms<&[Power]>, + right: Windows100ms<&[Power]>, +) -> Windows100ms> { + assert_eq!( + left.len(), + right.len(), + "Channels must have the same length." + ); + let mut result = Vec::with_capacity(left.len()); + for (l, r) in left.inner.iter().zip(right.inner) { + result.push(Power(l.0 + r.0)); + } + Windows100ms { inner: result } +} + +/// In-place version of `reduce_stereo` that stores the result in the former left channel. +pub fn reduce_stereo_in_place(left: Windows100ms<&mut [Power]>, right: Windows100ms<&[Power]>) { + assert_eq!( + left.len(), + right.len(), + "Channels must have the same length." + ); + for (l, r) in left.inner.iter_mut().zip(right.inner) { + l.0 += r.0; + } +} + +/// Perform gating and averaging for a BS.1770-4 integrated loudness measurement. +/// +/// The integrated loudness measurement is not just the average power over the +/// entire signal. BS.1770-4 defines two stages of gating that exclude +/// parts of the signal, to ensure that silent parts do not contribute to the +/// loudness measurement. This function performs that gating, and returns the +/// average power over the windows that were not excluded. +/// +/// The result of this function is the integrated loudness measurement. +/// +/// When no signal remains after applying the gate, this function returns +/// `None`. In particular, this happens when all of the signal is softer than +/// -70 LKFS, including a signal that consists of pure silence. +pub fn gated_mean(windows_100ms: Windows100ms<&[Power]>) -> Option { + let mut gating_blocks = Vec::with_capacity(windows_100ms.len()); + + // Stage 1: an absolute threshold of -70 LKFS. (Equation 6, p.6.) + let absolute_threshold = Power::from_lkfs(-70.0); + + // Iterate over all 400ms windows. + for window in windows_100ms.inner.windows(4) { + // Note that the sum over channels has already been performed at this point. + let gating_block_power = Power(0.25 * window.iter().map(|mean| mean.0).sum::()); + + if gating_block_power > absolute_threshold { + gating_blocks.push(gating_block_power); + } + } + + if gating_blocks.is_empty() { + return None; + } + + // Compute the loudness after applying the absolute gate, in order to + // determine the threshold for the relative gate. + let mut sum_power = Sum::zero(); + for &gating_block_power in &gating_blocks { + sum_power.add(gating_block_power.0); + } + let absolute_gated_power = Power(sum_power.sum / (gating_blocks.len() as f32)); + + // Stage 2: Apply the relative gate. + let relative_threshold = Power::from_lkfs(absolute_gated_power.loudness_lkfs() - 10.0); + let mut sum_power = Sum::zero(); + let mut n_blocks = 0_usize; + for &gating_block_power in &gating_blocks { + if gating_block_power > relative_threshold { + sum_power.add(gating_block_power.0); + n_blocks += 1; + } + } + + if n_blocks == 0 { + return None; + } + + let relative_gated_power = Power(sum_power.sum / n_blocks as f32); + Some(relative_gated_power) +} diff --git a/src/ai/embedding.rs b/src/ai/embedding.rs index e02c3b1..09668e5 100644 --- a/src/ai/embedding.rs +++ b/src/ai/embedding.rs @@ -88,7 +88,15 @@ fn hugging_face(robot_id: &str, info: &HuggingFaceModelInfo, s: &str) -> Result< }; let token_ids = Tensor::new(&tokens[..], &m.device)?.unsqueeze(0)?; let token_type_ids = token_ids.zeros_like()?; - let outputs = m.forward(&token_ids, &token_type_ids)?; + // following attention_mask parameter is needed when batch inputs are of different token length + // let attention_mask = tokens + // .iter() + // .map(|tokens| { + // let tokens = tokens.get_attention_mask().to_vec(); + // Ok(Tensor::new(tokens.as_slice(), device)?) + // }) + // .collect::>>()?; + let outputs = m.forward(&token_ids, &token_type_ids, None)?; let (_n_sentence, n_tokens, _hidden_size) = outputs.dims3()?; let embeddings = (outputs.sum(1)? / (n_tokens as f64))?; let embeddings = embeddings.broadcast_div(&embeddings.sqr()?.sum_keepdim(1)?.sqrt()?)?; diff --git a/src/ai/huggingface.rs b/src/ai/huggingface.rs index 04898a3..73002f5 100644 --- a/src/ai/huggingface.rs +++ b/src/ai/huggingface.rs @@ -9,7 +9,8 @@ use candle::{DType, Device}; use candle_nn::VarBuilder; use candle_transformers::models::bert::{BertModel, Config, DTYPE}; use candle_transformers::models::gemma::{Config as GemmaConfig, Model as GemmaModel}; -use candle_transformers::models::llama::{Cache as LlamaCache, Llama, LlamaConfig}; +use candle_transformers::models::llama::{Cache as LlamaCache, Llama, LlamaConfig, LlamaEosToks}; +use candle_transformers::models::parler_tts::{Config as ParlerTtsConfig, Model as ParlerTtsModel}; use candle_transformers::models::phi3::{Config as Phi3Config, Model as Phi3}; use futures_util::StreamExt; use reqwest::header::{HeaderMap, HeaderValue}; @@ -38,11 +39,14 @@ pub(crate) enum HuggingFaceModel { TinyLlama1_1bChatV1_0, Gemma2bInstruct, Gemma7bInstruct, + ParlerTtsMiniV1, + ParlerTtsLargeV1, + WhisperLargeV3, } pub(crate) enum LoadedHuggingFaceModel { Bert((BertModel, Tokenizer)), - Llama((Device, Llama, LlamaCache, Tokenizer, Option)), + Llama((Device, Llama, LlamaCache, Tokenizer, Option)), Gemma((Device, GemmaModel, Tokenizer)), Phi3((Device, Phi3, Tokenizer)), } @@ -376,6 +380,36 @@ impl HuggingFaceModel { dimenssions: 1024, model_type: HuggingFaceModelType::Gemma, }, + HuggingFaceModel::ParlerTtsMiniV1 => HuggingFaceModelInfo { + repository: "parler-tts/parler-tts-mini-v1", + mirror: "parler-tts/parler-tts-mini-v1", + model_files: get_common_model_files(), + model_index_file: "", + tokenizer_filename: "tokenizer.json", + dimenssions: 1024, + model_type: HuggingFaceModelType::Llama, + }, + HuggingFaceModel::ParlerTtsLargeV1 => HuggingFaceModelInfo { + repository: "parler-tts/parler-tts-large-v1", + mirror: "parler-tts/parler-tts-large-v1", + model_files: { + let mut v = get_common_model_files(); + let mut idx = 0usize; + for &f in v.iter() { + if f.eq("model.safetensors") { + break; + } + idx = idx + 1; + } + v.remove(idx); + v + }, + model_index_file: "model.safetensors.index.json", + tokenizer_filename: "tokenizer.json", + dimenssions: 1024, + model_type: HuggingFaceModelType::Gemma, + }, + HuggingFaceModel::WhisperLargeV3 => todo!(), } } } @@ -873,7 +907,7 @@ fn get_model_files(info: &HuggingFaceModelInfo) -> Result> { pub(crate) fn load_llama_model_files( info: &HuggingFaceModelInfo, -) -> Result<(Device, Llama, LlamaCache, Tokenizer, Option)> { +) -> Result<(Device, Llama, LlamaCache, Tokenizer, Option)> { log::info!("load_llama_model_files start"); let tokenizer = init_tokenizer(&info.repository)?; let device = device()?; @@ -881,14 +915,14 @@ pub(crate) fn load_llama_model_files( let config_filename = construct_model_file_path(&info.repository, "config.json"); let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?; let config = config.into_config(device.is_cuda()); - let eos_token_id = config - .eos_token_id - .or_else(|| tokenizer.token_to_id("")); - let filenames = get_model_files(info)?; let dtype = DType::F16; let cache = LlamaCache::new(true, dtype, &config, &device)?; + let filenames = get_model_files(info)?; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; let m = Llama::load(vb, &config)?; + let eos_token_id = config + .eos_token_id + .or_else(|| tokenizer.token_to_id("").map(LlamaEosToks::Single)); log::info!("load_llama_model_files end"); Ok((device, m, cache, tokenizer, eos_token_id)) } @@ -912,6 +946,19 @@ pub(crate) fn load_gemma_model_files( Ok((device, model, tokenizer)) } +pub(crate) fn load_parler_tts_model_files( + info: &HuggingFaceModelInfo, +) -> Result<(Device, ParlerTtsModel, Tokenizer)> { + let tokenizer = init_tokenizer(&info.repository)?; + let device = device()?; + let filenames = get_model_files(info)?; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? }; + let config_filename = construct_model_file_path(&info.repository, "config.json"); + let config: ParlerTtsConfig = serde_json::from_reader(std::fs::File::open(config_filename)?)?; + let model = ParlerTtsModel::new(&config, vb)?; + Ok((device, model, tokenizer)) +} + pub(crate) fn load_pytorch_mode_files(info: &HuggingFaceModelInfo, device: &Device) -> Result<()> { let weights_filename = construct_model_file_path(&info.repository, "pytorch_model.bin"); let vb = VarBuilder::from_pth(&weights_filename, DType::BF16, device)?; diff --git a/src/ai/llama.rs b/src/ai/llama.rs index ce25c03..7f28dd0 100644 --- a/src/ai/llama.rs +++ b/src/ai/llama.rs @@ -1,6 +1,6 @@ use candle::{Device, Tensor}; use candle_transformers::generation::{LogitsProcessor, Sampling}; -use candle_transformers::models::llama::{Cache, Llama}; +use candle_transformers::models::llama::{Cache, Llama, LlamaEosToks}; // use crossbeam_channel::Sender; use frand::Rand; use tokenizers::Tokenizer; @@ -28,7 +28,7 @@ pub(super) fn gen_text( model: &Llama, cache: &Cache, tokenizer: &Tokenizer, - eos_token_id: &Option, + eos_token_id: &Option, prompt: &str, sample_len: usize, top_k: Option, @@ -121,8 +121,14 @@ pub(super) fn gen_text( token_generated += 1; tokens.push(next_token); - if Some(next_token) == *eos_token_id { - break; + match eos_token_id { + Some(LlamaEosToks::Single(eos_tok_id)) if next_token == *eos_tok_id => { + break; + } + Some(LlamaEosToks::Multiple(ref eos_ids)) if eos_ids.contains(&next_token) => { + break; + } + _ => (), } if let Some(t) = tokenizer.next_token(next_token)? { // log::info!("{&t}"); diff --git a/src/ai/mod.rs b/src/ai/mod.rs index d529457..37b1c76 100644 --- a/src/ai/mod.rs +++ b/src/ai/mod.rs @@ -1,3 +1,6 @@ +pub(crate) mod asr; +pub(crate) mod audio; +pub(crate) mod bs1770; pub(crate) mod chat; pub(crate) mod completion; pub(crate) mod crud; @@ -7,3 +10,4 @@ pub(super) mod huggingface; pub(super) mod llama; pub(super) mod phi3; mod token_output_stream; +pub(crate) mod tts; diff --git a/src/ai/tts.rs b/src/ai/tts.rs new file mode 100644 index 0000000..19b8556 --- /dev/null +++ b/src/ai/tts.rs @@ -0,0 +1,9 @@ +use serde::{Deserialize, Serialize}; + +use super::huggingface::{load_bert_model_files, HuggingFaceModel, HuggingFaceModelInfo}; + +#[derive(Clone, Deserialize, Serialize)] +#[serde(tag = "id", content = "model")] +pub(crate) enum TtsProvider { + HuggingFace(HuggingFaceModel), +} diff --git a/src/man/settings.rs b/src/man/settings.rs index 4566492..4706941 100644 --- a/src/man/settings.rs +++ b/src/man/settings.rs @@ -11,7 +11,7 @@ use serde::{Deserialize, Serialize}; use serde_json::{Map, Value}; use crate::ai::huggingface::HuggingFaceModel; -use crate::ai::{chat, completion, embedding, huggingface}; +use crate::ai::{asr, chat, completion, embedding, huggingface, tts}; use crate::db; use crate::result::{Error, Result}; use crate::robot::dto::RobotQuery; @@ -53,6 +53,10 @@ pub(crate) struct Settings { pub(crate) text_generation_provider: TextGenerationProvider, #[serde(rename = "sentenceEmbeddingProvider")] pub(crate) sentence_embedding_provider: SentenceEmbeddingProvider, + #[serde(rename = "asrProvider")] + pub(crate) asr_provider: AsrProvider, + #[serde(rename = "ttsProvider")] + pub(crate) tts_provider: TtsProvider, #[serde(rename = "smtpHost")] pub(crate) smtp_host: String, #[serde(rename = "smtpUsername")] @@ -152,6 +156,38 @@ pub(crate) struct SentenceEmbeddingProvider { pub(crate) proxy_url: String, } +#[derive(Clone, Deserialize, Serialize)] +pub(crate) struct AsrProvider { + pub(crate) provider: asr::AsrProvider, + #[serde(rename = "apiUrl")] + pub(crate) api_url: String, + #[serde(rename = "apiKey")] + pub(crate) api_key: String, + pub(crate) model: String, + #[serde(rename = "connectTimeoutMillis")] + pub(crate) connect_timeout_millis: u32, + #[serde(rename = "readTimeoutMillis")] + pub(crate) read_timeout_millis: u32, + #[serde(rename = "proxyUrl")] + pub(crate) proxy_url: String, +} + +#[derive(Clone, Deserialize, Serialize)] +pub(crate) struct TtsProvider { + pub(crate) provider: tts::TtsProvider, + #[serde(rename = "apiUrl")] + pub(crate) api_url: String, + #[serde(rename = "apiKey")] + pub(crate) api_key: String, + pub(crate) model: String, + #[serde(rename = "connectTimeoutMillis")] + pub(crate) connect_timeout_millis: u32, + #[serde(rename = "readTimeoutMillis")] + pub(crate) read_timeout_millis: u32, + #[serde(rename = "proxyUrl")] + pub(crate) proxy_url: String, +} + impl Default for GlobalSettings { fn default() -> Self { GlobalSettings { @@ -207,6 +243,28 @@ impl Default for Settings { read_timeout_millis: 10000, proxy_url: String::new(), }, + asr_provider: AsrProvider { + provider: asr::AsrProvider::HuggingFace( + huggingface::HuggingFaceModel::WhisperLargeV3, + ), + api_url: String::new(), + api_key: String::new(), + model: String::new(), + connect_timeout_millis: 5000, + read_timeout_millis: 10000, + proxy_url: String::new(), + }, + tts_provider: TtsProvider { + provider: tts::TtsProvider::HuggingFace( + huggingface::HuggingFaceModel::ParlerTtsMiniV1, + ), + api_url: String::new(), + api_key: String::new(), + model: String::new(), + connect_timeout_millis: 5000, + read_timeout_millis: 10000, + proxy_url: String::new(), + }, smtp_host: String::new(), smtp_username: String::new(), smtp_password: String::new(),