From 2776772ace6fb58723091542f7acea259c493689 Mon Sep 17 00:00:00 2001 From: Chang Liu Date: Fri, 22 Nov 2024 08:37:54 -0800 Subject: [PATCH] [Feature] Add gather/scatter support 1D tensor (#74) Migrated from: https://github.com/rapidsai/wholegraph/pull/229 This PR is to add gather/scatter support 1D tensor on python level, as WholeGraph should support basic indexing operations for both 1D (array) and 2D (matrix) wholememory tensors. Without this PR, if with 1D wholememory tensor, gather/scatter op does not work, e.g., https://github.com/rapidsai/wholegraph/blob/0efba33835d6e4e104b5d7101a91e0ea55a6ca53/python/pylibwholegraph/pylibwholegraph/torch/tensor.py#L89 To test, run ``` pytest --cache-clear --import-mode=append tests/wholegraph_torch/ops/test_wholegraph_gather_scatter.py -s ``` **Remaining issue:** On my local test with single GPU, the test can pass. For multiGPU setup, gather op works fine, but 1D scatter seems not working as it would crash at: https://github.com/rapidsai/wholegraph/blob/2e963b98aa6027c300d60e839010d3dd8ca422eb/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_gather_scatter.py#L108 with incorrect scatter outputs: `Indices where allclose fails: tensor([0., 0., 0., ..., 0., 0., 0.]) tensor([ 1435., 1439., 1443., ..., 257703., 257707., 257711.]) ` This would work if this bugfix is merged: https://github.com/rapidsai/cugraph-gnn/pull/73 cc. @linhu-nv Authors: - Chang Liu (https://github.com/chang-l) Approvers: - https://github.com/linhu-nv - Alex Barghi (https://github.com/alexbarghi-nv) URL: https://github.com/rapidsai/cugraph-gnn/pull/74 --- .../ops/test_wholegraph_gather_scatter.py | 50 ++++++++++++------- .../pylibwholegraph/torch/tensor.py | 12 +++-- .../pylibwholegraph/torch/wholememory_ops.py | 6 ++- 3 files changed, 44 insertions(+), 24 deletions(-) diff --git a/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_gather_scatter.py b/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_gather_scatter.py index 361ae4f6..078e6cf2 100644 --- a/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_gather_scatter.py +++ b/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_gather_scatter.py @@ -25,6 +25,8 @@ def gen_int_embedding(indice_tensor, embedding_dim, output_type): + if embedding_dim == 0: + embedding_dim = 1 # unsqueeze 2D for input (2D is required for scatter op) indice_count = indice_tensor.shape[0] indice_part = ( indice_tensor.type(torch.int).reshape(indice_count, 1).repeat(1, embedding_dim) @@ -57,9 +59,14 @@ def scatter_gather_test_cast( f"embedding_dim={embedding_dim}, " f"indice_count={indice_count}, dt={dt}, mt={mt}, ml={ml}" ) - wm_embedding = wmb.create_wholememory_matrix( - dt, embedding_count, embedding_dim, -1, wm_comm, mt, ml, entry_partition - ) + if embedding_dim == 0: + wm_embedding = wmb.create_wholememory_array( + dt, embedding_count, wm_comm, mt, ml, entry_partition + ) + else: + wm_embedding = wmb.create_wholememory_matrix( + dt, embedding_count, embedding_dim, -1, wm_comm, mt, ml, entry_partition + ) scatter_indice = torch.arange( world_rank, embedding_count, world_size, dtype=torch.int64 @@ -93,9 +100,13 @@ def scatter_gather_test_cast( local_ref_start = wm_embedding.get_local_entry_start() local_ref_count = wm_embedding.get_local_entry_count() assert local_start == local_ref_start - assert local_tensor_cuda.dim() == 2 + assert local_tensor_cuda.dim() == 2 if embedding_dim > 0 else 1 assert local_tensor_cuda.shape[0] == local_ref_count - assert local_tensor_cuda.shape[1] == embedding_dim + if local_tensor_cuda.dim() == 2: + assert local_tensor_cuda.shape[1] == embedding_dim + else: + # unsqueeze to 2D for comparison + local_tensor_cuda = local_tensor_cuda.unsqueeze(1) local_tensor = local_tensor_cuda.cpu() local_indices = torch.arange( @@ -118,6 +129,9 @@ def scatter_gather_test_cast( ) embedding_after_gather = embedding_after_gather_cuda.cpu() ref_embedding_gather = gen_int_embedding(gather_indice, embedding_dim, torch.float) + if embedding_after_gather.dim() == 1: + # unsqueeze to 2D for comparison + embedding_after_gather = embedding_after_gather.unsqueeze(1) # print('\ngather_indice=%s\nembedding_after_gather=%s\nref_embedding_gather=%s' % ( # gather_indice, embedding_after_gather, ref_embedding_gather)) assert torch.allclose(embedding_after_gather, ref_embedding_gather) @@ -138,7 +152,6 @@ def routine_func(world_rank: int, world_size: int): wm_comm = wm_comm.wmb_comm embedding_count = 1024 * 256 * world_size + 3 - embedding_dim = 256 indice_count = 100001 dt = wmb.WholeMemoryDataType.DtFloat entry_partition = random_partition(embedding_count, world_size) @@ -154,18 +167,19 @@ def routine_func(world_rank: int, world_size: int): wmb.WholeMemoryMemoryLocation.MlHost, wmb.WholeMemoryMemoryLocation.MlDevice, ]: - if wm_comm.support_type_location(mt, ml): - scatter_gather_test_cast( - wm_comm, - dt, - mt, - ml, - embedding_count, - embedding_dim, - indice_count, - True, - entry_partition, - ) + for embedding_dim in [0, 256]: # 0 is for 1D tensor + if wm_comm.support_type_location(mt, ml): + scatter_gather_test_cast( + wm_comm, + dt, + mt, + ml, + embedding_count, + embedding_dim, + indice_count, + True, + entry_partition, + ) wmb.finalize() diff --git a/python/pylibwholegraph/pylibwholegraph/torch/tensor.py b/python/pylibwholegraph/pylibwholegraph/torch/tensor.py index e46ffa2c..41d8fad3 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/tensor.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/tensor.py @@ -62,7 +62,7 @@ def gather( self, indice: torch.Tensor, *, force_dtype: Union[torch.dtype, None] = None ): assert indice.dim() == 1 - embedding_dim = self.shape[1] + embedding_dim = self.shape[1] if self.dim() == 2 else 1 embedding_count = indice.shape[0] current_cuda_device = "cuda:%d" % (torch.cuda.current_device(),) output_dtype = force_dtype if force_dtype is not None else self.dtype @@ -79,13 +79,17 @@ def gather( get_wholegraph_env_fns(), get_stream(), ) - return output_tensor + return output_tensor.view(-1) if self.dim() == 1 else output_tensor def scatter(self, input_tensor: torch.Tensor, indice: torch.Tensor): assert indice.dim() == 1 - assert input_tensor.dim() == 2 + assert input_tensor.dim() == self.dim() assert indice.shape[0] == input_tensor.shape[0] - assert input_tensor.shape[1] == self.shape[1] + if self.dim() == 2: + assert input_tensor.shape[1] == self.shape[1] + else: + # unsqueeze to 2D tensor because wmb_tensor is unsqueezed within scatter_op + input_tensor = input_tensor.unsqueeze(1) wmb.wholememory_scatter_op( wrap_torch_tensor(input_tensor), wrap_torch_tensor(indice), diff --git a/python/pylibwholegraph/pylibwholegraph/torch/wholememory_ops.py b/python/pylibwholegraph/pylibwholegraph/torch/wholememory_ops.py index eedf4bbb..773f3a36 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/wholememory_ops.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/wholememory_ops.py @@ -39,8 +39,10 @@ def wholememory_gather_forward_functor( assert indices_tensor.dtype == torch.int32 or indices_tensor.dtype == torch.int64 if torch_output_dtype is None: torch_output_dtype = wholememory_dtype_to_torch_dtype(wholememory_tensor.dtype) + + embedding_dim = wholememory_tensor.shape[1] if wholememory_tensor.dim() == 2 else 1 output_tensor = torch.empty( - [indices_tensor.shape[0], wholememory_tensor.shape[1]], + [indices_tensor.shape[0], embedding_dim], device="cuda", dtype=torch_output_dtype, requires_grad=requires_grad, @@ -52,7 +54,7 @@ def wholememory_gather_forward_functor( get_wholegraph_env_fns(), get_stream(), ) - return output_tensor + return output_tensor.view(-1) if wholememory_tensor.dim() == 1 else output_tensor def wholememory_scatter_functor(