Skip to content

Commit

Permalink
Fix bug for asort kernel & faster sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
guoqingbao committed Jan 27, 2025
1 parent fe31cd1 commit 1625938
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 62 deletions.
2 changes: 1 addition & 1 deletion candle-core/src/cuda_backend/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ impl CudaDevice {
self.id
}

fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
pub fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
let elem_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
let slice = match dtype {
Expand Down
108 changes: 97 additions & 11 deletions candle-core/src/sort.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use crate::{Result, Tensor};
use crate::{DType, Result, Shape, Storage, Tensor};
use rayon::prelude::*;

#[derive(Debug, Clone, Copy)]
#[derive(Debug, Clone)]
struct ArgSort {
asc: bool,
last_dim: usize,
dtype: DType,
#[cfg(feature = "cuda")]
indices: Tensor,
}

impl ArgSort {
Expand Down Expand Up @@ -56,7 +58,8 @@ impl ArgSort {
mod cuda {
use super::*;
use crate::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
result::memcpy_dtod_sync, CudaSlice, DevicePtr, DeviceRepr, LaunchAsync, LaunchConfig,
ValidAsZeroBits,
};
use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, WrapErr};
use crate::{CudaDevice, WithDType};
Expand All @@ -74,28 +77,100 @@ mod cuda {
Some((o1, o2)) => src.slice(o1..o2),
};
let elem_count = layout.shape().elem_count();
let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
let func = if self.asc {
dev.get_or_load_func(&kernel_name::<T>("asort_asc"), kernels::SORT)?
} else {
dev.get_or_load_func(&kernel_name::<T>("asort_desc"), kernels::SORT)?
};
let ncols = self.last_dim;
let nrows = elem_count / ncols;
let ncols_pad = next_power_of_2(ncols);
let (indices, _) = self.indices.storage_and_layout();

let indices = match &*indices {
Storage::Cuda(k) => k.as_cuda_slice::<u32>()?.to_owned(),
_ => crate::bail!("indices must be a cuda tensor"),
};
let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;

//size of each row must be log2-base for bitonic sort
let ncols_pad = next_power_of_2(ncols);
//alloc temp buffer for paddings
let tmp_rows = dev.const_impl(
if self.asc {
std::f64::MAX
} else {
std::f64::MIN
},
&Shape::from((nrows, ncols_pad)),
self.dtype,
)?;
let tmp_indices = unsafe { dev.alloc::<u32>(ncols_pad) }.w()?;
// Determine the number of threads per block and blocks per row
let max_threads_per_block = 1024;
let threads_per_block = max_threads_per_block.min(ncols_pad);
let blocks_per_row = (ncols_pad + threads_per_block - 1) / threads_per_block;

let params = (&slice, &dst, ncols as i32, ncols_pad as i32);
let cfg = LaunchConfig {
grid_dim: (blocks_per_row as u32, nrows as u32, 1),
grid_dim: (blocks_per_row as u32, 1, 1),
block_dim: (threads_per_block as u32, 1, 1),
shared_mem_bytes: (threads_per_block * std::mem::size_of::<u32>()) as u32,
};
unsafe { func.launch(cfg, params) }.w()?;

unsafe {
for row in 0..nrows {
let start_o = row * ncols;
let slice_row = slice.slice(start_o..);
let dst_row = dst.slice(start_o..);
let tmp_row_ptr = match &tmp_rows.slice {
S::U8(inp) => *inp.slice(start_o..).device_ptr(),
S::U32(inp) => *inp.slice(start_o..).device_ptr(),
S::I64(inp) => *inp.slice(start_o..).device_ptr(),
S::BF16(inp) => *inp.slice(start_o..).device_ptr(),
S::F16(inp) => *inp.slice(start_o..).device_ptr(),
S::F32(inp) => *inp.slice(start_o..).device_ptr(),
S::F64(inp) => *inp.slice(start_o..).device_ptr(),
};

memcpy_dtod_sync(
tmp_row_ptr,
*slice_row.device_ptr(),
ncols * std::mem::size_of::<T>(),
)
.w()?;
memcpy_dtod_sync(
*tmp_indices.device_ptr(),
*indices.device_ptr(),
ncols * std::mem::size_of::<u32>(),
)
.w()?;

let mut k = 2;
while k <= ncols_pad {
// Minor step
let mut j = k >> 1;
while j > 0 {
let params = (tmp_row_ptr, &tmp_indices, j as i32, k as i32);
func.clone().launch(cfg, params).w()?;
j = j >> 1;
}
k <<= 1;
}

//copy back valid elements
memcpy_dtod_sync(
*slice_row.device_ptr(),
tmp_row_ptr,
ncols * std::mem::size_of::<T>(),
)
.w()?;
memcpy_dtod_sync(
*dst_row.device_ptr(),
*tmp_indices.device_ptr(),
ncols * std::mem::size_of::<u32>(),
)
.w()?;
}
}
Ok(S::U32(dst))
}
}
Expand Down Expand Up @@ -227,8 +302,18 @@ impl Tensor {
None => crate::bail!("empty last-dim in arg-sort"),
Some(last_dim) => *last_dim,
};
#[cfg(feature = "cuda")]
let indices_cpu = (0..last_dim).into_iter().map(|a| a as u32).collect();
#[cfg(feature = "cuda")]
let indices = Tensor::from_vec(indices_cpu, (1, last_dim), self.device())?;
// No need for a backward pass for arg sort.
self.apply_op1_no_bwd(&ArgSort { asc, last_dim })
self.apply_op1_no_bwd(&ArgSort {
asc,
last_dim,
dtype: self.dtype(),
#[cfg(feature = "cuda")]
indices,
})
}

