Skip to content

Commit

Permalink
feat(distributed): 实现分布式 embed token
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Apr 10, 2024
1 parent 7b0ff98 commit 960da2a
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 68 deletions.
2 changes: 0 additions & 2 deletions nvidia/common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,9 @@ mod mat_mul;
mod reform;
mod rms_norm;
mod rotary_embedding;
mod storage;
mod swiglu;

pub use common::utok;
pub use storage::{tensor, Cache};
pub use tensor::{slice, udim, DataType, LocalSplitable, Tensor};

use cublas::{Cublas, CublasSpore};
Expand Down
24 changes: 0 additions & 24 deletions nvidia/common/src/storage.rs

This file was deleted.

17 changes: 8 additions & 9 deletions nvidia/distributed/src/gather.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
use common_nv::{
cuda::{ContextSpore, DevByte, StreamSpore},
cuda::{ContextSpore, DevMemSpore, StreamSpore},
utok, Tensor,
};
use std::ops::{Deref, DerefMut};
use std::ops::Deref;

pub fn gather<T, U, I>(
x: &mut [Tensor<T>],
table: &Tensor<U>,
pub fn gather<T, I>(
x: &mut [Tensor<DevMemSpore>],
table: &Tensor<T>,
tokens: I,
comms: &nccl::CommunicatorGroup,
streams: &[StreamSpore],
) where
T: DerefMut<Target = [DevByte]>,
U: Deref<Target = [u8]>,
T: Deref<Target = [u8]>,
I: IntoIterator<Item = utok>,
{
assert!(!x.is_empty());
Expand Down Expand Up @@ -41,12 +40,12 @@ pub fn gather<T, U, I>(
for (i, comm) in comms.call().iter().enumerate() {
comm.device().retain_primary().apply(|ctx| {
let stream = unsafe { streams[i].sprout(ctx) };
let dst = &mut **x[i].physical_mut();
let mut dst = unsafe { x[i].physical_mut().sprout(ctx) };
for _ in 0..distributed {
let Some((i, t)) = iter.next() else { break };
stream.memcpy_h2d(&mut dst[d * i..][..d], &table[d * t..][..d]);
}
comm.all_gather(dst, None, &stream);
comm.all_gather(&mut dst, None, &stream);
});
}
}
67 changes: 52 additions & 15 deletions nvidia/distributed/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#![cfg(detected_nccl)]
#![allow(unused)]

mod gather;
mod parameters;
Expand All @@ -10,8 +9,8 @@ extern crate log;
pub use common_nv::cuda;

use common_nv::{
cuda::{AsRaw, ContextResource, ContextSpore, CudaDataType, DevMem, Device, StreamSpore},
tensor, udim, utok, DataType, LocalSplitable, Tensor,
cuda::{AsRaw, ContextResource, ContextSpore, DevMemSpore, Device, StreamSpore},
udim, utok, Tensor,
};
use nccl::CommunicatorGroup;
use parameters::ParameterMatrix;
Expand Down Expand Up @@ -49,8 +48,24 @@ impl transformer::Transformer for Transformer {
.map(|ctx| ctx.apply(|c| c.stream().sporulate()))
.collect::<Vec<_>>();

for (context, mut stream) in zip(contexts, streams) {
context.apply(|ctx| unsafe { stream.kill(ctx) });
let x0 = self.token_embed(&requests, &streams);
let x1 = zip(zip(&contexts, &streams), &x0)
.map(|((context, stream), x)| {
context.apply(|ctx| {
let stream = unsafe { stream.sprout(ctx) };
Tensor::alloc(x.data_type(), x.shape(), |len| {
stream.malloc::<u8>(len).sporulate()
})
})
})
.collect::<Vec<_>>();

for ((context, mut stream), (mut x0, mut x1)) in zip(zip(contexts, streams), zip(x0, x1)) {
context.apply(|ctx| unsafe {
stream.kill(ctx);
x0.physical_mut().kill(ctx);
x1.physical_mut().kill(ctx);
});
}
todo!()
}
Expand Down Expand Up @@ -85,11 +100,11 @@ impl Transformer {
}
}

fn token_embed<'ctx, Id>(
fn token_embed<Id>(
&self,
requests: &[Request<Id, ()>],
streams: &[StreamSpore],
) -> Vec<Tensor<LocalSplitable<DevMem<'ctx>>>> {
) -> Vec<Tensor<DevMemSpore>> {
let dt = self.host.data_type();
let nt = requests.iter().map(Request::seq_len).sum::<udim>();
let d = self.host.hidden_size() as udim;
Expand All @@ -98,7 +113,13 @@ impl Transformer {
.comms
.contexts()
.zip(streams)
.map(|(context, stream)| context.apply(|ctx| tensor(dt, &[nt, d], todo!())))
.map(|(context, stream)| {
context.apply(|ctx| {
Tensor::alloc(dt, &[nt, d], |len| {
unsafe { stream.sprout(ctx) }.malloc::<u8>(len).sporulate()
})
})
})
.collect::<Vec<_>>();

let tokens = requests.iter().flat_map(Request::tokens).copied();
Expand All @@ -122,12 +143,28 @@ impl Drop for Transformer {
}
}

fn convert(dt: DataType) -> CudaDataType {
match dt {
DataType::F16 => CudaDataType::f16,
DataType::BF16 => CudaDataType::bf16,
DataType::F32 => CudaDataType::f32,
DataType::F64 => CudaDataType::f64,
_ => unreachable!(),
#[test]
fn test() {
use common_nv::cuda::{self, Device};
use log::LevelFilter::Trace;
use simple_logger::SimpleLogger;
use transformer::Transformer as _;

const N: usize = 1;

cuda::init();
if Device::count() < N {
return;
}

SimpleLogger::new().with_level(Trace).init().unwrap();

let time = Instant::now();
let transformer = Transformer::new(
"../../../TinyLlama-1.1B-Chat-v1.0_F16",
&[Device::fetch().unwrap()],
);
info!("load {:?}", time.elapsed());

transformer.decode(vec![Request::new(0, &[1, 2, 3], &mut [], 0, true)]);
}
6 changes: 4 additions & 2 deletions nvidia/distributed/src/parameters.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use common_nv::{
#![allow(unused)]

use common_nv::{
cuda::{Context, ContextGuard, ContextResource, ContextSpore, DevByte, DevMem, DevMemSpore},
udim, Tensor,
};
Expand Down Expand Up @@ -154,7 +156,7 @@ fn test_load() {
SimpleLogger::new().with_level(Trace).init().unwrap();

let time = Instant::now();
let safetensors = Memory::load_safetensors_from_dir("../../../TinyLlama-1.1B-Chat-v1.0");
let safetensors = Memory::load_safetensors_from_dir("../../../TinyLlama-1.1B-Chat-v1.0_F16");
info!("mmap {:?}", time.elapsed());

let model = match safetensors {
Expand Down
44 changes: 28 additions & 16 deletions nvidia/transformer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ pub use common_nv::cuda;

use ::half::f16;
use common_nv::{
cuda::{bindings::CUdeviceptr, memcpy_d2h, DevMem},
slice, tensor, udim, utok, Cache, DataType, LocalSplitable, NvidiaKernels, Tensor,
cuda::{bindings::CUdeviceptr, memcpy_d2h, DevMem, DevMemSpore},
slice, udim, utok, DataType, LocalSplitable, NvidiaKernels, Tensor,
};
use cuda::{AsRaw, Context, ContextResource, ContextSpore, Device, Stream, StreamSpore};
use parameters::{LayerParameter, LayersParameters, ModelParameters};
Expand Down Expand Up @@ -61,15 +61,15 @@ impl transformer::Transformer for Transformer {
let compute = ctx.stream();
// 生成词嵌入并预分配空间
let mut x0 = self.token_embed(&requests, &compute);
let mut x1 = tensor(x0.data_type(), x0.shape(), &transfer);
let mut x1 =
Tensor::alloc(x0.data_type(), x0.shape(), |len| transfer.malloc::<u8>(len));
let mut buf = LayerBuffer::alloc(&self.host, &requests, |len| {
transfer.malloc::<u8>(len).into()
});
// 生成位置张量
let nt = x0.shape()[0]; // `nt` for number of tokens
let pos_ = pos(&requests, nt);
let mut pos = tensor(DataType::U32, &[nt], &transfer);
transfer.memcpy_h2d(pos.physical_mut(), &pos_);
let pos = Tensor::new(DataType::U32, &[nt], transfer.from_host(&pos_));
// 推理
compute.wait_for(&transfer.record());
{
Expand Down Expand Up @@ -170,13 +170,13 @@ impl Transformer {
&self,
requests: &[Request<Id, Cache>],
compute: &Stream<'ctx>,
) -> Tensor<Local<'ctx>> {
) -> Tensor<DevMem<'ctx>> {
let dt = self.host.data_type();
let nt = requests.iter().map(Request::seq_len).sum::<udim>();
let d = self.host.hidden_size() as udim;
let kernels = self.kernels.on(compute);

let mut x0 = tensor(dt, &[nt, d], compute);
let mut x0 = Tensor::alloc(dt, &[nt, d], |len| compute.malloc::<u8>(len));
let tokens = requests.iter().flat_map(Request::tokens).copied();
kernels.gather(&mut x0, &self.host.embed_tokens(), tokens);
// compute.synchronize();
Expand All @@ -188,10 +188,10 @@ impl Transformer {
fn before_att<'ctx>(
&self,
params: &LayerParameter,
x0: &Tensor<Local>,
x1: &mut Tensor<Local>,
x0: &Tensor<DevMem>,
x1: &mut Tensor<DevMem>,
qkv: &mut Tensor<Local<'ctx>>,
pos: &Tensor<Local>,
pos: &Tensor<DevMem>,
compute: &Stream,
) -> (
Tensor<Local<'ctx>>,
Expand Down Expand Up @@ -240,7 +240,7 @@ impl Transformer {
q: Tensor<Local>,
k: Tensor<Local>,
v: Tensor<Local>,
o: &mut Tensor<Local>,
o: &mut Tensor<DevMem>,
q_buf: &mut Local,
att_buf: &mut Local,
compute: &Stream,
Expand Down Expand Up @@ -324,8 +324,8 @@ impl Transformer {
fn after_att(
&self,
params: &LayerParameter,
x0: &mut Tensor<Local>,
x1: &mut Tensor<Local>,
x0: &mut Tensor<DevMem>,
x1: &mut Tensor<DevMem>,
gate_up: &mut Tensor<Local>,
compute: &Stream,
) {
Expand Down Expand Up @@ -366,9 +366,9 @@ impl Transformer {
fn move_decode<'ctx, Id>(
&self,
requests: &[Request<Id, Cache>],
x0: Tensor<Local<'ctx>>,
x0: Tensor<DevMem<'ctx>>,
compute: &Stream,
) -> Tensor<Local<'ctx>> {
) -> Tensor<DevMem<'ctx>> {
let buf = x0.physical().as_ptr() as CUdeviceptr;
let len = self.host.hidden_size() * self.host.data_type().size();

Expand All @@ -395,7 +395,7 @@ impl Transformer {
x0.slice(&[slice![from begin, until dst + 1], slice![all]])
}

fn logits(&self, mut x: Tensor<Local>, compute: &Stream) -> Tensor<Cache> {
fn logits(&self, mut x: Tensor<DevMem>, compute: &Stream) -> Tensor<Cache> {
let dt = self.host.data_type();
let voc = self.host.vocab_size() as udim;

Expand Down Expand Up @@ -430,6 +430,18 @@ impl Transformer {
}
}

pub struct Cache {
pub context: Arc<Context>,
pub mem: DevMemSpore,
}

impl Drop for Cache {
#[inline]
fn drop(&mut self) {
self.context.apply(|ctx| unsafe { self.mem.kill(ctx) });
}
}

#[inline]
fn tensor_cache(
dt: DataType,
Expand Down
7 changes: 7 additions & 0 deletions tensor/src/split.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@ impl<T: Clone> Splitable for T {
#[repr(transparent)]
pub struct LocalSplitable<T>(Rc<T>);

impl<T> AsRef<T> for LocalSplitable<T> {
#[inline]
fn as_ref(&self) -> &T {
&self.0
}
}

impl<T> From<T> for LocalSplitable<T> {
#[inline]
fn from(t: T) -> Self {
Expand Down

0 comments on commit 960da2a

Please sign in to comment.