Skip to content

Commit

Permalink
Added a few things to the NCCL communicator (#503)
Browse files Browse the repository at this point in the history
  • Loading branch information
brownbaerchen authored Nov 14, 2024
1 parent bf940bd commit b81b47b
Showing 1 changed file with 25 additions and 1 deletion.
26 changes: 25 additions & 1 deletion pySDC/helpers/NCCL_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __getattr__(self, name):
Args:
Name (str): Name of the requested attribute
"""
if name not in ['size', 'rank', 'Get_rank', 'Get_size', 'Split']:
if name not in ['size', 'rank', 'Get_rank', 'Get_size', 'Split', 'Create_cart', 'Is_inter', 'Get_topology']:
cp.cuda.get_current_stream().synchronize()

return getattr(self.commMPI, name)
Expand Down Expand Up @@ -71,6 +71,26 @@ def get_op(self, MPI_op):
else:
raise NotImplementedError('Don\'t know what NCCL operation to use to replace this MPI operation!')

def reduce(self, sendobj, op=MPI.SUM, root=0):
sync = False
if hasattr(sendobj, 'data'):
if hasattr(sendobj.data, 'ptr'):
sync = True
if sync:
cp.cuda.Device().synchronize()

return self.commMPI.reduce(sendobj, op=op, root=root)

def allreduce(self, sendobj, op=MPI.SUM):
sync = False
if hasattr(sendobj, 'data'):
if hasattr(sendobj.data, 'ptr'):
sync = True
if sync:
cp.cuda.Device().synchronize()

return self.commMPI.allreduce(sendobj, op=op)

def Reduce(self, sendbuf, recvbuf, op=MPI.SUM, root=0):
if not hasattr(sendbuf.data, 'ptr'):
return self.commMPI.Reduce(sendbuf=sendbuf, recvbuf=recvbuf, op=op, root=root)
Expand Down Expand Up @@ -113,3 +133,7 @@ def Bcast(self, buf, root=0):
stream = cp.cuda.get_current_stream()

self.commNCCL.bcast(buff=buf.data.ptr, count=count, datatype=dtype, root=root, stream=stream.ptr)

def Barrier(self):
cp.cuda.get_current_stream().synchronize()
self.commMPI.Barrier()

0 comments on commit b81b47b

Please sign in to comment.