Skip to content

Commit

Permalink
Merge pull request #16 from InfiniTensor/dev
Browse files Browse the repository at this point in the history
重构 chat-template 和 tokenizer
  • Loading branch information
YdrMaster authored Aug 5, 2024
2 parents 30578bf + 227edbf commit f688768
Show file tree
Hide file tree
Showing 34 changed files with 1,081 additions and 463 deletions.
226 changes: 151 additions & 75 deletions Cargo.lock

Large diffs are not rendered by default.

7 changes: 5 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ members = [
"tensor",
"tokenizer",
"causal-lm",
"chat-template",
"service",
"web-api",
"xtask",
Expand All @@ -12,6 +13,7 @@ members = [
"devices/common-cpu",
"devices/nvidia-gpu",
"devices/cambricon-mlu",
"devices/ascend-card",

"models/llama/common",
"models/llama/common-cpu",
Expand All @@ -34,6 +36,7 @@ tokio = { version = "1.38", features = ["rt-multi-thread", "sync"] }
digit-layout = "0.0"
build-script-cfg = "0.0"

operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "5a88159", default-features = false }
search-cuda-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "fb088b6" }
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "e6ee6ea", default-features = false }
search-cuda-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "d089ada" }
search-neuware-tools = "0.0"
search-ascend-tools = { git = "https://github.com/InfiniTensor/ascendcl", rev = "1e7a696" }
2 changes: 2 additions & 0 deletions causal-lm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ pub trait CausalLM: Model {
type Storage;
/// 最大序列长度。
fn max_seq_len(&self) -> upos;
/// 模型定义的句子起始符。
fn bos_token(&self) -> utok;
/// 模型定义的句子结束符。
fn eos_token(&self) -> utok;
/// 创建一个未填充的缓存张量(`num_layers x 2 x num_kv_head x max_seq_len x head_dim`)。
Expand Down
9 changes: 9 additions & 0 deletions chat-template/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[package]
name = "chat-template"
version = "0.0.0"
edition = "2021"
authors = ["YdrMaster <[email protected]>"]

[dependencies]
serde = { workspace = true, features = ["derive"] }
minijinja = { version = "2.1", default-features = false, features = ["loader"] }
119 changes: 119 additions & 0 deletions chat-template/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#![deny(warnings)]

use minijinja::Environment;
use serde::Serialize;
use std::sync::{
atomic::{AtomicUsize, Ordering::Relaxed},
OnceLock, RwLock,
};

#[repr(transparent)]
pub struct ChatTemplate(String);

#[derive(Serialize)]
pub struct Message<'a> {
pub role: &'a str,
pub content: &'a str,
}

impl ChatTemplate {
pub fn new(template: String) -> Self {
static NEXT: AtomicUsize = AtomicUsize::new(0);
let id = NEXT.fetch_add(1, Relaxed).to_string();

jinja()
.write()
.unwrap()
.add_template_owned(id.clone(), template)
.unwrap();

Self(id)
}

pub fn render(
&self,
messages: &[Message<'_>],
bos_token: &str,
eos_token: &str,
add_generation_prompt: bool,
) -> Result<String, minijinja::Error> {
#[derive(Serialize)]
struct Args<'a> {
messages: &'a [Message<'a>],
bos_token: &'a str,
eos_token: &'a str,
add_generation_prompt: bool,
}

jinja()
.read()
.unwrap()
.get_template(&self.0)
.unwrap()
.render(Args {
messages,
bos_token,
eos_token,
add_generation_prompt,
})
}
}

impl Drop for ChatTemplate {
fn drop(&mut self) {
jinja().write().unwrap().remove_template(&self.0);
}
}

fn jinja() -> &'static RwLock<Environment<'static>> {
static ENV: OnceLock<RwLock<Environment<'_>>> = OnceLock::new();
ENV.get_or_init(|| {
let mut env = Environment::empty();
env.set_unknown_method_callback(|_, value, method, args| {
use minijinja::{value::ValueKind as ThisType, ErrorKind::UnknownMethod, Value};
match (method, value.kind(), args) {
("strip", ThisType::String, []) => Ok(Value::from_safe_string(
value.to_str().unwrap().trim().into(),
)),
_ => Err(UnknownMethod.into()),
}
});
RwLock::new(env)
})
}

