Skip to content

Commit

Permalink
NCCL communicators (#392)
Browse files Browse the repository at this point in the history
* Added wrapper for MPI communicator to use NCCL under the hood

* Small fix

* Moved NCCL communicator wrapper to helpers

---------

Co-authored-by: Thomas <[email protected]>
  • Loading branch information
brownbaerchen and Thomas authored Jan 20, 2024
1 parent a22818f commit 2f72ea3
Show file tree
Hide file tree
Showing 3 changed files with 176 additions and 20 deletions.
98 changes: 98 additions & 0 deletions pySDC/helpers/NCCL_communicator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from mpi4py import MPI
from cupy.cuda import nccl
import cupy as cp
import numpy as np


class NCCLComm(object):
"""
Wraps an MPI communicator and performs some calls to NCCL functions instead.
"""

def __init__(self, comm):
"""
Args:
comm (mpi4py.Intracomm): MPI communicator
"""
self.commMPI = comm

uid = comm.bcast(nccl.get_unique_id(), root=0)
self.commNCCL = nccl.NcclCommunicator(comm.size, uid, comm.rank)

def __getattr__(self, name):
"""
Pass calls that are not explicitly overridden by NCCL functionality on to the MPI communicator.
When performing any operations that depend on data, we have to synchronize host and device beforehand.
Args:
Name (str): Name of the requested attribute
"""
if name not in ['size', 'rank', 'Get_rank', 'Get_size', 'Split']:
cp.cuda.get_current_stream().synchronize()
return getattr(self.commMPI, name)

@staticmethod
def get_dtype(data):
"""
As NCCL doesn't support complex numbers, we have to act as if we're sending two real numbers if using complex.
"""
dtype = data.dtype
if dtype in [np.dtype('float32'), np.dtype('complex64')]:
return nccl.NCCL_FLOAT32
elif dtype in [np.dtype('float64'), np.dtype('complex128')]:
return nccl.NCCL_FLOAT64
elif dtype in [np.dtype('int32')]:
return nccl.NCCL_INT32
elif dtype in [np.dtype('int64')]:
return nccl.NCCL_INT64
else:
raise NotImplementedError(f'Don\'t know what NCCL dtype to use to send data of dtype {data.dtype}!')

@staticmethod
def get_count(data):
"""
As NCCL doesn't support complex numbers, we have to act as if we're sending two real numbers if using complex.
"""
if cp.iscomplexobj(data):
return data.size * 2
else:
return data.size

def get_op(self, MPI_op):
if MPI_op == MPI.SUM:
return nccl.NCCL_SUM
elif MPI_op == MPI.PROD:
return nccl.NCCL_PROD
elif MPI_op == MPI.MAX:
return nccl.NCCL_MAX
elif MPI_op == MPI.MIN:
return nccl.NCCL_MIN
else:
raise NotImplementedError('Don\'t know what NCCL operation to use to replace this MPI operation!')

def Reduce(self, sendbuf, recvbuf, op=MPI.SUM, root=0):
dtype = self.get_dtype(sendbuf)
count = self.get_count(sendbuf)
op = self.get_op(op)
recvbuf = cp.empty(1) if recvbuf is None else recvbuf
stream = cp.cuda.get_current_stream()

self.commNCCL.reduce(
sendbuf=sendbuf.data.ptr,
recvbuf=recvbuf.data.ptr,
count=count,
datatype=dtype,
op=op,
root=root,
stream=stream.ptr,
)

def Allreduce(self, sendbuf, recvbuf, op=MPI.SUM):
dtype = self.get_dtype(sendbuf)
count = self.get_count(sendbuf)
op = self.get_op(op)
stream = cp.cuda.get_current_stream()

self.commNCCL.allReduce(
sendbuf=sendbuf.data.ptr, recvbuf=recvbuf.data.ptr, count=count, datatype=dtype, op=op, stream=stream.ptr
)
3 changes: 0 additions & 3 deletions pySDC/implementations/sweeper_classes/generic_implicit_MPI.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,6 @@ def compute_residual(self, stage=None):
L.status.residual = 0.0 if L.status.residual is None else L.status.residual
return None

# check if there are new values (e.g. from a sweep)
# assert L.status.updated

# compute the residual for each node

# build QF(u)
Expand Down
95 changes: 78 additions & 17 deletions pySDC/tests/test_sweepers/test_MPI_sweeper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest


def run(use_MPI, num_nodes, quad_type, residual_type, imex):
def run(use_MPI, num_nodes, quad_type, residual_type, imex, useNCCL):
"""
Run a single sweep for a problem and compute the solution at the end point with a sweeper as specified.
Expand Down Expand Up @@ -35,8 +35,18 @@ def run(use_MPI, num_nodes, quad_type, residual_type, imex):

dt = 1e-1
sweeper_params = {'num_nodes': num_nodes, 'quad_type': quad_type, 'QI': 'IEpar', 'QE': 'PIC'}
problem_params = {}

if useNCCL:
from pySDC.helpers.NCCL_communicator import NCCLComm
from mpi4py import MPI

sweeper_params['comm'] = NCCLComm(MPI.COMM_WORLD)
problem_params['useGPU'] = True

description = {}
description['problem_class'] = problem_class
description['problem_params'] = problem_params
description['sweeper_class'] = sweeper_class
description['sweeper_params'] = sweeper_params
description['level_params'] = {'dt': dt, 'residual_type': residual_type}
Expand All @@ -47,18 +57,13 @@ def run(use_MPI, num_nodes, quad_type, residual_type, imex):
if imex:
u0 = controller.MS[0].levels[0].prob.u_exact(0)
else:
u0 = np.ones_like(controller.MS[0].levels[0].prob.u_exact(0))
u0 = controller.MS[0].levels[0].prob.u_exact(0) + 1.0
controller.run(u0, 0, dt)
controller.MS[0].levels[0].sweep.compute_end_point()
return controller.MS[0].levels[0]


@pytest.mark.mpi4py
@pytest.mark.parametrize("num_nodes", [2])
@pytest.mark.parametrize("quad_type", ['GAUSS', 'RADAU-RIGHT'])
@pytest.mark.parametrize("residual_type", ['last_abs', 'full_rel'])
@pytest.mark.parametrize("imex", [True, False])
def test_sweeper(num_nodes, quad_type, residual_type, imex, launch=True):
def individual_test(num_nodes, quad_type, residual_type, imex, useNCCL, launch=True):
"""
Make a test if the result matches between the MPI and non-MPI versions of a sweeper.
Tests solution at the right end point and the residual.
Expand All @@ -79,7 +84,7 @@ def test_sweeper(num_nodes, quad_type, residual_type, imex, launch=True):
my_env['PYTHONPATH'] = '../../..:.'
my_env['COVERAGE_PROCESS_START'] = 'pyproject.toml'

cmd = f"mpirun -np {num_nodes} python {__file__} --test_sweeper {num_nodes} {quad_type} {residual_type} {imex}".split()
cmd = f"mpirun -np {num_nodes} python {__file__} --test_sweeper {num_nodes} {quad_type} {residual_type} {imex} {useNCCL}".split()

p = subprocess.Popen(cmd, env=my_env, cwd=".")

Expand All @@ -89,20 +94,76 @@ def test_sweeper(num_nodes, quad_type, residual_type, imex, launch=True):
num_nodes,
)
else:
import numpy as np

