diff --git a/pySDC/helpers/NCCL_communicator.py b/pySDC/helpers/NCCL_communicator.py index 4001498507..40da269fcb 100644 --- a/pySDC/helpers/NCCL_communicator.py +++ b/pySDC/helpers/NCCL_communicator.py @@ -27,7 +27,7 @@ def __getattr__(self, name): Args: Name (str): Name of the requested attribute """ - if name not in ['size', 'rank', 'Get_rank', 'Get_size', 'Split']: + if name not in ['size', 'rank', 'Get_rank', 'Get_size', 'Split', 'Create_cart', 'Is_inter', 'Get_topology']: cp.cuda.get_current_stream().synchronize() return getattr(self.commMPI, name) @@ -71,6 +71,26 @@ def get_op(self, MPI_op): else: raise NotImplementedError('Don\'t know what NCCL operation to use to replace this MPI operation!') + def reduce(self, sendobj, op=MPI.SUM, root=0): + sync = False + if hasattr(sendobj, 'data'): + if hasattr(sendobj.data, 'ptr'): + sync = True + if sync: + cp.cuda.Device().synchronize() + + return self.commMPI.reduce(sendobj, op=op, root=root) + + def allreduce(self, sendobj, op=MPI.SUM): + sync = False + if hasattr(sendobj, 'data'): + if hasattr(sendobj.data, 'ptr'): + sync = True + if sync: + cp.cuda.Device().synchronize() + + return self.commMPI.allreduce(sendobj, op=op) + def Reduce(self, sendbuf, recvbuf, op=MPI.SUM, root=0): if not hasattr(sendbuf.data, 'ptr'): return self.commMPI.Reduce(sendbuf=sendbuf, recvbuf=recvbuf, op=op, root=root) @@ -113,3 +133,7 @@ def Bcast(self, buf, root=0): stream = cp.cuda.get_current_stream() self.commNCCL.bcast(buff=buf.data.ptr, count=count, datatype=dtype, root=root, stream=stream.ptr) + + def Barrier(self): + cp.cuda.get_current_stream().synchronize() + self.commMPI.Barrier()