#[test]
fn test() {
const TAIDE: &str = "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = '<<SYS>>\n' + messages[0]['content'] + '\n<</SYS>>\n\n' %}{% else %}{% set loop_messages = messages %}{% set system_message = '' %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{% set content = system_message + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content + ' [/INST]'}}{% elif message['role'] == 'assistant' %}{{ ' ' + content + ' ' + eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}";
const MINICPM: &str = "{% for message in messages %}{% if message['role'] == 'user' %}{{'<用户>' + message['content'].strip() + '<AI>'}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}";

let result = ChatTemplate::new(TAIDE.into())
.render(
&[Message {
role: "user",
content: "Hello, who are you?",
}],
"<s>",
"</s>",
true,
)
.unwrap();

assert_eq!(
result,
"<s>[INST] Hello, who are you? [/INST]<|im_start|>assistant\n"
);

let result = ChatTemplate::new(MINICPM.into())
.render(
&[Message {
role: "user",
content: "Hello, who are you?",
}],
"<s>",
"</s>",
true,
)
.unwrap();
assert_eq!(result, "<用户>Hello, who are you?<AI>");
}
39 changes: 0 additions & 39 deletions common/src/between_f32.rs

This file was deleted.

2 changes: 0 additions & 2 deletions common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,10 @@ pub type utok = u32;
#[allow(non_camel_case_types)]
pub type upos = u32;

mod between_f32;
mod blob;
pub mod safe_tensors;
pub mod test_model;

pub use between_f32::BetweenF32;
pub use blob::Blob;
pub use half::{bf16, f16};

Expand Down
17 changes: 17 additions & 0 deletions devices/ascend-card/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[package]
name = "common-acl"
version = "0.0.0"
edition = "2021"
authors = ["YdrMaster <[email protected]>"]

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
common = { path = "../../common" }
common-devices = { path = "../common" }
tensor = { path = "../../tensor" }
operators = { workspace = true, features = ["ascend-card"] }

[build-dependencies]
build-script-cfg.workspace = true
search-ascend-tools.workspace = true
9 changes: 9 additions & 0 deletions devices/ascend-card/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
fn main() {
use build_script_cfg::Cfg;
use search_ascend_tools::find_ascend_toolkit_home;

let ascend = Cfg::new("detected_ascend");
if find_ascend_toolkit_home().is_some() {
ascend.define();
}
}
4 changes: 4 additions & 0 deletions devices/ascend-card/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#![cfg(detected_ascend)]

pub use operators::ascendcl;
pub use tensor::Tensor;
13 changes: 5 additions & 8 deletions devices/common-cpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ use digit_layout::types::F16;
use operators::{
fuesd_softmax::common_cpu as softmax,
mat_mul::common_cpu as mat_mul,
mlp::common_cpu as mlp,
random_sample::{common_cpu as random_sample, Args, KVPair, SampleArgs},
reform::common_cpu as reform,
rms_norm::common_cpu as rms_norm,
rope::common_cpu as rope,
swiglu::common_cpu as swiglu,
Operator, QueueOf,
};
use std::ops::{Deref, DerefMut};
Expand All @@ -34,7 +34,7 @@ pub struct CpuKernels {
rms_norm: rms_norm::Operator,
rope: rope::Operator,
softmax: softmax::Operator,
swiglu: swiglu::Operator,
mlp: mlp::Operator,
sample: random_sample::Operator,
}

Expand Down Expand Up @@ -62,7 +62,7 @@ impl Default for CpuKernels {
rms_norm: rms_norm::Operator::new(&Cpu),
rope: rope::Operator::new(&Cpu),
softmax: softmax::Operator::new(&Cpu),
swiglu: swiglu::Operator::new(&Cpu),
mlp: mlp::Operator::new(&Cpu),
sample: random_sample::Operator::new(&Cpu),
}
}
Expand Down Expand Up @@ -100,11 +100,8 @@ impl Operators for CpuKernels {
) -> &impl operators::fuesd_softmax::FusedSoftmax<Self::Handle> {
&self.softmax
}
fn swiglu_op(
&self,
_: &QueueOf<Self::Handle>,
) -> &impl operators::swiglu::Swiglu<Self::Handle> {
&self.swiglu
fn mlp_op(&self, _: &QueueOf<Self::Handle>) -> &impl operators::mlp::Mlp<Self::Handle> {
&self.mlp
}
}