/// Sorts the tensor along the last dimension, returns the sorted tensor together with the
Expand All @@ -244,7 +329,8 @@ impl Tensor {
});
}
let asort = self.arg_sort_last_dim(asc)?;
let sorted = self.gather(&asort, crate::D::Minus1)?;
// let sorted = self.gather(&asort, crate::D::Minus1)?;
let sorted = self.clone();
Ok((sorted, asort))
}
}
88 changes: 39 additions & 49 deletions candle-kernels/src/sort.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,73 +5,63 @@
#include<stdint.h>

template<typename T>
static inline __device__ void ggml_cuda_swap(T & a, T & b) {
inline __device__ void swap(T & a, T & b) {
T tmp = a;
a = b;
b = tmp;
}

template<int order, typename T>
static __device__ void k_argsort(const T * x, uint32_t * dst, const int ncols, int ncols_pad) {
// bitonic sort
int col = threadIdx.x + blockIdx.x * blockDim.x; // Global column index
int row = blockIdx.y;
template <typename T>
__device__ void bitonicSortGPU(T* arr, uint32_t * dst, int j, int k, bool ascending) {
unsigned int i, ij;
i = threadIdx.x + blockDim.x * blockIdx.x;
ij = i ^ j;

if (col >= ncols_pad) {
return;
}

const T * x_row = x + row * ncols;
extern __shared__ int dst_row[];

// Initialize indices
dst_row[threadIdx.x] = (col < ncols) ? col : ncols; // Use ncols as a placeholder for padding

__syncthreads();

// Perform bitonic sort within the block
for (int k = 2; k <= blockDim.x; k *= 2) {
for (int j = k / 2; j > 0; j /= 2) {
int ixj = threadIdx.x ^ j;
if (ixj > threadIdx.x) {
if ((threadIdx.x & k) == 0) {
if (dst_row[threadIdx.x] < ncols &&
(dst_row[ixj] < ncols && (order == SORT_ORDER_ASC ?
x_row[dst_row[threadIdx.x]] > x_row[dst_row[ixj]] :
x_row[dst_row[threadIdx.x]] < x_row[dst_row[ixj]]))
) {
ggml_cuda_swap(dst_row[threadIdx.x], dst_row[ixj]);
}
} else {
if (dst_row[ixj] < ncols &&
(dst_row[threadIdx.x] < ncols && (order == SORT_ORDER_ASC ?
x_row[dst_row[threadIdx.x]] < x_row[dst_row[ixj]] :
x_row[dst_row[threadIdx.x]] > x_row[dst_row[ixj]]))
) {
ggml_cuda_swap(dst_row[threadIdx.x], dst_row[ixj]);
}
if (ij > i) {
if ((i & k) == 0) {
// Sort in ascending order
if (ascending) {
if (arr[i] > arr[ij]) {
swap(arr[i], arr[ij]);
swap(dst[i], dst[ij]);
}
}
// Sort in descending order
else {
if (arr[i] < arr[ij]) {
swap(arr[i], arr[ij]);
swap(dst[i], dst[ij]);
}
}
} else {
// Sort in ascending order
if (ascending) {
if (arr[i] < arr[ij]) {
swap(arr[i], arr[ij]);
swap(dst[i], dst[ij]);
}
}
// Sort in descending order
else {
if (arr[i] > arr[ij]) {
swap(arr[i], arr[ij]);
swap(dst[i], dst[ij]);
}
}
__syncthreads();
}
}

// Copy the result to dst without the padding
if (col < ncols) {
dst[row * ncols + col] = dst_row[threadIdx.x];
}
}

#define ASORT_OP(TYPENAME, RUST_NAME) \
extern "C" __global__ void asort_asc_##RUST_NAME( \
const TYPENAME * x, uint32_t * dst, const int ncols, int ncols_pad \
TYPENAME * x, uint32_t * dst, const int j, const int k \
) { \
k_argsort<SORT_ORDER_ASC>(x, dst, ncols, ncols_pad); \
bitonicSortGPU(x, dst, j, k, true);\
} \
extern "C" __global__ void asort_desc_##RUST_NAME( \
const TYPENAME * x, uint32_t * dst, const int ncols, int ncols_pad \
TYPENAME * x, uint32_t * dst, const int j, const int k \
) { \
k_argsort<SORT_ORDER_DESC>(x, dst, ncols, ncols_pad); \
bitonicSortGPU(x, dst, j, k, false);\
} \

#if __CUDA_ARCH__ >= 800
Expand Down
4 changes: 3 additions & 1 deletion candle-transformers/src/generation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ impl LogitsProcessor {
fn sample_topp(&self, logits: &Tensor, top_p: f32) -> Result<u32> {
let mut prs: Vec<f32> = logits.to_vec1()?;
let argsort_indices: Vec<u32> = logits.arg_sort_last_dim(false)?.to_vec1()?;

// // Slower approach on CPU.
// let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>();
// argsort_indices.sort_by(|&i, &j| prs[j].total_cmp(&prs[i]));
// Clamp smaller probabilities to zero.
let mut cumsum = 0.;
for index in &argsort_indices {
Expand Down

0 comments on commit 1625938

Please sign in to comment.