-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: YdrMaster <[email protected]>
- Loading branch information
Showing
12 changed files
with
192 additions
and
72 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
|
||
#[macro_use] | ||
extern crate log; | ||
pub extern crate cuda; | ||
|
||
mod fused_softmax; | ||
mod gather; | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,7 +9,7 @@ authors = ["YdrMaster <[email protected]>"] | |
[dependencies] | ||
transformer = { path = "../../transformer" } | ||
common-nv = { path = "../common" } | ||
cuda.workspace = true | ||
nccl.workspace = true | ||
log.workspace = true | ||
half.workspace = true | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,60 +1,17 @@ | ||
#![cfg(detected_nccl)] | ||
|
||
#[test] | ||
fn test_load() { | ||
use cuda::{ContextResource, ContextSpore, Device}; | ||
use std::{io::ErrorKind::NotFound, time::Instant}; | ||
use transformer::{Distributer, Llama2, Memory, SafeTensorError}; | ||
mod parameters; | ||
|
||
const N: usize = 1; | ||
#[macro_use] | ||
extern crate log; | ||
|
||
cuda::init(); | ||
if Device::count() < N { | ||
return; | ||
} | ||
let devices = (0..N as _).map(Device::new).collect::<Vec<_>>(); | ||
let contexts = devices | ||
.iter() | ||
.map(Device::retain_primary) | ||
.collect::<Vec<_>>(); | ||
let align = devices.iter().map(Device::alignment).max().unwrap(); | ||
pub use common_nv::cuda; | ||
|
||
let time = Instant::now(); | ||
let safetensors = Memory::load_safetensors_from_dir("../../../TinyLlama-1.1B-Chat-v1.0"); | ||
println!("mmap {:?}", time.elapsed()); | ||
use parameters::ParameterMatrix; | ||
|
||
let model = match safetensors { | ||
Ok(m) => m, | ||
Err(SafeTensorError::Io(e)) if e.kind() == NotFound => return, | ||
Err(e) => panic!("{e:?}"), | ||
}; | ||
|
||
let nlayers = model.num_hidden_layers(); | ||
let mut matrix = Vec::with_capacity(contexts.len() * nlayers); | ||
|
||
let distributer = Distributer::new(&model, contexts.len(), align); | ||
let time = Instant::now(); | ||
for (i, context) in contexts.iter().enumerate() { | ||
context.apply(|ctx| { | ||
let stream = ctx.stream(); | ||
for layer in 0..nlayers { | ||
matrix.push( | ||
stream | ||
.from_host(distributer.distribute(layer, i).as_slice()) | ||
.sporulate(), | ||
); | ||
} | ||
}); | ||
} | ||
println!("distribute {:?}", time.elapsed()); | ||
|
||
let time = Instant::now(); | ||
for (i, context) in contexts.iter().enumerate() { | ||
context.apply(|ctx| { | ||
for element in &mut matrix[i * nlayers..][..nlayers] { | ||
unsafe { element.kill(ctx) }; | ||
} | ||
}); | ||
} | ||
println!("kill {:?}", time.elapsed()); | ||
pub struct Transformer { | ||
comms: nccl::CommunicatorGroup, | ||
matrix: ParameterMatrix, | ||
} | ||
|
||
impl Transformer {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
use common_nv::{ | ||
cuda::{Context, ContextGuard, ContextResource, ContextSpore, DevMem, DevMemSpore, DevSlice}, | ||
udim, Tensor, | ||
}; | ||
use std::time::Instant; | ||
use transformer::{DistributeScheme, Distributer, Llama2}; | ||
|
||
pub struct ParameterMatrix { | ||
scheme: DistributeScheme, | ||
matrix: Vec<DevMemSpore>, | ||
} | ||
|
||
impl ParameterMatrix { | ||
pub fn load(model: &dyn Llama2, contexts: &[Context]) -> Self { | ||
let align = contexts | ||
.iter() | ||
.map(|ctx| ctx.device().alignment()) | ||
.max() | ||
.unwrap(); | ||
|
||
let nlayers = model.num_hidden_layers(); | ||
let mut matrix = Vec::with_capacity(contexts.len() * nlayers); | ||
|
||
let distributer = Distributer::new(model, contexts.len(), align); | ||
let time = Instant::now(); | ||
for (i, context) in contexts.iter().enumerate() { | ||
context.apply(|ctx| { | ||
let stream = ctx.stream(); | ||
for layer in 0..nlayers { | ||
matrix.push( | ||
stream | ||
.from_host(distributer.distribute(layer, i).as_slice()) | ||
.sporulate(), | ||
); | ||
} | ||
}); | ||
} | ||
info!("distribute {:?}", time.elapsed()); | ||
|
||
Self { | ||
scheme: distributer.scheme().clone(), | ||
matrix, | ||
} | ||
} | ||
|
||
pub unsafe fn kill(&mut self, contexts: &[Context]) { | ||
assert_eq!(contexts.len(), self.scheme.n); | ||
let nlayers = self.matrix.len() / self.scheme.n; | ||
for (i, context) in contexts.iter().enumerate() { | ||
context.apply(|ctx| { | ||
for element in &mut self.matrix[i * nlayers..][..nlayers] { | ||
element.kill(ctx); | ||
} | ||
}); | ||
} | ||
} | ||
} | ||
|
||
pub struct Layer<'ctx> { | ||
scheme: &'ctx DistributeScheme, | ||
mem: DevMem<'ctx>, | ||
} | ||
|
||
impl ParameterMatrix { | ||
pub fn get<'ctx>(&'ctx self, layer: usize, i: usize, ctx: &'ctx ContextGuard) -> Layer<'ctx> { | ||
let nlayers = self.matrix.len() / self.scheme.n; | ||
Layer { | ||
scheme: &self.scheme, | ||
mem: unsafe { self.matrix[i * nlayers + layer].sprout(ctx) }, | ||
} | ||
} | ||
} | ||
|
||
impl Layer<'_> { | ||
#[inline] | ||
pub fn input_layernorm(&self) -> Tensor<&DevSlice> { | ||
let d = self.scheme.nh * self.scheme.dh; | ||
Tensor::new(self.scheme.dt, &[d], todo!()) | ||
} | ||
|
||
#[inline] | ||
pub fn w_qkv(&self) -> Tensor<&[u8]> { | ||
let nh = self.scheme.nh; | ||
let nkvh = self.scheme.nkvh; | ||
let dh = self.scheme.dh; | ||
let d = nh * dh; | ||
let n = self.scheme.n as udim; | ||
Tensor::new(self.scheme.dt, &[(nh + nkvh + nkvh) / n * dh, d], todo!()) | ||
} | ||
|
||
#[inline] | ||
pub fn w_o(&self) -> Tensor<&[u8]> { | ||
let d = self.scheme.nh * self.scheme.dh; | ||
let n = self.scheme.n as udim; | ||
Tensor::new(self.scheme.dt, &[d / n, d], todo!()) | ||
} | ||
|
||
#[inline] | ||
pub fn post_att_layernorm(&self) -> Tensor<&[u8]> { | ||
let d = self.scheme.nh * self.scheme.dh; | ||
Tensor::new(self.scheme.dt, &[d], todo!()) | ||
} | ||
|
||
#[inline] | ||
pub fn mlp_gate_up(&self) -> Tensor<&[u8]> { | ||
let di = self.scheme.di; | ||
let d = self.scheme.nh * self.scheme.dh; | ||
let n = self.scheme.n as udim; | ||
Tensor::new(self.scheme.dt, &[(di + di) / n, d], todo!()) | ||
} | ||
|
||
#[inline] | ||
pub fn mlp_down(&self) -> Tensor<&[u8]> { | ||
let di = self.scheme.di; | ||
let d = self.scheme.nh * self.scheme.dh; | ||
let n = self.scheme.n as udim; | ||
Tensor::new(self.scheme.dt, &[d, di / n], todo!()) | ||
} | ||
} | ||
|
||
#[test] | ||
fn test_load() { | ||
use common_nv::cuda::{self, Device}; | ||
use std::io::ErrorKind::NotFound; | ||
use transformer::{Memory, SafeTensorError}; | ||
|
||
const N: usize = 1; | ||
|
||
cuda::init(); | ||
if Device::count() < N { | ||
return; | ||
} | ||
|
||
let time = Instant::now(); | ||
let safetensors = Memory::load_safetensors_from_dir("../../../TinyLlama-1.1B-Chat-v1.0"); | ||
println!("mmap {:?}", time.elapsed()); | ||
|
||
let model = match safetensors { | ||
Ok(m) => m, | ||
Err(SafeTensorError::Io(e)) if e.kind() == NotFound => return, | ||
Err(e) => panic!("{e:?}"), | ||
}; | ||
|
||
let contexts = (0..N as _) | ||
.map(|i| Device::new(i).retain_primary()) | ||
.collect::<Vec<_>>(); | ||
unsafe { ParameterMatrix::load(&model, &contexts).kill(&contexts) }; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,7 +9,6 @@ authors = ["YdrMaster <[email protected]>"] | |
[dependencies] | ||
transformer = { path = "../../transformer" } | ||
common-nv = { path = "../common" } | ||
cuda.workspace = true | ||
log.workspace = true | ||
half.workspace = true | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters