Skip to content

Commit

Permalink
feat: 完成寒武纪模型搭建,适配 c 算子库
Browse files Browse the repository at this point in the history
  • Loading branch information
kilinchange committed Jul 5, 2024
1 parent cf26efb commit 862e797
Show file tree
Hide file tree
Showing 8 changed files with 532 additions and 151 deletions.
94 changes: 57 additions & 37 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,8 @@ tokio = { version = "1.38", features = ["rt-multi-thread", "sync"] }
digit-layout = "0.0"
build-script-cfg = "0.0"

operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "189c2d5", default-features = false }
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "f5d2b6b", default-features = false }
nccl = { git = "https://github.com/YdrMaster/cuda-driver", rev = "877df52" }
search-cuda-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "877df52" }

cndrv = "0.1"
search-neuware-tools = "0.0"
4 changes: 2 additions & 2 deletions devices/cambricon-mlu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ authors = ["YdrMaster <[email protected]>"]
common = { path = "../../common" }
common-devices = { path = "../common" }
tensor = { path = "../../tensor" }
sample = { path = "../../sample" }
digit-layout="0.0"
cndrv.workspace = true
# operators = { workspace = true, features = ["cambricon-mlu"] }
operators = { workspace = true, features = ["cambricon-mlu"] }

[build-dependencies]
build-script-cfg.workspace = true
Expand Down
28 changes: 28 additions & 0 deletions devices/cambricon-mlu/src/gather.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
use common::utok;
use operators::cndrv::{DevByte, Queue};
use std::ops::{Deref, DerefMut};
use tensor::Tensor;

pub fn gather<T, U, I>(x: &mut Tensor<T>, table: &Tensor<U>, tokens: I, queue: &Queue)
where
T: DerefMut<Target = [DevByte]>,
U: Deref<Target = [u8]>,
I: IntoIterator<Item = utok>,
{
let &[_, d] = x.shape() else { panic!() };

debug_assert_eq!(x.data_layout(), table.data_layout());
debug_assert_eq!(table.shape().len(), 2);
debug_assert_eq!(table.shape()[1], d);
debug_assert!(x.is_contiguous());
debug_assert!(table.is_contiguous());
let d = d as usize * x.data_layout().nbytes();

let x = &mut **x.physical_mut();
let table = table.as_slice();
for (i, t) in tokens.into_iter().enumerate() {
let dst = &mut x[d * i..][..d];
let src = &table[d * t as usize..][..d];
queue.memcpy_h2d(dst, src);
}
}
Loading

0 comments on commit 862e797

Please sign in to comment.