diff --git a/Cargo.lock b/Cargo.lock index cfb7f724..71d4470a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -337,7 +337,7 @@ dependencies = [ [[package]] name = "common" version = "0.1.0" -source = "git+https://github.com/YdrMaster/operators-rs?rev=b05568a#b05568a30570c5a25bfc84a4cced2413701f9e0c" +source = "git+https://github.com/YdrMaster/operators-rs?rev=d059b1a#d059b1ac742d38692f118c31bd7b1f9cd6b5a369" dependencies = [ "digit-layout", ] @@ -380,7 +380,6 @@ name = "common-nv" version = "0.0.0" dependencies = [ "build-script-cfg", - "cc", "common 0.0.0", "common-devices", "digit-layout", @@ -1180,7 +1179,7 @@ checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" [[package]] name = "operators" version = "0.0.0" -source = "git+https://github.com/YdrMaster/operators-rs?rev=b05568a#b05568a30570c5a25bfc84a4cced2413701f9e0c" +source = "git+https://github.com/YdrMaster/operators-rs?rev=d059b1a#d059b1ac742d38692f118c31bd7b1f9cd6b5a369" dependencies = [ "build-script-cfg", "cndrv", @@ -1194,6 +1193,7 @@ dependencies = [ "libloading", "log", "nccl", + "rand", "rayon", "search-cuda-tools", "search-neuware-tools 0.0.0 (registry+https://github.com/rust-lang/crates.io-index)", @@ -1615,18 +1615,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.62" +version = "1.0.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2675633b1499176c2dff06b0856a27976a8f9d436737b4cf4f312d4d91d8bbb" +checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.62" +version = "1.0.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d20468752b09f49e909e55a5d338caa8bedf615594e9d80bc4c565d30faf798c" +checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index a9dbcdae..ab06e164 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,6 +35,6 @@ tokio = { version = "1.38", features = ["rt-multi-thread", "sync"] } digit-layout = "0.0" build-script-cfg = "0.0" -operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "b05568a", default-features = false } +operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "d059b1a", default-features = false } search-cuda-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "fb088b6" } search-neuware-tools = "0.0" diff --git a/devices/common-cpu/src/lib.rs b/devices/common-cpu/src/lib.rs index daf65aa0..a338ad1f 100644 --- a/devices/common-cpu/src/lib.rs +++ b/devices/common-cpu/src/lib.rs @@ -7,12 +7,18 @@ macro_rules! slice { mod gather; -use common::utok; +use common::{f16, utok}; use common_devices::{Operators, SliceOn}; +use digit_layout::types::F16; use operators::{ - fuesd_softmax::common_cpu as softmax, mat_mul::common_cpu as mat_mul, - reform::common_cpu as reform, rms_norm::common_cpu as rms_norm, rope::common_cpu as rope, - swiglu::common_cpu as swiglu, Operator, QueueOf, + fuesd_softmax::common_cpu as softmax, + mat_mul::common_cpu as mat_mul, + random_sample::{common_cpu as random_sample, Args, KVPair, SampleArgs}, + reform::common_cpu as reform, + rms_norm::common_cpu as rms_norm, + rope::common_cpu as rope, + swiglu::common_cpu as swiglu, + Operator, QueueOf, }; use std::ops::{Deref, DerefMut}; use tensor::Tensor; @@ -29,6 +35,23 @@ pub struct CpuKernels { rope: rope::Operator, softmax: softmax::Operator, swiglu: swiglu::Operator, + sample: random_sample::Operator, +} + +impl CpuKernels { + pub fn sample(&self, temperature: f32, top_p: f32, top_k: usize, logits: &[f16]) -> utok { + let mut kv_pair = KVPair::new(0, f16::ZERO); + let mut args = Args::::new(F16, logits.len()); + args.kv_pair_base = &mut kv_pair as *mut _ as _; + args.data_base = logits.as_ptr() as _; + args.detail = SampleArgs { + temperature, + top_p, + top_k, + }; + self.sample.launch(&args, &ThisThread).unwrap(); + kv_pair.idx() as _ + } } impl Default for CpuKernels { @@ -40,6 +63,7 @@ impl Default for CpuKernels { rope: rope::Operator::new(&Cpu), softmax: softmax::Operator::new(&Cpu), swiglu: swiglu::Operator::new(&Cpu), + sample: random_sample::Operator::new(&Cpu), } } } diff --git a/devices/nvidia-gpu/Cargo.toml b/devices/nvidia-gpu/Cargo.toml index 5caa0d3a..3d8042e3 100644 --- a/devices/nvidia-gpu/Cargo.toml +++ b/devices/nvidia-gpu/Cargo.toml @@ -18,4 +18,3 @@ digit-layout.workspace = true [build-dependencies] build-script-cfg.workspace = true search-cuda-tools.workspace = true -cc = "1.0" diff --git a/devices/nvidia-gpu/build.rs b/devices/nvidia-gpu/build.rs index af19ff54..7273ce09 100644 --- a/devices/nvidia-gpu/build.rs +++ b/devices/nvidia-gpu/build.rs @@ -9,13 +9,5 @@ fn main() { if find_nccl_root().is_some() { nccl.define(); } - println!("cargo:rerun-if-changed=src/sample.cu"); - cc::Build::new() - .cuda(true) - .flag("-gencode") - .flag("arch=compute_80,code=sm_80") - .flag("-allow-unsupported-compiler") - .file("src/sample.cu") - .compile("sample"); } } diff --git a/devices/nvidia-gpu/src/.clang-format b/devices/nvidia-gpu/src/.clang-format deleted file mode 100644 index 66c6e431..00000000 --- a/devices/nvidia-gpu/src/.clang-format +++ /dev/null @@ -1,66 +0,0 @@ -# Generated from CLion C/C++ Code Style settings -BasedOnStyle: LLVM -AccessModifierOffset: -4 -AlignAfterOpenBracket: Align -# AlignConsecutiveAssignments: None -AlignOperands: Align -AllowAllArgumentsOnNextLine: false -AllowAllConstructorInitializersOnNextLine: false -AllowAllParametersOfDeclarationOnNextLine: false -AllowShortBlocksOnASingleLine: Always -AllowShortCaseLabelsOnASingleLine: false -AllowShortFunctionsOnASingleLine: All -AllowShortIfStatementsOnASingleLine: Always -AllowShortLambdasOnASingleLine: All -AllowShortLoopsOnASingleLine: true -AlwaysBreakAfterReturnType: None -AlwaysBreakTemplateDeclarations: No -BreakBeforeBraces: Custom -BraceWrapping: - AfterCaseLabel: false - AfterClass: false - AfterControlStatement: Never - AfterEnum: false - AfterFunction: false - AfterNamespace: false - AfterUnion: false - BeforeCatch: false - BeforeElse: false - IndentBraces: false - SplitEmptyFunction: false - SplitEmptyRecord: true -BreakBeforeBinaryOperators: None -BreakBeforeTernaryOperators: true -BreakConstructorInitializers: BeforeColon -BreakInheritanceList: BeforeColon -ColumnLimit: 0 -CompactNamespaces: true -ContinuationIndentWidth: 4 -IndentCaseLabels: true -IndentPPDirectives: None -IndentWidth: 4 -KeepEmptyLinesAtTheStartOfBlocks: true -MaxEmptyLinesToKeep: 2 -NamespaceIndentation: All -ObjCSpaceAfterProperty: false -ObjCSpaceBeforeProtocolList: true -PointerAlignment: Right -ReflowComments: false -SpaceAfterCStyleCast: true -SpaceAfterLogicalNot: false -SpaceAfterTemplateKeyword: false -SpaceBeforeAssignmentOperators: true -SpaceBeforeCpp11BracedList: false -SpaceBeforeCtorInitializerColon: true -SpaceBeforeInheritanceColon: true -SpaceBeforeParens: ControlStatements -SpaceBeforeRangeBasedForLoopColon: true -SpaceInEmptyParentheses: false -SpacesBeforeTrailingComments: 0 -SpacesInAngles: false -SpacesInCStyleCastParentheses: false -SpacesInContainerLiterals: false -SpacesInParentheses: false -SpacesInSquareBrackets: false -TabWidth: 4 -UseTab: Never diff --git a/devices/nvidia-gpu/src/lib.rs b/devices/nvidia-gpu/src/lib.rs index 906c4a24..589f014b 100644 --- a/devices/nvidia-gpu/src/lib.rs +++ b/devices/nvidia-gpu/src/lib.rs @@ -1,26 +1,33 @@ #![cfg(detected_cuda)] mod gather; -mod sample; -use common::utok; +use ::sample::SampleArgs; +use common::{f16, utok}; use common_devices::{Operators, SliceOn}; use cuda::{AsRaw, Device}; use digit_layout::types::{F16, U32}; use operators::{ - dyn_, fuesd_softmax::nvidia_gpu as softmax, mat_mul::nvidia_gpu as mat_mul, - reform::nvidia_gpu as reform, rms_norm::nvidia_gpu as rms_norm, rope::nvidia_gpu as rope, - swiglu::nvidia_gpu as swiglu, Operator, QueueOf, TensorLayout, + cuda::{memcpy_d2h, DevByte, DevMem, Stream}, + dyn_, + fuesd_softmax::nvidia_gpu as softmax, + mat_mul::nvidia_gpu as mat_mul, + random_sample::{nvidia_gpu as random_sample, KVPair, RandomSample}, + reform::nvidia_gpu as reform, + rms_norm::nvidia_gpu as rms_norm, + rope::nvidia_gpu as rope, + swiglu::nvidia_gpu as swiglu, + Operator, QueueOf, TensorLayout, Workspace, }; use std::{ collections::HashMap, + mem::size_of, ops::{Deref, DerefMut}, ptr::{null, null_mut}, }; pub use common_devices::{Kernels, KernelsA, KernelsB}; pub use operators::{cuda, nvidia_gpu::Handle as Gpu}; -pub use sample::{sample_cpu, sample_nv}; pub use tensor::{reslice, reslice_mut, slice, split, udim, LocalSplitable, Tensor}; #[cfg(detected_nccl)] @@ -35,10 +42,11 @@ struct Internal { reform: reform::Operator, softmax: softmax::Operator, swiglu: swiglu::Operator, + random_sample: random_sample::Operator, } impl Internal { - pub fn new(handle: &Gpu, d: usize) -> Self { + pub fn new(handle: &Gpu, d: usize, voc: usize) -> Self { let mat_mul = mat_mul::Operator::new(handle); let mut rms_norm = rms_norm::Operator::new(handle); @@ -92,6 +100,11 @@ impl Internal { }) .unwrap(); + let mut random_sample = random_sample::Operator::new(handle); + random_sample + .scheme(&operators::random_sample::Args::new(F16, voc)) + .unwrap(); + Self { mat_mul, rms_norm, @@ -99,30 +112,75 @@ impl Internal { reform, softmax, swiglu, + random_sample, } } } impl NvidiaKernels { - pub fn new(devices: &[Device], rms_norm_size: usize) -> Self { + pub fn new(devices: &[Device], rms_norm_size: usize, voc_size: usize) -> Self { Self( devices .iter() .map(|d| { ( unsafe { d.as_raw() }, - Internal::new(&Gpu::new(d.retain_primary()), rms_norm_size), + Internal::new(&Gpu::new(d.retain_primary()), rms_norm_size, voc_size), ) }) .collect(), ) } -} -impl NvidiaKernels { fn get(&self, queue: &QueueOf) -> &Internal { self.0.get(&unsafe { queue.ctx().dev().as_raw() }).unwrap() } + + pub fn sample_workspace<'ctx>(&self, queue: &'ctx QueueOf) -> DevMem<'ctx> { + let random_sample = &self.get(queue).random_sample; + let workspace_len = random_sample.workspace(); + let scheme_n = random_sample.scheme_n(); + let mut workspace = queue.malloc::(workspace_len); + let host = (0..scheme_n).map(|i| i as u32).collect::>(); + queue.memcpy_h2d(&mut workspace[..scheme_n * size_of::()], &host); + workspace + } + + pub fn sample( + &self, + args: impl IntoIterator, + logits: &[DevByte], + workspace: &mut [DevByte], + stream: &Stream, + ) -> Vec { + let random_sample = &self.get(stream).random_sample; + let voc = random_sample.scheme_n(); + let logits = logits.as_ptr(); + + let details = args.into_iter().collect::>(); + let kv_pair_size = KVPair::<()>::LAYOUT.nbytes(); + let mut kv_pairs = stream.malloc::(details.len() * kv_pair_size); + + let mut args = operators::random_sample::Args::::new(F16, voc); + args.workspace = Workspace { + ptr: workspace.as_mut_ptr(), + len: workspace.len(), + }; + for (i, arg) in details.iter().enumerate() { + args.kv_pair_base = unsafe { kv_pairs.as_mut_ptr().add(i * kv_pair_size) }; + args.data_base = unsafe { logits.add(i * voc * F16.nbytes()) }; + args.detail.temperature = arg.temperature; + args.detail.top_p = arg.top_p; + args.detail.top_k = arg.top_k; + random_sample.launch(&args, stream).unwrap(); + } + + let mut host = vec![KVPair::new(0, f16::ZERO); details.len()]; + stream.synchronize(); + memcpy_d2h(&mut host, &kv_pairs); + + host.into_iter().map(|kv| kv.idx() as _).collect() + } } impl Kernels for NvidiaKernels {} diff --git a/devices/nvidia-gpu/src/sample.cu b/devices/nvidia-gpu/src/sample.cu deleted file mode 100644 index 0dcfbfbb..00000000 --- a/devices/nvidia-gpu/src/sample.cu +++ /dev/null @@ -1,119 +0,0 @@ -#include -#include -#include -#include - -extern "C" cudaError argmax_half( - void *temp_storage, size_t *temp_storage_bytes, - half const *input, int num_items, - cub::KeyValuePair *output, - cudaStream_t stream) { - return cub::DeviceReduce::ArgMax( - temp_storage, *temp_storage_bytes, - input, - output, - num_items, - stream); -} - -extern "C" cudaError radix_sort_half( - void *temp_storage, size_t *temp_storage_bytes, - half const *key_in, half *key_out, - unsigned int const *value_in, unsigned int *value_out, - int num_items, - cudaStream_t stream) { - return cub::DeviceRadixSort::SortPairsDescending( - temp_storage, *temp_storage_bytes, - key_in, - key_out, - value_in, - value_out, - num_items, - 0, - sizeof(half) * 8, - stream); -} - -extern "C" cudaError inclusive_sum_half( - void *temp_storage, size_t *temp_storage_bytes, - half *data, int num_items, - cudaStream_t stream) { - return cub::DeviceScan::InclusiveSum( - temp_storage, *temp_storage_bytes, - data, - data, - num_items, - stream); -} - -#define RUNTIME(statement) \ - { \ - auto error = statement; \ - if (error != cudaSuccess) { \ - printf("Error: %s (%d) at \"%s\"\n", cudaGetErrorString(error), error, #statement); \ - return error; \ - } \ - } - -static __global__ void partial_softmax_half_kernel( - half2 *__restrict__ data, - float temperature, - int n) { - int i = blockIdx.x * blockDim.x + threadIdx.x; - if (0 < i && i < n) { - auto max = __ldg((half *) data); - data[i] = h2exp((data[i] - half2(max, max)) / half2(temperature, temperature)); - } -} - -static __global__ void set_softmax_max_kernel( - half *__restrict__ data, float temperature) { - data[1] = hexp((data[1] - data[0]) / (half) temperature); - data[0] = 1; -} - -extern "C" cudaError partial_softmax_half( - half *data, - float temperature, - int voc, - cudaStream_t stream) { - - voc /= 2; - auto block = min(1024, voc); - auto grid = (voc + block - 1) / block; - partial_softmax_half_kernel<<>>((half2 *) data, temperature, voc); - set_softmax_max_kernel<<<1, 1, 0, stream>>>(data, temperature); - - return cudaGetLastError(); -} - -static __global__ void random_sample_kernel( - half const *__restrict__ data, - unsigned int const *__restrict__ indices, - unsigned int *__restrict__ index_, - float random, float topp, int topk, int voc) { - half p = random * min(topp * (float) data[voc - 1], (float) data[topk - 1]); - for (int i = 0;; ++i) { - if (data[i] >= p) { - *index_ = indices[i]; - return; - } - } -} - -extern "C" cudaError random_sample_half( - half const *data, - unsigned int const *indices, - unsigned int *index, - float random, float topp, int topk, int voc, - cudaStream_t stream) { - unsigned int *index_ = nullptr; - cudaMallocAsync(&index_, sizeof(unsigned int), stream); - - random_sample_kernel<<<1, 1, 0, stream>>>(data, indices, index_, random, topp, topk, voc); - - cudaMemcpy(index, index_, sizeof(unsigned int), cudaMemcpyDeviceToHost); - cudaFree(index_); - - return cudaGetLastError(); -} diff --git a/devices/nvidia-gpu/src/sample.rs b/devices/nvidia-gpu/src/sample.rs deleted file mode 100644 index 89581405..00000000 --- a/devices/nvidia-gpu/src/sample.rs +++ /dev/null @@ -1,283 +0,0 @@ -use common::{f16, utok, Blob}; -use operators::cuda::{bindings::CUstream, memcpy_d2h, AsRaw, DevByte, DevMem, Stream}; -use sample::SampleArgs; -use std::{ - collections::HashMap, - ffi::{c_int, c_void}, - ptr::{null, null_mut}, - sync::{Mutex, OnceLock}, -}; -use tensor::{reslice, reslice_mut}; - -pub fn sample_cpu( - args: impl IntoIterator, - logits: &[DevByte], - voc: usize, - _stream: &Stream, -) -> Vec { - let mut host = Blob::new(logits.len()); - memcpy_d2h(&mut host, logits); - - let logits: &[f16] = reslice(&host); - args.into_iter() - .map(|(i, arg)| arg.random(&logits[voc * i..][..voc])) - .collect() -} - -#[derive(Clone, Copy, PartialEq, Eq, Hash, Default, Debug)] -#[repr(C)] -struct CubKeyValuePair { - k: K, - v: V, -} - -extern "C" { - // extern "C" cudaError argmax_half( - // void *temp_storage, size_t *temp_storage_bytes, - // half const *input, int num_items, - // cub::KeyValuePair *output, - // cudaStream_t stream) - fn argmax_half( - temp_storage: *mut c_void, - temp_storage_bytes: *mut usize, - input: *const f16, - num_items: c_int, - output: *mut CubKeyValuePair, - stream: CUstream, - ) -> c_int; - - // extern "C" cudaError radix_sort_half( - // void *temp_storage, size_t *temp_storage_bytes, - // half const *key_in, half *key_out, - // unsigned int const *value_in, unsigned int *value_out, - // int num_items, - // cudaStream_t stream) - fn radix_sort_half( - temp_storage: *mut c_void, - temp_storage_bytes: *mut usize, - key_in: *const f16, - key_out: *mut f16, - value_in: *const u32, - value_out: *mut u32, - num_items: c_int, - stream: CUstream, - ) -> c_int; - - // extern "C" cudaError inclusive_sum_half( - // void *temp_storage, size_t *temp_storage_bytes, - // half *data, int num_items, - // cudaStream_t stream) - fn inclusive_sum_half( - temp_storage: *mut c_void, - temp_storage_bytes: *mut usize, - data: *mut f16, - num_items: c_int, - stream: CUstream, - ) -> c_int; - - // extern "C" cudaError partial_softmax_half( - // half *data, - // float temperature, - // unsigned int topk, - // cudaStream_t stream) - fn partial_softmax_half( - data: *mut f16, - temperature: f32, - topk: c_int, - stream: CUstream, - ) -> c_int; - - // extern "C" cudaError random_sample_half( - // half const *data, - // unsigned int const *indices, - // unsigned int *index, - // float probability, - // int topk, - // cudaStream_t stream) - fn random_sample_half( - data: *const f16, - indices: *const u32, - index: *mut u32, - random: f32, - topp: f32, - topk: c_int, - voc: c_int, - stream: CUstream, - ) -> c_int; -} - -fn prealloc_argmax<'ctx>(stream: &Stream<'ctx>, len: usize) -> DevMem<'ctx> { - static MAP: OnceLock>> = OnceLock::new(); - let len = *MAP - .get_or_init(Default::default) - .lock() - .unwrap() - .entry(len) - .or_insert_with(|| { - let mut temp_storage_bytes = 0; - assert_eq!(0, unsafe { - argmax_half( - null_mut(), - &mut temp_storage_bytes, - null(), - len as _, - null_mut(), - stream.as_raw(), - ) - }); - temp_storage_bytes - }); - stream.malloc::(len) -} - -fn prealloc_radix_sort<'ctx>(stream: &Stream<'ctx>, len: usize) -> DevMem<'ctx> { - static MAP: OnceLock>> = OnceLock::new(); - let len = *MAP - .get_or_init(Default::default) - .lock() - .unwrap() - .entry(len) - .or_insert_with(|| { - let mut temp_storage_bytes = 0; - assert_eq!(0, unsafe { - radix_sort_half( - null_mut(), - &mut temp_storage_bytes, - null(), - null_mut(), - null(), - null_mut(), - len as _, - stream.as_raw(), - ) - }); - temp_storage_bytes - }); - stream.malloc::(len) -} - -fn prealloc_inclusive_sum<'ctx>(stream: &Stream<'ctx>, len: usize) -> DevMem<'ctx> { - static MAP: OnceLock>> = OnceLock::new(); - let len = *MAP - .get_or_init(Default::default) - .lock() - .unwrap() - .entry(len) - .or_insert_with(|| { - let mut temp_storage_bytes = 0; - assert_eq!(0, unsafe { - inclusive_sum_half( - null_mut(), - &mut temp_storage_bytes, - null_mut(), - len as _, - stream.as_raw(), - ) - }); - temp_storage_bytes - }); - stream.malloc::(len) -} - -pub fn sample_nv( - args: impl IntoIterator, - logits: &[DevByte], - voc: usize, - stream: &Stream, -) -> Vec { - let mut temp_argmax = prealloc_argmax(stream, voc); - let mut argmax_host = CubKeyValuePair::::default(); - let mut argmax_out = stream.malloc::>(1); - - let mut temp_sort = prealloc_radix_sort(stream, voc); - let mut sort_out = stream.malloc::(voc); - let mut indices_host = stream.ctx().malloc_host::(voc); - reslice_mut::(&mut indices_host) - .iter_mut() - .enumerate() - .for_each(|(i, idx)| *idx = i as u32); - let indices_in = stream.from_host(&indices_host); - let mut indices_out = stream.malloc::(voc); - - let mut temp_sum = prealloc_inclusive_sum(stream, voc); - - let logits = logits.as_ptr().cast::(); - let ans = args - .into_iter() - .map(|(i, args)| { - let logits = unsafe { logits.add(i * voc) }; - - if args.is_argmax() { - assert_eq!(0, unsafe { - argmax_half( - temp_argmax.as_mut_ptr().cast(), - &mut temp_argmax.len(), - logits, - voc as _, - argmax_out.as_mut_ptr().cast(), - stream.as_raw(), - ) - }); - memcpy_d2h(std::slice::from_mut(&mut argmax_host), &argmax_out); - argmax_host.k as utok - } else { - let topk = args.top_k.min(voc) as c_int; - assert_eq!(0, unsafe { - radix_sort_half( - temp_sort.as_mut_ptr().cast(), - &mut temp_sort.len(), - logits, - sort_out.as_mut_ptr().cast(), - indices_in.as_ptr().cast(), - indices_out.as_mut_ptr().cast(), - voc as _, - stream.as_raw(), - ) - }); - assert_eq!(0, unsafe { - partial_softmax_half( - sort_out.as_mut_ptr().cast(), - args.temperature, - voc as _, - stream.as_raw(), - ) - }); - assert_eq!(0, unsafe { - inclusive_sum_half( - temp_sum.as_mut_ptr().cast(), - &mut temp_sum.len(), - sort_out.as_mut_ptr().cast(), - voc as _, - stream.as_raw(), - ) - }); - let mut index = 0; - assert_eq!(0, unsafe { - random_sample_half( - sort_out.as_ptr().cast(), - indices_out.as_ptr().cast(), - &mut index, - rand::random::(), - args.top_p, - topk, - voc as _, - stream.as_raw(), - ) - }); - index as utok - } - }) - .collect(); - - temp_argmax.drop_on(stream); - argmax_out.drop_on(stream); - - temp_sort.drop_on(stream); - sort_out.drop_on(stream); - indices_in.drop_on(stream); - indices_out.drop_on(stream); - - temp_sum.drop_on(stream); - - ans -} diff --git a/models/llama/common-cpu/src/lib.rs b/models/llama/common-cpu/src/lib.rs index 052efa04..d21476ac 100644 --- a/models/llama/common-cpu/src/lib.rs +++ b/models/llama/common-cpu/src/lib.rs @@ -201,7 +201,14 @@ impl CausalLM for Transformer { args.into_iter() .flat_map(|meta| repeat(meta.args).take(meta.num_decode)) .enumerate() - .map(|(i, args)| args.random(&common_cpu::slice!(logits; voc; [i]))) + .map(|(i, args)| { + self.kernels.sample( + args.temperature, + args.top_p, + args.top_k, + &common_cpu::slice!(logits; voc; [i]), + ) + }) .collect() } } diff --git a/models/llama/nvidia-gpu/src/lib.rs b/models/llama/nvidia-gpu/src/lib.rs index e45e8ea9..51123b22 100644 --- a/models/llama/nvidia-gpu/src/lib.rs +++ b/models/llama/nvidia-gpu/src/lib.rs @@ -8,14 +8,13 @@ extern crate log; use causal_lm::{CausalLM, DecodingMeta, Model, QueryContext, SampleMeta}; use common::{upos, utok, Blob, FileLoadError}; use common_nv::{ - cuda::memcpy_d2h, sample_nv, slice, udim, Gpu, Kernels, KernelsA, KernelsB, NvidiaKernels, - Tensor, + cuda::{memcpy_d2h, AsRaw}, + slice, udim, Gpu, Kernels, KernelsA, KernelsB, NvidiaKernels, Tensor, }; use cuda::{ ContextResource, ContextSpore, DevByte, DevMem, DevMemSpore, Device, EventSpore, HostMemSpore, Stream, StreamSpore, }; -use digit_layout::types::F16; use llama::{ComputeConst, InferenceConfig, LayerStorage, SliceOn, Weight}; use resource::Resource; use std::{ @@ -26,7 +25,7 @@ use std::{ ops::Deref, path::Path, rc::Rc, - slice::from_raw_parts, + slice::{from_raw_parts, from_raw_parts_mut}, sync::{Arc, Mutex, MutexGuard}, time::Instant, }; @@ -42,6 +41,7 @@ struct Internal { resource: Arc, transfer: StreamSpore, kernels: NvidiaKernels, + sample_workspace: DevMemSpore, embed_tokens: Tensor, layers: Vec>, @@ -107,8 +107,13 @@ impl Model for Transformer { .map(|l| (l.map(from_host), transfer.record().sporulate())) .collect(); + let kernels = NvidiaKernels::new(&[device], host.config.d as _, host.config.voc as _); + let sample_workspace = kernels.sample_workspace(compute).sporulate(); + Ok(Self(ManuallyDrop::new(Internal { - kernels: NvidiaKernels::new(&[device], host.config.d as _), + kernels, + sample_workspace, + embed_tokens: host.embed_tokens.as_ref().map_physical(page_lock), layers, lm_layernorm: host @@ -276,19 +281,16 @@ impl CausalLM for Transformer { args: impl IntoIterator, logits: Tensor, ) -> Vec { - assert_eq!(logits.data_layout(), F16); - let &[_nt, voc] = logits.shape() else { - panic!() - }; - let voc = voc as usize; - + let workspace_ptr = unsafe { self.0.sample_workspace.as_raw() }; + let workspace_len = self.0.sample_workspace.len(); self.0.resource.apply(|compute| { - sample_nv( + let workspace = + unsafe { from_raw_parts_mut(workspace_ptr as *mut DevByte, workspace_len) }; + self.0.kernels.sample( args.into_iter() - .flat_map(|meta| repeat(meta.args).take(meta.num_decode)) - .enumerate(), + .flat_map(|meta| repeat(meta.args).take(meta.num_decode)), logits.take_physical().mem.sprout_ref(compute.ctx()), - voc, + workspace, compute, ) }) @@ -303,6 +305,7 @@ impl Drop for Transformer { resource, transfer, kernels: _, + sample_workspace, embed_tokens, layers, lm_layernorm, @@ -312,6 +315,7 @@ impl Drop for Transformer { resource.apply(|compute| { let ctx = compute.ctx(); transfer.sprout(ctx); + sample_workspace.sprout(ctx); embed_tokens.take_physical().sprout(ctx); lm_layernorm.take_physical().sprout(ctx); lm_head.take_physical().sprout(ctx);