Skip to content

Commit

Permalink
feat(distributed): 开始实现分布式推理
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Apr 4, 2024
1 parent 88979c2 commit 5a63355
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 19 deletions.
8 changes: 4 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ serde = "1.0"
log = "0.4"
tokio = { version = "1.37", features = ["rt-multi-thread", "sync"] }

cuda = { git = "https://github.com/YdrMaster/cuda-driver", rev = "e95b84" }
cublas = { git = "https://github.com/YdrMaster/cuda-driver", rev = "e95b84" }
nccl = { git = "https://github.com/YdrMaster/cuda-driver", rev = "e95b84" }
search-cuda-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "e95b84" }
cuda = { git = "https://github.com/YdrMaster/cuda-driver", rev = "e92d252" }
cublas = { git = "https://github.com/YdrMaster/cuda-driver", rev = "e92d252" }
nccl = { git = "https://github.com/YdrMaster/cuda-driver", rev = "e92d252" }
search-cuda-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "e92d252" }
80 changes: 78 additions & 2 deletions nvidia/distributed/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,87 @@ extern crate log;

pub use common_nv::cuda;

use common_nv::{
cuda::{AsRaw, ContextResource, ContextSpore, Device},
utok, Tensor,
};
use nccl::CommunicatorGroup;
use parameters::ParameterMatrix;
use std::{iter::zip, path::Path, time::Instant};
use transformer::{LayerCache, Llama2, Memory, Request};

pub struct Transformer {
comms: nccl::CommunicatorGroup,
host: Memory,
comms: CommunicatorGroup,
matrix: ParameterMatrix,
}

impl Transformer {}
impl transformer::Transformer for Transformer {
type Cache = ();

#[inline]
fn model(&self) -> &dyn Llama2 {
&self.host
}

fn new_cache(&self) -> Vec<LayerCache<Self::Cache>> {
todo!()
}

fn decode<Id>(
&self,
mut requests: Vec<Request<Id, Self::Cache>>,
) -> (Vec<Id>, Tensor<Self::Cache>) {
// 归拢所有纯解码的请求到前面,减少批量解码的拷贝开销
requests.sort_unstable_by_key(Request::purely_decode);

let contexts = self.comms.context_iter().collect::<Vec<_>>();
let streams = contexts
.iter()
.map(|ctx| ctx.apply(|c| c.stream().sporulate()))
.collect::<Vec<_>>();

for (context, mut stream) in zip(contexts, streams) {
context.apply(|ctx| unsafe { stream.kill(ctx) });
}
todo!()
}

fn sample<Id>(
&self,
_args: &transformer::SampleArgs,
_requests: Vec<Id>,
_logits: Tensor<Self::Cache>,
) -> Vec<(Id, utok)> {
todo!()
}
}

impl Transformer {
pub fn new(model_dir: impl AsRef<Path>, dev: &[Device]) -> Self {
let time = Instant::now();
let host = Memory::load_safetensors_from_dir(model_dir).unwrap();
info!("load host: {:?}", time.elapsed());

Self {
comms: CommunicatorGroup::new(
&dev.iter()
.map(|dev| unsafe { dev.as_raw() })
.collect::<Vec<_>>(),
),
matrix: ParameterMatrix::load(
&host,
&dev.iter().map(Device::retain_primary).collect::<Vec<_>>(),
),
host,
}
}
}

impl Drop for Transformer {
#[inline]
fn drop(&mut self) {
let contexts = self.comms.context_iter().collect::<Vec<_>>();
unsafe { self.matrix.kill(&contexts) }
}
}
10 changes: 4 additions & 6 deletions nvidia/transformer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@ use std::{
sync::{Arc, Mutex},
time::Instant,
};
use transformer::{
pos, Kernels, LayerBuffer, LayerCache, Llama2, Memory, Request, SampleArgs, Transformer,
};
use transformer::{pos, Kernels, LayerBuffer, LayerCache, Llama2, Memory, Request, SampleArgs};

pub struct NvidiaTransformer {
pub struct Transformer {
host: Memory,
model: ModelParameters,
layers: Mutex<LayersParameters>,
Expand All @@ -30,7 +28,7 @@ pub struct NvidiaTransformer {
kernels: NvidiaKernels,
}

impl Transformer for NvidiaTransformer {
impl transformer::Transformer for Transformer {
type Cache = Cache;

#[inline]
Expand Down Expand Up @@ -127,7 +125,7 @@ impl Transformer for NvidiaTransformer {
}
}

impl NvidiaTransformer {
impl Transformer {
pub fn new(config: File, mut safetensors: File, preload_layers: usize, dev: Device) -> Self {
let context = Arc::new(dev.retain_primary());
let time = Instant::now();
Expand Down
6 changes: 3 additions & 3 deletions service/src/nvidia.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::{fs::File, path::Path, time::Instant};
use transformer_nv::{cuda, NvidiaTransformer};
use transformer_nv::{cuda, Transformer};

pub fn transformer(model_dir: impl AsRef<Path>, device: i32) -> NvidiaTransformer {
pub fn transformer(model_dir: impl AsRef<Path>, device: i32) -> Transformer {
cuda::init();

let time = Instant::now();
Expand All @@ -13,7 +13,7 @@ pub fn transformer(model_dir: impl AsRef<Path>, device: i32) -> NvidiaTransforme
let time = Instant::now();
let dev = cuda::Device::new(device);
dev.set_mempool_threshold(u64::MAX);
let transformer = NvidiaTransformer::new(config, safetensors, usize::MAX, dev);
let transformer = Transformer::new(config, safetensors, usize::MAX, dev);
info!("build transformer ... {:?}", time.elapsed());

transformer
Expand Down

0 comments on commit 5a63355

Please sign in to comment.