Skip to content

Commit

Permalink
Performance improvement for dot and v'p
Browse files Browse the repository at this point in the history
  • Loading branch information
baggepinnen committed Jul 5, 2020
1 parent 459f1b2 commit 296efcb
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 3 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MonteCarloMeasurements"
uuid = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
authors = ["baggepinnen <[email protected]>"]
version = "0.9.1"
version = "0.9.2"

[deps]
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Expand Down
20 changes: 20 additions & 0 deletions src/particles.jl
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ Base.Matrix(v::MvParticles) = Array(v)
Statistics.mean(v::MvParticles) = mean.(v)
Statistics.cov(v::MvParticles,args...;kwargs...) = cov(Matrix(v), args...; kwargs...)
Statistics.cor(v::MvParticles,args...;kwargs...) = cor(Matrix(v), args...; kwargs...)
Statistics.var(v::MvParticles,args...; corrected = true, kwargs...) = sum(abs2, v)/(length(v) - corrected)
Distributions.fit(d::Type{<:MultivariateDistribution}, p::MvParticles) = fit(d,Matrix(p)')
Distributions.fit(d::Type{<:Distribution}, p::AbstractParticles) = fit(d,p.particles)

Expand Down Expand Up @@ -531,3 +532,22 @@ function _pgemv(
end

Base.:*(A::Matrix{T}, p::Vector{StaticParticles{T,N}}) where {T<:Union{Float32,Float64,ComplexF32,ComplexF64},N} = _pgemv(A,p)


"""
_pdot(v::Vector{T}, p::Vector{StaticParticles{T, N}}) where {T, N}
Perform `v'p::Vector{StaticParticles{T,N}` using BLAS matrix-vector multiply. This function is automatically used when applicable and there is no need to call it manually.
"""
function _pdot(
v::Vector{T},
p::Vector{StaticParticles{T,N}},
) where {T<:Union{Float32,Float64,ComplexF32,ComplexF64},N}
pm = reinterpret(T, p)
M = reshape(pm, N, :)
Mv = M*v
StaticParticles{T,N}(Mv)
end

LinearAlgebra.dot(v::Vector{T}, p::Vector{StaticParticles{T,N}}) where {T<:Union{Float32,Float64,ComplexF32,ComplexF64},N} = _pdot(v,p)
LinearAlgebra.dot(p::Vector{StaticParticles{T,N}}, v::Vector{T}) where {T<:Union{Float32,Float64,ComplexF32,ComplexF64},N} = _pdot(v,p)
15 changes: 13 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -610,13 +610,24 @@ Random.seed!(0)
A = randn(20,10)
@test mean(sum(abs, A*p - MonteCarloMeasurements._pgemv(A,p))) < 1e-12
@test mean(sum(abs, A*Particles.(p) - A*p)) < 1e-12

v = randn(10)
@test mean(sum(abs, v'p - MonteCarloMeasurements._pdot(v,p))) < 1e-12
@test mean(sum(abs, p'v - MonteCarloMeasurements._pdot(v,p))) < 1e-12
@test mean(sum(abs, v'*Particles.(p) - v'p)) < 1e-12

#
# @btime $A*$p
# @btime pgemv($A,$p)
# @btime _pgemv($A,$p)
#
# @btime sum($A*$p)
# @btime sum(pgemv($A,$p))
# @btime sum(_pgemv($A,$p))
#
# @btime $v'*$p
# @btime _pdot($v,$p)
#
# @btime sum($v'*$p)
# @btime sum(_pdot($v,$p))

end

Expand Down

0 comments on commit 296efcb

Please sign in to comment.