imex = False if imex == 'False' else True
MPI = run(use_MPI=True, num_nodes=int(num_nodes), quad_type=quad_type, residual_type=residual_type, imex=imex)
if useNCCL:
import cupy as xp
else:
import numpy as xp

MPI = run(
use_MPI=True,
num_nodes=int(num_nodes),
quad_type=quad_type,
residual_type=residual_type,
imex=imex,
useNCCL=useNCCL,
)
nonMPI = run(
use_MPI=False, num_nodes=int(num_nodes), quad_type=quad_type, residual_type=residual_type, imex=imex
use_MPI=False,
num_nodes=int(num_nodes),
quad_type=quad_type,
residual_type=residual_type,
imex=imex,
useNCCL=False,
)

assert np.allclose(MPI.uend, nonMPI.uend, atol=1e-14), 'Got different solutions at end point!'
assert np.allclose(MPI.status.residual, nonMPI.status.residual, atol=1e-14), 'Got different residuals!'
assert xp.allclose(MPI.uend, nonMPI.uend, atol=1e-14), 'Got different solutions at end point!'
assert xp.allclose(MPI.status.residual, nonMPI.status.residual, atol=1e-14), 'Got different residuals!'


@pytest.mark.mpi4py
@pytest.mark.parametrize("num_nodes", [2])
@pytest.mark.parametrize("quad_type", ['GAUSS', 'RADAU-RIGHT'])
@pytest.mark.parametrize("residual_type", ['last_abs', 'full_rel'])
@pytest.mark.parametrize("imex", [True, False])
def test_sweeper(num_nodes, quad_type, residual_type, imex, launch=True):
"""
Make a test if the result matches between the MPI and non-MPI versions of a sweeper.
Tests solution at the right end point and the residual.
Args:
num_nodes (int): The number of nodes to use
quad_type (str): Type of nodes
residual_type (str): Type of residual computation
imex (bool): Use IMEX sweeper or not
launch (bool): If yes, it will launch `mpirun` with the required number of processes
"""
individual_test(num_nodes, quad_type, residual_type, imex, useNCCL=False, launch=launch)


@pytest.mark.cupy
@pytest.mark.parametrize("num_nodes", [2])
@pytest.mark.parametrize("quad_type", ['GAUSS', 'RADAU-RIGHT'])
@pytest.mark.parametrize("residual_type", ['last_abs', 'full_rel'])
@pytest.mark.parametrize("imex", [False])
def test_sweeper_NCCL(num_nodes, quad_type, residual_type, imex, launch=True):
"""
Make a test if the result matches between the MPI and non-MPI versions of a sweeper.
Tests solution at the right end point and the residual.
Args:
num_nodes (int): The number of nodes to use
quad_type (str): Type of nodes
residual_type (str): Type of residual computation
imex (bool): Use IMEX sweeper or not
launch (bool): If yes, it will launch `mpirun` with the required number of processes
"""
individual_test(num_nodes, quad_type, residual_type, imex, useNCCL=True, launch=launch)


if __name__ == '__main__':
import sys

if '--test_sweeper' in sys.argv:
test_sweeper(sys.argv[-4], sys.argv[-3], sys.argv[-2], sys.argv[-1], launch=False)
imex = False if sys.argv[-2] == 'False' else True
useNCCL = False if sys.argv[-1] == 'False' else True
individual_test(sys.argv[-5], sys.argv[-4], sys.argv[-3], imex=imex, useNCCL=useNCCL, launch=False)

0 comments on commit 2f72ea3

Please sign in to comment.