Skip to content

Commit

Permalink
Revert coo_remove_scalar_kernel code
Browse files Browse the repository at this point in the history
  • Loading branch information
viclafargue committed Jan 22, 2025
1 parent 3cd3880 commit 347ea4e
Showing 1 changed file with 44 additions and 34 deletions.
78 changes: 44 additions & 34 deletions cpp/include/raft/sparse/op/detail/filter.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -42,37 +42,35 @@ namespace sparse {
namespace op {
namespace detail {

template <uint64_t TPB_X, typename T>
RAFT_KERNEL coo_remove_scalar_kernel(const int* in_rows,
const int* in_cols,
const T* in_vals,
template <int TPB_X, typename T>
RAFT_KERNEL coo_remove_scalar_kernel(const int* rows,
const int* cols,
const T* vals,
uint64_t nnz,
int* out_rows,
int* out_cols,
T* out_vals,
uint64_t* row_indices,
int* rows_lenght_acc,
T scalar,
int n_rows)
uint64_t* ex_scan,
uint64_t* cur_ex_scan,
int m,
T scalar)
{
uint64_t in_idx = (blockIdx.x * TPB_X) + threadIdx.x;

if (in_idx >= nnz)
return;

int val = in_vals[in_idx];

if (val == scalar)
return;

int row = in_rows[in_idx];

uint64_t row_start_index = row_indices[row];
uint64_t out_idx = row_start_index + atomicAdd(rows_lenght_acc + row, 1);

out_rows[out_idx] = row;
out_cols[out_idx] = in_cols[in_idx];
out_vals[out_idx] = val;
int row = (blockIdx.x * TPB_X) + threadIdx.x;

if (row < m) {
uint64_t start = cur_ex_scan[row];
uint64_t stop = get_stop_idx(row, m, nnz, cur_ex_scan);
uint64_t cur_out_idx = ex_scan[row];

for (uint64_t idx = start; idx < stop; idx++) {
if (vals[idx] != scalar) {
out_rows[cur_out_idx] = rows[idx];
out_cols[cur_out_idx] = cols[idx];
out_vals[cur_out_idx] = vals[idx];
++cur_out_idx;
}
}
}
}

/**
Expand All @@ -92,7 +90,7 @@ RAFT_KERNEL coo_remove_scalar_kernel(const int* in_rows,
* @param d_alloc device allocator for temporary buffers
* @param stream: cuda stream to use
*/
template <uint64_t TPB_X, typename T>
template <int TPB_X, typename T>
void coo_remove_scalar(const int* rows,
const int* cols,
const T* vals,
Expand All @@ -101,22 +99,27 @@ void coo_remove_scalar(const int* rows,
int* ccols,
T* cvals,
uint64_t* cnnz,
uint64_t* cur_cnnz,
T scalar,
int n,
cudaStream_t stream)
{
rmm::device_uvector<uint64_t> ex_scan(n, stream);
rmm::device_uvector<uint64_t> cur_ex_scan(n, stream);
RAFT_CUDA_TRY(cudaMemsetAsync(ex_scan.data(), 0, (uint64_t)n * sizeof(uint64_t), stream));
RAFT_CUDA_TRY(cudaMemsetAsync(cur_ex_scan.data(), 0, (uint64_t)n * sizeof(uint64_t), stream));

thrust::device_ptr<uint64_t> dev_cnnz = thrust::device_pointer_cast(cnnz);
thrust::device_ptr<uint64_t> dev_ex_scan = thrust::device_pointer_cast(ex_scan.data());
thrust::exclusive_scan(rmm::exec_policy(stream), dev_cnnz, dev_cnnz + n, dev_ex_scan);
RAFT_CUDA_TRY(cudaPeekAtLastError());

rmm::device_uvector<int> rows_length_acc(n, stream);
RAFT_CUDA_TRY(cudaMemsetAsync(rows_length_acc.data(), 0, (uint64_t)n * sizeof(int), stream));
thrust::device_ptr<uint64_t> dev_cur_cnnz = thrust::device_pointer_cast(cur_cnnz);
thrust::device_ptr<uint64_t> dev_cur_ex_scan = thrust::device_pointer_cast(cur_ex_scan.data());
thrust::exclusive_scan(rmm::exec_policy(stream), dev_cur_cnnz, dev_cur_cnnz + n, dev_cur_ex_scan);
RAFT_CUDA_TRY(cudaPeekAtLastError());

dim3 grid(raft::ceildiv(nnz, TPB_X), 1, 1);
dim3 grid(raft::ceildiv(n, TPB_X), 1, 1);
dim3 blk(TPB_X, 1, 1);

coo_remove_scalar_kernel<TPB_X><<<grid, blk, 0, stream>>>(rows,
Expand All @@ -127,9 +130,9 @@ void coo_remove_scalar(const int* rows,
ccols,
cvals,
dev_ex_scan.get(),
rows_length_acc.data(),
scalar,
n);
dev_cur_ex_scan.get(),
n,
scalar);
RAFT_CUDA_TRY(cudaPeekAtLastError());
}

Expand All @@ -145,13 +148,19 @@ template <int TPB_X, typename T>
void coo_remove_scalar(COO<T>* in, COO<T>* out, T scalar, cudaStream_t stream)
{
rmm::device_uvector<uint64_t> row_count_nz(in->n_rows, stream);
rmm::device_uvector<uint64_t> row_count(in->n_rows, stream);

RAFT_CUDA_TRY(cudaMemsetAsync(row_count_nz.data(), 0, (uint64_t)in->n_rows * sizeof(uint64_t), stream));
RAFT_CUDA_TRY(cudaMemsetAsync(row_count.data(), 0, (uint64_t)in->n_rows * sizeof(uint64_t), stream));

linalg::coo_degree(in->rows(), in->nnz, row_count.data(), stream);
RAFT_CUDA_TRY(cudaPeekAtLastError());

linalg::coo_degree_scalar(in->rows(), in->vals(), in->nnz, scalar, (unsigned long long int*)row_count_nz.data(), stream);
RAFT_CUDA_TRY(cudaPeekAtLastError());

thrust::device_ptr<uint64_t> d_row_count_nz = thrust::device_pointer_cast(row_count_nz.data());
uint64_t out_nnz = thrust::reduce(rmm::exec_policy(stream), d_row_count_nz, d_row_count_nz + in->n_rows, (uint64_t)0);
uint64_t out_nnz = thrust::reduce(rmm::exec_policy(stream), d_row_count_nz, d_row_count_nz + in->n_rows);

out->allocate(out_nnz, in->n_rows, in->n_cols, false, stream);

Expand All @@ -163,6 +172,7 @@ void coo_remove_scalar(COO<T>* in, COO<T>* out, T scalar, cudaStream_t stream)
out->cols(),
out->vals(),
row_count_nz.data(),
row_count.data(),
scalar,
in->n_rows,
stream);
Expand Down

0 comments on commit 347ea4e

Please sign in to comment.