Skip to content

Commit

Permalink
add to_csr for bitset & bitmap
Browse files Browse the repository at this point in the history
  • Loading branch information
rhdong committed Jan 8, 2025
1 parent ecfe45b commit 66503b1
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 5 deletions.
8 changes: 8 additions & 0 deletions cpp/include/raft/core/bitmap.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <raft/core/device_container_policy.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/resources.hpp>
#include <raft/sparse/convert/csr.cuh>

#include <type_traits>

Expand All @@ -42,4 +43,11 @@ _RAFT_DEVICE void bitmap_view<bitmap_t, index_t>::set(const index_t row,
set(row * cols_ + col, new_value);
}

template <typename bitmap_t, typename index_t>
template <typename csr_matrix_t>
void bitmap_view<bitmap_t, index_t>::to_csr(const raft::resources& res, csr_matrix_t& csr) const
{
raft::sparse::convert::bitmap_to_csr(res, *this, csr);
}

} // end namespace raft::core
22 changes: 22 additions & 0 deletions cpp/include/raft/core/bitmap.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,28 @@ struct bitmap_view : public bitset_view<bitmap_t, index_t> {
*/
inline _RAFT_HOST_DEVICE index_t get_n_cols() const { return cols_; }

/**
* @brief Converts to a Compressed Sparse Row (CSR) format matrix.
*
* This method transforms a two-dimensional bitmap matrix into a CSR representation,
* where each '1' bit in the bitmap corresponds to a non-zero entry in the CSR matrix.
* The bitmap is interpreted as a row-major matrix, with rows and columns defined by
* the dimensions of the bitmap.
*
* @tparam bitmap_t The data type of the elements in the bitmap matrix.
* @tparam index_t The data type used for indexing the elements in the matrices.
* @tparam csr_matrix_t Specifies the CSR matrix type, constrained to raft::device_csr_matrix.
*
* @param[in] res RAFT resources for managing CUDA streams and execution policies.
* @param[out] csr Output parameter where the resulting CSR matrix is stored. Each '1' bit in
* the bitmap corresponds to a non-zero element in the CSR matrix.
*
* The caller must ensure that: The `csr` matrix is pre-allocated with dimensions and non-zero
* count matching the expected output.
*/
template <typename csr_matrix_t>
void to_csr(const raft::resources& res, csr_matrix_t& csr) const;

private:
index_t rows_;
index_t cols_;
Expand Down
8 changes: 8 additions & 0 deletions cpp/include/raft/core/bitset.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <raft/core/resources.hpp>
#include <raft/linalg/map.cuh>
#include <raft/linalg/reduce.cuh>
#include <raft/sparse/convert/csr.cuh>
#include <raft/util/device_atomics.cuh>
#include <raft/util/popc.cuh>

Expand Down Expand Up @@ -165,6 +166,13 @@ double bitset_view<bitset_t, index_t>::sparsity(const raft::resources& res) cons
return static_cast<double>((1.0 * (size_h - count_h)) / (1.0 * size_h));
}

template <typename bitset_t, typename index_t>
template <typename csr_matrix_t>
void bitset_view<bitset_t, index_t>::to_csr(const raft::resources& res, csr_matrix_t& csr) const
{
raft::sparse::convert::bitset_to_csr(res, *this, csr);
}

