From bd82e5f8600ce2581c4af9395f98de4691958d66 Mon Sep 17 00:00:00 2001 From: Alexandre Boulch Date: Thu, 3 Dec 2020 20:15:28 +0000 Subject: [PATCH 1/3] added no compilation support using pytorch geometric --- doc/install.md | 14 ++- lightconvpoint/knn/__init__.py | 4 + lightconvpoint/knn/furthest_point_sampling.py | 39 ++++++++ lightconvpoint/knn/knn.py | 34 +++++++ lightconvpoint/knn/quantized_sampling.py | 90 +++++++++++++++++++ lightconvpoint/knn/random_sampling.py | 37 ++++++++ setup.py | 45 ++++++---- 7 files changed, 246 insertions(+), 17 deletions(-) create mode 100644 lightconvpoint/knn/__init__.py create mode 100644 lightconvpoint/knn/furthest_point_sampling.py create mode 100644 lightconvpoint/knn/knn.py create mode 100644 lightconvpoint/knn/quantized_sampling.py create mode 100644 lightconvpoint/knn/random_sampling.py diff --git a/doc/install.md b/doc/install.md index b587f59..b30f291 100644 --- a/doc/install.md +++ b/doc/install.md @@ -7,6 +7,18 @@ - Pytorch ## Library installation + +We provide two intallation modes, one that rely only on the pytorch-geometric packqge and the other that uses our c++ libraries for k-nearest neighbor search and quantized sampling. +The later is faster on our computers. + + +### Installation with compilation of our knn libraries +``` +pip install -ve /path/to/LightConvPoint/ --install-option="--compile" +``` + +### Installation without compilation ``` pip install -ve /path/to/LightConvPoint/ -``` \ No newline at end of file +``` + diff --git a/lightconvpoint/knn/__init__.py b/lightconvpoint/knn/__init__.py new file mode 100644 index 0000000..a46ed8c --- /dev/null +++ b/lightconvpoint/knn/__init__.py @@ -0,0 +1,4 @@ +from .knn import knn +from .quantized_sampling import quantized_pick_knn +from .random_sampling import random_pick_knn +from .furthest_point_sampling import farthest_pick_knn \ No newline at end of file diff --git a/lightconvpoint/knn/furthest_point_sampling.py b/lightconvpoint/knn/furthest_point_sampling.py new file mode 100644 index 0000000..d216940 --- /dev/null +++ b/lightconvpoint/knn/furthest_point_sampling.py @@ -0,0 +1,39 @@ +import torch +import math +from torch_geometric.nn.pool import fps +from lightconvpoint.knn import knn + +import importlib +knn_c_func_spec = importlib.util.find_spec('lightconvpoint.knn_c_func') + +if knn_c_func_spec is not None: + knn_c_func = importlib.util.module_from_spec(knn_c_func_spec) + knn_c_func_spec.loader.exec_module(knn_c_func) + +def farthest_pick_knn(points: torch.Tensor, nqueries: int, K: int): + + if knn_c_func_spec is not None: + return knn_c_func.farthest_pick_knn(points, nqueries, K) + + bs, dim, nx = points.shape + + ratio = nqueries / nx + + batch_x = torch.arange(0, bs, dtype=torch.long, device=points.device).unsqueeze(1).expand(bs,nx) + + x = points.transpose(1,2).reshape(-1, dim) + batch_x = batch_x.view(-1) + + indices_queries = fps(x, batch_x, ratio) + + points_queries = x[indices_queries] + + indices_queries = indices_queries.view(bs, -1) + points_queries = points_queries.view(bs,-1,3) + points_queries = points_queries.transpose(1,2) + + assert(indices_queries.shape[1] == nqueries) + + indices_knn = knn(points, points_queries, K) + + return indices_queries, indices_knn, points_queries \ No newline at end of file diff --git a/lightconvpoint/knn/knn.py b/lightconvpoint/knn/knn.py new file mode 100644 index 0000000..8023190 --- /dev/null +++ b/lightconvpoint/knn/knn.py @@ -0,0 +1,34 @@ +import torch +from torch_geometric.nn.pool import knn as tc_knn + +import importlib +knn_c_func_spec = importlib.util.find_spec('lightconvpoint.knn_c_func') + +if knn_c_func_spec is not None: + knn_c_func = importlib.util.module_from_spec(knn_c_func_spec) + knn_c_func_spec.loader.exec_module(knn_c_func) + +def knn(points: torch.Tensor, queries: torch.Tensor, K: int): + + if knn_c_func_spec is not None: + return knn_c_func.knn(points, queries, K) + + bs = points.shape[0] + dim= points.shape[1] + nx = points.shape[2] + ny = queries.shape[2] + + K = min(K, nx) + + batch_x = torch.arange(0, bs, dtype=torch.long, device=points.device).unsqueeze(1).expand(bs,nx) + batch_y = torch.arange(0, bs, dtype=torch.long, device=queries.device).unsqueeze(1).expand(bs,ny) + + x = points.transpose(1,2).reshape(-1, dim) + y = queries.transpose(1,2).reshape(-1, dim) + batch_x = batch_x.view(-1) + batch_y = batch_y.view(-1) + + indices = tc_knn(x,y,K,batch_x=batch_x, batch_y=batch_y) + indices = indices[1] + return indices.view(bs,ny,K) + diff --git a/lightconvpoint/knn/quantized_sampling.py b/lightconvpoint/knn/quantized_sampling.py new file mode 100644 index 0000000..d3bbf46 --- /dev/null +++ b/lightconvpoint/knn/quantized_sampling.py @@ -0,0 +1,90 @@ +import torch +import math +from torch_geometric.nn.pool import voxel_grid +from lightconvpoint.knn import knn + +import importlib +knn_c_func_spec = importlib.util.find_spec('lightconvpoint.knn_c_func') + +if knn_c_func_spec is not None: + knn_c_func = importlib.util.module_from_spec(knn_c_func_spec) + knn_c_func_spec.loader.exec_module(knn_c_func) + +def unique(x, dim=None): + """Unique elements of x and indices of those unique elements + https://github.com/pytorch/pytorch/issues/36748#issuecomment-619514810 + + e.g. + + unique(tensor([ + [1, 2, 3], + [1, 2, 4], + [1, 2, 3], + [1, 2, 5] + ]), dim=0) + => (tensor([[1, 2, 3], + [1, 2, 4], + [1, 2, 5]]), + tensor([0, 1, 3])) + """ + unique, inverse = torch.unique( + x, sorted=True, return_inverse=True, dim=dim) + perm = torch.arange(inverse.size(0), dtype=inverse.dtype, + device=inverse.device) + inverse, perm = inverse.flip([0]), perm.flip([0]) + return unique, inverse.new_empty(unique.size(0)).scatter_(0, inverse, perm) + +def quantized_pick_knn(points: torch.Tensor, nqueries: int, K: int): + + if knn_c_func_spec is not None: + return knn_c_func.quantized_pick_knn(points, nqueries, K) + + bs, dim, nx = points.shape + + mini = points.min(dim=2)[0] + maxi = points.max(dim=2)[0] + + initial_voxel_size = (maxi-mini).norm(2, dim=1) / math.sqrt(nqueries) + + indices_queries = [] + points_queries = [] + + for b_id in range(bs): + voxel_size = initial_voxel_size[b_id] + x = points[b_id].transpose(0,1) + + b_selected_points = [] + count = 0 + + while(True): + batch_x = torch.zeros(x.shape[0], device=points.device, dtype=torch.long) + + voxel_ids = voxel_grid(x,batch_x, voxel_size) + _, unique_indices = unique(voxel_ids) + + if count + unique_indices.shape[0] >= nqueries: + unique_indices = unique_indices[torch.randperm(unique_indices.shape[0])] + b_selected_points.append(unique_indices[:nqueries-count]) + count += unique_indices.shape[0] + break + + b_selected_points.append(unique_indices) + count += unique_indices.shape[0] + + select = torch.ones(x.shape[0], dtype=torch.bool, device=x.device) + select[unique_indices] = False + x = x[select] + voxel_size /= 2 + + b_selected_points = torch.cat(b_selected_points, dim=0) + indices_queries.append(b_selected_points) + + points_queries.append(points[b_id].transpose(0,1)[b_selected_points]) + + indices_queries = torch.stack(indices_queries, dim=0) + points_queries = torch.stack(points_queries, dim=0) + points_queries = points_queries.transpose(1,2) + + indices_knn = knn(points, points_queries, K) + + return indices_queries, indices_knn, points_queries diff --git a/lightconvpoint/knn/random_sampling.py b/lightconvpoint/knn/random_sampling.py new file mode 100644 index 0000000..2a4ae41 --- /dev/null +++ b/lightconvpoint/knn/random_sampling.py @@ -0,0 +1,37 @@ +import torch +import math +from lightconvpoint.knn import knn + +import importlib +knn_c_func_spec = importlib.util.find_spec('lightconvpoint.knn_c_func') + +if knn_c_func_spec is not None: + knn_c_func = importlib.util.module_from_spec(knn_c_func_spec) + knn_c_func_spec.loader.exec_module(knn_c_func) + +def random_pick_knn(points: torch.Tensor, nqueries: int, K: int): + + if knn_c_func_spec is not None: + return knn_c_func.random_pick_knn(points, nqueries, K) + + bs, dim, nx = points.shape + + indices_queries = [] + points_queries = [] + + for b_id in range(bs): + + indices_queries_ = torch.randperm(nx)[:nqueries] + indices_queries.append(indices_queries_) + + x = points[b_id].transpose(0,1) + points_queries.append(x[indices_queries_]) + + + indices_queries = torch.stack(indices_queries, dim=0) + points_queries = torch.stack(points_queries, dim=0) + points_queries = points_queries.transpose(1,2) + + indices_knn = knn(points, points_queries, K) + + return indices_queries, indices_knn, points_queries \ No newline at end of file diff --git a/setup.py b/setup.py index 869e6c4..3192dbc 100755 --- a/setup.py +++ b/setup.py @@ -1,22 +1,35 @@ from setuptools import setup from torch.utils import cpp_extension +import sys + +print(sys.argv) + +if "--compile" in sys.argv: + print("LIGHTCONVPOINT -- COMPILING CPP MODULES") + ext_modules=[ + cpp_extension.CppExtension( + "lightconvpoint.knn_c_func", + [ + "lightconvpoint/src/knn.cxx", + "lightconvpoint/src/knn_bind.cxx", + "lightconvpoint/src/knn_random.cxx", + "lightconvpoint/src/knn_farthest.cxx", + "lightconvpoint/src/knn_convpoint.cxx", + "lightconvpoint/src/knn_quantized.cxx", + ], + extra_compile_args=["-fopenmp"], + extra_link_args=["-fopenmp"], + ) + ] + cmdclass={"build_ext": cpp_extension.BuildExtension} + sys.argv.remove("--compile") +else: + print("LIGHTCONVPOINT -- PYTHON MODULES") + ext_modules=[] + cmdclass={} setup( name="lightconvpoint", - ext_modules=[ - cpp_extension.CppExtension( - "lightconvpoint.knn", - [ - "lightconvpoint/src/knn.cxx", - "lightconvpoint/src/knn_bind.cxx", - "lightconvpoint/src/knn_random.cxx", - "lightconvpoint/src/knn_farthest.cxx", - "lightconvpoint/src/knn_convpoint.cxx", - "lightconvpoint/src/knn_quantized.cxx", - ], - extra_compile_args=["-fopenmp"], - extra_link_args=["-fopenmp"], - ) - ], - cmdclass={"build_ext": cpp_extension.BuildExtension}, + ext_modules=ext_modules, + cmdclass=cmdclass, ) From eb99fe0d3bfaf65828fc8364b89fb9ce2482bee0 Mon Sep 17 00:00:00 2001 From: aboulch Date: Fri, 4 Dec 2020 16:58:01 +0100 Subject: [PATCH 2/3] indices correction --- lightconvpoint/knn/quantized_sampling.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/lightconvpoint/knn/quantized_sampling.py b/lightconvpoint/knn/quantized_sampling.py index d3bbf46..c7e9281 100644 --- a/lightconvpoint/knn/quantized_sampling.py +++ b/lightconvpoint/knn/quantized_sampling.py @@ -56,6 +56,8 @@ def quantized_pick_knn(points: torch.Tensor, nqueries: int, K: int): b_selected_points = [] count = 0 + x_ids = torch.arange(x.shape[0]) + while(True): batch_x = torch.zeros(x.shape[0], device=points.device, dtype=torch.long) @@ -64,16 +66,17 @@ def quantized_pick_knn(points: torch.Tensor, nqueries: int, K: int): if count + unique_indices.shape[0] >= nqueries: unique_indices = unique_indices[torch.randperm(unique_indices.shape[0])] - b_selected_points.append(unique_indices[:nqueries-count]) + b_selected_points.append(x_ids[unique_indices[:nqueries-count]]) count += unique_indices.shape[0] break - b_selected_points.append(unique_indices) + b_selected_points.append(x_ids[unique_indices]) count += unique_indices.shape[0] select = torch.ones(x.shape[0], dtype=torch.bool, device=x.device) select[unique_indices] = False x = x[select] + x_ids = x_ids[select] voxel_size /= 2 b_selected_points = torch.cat(b_selected_points, dim=0) From 810fd86e438a6a861dbfe237789e94e3eb2ed662 Mon Sep 17 00:00:00 2001 From: Alexandre Boulch Date: Tue, 8 Dec 2020 14:44:25 +0000 Subject: [PATCH 3/3] update setup.py and doc --- doc/install.md | 13 +++++++++---- setup.py | 10 +++++----- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/doc/install.md b/doc/install.md index b30f291..a749f28 100644 --- a/doc/install.md +++ b/doc/install.md @@ -11,14 +11,19 @@ We provide two intallation modes, one that rely only on the pytorch-geometric packqge and the other that uses our c++ libraries for k-nearest neighbor search and quantized sampling. The later is faster on our computers. - ### Installation with compilation of our knn libraries + +This is the default installation procedure, and the one used in the FKAConv paper. + ``` -pip install -ve /path/to/LightConvPoint/ --install-option="--compile" +pip install -ve /path/to/LightConvPoint/ ``` ### Installation without compilation + +We also provide a version that rely only on pytorch geometric. +However, this version is slower and comes with no performance guaranty. + ``` -pip install -ve /path/to/LightConvPoint/ +pip install -ve /path/to/LightConvPoint/ --install-option="--nocompile" ``` - diff --git a/setup.py b/setup.py index 3192dbc..b34bc75 100755 --- a/setup.py +++ b/setup.py @@ -4,7 +4,11 @@ print(sys.argv) -if "--compile" in sys.argv: +if "--nocompile" in sys.argv: + print("LIGHTCONVPOINT -- PYTHON MODULES") + ext_modules=[] + cmdclass={} +else: print("LIGHTCONVPOINT -- COMPILING CPP MODULES") ext_modules=[ cpp_extension.CppExtension( @@ -23,10 +27,6 @@ ] cmdclass={"build_ext": cpp_extension.BuildExtension} sys.argv.remove("--compile") -else: - print("LIGHTCONVPOINT -- PYTHON MODULES") - ext_modules=[] - cmdclass={} setup( name="lightconvpoint",