From 844beaf383597338fa9e1953a8316a4baecc37fb Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 10 Sep 2024 15:37:28 -0400 Subject: [PATCH] feat: starting simple oneDNN wrapper --- src/onednn/api.jl | 78 ++++++++++ src/onednn/memory.jl | 344 +++++++++++++++++++++++++++++++++++++++++++ src/onednn/oneDNN.jl | 22 +++ src/onednn/types.jl | 63 ++++++++ src/onednn/utils.jl | 129 ++++++++++++++++ src/utils.jl | 6 + 6 files changed, 642 insertions(+) create mode 100644 src/onednn/api.jl create mode 100644 src/onednn/memory.jl create mode 100644 src/onednn/types.jl create mode 100644 src/onednn/utils.jl diff --git a/src/onednn/api.jl b/src/onednn/api.jl new file mode 100644 index 00000000..bdd2b59e --- /dev/null +++ b/src/onednn/api.jl @@ -0,0 +1,78 @@ +""" + engine() + +Create a new oneDNN engine. Currently creates a CPU engine. +""" +engine() = Engine() + +""" + global_engine() + +Fetch the global oneDNN engine created in LuxLib. If it doesn't exist, create it. +""" +function global_engine() + if !GLOBAL_ENGINE_INITIALIZED[] + GLOBAL_ENGINE[] = engine() + GLOBAL_ENGINE_INITIALIZED[] = true + end + return GLOBAL_ENGINE[] +end + +""" + get_math_mode() + +Get the current math mode for oneDNN. +""" +function get_math_mode() + mode = Ref{Lib.dnnl_fpmath_mode_t}() + @dnnlcall Lib.dnnl_get_default_fpmath_mode(mode) + dnnl_mode = unwrap_ref(mode) + return if dnnl_mode == Lib.dnnl_fpmath_mode_strict + :strict + elseif dnnl_mode == Lib.dnnl_fpmath_mode_bf16 + :bf16 + elseif dnnl_mode == Lib.dnnl_fpmath_mode_f16 + :f16 + elseif dnnl_mode == Lib.dnnl_fpmath_mode_tf32 + :tf32 + elseif dnnl_mode == Lib.dnnl_fpmath_mode_any + :fastest + else + error("Unknown math mode: $(dnnl_mode). This should not happen. Please open an \ + issue in `LuxLib.jl`.") + end +end + +""" + set_math_mode!(mode) + +Set the current math mode for oneDNN. `mode` must be one of the following: + + - `:strict` -- `Lib.dnnl_fpmath_mode_strict` + - `:bf16` -- `Lib.dnnl_fpmath_mode_bf16` + - `:f16` -- `Lib.dnnl_fpmath_mode_f16` + - `:tf32` -- `Lib.dnnl_fpmath_mode_tf32` + - `:fastest` -- `Lib.dnnl_fpmath_mode_any` + +For details, see [`Lib.dnnl_fpmath_mode_t`](@ref). + +See also [`get_math_mode`](@ref). +""" +function set_math_mode!(mode::Symbol) + dnnl_mode = if mode == :strict + Lib.dnnl_fpmath_mode_strict + elseif mode == :bf16 + Lib.dnnl_fpmath_mode_bf16 + elseif mode == :f16 + Lib.dnnl_fpmath_mode_f16 + elseif mode == :tf32 + Lib.dnnl_fpmath_mode_tf32 + elseif mode == :fastest + Lib.dnnl_fpmath_mode_any + else + error("Invalid math mode: $(mode). Valid modes are `:strict`, `:bf16`, `:f16`, \ + `:tf32`, and `:fastest`.") + end + @dnnlcall Lib.dnnl_set_default_fpmath_mode(dnnl_mode) + return nothing +end diff --git a/src/onednn/memory.jl b/src/onednn/memory.jl new file mode 100644 index 00000000..69c4fae7 --- /dev/null +++ b/src/onednn/memory.jl @@ -0,0 +1,344 @@ +# Memory Descriptor +struct MemoryDesc + handle::Lib.dnnl_memory_desc_t + + MemoryDesc(x::Lib.dnnl_memory_desc_t) = new(x) + MemoryDesc(x) = memory_descriptor(x) +end + +Base.unsafe_convert(::Type{Lib.dnnl_memory_desc_t}, x::MemoryDesc) = x.handle +function Base.unsafe_convert(::Type{Ptr{Lib.dnnl_memory_desc_t}}, x::MemoryDesc) + return Base.unsafe_convert(Ptr{Lib.dnnl_memory_desc_t}, Base.pointer_from_objref(x)) +end + +memory_descriptor(x::MemoryDesc) = x +function Base.cconvert(::Type{Ptr{Lib.dnnl_memory_desc_t}}, x::MemoryDesc) + return Base.cconvert(Ptr{Lib.dnnl_memory_desc_t}, Ref(x.handle)) +end + +function Base.eltype(md::MemoryDesc) + result = Ref{Lib.dnnl_data_type_t}() + @dnnlcall Lib.dnnl_memory_desc_query(md, Lib.dnnl_query_data_type, result) + return dnnl_type_to_julia(unwrap_ref(result)) +end + +function Base.size(md::MemoryDesc) + result = Ref(Vector{Int64}(undef, Lib.DNNL_MAX_NDIMS)) + @dnnlcall Lib.dnnl_memory_desc_query(md, Lib.dnnl_query_dims, result) + return Tuple(reverse(result[][1:ndims(md)])) +end + +function Base.strides(md::MemoryDesc) + result = Ref(Vector{Int64}(undef, Lib.DNNL_MAX_NDIMS)) + @dnnlcall Lib.dnnl_memory_desc_query(md, Lib.dnnl_query_strides, result) + return Tuple(reverse(result[][1:ndims(md)])) +end + +function Base.ndims(md::MemoryDesc) + result = Ref{Lib.dnnl_dim_t}() + @dnnlcall Lib.dnnl_memory_desc_query(md, Lib.dnnl_query_ndims_s32, result) + return Int(unwrap_ref(result)) +end + +function padded_size(md::MemoryDesc) + result = Ref(Vector{Int64}(undef, Lib.DNNL_MAX_NDIMS)) + @dnnlcall Lib.dnnl_memory_desc_query(md, Lib.dnnl_query_padded_dims, result) + padded_dims = result[] + return Tuple(reverse(padded_dims[1:findlast(!=(0), padded_dims)])) +end + +function padded_offsets(md::MemoryDesc) + result = Ref(Vector{Int64}(undef, Lib.DNNL_MAX_NDIMS)) + @dnnlcall Lib.dnnl_memory_desc_query(md, Lib.dnnl_query_padded_offsets, result) + return Tuple(reverse(result[][1:ndims(md)])) +end + +function format_kind(md::MemoryDesc) + result = Ref{Lib.dnnl_format_kind_t}() + @dnnlcall Lib.dnnl_memory_desc_query(md, Lib.dnnl_query_format_kind, result) + return unwrap_ref(result) +end + +function print_memory_descriptor(io::IO, md::MemoryDesc, level::Int=0) + base_desc = "oneDNN Memory Description:" + # TODO: Additional information if the format is "blocked" + join_str = "\n" * " "^(level + 1) + ndims_str = "ndims: $(ndims(md))" + size_str = "size: $(size(md))" + datatype_str = "datatype: $(eltype(md))" + format_kind_str = "format kind: $(format_kind(md))" + padded_dims_str = "padded dims: $(padded_size(md))" + padded_offsets_str = "padded offsets: $(padded_offsets(md))" + desc = join( + [base_desc, ndims_str, size_str, datatype_str, + format_kind_str, padded_dims_str, padded_offsets_str], + join_str) + print(io, desc) +end + +function Base.show(io::IO, ::MIME"text/plain", md::MemoryDesc) + print_memory_descriptor(io, md) +end + +memory_descriptor(x::AbstractArray{T}) where {T} = memory_descriptor(T, size(x), strides(x)) + +function memory_descriptor( + ::Type{T}, dims::Dims{N}, strides::Dims{N}=default_strides(dims)) where {T, N} + handle = Ref{Lib.dnnl_memory_desc_t}() + @dnnlcall dnnl_memory_desc_create_with_strides( + handle, N, reverse(dims), T, reverse(strides)) + return MemoryDesc(unwrap_ref(handle)) +end + +# convenience creation by tag. +function memory_descriptor( + ::Type{T}, dims::Dims{N}, tag::Union{Lib.dnnl_format_tag_t, UInt32}) where {T, N} + handle = Ref{Lib.dnnl_memory_desc_t}() + @dnnlcall dnnl_memory_desc_create_with_tag(handle, N, reverse(dims), T, tag) + return MemoryDesc(unwrap_ref(handle)) +end + +# toany(a::MemoryDesc) = memorydesc(a.data_type, logicalsize(a), dnnl_format_any()) + +# isany(a::Ptr{MemoryDesc}) = isany(unsafe_load(a)) +# isany(a::MemoryDesc) = a.format_kind == Lib.dnnl_format_kind_any + +# function Base.:(==)(a::MaybeRef{MemoryDesc}, b::MaybeRef{MemoryDesc}) +# return Bool(Lib.dnnl_memory_desc_equal(wrap_ref(a), wrap_ref(b))) +# end + +function get_bytes(a::MaybeRef{MemoryDesc}) + return signed(Lib.dnnl_memory_desc_get_size(unwrap_ref(a).handle)) +end + +# Memory Type for oneDNN -- distinct from Memory in Base +struct Memory{T, N, A <: AbstractArray{T}} <: AbstractArray{T, N} + # The underlying array that is supplying the data. + array::A + offset::Int + + # Keep around some information about size and padding. + logicalsize::Dims{N} + + # Memory object from DNNL + memory::MemoryPtr +end + +ArrayInterface.fast_scalar_indexing(::Type{<:Memory}) = false +ArrayInterface.can_setindex(::Type{<:Memory}) = false + +function Base.convert(::Type{Memory{T, N, A}}, x::Memory{T, N, B}) where {T, N, A, B} + return Memory(convert(A, x.array), x.offset, x.logicalsize, x.memory) +end + +memory_descriptor(x::Memory) = MemoryDesc(memory_descriptor_ptr(x)) + +Base.sizeof(x::Memory) = get_bytes(memory_descriptor(x)) + +# toany(x::Memory) = toany(memorydesc(x)) + +Base.size(x::Memory) = x.logicalsize +# logicalsize(x::Memory) = size(x) +Base.strides(x::Memory) = strides(memory_descriptor(x)) +# padded_size(x::Memory{T,N}) where {T,N} = padded_size(memorydesc(x), Val(N)) + +Base.parent(x::Memory) = x.array +# function ChainRulesCore.rrule(::typeof(Base.parent), x::Memory) +# return parent(x), Δ -> (ChainRulesCore.NoTangent(), Δ) +# end + +# arraytype(::Memory{T,N,A}) where {T,N,A} = A + +function Base.show(io::IO, x::Memory) + print(io, "Opaque Memory with ") + print_memory_descriptor(io, memory_descriptor(x)) + x.offset != 1 && print(io, " - SubArray") + return +end +Base.show(io::IO, ::MIME"text/plain", x::Memory) = show(io, x) + +# Base.any(f::F, x::Memory) where {F <: Function} = any(f, materialize(x)) + +# for creating OneDNN arguments +# @inline access_pointer(x, offset, context) = pointer(x, offset) +# function setptr!(x::Memory{T}, context::AccessContext = Reading()) where {T} +# ptr = access_pointer(x.array, x.offset, context) +# @apicall dnnl_memory_set_data_handle_v2(x.memory, ptr, global_stream()) +# end + +# function Base.cconvert( +# ::Type{T}, x::Memory +# ) where {T<:Union{Lib.dnnl_memory_t,Ptr{Lib.dnnl_memory_t}}} +# setptr!(x) +# return Base.cconvert(T, x.memory) +# end + +# Base.cconvert(::Type{Ptr{Lib.dnnl_memory_desc_t}}, x::Memory) = memorydesc_ptr(x) + +# Base.elsize(::Type{<:Memory{T}}) where {T} = sizeof(T) +# function Base.unsafe_convert(::Type{Ptr{T}}, x::Memory{T}) where {T} +# return pointer(x.array) +# end + +# # For constructing DNNL arguments. +# function dnnl_exec_arg(x::Memory, context::AccessContext = Reading()) +# setptr!(x, context) +# return x.memory +# end + +# Try to remove as many layers of wrapping around `A` as possible. +# Since all of the dimension and layout information will be stored in the OneDNN +# `memorydesc`, we don't need to hold onto it on the Julia level, which can potentially +# cause down-stream type instabilities. +Memory(A::AbstractArray) = Memory(ancestor(A), offset(A), size(A), MemoryPtr(A)) + +offset(::AbstractArray) = one(Int64) +offset(x::SubArray) = Base.first_index(x) + +Memory(M::Memory) = M + +# function ChainRulesCore.rrule(::Type{<:Memory}, x) +# return (Memory(x), Δ -> (ChainRulesCore.NoTangent(), Δ)) +# end + +# # Convenience method for creating destination memories from a source memory. +# Base.size(M::Memory) = M.logicalsize +# Base.eltype(M::Memory{T}) where {T} = T + +function Base.getindex(::Memory, I::Vararg{Int, N}) where {N} + throw(ArgumentError("Cannot index opaque memory formats.")) +end + +function Base.setindex!(::Memory, v, I::Vararg{Int, N}) where {N} + throw(ArgumentError("Cannot index opaque memory formats.")) +end + +memory_descriptor(M::Memory) = MemoryDesc(memory_descriptor_ptr(M)) +function memory_descriptor_ptr(M::Memory) + md = Ref{Lib.dnnl_memory_desc_t}() + @dnnlcall Lib.dnnl_memory_get_memory_desc(M.memory, md) + return unwrap_ref(md) +end + +# ##### +# ##### Lazy Transpose +# ##### + +# # General idea: swap the dims and strides. +# # TODO: Need to validate that this is a blocked layout with no tiling ... +# function Base.adjoint(M::Memory{T,2}) where {T} +# dims = size(M) +# strides = Base.strides(memorydesc(M), Val(2)) + +# reversed_dims = reverse(dims) +# desc = memorydesc(T, reversed_dims, reverse(strides)) +# memory = MemoryPtr(parent(M), desc) +# return Memory(parent(M), M.offset, reversed_dims, memory) +# end + +# function Base.permutedims(M::Memory{T,N}, perm::NTuple{N,Int}) where {T,N} +# dims = size(M) +# strides = Base.strides(memorydesc(M), Val(N)) +# dims_permuted = unsafe_permute(dims, perm) +# strides_permuted = unsafe_permute(strides, perm) + +# desc = memorydesc(T, dims_permuted, strides_permuted) +# memory = MemoryPtr(parent(M), desc) +# return Memory(parent(M), M.offset, dims_permuted, memory) +# end + +# function unsafe_permute(a::NTuple{N,Int}, b::NTuple{N,Int}) where {N} +# return ntuple(i -> @inbounds(a[@inbounds b[i]]), Val(N)) +# end + +# ##### +# ##### Construct more memories!! +# ##### + +# function Base.similar( +# x::Memory{U,M}, +# ::Type{T} = eltype(x), +# dims::NTuple{N,Int} = size(x), +# desc::MemoryDesc = (M == N && U === T) ? memorydesc(x) : memorydesc(T, dims), +# ) where {U,T,M,N} +# # Number of bytes to allocate. +# # Since OneDNN is free to reorder and pad, we need to explicitly ask it. +# bytes = getbytes(desc) + +# # Allocate the output array. +# # This will be allocated as just a plain vector with dimensions padded with ones so it +# # has the same dimension as the wrapped "Memory" +# padded_dims = (div(bytes, sizeof(T)), ntuple(_ -> 1, Val(N - 1))...) +# out = similar(ancestor(x), T, padded_dims) + +# # Since we specifically created this array, the offset will always start at one. +# return Memory(out, 1, dims, MemoryPtr(out, desc)) +# end + +# Base.similar(x::Memory{T,M}, dims::NTuple{N,Int}) where {T,M,N} = similar(x, T, dims) +# function Base.similar(x::Memory{T,M}, dims::NTuple{N,Int}, desc::MemoryDesc) where {T,M,N} +# return similar(x, T, dims, desc) +# end + +# materialize(x::AbstractArray, args...; kw...) = x +# function Array(M::Memory{T, N}) where {T, N} +# # Check if this memory is already in the requested layout. +# # If so, return the underlying array. +# desired_strides = default_strides(size(M)) +# actual_strides = strides(M) + +# # In order to return the underlying object, we need to ensure that: +# # 1. The length of the wrapped object is the same as the length of the Memory. +# # This helps handle views correctly. +# # +# # 2. Strides are the same +# if length(parent(M)) == length(M) && desired_strides == actual_strides +# return reshape(parent(M), size(M)) +# end + +# desc = memory_descriptor(T, size(M), desired_strides) +# end +# function materialize(M::Memory{T,N}; allowreorder = true) where {T,N} +# # Check if this memory is already in the requested layout. +# # If so, return the underlying array. +# desired_strides = default_strides(logicalsize(M)) +# actual_strides = strides(M) + +# # In order to return the underlying object, we need to ensure that: +# # 1. The length of the wrapped object is the same as the length of the Memory. +# # This helps handle views correctly. +# # +# # 2. Strides are the same[ +# if length(parent(M)) == length(M) && desired_strides == actual_strides +# return reshape(parent(M), logicalsize(M)) +# end + +# if !allowreorder +# msg = """ +# Expected strides: $desired_strides. +# Found strides: $actual_strides. +# """ +# throw(ArgumentError(msg)) +# end + +# desc = memorydesc(T, logicalsize(M), desired_strides) +# return reshape(parent(reorder(desc, M)), logicalsize(M)) +# end + +# function ChainRulesCore.rrule( +# ::typeof(materialize), x, args::Vararg{Any,N}; kw... +# ) where {N} +# return materialize(x, args...; kw...), +# Δ -> (ChainRulesCore.NoTangent(), Δ, ntuple(_ -> ChainRulesCore.NoTangent(), Val(N))) +# end + +# ##### +# ##### Reshape +# ##### + +# function Base.reshape(memory::Memory{T}, dims::NTuple{N,Int}) where {T,N} +# md = Ref{MemoryDesc}() +# @apicall dnnl_memory_desc_reshape(md, memory, N, Ref(reverse(dims))) +# new_memory = MemoryPtr(parent(memory), md) +# return Memory(parent(memory), memory.offset, dims, new_memory) +# end diff --git a/src/onednn/oneDNN.jl b/src/onednn/oneDNN.jl index c5163e24..3792d268 100644 --- a/src/onednn/oneDNN.jl +++ b/src/onednn/oneDNN.jl @@ -1,5 +1,27 @@ module oneDNN +using ArrayInterface: ArrayInterface +using ..Utils: ancestor + include("lib.jl") # Low-level bindings to oneDNN C API -- automatically generated +include("utils.jl") + +include("types.jl") +include("memory.jl") + +include("api.jl") + +const GLOBAL_ENGINE_INITIALIZED = Ref{Bool}(false) +const GLOBAL_ENGINE = Ref{Engine}() + +function __init__() + # Initialize the global engine. + GLOBAL_ENGINE[] = engine() + GLOBAL_ENGINE_INITIALIZED[] = true + + # Set the default math mode. We set to the fastest mode. + set_math_mode!(:fastest) +end + end diff --git a/src/onednn/types.jl b/src/onednn/types.jl new file mode 100644 index 00000000..84916205 --- /dev/null +++ b/src/onednn/types.jl @@ -0,0 +1,63 @@ +@wrap_type MemoryPtr dnnl_memory_t dnnl_memory_destroy + +function MemoryPtrNoFinalizer(A::AbstractArray, desc=memory_descriptor(A)) + return MemoryPtrNoFinalizer(convert(Ptr{Nothing}, pointer(A)), desc) +end + +function MemoryPtrNoFinalizer(ptr::Ptr{Nothing}, desc) + memory = MemoryPtr(InnerConstructor()) + @dnnlcall dnnl_memory_create(memory, desc, global_engine(), ptr) + return memory +end + +@wrap_type Engine dnnl_engine_t dnnl_engine_destroy + +function EngineNoFinalizer(kind=Lib.dnnl_cpu, index=0) + engine = Engine(InnerConstructor()) + @dnnlcall dnnl_engine_create(engine, kind, index) + return engine +end + +@wrap_type Stream dnnl_stream_t dnnl_stream_destroy + +function StreamNoFinalizer(engine::Engine) + stream = Stream(InnerConstructor()) + @dnnlcall dnnl_stream_create(stream, engine, Lib.dnnl_stream_default_flags) + return stream +end + +@wrap_type Attributes dnnl_primitive_attr_t dnnl_primitive_attr_destroy + +function AttributesNoFinalizer() + attributes = Attributes(InnerConstructor()) + @dnnlcall dnnl_primitive_attr_create(attributes) + @dnnlcall dnnl_primitive_attr_set_scratchpad_mode( + attributes, Lib.dnnl_scratchpad_mode_user) + return attributes +end + +@wrap_type PostOps dnnl_post_ops_t dnnl_post_ops_destroy + +function PostOpsNoFinalizer() + postops = PostOps(InnerConstructor()) + @dnnlcall dnnl_post_ops_create(postops) + return postops +end + +@wrap_type PrimitiveDescriptor dnnl_primitive_desc_t dnnl_primitive_desc_destroy + +function PrimitiveDescriptorNoFinalizer(args...) + return PrimitiveDescriptorNoFinalizer(Lib.dnnl_primitive_desc_create, args...) +end + +function PrimitiveDescriptorNoFinalizer(f::F, args...) where {F <: Function} + descriptor = PrimitiveDescriptor(InnerConstructor()) + @dnnlcall f(descriptor, args...) + return descriptor +end + +function Base.copy(x::PrimitiveDescriptor) + descriptor = PrimitiveDescriptor(InnerConstructor()) + @dnnlcall dnnl_primitive_desc_clone(descriptor, x) + return descriptor +end diff --git a/src/onednn/utils.jl b/src/onednn/utils.jl new file mode 100644 index 00000000..7aba8287 --- /dev/null +++ b/src/onednn/utils.jl @@ -0,0 +1,129 @@ +macro dnnlcall(ex) + expr = dnnlcall_partial_impl(ex) + return quote + status = $(expr) + if status != Lib.dnnl_success + throw(ErrorException("oneDNN call failed with status $(status).")) + end + status + end +end + +function dnnlcall_partial_impl(expr) + expr.head != :call && error("Only call `@dnnlcall` on function calls") + + # Prefix "Lib." in front of the function call. + # However, sometimes the function to call is passed as a higher order function. + # Thus, we only implicitly attach "Lib" is the function name starts with "dnnl". + fname = expr.args[1] + if isa(fname, Symbol) + fname = startswith(string(fname), "dnnl") ? :(Lib.$(fname)) : :($(esc(fname))) + end + + # Escape and convert each of the arguments. + args = expr.args[2:end] + for i in eachindex(args) + # Handle splats. + arg = args[i] + if isa(arg, Expr) && arg.head == :... + args[i] = :(dnnl_convert($(esc(arg.args[1]))...)...) + else + args[i] = :(dnnl_convert($(esc(args[i])))) + end + end + return :($fname($(args...))) +end + +struct InnerConstructor end + +macro wrap_type(jl_name, c_name, destructor) + lower_constructor_name = Symbol(jl_name, :NoFinalizer) + + # Automatically add the "Lib" prefix if required. + c_name isa Symbol && (c_name = :(Lib.$(c_name))) + + return esc(quote + # Type definition + mutable struct $(jl_name) + handle::$(c_name) + $(jl_name)(::InnerConstructor) = new($(c_name)()) + end + + # Use a trick of Lower and Higher constructeors. + # Lower constructors should have the name `$(jl_name)NoFinalizer` and not + # attach finalizers. + # + # The higher constructor will simply forward to the lower constructor but + # attach a finalizer before returning. + function $(jl_name)(args...) + x = $(lower_constructor_name)(args...) + attach_finalizer!(x) + return x + end + + # Finalizer + destroy(x::$(jl_name)) = @dnnlcall $(destructor)(x) + attach_finalizer!(x::$(jl_name)) = finalizer(destroy, x) + + # Conversion functions + Base.unsafe_convert(::Type{$(c_name)}, x::$(jl_name)) = x.handle + function Base.unsafe_convert(::Type{Ptr{$(c_name)}}, x::$(jl_name)) + return Base.unsafe_convert(Ptr{$(c_name)}, Base.pointer_from_objref(x)) + end + end) +end + +const MaybeRef{T} = Union{Ref{T}, T} +const MaybePtr{T} = Union{Ptr{T}, T} + +wrap_ref(x::Ref) = x +wrap_ref(x) = Ref(x) + +unwrap_ref(x::Ref) = x[] +unwrap_ref(x) = x + +dnnl_type(::Type{Float16}) = Lib.dnnl_f16 +dnnl_type(::Type{Float32}) = Lib.dnnl_f32 +dnnl_type(::Type{Float64}) = Lib.dnnl_f64 +dnnl_type(::Type{Int32}) = Lib.dnnl_s32 +dnnl_type(::Type{Int8}) = Lib.dnnl_s8 +dnnl_type(::Type{UInt8}) = Lib.dnnl_u8 +dnnl_type(::Type{Bool}) = Lib.dnnl_boolean +dnnl_type(::T) where {T <: Number} = dnnl_type(T) +dnnl_type(::Type{T}) where {T} = error("No DNNL type for type $T") +dnnl_type(::T) where {T} = error("No DNNL type for $T") + +function dnnl_type_to_julia(x::Lib.dnnl_data_type_t) + x == Lib.dnnl_f16 && return Float16 + x == Lib.dnnl_f32 && return Float32 + x == Lib.dnnl_f64 && return Float64 + x == Lib.dnnl_s32 && return Int32 + x == Lib.dnnl_s8 && return Int8 + x == Lib.dnnl_u8 && return UInt8 + x == Lib.dnnl_boolean && return Bool + error("No Julia type for DNNL type $x") +end + +dnnl_convert(x) = x +dnnl_convert(x, y...) = (dnnl_convert(x), dnnl_convert.(y...)...) +dnnl_convert(::Type{T}) where {T} = dnnl_type(T) +dnnl_convert(x::Dims{N}) where {N} = Ref(dnnl_dims(x)) +# dnnl_convert(x::NTuple{N, oneDNNMemoryDesc}) where {N} = Ref(x) + +# Make a DIMS array +# NOTE: The OneDNN C-API expects a pointer, so we can't just pass a tuple. +# We either need to pass an array, or a Ref{Tuple}. +# Hwere, we choose to do the latter. +function dnnl_dims(x::Dims{N}) where {N} + f(i) = i ≤ length(x) ? Lib.dnnl_dim_t(x[i]) : zero(Lib.dnnl_dim_t) + return ntuple(f, Val(Lib.DNNL_MAX_NDIMS)) +end +dnnl_dims(x::Dims{Lib.DNNL_MAX_NDIMS}) = x + +dnnl_dims(x::AbstractArray) = dnnl_dims(strides(x)) +dnnl_dims() = ntuple(Returns(zero(Int64)), Val(Lib.DNNL_MAX_NDIMS)) +dnnl_dims(::Tuple{}) = dnnl_dims() + +# Formats +default_strides(size::Tuple{Vararg{Int, N}}) where {N} = Base.size_to_strides(1, size...) +# dnnl_format_any() = Lib.dnnl_format_tag_any diff --git a/src/utils.jl b/src/utils.jl index 0a94d8c5..0aaf0f98 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -319,4 +319,10 @@ end CRC.@non_differentiable static_training_mode_check(::Any...) +function ancestor(x::AbstractArray) + p = parent(x) + p === x && return x + return ancestor(p) +end + end