Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mark more methods as device methods #2336

Draft
wants to merge 9 commits into
base: vc/precompile_tools
Choose a base branch
from
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
NVTX = "5da4648a-3479-48b8-97b9-01cb529c0a1f"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Expand Down Expand Up @@ -65,6 +66,7 @@ Libdl = "1"
LinearAlgebra = "1"
Logging = "1"
NVTX = "0.3.2"
PrecompileTools = "1.2.1"
Preferences = "1"
PrettyTables = "2"
Printf = "1"
Expand Down
6 changes: 3 additions & 3 deletions src/device/intrinsics/atomics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ for A in (AS.Generic, AS.Global, AS.Shared), T in (:Int16, :UInt16)
end

intr = "atom$scope.cas.b16 \$0, [\$1], \$2, \$3;"
@eval @inline atomic_cas!(ptr::LLVMPtr{$T,$A}, cmp::$T, val::$T) =
@eval @device_function @inline atomic_cas!(ptr::LLVMPtr{$T,$A}, cmp::$T, val::$T) =
@asmcall($intr, "=h,l,h,h", true, $T, Tuple{Core.LLVMPtr{$T,$A},$T,$T}, ptr, cmp, val)
end

Expand All @@ -172,7 +172,7 @@ for A in (AS.Generic, AS.Global, AS.Shared)
nb = sizeof(T)*8
fn = Symbol("atomic_$(op)!")
intr = "llvm.nvvm.atomic.load.$op.$nb.p$(convert(Int, A))i$nb"
@eval @inline $fn(ptr::LLVMPtr{$T,$A}, val::$T) =
@eval @device_function @inline $fn(ptr::LLVMPtr{$T,$A}, val::$T) =
@typed_ccall($intr, llvmcall, $T, (LLVMPtr{$T,$A}, $T), ptr, val)
end
end
Expand All @@ -192,7 +192,7 @@ for A in (AS.Generic, AS.Global, AS.Shared), T in (:Float16,)
end

intr = "atom$scope.add.noftz.f16 \$0, [\$1], \$2;"
@eval @inline atomic_add!(ptr::LLVMPtr{$T,$A}, val::$T) =
@eval @device_function @inline atomic_add!(ptr::LLVMPtr{$T,$A}, val::$T) =
@asmcall($intr, "=h,l,h", true, $T, Tuple{Core.LLVMPtr{$T,$A},$T}, ptr, val)
end

Expand Down
11 changes: 6 additions & 5 deletions src/device/intrinsics/cooperative_groups.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Noteworthy missing functionality:
module CG

using ..CUDA
using ..CUDA: i32, Aligned, alignment
using ..CUDA: i32, Aligned, alignment, @device_function

using ..LLVM.Interop
using ..LLVMLoopInfo
Expand Down Expand Up @@ -70,7 +70,7 @@ const grid_workspace = Ptr{grid_workspace_st}
end
end

function get_grid_workspace()
@device_function function get_grid_workspace()
# interpret the address from envreg 1 and 2 as the driver's grid workspace
hi = ccall("llvm.nvvm.read.ptx.sreg.envreg1", llvmcall, UInt32, ())
lo = ccall("llvm.nvvm.read.ptx.sreg.envreg2", llvmcall, UInt32, ())
Expand Down Expand Up @@ -370,7 +370,7 @@ end
return oldArrive
end

@inline function barrier_wait(gg::grid_group, token)
@device_function @inline function barrier_wait(gg::grid_group, token)
arrived = gg.details.barrier

if is_cta_master()
Expand Down Expand Up @@ -548,11 +548,12 @@ end

## pipeline operations

pipeline_commit() = ccall("llvm.nvvm.cp.async.commit.group", llvmcall, Cvoid, ())
@device_function pipeline_commit() = ccall("llvm.nvvm.cp.async.commit.group", llvmcall, Cvoid, ())

pipeline_wait_prior(n) =
@device_function pipeline_wait_prior(n) =
ccall("llvm.nvvm.cp.async.wait.group", llvmcall, Cvoid, (Int32,), n)

# TODO device function?
@generated function pipeline_memcpy_async(dst::LLVMPtr{T}, src::LLVMPtr{T}) where T
size_and_align = sizeof(T)
size_and_align in (4, 8, 16) || :(return error($"Unsupported size $size_and_align"))
Expand Down
3 changes: 3 additions & 0 deletions src/device/intrinsics/misc.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
export clock, nanosleep

