Skip to content

Commit

Permalink
Apply thread number fix for AMDGPU and oneAPI backends
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilipFackler committed Jan 9, 2025
1 parent 2d9e6df commit f975fe3
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
6 changes: 3 additions & 3 deletions ext/JACCAMDGPU/JACCAMDGPU.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ end
function JACC.parallel_reduce(
::AMDGPUBackend, N::Integer, op, f::Function, x...; init)
numThreads = 512
threads = min(N, numThreads)
threads = numThreads
blocks = ceil(Int, N / threads)
ret = fill!(AMDGPU.ROCArray{typeof(init)}(undef, blocks), init)
rret = AMDGPU.ROCArray([init])
Expand All @@ -84,8 +84,8 @@ end
function JACC.parallel_reduce(
::AMDGPUBackend, (M, N)::Tuple{Integer, Integer}, op, f::Function, x...; init)
numThreads = 16
Mthreads = min(M, numThreads)
Nthreads = min(N, numThreads)
Mthreads = numThreads
Nthreads = numThreads
Mblocks = ceil(Int, M / Mthreads)
Nblocks = ceil(Int, N / Nthreads)
ret = fill!(AMDGPU.ROCArray{typeof(init)}(undef, (Mblocks, Nblocks)), init)
Expand Down
6 changes: 3 additions & 3 deletions ext/JACCONEAPI/JACCONEAPI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ end
function JACC.parallel_reduce(
::oneAPIBackend, N::Integer, op, f::Function, x...; init)
numItems = 256
items = min(N, numItems)
items = numItems
groups = ceil(Int, N / items)
ret = oneAPI.zeros(typeof(init), groups)
rret = oneAPI.zeros(typeof(init), 1)
Expand All @@ -70,8 +70,8 @@ end
function JACC.parallel_reduce(
::oneAPIBackend, (M, N)::Tuple{Integer, Integer}, op, f::Function, x...; init)
numItems = 16
Mitems = min(M, numItems)
Nitems = min(N, numItems)
Mitems = numItems
Nitems = numItems
Mgroups = ceil(Int, M / Mitems)
Ngroups = ceil(Int, N / Nitems)
ret = oneAPI.zeros(typeof(init), (Mgroups, Ngroups))
Expand Down

0 comments on commit f975fe3

Please sign in to comment.