Skip to content

Commit

Permalink
overhaul promotion (#49)
Browse files Browse the repository at this point in the history
* fix promotion

* Update src/particles.jl

Co-Authored-By: Fredrik Bagge Carlson <[email protected]>

* add back lost comment

* add some tests

* fix ambiguity
  • Loading branch information
simeonschaub authored and baggepinnen committed Dec 19, 2019
1 parent 5c0fd3e commit 2982c79
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 12 deletions.
14 changes: 12 additions & 2 deletions src/particles.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,6 @@ for PT in (:Particles, :StaticParticles)
@eval begin
Base.length(::Type{$PT{T,N}}) where {T,N} = N
Base.eltype(::Type{$PT{T,N}}) where {T,N} = $PT{T,N}
Base.promote_rule(::Type{S}, ::Type{$PT{T,N}}) where {S<:Number,T,N} = $PT{promote_type(S,T),N} # This is hard to hit due to method for real 3 lines down
Base.promote_rule(::Type{Bool}, ::Type{$PT{T,N}}) where {T,N} = $PT{promote_type(Bool,T),N} # Needed since above is not specific enough.

Base.convert(::Type{StaticParticles{T,N}}, p::$PT{T,N}) where {T,N} = StaticParticles(p.particles)
Base.convert(::Type{$PT{T,N}}, f::Real) where {T,N} = $PT{T,N}(fill(T(f),N))
Expand Down Expand Up @@ -270,8 +268,20 @@ for PT in (:Particles, :StaticParticles)
Base.:\(p::Vector{<:$PT}, p2::Vector{<:$PT}) = Matrix(p)\Matrix(p2) # Must be here to be most specific
end

@eval Base.promote_rule(::Type{S}, ::Type{$PT{T,N}}) where {S<:Number,T,N} = $PT{promote_type(S,T),N} # This is hard to hit due to method for real 3 lines down
@eval Base.promote_rule(::Type{Bool}, ::Type{$PT{T,N}}) where {T,N} = $PT{promote_type(Bool,T),N}

for PT2 in (:Particles, :StaticParticles)
if PT == PT2
@eval Base.promote_rule(::Type{$PT{S,N}}, ::Type{$PT{T,N}}) where {S,T,N} = $PT{promote_type(S,T),N}
elseif any(==(:StaticParticles), (PT, PT2))
@eval Base.promote_rule(::Type{$PT{S,N}}, ::Type{$PT2{T,N}}) where {S,T,N} = StaticParticles{promote_type(S,T),N}
else
@eval Base.promote_rule(::Type{$PT{S,N}}, ::Type{$PT2{T,N}}) where {S,T,N} = Particles{promote_type(S,T),N}
end
end

@eval Base.promote_rule(::Type{<:AbstractParticles}, ::Type{$PT{T,N}}) where {T,N} = Union{}
end

Base.length(p::AbstractParticles{T,N}) where {T,N} = N
Expand Down
29 changes: 19 additions & 10 deletions src/register_primitive.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,33 +22,41 @@ function register_primitive_multi(ff, eval=eval)
for PT in (:Particles, :StaticParticles)
eval(quote
function ($m.$f)(p::$PT{T,N},a::Real...) where {T,N}
$PT{T,N}(($m.$f).(p.particles, MonteCarloMeasurements.maybe_particles.(a)...)) # maybe_particles introduced to handle >2 arg operators
res = ($m.$f).(p.particles, MonteCarloMeasurements.maybe_particles.(a)...) # maybe_particles introduced to handle >2 arg operators
return $PT{eltype(res),N}(res)
end
function ($m.$f)(a::Real,p::$PT{T,N}) where {T,N}
$PT{T,N}(map(x->($m.$f)(a,x), p.particles))
res = map(x->($m.$f)(a,x), p.particles)
return $PT{eltype(res),N}(res)
end
function ($m.$f)(p1::$PT{T,N},p2::$PT{T,N}) where {T,N}
$PT{T,N}(map(($m.$f), p1.particles, p2.particles))
res = map(($m.$f), p1.particles, p2.particles)
return $PT{eltype(res),N}(res)
end
function ($m.$f)(p1::$PT{T,N},p2::$PT{S,N}) where {T,S,N} # Needed for particles of different float types :/
$PT{promote_type(T,S),N}(map(($m.$f), p1.particles, p2.particles))
res = map(($m.$f), p1.particles, p2.particles)
return $PT{eltype(res),N}(res)
end
end)
end
# The code below is resolving some method ambiguities
eval(quote
function ($m.$f)(p1::StaticParticles{T,N},p2::Particles{T,N}) where {T,N}
StaticParticles{T,N}(map(($m.$f), p1.particles, p2.particles))
res = map(($m.$f), p1.particles, p2.particles)
return StaticParticles{eltype(res),N}(res)
end
function ($m.$f)(p1::StaticParticles{T,N},p2::Particles{S,N}) where {T,S,N} # Needed for particles of different float types :/
StaticParticles{promote_type(T,S),N}(map(($m.$f), p1.particles, p2.particles))
res = map(($m.$f), p1.particles, p2.particles)
return StaticParticles{eltype(res),N}(res)
end

function ($m.$f)(p1::Particles{T,N},p2::StaticParticles{T,N}) where {T,N}
StaticParticles{T,N}(map(($m.$f), p1.particles, p2.particles))
res = map(($m.$f), p1.particles, p2.particles)
return StaticParticles{eltype(res),N}(res)
end
function ($m.$f)(p1::Particles{T,N},p2::StaticParticles{S,N}) where {T,S,N} # Needed for particles of different float types :/
StaticParticles{promote_type(T,S),N}(map(($m.$f), p1.particles, p2.particles))
res = map(($m.$f), p1.particles, p2.particles)
return StaticParticles{eltype(res),N}(res)
end
end)
end
Expand All @@ -63,8 +71,9 @@ function register_primitive_single(ff, eval=eval)
m = Base.parentmodule(ff)
for PT in (:Particles, :StaticParticles)
eval(quote
function ($m.$f)(p::$PT)
$PT(map(($m.$f), p.particles))
function ($m.$f)(p::$PT{T,N}) where {T,N}
res = map(($m.$f), p.particles)
return $PT{eltype(res),N}(res)
end
end)
end
Expand Down
12 changes: 12 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,11 @@ Random.seed!(0)
show(io, p)
s = String(take!(io))
@test occursin('±', s)

# issue #50
@test 2.5 * p isa PT{Float64}
@test p / 3 isa PT{Float64}
@test sqrt(p) isa PT{Float64}
end
end
end
Expand Down Expand Up @@ -315,6 +320,13 @@ Random.seed!(0)
@test promote_type(Particles{Float64,10}, Int64) == Particles{Float64,10}
@test promote_type(Particles{Float64,10}, ComplexF64) == Complex{Particles{Float64,10}}
@test promote_type(Particles{Float64,10}, Missing) == Union{Particles{Float64,10},Missing}
@testset "promotion of $PT" for PT in (Particles, StaticParticles)
@test promote_type(PT{Float64,10}, PT{Float64,10}) == PT{Float64,10}
@test promote_type(PT{Float64,10}, PT{Int,10}) == PT{Float64,10}
@test promote_type(PT{Int,5}, PT{Float64,10}) == PT
end
@test promote_type(Particles{Float64,10}, StaticParticles{Float64,10}) == StaticParticles{Float64,10}
@test promote_type(Particles{Int,10}, StaticParticles{Float64,10}) == StaticParticles{Float64,10}
@test convert(Float64, 0p) isa Float64
@test convert(Float64, 0p) == 0
@test convert(Int, 0p) isa Int
Expand Down

0 comments on commit 2982c79

Please sign in to comment.