Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added JACC BLAS module. Only dot and axpy for the momemnt. Added test… #89

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["pedrovalerolara <[email protected]>", "williamfgc <williamfgc@yah
version = "0.0.4"

[deps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this as AMDGPU is a weak dependency

Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"

Expand Down
12 changes: 6 additions & 6 deletions ext/JACCAMDGPU/JACCAMDGPU.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ function JACC.parallel_reduce(
@roc groupsize=threads gridsize=blocks _parallel_reduce_amdgpu(
N, ret, f, x...)
AMDGPU.synchronize()
@roc groupsize=threads gridsize=threads reduce_kernel_amdgpu(
@roc groupsize=threads gridsize=1 reduce_kernel_amdgpu(
blocks, ret, rret)
AMDGPU.synchronize()
return rret
Expand All @@ -52,7 +52,7 @@ function JACC.parallel_reduce(
@roc groupsize=(Mthreads, Nthreads) gridsize=(Mblocks, Nblocks) _parallel_reduce_amdgpu_MN(
(M, N), ret, f, x...)
AMDGPU.synchronize()
@roc groupsize=(Mthreads, Nthreads) gridsize=(Mthreads, Nthreads) reduce_kernel_amdgpu_MN(
@roc groupsize=(Mthreads, Nthreads) gridsize=(1, 1) reduce_kernel_amdgpu_MN(
(Mblocks, Nblocks), ret, rret)
AMDGPU.synchronize()
return rret
Expand Down Expand Up @@ -125,15 +125,15 @@ end

function reduce_kernel_amdgpu(N, red, ret)
shared_mem = @ROCStaticLocalArray(Float64, 512)
i = (workgroupIdx().x - 1) * workgroupDim().x + workitemIdx().x
i = workitemIdx().x
ii = i
tmp::Float64 = 0.0
if N > 512
while ii <= N
tmp += @inbounds red[ii]
ii += 512
end
else
elseif (i <= N)
tmp = @inbounds red[i]
end
shared_mem[i] = tmp
Expand Down Expand Up @@ -223,8 +223,8 @@ end

function reduce_kernel_amdgpu_MN((M, N), red, ret)
shared_mem = @ROCStaticLocalArray(Float64, 256)
i = (workgroupIdx().x - 1) * workgroupDim().x + workitemIdx().x
j = (workgroupIdx().y - 1) * workgroupDim().y + workitemIdx().y
i = workitemIdx().x
j = workitemIdx().y
ii = i
jj = j

Expand Down
8 changes: 4 additions & 4 deletions ext/JACCCUDA/JACCCUDA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,18 +128,18 @@ end

function reduce_kernel_cuda(N, red, ret)
shared_mem = @cuDynamicSharedMem(Float64, 512)
i = (blockIdx().x - 1) * blockDim().x + threadIdx().x
i = threadIdx().x
ii = i
tmp::Float64 = 0.0
if N > 512
while ii <= N
tmp += @inbounds red[ii]
ii += 512
end
else
tmp = @inbounds red[i]
elseif (i <= N)
tmp = @inbounds red[i]
end
shared_mem[i] = tmp
shared_mem[threadIdx().x] = tmp
sync_threads()
if (i <= 256)
shared_mem[i] += shared_mem[i + 256]
Expand Down
2 changes: 1 addition & 1 deletion ext/JACCONEAPI/JACCONEAPI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ function reduce_kernel_oneapi(N, red, ret)
tmp += @inbounds red[ii]
ii += 256
end
else
elseif (i <= N)
tmp = @inbounds red[i]
end
shared_mem[i] = tmp
Expand Down
3 changes: 3 additions & 0 deletions src/JACC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ include("helper.jl")
# overloaded array functions
include("array.jl")

include("JACCBLAS.jl")
using .BLAS

export Array, @atomic
export parallel_for

Expand Down
21 changes: 21 additions & 0 deletions src/JACCBLAS.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
module BLAS

using JACC

function _axpy(i, alpha, x, y)
@inbounds x[i] += alpha * y[i]
end

function _dot(i, x, y)
return @inbounds x[i] * y[i]
end

function axpy(n::I, alpha, x, y) where {I<:Integer}
JACC.parallel_for(n, _axpy, alpha, x, y)
end

function dot(n::I, x, y) where {I<:Integer}
JACC.parallel_reduce(n, _dot, x, y)
end

end # module BLAS
33 changes: 33 additions & 0 deletions test/tests_amdgpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,36 @@ end
JACC.parallel_for(N, minus_one, x)
@test zeros(N)≈Array(x) rtol=1e-5
end

@testset "JACC.BLAS" begin

function seq_axpy(N, alpha, x, y)
for i in 1:N
@inbounds x[i] += alpha * y[i]
end
end

function seq_dot(N, x, y)
r = 0.0
for i in 1:N
@inbounds r += x[i] * y[i]
end
return r
end

x = ones(1_000)
y = ones(1_000)
jx = JACC.ones(1_000)
jy = JACC.ones(1_000)
alpha = 2.0

seq_axpy(1_000, alpha, x, y)
ref_result = seq_dot(1_000, x, y)

JACC.BLAS.axpy(1_000, alpha, jx, jy)
jresult = JACC.BLAS.dot(1_000, jx, jy)
result = Array(jresult)

@test result[1]≈ref_result rtol=1e-8

end
33 changes: 33 additions & 0 deletions test/tests_cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,36 @@ end
# C[i] = A[i] + B[i]
# end
# end

@testset "JACC.BLAS" begin

function seq_axpy(N, alpha, x, y)
for i in 1:N
@inbounds x[i] += alpha * y[i]
end
end

function seq_dot(N, x, y)
r = 0.0
for i in 1:N
@inbounds r += x[i] * y[i]
end
return r
end

x = ones(1_000)
y = ones(1_000)
jx = JACC.ones(1_000)
jy = JACC.ones(1_000)
alpha = 2.0

seq_axpy(1_000, alpha, x, y)
ref_result = seq_dot(1_000, x, y)

JACC.BLAS.axpy(1_000, alpha, jx, jy)
jresult = JACC.BLAS.dot(1_000, jx, jy)
result = Array(jresult)

@test result[1]≈ref_result rtol=1e-8

end
35 changes: 35 additions & 0 deletions test/tests_oneapi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,38 @@ end

@test Array(x_device)≈x_expected rtol=1e-1
end

@testset "JACC.BLAS" begin

function seq_axpy(N, alpha, x, y)
for i in 1:N
@inbounds x[i] += alpha * y[i]
end
end

function seq_dot(N, x, y)
r = 0.0
for i in 1:N
@inbounds r += x[i] * y[i]
end
return r
end

SIZE = Int32(1_000)
x = ones(Float32, SIZE)
y = ones(Float32, SIZE)
jx = JACC.ones(Float32, SIZE)
jy = JACC.ones(Float32, SIZE)
alpha = Float32(2.0)

seq_axpy(SIZE, alpha, x, y)
ref_result = seq_dot(SIZE, x, y)

JACC.BLAS.axpy(SIZE, alpha, jx, jy)
jresult = JACC.BLAS.dot(SIZE, jx, jy)
result = Array(jresult)

@test result[1]≈ref_result rtol=1e-8

end

33 changes: 33 additions & 0 deletions test/tests_threads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -277,3 +277,36 @@ end

@test f2≈df2 rtol=1e-1
end

@testset "JACC.BLAS" begin

x = ones(1_000)
y = ones(1_000)
jx = JACC.ones(1_000)
jy = JACC.ones(1_000)
alpha = 2.0

function seq_axpy(N, alpha, x, y)
for i in 1:N
@inbounds x[i] += alpha * y[i]
end
end

function seq_dot(N, x, y)
r = 0.0
for i in 1:N
@inbounds r += x[i] * y[i]
end
return r
end

seq_axpy(1_000, alpha, x, y)
ref_result = seq_dot(1_000, x, y)

JACC.BLAS.axpy(1_000, alpha, jx, jy)
jresult = JACC.BLAS.dot(1_000, jx, jy)
result = jresult[1]

@test result≈ref_result rtol=1e-8

end