Skip to content

Commit

Permalink
grid hash
Browse files Browse the repository at this point in the history
  • Loading branch information
Abyssaledge committed Dec 4, 2023
1 parent 32eabf7 commit f946526
Show file tree
Hide file tree
Showing 7 changed files with 379 additions and 7 deletions.
5 changes: 5 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@
['./torchex/src/incremental_points/incremental_points.cpp',
'./torchex/src/incremental_points/incremental_points_kernel.cu',]
),
CUDAExtension(
'grid_hash_ext',
['./torchex/src/grid_hash/grid_hash.cpp',
'./torchex/src/grid_hash/grid_hash_kernel.cu',]
),
],
cmdclass={
'build_ext': BuildExtension
Expand Down
42 changes: 42 additions & 0 deletions test/test_grid_hash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import os
import torch
from ipdb import set_trace
from torchex import GridHash
import random


if __name__ == '__main__':
random.seed(2)
torch.manual_seed(0)
torch.cuda.manual_seed(0)

for i in range(100):
print('***********')
all_size = random.randint(1, 500000)
high = random.randint(10, 1024)
coors = torch.randint(low=0, high=high, size=(all_size, 3)).int().cuda()

unq_coors = torch.unique(coors, dim=0)

# table = grid_hash_build(unq_coors)
hasher = GridHash(coors, debug=True)
table = hasher.table

valid_mask = hasher.valid_mask

valid_values = hasher.valid_table[:, 1]

# make sure the mapping is right

out_valid_values = hasher.probe(table[valid_mask, 0])
assert (out_valid_values == valid_values).all()

# make sure the mapping is stable

index = torch.randint(low=0, high=all_size - 1, size=(random.randint(1, 500000),) ).long().cuda()
queries = coors

values_1 = hasher.probe(queries)

values_2 = hasher.probe(queries[index])
assert (values_1[index] == values_2).all()
4 changes: 3 additions & 1 deletion torchex/operator_py/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
from .codec_op import mask_encoder, mask_decoder
from .iou3d_op import boxes_iou_bev, boxes_iou_bev_1to1, boxes_overlap_1to1, nms_gpu, nms_normal_gpu, nms_mixed_gpu, aug_nms_gpu
from .chamfer_distance_op import chamfer_distance
from .grid_hash import grid_hash_build, grid_hash_probe, GridHash


__all__ = ['dynamic_point_pool', 'TorchTimer', 'ingroup_inds', 'connected_components', 'scatter_sum', 'scatter_mean',
'scatter_max', 'ScatterData', 'mask_encoder', 'mask_decoder', 'boxes_iou_bev', 'boxes_iou_bev_1to1', 'boxes_overlap_1to1',
'nms_gpu', 'nms_normal_gpu', 'nms_mixed_gpu', 'aug_nms_gpu', 'incremental_points_mask']
'nms_gpu', 'nms_normal_gpu', 'nms_mixed_gpu', 'aug_nms_gpu', 'incremental_points_mask',
'grid_hash_build', 'grid_hash_probe', 'GridHash']
104 changes: 104 additions & 0 deletions torchex/operator_py/grid_hash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import torch
import grid_hash_ext

from torch.autograd import Function
from ipdb import set_trace

def _flatten_3dim_coors(coors, sizes):

x, y, z = coors[:, 0], coors[:, 1], coors[:, 2]
x_size = sizes[0]
y_size = sizes[1]
z_size = sizes[2]
assert x.max() < x_size and y.max() < y_size and z.max() < z_size
flatten_coors = x * y_size * z_size + y * z_size + z
return flatten_coors

def _get_dim_size(coors):

x, y, z = coors[:, 0], coors[:, 1], coors[:, 2]
x_size = x.max() + 1
y_size = y.max() + 1
z_size = z.max() + 1
return x_size, y_size, z_size

class _BuildHashTableFunction(Function):

@staticmethod
def forward(ctx, coors):

# unique the coors
table = grid_hash_ext.build_table(coors)
ctx.mark_non_differentiable(table)

return table

@staticmethod
def backward(ctx, g):
return None

class _ProbeHashTableFunction(Function):

@staticmethod
def forward(ctx, coors, table):

out_values = grid_hash_ext.probe_table(coors, table)
ctx.mark_non_differentiable(out_values)

return out_values

@staticmethod
def backward(ctx, g):
return None

grid_hash_build = _BuildHashTableFunction.apply
grid_hash_probe = _ProbeHashTableFunction.apply



class GridHash:

def __init__(self, coors, debug=False):

coors = torch.unique(coors, dim=0)

if coors.ndim == 1:
self.dim_size = None
else:
assert coors.ndim == 2 and coors.size(1) == 3
self.dim_size = _get_dim_size(coors)
coors = _flatten_3dim_coors(coors, self.dim_size)


if coors.dtype != torch.int32:
coors = coors.int()

if debug:
assert (coors >= 0).all() and (coors < 2 ** 30).all()

self.table = grid_hash_build(coors)
print(len(self.table))
self.valid_mask = self.table[:, 0] != -1
self.valid_table = self.table[self.valid_mask]

