Skip to content

Commit

Permalink
refactor(transformer): 简化模型加载
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Apr 18, 2024
1 parent 076b05c commit 75c4f29
Show file tree
Hide file tree
Showing 11 changed files with 23 additions and 39 deletions.
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion nvidia/distributed/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ impl transformer::Transformer for Transformer {
impl Transformer {
pub fn new(model_dir: impl AsRef<Path>, dev: &[Device]) -> Self {
let time = Instant::now();
let host = Memory::load_safetensors_from_dir(model_dir).unwrap();
let host = Memory::load_safetensors(model_dir).unwrap();
info!("load host: {:?}", time.elapsed());

let block_size = dev.iter().map(|dev| dev.max_block_dims().0).min().unwrap();
Expand Down
2 changes: 1 addition & 1 deletion nvidia/distributed/src/parameters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ fn test_load() {
SimpleLogger::new().with_level(Trace).init().unwrap();

let time = Instant::now();
let safetensors = Memory::load_safetensors_from_dir(model_dir);
let safetensors = Memory::load_safetensors(model_dir);
info!("mmap {:?}", time.elapsed());

let model = match safetensors {
Expand Down
1 change: 0 additions & 1 deletion nvidia/transformer/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ authors = ["YdrMaster <[email protected]>"]
[dependencies]
transformer = { path = "../../transformer" }
common-nv = { path = "../common" }
log.workspace = true
half.workspace = true

[dev-dependencies]
Expand Down
18 changes: 3 additions & 15 deletions nvidia/transformer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,19 @@

mod parameters;

#[macro_use]
extern crate log;

pub use common_nv::cuda;

use ::half::f16;
use common_nv::{
cuda::{memcpy_d2h, DevMem, DevMemSpore},
slice, udim, utok, DataType, LocalSplitable, NvidiaKernels, NvidiaKernelsPtx, SafeTensors,
Tensor,
slice, udim, utok, DataType, LocalSplitable, NvidiaKernels, NvidiaKernelsPtx, Tensor,
};
use cuda::{Context, ContextResource, ContextSpore, Device, Stream, StreamSpore};
use parameters::{LayerParameter, LayersParameters, ModelParameters};
use std::{
fs::File,
path::Path,
slice::from_raw_parts,
sync::{Arc, Mutex},
time::Instant,
};
use transformer::{pos, Kernels, LayerBuffer, LayerCache, Llama2, Memory, Request, SampleArgs};

Expand Down Expand Up @@ -140,15 +134,9 @@ type Splitable<'ctx> = LocalSplitable<DevMem<'ctx>>;

impl Transformer {
pub fn new(model_dir: impl AsRef<Path>, preload_layers: usize, dev: Device) -> Self {
let time = Instant::now();
let config = File::open(model_dir.as_ref().join("config.json")).unwrap();
let model = SafeTensors::load_from_dir(model_dir).unwrap();
info!("open file {:?}", time.elapsed());

let context = Arc::new(dev.retain_primary());
let host = Memory::load_safetensors(
config,
model,
let host = Memory::load_safetensors_realloc(
model_dir,
Some(|l| context.apply(|ctx| ctx.malloc_host::<u8>(l).sporulate())),
)
.unwrap();
Expand Down
2 changes: 1 addition & 1 deletion service/src/cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use transformer_cpu::Transformer;

pub fn transformer(model_dir: impl AsRef<Path>) -> Transformer {
let time = Instant::now();
let model = Memory::load_safetensors_from_dir(model_dir).unwrap();
let model = Memory::load_safetensors(model_dir).unwrap();
info!("load model ... {:?}", time.elapsed());

let time = Instant::now();
Expand Down
2 changes: 1 addition & 1 deletion transformer-cpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ fn test_build() {
println!("model_dir: {}", model_dir.display());

let t0 = Instant::now();
let safetensors = Memory::load_safetensors_from_dir(model_dir);
let safetensors = Memory::load_safetensors(model_dir);
let t1 = Instant::now();
println!("mmap {:?}", t1 - t0);

Expand Down
2 changes: 1 addition & 1 deletion transformer/src/parameters/distribute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ fn test() {
println!("model_dir: {}", model_dir.display());

let time = Instant::now();
let safetensors = Memory::load_safetensors_from_dir(model_dir);
let safetensors = Memory::load_safetensors(model_dir);
println!("mmap {:?}", time.elapsed());

let model = match safetensors {
Expand Down
2 changes: 1 addition & 1 deletion transformer/src/parameters/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ fn test_load() {
println!("model_dir: {}", model_dir.display());

let t0 = Instant::now();
let safetensors = Memory::load_safetensors_from_dir(model_dir);
let safetensors = Memory::load_safetensors(model_dir);
let t1 = Instant::now();
println!("mmap {:?}", t1 - t0);

Expand Down
28 changes: 13 additions & 15 deletions transformer/src/parameters/safe_tensors.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,27 @@
use super::{memory::Layer, storage::HostMem, ConfigJson, Memory, Storage};
use common::{
safe_tensors::{Dtype, SafeTensors, SafeTensorsError},
safe_tensors::{
Dtype, SafeTensors,
SafeTensorsError::{self, Io, Json},
},
Blob,
};
use std::{fs::File, io::Read, ops::DerefMut, path::Path, sync::Arc};
use std::{fs::File, ops::DerefMut, path::Path, sync::Arc};
use tensor::{udim, DataType, Shape, Tensor};

impl Memory {
pub fn load_safetensors_from_dir(
model_dir: impl AsRef<Path>,
) -> Result<Self, SafeTensorsError> {
let model_dir = model_dir.as_ref();
let config = File::open(model_dir.join("config.json")).map_err(SafeTensorsError::Io)?;
let model = SafeTensors::load_from_dir(model_dir)?;
Self::load_safetensors(config, model, Some(Blob::new)).map_err(SafeTensorsError::Json)
pub fn load_safetensors(model_dir: impl AsRef<Path>) -> Result<Self, SafeTensorsError> {
Self::load_safetensors_realloc(model_dir, Some(Blob::new))
}

pub fn load_safetensors<T: HostMem + DerefMut<Target = [u8]>>(
config: impl Read,
model: SafeTensors,
pub fn load_safetensors_realloc<T: HostMem + DerefMut<Target = [u8]>>(
model_dir: impl AsRef<Path>,
mut realloc: Option<impl FnMut(usize) -> T>,
) -> Result<Self, serde_json::Error> {
let config: ConfigJson = serde_json::from_reader(config)?;
) -> Result<Self, SafeTensorsError> {
let config = File::open(model_dir.as_ref().join("config.json")).map_err(Io)?;
let config: ConfigJson = serde_json::from_reader(&config).map_err(Json)?;
let model = SafeTensors::load_from_dir(model_dir)?.share();

let model = model.share();
let tensor = |name: &str| {
let shared = model
.share_tensor(name)
Expand Down
2 changes: 1 addition & 1 deletion xtask/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ impl CastArgs {
let model_dir = PathBuf::from(self.model);

let time = Instant::now();
let model = Memory::load_safetensors_from_dir(&model_dir).unwrap();
let model = Memory::load_safetensors(&model_dir).unwrap();
println!("load model ... {:?}", time.elapsed());

let target = self.target.map(PathBuf::from).unwrap_or_else(|| {
Expand Down

0 comments on commit 75c4f29

Please sign in to comment.