Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory Access out of bounds in mra/cuda_kernel.cu::index_max_cuda_kernel() #35507

Open
4 tasks
dingfen opened this issue Jan 4, 2025 · 0 comments
Open
4 tasks
Labels

Comments

@dingfen
Copy link

dingfen commented Jan 4, 2025

System Info

  • OS: Linux ubuntu 22.04 LTS
  • Device: A100-80GB
  • docker: nvidia/pytorch:24.04-py3
  • transformers: latest, 4.47.0

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Reproduction

  1. pip install the latest transformers
  2. prepare the UT test enviroments by pip install -e .[testing]
  3. pytest tests/models/mra/test_modeling_mra.py

Analysis

There might be some memory access out-of-bound behaviours in CUDA kernel index_max_cuda_kernel()
https://github.com/huggingface/transformers/blob/main/src/transformers/kernels/mra/cuda_kernel.cu#L6C1-L58C2

Note that max_buffer in this kernel is extern __shared__ float type, which means max_buffer would be stored in shared memory.
According to https://github.com/huggingface/transformers/blob/main/src/transformers/kernels/mra/cuda_launch.cu#L24-L35, CUDA would launch this kernel with

  • gird size: batch_size
  • block size: 256
  • shared memory size: A_num_block * 32 * sizeof(float)

In case that A_num_block < 4, the for statement below might accidentally locate the memory out of A_num_block * 32, since num_thread here is 256, and threadIdx.x is [0, 255].

for (int idx_start = 0; idx_start < 32 * num_block; idx_start = idx_start + num_thread) {

Therefore, when threadblocks of threads try to access max_buffer, it would be wiser and more careful to always add if statements before to avoid memory access out of bounds.

So We suggest to add if statements in two places:
捕获

Expected behavior

UT tests should all pass!

@dingfen dingfen added the bug label Jan 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant