From 2f72ea3b923bfd58b2ede884f003416885a32fd7 Mon Sep 17 00:00:00 2001 From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com> Date: Sat, 20 Jan 2024 12:04:05 +0100 Subject: [PATCH] NCCL communicators (#392) * Added wrapper for MPI communicator to use NCCL under the hood * Small fix * Moved NCCL communicator wrapper to helpers --------- Co-authored-by: Thomas --- pySDC/helpers/NCCL_communicator.py | 98 +++++++++++++++++++ .../sweeper_classes/generic_implicit_MPI.py | 3 - pySDC/tests/test_sweepers/test_MPI_sweeper.py | 95 ++++++++++++++---- 3 files changed, 176 insertions(+), 20 deletions(-) create mode 100644 pySDC/helpers/NCCL_communicator.py diff --git a/pySDC/helpers/NCCL_communicator.py b/pySDC/helpers/NCCL_communicator.py new file mode 100644 index 0000000000..155c47622b --- /dev/null +++ b/pySDC/helpers/NCCL_communicator.py @@ -0,0 +1,98 @@ +from mpi4py import MPI +from cupy.cuda import nccl +import cupy as cp +import numpy as np + + +class NCCLComm(object): + """ + Wraps an MPI communicator and performs some calls to NCCL functions instead. + """ + + def __init__(self, comm): + """ + Args: + comm (mpi4py.Intracomm): MPI communicator + """ + self.commMPI = comm + + uid = comm.bcast(nccl.get_unique_id(), root=0) + self.commNCCL = nccl.NcclCommunicator(comm.size, uid, comm.rank) + + def __getattr__(self, name): + """ + Pass calls that are not explicitly overridden by NCCL functionality on to the MPI communicator. + When performing any operations that depend on data, we have to synchronize host and device beforehand. + + Args: + Name (str): Name of the requested attribute + """ + if name not in ['size', 'rank', 'Get_rank', 'Get_size', 'Split']: + cp.cuda.get_current_stream().synchronize() + return getattr(self.commMPI, name) + + @staticmethod + def get_dtype(data): + """ + As NCCL doesn't support complex numbers, we have to act as if we're sending two real numbers if using complex. + """ + dtype = data.dtype + if dtype in [np.dtype('float32'), np.dtype('complex64')]: + return nccl.NCCL_FLOAT32 + elif dtype in [np.dtype('float64'), np.dtype('complex128')]: + return nccl.NCCL_FLOAT64 + elif dtype in [np.dtype('int32')]: + return nccl.NCCL_INT32 + elif dtype in [np.dtype('int64')]: + return nccl.NCCL_INT64 + else: + raise NotImplementedError(f'Don\'t know what NCCL dtype to use to send data of dtype {data.dtype}!') + + @staticmethod + def get_count(data): + """ + As NCCL doesn't support complex numbers, we have to act as if we're sending two real numbers if using complex. + """ + if cp.iscomplexobj(data): + return data.size * 2 + else: + return data.size + + def get_op(self, MPI_op): + if MPI_op == MPI.SUM: + return nccl.NCCL_SUM + elif MPI_op == MPI.PROD: + return nccl.NCCL_PROD + elif MPI_op == MPI.MAX: + return nccl.NCCL_MAX + elif MPI_op == MPI.MIN: + return nccl.NCCL_MIN + else: + raise NotImplementedError('Don\'t know what NCCL operation to use to replace this MPI operation!') + + def Reduce(self, sendbuf, recvbuf, op=MPI.SUM, root=0): + dtype = self.get_dtype(sendbuf) + count = self.get_count(sendbuf) + op = self.get_op(op) + recvbuf = cp.empty(1) if recvbuf is None else recvbuf + stream = cp.cuda.get_current_stream() + + self.commNCCL.reduce( + sendbuf=sendbuf.data.ptr, + recvbuf=recvbuf.data.ptr, + count=count, + datatype=dtype, + op=op, + root=root, + stream=stream.ptr, + ) + + def Allreduce(self, sendbuf, recvbuf, op=MPI.SUM): + dtype = self.get_dtype(sendbuf) + count = self.get_count(sendbuf) + op = self.get_op(op) + stream = cp.cuda.get_current_stream() + + self.commNCCL.allReduce( + sendbuf=sendbuf.data.ptr, recvbuf=recvbuf.data.ptr, count=count, datatype=dtype, op=op, stream=stream.ptr + ) diff --git a/pySDC/implementations/sweeper_classes/generic_implicit_MPI.py b/pySDC/implementations/sweeper_classes/generic_implicit_MPI.py index 1533a1129a..0189b51831 100644 --- a/pySDC/implementations/sweeper_classes/generic_implicit_MPI.py +++ b/pySDC/implementations/sweeper_classes/generic_implicit_MPI.py @@ -85,9 +85,6 @@ def compute_residual(self, stage=None): L.status.residual = 0.0 if L.status.residual is None else L.status.residual return None - # check if there are new values (e.g. from a sweep) - # assert L.status.updated - # compute the residual for each node # build QF(u) diff --git a/pySDC/tests/test_sweepers/test_MPI_sweeper.py b/pySDC/tests/test_sweepers/test_MPI_sweeper.py index 1e46dd734e..4d5be4ece9 100644 --- a/pySDC/tests/test_sweepers/test_MPI_sweeper.py +++ b/pySDC/tests/test_sweepers/test_MPI_sweeper.py @@ -1,7 +1,7 @@ import pytest -def run(use_MPI, num_nodes, quad_type, residual_type, imex): +def run(use_MPI, num_nodes, quad_type, residual_type, imex, useNCCL): """ Run a single sweep for a problem and compute the solution at the end point with a sweeper as specified. @@ -35,8 +35,18 @@ def run(use_MPI, num_nodes, quad_type, residual_type, imex): dt = 1e-1 sweeper_params = {'num_nodes': num_nodes, 'quad_type': quad_type, 'QI': 'IEpar', 'QE': 'PIC'} + problem_params = {} + + if useNCCL: + from pySDC.helpers.NCCL_communicator import NCCLComm + from mpi4py import MPI + + sweeper_params['comm'] = NCCLComm(MPI.COMM_WORLD) + problem_params['useGPU'] = True + description = {} description['problem_class'] = problem_class + description['problem_params'] = problem_params description['sweeper_class'] = sweeper_class description['sweeper_params'] = sweeper_params description['level_params'] = {'dt': dt, 'residual_type': residual_type} @@ -47,18 +57,13 @@ def run(use_MPI, num_nodes, quad_type, residual_type, imex): if imex: u0 = controller.MS[0].levels[0].prob.u_exact(0) else: - u0 = np.ones_like(controller.MS[0].levels[0].prob.u_exact(0)) + u0 = controller.MS[0].levels[0].prob.u_exact(0) + 1.0 controller.run(u0, 0, dt) controller.MS[0].levels[0].sweep.compute_end_point() return controller.MS[0].levels[0] -@pytest.mark.mpi4py -@pytest.mark.parametrize("num_nodes", [2]) -@pytest.mark.parametrize("quad_type", ['GAUSS', 'RADAU-RIGHT']) -@pytest.mark.parametrize("residual_type", ['last_abs', 'full_rel']) -@pytest.mark.parametrize("imex", [True, False]) -def test_sweeper(num_nodes, quad_type, residual_type, imex, launch=True): +def individual_test(num_nodes, quad_type, residual_type, imex, useNCCL, launch=True): """ Make a test if the result matches between the MPI and non-MPI versions of a sweeper. Tests solution at the right end point and the residual. @@ -79,7 +84,7 @@ def test_sweeper(num_nodes, quad_type, residual_type, imex, launch=True): my_env['PYTHONPATH'] = '../../..:.' my_env['COVERAGE_PROCESS_START'] = 'pyproject.toml' - cmd = f"mpirun -np {num_nodes} python {__file__} --test_sweeper {num_nodes} {quad_type} {residual_type} {imex}".split() + cmd = f"mpirun -np {num_nodes} python {__file__} --test_sweeper {num_nodes} {quad_type} {residual_type} {imex} {useNCCL}".split() p = subprocess.Popen(cmd, env=my_env, cwd=".") @@ -89,20 +94,76 @@ def test_sweeper(num_nodes, quad_type, residual_type, imex, launch=True): num_nodes, ) else: - import numpy as np - - imex = False if imex == 'False' else True - MPI = run(use_MPI=True, num_nodes=int(num_nodes), quad_type=quad_type, residual_type=residual_type, imex=imex) + if useNCCL: + import cupy as xp + else: + import numpy as xp + + MPI = run( + use_MPI=True, + num_nodes=int(num_nodes), + quad_type=quad_type, + residual_type=residual_type, + imex=imex, + useNCCL=useNCCL, + ) nonMPI = run( - use_MPI=False, num_nodes=int(num_nodes), quad_type=quad_type, residual_type=residual_type, imex=imex + use_MPI=False, + num_nodes=int(num_nodes), + quad_type=quad_type, + residual_type=residual_type, + imex=imex, + useNCCL=False, ) - assert np.allclose(MPI.uend, nonMPI.uend, atol=1e-14), 'Got different solutions at end point!' - assert np.allclose(MPI.status.residual, nonMPI.status.residual, atol=1e-14), 'Got different residuals!' + assert xp.allclose(MPI.uend, nonMPI.uend, atol=1e-14), 'Got different solutions at end point!' + assert xp.allclose(MPI.status.residual, nonMPI.status.residual, atol=1e-14), 'Got different residuals!' + + +@pytest.mark.mpi4py +@pytest.mark.parametrize("num_nodes", [2]) +@pytest.mark.parametrize("quad_type", ['GAUSS', 'RADAU-RIGHT']) +@pytest.mark.parametrize("residual_type", ['last_abs', 'full_rel']) +@pytest.mark.parametrize("imex", [True, False]) +def test_sweeper(num_nodes, quad_type, residual_type, imex, launch=True): + """ + Make a test if the result matches between the MPI and non-MPI versions of a sweeper. + Tests solution at the right end point and the residual. + + Args: + num_nodes (int): The number of nodes to use + quad_type (str): Type of nodes + residual_type (str): Type of residual computation + imex (bool): Use IMEX sweeper or not + launch (bool): If yes, it will launch `mpirun` with the required number of processes + """ + individual_test(num_nodes, quad_type, residual_type, imex, useNCCL=False, launch=launch) + + +@pytest.mark.cupy +@pytest.mark.parametrize("num_nodes", [2]) +@pytest.mark.parametrize("quad_type", ['GAUSS', 'RADAU-RIGHT']) +@pytest.mark.parametrize("residual_type", ['last_abs', 'full_rel']) +@pytest.mark.parametrize("imex", [False]) +def test_sweeper_NCCL(num_nodes, quad_type, residual_type, imex, launch=True): + """ + Make a test if the result matches between the MPI and non-MPI versions of a sweeper. + Tests solution at the right end point and the residual. + + Args: + num_nodes (int): The number of nodes to use + quad_type (str): Type of nodes + residual_type (str): Type of residual computation + imex (bool): Use IMEX sweeper or not + launch (bool): If yes, it will launch `mpirun` with the required number of processes + """ + individual_test(num_nodes, quad_type, residual_type, imex, useNCCL=True, launch=launch) if __name__ == '__main__': import sys if '--test_sweeper' in sys.argv: - test_sweeper(sys.argv[-4], sys.argv[-3], sys.argv[-2], sys.argv[-1], launch=False) + imex = False if sys.argv[-2] == 'False' else True + useNCCL = False if sys.argv[-1] == 'False' else True + individual_test(sys.argv[-5], sys.argv[-4], sys.argv[-3], imex=imex, useNCCL=useNCCL, launch=False)