diff --git a/nvidia/common/src/reform.rs b/nvidia/common/src/reform.rs index 5fef03fa..7e3abd4b 100644 --- a/nvidia/common/src/reform.rs +++ b/nvidia/common/src/reform.rs @@ -81,27 +81,43 @@ extern "C" __global__ void {name}( assert_eq!(dst.data_type(), src.data_type()); assert_eq!(dst.shape(), src.shape()); - let &[r, c, b] = dst.shape() else { + let [r @ .., c, b] = dst.shape() else { unreachable!() }; - let &[rsa, csa, 1] = dst.strides() else { + let [rsa @ .., csa, 1] = dst.strides() else { unreachable!() }; - let &[rsb, csb, 1] = src.strides() else { + let [rsb @ .., csb, 1] = src.strides() else { unreachable!() }; + let r = match r { + [] => unreachable!(), + &[r] => r, + r => { + assert!(r + .iter() + .map(|r| *r as i32) + .enumerate() + .skip(1) + .all(|(i, r)| r * rsa[i] == rsa[i - 1] && r * rsb[i] == rsb[i - 1])); + r.iter().product() + } + }; + let rsa = rsa.last().unwrap(); + let rsb = rsb.last().unwrap(); + let contiguous_bytes = b * dst.data_type().size() as udim; assert_eq!(contiguous_bytes % self.warp_size, 0); let bytes_per_thread = contiguous_bytes / self.warp_size; assert!(bytes_per_thread <= 32 && bytes_per_thread.is_power_of_two()); let dst_ptr = (dst.physical().as_ptr() as isize + dst.bytes_offset()) as CUdeviceptr; - let rsa = rsa as udim / b; - let csa = csa as udim / b; + let rsa = *rsa as udim / b; + let csa = *csa as udim / b; let src_ptr = (src.physical().as_ptr() as isize + src.bytes_offset()) as CUdeviceptr; - let rsb = rsb as udim / b; - let csb = csb as udim / b; + let rsb = *rsb as udim / b; + let csb = *csb as udim / b; let params: [*const c_void; 8] = [ (&dst_ptr) as *const _ as _, (&rsa) as *const _ as _, diff --git a/xtask/src/chat.rs b/xtask/src/chat.rs index 72b47b36..3e360a1a 100644 --- a/xtask/src/chat.rs +++ b/xtask/src/chat.rs @@ -49,9 +49,10 @@ fn print_help() { println!( "\ /list 列出现存的会话及对话次数 -/create 新建会话session -/switch [0-9+] 切换至指定会话 -/drop [0-9+] 丢弃指定会话 +/create 新建会话 +/fork [id] 复制当前会话或指定会话 +/switch 切换至指定会话 +/drop [id] 丢弃当前会话或指定会话 /args 打印当前参数 /args key value 设置指定参数 /help 打印帮助信息 @@ -147,6 +148,27 @@ impl Chatting { self.sessions.insert(self.current, self.service.launch()); println!("Create new session {}.", self.current); } + ["/fork"] => { + let new = self.session().fork(); + self.current = self.next_id; + self.next_id += 1; + self.sessions.insert(self.current, new); + println!("Fork session to {}.", self.current); + } + ["/fork", n] => match n.parse() { + Ok(target_id) => { + if let Some(s) = self.sessions.get(&target_id) { + let new = s.fork(); + self.current = self.next_id; + self.next_id += 1; + self.sessions.insert(self.current, new); + println!("Fork session {} to {}.", target_id, self.current); + } else { + println!("Invalid session ID."); + } + } + Err(_) => println!("Invalid drop command"), + }, ["/switch", n] => match n.parse() { Ok(target_id) => { if target_id == self.current {