if debug:
assert self.valid_mask.sum() == len(coors)
assert self.valid_table[:, 1].max() + 1 == len(coors)
assert len(self.valid_table[:, 1]) == len(torch.unique(self.valid_table[:, 1]))
if not self.valid_mask.all():
assert (self.table[~self.valid_mask, 0] == -1).all()

def probe(self, coors):

if coors.ndim == 2:
assert coors.size(1) == 3
coors = _flatten_3dim_coors(coors, self.dim_size)
else:
assert coors.ndim == 1

if coors.dtype != torch.int32:
coors = coors.int()

values = grid_hash_probe(coors, self.table)

return values
112 changes: 112 additions & 0 deletions torchex/src/grid_hash/grid_hash.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
#include <assert.h>
#include <torch/extension.h>
#include <torch/serialize/tensor.h>
#include <vector>
#include <tuple>

#define CHECK_CUDA(x) \
TORCH_CHECK(x.device().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)


void build_hash_table_launcher(
const int *coors_ptr,
int *table_ptr,
int table_size,
int N
);

void probe_hash_table_launcher(
const int *coors_ptr,
const int *table_ptr,
int *out_values_ptr,
int table_size,
int N
);

__inline__ int up_2n(int n){
if (n == 1) return 1;
int temp = n - 1;
temp |= temp >> 1;
temp |= temp >> 2;
temp |= temp >> 4;
temp |= temp >> 8;
temp |= temp >> 16;
return temp + 1;
}


torch::Tensor build_hash_table(
torch::Tensor coors
);

torch::Tensor build_hash_table(
torch::Tensor coors
) {

CHECK_INPUT(coors);
int N = coors.size(0);
assert (coors.ndimension() == 1);

auto int_opts = coors.options().dtype(torch::kInt32);
int table_size = up_2n(N);

torch::Tensor table = torch::full({table_size, 2}, -1, int_opts); // the first channel is flattened coors, the second channel is the mapped 1D index


const int *coors_ptr = coors.data_ptr<int>();
int *table_ptr = table.data_ptr<int>();

build_hash_table_launcher(
coors_ptr,
table_ptr,
table_size,
N
);
return table;

}

torch::Tensor probe_hash_table(
torch::Tensor coors,
torch::Tensor table
);

torch::Tensor probe_hash_table(
torch::Tensor coors,
torch::Tensor table
) {

CHECK_INPUT(coors);
CHECK_INPUT(table);
int N = coors.size(0);
assert (coors.ndimension() == 1);
torch::Tensor out_values = torch::full_like(coors, -1); // the first channel is flattened coors, the second channel is the mapped 1D index

int table_size = table.size(0);

const int *coors_ptr = coors.data_ptr<int>();
const int *table_ptr = table.data_ptr<int>();
int *out_values_ptr = out_values.data_ptr<int>();

probe_hash_table_launcher(
coors_ptr,
table_ptr,
out_values_ptr,
table_size,
N
);

return out_values;

}


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("build_table", &build_hash_table, "build hash table (coors -> 1D compact index) given coors ");
m.def("probe_table", &probe_hash_table, "query the hash table acoording to the given coors");
}
103 changes: 103 additions & 0 deletions torchex/src/grid_hash/grid_hash_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#include <assert.h>
#include <vector>
#include <math.h>
#include <stdio.h>
#include <torch/serialize/tensor.h>
#include <torch/extension.h>
#include <torch/types.h>
#include "cuda_fp16.h"
#include "../utils/error.cuh"
#include "../utils/functions.cuh"

#define THREADS_PER_BLOCK 256
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))

#define DEBUG
// #define ASSERTION


__global__ void build_hash_table_kernel(
const int *keys,
int *table,
int table_size,
int N
) {

int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) return;
int key = keys[idx];
setvalue(key, idx, table, table_size);

}

__global__ void probe_hash_table_kernel(
const int *keys,
const int *table,
int *out_values,
int table_size,
int N
) {

int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) return;
int key = keys[idx];
int v = getvalue(key, table, table_size);
assert (v != -1); // -1 means key error, because all values are >= 0
out_values[idx] = v;
}

void probe_hash_table_launcher(
const int *keys,
const int *table,
int *out_values,
int table_size,
int N
) {

dim3 blocks(DIVUP(N, THREADS_PER_BLOCK));
dim3 threads(THREADS_PER_BLOCK);

probe_hash_table_kernel<<<blocks, threads>>>(
keys,
table,
out_values,
table_size,
N
);

#ifdef DEBUG
CHECK_CALL(cudaGetLastError());
CHECK_CALL(cudaDeviceSynchronize());
#endif

return;

}


void build_hash_table_launcher(
const int *keys,
int *table,
int table_size,
int N
) {

dim3 blocks(DIVUP(N, THREADS_PER_BLOCK));
dim3 threads(THREADS_PER_BLOCK);
assert (table_size >= N);

build_hash_table_kernel<<<blocks, threads>>>(
keys,
table,
table_size,
N
);

#ifdef DEBUG
CHECK_CALL(cudaGetLastError());
CHECK_CALL(cudaDeviceSynchronize());
#endif

return;

}
Loading

0 comments on commit f946526

Please sign in to comment.