Skip to content

Commit

Permalink
feat(xtask): 为 generate 添加参数,控制是否拷贝参数到内存和最大迭代次数
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Feb 27, 2024
1 parent 910e00c commit a1a8c7f
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 8 deletions.
2 changes: 1 addition & 1 deletion model-parameters/src/memory/realloc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ impl Memory {
w_qkv: writer.write(src.w_qkv(layer)),
self_attn_o_proj: writer.write(src.self_attn_o_proj(layer)),
post_attention_layernorm: writer.write(src.post_attention_layernorm(layer)),
mlp_gate_up: writer.write(src.mlp_gate(layer)),
mlp_gate_up: writer.write(src.mlp_gate_up(layer)),
mlp_down: writer.write(src.mlp_down(layer)),
})
.collect(),
Expand Down
2 changes: 1 addition & 1 deletion transformer-cpu/src/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ pub(super) fn matmul<T, U, V>(
false,
false,
false,
gemm::Parallelism::None,
gemm::Parallelism::Rayon(0),
)
},
_ => unreachable!(),
Expand Down
2 changes: 1 addition & 1 deletion transformer-cpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ impl Transformer {
Self {
logits: vec![0.0f32; model.vocab_size()],
model: match model.data_type() {
DataType::BF16 => Box::new(Memory::cast(&*model, DataType::F32)),
DataType::BF16 | DataType::F32 => Box::new(Memory::cast(&*model, DataType::F16)),
_ => model,
},
}
Expand Down
2 changes: 1 addition & 1 deletion transformer-nvidia/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
mod parameters;

use cuda::{driver, Context, Stream};
use model_parameters::{Llama2, Memory};
use model_parameters::Memory;
use parameters::{LayersParameters, ModelParameters};
use std::{
ptr::{null_mut, NonNull},
Expand Down
66 changes: 62 additions & 4 deletions xtask/src/generate.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
use common::utok;
use log::LevelFilter;
use simple_logger::SimpleLogger;
use std::{io::Write, path::PathBuf, time::Instant};
use std::{
alloc::Layout, collections::HashMap, io::Write, path::PathBuf, ptr::NonNull, sync::Mutex,
time::Instant,
};
use tokenizer::{Tokenizer, BPE};
use transformer_cpu::{model_parameters::Memory, Transformer};
use transformer_cpu::{
model_parameters::{Allocator, Llama2, Memory},
Transformer,
};

#[derive(Args, Default)]
pub(crate) struct GenerateArgs {
Expand All @@ -13,11 +19,45 @@ pub(crate) struct GenerateArgs {
/// Prompt.
#[clap(short, long)]
prompt: String,
/// Max steps.
#[clap(short, long)]
step: Option<usize>,
/// Copy model parameters inside memory.
#[clap(short, long)]
inside_mem: bool,
/// Log level.
#[clap(short, long)]
log: Option<String>,
}

struct NormalAllocator(Mutex<HashMap<*const u8, usize>>);

impl Allocator for NormalAllocator {
unsafe fn allocate(&self, size: usize) -> NonNull<u8> {
let ptr = NonNull::new(std::alloc::alloc(Layout::from_size_align_unchecked(
size,
std::mem::align_of::<usize>(),
)))
.unwrap();
self.0.lock().unwrap().insert(ptr.as_ptr(), size);
ptr
}

unsafe fn deallocate(&self, ptr: NonNull<u8>) {
std::alloc::dealloc(
ptr.as_ptr(),
Layout::from_size_align_unchecked(
self.0
.lock()
.unwrap()
.remove(&ptr.as_ptr().cast_const())
.unwrap(),
std::mem::align_of::<usize>(),
),
)
}
}

impl GenerateArgs {
pub fn invoke(self) {
let log = self
Expand All @@ -36,9 +76,19 @@ impl GenerateArgs {
let model_dir = PathBuf::from(self.model);

let time = Instant::now();
let model = Box::new(Memory::load_safetensors(&model_dir).unwrap());
let mut model = Box::new(Memory::load_safetensors(&model_dir).unwrap());
info!("load model ... {:?}", time.elapsed());

if self.inside_mem {
let time = Instant::now();
let allocator = NormalAllocator(Mutex::new(HashMap::new()));
model = Box::new(Memory::realloc_with(&*model, allocator));
info!("copy model ... {:?}", time.elapsed());
}
let step = self
.step
.unwrap_or(usize::MAX)
.min(model.max_position_embeddings());
let time = Instant::now();
let mut transformer = Transformer::new(model);
let mut kv_cache = transformer.new_cache();
Expand All @@ -63,7 +113,8 @@ impl GenerateArgs {

let mut token = *last;
let mut pos = tokens.len();
loop {
let time = Instant::now();
while pos < step {
let logits = transformer.forward(token, &mut kv_cache, pos as _);
let next = argmax(&logits);

Expand All @@ -73,6 +124,13 @@ impl GenerateArgs {
print!("{}", tokenizer.decode(next).replace('▁', " "));
std::io::stdout().flush().unwrap();
}
println!();
let duration = time.elapsed();
info!("generate ... {duration:?}");
info!(
"avg. speed ... {} tokens/s",
(pos - tokens.len()) as f32 / duration.as_secs_f32()
)
}
}

Expand Down

0 comments on commit a1a8c7f

Please sign in to comment.