Skip to content

Commit

Permalink
perf(model-parameters): use rayon to cast faster
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Feb 15, 2024
1 parent 1c9e761 commit dfdc605
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 10 deletions.
1 change: 1 addition & 0 deletions model-parameters/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ authors = ["YdrMaster <[email protected]>"]
common = { path = "../common" }
log = "0.4"
half = "2.3"
rayon = "1.8"
memmap2 = "0.9"
safetensors = "0.4"
serde_json = "1.0"
Expand Down
36 changes: 26 additions & 10 deletions model-parameters/src/memory/inside_memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,27 @@ use half::{bf16, f16};

impl Memory<Vec<u8>> {
pub fn cast<T: AsRef<[u8]>>(src: &Memory<T>, new_dtype: DataType) -> Self {
let mut blob = Vec::new();

let from = src.config.torch_dtype;
let mut blob = Vec::with_capacity(src.blob.as_ref().len() * new_dtype.size() / from.size());
let mut append = |src: &[u8]| {
let start = blob.len();
let end = start + src.len() * new_dtype.size() / from.size();
blob.resize(end, 0);
let additional = src.len() * new_dtype.size() / from.size();
let end = start + additional;

blob.reserve_exact(additional);
unsafe { blob.set_len(end) };

cast(from, src, new_dtype, &mut blob[start..end]);
start
};

if let Some(n) = std::option_env!("RAYON_PARALLELISM") {
rayon::ThreadPoolBuilder::new()
.num_threads(n.parse().unwrap())
.build_global()
.unwrap();
}

let embed_tokens = append(src.embed_tokens());
let layers = (0..src.config.num_hidden_layers)
.map(|layer| LayerParamsOffset {
Expand Down Expand Up @@ -48,12 +58,18 @@ impl Memory<Vec<u8>> {
fn cast(src_dtype: DataType, src: &[u8], dst_dtype: DataType, dst: &mut [u8]) {
macro_rules! cast {
($f:expr; $src:expr, $src_ty:ty => $dst:expr, $dst_ty:ty) => {
let len = $src.len() / std::mem::size_of::<$src_ty>();
assert_eq!(len * std::mem::size_of::<$dst_ty>(), $dst.len());
let src = unsafe { std::slice::from_raw_parts($src.as_ptr() as *const $src_ty, len) };
let dst =
unsafe { std::slice::from_raw_parts_mut($dst.as_mut_ptr() as *mut $dst_ty, len) };
src.iter().zip(dst).for_each(|(src, dst)| *dst = $f(*src));
use rayon::iter::*;
use std::{mem::size_of, slice::from_raw_parts, slice::from_raw_parts_mut};

let len = $src.len() / size_of::<$src_ty>();
debug_assert_eq!(len * size_of::<$dst_ty>(), $dst.len());
let src = unsafe { from_raw_parts($src.as_ptr() as *const $src_ty, len) };
let dst = unsafe { from_raw_parts_mut($dst.as_mut_ptr() as *mut $dst_ty, len) };

#[allow(clippy::redundant_closure_call)]
src.par_iter()
.zip(dst)
.for_each(|(src, dst)| *dst = $f(*src));
};
}

Expand Down

0 comments on commit dfdc605

Please sign in to comment.