From f946526422f01f632028e5dd2d14961fe747bcbb Mon Sep 17 00:00:00 2001 From: Lue FAN <1060056270@qq.com> Date: Mon, 4 Dec 2023 15:13:24 +0000 Subject: [PATCH] grid hash --- setup.py | 5 + test/test_grid_hash.py | 42 ++++++++ torchex/operator_py/__init__.py | 4 +- torchex/operator_py/grid_hash.py | 104 ++++++++++++++++++++ torchex/src/grid_hash/grid_hash.cpp | 112 ++++++++++++++++++++++ torchex/src/grid_hash/grid_hash_kernel.cu | 103 ++++++++++++++++++++ torchex/src/utils/functions.cuh | 16 ++-- 7 files changed, 379 insertions(+), 7 deletions(-) create mode 100644 test/test_grid_hash.py create mode 100644 torchex/operator_py/grid_hash.py create mode 100644 torchex/src/grid_hash/grid_hash.cpp create mode 100644 torchex/src/grid_hash/grid_hash_kernel.cu diff --git a/setup.py b/setup.py index 86a4fd0..37b6c00 100644 --- a/setup.py +++ b/setup.py @@ -62,6 +62,11 @@ ['./torchex/src/incremental_points/incremental_points.cpp', './torchex/src/incremental_points/incremental_points_kernel.cu',] ), + CUDAExtension( + 'grid_hash_ext', + ['./torchex/src/grid_hash/grid_hash.cpp', + './torchex/src/grid_hash/grid_hash_kernel.cu',] + ), ], cmdclass={ 'build_ext': BuildExtension diff --git a/test/test_grid_hash.py b/test/test_grid_hash.py new file mode 100644 index 0000000..252fd34 --- /dev/null +++ b/test/test_grid_hash.py @@ -0,0 +1,42 @@ +import os +import torch +from ipdb import set_trace +from torchex import GridHash +import random + + +if __name__ == '__main__': + random.seed(2) + torch.manual_seed(0) + torch.cuda.manual_seed(0) + + for i in range(100): + print('***********') + all_size = random.randint(1, 500000) + high = random.randint(10, 1024) + coors = torch.randint(low=0, high=high, size=(all_size, 3)).int().cuda() + + unq_coors = torch.unique(coors, dim=0) + + # table = grid_hash_build(unq_coors) + hasher = GridHash(coors, debug=True) + table = hasher.table + + valid_mask = hasher.valid_mask + + valid_values = hasher.valid_table[:, 1] + + # make sure the mapping is right + + out_valid_values = hasher.probe(table[valid_mask, 0]) + assert (out_valid_values == valid_values).all() + + # make sure the mapping is stable + + index = torch.randint(low=0, high=all_size - 1, size=(random.randint(1, 500000),) ).long().cuda() + queries = coors + + values_1 = hasher.probe(queries) + + values_2 = hasher.probe(queries[index]) + assert (values_1[index] == values_2).all() diff --git a/torchex/operator_py/__init__.py b/torchex/operator_py/__init__.py index ada997c..d7d1586 100644 --- a/torchex/operator_py/__init__.py +++ b/torchex/operator_py/__init__.py @@ -8,8 +8,10 @@ from .codec_op import mask_encoder, mask_decoder from .iou3d_op import boxes_iou_bev, boxes_iou_bev_1to1, boxes_overlap_1to1, nms_gpu, nms_normal_gpu, nms_mixed_gpu, aug_nms_gpu from .chamfer_distance_op import chamfer_distance +from .grid_hash import grid_hash_build, grid_hash_probe, GridHash __all__ = ['dynamic_point_pool', 'TorchTimer', 'ingroup_inds', 'connected_components', 'scatter_sum', 'scatter_mean', 'scatter_max', 'ScatterData', 'mask_encoder', 'mask_decoder', 'boxes_iou_bev', 'boxes_iou_bev_1to1', 'boxes_overlap_1to1', - 'nms_gpu', 'nms_normal_gpu', 'nms_mixed_gpu', 'aug_nms_gpu', 'incremental_points_mask'] + 'nms_gpu', 'nms_normal_gpu', 'nms_mixed_gpu', 'aug_nms_gpu', 'incremental_points_mask', + 'grid_hash_build', 'grid_hash_probe', 'GridHash'] diff --git a/torchex/operator_py/grid_hash.py b/torchex/operator_py/grid_hash.py new file mode 100644 index 0000000..217cc0c --- /dev/null +++ b/torchex/operator_py/grid_hash.py @@ -0,0 +1,104 @@ +import torch +import grid_hash_ext + +from torch.autograd import Function +from ipdb import set_trace + +def _flatten_3dim_coors(coors, sizes): + + x, y, z = coors[:, 0], coors[:, 1], coors[:, 2] + x_size = sizes[0] + y_size = sizes[1] + z_size = sizes[2] + assert x.max() < x_size and y.max() < y_size and z.max() < z_size + flatten_coors = x * y_size * z_size + y * z_size + z + return flatten_coors + +def _get_dim_size(coors): + + x, y, z = coors[:, 0], coors[:, 1], coors[:, 2] + x_size = x.max() + 1 + y_size = y.max() + 1 + z_size = z.max() + 1 + return x_size, y_size, z_size + +class _BuildHashTableFunction(Function): + + @staticmethod + def forward(ctx, coors): + + # unique the coors + table = grid_hash_ext.build_table(coors) + ctx.mark_non_differentiable(table) + + return table + + @staticmethod + def backward(ctx, g): + return None + +class _ProbeHashTableFunction(Function): + + @staticmethod + def forward(ctx, coors, table): + + out_values = grid_hash_ext.probe_table(coors, table) + ctx.mark_non_differentiable(out_values) + + return out_values + + @staticmethod + def backward(ctx, g): + return None + +grid_hash_build = _BuildHashTableFunction.apply +grid_hash_probe = _ProbeHashTableFunction.apply + + + +class GridHash: + + def __init__(self, coors, debug=False): + + coors = torch.unique(coors, dim=0) + + if coors.ndim == 1: + self.dim_size = None + else: + assert coors.ndim == 2 and coors.size(1) == 3 + self.dim_size = _get_dim_size(coors) + coors = _flatten_3dim_coors(coors, self.dim_size) + + + if coors.dtype != torch.int32: + coors = coors.int() + + if debug: + assert (coors >= 0).all() and (coors < 2 ** 30).all() + + self.table = grid_hash_build(coors) + print(len(self.table)) + self.valid_mask = self.table[:, 0] != -1 + self.valid_table = self.table[self.valid_mask] + + if debug: + assert self.valid_mask.sum() == len(coors) + assert self.valid_table[:, 1].max() + 1 == len(coors) + assert len(self.valid_table[:, 1]) == len(torch.unique(self.valid_table[:, 1])) + if not self.valid_mask.all(): + assert (self.table[~self.valid_mask, 0] == -1).all() + + def probe(self, coors): + + if coors.ndim == 2: + assert coors.size(1) == 3 + coors = _flatten_3dim_coors(coors, self.dim_size) + else: + assert coors.ndim == 1 + + if coors.dtype != torch.int32: + coors = coors.int() + + values = grid_hash_probe(coors, self.table) + + return values \ No newline at end of file diff --git a/torchex/src/grid_hash/grid_hash.cpp b/torchex/src/grid_hash/grid_hash.cpp new file mode 100644 index 0000000..2076f54 --- /dev/null +++ b/torchex/src/grid_hash/grid_hash.cpp @@ -0,0 +1,112 @@ +#include +#include +#include +#include +#include + +#define CHECK_CUDA(x) \ + TORCH_CHECK(x.device().is_cuda(), #x, " must be a CUDAtensor ") +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + + +void build_hash_table_launcher( + const int *coors_ptr, + int *table_ptr, + int table_size, + int N +); + +void probe_hash_table_launcher( + const int *coors_ptr, + const int *table_ptr, + int *out_values_ptr, + int table_size, + int N +); + +__inline__ int up_2n(int n){ + if (n == 1) return 1; + int temp = n - 1; + temp |= temp >> 1; + temp |= temp >> 2; + temp |= temp >> 4; + temp |= temp >> 8; + temp |= temp >> 16; + return temp + 1; +} + + +torch::Tensor build_hash_table( + torch::Tensor coors +); + +torch::Tensor build_hash_table( + torch::Tensor coors +) { + + CHECK_INPUT(coors); + int N = coors.size(0); + assert (coors.ndimension() == 1); + + auto int_opts = coors.options().dtype(torch::kInt32); + int table_size = up_2n(N); + + torch::Tensor table = torch::full({table_size, 2}, -1, int_opts); // the first channel is flattened coors, the second channel is the mapped 1D index + + + const int *coors_ptr = coors.data_ptr(); + int *table_ptr = table.data_ptr(); + + build_hash_table_launcher( + coors_ptr, + table_ptr, + table_size, + N + ); + return table; + +} + +torch::Tensor probe_hash_table( + torch::Tensor coors, + torch::Tensor table +); + +torch::Tensor probe_hash_table( + torch::Tensor coors, + torch::Tensor table +) { + + CHECK_INPUT(coors); + CHECK_INPUT(table); + int N = coors.size(0); + assert (coors.ndimension() == 1); + torch::Tensor out_values = torch::full_like(coors, -1); // the first channel is flattened coors, the second channel is the mapped 1D index + + int table_size = table.size(0); + + const int *coors_ptr = coors.data_ptr(); + const int *table_ptr = table.data_ptr(); + int *out_values_ptr = out_values.data_ptr(); + + probe_hash_table_launcher( + coors_ptr, + table_ptr, + out_values_ptr, + table_size, + N + ); + + return out_values; + +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("build_table", &build_hash_table, "build hash table (coors -> 1D compact index) given coors "); + m.def("probe_table", &probe_hash_table, "query the hash table acoording to the given coors"); +} diff --git a/torchex/src/grid_hash/grid_hash_kernel.cu b/torchex/src/grid_hash/grid_hash_kernel.cu new file mode 100644 index 0000000..469713c --- /dev/null +++ b/torchex/src/grid_hash/grid_hash_kernel.cu @@ -0,0 +1,103 @@ +#include +#include +#include +#include +#include +#include +#include +#include "cuda_fp16.h" +#include "../utils/error.cuh" +#include "../utils/functions.cuh" + +#define THREADS_PER_BLOCK 256 +#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) + +#define DEBUG +// #define ASSERTION + + +__global__ void build_hash_table_kernel( + const int *keys, + int *table, + int table_size, + int N +) { + + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= N) return; + int key = keys[idx]; + setvalue(key, idx, table, table_size); + +} + +__global__ void probe_hash_table_kernel( + const int *keys, + const int *table, + int *out_values, + int table_size, + int N +) { + + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= N) return; + int key = keys[idx]; + int v = getvalue(key, table, table_size); + assert (v != -1); // -1 means key error, because all values are >= 0 + out_values[idx] = v; +} + +void probe_hash_table_launcher( + const int *keys, + const int *table, + int *out_values, + int table_size, + int N + ) { + + dim3 blocks(DIVUP(N, THREADS_PER_BLOCK)); + dim3 threads(THREADS_PER_BLOCK); + + probe_hash_table_kernel<<>>( + keys, + table, + out_values, + table_size, + N + ); + + #ifdef DEBUG + CHECK_CALL(cudaGetLastError()); + CHECK_CALL(cudaDeviceSynchronize()); + #endif + + return; + +} + + +void build_hash_table_launcher( + const int *keys, + int *table, + int table_size, + int N + ) { + + dim3 blocks(DIVUP(N, THREADS_PER_BLOCK)); + dim3 threads(THREADS_PER_BLOCK); + assert (table_size >= N); + + build_hash_table_kernel<<>>( + keys, + table, + table_size, + N + ); + + #ifdef DEBUG + CHECK_CALL(cudaGetLastError()); + CHECK_CALL(cudaDeviceSynchronize()); + #endif + + return; + +} diff --git a/torchex/src/utils/functions.cuh b/torchex/src/utils/functions.cuh index 256d5ec..2cf5088 100644 --- a/torchex/src/utils/functions.cuh +++ b/torchex/src/utils/functions.cuh @@ -27,16 +27,20 @@ int up_2n(int n){ // A simple hash table -__device__ __forceinline__ int double_hash(int key, int probe_i, int table_size); -__device__ __forceinline__ int double_hash(int key, int probe_i, int table_size){ +__device__ __forceinline__ int double_hash(const long long key, int probe_i, int table_size); +__device__ __forceinline__ int double_hash(const long long key, int probe_i, int table_size){ // equivalent to (key + probe_i * (key * 2 + 1)) % table_size, keep the one more mod op for better understanding. return (key % table_size + probe_i * (key * 2 + 1)) % table_size; } -__device__ void setvalue(const int key, const int value, int *table, const int table_size); -__device__ void setvalue(const int key, const int value, int *table, const int table_size){ +__device__ void setvalue(const long long key, const int value, int *table, const int table_size); +__device__ void setvalue(const long long key, const int value, int *table, const int table_size){ for (int i = 0; i < table_size; i++){ int slot_idx = double_hash(key, i, table_size); + // if (slot_idx < 0 || slot_idx >= table_size){ + // printf("slot_idx: %d, table_size: %d, probe: %d, key: %d\n", slot_idx, table_size, i, key); + // assert(false); + // } int old_key = atomicCAS(&table[2 * slot_idx], -1, key); // even pos: key, odd pos: value if (old_key == -1){ table[2 * slot_idx + 1] = value; @@ -47,8 +51,8 @@ __device__ void setvalue(const int key, const int value, int *table, const int t assert(false); } -__device__ int getvalue(const int key, const int *table, const int table_size); -__device__ int getvalue(const int key, const int *table, const int table_size){ +__device__ int getvalue(const long long key, const int *table, const int table_size); +__device__ int getvalue(const long long key, const int *table, const int table_size){ for (int i = 0; i < table_size; i++){ int slot_idx = double_hash(key, i, table_size); int slot_key = table[2 * slot_idx]; // even pos: key, odd pos: value