-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
32eabf7
commit f946526
Showing
7 changed files
with
379 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import os | ||
import torch | ||
from ipdb import set_trace | ||
from torchex import GridHash | ||
import random | ||
|
||
|
||
if __name__ == '__main__': | ||
random.seed(2) | ||
torch.manual_seed(0) | ||
torch.cuda.manual_seed(0) | ||
|
||
for i in range(100): | ||
print('***********') | ||
all_size = random.randint(1, 500000) | ||
high = random.randint(10, 1024) | ||
coors = torch.randint(low=0, high=high, size=(all_size, 3)).int().cuda() | ||
|
||
unq_coors = torch.unique(coors, dim=0) | ||
|
||
# table = grid_hash_build(unq_coors) | ||
hasher = GridHash(coors, debug=True) | ||
table = hasher.table | ||
|
||
valid_mask = hasher.valid_mask | ||
|
||
valid_values = hasher.valid_table[:, 1] | ||
|
||
# make sure the mapping is right | ||
|
||
out_valid_values = hasher.probe(table[valid_mask, 0]) | ||
assert (out_valid_values == valid_values).all() | ||
|
||
# make sure the mapping is stable | ||
|
||
index = torch.randint(low=0, high=all_size - 1, size=(random.randint(1, 500000),) ).long().cuda() | ||
queries = coors | ||
|
||
values_1 = hasher.probe(queries) | ||
|
||
values_2 = hasher.probe(queries[index]) | ||
assert (values_1[index] == values_2).all() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
import torch | ||
import grid_hash_ext | ||
|
||
from torch.autograd import Function | ||
from ipdb import set_trace | ||
|
||
def _flatten_3dim_coors(coors, sizes): | ||
|
||
x, y, z = coors[:, 0], coors[:, 1], coors[:, 2] | ||
x_size = sizes[0] | ||
y_size = sizes[1] | ||
z_size = sizes[2] | ||
assert x.max() < x_size and y.max() < y_size and z.max() < z_size | ||
flatten_coors = x * y_size * z_size + y * z_size + z | ||
return flatten_coors | ||
|
||
def _get_dim_size(coors): | ||
|
||
x, y, z = coors[:, 0], coors[:, 1], coors[:, 2] | ||
x_size = x.max() + 1 | ||
y_size = y.max() + 1 | ||
z_size = z.max() + 1 | ||
return x_size, y_size, z_size | ||
|
||
class _BuildHashTableFunction(Function): | ||
|
||
@staticmethod | ||
def forward(ctx, coors): | ||
|
||
# unique the coors | ||
table = grid_hash_ext.build_table(coors) | ||
ctx.mark_non_differentiable(table) | ||
|
||
return table | ||
|
||
@staticmethod | ||
def backward(ctx, g): | ||
return None | ||
|
||
class _ProbeHashTableFunction(Function): | ||
|
||
@staticmethod | ||
def forward(ctx, coors, table): | ||
|
||
out_values = grid_hash_ext.probe_table(coors, table) | ||
ctx.mark_non_differentiable(out_values) | ||
|
||
return out_values | ||
|
||
@staticmethod | ||
def backward(ctx, g): | ||
return None | ||
|
||
grid_hash_build = _BuildHashTableFunction.apply | ||
grid_hash_probe = _ProbeHashTableFunction.apply | ||
|
||
|
||
|
||
class GridHash: | ||
|
||
def __init__(self, coors, debug=False): | ||
|
||
coors = torch.unique(coors, dim=0) | ||
|
||
if coors.ndim == 1: | ||
self.dim_size = None | ||
else: | ||
assert coors.ndim == 2 and coors.size(1) == 3 | ||
self.dim_size = _get_dim_size(coors) | ||
coors = _flatten_3dim_coors(coors, self.dim_size) | ||
|
||
|
||
if coors.dtype != torch.int32: | ||
coors = coors.int() | ||
|
||
if debug: | ||
assert (coors >= 0).all() and (coors < 2 ** 30).all() | ||
|
||
self.table = grid_hash_build(coors) | ||
print(len(self.table)) | ||
self.valid_mask = self.table[:, 0] != -1 | ||
self.valid_table = self.table[self.valid_mask] | ||
|
||
if debug: | ||
assert self.valid_mask.sum() == len(coors) | ||
assert self.valid_table[:, 1].max() + 1 == len(coors) | ||
assert len(self.valid_table[:, 1]) == len(torch.unique(self.valid_table[:, 1])) | ||
if not self.valid_mask.all(): | ||
assert (self.table[~self.valid_mask, 0] == -1).all() | ||
|
||
def probe(self, coors): | ||
|
||
if coors.ndim == 2: | ||
assert coors.size(1) == 3 | ||
coors = _flatten_3dim_coors(coors, self.dim_size) | ||
else: | ||
assert coors.ndim == 1 | ||
|
||
if coors.dtype != torch.int32: | ||
coors = coors.int() | ||
|
||
values = grid_hash_probe(coors, self.table) | ||
|
||
return values |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
#include <assert.h> | ||
#include <torch/extension.h> | ||
#include <torch/serialize/tensor.h> | ||
#include <vector> | ||
#include <tuple> | ||
|
||
#define CHECK_CUDA(x) \ | ||
TORCH_CHECK(x.device().is_cuda(), #x, " must be a CUDAtensor ") | ||
#define CHECK_CONTIGUOUS(x) \ | ||
TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ") | ||
#define CHECK_INPUT(x) \ | ||
CHECK_CUDA(x); \ | ||
CHECK_CONTIGUOUS(x) | ||
|
||
|
||
void build_hash_table_launcher( | ||
const int *coors_ptr, | ||
int *table_ptr, | ||
int table_size, | ||
int N | ||
); | ||
|
||
void probe_hash_table_launcher( | ||
const int *coors_ptr, | ||
const int *table_ptr, | ||
int *out_values_ptr, | ||
int table_size, | ||
int N | ||
); | ||
|
||
__inline__ int up_2n(int n){ | ||
if (n == 1) return 1; | ||
int temp = n - 1; | ||
temp |= temp >> 1; | ||
temp |= temp >> 2; | ||
temp |= temp >> 4; | ||
temp |= temp >> 8; | ||
temp |= temp >> 16; | ||
return temp + 1; | ||
} | ||
|
||
|
||
torch::Tensor build_hash_table( | ||
torch::Tensor coors | ||
); | ||
|
||
torch::Tensor build_hash_table( | ||
torch::Tensor coors | ||
) { | ||
|
||
CHECK_INPUT(coors); | ||
int N = coors.size(0); | ||
assert (coors.ndimension() == 1); | ||
|
||
auto int_opts = coors.options().dtype(torch::kInt32); | ||
int table_size = up_2n(N); | ||
|
||
torch::Tensor table = torch::full({table_size, 2}, -1, int_opts); // the first channel is flattened coors, the second channel is the mapped 1D index | ||
|
||
|
||
const int *coors_ptr = coors.data_ptr<int>(); | ||
int *table_ptr = table.data_ptr<int>(); | ||
|
||
build_hash_table_launcher( | ||
coors_ptr, | ||
table_ptr, | ||
table_size, | ||
N | ||
); | ||
return table; | ||
|
||
} | ||
|
||
torch::Tensor probe_hash_table( | ||
torch::Tensor coors, | ||
torch::Tensor table | ||
); | ||
|
||
torch::Tensor probe_hash_table( | ||
torch::Tensor coors, | ||
torch::Tensor table | ||
) { | ||
|
||
CHECK_INPUT(coors); | ||
CHECK_INPUT(table); | ||
int N = coors.size(0); | ||
assert (coors.ndimension() == 1); | ||
torch::Tensor out_values = torch::full_like(coors, -1); // the first channel is flattened coors, the second channel is the mapped 1D index | ||
|
||
int table_size = table.size(0); | ||
|
||
const int *coors_ptr = coors.data_ptr<int>(); | ||
const int *table_ptr = table.data_ptr<int>(); | ||
int *out_values_ptr = out_values.data_ptr<int>(); | ||
|
||
probe_hash_table_launcher( | ||
coors_ptr, | ||
table_ptr, | ||
out_values_ptr, | ||
table_size, | ||
N | ||
); | ||
|
||
return out_values; | ||
|
||
} | ||
|
||
|
||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||
m.def("build_table", &build_hash_table, "build hash table (coors -> 1D compact index) given coors "); | ||
m.def("probe_table", &probe_hash_table, "query the hash table acoording to the given coors"); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
#include <assert.h> | ||
#include <vector> | ||
#include <math.h> | ||
#include <stdio.h> | ||
#include <torch/serialize/tensor.h> | ||
#include <torch/extension.h> | ||
#include <torch/types.h> | ||
#include "cuda_fp16.h" | ||
#include "../utils/error.cuh" | ||
#include "../utils/functions.cuh" | ||
|
||
#define THREADS_PER_BLOCK 256 | ||
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) | ||
|
||
#define DEBUG | ||
// #define ASSERTION | ||
|
||
|
||
__global__ void build_hash_table_kernel( | ||
const int *keys, | ||
int *table, | ||
int table_size, | ||
int N | ||
) { | ||
|
||
int idx = blockIdx.x * blockDim.x + threadIdx.x; | ||
if (idx >= N) return; | ||
int key = keys[idx]; | ||
setvalue(key, idx, table, table_size); | ||
|
||
} | ||
|
||
__global__ void probe_hash_table_kernel( | ||
const int *keys, | ||
const int *table, | ||
int *out_values, | ||
int table_size, | ||
int N | ||
) { | ||
|
||
int idx = blockIdx.x * blockDim.x + threadIdx.x; | ||
if (idx >= N) return; | ||
int key = keys[idx]; | ||
int v = getvalue(key, table, table_size); | ||
assert (v != -1); // -1 means key error, because all values are >= 0 | ||
out_values[idx] = v; | ||
} | ||
|
||
void probe_hash_table_launcher( | ||
const int *keys, | ||
const int *table, | ||
int *out_values, | ||
int table_size, | ||
int N | ||
) { | ||
|
||
dim3 blocks(DIVUP(N, THREADS_PER_BLOCK)); | ||
dim3 threads(THREADS_PER_BLOCK); | ||
|
||
probe_hash_table_kernel<<<blocks, threads>>>( | ||
keys, | ||
table, | ||
out_values, | ||
table_size, | ||
N | ||
); | ||
|
||
#ifdef DEBUG | ||
CHECK_CALL(cudaGetLastError()); | ||
CHECK_CALL(cudaDeviceSynchronize()); | ||
#endif | ||
|
||
return; | ||
|
||
} | ||
|
||
|
||
void build_hash_table_launcher( | ||
const int *keys, | ||
int *table, | ||
int table_size, | ||
int N | ||
) { | ||
|
||
dim3 blocks(DIVUP(N, THREADS_PER_BLOCK)); | ||
dim3 threads(THREADS_PER_BLOCK); | ||
assert (table_size >= N); | ||
|
||
build_hash_table_kernel<<<blocks, threads>>>( | ||
keys, | ||
table, | ||
table_size, | ||
N | ||
); | ||
|
||
#ifdef DEBUG | ||
CHECK_CALL(cudaGetLastError()); | ||
CHECK_CALL(cudaDeviceSynchronize()); | ||
#endif | ||
|
||
return; | ||
|
||
} |
Oops, something went wrong.