Skip to content

Commit

Permalink
refactor(gpt2): 进一步整理和简化 gpt2 模型结构
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Dec 30, 2024
1 parent d55ac10 commit 152266b
Show file tree
Hide file tree
Showing 7 changed files with 198 additions and 218 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ itertools = "0.13"
env_logger = "0.11"
build-script-cfg = "0.0"

operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "91be1cc", default-features = false }
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "78b578d", default-features = false }

search-cl-tools = { git = "https://github.com/InfiniTensor/clrt", rev = "9b6289d" }
search-infini-tools = { git = "https://github.com/InfiniTensor/infini-rt", rev = "f40bcb5" }
Expand Down
13 changes: 9 additions & 4 deletions models/gpt2/common-cpu/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,19 @@ fn test_infer() {
return;
};
let gguf = GGufModel::read(model.iter().map(|s| &**s));
let model = Storage::from_gguf(&gguf);

let TokenizerAndPrompt {
eos,
tokenizer,
prompt,
} = TokenizerAndPrompt::new(&gguf, prompt, as_user);

let model = Storage::from_gguf(&gguf);
println!("{:?}", model.meta);

let sample_args = SampleArgs::new(temperature, top_p, top_k).expect("invalid sample args");
println!("{sample_args:?}");

let &Gpt2Meta {
dt_embd,
nctx,
Expand All @@ -49,7 +55,7 @@ fn test_infer() {

test_utils::test_infer(eos, tokenizer, &prompt, max_steps, |input, pos| {
// 词汇编码缓存
let mut embd = Tensor::new(dt_embd, &[1, input.len(), d]).map(Blob::new);
let mut embd = Tensor::new(dt_embd, &[input.len(), d]).map(Blob::new);
// 词汇位置缓存
let mut logits = model.meta.logits(1).map(Blob::new);
let l = embd.get().len() / input.len();
Expand All @@ -60,7 +66,7 @@ fn test_infer() {
worker
.launch(
gpt2::args::Args {
token_embd: embd.map_slice_mut(),
embd: embd.map_slice_mut(),
logits: logits.map_slice_mut(),
idx: postion(input.len(), pos).map_slice(),
idx_add: postion(input.len(), 0).map_slice(),
Expand All @@ -70,7 +76,6 @@ fn test_infer() {
out_len: 1,
pos,
}],
num_tokens: input.len(),
max_seq_len: input.len(),
max_att_len: pos + input.len(),
},
Expand Down
5 changes: 3 additions & 2 deletions models/gpt2/common-cpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,14 @@ where
{
type Hardware = Cpu;
type TopoNode = N;
type AddRows = op!(add_rows);
type LayerNorm = op!(layer_norm);
type MatMul = op!(mat_mul);
type AttnKVCached = op!(attention_kv_cached);
type Gelu = op!(gelu);
type Rearrange = op!(rearrange);
type AllReduce = R;
type AddRows = op!(add_rows);
type Mlp = op!(gpt2_mlp);

fn debug<T>(tensor: &Tensor<T>)
where
T: Deref<Target = [ByteOf<Self::Hardware>]>,
Expand Down
3 changes: 1 addition & 2 deletions models/gpt2/common/src/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@ use tensor::Tensor;

pub struct Args<'a, H: Hardware> {
/// shape: [nt, d]
pub token_embd: Tensor<&'a mut [H::Byte]>,
pub embd: Tensor<&'a mut [H::Byte]>,
/// shape: [nout, nvoc]
pub logits: Tensor<&'a mut [H::Byte]>,
pub idx: Tensor<&'a [H::Byte]>,
pub idx_add: Tensor<&'a [H::Byte]>,
pub requests: Vec<Request<'a, H>>,

pub num_tokens: usize,
pub max_seq_len: usize,
pub max_att_len: usize,
}
Expand Down
Loading

0 comments on commit 152266b

Please sign in to comment.