Skip to content

Commit

Permalink
refactor(transformer): 使用封装的 safetensor 组件
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Apr 18, 2024
1 parent cd04944 commit 076b05c
Show file tree
Hide file tree
Showing 16 changed files with 194 additions and 180 deletions.
2 changes: 0 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

95 changes: 81 additions & 14 deletions common/src/safe_tensors.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
//! safetensors 文件的加载和访问。
use memmap2::Mmap;
use safetensors::tensor::TensorInfo;
use std::{
collections::HashMap,
collections::{hash_map, HashMap},
fs::File,
io::{Error as IoError, ErrorKind::NotFound},
mem::size_of_val,
ops::Deref,
path::Path,
pin::Pin,
sync::Arc,
};
use SafeTensorsError::{Io, Json};

pub use safetensors::Dtype;
pub use safetensors::{tensor::TensorInfo, Dtype};

/// safetensors 文件的统一结构。
pub struct SafeTensors {
tensors: HashMap<String, (usize, TensorInfo)>,
files: Vec<(Mmap, String)>,
tensors: HashMap<String, (usize, TensorInfo)>, // name -> (file_index, tensor_info)
files: Vec<(Mmap, String)>, // file_index -> (mmap, format)
}

/// 加载 safetensors 文件可能产生的错误。
Expand Down Expand Up @@ -44,7 +46,7 @@ pub struct SafeTensor<'a> {
/// [SafeTensors] 的张量迭代器。
pub struct Iter<'a> {
obj: &'a SafeTensors,
iter: std::collections::hash_map::Iter<'a, String, (usize, TensorInfo)>,
iter: hash_map::Iter<'a, String, (usize, TensorInfo)>,
}

