Skip to content

Commit

Permalink
Separate mscclpp-test kernels (#122)
Browse files Browse the repository at this point in the history
Separate different kernel implementations in mscclpp-test to reduce the
number of registers required by the kernels.
  • Loading branch information
chhwang authored Jul 10, 2023
1 parent 2b98333 commit 1d71715
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 64 deletions.
68 changes: 38 additions & 30 deletions test/mscclpp-test/allgather_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@ __constant__ mscclpp::ProxyChannel constRawProxyChan[16];

__constant__ mscclpp::SmChannel constSmChans[8];

__device__ void allgather0(mscclpp::SimpleProxyChannel proxyChan, int rank, int worldSize, int remoteRank,
size_t nelemsPerGPU) {
__global__ void allgather0(int rank, int worldSize, size_t nelemsPerGPU) {
int warpId = threadIdx.x / 32;

// Each warp is responsible for one of the remote ranks
mscclpp::SimpleProxyChannel proxyChan = constProxyChans[warpId];

// this allgather is really simple and implemented as an alltoall

// this thread's role is a sender role
Expand Down Expand Up @@ -105,14 +109,24 @@ __device__ void localAllGatherSm(int rank, int nRanksPerNode, int startRankChunk
constSmChans[peerIdx].get(offset + offsetForThisBlock, sizeForThisBlock, threadIdx.x, blockDim.x);
}

__device__ void allgather1(mscclpp::SimpleProxyChannel proxyChan, int rank, int worldSize, int nRanksPerNode,
int remoteRank, size_t nelemsPerGPU) {
__global__ void allgather1(int rank, int worldSize, int nRanksPerNode, size_t nelemsPerGPU) {
int warpId = threadIdx.x / 32;
int remoteRank = (warpId < rank) ? warpId : warpId + 1;

// Each warp is responsible for one of the remote ranks
mscclpp::SimpleProxyChannel proxyChan = constProxyChans[warpId];

localAllGather(proxyChan, rank, worldSize, nRanksPerNode, remoteRank, rank * nelemsPerGPU * sizeof(int),
nelemsPerGPU * sizeof(int));
}

__device__ void allgather2(mscclpp::SimpleProxyChannel proxyChan, int rank, int worldSize, int nRanksPerNode,
int remoteRank, size_t nelemsPerGPU) {
__global__ void allgather2(int rank, int worldSize, int nRanksPerNode, size_t nelemsPerGPU) {
int warpId = threadIdx.x / 32;
int remoteRank = (warpId < rank) ? warpId : warpId + 1;

// Each warp is responsible for one of the remote ranks
mscclpp::SimpleProxyChannel proxyChan = constProxyChans[warpId];

// this allgather is a pipelined and hierarchical one and only works for two nodes
// it is implemented as follows:
// Step 1: each node does a local allgather and concurrently,
Expand Down Expand Up @@ -181,7 +195,12 @@ __device__ void allgather2(mscclpp::SimpleProxyChannel proxyChan, int rank, int
}
}

__device__ void allgather3(mscclpp::ProxyChannel proxyChan, int rank, int worldSize) {
__global__ void allgather3(int rank, int worldSize) {
int warpId = threadIdx.x / 32;

// Each warp is responsible for one of the remote ranks
mscclpp::ProxyChannel proxyChan = constRawProxyChan[warpId];

int tid = threadIdx.x;
__syncthreads();
if (tid == 0) {
Expand All @@ -197,7 +216,7 @@ __device__ void allgather3(mscclpp::ProxyChannel proxyChan, int rank, int worldS
}
}

__device__ void allgather4(int rank, int worldSize, int nRanksPerNode, size_t nelemsPerGPU) {
__global__ void allgather4(int rank, int worldSize, int nRanksPerNode, size_t nelemsPerGPU) {
// this allgather is a pipelined and hierarchical one and only works for two nodes
// it is implemented as follows:
// Step 1: each node does a local allgather and concurrently,
Expand Down Expand Up @@ -255,27 +274,6 @@ __device__ void allgather4(int rank, int worldSize, int nRanksPerNode, size_t ne
nBlocksForLocalAllGather);
}

__global__ void kernel(int rank, int worldSize, int nRanksPerNode, size_t nelemsPerGPU, int kernel) {
// find the mapping between remoteRank and devConns
int warpId = threadIdx.x / 32;
int remoteRank = (warpId < rank) ? warpId : warpId + 1;
// Each warp is responsible for one of the remote ranks
mscclpp::SimpleProxyChannel proxyChan = constProxyChans[warpId];

if (kernel == 0) {
allgather0(proxyChan, rank, worldSize, remoteRank, nelemsPerGPU);
} else if (kernel == 1) {
allgather1(proxyChan, rank, worldSize, nRanksPerNode, remoteRank, nelemsPerGPU);
} else if (kernel == 2) {
allgather2(proxyChan, rank, worldSize, nRanksPerNode, remoteRank, nelemsPerGPU);
} else if (kernel == 3) {
mscclpp::ProxyChannel proxyChan = constRawProxyChan[warpId];
allgather3(proxyChan, rank, worldSize);
} else if (kernel == 4) {
allgather4(rank, worldSize, nRanksPerNode, nelemsPerGPU);
}
}

class AllGatherChannelService : public mscclpp::BaseProxyService {
public:
AllGatherChannelService(mscclpp::Communicator& communicator, int worldSize, int rank, int cudaDevice);
Expand Down Expand Up @@ -380,7 +378,17 @@ void AllGatherTestColl::runColl(const TestArgs& args, cudaStream_t stream) {
nBlocks = 1;
nThreads = 32 * (worldSize - 1);
}
kernel<<<nBlocks, nThreads, 0, stream>>>(rank, worldSize, nRanksPerNode, paramCount_, kernelNum);
if (kernelNum == 0) {
allgather0<<<nBlocks, nThreads, 0, stream>>>(rank, worldSize, paramCount_);
} else if (kernelNum == 1) {
allgather1<<<nBlocks, nThreads, 0, stream>>>(rank, worldSize, nRanksPerNode, paramCount_);
} else if (kernelNum == 2) {
allgather2<<<nBlocks, nThreads, 0, stream>>>(rank, worldSize, nRanksPerNode, paramCount_);
} else if (kernelNum == 3) {
allgather3<<<nBlocks, nThreads, 0, stream>>>(rank, worldSize);
} else if (kernelNum == 4) {
allgather4<<<nBlocks, nThreads, 0, stream>>>(rank, worldSize, nRanksPerNode, paramCount_);
}
}

void AllGatherTestColl::initData(const TestArgs& args, std::vector<void*> sendBuff, void* expectedBuff) {
Expand Down
42 changes: 20 additions & 22 deletions test/mscclpp-test/allreduce_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ __device__ void allGatherSm(int rank, int worldSize, int nRanksPerNode, size_t n
nBlocksForLocalAllGather);
}

__device__ void allreduce0(int* buff, int* scratch, int rank, int worldSize, size_t nelems, size_t scratchDataCount) {
__global__ void allreduce0(int* buff, int* scratch, int rank, int worldSize, size_t nelems, size_t scratchDataCount) {
int peerId = blockIdx.x / BLOCKS_PER_PEER;
int isComm = (threadIdx.x == 0) && (blockIdx.x % BLOCKS_PER_PEER == 0);
int remoteRank = (peerId < rank) ? peerId : peerId + 1;
Expand Down Expand Up @@ -560,7 +560,7 @@ __device__ void allreduce0(int* buff, int* scratch, int rank, int worldSize, siz
}
}

__device__ void allreduce1(int* buff, int* scratch, int rank, int worldSize, size_t nelems, size_t scratchDataCount) {
__global__ void allreduce1(int* buff, int* scratch, int rank, int worldSize, size_t nelems, size_t scratchDataCount) {
int isComm = (threadIdx.x == 0) && (blockIdx.x == 0);
int remoteSendRank = (rank + 1) % worldSize;
int remoteRecvRank = (rank + worldSize - 1) % worldSize;
Expand Down Expand Up @@ -671,7 +671,7 @@ __device__ void allreduce1(int* buff, int* scratch, int rank, int worldSize, siz
}
}

__device__ void allreduce2(int* buff, void* scratch, void* putPktBuf, void* getPktBuf, void* result, int rank,
__global__ void allreduce2(int* buff, void* scratch, void* putPktBuf, void* getPktBuf, void* result, int rank,
int nRanksPerNode, int worldSize, size_t nelems) {
int numPeersPerNode = nRanksPerNode - 1;
size_t nPkts = nelems / 2; // 2 elems per packet, assume nelems is even
Expand Down Expand Up @@ -776,35 +776,21 @@ __device__ void allreduce2(int* buff, void* scratch, void* putPktBuf, void* getP
}
}

__device__ void allreduce3(int* buff, int* scratch, void* result, int rank, int nRanksPerNode, int worldSize,
__global__ void allreduce3(int* buff, int* scratch, void* result, int rank, int nRanksPerNode, int worldSize,
size_t nelems) {
reduceScatter(buff, scratch, rank, nRanksPerNode, worldSize, nelems);
if (threadIdx.x == 0 && blockIdx.x == 0) {
allGather(rank, worldSize, nRanksPerNode, nelems / worldSize);
}
}

__device__ void allreduce4(int* buff, int* scratch, void* result, int rank, int nRanksPerNode, int worldSize,
__global__ void allreduce4(int* buff, int* scratch, void* result, int rank, int nRanksPerNode, int worldSize,
size_t nelems) {
reduceScatterSm(buff, scratch, rank, nRanksPerNode, worldSize, nelems);
deviceSyncer.sync(gridDim.x);
allGatherSm(rank, worldSize, nRanksPerNode, nelems / worldSize);
}

__global__ void kernel(void* buff, void* scratch, void* result, void* putPktBuf, void* getPktBuf, int rank,
int nRanksPerNode, int worldSize, size_t nelems, size_t scratchDataCount, int kernel) {
if (kernel == 0)
allreduce0((int*)buff, (int*)scratch, rank, worldSize, nelems, scratchDataCount);
else if (kernel == 1)
allreduce1((int*)buff, (int*)scratch, rank, worldSize, nelems, scratchDataCount);
else if (kernel == 2)
allreduce2((int*)buff, scratch, putPktBuf, getPktBuf, result, rank, nRanksPerNode, worldSize, nelems);
else if (kernel == 3)
allreduce3((int*)buff, (int*)scratch, result, rank, nRanksPerNode, worldSize, nelems);
else if (kernel == 4)
allreduce4((int*)buff, (int*)scratch, result, rank, nRanksPerNode, worldSize, nelems);
}

class AllReduceTestColl : public BaseTestColl {
public:
AllReduceTestColl() = default;
Expand Down Expand Up @@ -845,9 +831,21 @@ void AllReduceTestColl::runColl(const TestArgs& args, cudaStream_t stream) {
tmpBuff = scratchPacketBuff;
nThreadsPerBlock = 1024;
}
kernel<<<nBlocks, nThreadsPerBlock, 0, stream>>>(inputBuff, tmpBuff, resultBuff, putPacketBuff, getPacketBuff, rank,
args.nRanksPerNode, worldSize, paramCount_, scratchDataCount,
kernelNum);
if (kernelNum == 0)
allreduce0<<<nBlocks, nThreadsPerBlock, 0, stream>>>((int*)inputBuff, (int*)tmpBuff, rank, worldSize, paramCount_,
scratchDataCount);
else if (kernelNum == 1)
allreduce1<<<nBlocks, nThreadsPerBlock, 0, stream>>>((int*)inputBuff, (int*)tmpBuff, rank, worldSize, paramCount_,
scratchDataCount);
else if (kernelNum == 2)
allreduce2<<<nBlocks, nThreadsPerBlock, 0, stream>>>((int*)inputBuff, tmpBuff, putPacketBuff, getPacketBuff,
resultBuff, rank, args.nRanksPerNode, worldSize, paramCount_);
else if (kernelNum == 3)
allreduce3<<<nBlocks, nThreadsPerBlock, 0, stream>>>((int*)inputBuff, (int*)tmpBuff, resultBuff, rank,
args.nRanksPerNode, worldSize, paramCount_);
else if (kernelNum == 4)
allreduce4<<<nBlocks, nThreadsPerBlock, 0, stream>>>((int*)inputBuff, (int*)tmpBuff, resultBuff, rank,
args.nRanksPerNode, worldSize, paramCount_);
}

void AllReduceTestColl::initData(const TestArgs& args, std::vector<void*> sendBuff, void* expectedBuff) {
Expand Down
19 changes: 7 additions & 12 deletions test/mscclpp-test/alltoall_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ __device__ void localAlltoall(int rank, int nRanksPerNode, size_t nElements) {
}
}

__device__ void alltoall0(int rank, int worldSize, size_t nElements) {
__global__ void alltoall0(int rank, int worldSize, size_t nElements) {
int remoteRank = (blockIdx.x < rank) ? blockIdx.x : blockIdx.x + 1;
mscclpp::SimpleProxyChannel proxyChan = constProxyChans[blockIdx.x];
if (threadIdx.x == 0) {
Expand All @@ -42,19 +42,10 @@ __device__ void alltoall0(int rank, int worldSize, size_t nElements) {
}
}

__device__ void alltoall1(int rank, int nRanksPerNode, size_t nElements) {
__global__ void alltoall1(int rank, int nRanksPerNode, size_t nElements) {
localAlltoall(rank, nRanksPerNode, nElements);
}

__global__ void kernel(int rank, int worldSize, size_t nElements, int nRanksPerNode, int kernelNum) {
if (kernelNum == 0) {
alltoall0(rank, worldSize, nElements);
}
if (kernelNum == 1) {
alltoall1(rank, nRanksPerNode, nElements);
}
}

class AllToAllTestColl : public BaseTestColl {
public:
AllToAllTestColl() = default;
Expand All @@ -74,7 +65,11 @@ void AllToAllTestColl::runColl(const TestArgs& args, cudaStream_t stream) {
const int nRanksPerNode = args.nRanksPerNode;
CUDATHROW(cudaMemcpyAsync((int*)localRecvBuff + paramCount_ * rank, (int*)localSendBuff + paramCount_ * rank,
paramCount_ * sizeof(int), cudaMemcpyDeviceToDevice, stream));
kernel<<<worldSize - 1, 32, 0, stream>>>(rank, worldSize, paramCount_, nRanksPerNode, kernelNum);
if (kernelNum == 0) {
alltoall0<<<worldSize - 1, 32, 0, stream>>>(rank, worldSize, paramCount_);
} else if (kernelNum == 1) {
alltoall1<<<worldSize - 1, 32, 0, stream>>>(rank, nRanksPerNode, paramCount_);
}
}

void AllToAllTestColl::initData(const TestArgs& args, std::vector<void*> sendBuff, void* expectedBuff) {
Expand Down

0 comments on commit 1d71715

Please sign in to comment.