Skip to content

Commit

Permalink
fixing batch_size
Browse files Browse the repository at this point in the history
Signed-off-by: jinsolp <[email protected]>
  • Loading branch information
jinsolp committed Feb 15, 2025
1 parent 5aa49a2 commit 650ec3c
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 8 deletions.
5 changes: 4 additions & 1 deletion cpp/src/neighbors/detail/nn_descent.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1361,8 +1361,11 @@ void GNND<Data_t, Index_t>::build(Data_t* data,
}
update_and_sample_thread.join();
std::cout << "iter " << it + 1 << "inside the loop, before hitting update counter "
<< update_counter_.load() << std::endl;
if (update_counter_ == -1) { break; }
std::cout << "iter " << it + 1 << "inside the loop, after hitting update counter "
<< update_counter_.load() << std::endl;
raft::copy(graph_host_buffer_.data_handle(),
graph_buffer_.data_handle(),
nrow_ * DEGREE_ON_DEVICE,
Expand Down
144 changes: 137 additions & 7 deletions cpp/src/neighbors/detail/nn_descent_batch.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ void get_global_nearest_k(

size_t num_batches = n_clusters;
size_t batch_size = (num_rows + n_clusters) / n_clusters;
// if (n_clusters == k) {
// batch_size = num_rows /
// }
printf("num batches: %lu, batch size: %lu\n", num_batches, batch_size);
if (ptr == nullptr) { // data on host

auto d_dataset_batch =
Expand Down Expand Up @@ -169,11 +173,12 @@ void get_global_nearest_k(
size_t batch_size_ = batch_size;

if (i == num_batches - 1) { batch_size_ = num_rows - batch_size * i; }
printf("\tusing batch size %lu\n", batch_size_);
thrust::copy(raft::resource::get_thrust_policy(res),
nearest_clusters_idx.data_handle() + i * batch_size_ * k,
nearest_clusters_idx.data_handle() + (i + 1) * batch_size_ * k,
nearest_clusters_idx.data_handle() + i * batch_size * k,
nearest_clusters_idx.data_handle() + (i + 1) * batch_size * k,
nearest_clusters_idxt.data_handle());
raft::copy(global_nearest_cluster.data_handle() + i * batch_size_ * k,
raft::copy(global_nearest_cluster.data_handle() + i * batch_size * k,
nearest_clusters_idxt.data_handle(),
batch_size_ * k,
resource::get_cuda_stream(res));
Expand Down Expand Up @@ -219,10 +224,14 @@ void get_inverted_indices(raft::resources const& res,
}
}

raft::print_host_vector("cluster sizes", cluster_size.data_handle(), n_clusters, std::cout);

offset(0) = 0;
for (size_t i = 1; i < n_clusters; i++) {
offset(i) = offset(i - 1) + cluster_size(i - 1);
}

raft::print_host_vector("offsets", offset.data_handle(), n_clusters, std::cout);
for (size_t i = 0; i < num_rows; i++) {
for (size_t j = 0; j < k; j++) {
IdxT cluster_id = global_nearest_cluster(i, j);
Expand Down Expand Up @@ -384,7 +393,9 @@ void build_and_merge(raft::resources const& res,
IdxT* batch_indices_h,
IdxT* batch_indices_d,
float* batch_distances_d,
GNND<const T, int>& nnd)
GNND<const T, int>& nnd,
size_t cluster_id,
size_t num_cols)
{
nnd.build(cluster_data, num_data_in_cluster, int_graph, true, batch_distances_d);

Expand All @@ -397,6 +408,109 @@ void build_and_merge(raft::resources const& res,
}
}

// looking for duplicates
auto batch_distances_h =
raft::make_host_matrix<float, int64_t, raft::row_major>(num_data_in_cluster, graph_degree);
raft::copy(batch_distances_h.data_handle(),
batch_distances_d,
num_data_in_cluster * graph_degree,
raft::resource::get_cuda_stream(res));

auto cluster_data_h = raft::make_host_matrix<T, int64_t, raft::row_major>(1, num_cols);
raft::print_host_vector("inverted indices", inverted_indices, num_data_in_cluster, std::cout);
for (size_t i = 0; i < num_data_in_cluster; i++) {
size_t global_row_idx = inverted_indices[i];
printf("\nbatch row %lu, global row %lu\n", i, global_row_idx);
raft::print_device_vector(
"batch distances:", batch_distances_d + i * graph_degree, graph_degree, std::cout);
raft::print_host_vector("global distances:",
global_distances_d + global_row_idx * graph_degree,
graph_degree,
std::cout);
raft::print_host_vector(
"batch indices:", batch_indices_h + i * graph_degree, graph_degree, std::cout);
raft::print_host_vector(
"global indices:", global_indices_d + global_row_idx * graph_degree, graph_degree, std::cout);

if (cluster_id == 0) {
raft::copy(cluster_data_h.data_handle(),
cluster_data + i * num_cols,
num_cols,
raft::resource::get_cuda_stream(res));
raft::print_host_vector(
"item1(global row)", cluster_data_h.data_handle(), num_cols, std::cout);
raft::print_host_vector(
"int graph", int_graph + i * int_graph_node_degree, graph_degree, std::cout);
}

for (size_t j = 0; j < graph_degree; j++) {
size_t batch_index_ij = batch_indices_h[i * graph_degree + j];

if (cluster_id == 0) {
printf("\titem2 index %lu (batch row %lu)\t", batch_index_ij, j);
for (size_t p = 0; p < num_data_in_cluster; p++) {
if (inverted_indices[p] == batch_index_ij) {
raft::copy(cluster_data_h.data_handle(),
cluster_data + p * num_cols,
num_cols,
raft::resource::get_cuda_stream(res));
raft::print_host_vector(
"item2(batch)", cluster_data_h.data_handle(), num_cols, std::cout);
break;
}
}
}

for (size_t k = 0; k < graph_degree; k++) {
size_t global_index_ik = global_indices_d[global_row_idx * graph_degree + k];

float batch_dist_ij = batch_distances_h(i, j);
float global_dist_ik = global_distances_d[global_row_idx * graph_degree + k];

if (batch_index_ij == global_index_ik) {
// distances should be the same for these two
// printf("Looking at same index for row %lu\n", global_row_idx);

if (batch_dist_ij != global_dist_ik &&
global_dist_ik != std::numeric_limits<float>::max()) {
// raft::print_device_vector("batch distances:", batch_distances_d + i * graph_degree,
// graph_degree, std::cout); raft::print_host_vector("global distances:",
// global_distances_d + global_row_idx * graph_degree, graph_degree, std::cout);
// raft::print_host_vector("batch indices:", batch_indices_h + i * graph_degree,
// graph_degree, std::cout); raft::print_host_vector("global indices:", global_indices_d
// + global_row_idx * graph_degree, graph_degree, std::cout);
printf(
"\tWrong dist calculation [%lu]. For row %lu, distance to item %lu differs: %f VS "
"%f\n",
i,
global_row_idx,
batch_index_ij,
batch_dist_ij,
global_dist_ik);

raft::copy(cluster_data_h.data_handle(),
cluster_data + i * num_cols,
num_cols,
raft::resource::get_cuda_stream(res));
raft::print_host_vector("item1", cluster_data_h.data_handle(), num_cols, std::cout);

for (size_t p = 0; p < num_data_in_cluster; p++) {
if (inverted_indices[p] == batch_index_ij) {
raft::copy(cluster_data_h.data_handle(),
cluster_data + p * num_cols,
num_cols,
raft::resource::get_cuda_stream(res));
printf("batch row %lu\t", p);
raft::print_host_vector("item2", cluster_data_h.data_handle(), num_cols, std::cout);
break;
}
}
}
}
}
}
}

raft::copy(batch_indices_d,
batch_indices_h,
num_data_in_cluster * graph_degree,
Expand Down Expand Up @@ -488,6 +602,11 @@ void cluster_nnd(raft::resources const& res,
"# Data on host. Running clusters: %lu / %lu", cluster_id + 1, params.n_clusters);
size_t num_data_in_cluster = cluster_size[cluster_id];
size_t offset = offsets[cluster_id];
printf("# Data on host. Running clusters: %lu / %lu (max %lu, num data: %lu)\n",
cluster_id + 1,
params.n_clusters,
max_cluster_size,
num_data_in_cluster);

#pragma omp parallel for
for (size_t i = 0; i < num_data_in_cluster; i++) {
Expand All @@ -511,7 +630,9 @@ void cluster_nnd(raft::resources const& res,
batch_indices_h,
batch_indices_d,
batch_distances_d,
nnd);
nnd,
cluster_id,
num_cols);
nnd.reset(res);
}
}
Expand Down Expand Up @@ -548,6 +669,11 @@ void cluster_nnd(raft::resources const& res,
"# Data on device. Running clusters: %lu / %lu", cluster_id + 1, params.n_clusters);
size_t num_data_in_cluster = cluster_size[cluster_id];
size_t offset = offsets[cluster_id];
printf("# Data on device. Running clusters: %lu / %lu (max %lu, num data: %lu)\n",
cluster_id + 1,
params.n_clusters,
max_cluster_size,
num_data_in_cluster);

auto cluster_data_view = raft::make_device_matrix_view<T, IdxT>(
cluster_data_matrix.data_handle(), num_data_in_cluster, num_cols);
Expand All @@ -572,7 +698,9 @@ void cluster_nnd(raft::resources const& res,
batch_indices_h,
batch_indices_d,
batch_distances_d,
nnd);
nnd,
cluster_id,
num_cols);
nnd.reset(res);
}
}
Expand All @@ -591,6 +719,7 @@ void batch_build(raft::resources const& res,

size_t num_rows = static_cast<size_t>(dataset.extent(0));
size_t num_cols = static_cast<size_t>(dataset.extent(1));
printf("num rows: %lu, num cols: %lu\n", num_rows, num_cols);

auto centroids =
raft::make_device_matrix<T, IdxT, raft::row_major>(res, params.n_clusters, num_cols);
Expand All @@ -612,6 +741,7 @@ void batch_build(raft::resources const& res,
auto offset = raft::make_host_vector<IdxT, IdxT, raft::row_major>(params.n_clusters);

size_t max_cluster_size, min_cluster_size;

get_inverted_indices(res,
params.n_clusters,
max_cluster_size,
Expand All @@ -620,7 +750,7 @@ void batch_build(raft::resources const& res,
inverted_indices.view(),
cluster_size.view(),
offset.view());

printf("max cluster size: %lu, min cluster size: %lu]\n", max_cluster_size, min_cluster_size);
if (intermediate_degree >= min_cluster_size) {
RAFT_LOG_WARN(
"Intermediate graph degree cannot be larger than minimum cluster size, reducing it to %lu",
Expand Down

0 comments on commit 650ec3c

Please sign in to comment.