Skip to content

Commit

Permalink
Fixing potential integer overflow on sequence counter (NVIDIA#729)
Browse files Browse the repository at this point in the history
* Fixing potential integer overflow on sequence counter

Current implementation may potential cause hangs or data corruption

Signed-off-by: Pasha (Pavel) Shamis <[email protected]>

* Fixing typo in comments

Addressing reviewers comments

Signed-off-by: Pasha (Pavel) Shamis <[email protected]>

---------

Signed-off-by: Pasha (Pavel) Shamis <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
  • Loading branch information
shamisp and ksivaman authored Apr 4, 2024
1 parent 1fa5bf1 commit e1e2b76
Showing 1 changed file with 35 additions and 31 deletions.
66 changes: 35 additions & 31 deletions transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@
asm volatile("fence.sc.gpu;\n"); \
}

// Return true if producer > consumer, otherwise false while preventing integer overflow
// If we expect that producer will be 2B+ messages behind consumer
#define CHECK_IDS(producer, consumer) (((unsigned)(producer) - (unsigned)(consumer)) & (~INT_MAX))

template <int RANKS>
__global__ void __launch_bounds__(MAX_THREADS)
userbuffers_fp16_sum_inplace_gpu_rw(const int op, const int flagoffset, const int firstrank,
Expand All @@ -74,7 +78,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
volatile int *flag = (volatile int *)&(myptr[targetgpu]);
userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]);
clock_t s = clock64();
while (*flag < reduce_id) {
while (CHECK_IDS(*flag, reduce_id)) {
if (clock64() - s > TIMEOUT) {
printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id,
*flag);
Expand Down Expand Up @@ -128,7 +132,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
flagptr[physgpu] = reduce_id;
volatile int *flag = (volatile int *)&myptr[targetgpu];
clock_t s = clock64();
while (*flag < reduce_id) {
while (CHECK_IDS(*flag, reduce_id)) {
if (clock64() - s > 2ull * TIMEOUT) {
printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id,
*flag);
Expand Down Expand Up @@ -162,7 +166,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
volatile int *flag = (volatile int *)&(myptr[targetgpu]);
userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]);
clock_t s = clock64();
while (*flag < reduce_id) {
while (CHECK_IDS(*flag, reduce_id)) {
if (clock64() - s > TIMEOUT) {
printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id,
*flag);
Expand Down Expand Up @@ -211,7 +215,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
flagptr[physgpu] = reduce_id;
volatile int *flag = (volatile int *)&myptr[targetgpu];
clock_t s = clock64();
while (*flag < reduce_id) {
while (CHECK_IDS(*flag, reduce_id)) {
if (clock64() - s > 2ull * TIMEOUT) {
printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id,
*flag);
Expand Down Expand Up @@ -273,7 +277,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
volatile int *flag = (volatile int *)&(myptr[targetgpu]);
userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]);
clock_t s = clock64();
while (*flag < reduce_id) {
while (CHECK_IDS(*flag, reduce_id)) {
if (clock64() - s > TIMEOUT) {
printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id,
*flag);
Expand Down Expand Up @@ -348,7 +352,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
volatile int *flag = (volatile int *)&(myptr[targetgpu]);
userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]);
clock_t s = clock64();
while (*flag < reduce_id) {
while (CHECK_IDS(*flag, reduce_id)) {
if (clock64() - s > TIMEOUT) {
printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id,
*flag);
Expand Down Expand Up @@ -422,7 +426,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
flagptr[physgpu] = reduce_id;
volatile int *flag = (volatile int *)&(myptr[targetgpu]);
clock_t s = clock64();
while (*flag < reduce_id) {
while (CHECK_IDS(*flag, reduce_id)) {
if (clock64() - s > TIMEOUT) {
printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id,
*flag);
Expand Down Expand Up @@ -490,7 +494,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
flagptr[physgpu] = reduce_id;
volatile int *flag = (volatile int *)&myptr[targetgpu];
clock_t s = clock64();
while (*flag < reduce_id) {
while (CHECK_IDS(*flag, reduce_id)) {
if (clock64() - s > 2ull * TIMEOUT) {
printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id,
*flag);
Expand Down Expand Up @@ -525,7 +529,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
flagptr[physgpu] = reduce_id;
volatile int *flag = (volatile int *)&(myptr[targetgpu]);
clock_t s = clock64();
while (*flag < reduce_id) {
while (CHECK_IDS(*flag, reduce_id)) {
if (clock64() - s > TIMEOUT) {
printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id,
*flag);
Expand Down Expand Up @@ -610,7 +614,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
flagptr[physgpu] = reduce_id;
volatile int *flag = (volatile int *)&(myptr[targetgpu]);
clock_t s = clock64();
while (*flag < reduce_id) {
while (CHECK_IDS(*flag, reduce_id)) {
if (clock64() - s > TIMEOUT) {
printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", myrank, blockIdx.x,
threadIdx.x, reduce_id, *flag);
Expand Down Expand Up @@ -740,7 +744,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
flagptr[physgpu] = reduce_id;
volatile int *flag = (volatile int *)&myptr[targetgpu];
clock_t s = clock64();
while (*flag < reduce_id) {
while (CHECK_IDS(*flag, reduce_id)) {
if (clock64() - s > 2ull * TIMEOUT) {
printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id,
*flag);
Expand Down Expand Up @@ -800,7 +804,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
volatile int *flag = (volatile int *)&(myptr[targetgpu]);
userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]);
clock_t s = clock64();
while (*flag < reduce_id) {
while (CHECK_IDS(*flag, reduce_id)) {
if (clock64() - s > TIMEOUT) {
printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", myrank, blockIdx.x,
threadIdx.x, reduce_id, *flag);
Expand Down Expand Up @@ -888,7 +892,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
volatile int *flag = (volatile int *)&(myptr[targetgpu]);
userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]);
clock_t s = clock64();
while (*flag < reduce_id) {
while (CHECK_IDS(*flag, reduce_id)) {
if (clock64() - s > TIMEOUT) {
printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", myrank, blockIdx.x,
threadIdx.x, reduce_id, *flag);
Expand Down Expand Up @@ -975,7 +979,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
volatile int *flag = (volatile int *)&(myptr[targetgpu]);
userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]);
clock_t s = clock64();
while (*flag < reduce_id) {
while (CHECK_IDS(*flag, reduce_id)) {
if (clock64() - s > TIMEOUT) {
printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", myrank, blockIdx.x,
threadIdx.x, reduce_id, *flag);
Expand Down Expand Up @@ -1072,7 +1076,7 @@ userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride_atomic_fp8(
volatile int* flag = (volatile int*)&(myptr[targetgpu]);
userptr[threadIdx.x] = reinterpret_cast<int4*>(commbuff[targetgpu+handleridx]);
clock_t s = clock64();
while (*flag < reduce_id) {
while (CHECK_IDS(*flag, reduce_id)) {
if (clock64()-s > TIMEOUT) {
printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n",
myrank, blockIdx.x, threadIdx.x, reduce_id, *flag);
Expand Down Expand Up @@ -1171,7 +1175,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
volatile int *flag = (volatile int *)&(myptr[targetgpu]);
userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]);
clock_t s = clock64();
while (*flag < reduce_id) {
while (CHECK_IDS(*flag, reduce_id)) {
if (clock64() - s > TIMEOUT) {
printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", myrank, blockIdx.x,
threadIdx.x, reduce_id, *flag);
Expand Down Expand Up @@ -1270,7 +1274,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
volatile int *flag = (volatile int *)&(myptr[targetgpu]);
userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]);
clock_t s = clock64();
while (*flag < reduce_id) {
while (CHECK_IDS(*flag, reduce_id)) {
if (clock64() - s > TIMEOUT) {
printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", myrank, blockIdx.x,
threadIdx.x, reduce_id, *flag);
Expand Down Expand Up @@ -1389,7 +1393,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
flagptr[physgpu] = reduce_id;
volatile int *flag = (volatile int *)&myptr[targetgpu];
clock_t s = clock64();
while (*flag < reduce_id) {
while (CHECK_IDS(*flag, reduce_id)) {
if (clock64() - s > 2ull * TIMEOUT) {
printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id,
*flag);
Expand Down Expand Up @@ -1486,7 +1490,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
flagptr[physgpu] = reduce_id;
volatile int *flag = (volatile int *)&myptr[targetgpu];
clock_t s = clock64();
while (*flag < reduce_id) {
while (CHECK_IDS(*flag, reduce_id)) {
if (clock64() - s > 2ull * TIMEOUT) {
printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id,
*flag);
Expand Down Expand Up @@ -1517,7 +1521,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
}
volatile int *flag = (volatile int *)&((reinterpret_cast<int *>(
commbuff[myrank + firstrank]))[flagoffset + threadIdx.x + firstrank]);
while (*flag < basecounter) {
while (CHECK_IDS(*flag, basecounter)) {
}
}
__syncthreads();
Expand Down Expand Up @@ -1635,7 +1639,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
const int end_aligned = start_elem + aligned_elem;

if (mythreadIdx == 0) {
while (*flag < gathercounter) {
while (CHECK_IDS(*flag, gathercounter)) {
}
gathercounter++;
}
Expand Down Expand Up @@ -1694,7 +1698,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
flagptr[flagoffset + gpustep * myrank + firstrank] = basecounter;
}
volatile int *flag = &localflag[gpustep * threadIdx.x + firstrank];
while (*flag < basecounter) {
while (CHECK_IDS(*flag, basecounter)) {
}
}
__syncthreads();
Expand Down Expand Up @@ -1864,7 +1868,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
const int end_aligned = start_elem + aligned_elem;

if (mythreadIdx == 0) {
while (*flag < gathercounter) {
while (CHECK_IDS(*flag, gathercounter)) {
}
gathercounter++;
}
Expand Down Expand Up @@ -1908,7 +1912,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
flagptr[flagoffset + gpustep * myrank + firstrank] = basecounter;
}
volatile int *flag = &localflag[gpustep * threadIdx.x + firstrank];
while (*flag < basecounter) {
while (CHECK_IDS(*flag, basecounter)) {
}
}
__syncthreads();
Expand Down Expand Up @@ -2114,7 +2118,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
const int end_aligned = start_elem + aligned_elem;

if (mythreadIdx == 0) {
while (*flag < gathercounter) {
while (CHECK_IDS(*flag, gathercounter)) {
}
gathercounter++;
}
Expand Down Expand Up @@ -3013,7 +3017,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
const int signal_id = (*recv_id) + 1;
volatile int *flag = (volatile int *)recv_flagptr;
clock_t s = clock64();
while (*flag < signal_id) {
while (CHECK_IDS(*flag, signal_id)) {
if (clock64() - s > TIMEOUT) {
printf("[%d from %d] pullrecv: expected %d, stuck with %d\n", myrank, peer, signal_id,
*flag);
Expand Down Expand Up @@ -3073,7 +3077,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
const int signal_id = (*recv_id) + 1;
volatile int *flag = (volatile int *)flagptr;
clock_t s = clock64();
while (*flag < signal_id) {
while (CHECK_IDS(*flag, signal_id)) {
if (clock64() - s > TIMEOUT) {
printf("[%d from %d] pullrecv: expected %d, stuck with %d\n", myrank, peer, signal_id,
*flag);
Expand Down Expand Up @@ -3142,7 +3146,7 @@ __global__ void kuserbuffers_pushrecv(int myrank, int peer, int *recv_id, int *f
if (*flag >= signal_id)
return;
clock_t s = clock64();
while (atomicAdd_system(flagptr, 0) < signal_id) {
while (CHECK_IDS(*flag, signal_id)) {
if (clock64() - s > TIMEOUT) {
printf("%d from %d] pushrecv: expected %d, stuck with %d\n", myrank, peer, signal_id, *flag);
return;
Expand Down Expand Up @@ -3193,7 +3197,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
if (*flag >= signal_id)
return;
clock_t s = clock64();
while (*flag < signal_id) {
while (CHECK_IDS(*flag, signal_id)) {
if (clock64() - s > TIMEOUT) {
printf("%d from %d] pushrecv: expected %d, stuck with %d\n", myrank, peer, signal_id,
*flag);
Expand Down Expand Up @@ -3245,7 +3249,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
volatile int *flag = (volatile int *)recv_flagptr;
// if(*flag>=signal_id) return;
clock_t s = clock64();
while (*flag < signal_id) {
while (CHECK_IDS(*flag, signal_id)) {
if (clock64() - s > TIMEOUT) {
printf("%d from %d] pushrecv: expected %d, stuck with %d\n", myrank, peer, signal_id,
*flag); /*return;*/
Expand Down Expand Up @@ -3312,7 +3316,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
volatile int *flag = (volatile int *)recv_flagptr;
// if(*flag>=signal_id) return;
clock_t s = clock64();
while (*flag < signal_id) {
while (CHECK_IDS(*flag, signal_id)) {
if (clock64() - s > TIMEOUT) {
printf("%d from %d] pushrecv: expected %d, stuck with %d\n", myrank, peer, signal_id,
*flag); /*return;*/
Expand Down

0 comments on commit e1e2b76

Please sign in to comment.