Skip to content

Commit

Permalink
fix(xtask): 在 chat 中测试 fork 并扩展 nvidia reform 以支持 fork 的需求
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Apr 28, 2024
1 parent d7aadec commit 8d0cdb7
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 10 deletions.
30 changes: 23 additions & 7 deletions nvidia/common/src/reform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,27 +81,43 @@ extern "C" __global__ void {name}(
assert_eq!(dst.data_type(), src.data_type());
assert_eq!(dst.shape(), src.shape());

let &[r, c, b] = dst.shape() else {
let [r @ .., c, b] = dst.shape() else {
unreachable!()
};
let &[rsa, csa, 1] = dst.strides() else {
let [rsa @ .., csa, 1] = dst.strides() else {
unreachable!()
};
let &[rsb, csb, 1] = src.strides() else {
let [rsb @ .., csb, 1] = src.strides() else {
unreachable!()
};

let r = match r {
[] => unreachable!(),
&[r] => r,
r => {
assert!(r
.iter()
.map(|r| *r as i32)
.enumerate()
.skip(1)
.all(|(i, r)| r * rsa[i] == rsa[i - 1] && r * rsb[i] == rsb[i - 1]));
r.iter().product()
}
};
let rsa = rsa.last().unwrap();
let rsb = rsb.last().unwrap();

let contiguous_bytes = b * dst.data_type().size() as udim;
assert_eq!(contiguous_bytes % self.warp_size, 0);
let bytes_per_thread = contiguous_bytes / self.warp_size;
assert!(bytes_per_thread <= 32 && bytes_per_thread.is_power_of_two());

let dst_ptr = (dst.physical().as_ptr() as isize + dst.bytes_offset()) as CUdeviceptr;
let rsa = rsa as udim / b;
let csa = csa as udim / b;
let rsa = *rsa as udim / b;
let csa = *csa as udim / b;
let src_ptr = (src.physical().as_ptr() as isize + src.bytes_offset()) as CUdeviceptr;
let rsb = rsb as udim / b;
let csb = csb as udim / b;
let rsb = *rsb as udim / b;
let csb = *csb as udim / b;
let params: [*const c_void; 8] = [
(&dst_ptr) as *const _ as _,
(&rsa) as *const _ as _,
Expand Down
28 changes: 25 additions & 3 deletions xtask/src/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,10 @@ fn print_help() {
println!(
"\
/list 列出现存的会话及对话次数
/create 新建会话session
/switch [0-9+] 切换至指定会话
/drop [0-9+] 丢弃指定会话
/create 新建会话
/fork [id] 复制当前会话或指定会话
/switch <id> 切换至指定会话
/drop [id] 丢弃当前会话或指定会话
/args 打印当前参数
/args key value 设置指定参数
/help 打印帮助信息
Expand Down Expand Up @@ -147,6 +148,27 @@ impl<M: CausalLM> Chatting<M> {
self.sessions.insert(self.current, self.service.launch());
println!("Create new session {}.", self.current);
}
["/fork"] => {
let new = self.session().fork();
self.current = self.next_id;
self.next_id += 1;
self.sessions.insert(self.current, new);
println!("Fork session to {}.", self.current);
}
["/fork", n] => match n.parse() {
Ok(target_id) => {
if let Some(s) = self.sessions.get(&target_id) {
let new = s.fork();
self.current = self.next_id;
self.next_id += 1;
self.sessions.insert(self.current, new);
println!("Fork session {} to {}.", target_id, self.current);
} else {
println!("Invalid session ID.");
}
}
Err(_) => println!("Invalid drop command"),
},
["/switch", n] => match n.parse() {
Ok(target_id) => {
if target_id == self.current {
Expand Down

0 comments on commit 8d0cdb7

Please sign in to comment.