@device_functions begin
"""
exit()

Expand Down Expand Up @@ -34,3 +35,5 @@ Puts a thread for a given amount `t`(in nanoseconds).
@asmcall("nanosleep.u32 \$0;", "r", true,
Cvoid, Tuple{UInt32}, convert(UInt32, t))
end

end
5 changes: 4 additions & 1 deletion src/device/intrinsics/synchronization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
export sync_threads, sync_warp
export sync_threads_count, sync_threads_and, sync_threads_or

@device_functions begin
"""
sync_threads()

Expand Down Expand Up @@ -64,7 +65,7 @@ the warp.

export barrier_sync

barrier_sync(id=0) = ccall("llvm.nvvm.barrier.sync", llvmcall, Cvoid, (Int32,), id)
@inline barrier_sync(id=0) = ccall("llvm.nvvm.barrier.sync", llvmcall, Cvoid, (Int32,), id)


## memory barriers (membar)
Expand Down Expand Up @@ -107,3 +108,5 @@ host threads, and all threads in peer devices as occurring before all writes to
memory made by the calling thread after the call to `threadfence_system()`.
"""
@inline threadfence_system() = ccall("llvm.nvvm.membar.sys", llvmcall, Cvoid, ())

end
2 changes: 1 addition & 1 deletion src/device/intrinsics/version.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ end
export compute_capability, ptx_isa_version

for var in ["sm_major", "sm_minor", "ptx_major", "ptx_minor"]
@eval @inline $(Symbol(var))() =
@eval @device_function @inline $(Symbol(var))() =
Base.llvmcall(
$("""@$var = external global i32
define i32 @entry() #0 {
Expand Down
6 changes: 3 additions & 3 deletions src/device/intrinsics/warp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ for (name, mode, mask, offset) in (("_up", :up, UInt32(0x00), src->src),
for (T,typ) in ((Int32, "i32"), (UInt32, "i32"), (Float32, "f32"))
intrinsic = "llvm.nvvm.shfl.sync.$mode.$typ"
@eval begin
@inline $fname(mask, val::$T, src, width=$ws) =
@device_function @inline $fname(mask, val::$T, src, width=$ws) =
ccall($intrinsic, llvmcall, $T,
(UInt32, $T, UInt32, UInt32),
mask, val, $(offset(:src)), pack(width, $mask))
Expand Down Expand Up @@ -109,7 +109,7 @@ for mode in (:all, :any, :uni)
@eval export $fname

intrinsic = "llvm.nvvm.vote.$mode.sync"
@eval @inline $fname(mask, pred) =
@eval @device_function @inline $fname(mask, pred) =
@typed_ccall($intrinsic, llvmcall, Bool, (UInt32, Bool), mask, pred)
end

Expand All @@ -119,7 +119,7 @@ for mode in (:ballot, )
@eval export $fname

intrinsic = "llvm.nvvm.vote.$mode.sync"
@eval @inline $fname(mask, pred) =
@eval @device_function @inline $fname(mask, pred) =
@typed_ccall($intrinsic, llvmcall, UInt32, (UInt32, Bool), mask, pred)
end

Expand Down
12 changes: 6 additions & 6 deletions src/device/intrinsics/wmma.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
export WMMA
module WMMA

using ..CUDA: AS
using ..CUDA: AS, @device_function
using Core: LLVMPtr

################################################################################
Expand Down Expand Up @@ -196,10 +196,10 @@ for ops in all_ldst_ops,
ptr_ty = :(LLVMPtr{$arr_ty, $addr_space_int})

if sz == 1
@eval $func_name(src_addr, stride) = tuple(ccall($ccall_name, llvmcall, $frag_ty, ($ptr_ty, Int32), src_addr, stride))
@eval @device_function $func_name(src_addr, stride) = tuple(ccall($ccall_name, llvmcall, $frag_ty, ($ptr_ty, Int32), src_addr, stride))
else
struct_ty = Symbol("LLVMStruct$sz")
@eval $func_name(src_addr, stride) = convert(NTuple{$sz, $frag_ty}, ccall($ccall_name, llvmcall, $struct_ty{$frag_ty}, ($ptr_ty, Int32), src_addr, stride))
@eval @device_function $func_name(src_addr, stride) = convert(NTuple{$sz, $frag_ty}, ccall($ccall_name, llvmcall, $struct_ty{$frag_ty}, ($ptr_ty, Int32), src_addr, stride))
end
@eval export $func_name
@eval @doc (@doc llvm_wmma_load) $func_name
Expand Down Expand Up @@ -263,7 +263,7 @@ export llvm_wmma_store

ptr_ty = :(LLVMPtr{$arr_ty, $addr_space_int})

@eval $func_name(dst_addr, data, stride) = ccall($ccall_name, llvmcall, Nothing, ($ptr_ty, $(frag_types...), Int32), dst_addr, $(frag_vars...), stride)
@eval @device_function $func_name(dst_addr, data, stride) = ccall($ccall_name, llvmcall, Nothing, ($ptr_ty, $(frag_types...), Int32), dst_addr, $(frag_vars...), stride)
@eval export $func_name
@eval @doc (@doc llvm_wmma_store) $func_name
end
Expand Down Expand Up @@ -340,10 +340,10 @@ for ops in all_wmma_ops,
c_vars = ntuple(i -> :(c[$i]), c_sz)

if d_sz == 1
@eval $func_name(a, b, c) = tuple(ccall($ccall_name, llvmcall, $d_frag_ty, ($(a_types...), $(b_types...), $(c_types...)), $(a_vars...), $(b_vars...), $(c_vars...)))
@eval @device_function $func_name(a, b, c) = tuple(ccall($ccall_name, llvmcall, $d_frag_ty, ($(a_types...), $(b_types...), $(c_types...)), $(a_vars...), $(b_vars...), $(c_vars...)))
else
struct_ty = Symbol("LLVMStruct$d_sz")
@eval $func_name(a, b, c) = convert(NTuple{$d_sz, $d_frag_ty}, ccall($ccall_name, llvmcall, $struct_ty{$d_frag_ty}, ($(a_types...), $(b_types...), $(c_types...)), $(a_vars...), $(b_vars...), $(c_vars...)))
@eval @device_function $func_name(a, b, c) = convert(NTuple{$d_sz, $d_frag_ty}, ccall($ccall_name, llvmcall, $struct_ty{$d_frag_ty}, ($(a_types...), $(b_types...), $(c_types...)), $(a_vars...), $(b_vars...), $(c_vars...)))
end
@eval export $func_name
@eval @doc (@doc llvm_wmma_mma) $func_name
Expand Down
6 changes: 3 additions & 3 deletions src/device/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ end
macro device_override(ex)
ex = macroexpand(__module__, ex)
esc(quote
Base.Experimental.@overlay(CUDA.method_table, $ex)
Base.Experimental.@overlay($(CUDA).method_table, $ex)
end)
end

Expand All @@ -31,7 +31,7 @@ macro device_function(ex)

esc(quote
$(combinedef(def))
@device_override $ex
$(CUDA).@device_override $ex
end)
end

Expand All @@ -47,7 +47,7 @@ macro device_functions(ex)
push!(out.args, rewrite(arg))
elseif Meta.isexpr(arg, [:function, :(=)])
# rewrite function definitions
push!(out.args, :(@device_function $arg))
push!(out.args, :($(CUDA).@device_function $arg))
else
# preserve all the rest
push!(out.args, arg)
Expand Down
14 changes: 14 additions & 0 deletions src/precompile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,17 @@ precompile(run_and_collect, (Cmd,))
precompile(cudaconvert, (Function,))
precompile(Core.kwfunc(cudacall), (NamedTuple{(:threads, :blocks), Tuple{Int64, Int64}},typeof(cudacall),CuFunction,Type{Tuple{}}))
precompile(Core.kwfunc(launch), (NamedTuple{(:threads, :blocks), Tuple{Int64, Int64}},typeof(launch),CuFunction))

using PrecompileTools: @setup_workload, @compile_workload
@static if VERSION >= v"1.11.0-DEV.1603"
@setup_workload let
@compile_workload begin
target = PTXCompilerTarget(; cap=v"7.5")
params = CUDACompilerParams(; cap=v"7.5", ptx=v"7.5")
config = CompilerConfig(target, params)
mi = GPUCompiler.methodinstance(typeof(identity), Tuple{Nothing})
job = CompilerJob(mi, config)
GPUCompiler.code_native(devnull, job)
end
end
end