You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In case that A_num_block < 4, the for statement below might accidentally locate the memory out of A_num_block * 32, since num_thread here is 256, and threadIdx.x is [0, 255].
Therefore, when threadblocks of threads try to access max_buffer, it would be wiser and more careful to always add if statements before to avoid memory access out of bounds.
So We suggest to add if statements in two places:
Expected behavior
UT tests should all pass!
The text was updated successfully, but these errors were encountered:
System Info
Who can help?
No response
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Reproduction
pip install -e .[testing]
pytest tests/models/mra/test_modeling_mra.py
Analysis
There might be some memory access out-of-bound behaviours in CUDA kernel
index_max_cuda_kernel()
https://github.com/huggingface/transformers/blob/main/src/transformers/kernels/mra/cuda_kernel.cu#L6C1-L58C2
Note that
max_buffer
in this kernel isextern __shared__ float
type, which meansmax_buffer
would be stored in shared memory.According to https://github.com/huggingface/transformers/blob/main/src/transformers/kernels/mra/cuda_launch.cu#L24-L35, CUDA would launch this kernel with
batch_size
A_num_block * 32 * sizeof(float)
In case that
A_num_block
< 4, the for statement below might accidentally locate the memory out ofA_num_block * 32
, since num_thread here is 256, and threadIdx.x is [0, 255].Therefore, when threadblocks of threads try to access
max_buffer
, it would be wiser and more careful to always addif
statements before to avoid memory access out of bounds.So We suggest to add
if
statements in two places:Expected behavior
UT tests should all pass!
The text was updated successfully, but these errors were encountered: