Skip to content

Commit

Permalink
fix: fix bug
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Sep 28, 2024
1 parent 55c8997 commit 6e9dbc1
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
17 changes: 14 additions & 3 deletions models/llama/common-cpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ use memmap2::Mmap;
use operators::{
common_cpu::{Cpu, ThisThread},
random_sample::{common_cpu::Operator as CpuOp, KVPair, SampleArgs},
QueueOf,
ByteOf, QueueOf,
};
use std::slice::from_raw_parts_mut;
use std::{ops::Deref, slice::from_raw_parts_mut};
use tensor::{ArrayLayout, BigEndian, Tensor};

pub struct Llama {
Expand Down Expand Up @@ -62,7 +62,7 @@ impl Llama {
let mut embd_buf = vec![0u8; embd.shape().iter().product::<usize>() * ele];
let mut logits_buf = vec![0u8; logits.shape().iter().product::<usize>() * ele];

let d = embd.shape()[1];
let d = embd.shape()[1] * ele;
for (i, &tok) in input.iter().enumerate() {
embd_buf[i * d..][..d].copy_from_slice(&self.token_embed[tok as usize * d..][..d]);
}
Expand Down Expand Up @@ -132,6 +132,13 @@ impl llama::Operators for Operators {
type AttnKVCached = op!(attention_kv_cached);
type Mlp = op!(mlp);
type Rearrange = op!(rearrange);

fn debug<T>(tensor: &Tensor<T>)
where
T: Deref<Target = [ByteOf<Self::Hardware>]>,
{
println!("{tensor}");
}
}

struct Weights {
Expand Down Expand Up @@ -194,6 +201,10 @@ fn test_load() {
let mut cache_buf = vec![0u8; cache.shape().iter().product::<usize>() * size_of::<f16>()];

let mut prompt = "Once upon a time,".to_string();

print!("{prompt}");
std::io::stdout().flush().unwrap();

let mut tokens = tokenizer.encode(&prompt);
while !tokens.contains(&2) {
let next = llama.infer(&tokens, &mut cache_buf, 0);
Expand Down
8 changes: 6 additions & 2 deletions models/llama/common/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ pub trait Operators {
type AttnKVCached: AttnKVCached<Self::Hardware>;
type Mlp: Mlp<Self::Hardware>;
type Rearrange: Rearrange<Self::Hardware>;

fn debug<T>(tensor: &Tensor<T>)
where
T: Deref<Target = [ByteOf<Self::Hardware>]>;
}

pub enum BlkWeight {
Expand Down Expand Up @@ -255,8 +259,8 @@ where
let x_ = unsafe { x.map_slice_static() };
self.rms_norm(&mut x, &x_, &w, workspace, queue_alloc)?;

let lm_head = self.weights.output(queue);
self.mat_mul(&mut logits, 0., &x, &lm_head, 1., workspace, queue_alloc)
let output = self.weights.output(queue);
self.mat_mul(&mut logits, 0., &x, &output, 1., workspace, queue_alloc)
}
}

Expand Down

0 comments on commit 6e9dbc1

Please sign in to comment.