impl SafeTensors {
Expand Down Expand Up @@ -94,7 +96,7 @@ impl SafeTensors {
// 加载索引文件
let index = File::open(&path).map_err(Io)?;
let index = unsafe { Mmap::map(&index) }.map_err(Io)?;
let index: SafeTensorIndex = serde_json::from_slice(&index).map_err(Json)?;
let index: SafeTensorsIndex = serde_json::from_slice(&index).map_err(Json)?;
// 初始化状态
let mut tensors = HashMap::new();
let mut file_map = HashMap::new();
Expand Down Expand Up @@ -131,6 +133,29 @@ impl SafeTensors {
Ok(Self { tensors, files })
}

/// 共享自身。
#[inline]
pub fn share(self) -> Pin<Arc<Self>> {
Pin::new(Arc::new(self))
}

/// 从共享的 [SafeTensors] 中获取共享的张量。
pub fn share_tensor(self: &Pin<Arc<Self>>, name: &str) -> Option<SharedTensor> {
let value = self.tensors.get(name)?;
let data = self.get_internal(value.0, &value.1).data;
Some(SharedTensor {
safetensors: self.clone(),
value: unsafe { &*(value as *const _) },
data: unsafe { &*(data as *const _) },
})
}

/// 检查张量是否存在。
#[inline]
pub fn contains(&self, name: &str) -> bool {
self.tensors.contains_key(name)
}

/// 获取张量。
#[inline]
pub fn get(&self, name: &str) -> Option<SafeTensor> {
Expand Down Expand Up @@ -192,35 +217,77 @@ impl<'a> Iterator for Iter<'a> {
}
}

/// 共享的张量。
#[derive(Clone)]
pub struct SharedTensor {
safetensors: Pin<Arc<SafeTensors>>,
value: &'static (usize, TensorInfo),
data: &'static [u8],
}

impl Deref for SharedTensor {
type Target = [u8];
#[inline]
fn deref(&self) -> &Self::Target {
self.data
}
}

impl SharedTensor {
/// 数据类型。
#[inline]
pub fn dtype(&self) -> Dtype {
self.value.1.dtype
}

/// 形状。
#[inline]
pub fn shape(&self) -> &[usize] {
&self.value.1.shape
}

/// 文件格式。
#[inline]
pub fn format(&self) -> &str {
&self.safetensors.files[self.value.0].1
}

/// 数据。
#[inline]
pub fn data(&self) -> &[u8] {
self.data
}
}

#[allow(missing_docs)]
#[derive(serde::Serialize, serde::Deserialize, Debug)]
pub struct SafeTensorIndex {
pub metadata: SafeTensorIndexMetadata,
pub struct SafeTensorsIndex {
pub metadata: SafeTensorsIndexMetadata,
pub weight_map: HashMap<String, String>,
}

#[allow(missing_docs)]
#[derive(serde::Serialize, serde::Deserialize, Debug)]
pub struct SafeTensorIndexMetadata {
pub struct SafeTensorsIndexMetadata {
pub total_size: usize,
}

#[allow(missing_docs)]
#[derive(serde::Serialize, serde::Deserialize, Debug)]
pub struct SafeTensorHeader {
pub struct SafeTensorsHeader {
#[serde(flatten)]
pub tensors: HashMap<String, TensorInfo>,
#[serde(rename = "__metadata__")]
pub metadata: SafeTensorHeaderMetadata,
pub metadata: SafeTensorsHeaderMetadata,
}

#[allow(missing_docs)]
#[derive(serde::Serialize, serde::Deserialize, Debug)]
pub struct SafeTensorHeaderMetadata {
pub struct SafeTensorsHeaderMetadata {
pub format: String,
}

fn load_header(file: &Mmap) -> Result<SafeTensorHeader, SafeTensorsError> {
fn load_header(file: &Mmap) -> Result<SafeTensorsHeader, SafeTensorsError> {
let header_len = unsafe { *file.as_ptr().cast::<u64>() };
let header = &file[size_of_val(&header_len)..][..header_len as _];
serde_json::from_slice(header).map_err(Json)
Expand Down
5 changes: 4 additions & 1 deletion nvidia/common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ mod rms_norm;
mod rotary_embedding;
mod swiglu;

pub use common::{test_model, utok};
pub use common::{
safe_tensors::{SafeTensors, SafeTensorsError},
test_model, utok,
};
pub use tensor::{slice, udim, DataType, LocalSplitable, Tensor};

use cublas::{Cublas, CublasSpore};
Expand Down
9 changes: 6 additions & 3 deletions nvidia/distributed/src/parameters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,14 @@ impl Layer<'_> {

#[test]
fn test_load() {
use common_nv::cuda::{self, Device};
use common_nv::{
cuda::{self, Device},
SafeTensorsError,
};
use log::LevelFilter::Trace;
use simple_logger::SimpleLogger;
use std::io::ErrorKind::NotFound;
use transformer::{Memory, SafeTensorError};
use transformer::Memory;

let Some(model_dir) = common_nv::test_model::find() else {
return;
Expand All @@ -167,7 +170,7 @@ fn test_load() {

let model = match safetensors {
Ok(m) => m,
Err(SafeTensorError::Io(e)) if e.kind() == NotFound => return,
Err(SafeTensorsError::Io(e)) if e.kind() == NotFound => return,
Err(e) => panic!("{e:?}"),
};

Expand Down
26 changes: 14 additions & 12 deletions nvidia/transformer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@ pub use common_nv::cuda;
use ::half::f16;
use common_nv::{
cuda::{memcpy_d2h, DevMem, DevMemSpore},
slice, udim, utok, DataType, LocalSplitable, NvidiaKernels, NvidiaKernelsPtx, Tensor,
slice, udim, utok, DataType, LocalSplitable, NvidiaKernels, NvidiaKernelsPtx, SafeTensors,
Tensor,
};
use cuda::{Context, ContextResource, ContextSpore, Device, Stream, StreamSpore};
use parameters::{LayerParameter, LayersParameters, ModelParameters};
use std::{
fs::File,
io::Read,
path::Path,
slice::from_raw_parts,
sync::{Arc, Mutex},
time::Instant,
Expand Down Expand Up @@ -138,18 +139,19 @@ impl transformer::Transformer for Transformer {
type Splitable<'ctx> = LocalSplitable<DevMem<'ctx>>;

impl Transformer {
pub fn new(config: File, mut safetensors: File, preload_layers: usize, dev: Device) -> Self {
let context = Arc::new(dev.retain_primary());
pub fn new(model_dir: impl AsRef<Path>, preload_layers: usize, dev: Device) -> Self {
let time = Instant::now();
let mut host = context.apply(|ctx| {
ctx.malloc_host::<u8>(safetensors.metadata().unwrap().len() as _)
.sporulate()
});
safetensors.read_exact(&mut host).unwrap();
drop(safetensors);
info!("read to host {:?}", time.elapsed());
let config = File::open(model_dir.as_ref().join("config.json")).unwrap();
let model = SafeTensors::load_from_dir(model_dir).unwrap();
info!("open file {:?}", time.elapsed());

let host = Memory::load_safetensors(config, host, false).unwrap();
let context = Arc::new(dev.retain_primary());
let host = Memory::load_safetensors(
config,
model,
Some(|l| context.apply(|ctx| ctx.malloc_host::<u8>(l).sporulate())),
)
.unwrap();
let load_layers = preload_layers.min(host.num_hidden_layers());

let (model, layers, kernels, transfer) = context.apply(|ctx| {
Expand Down
12 changes: 3 additions & 9 deletions service/src/nvidia.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,13 @@
use std::{fs::File, path::Path, time::Instant};
use std::{path::Path, time::Instant};

pub fn transformer(model_dir: impl AsRef<Path>, device: i32) -> transformer_nv::Transformer {
use transformer_nv::{cuda, Transformer};
cuda::init();

let time = Instant::now();
let model_dir = model_dir.as_ref();
let config = File::open(model_dir.join("config.json")).unwrap();
let safetensors = File::open(model_dir.join("model.safetensors")).unwrap();
info!("open file {:?}", time.elapsed());

let time = Instant::now();
cuda::init();
let dev = cuda::Device::new(device);
dev.set_mempool_threshold(u64::MAX);
let transformer = Transformer::new(config, safetensors, usize::MAX, dev);
let transformer = Transformer::new(model_dir, usize::MAX, dev);
info!("build transformer ... {:?}", time.elapsed());

transformer
Expand Down
5 changes: 3 additions & 2 deletions transformer-cpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,9 @@ fn tensor(dt: DataType, shape: &[udim]) -> Tensor<Blob> {

#[test]
fn test_build() {
use common::safe_tensors::SafeTensorsError;
use std::{io::ErrorKind::NotFound, time::Instant};
use transformer::{Memory, SafeTensorError};
use transformer::Memory;

let Some(model_dir) = common::test_model::find() else {
return;
Expand All @@ -342,7 +343,7 @@ fn test_build() {

let safetensors = match safetensors {
Ok(m) => m,
Err(SafeTensorError::Io(e)) if e.kind() == NotFound => return,
Err(SafeTensorsError::Io(e)) if e.kind() == NotFound => return,
Err(e) => panic!("{e:?}"),
};

Expand Down
2 changes: 0 additions & 2 deletions transformer/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ authors = ["YdrMaster <[email protected]>"]
[dependencies]
common = { path = "../common" }
tensor = { path = "../tensor" }
memmap2 = "0.9"
safetensors = "0.4"
rand = "0.8"
half.workspace = true
rayon.workspace = true
Expand Down
4 changes: 1 addition & 3 deletions transformer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ pub use blas::Matrix;
pub use buffer::LayerBuffer;
pub use cache::LayerCache;
pub use kernels::Kernels;
pub use parameters::{
save, DistributeScheme, DistributedLayer, Distributer, Llama2, Memory, SafeTensorError,
};
pub use parameters::{save, DistributeScheme, DistributedLayer, Distributer, Llama2, Memory};
pub use pos::pos;
pub use request::Request;
pub use sample::{BetweenF32, SampleArgs};
Expand Down
3 changes: 2 additions & 1 deletion transformer/src/parameters/cast.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use super::{memory::Layer, ConfigJson, Llama2, Memory, Storage};
use common::Blob;
use half::{bf16, f16};
use std::sync::Arc;
use tensor::{DataType, Tensor, Ty};

impl Memory {
Expand Down Expand Up @@ -59,5 +60,5 @@ fn typed<T: Ty + Sync, U: Ty + Send>(
.zip(reslice_mut(ans.physical_mut()))
.for_each(|(src, dst)| *dst = cast(src));

unsafe { ans.map_physical(Storage::new) }
unsafe { ans.map_physical(|b| Storage::Others(Arc::new(b))) }
}
5 changes: 3 additions & 2 deletions transformer/src/parameters/distribute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,8 @@ impl DistributeScheme {

#[test]
fn test() {
use super::{Memory, SafeTensorError};
use super::Memory;
use common::safe_tensors::SafeTensorsError;
use std::{io::ErrorKind::NotFound, time::Instant};

let Some(model_dir) = common::test_model::find() else {
Expand All @@ -297,7 +298,7 @@ fn test() {

let model = match safetensors {
Ok(m) => m,
Err(SafeTensorError::Io(e)) if e.kind() == NotFound => return,
Err(SafeTensorsError::Io(e)) if e.kind() == NotFound => return,
Err(e) => panic!("{e:?}"),
};

Expand Down
Loading

0 comments on commit 076b05c

Please sign in to comment.