template <typename bitset_t, typename index_t>
bitset<bitset_t, index_t>::bitset(const raft::resources& res,
raft::device_vector_view<const index_t, index_t> mask_index,
Expand Down
65 changes: 65 additions & 0 deletions cpp/include/raft/core/bitset.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,71 @@ struct bitset_view {
return (bitset_len + bits_per_element - 1) / bits_per_element;
}

/**
* @brief Converts to a Compressed Sparse Row (CSR) format matrix.
*
* This method transforms the bitset view into a CSR matrix representation, where each '1' bit in
* the bitset corresponds to a non-zero entry in the CSR matrix. The bitset format supports
* only a single-row matrix, so if the CSR matrix requires multiple rows, the bitset data is
* repeated for each row in the output.
*
* Example usage:
*
* @code{.cpp}
* #include <raft/core/resource/cuda_stream.hpp>
* #include <raft/sparse/convert/csr.cuh>
* #include <rmm/device_uvector.hpp>
*
* using bitset_t = uint32_t;
* using index_t = int;
* using value_t = float;
*
* raft::resources handle;
* auto stream = resource::get_cuda_stream(handle);
* index_t n_rows = 3;
* index_t n_cols = 30;
*
* // Compute bitset size and initialize device memory
* index_t bitset_size = (n_cols + sizeof(bitset_t) * 8 - 1) / (sizeof(bitset_t) * 8);
* rmm::device_uvector<bitset_t> bitset_d(bitset_size, stream);
* std::vector<bitset_t> bitset_h = {
* bitset_t(0b11001010),
* }; // Example bitset, with 4 non-zero entries.
*
* raft::copy(bitset_d.data(), bitset_h.data(), bitset_h.size(), stream);
*
* // Create bitset view and CSR matrix
* auto bitset_view = raft::core::bitset_view<bitset_t, index_t>(bitset_d.data(), n_cols);
* auto csr = raft::make_device_csr_matrix<value_t, index_t>(handle, n_rows, n_cols, 4 * n_rows);
*
* // Convert bitset to CSR
* bitset_view.to_csr(handle, csr);
* resource::sync_stream(handle);
*
* // Results:
* // csr.indptr = [0, 4, 8, 12];
* // csr.indices = [1, 3, 6, 7,
* // 1, 3, 6, 7,
* // 1, 3, 6, 7];
* // csr.values = [1, 1, 1, 1,
* // 1, 1, 1, 1,
* // 1, 1, 1, 1];
* @endcode
*
* @tparam bitset_t The data type of the elements in the bitset matrix.
* @tparam index_t The data type used for indexing the elements in the matrices.
* @tparam csr_matrix_t Specifies the CSR matrix type, constrained to raft::device_csr_matrix.
*
* @param[in] res RAFT resources for managing CUDA streams and execution policies.
* @param[out] csr Output parameter where the resulting CSR matrix is stored. Each '1' bit in
* the bitset corresponds to a non-zero element in the CSR matrix.
*
* The caller must ensure that: The `csr` matrix is pre-allocated with dimensions and non-zero
* count matching the expected output, i.e., `nnz_for_csr = nnz_for_bitset * n_rows`.
*/
template <typename csr_matrix_t>
void to_csr(const raft::resources& res, csr_matrix_t& csr) const;

private:
bitset_t* bitset_ptr_;
index_t bitset_len_;
Expand Down
3 changes: 2 additions & 1 deletion cpp/include/raft/sparse/convert/csr.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

#pragma once

#include <raft/core/bitmap.cuh>
#include <raft/core/bitmap.hpp>
#include <raft/core/bitset.hpp>
#include <raft/core/device_csr_matrix.hpp>
#include <raft/sparse/convert/detail/adj_to_csr.cuh>
#include <raft/sparse/convert/detail/bitmap_to_csr.cuh>
Expand Down
8 changes: 4 additions & 4 deletions cpp/test/sparse/convert_csr.cu
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ class BitmapToCSRTest : public ::testing::TestWithParam<BitmapToCSRInputs<index_
raft::make_device_csr_matrix<value_t, index_t>(handle, params.n_rows, params.n_cols, nnz);
auto csr_view = csr.structure_view();

convert::bitmap_to_csr(handle, bitmap, csr);
bitmap.to_csr(handle, csr);
raft::copy(indptr_d.data(), csr_view.get_indptr().data(), indptr_d.size(), stream);
raft::copy(indices_d.data(), csr_view.get_indices().data(), indices_d.size(), stream);
raft::copy(values_d.data(), csr.get_elements().data(), nnz, stream);
Expand All @@ -380,7 +380,7 @@ class BitmapToCSRTest : public ::testing::TestWithParam<BitmapToCSRInputs<index_
indptr_d.data(), indices_d.data(), params.n_rows, params.n_cols, nnz);
auto csr = raft::make_device_csr_matrix<value_t, index_t>(handle, csr_view);

convert::bitmap_to_csr(handle, bitmap, csr);
bitmap.to_csr(handle, csr);
raft::copy(values_d.data(), csr.get_elements().data(), nnz, stream);
}
resource::sync_stream(handle);
Expand Down Expand Up @@ -661,7 +661,7 @@ class BitsetToCSRTest : public ::testing::TestWithParam<BitsetToCSRInputs<index_
raft::make_device_csr_matrix<value_t, index_t>(handle, params.n_repeat, params.n_cols, nnz);
auto csr_view = csr.structure_view();

convert::bitset_to_csr(handle, bitset, csr);
bitset.to_csr(handle, csr);
raft::copy(indptr_d.data(), csr_view.get_indptr().data(), indptr_d.size(), stream);
raft::copy(indices_d.data(), csr_view.get_indices().data(), indices_d.size(), stream);
raft::copy(values_d.data(), csr.get_elements().data(), nnz, stream);
Expand All @@ -670,7 +670,7 @@ class BitsetToCSRTest : public ::testing::TestWithParam<BitsetToCSRInputs<index_
indptr_d.data(), indices_d.data(), params.n_repeat, params.n_cols, nnz);
auto csr = raft::make_device_csr_matrix<value_t, index_t>(handle, csr_view);

convert::bitset_to_csr(handle, bitset, csr);
bitset.to_csr(handle, csr);
raft::copy(values_d.data(), csr.get_elements().data(), nnz, stream);
}
resource::sync_stream(handle);
Expand Down

0 comments on commit 66503b1

Please sign in to comment.