Expand Down
67 changes: 49 additions & 18 deletions devices/common/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
use common::utok;
use operators::{
fuesd_softmax, mat_mul, reform, rms_norm, rope, swiglu, Handle, Operator, QueueOf,
};
use operators::{fuesd_softmax, mat_mul, mlp, reform, rms_norm, rope, Handle, Operator, QueueOf};
use std::ops::{Deref, DerefMut};
use tensor::Tensor;

Expand All @@ -18,7 +16,7 @@ pub trait Operators {
&self,
queue: &QueueOf<Self::Handle>,
) -> &impl fuesd_softmax::FusedSoftmax<Self::Handle>;
fn swiglu_op(&self, queue: &QueueOf<Self::Handle>) -> &impl swiglu::Swiglu<Self::Handle>;
fn mlp_op(&self, queue: &QueueOf<Self::Handle>) -> &impl mlp::Mlp<Self::Handle>;
}

pub trait KernelsA {
Expand Down Expand Up @@ -68,10 +66,23 @@ pub trait KernelsA {
where
T: DerefMut<Target = SliceOn<Self::Handle>>;

fn swiglu<T, U>(&self, gate: &mut Tensor<T>, up: &Tensor<U>, queue: &QueueOf<Self::Handle>)
where
T: DerefMut<Target = SliceOn<Self::Handle>>,
U: Deref<Target = SliceOn<Self::Handle>>;
#[allow(clippy::too_many_arguments)]
fn mlp<M0, M1, C0, C1, C2>(
&self,
x: &mut Tensor<M0>,
x1: &Tensor<C0>,
gate_up: &mut Tensor<M1>,
w_gate_up: &Tensor<C1>,
w_down: &Tensor<C2>,
down_alpha: f32,
down_bias: bool,
queue: &QueueOf<Self::Handle>,
) where
M0: DerefMut<Target = SliceOn<Self::Handle>>,
M1: DerefMut<Target = SliceOn<Self::Handle>>,
C0: Deref<Target = SliceOn<Self::Handle>>,
C1: Deref<Target = SliceOn<Self::Handle>>,
C2: Deref<Target = SliceOn<Self::Handle>>;
}

pub trait KernelsB {
Expand Down Expand Up @@ -216,18 +227,38 @@ reform failed: {e}
.unwrap();
}

fn swiglu<T, U>(&self, gate: &mut Tensor<T>, up: &Tensor<U>, queue: &QueueOf<Self::Handle>)
where
T: DerefMut<Target = SliceOn<Self::Handle>>,
U: Deref<Target = SliceOn<Self::Handle>>,
fn mlp<M0, M1, C0, C1, C2>(
&self,
x: &mut Tensor<M0>,
x1: &Tensor<C0>,
gate_up: &mut Tensor<M1>,
w_gate_up: &Tensor<C1>,
w_down: &Tensor<C2>,
down_alpha: f32,
down_bias: bool,
queue: &QueueOf<Self::Handle>,
) where
M0: DerefMut<Target = SliceOn<Self::Handle>>,
M1: DerefMut<Target = SliceOn<Self::Handle>>,
C0: Deref<Target = SliceOn<Self::Handle>>,
C1: Deref<Target = SliceOn<Self::Handle>>,
C2: Deref<Target = SliceOn<Self::Handle>>,
{
self.swiglu_op(queue)
self.mlp_op(queue)
.launch(
&swiglu::Args {
gate_layout: gate.layout(),
gate_base: gate.base_mut(),
up_layout: up.layout(),
up_base: up.base(),
&mlp::Args {
y_layout: x.layout(),
y_base: x.base_mut(),
x_layout: x1.layout(),
x_base: x1.base(),
gate_up_layout: gate_up.layout(),
gate_up_base: gate_up.base_mut(),
w_gate_up_layout: w_gate_up.layout(),
w_gate_up_base: w_gate_up.base(),
w_down_layout: w_down.layout(),
w_down_base: w_down.base(),
down_alpha,
down_bias,
},
queue,
)
Expand Down
Loading

0 comments on commit f688768

Please sign in to comment.