Skip to content

Commit

Permalink
Fix index type
Browse files Browse the repository at this point in the history
  • Loading branch information
enp1s0 committed Feb 24, 2025
1 parent 613ca02 commit 6e9759b
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions cpp/include/raft/linalg/detail/strided_reduction.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,19 @@ namespace detail {
// Note that the compensation will only be performed 'per-block' for performance
// reasons and therefore not be equivalent to a sequential compensation.

template <typename Type, typename MainLambda>
template <typename Type, typename IdxType, typename MainLambda>
RAFT_KERNEL stridedSummationKernel(
Type* out, const Type* data, int D, int N, Type init, MainLambda main_op)
Type* out, const Type* data, IdxType D, IdxType N, Type init, MainLambda main_op)
{
// Thread reduction
Type thread_sum = Type(init);
Type thread_c = Type(0);
int colStart = blockIdx.x * blockDim.x + threadIdx.x;
Type thread_sum = Type(init);
Type thread_c = Type(0);
IdxType colStart = blockIdx.x * blockDim.x + threadIdx.x;
if (colStart < D) {
int rowStart = blockIdx.y * blockDim.y + threadIdx.y;
int stride = blockDim.y * gridDim.y;
for (int j = rowStart; j < N; j += stride) {
int idx = colStart + j * D;
IdxType rowStart = blockIdx.y * blockDim.y + threadIdx.y;
IdxType stride = blockDim.y * gridDim.y;
for (IdxType j = rowStart; j < N; j += stride) {
auto idx = colStart + j * D;

// KahanBabushkaNeumaierSum
const Type cur_value = main_op(data[idx], j);
Expand Down Expand Up @@ -97,8 +97,8 @@ template <typename InType,
typename ReduceLambda>
RAFT_KERNEL stridedReductionKernel(OutType* dots,
const InType* data,
int D,
int N,
IdxType D,
IdxType N,
OutType init,
MainLambda main_op,
ReduceLambda reduce_op)
Expand Down Expand Up @@ -167,7 +167,7 @@ void stridedReduction(OutType* dots,
raft::min((IdxType)MaxBlocksDimY, raft::ceildiv(N, (IdxType)MinRowsPerBlk)));
const size_t shmemSize = sizeof(OutType) * Block.x * 2;

stridedSummationKernel<InType>
stridedSummationKernel<InType, IdxType>
<<<grid, Block, shmemSize, stream>>>(dots, data, D, N, init, main_op);
} else {
// Arbitrary numbers for now, probably need to tune
Expand Down

0 comments on commit 6e9759b

Please sign in to comment.