Skip to content

Commit

Permalink
Use BLAS for mul! and axpy!
Browse files Browse the repository at this point in the history
  • Loading branch information
baggepinnen committed Jul 5, 2020
1 parent 20f9027 commit 54d11ec
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 3 deletions.
43 changes: 40 additions & 3 deletions src/particles.jl
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ Base.:*(A::Matrix{T}, p::Vector{StaticParticles{T,N}}) where {T<:Union{Float32,F
Perform `v'p::Vector{StaticParticles{T,N}` using BLAS matrix-vector multiply. This function is automatically used when applicable and there is no need to call it manually.
"""
function _pdot(
v::Vector{T},
v::AbstractVector{T},
p::Vector{StaticParticles{T,N}},
) where {T<:Union{Float32,Float64,ComplexF32,ComplexF64},N}
pm = reinterpret(T, p)
Expand All @@ -571,5 +571,42 @@ function _pdot(
StaticParticles{T,N}(Mv)
end

LinearAlgebra.dot(v::Vector{T}, p::Vector{StaticParticles{T,N}}) where {T<:Union{Float32,Float64,ComplexF32,ComplexF64},N} = _pdot(v,p)
LinearAlgebra.dot(p::Vector{StaticParticles{T,N}}, v::Vector{T}) where {T<:Union{Float32,Float64,ComplexF32,ComplexF64},N} = _pdot(v,p)
LinearAlgebra.dot(v::AbstractVector{T}, p::Vector{StaticParticles{T,N}}) where {T<:Union{Float32,Float64,ComplexF32,ComplexF64},N} = _pdot(v,p)
LinearAlgebra.dot(p::Vector{StaticParticles{T,N}}, v::AbstractVector{T}) where {T<:Union{Float32,Float64,ComplexF32,ComplexF64},N} = _pdot(v,p)


function _paxpy!(
a::T,
x::Vector{StaticParticles{T,N}},
y::Vector{StaticParticles{T,N}},
) where {T<:Union{Float32,Float64,ComplexF32,ComplexF64},N}
X = reinterpret(T, x)
Y = reinterpret(T, y)
LinearAlgebra.axpy!(a,X,Y)
StaticParticles{T,N}(reinterpret(StaticParticles{T,N}, Y))
end

LinearAlgebra.axpy!(
a::T,
x::Vector{StaticParticles{T,N}},
y::Vector{StaticParticles{T,N}},
) where {T<:Union{Float32,Float64,ComplexF32,ComplexF64},N} = _paxpy!(a,x,y)





function LinearAlgebra.mul!(
y::Vector{StaticParticles{T,N}},
A::AbstractMatrix{T},
b::Vector{StaticParticles{T,N}},
) where {T<:Union{Float32,Float64,ComplexF32,ComplexF64},N}
Bv = reinterpret(T, b)
B = reshape(Bv, N, :)'
# Y0 = A*B
# reinterpret(StaticParticles{T,N}, vec(Y0'))
Yv = reinterpret(T, y)
Y = reshape(Yv, :, N)
mul!(Y,A,B)
reinterpret(StaticParticles{T,N}, vec(Y'))
end
13 changes: 13 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,14 @@ Random.seed!(0)
@test mean(sum(abs, p'v - MonteCarloMeasurements._pdot(v,p))) < 1e-12
@test mean(sum(abs, v'*Particles.(p) - v'p)) < 1e-12


@test mean(sum(abs, axpy!(2,Matrix(p),copy(Matrix(p))) - Matrix(axpy!(2,p,copy(p))))) < 1e-12
@test mean(sum(abs, axpy!(2,Matrix(p),copy(Matrix(p))) - Matrix(axpy!(2,p,copy(p))))) < 1e-12


y = randn(20) .∓ 1
@test mean(sum(abs, mul!(y,A,p) - mul!(Particles.(y),A,Particles.(p)))) < 1e-12

#
# @btime $A*$p
# @btime _pgemv($A,$p)
Expand All @@ -629,6 +637,11 @@ Random.seed!(0)
# @btime sum($v'*$p)
# @btime sum(_pdot($v,$p))

# @btime mul!($y,$A,$p)
# @btime MonteCarloMeasurements.pmul!($y,$A,$p)
# 178.373 μs (6 allocations: 336 bytes)
# 22.320 μs (0 allocations: 0 bytes)
# 3.705 μs (0 allocations: 0 bytes)
end


Expand Down

0 comments on commit 54d11ec

Please sign in to comment.