From 370e2fc85c9fa9c86da8b3d03805a2c5b4ce5374 Mon Sep 17 00:00:00 2001 From: Rafael Schouten Date: Sun, 20 Jun 2021 00:34:52 +1000 Subject: [PATCH] working but slow --- src/optics.jl | 130 +++++++++++++++++++++---------------------- test/test_queries.jl | 36 ++++++------ 2 files changed, 83 insertions(+), 83 deletions(-) diff --git a/src/optics.jl b/src/optics.jl index b2bb340a..3140ae9d 100644 --- a/src/optics.jl +++ b/src/optics.jl @@ -42,8 +42,7 @@ julia> obj = (a=1, b=2); lens=@optic _.a; val = 100; julia> set(obj, lens, val) (a = 100, b = 2) -``` -See also [`modify`](@ref). +``` See also [`modify`](@ref). """ function set end @@ -346,15 +345,7 @@ Here `f` has signature `f(::Value, ::State) -> Tuple{NewValue, NewState}`. """ function modify_stateful end -@inline function modify_stateful(f, (obj, state), optic::Properties) - let f=f, obj=obj, state=state - modify_stateful_context((obj, state), optic) do _, fn, pr, st - f(getfield(pr, known(fn)), st) - end - end -end - -@generated function modify_stateful_context(f, (obj, state1)::T, optic::Properties) where T +@generated function modify_stateful(f::F, (obj, state)::T, optic::Properties) where {T,F} _modify_stateful_inner(T) end @@ -363,29 +354,29 @@ function _modify_stateful_inner(::Type{<:Tuple{O,S}}) where {O,S} modifications = [] vals = Expr(:tuple) fns = fieldnames(O) - local st1 = :state0 - local st2 = :state1 for (i, fn) in enumerate(fns) v = Symbol("val$i") - st1 = Symbol("state$i") - st2 = Symbol("state$(i+1)") - ms = if O <: Tuple - :(($v, $st2) = f(obj, StaticInt{$(QuoteNode(fn))}(), props, $st1)) + st = if S <: ContextState + if O <: Tuple + :(ContextState(state.vals, obj, StaticInt{$(QuoteNode(fn))}())) + else + :(ContextState(state.vals, obj, StaticSymbol{$(QuoteNode(fn))}())) + end else - :(($v, $st2) = f(obj, StaticSymbol{$(QuoteNode(fn))}(), props, $st1)) + :state end + ms = :(($v, state) = f(getfield(props, $(QuoteNode(fn))), $st)) push!(modifications, ms) push!(vals.args, v) end patch = O <: Tuple ? vals : :(NamedTuple{$fns}($vals)) - Expr(:block, - :(props = getproperties(obj)), - modifications..., - :(patch = $patch), - :(new_obj = maybesetproperties($st2, obj, patch)), - :(new_state = maybesetstate($st2, obj, patch)), - :(return (setproperties(obj, patch), $st2)), - ) + start = :(props = getproperties(obj)) + rest = MacroTools.@q begin + patch = $patch + new_obj = maybesetproperties(state, obj, patch) + return (new_obj, state) + end + Expr(:block, start, modifications..., rest) end maybesetproperties(state, obj, patch) = setproperties(obj, patch) @@ -426,15 +417,10 @@ Query(; select=Any, descend=x -> true, optic=Properties()) = Query(select, desce OpticStyle(::Type{<:AbstractQuery}) = SetBased() -struct Context{Select,Descend,Optic<:Union{ComposedOptic,Properties}} <: AbstractQuery - select_condition::Select - descent_condition::Descend - optic::Optic -end - - -struct ContextState{V} +struct ContextState{V,O,FN} vals::V + obj::O + fn::FN end struct GetAllState{V} vals::V @@ -445,57 +431,69 @@ struct SetAllState{C,V,I} itr::I end -pop(x) = first(x), Base.tail(x) -push(x, val) = (x..., val) -push(x::GetAllState, val) = GetAllState(push(x.vals, val)) +const GetStates = Union{GetAllState,ContextState} + +@inline pop(x) = first(x), Base.tail(x) +@inline push(x, val) = (x..., val) +@inline push(x::GetAllState, val) = GetAllState(push(x.vals, val)) +@inline push(x::ContextState, val) = ContextState(push(x.vals, val), nothing, nothing) (q::Query)(obj) = getall(obj, q) -function getall(obj, q) +getall(obj, q) = _getall(obj, q).vals +function _getall(obj, q::Q) where Q<:Query initial_state = GetAllState(()) - _, final_state = modify_stateful((obj, initial_state), q) do o, s - new_state = push(s, outer(q.optic, o, s)) - o, new_state + _, final_state = let q=q + modify_stateful((obj, initial_state), q) do o, s + new_state = push(s, outer(q.optic, o, s)) + o, new_state + end end - return final_state.vals + final_state end -function setall(obj, q, vals) +function setall(obj, q::Q, vals) where Q<:Query initial_state = SetAllState(Unchanged(), vals, 1) - final_obj, _ = modify_stateful((obj, initial_state), q) do o, s - new_output = outer(q.optic, o, s) - new_state = SetAllState(Changed(), s.vals, s.itr + 1) - new_output, new_state + final_obj, _ = let obj=obj, q=q, initial_state=initial_state + modify_stateful((obj, initial_state), q) do o, s + new_output = outer(q.optic, o, s) + new_state = SetAllState(Changed(), s.vals, s.itr + 1) + new_output, new_state + end end return final_obj end -function context(f, obj, q) - initial_state = GetAllState(()) - _, final_state = modify_stateful_context((obj, initial_state), Properties()) do o, fn, pr, s - new_state = push(s, f(o, known(fn))) - o, new_state +function context(f::F, obj, q::Q) where {F,Q<:Query} + initial_state = ContextState((), nothing, nothing) + _, final_state = let f=f + modify_stateful((obj, initial_state), q) do o, s + new_state = push(s, f(s.obj, known(s.fn))) + o, new_state + end end return final_state.vals end modify(f, obj, q::Query) = setall(obj, q, map(f, getall(obj, q))) -@inline function modify_stateful(f::F, (obj, state), q::Query) where F - modify_stateful((obj, state), inner(q.optic)) do o, s - if q.select_condition(o) - f(o, s) - elseif q.descent_condition(o) - ds = descent_state(s) - o, s = modify_stateful(f::F, (o, ds), q) - o, merge_state(s, ds) - else - o, s +@inline function modify_stateful(f::F, (obj, state), q::Q) where {F,Q<:Query} + let f=f, q=q + modify_stateful((obj, state), inner(q.optic)) do o, s + if (q::Q).select_condition(o) + (f::F)(o, s) + elseif (q::Q).descent_condition(o) + ds = descent_state(s) + o, ns = modify_stateful(f::F, (o, ds), q::Q) + o, merge_state(ds, ns) + else + o, s + end end end end -maybesetproperties(state::GetAllState, obj, patch) = obj +maybesetproperties(state::GetStates, obj, patch) = obj maybesetproperties(state::SetAllState, obj, patch) = maybesetproperties(state.change, state, obj, patch) maybesetproperties(::Changed, state::SetAllState, obj, patch) = setproperties(obj, patch) @@ -516,8 +514,8 @@ anychanged(::Changed, ::Changed) = Changed() inner(optic) = optic inner(optic::ComposedOptic) = optic.inner -outer(optic, o, state::GetAllState) = o -outer(optic::ComposedOptic, o, state::GetAllState) = optic.outer(o) +outer(optic, o, state::GetStates) = o +outer(optic::ComposedOptic, o, state::GetStates) = optic.outer(o) outer(optic::ComposedOptic, o, state::SetAllState) = set(o, optic.outer, state.vals[state.itr]) outer(optic, o, state::SetAllState) = state.vals[state.itr] @@ -532,7 +530,7 @@ function (l::PropertyLens{field})(obj) where {field} end @inline function set(obj, l::PropertyLens{field}, val) where {field} - patch = (;field => val) + patch = (; field => val) setproperties(obj, patch) end diff --git a/test/test_queries.jl b/test/test_queries.jl index 5a502722..99722716 100644 --- a/test/test_queries.jl +++ b/test/test_queries.jl @@ -1,9 +1,7 @@ using Accessors, Test, BenchmarkTools, Static using Accessors: setall, getall, context - -obj = (7, (a=17.0, b=2.0f0), ("3", 4, 5.0), ((x=19, a=6.0,), [1,])) +obj = (7, (a=17.0, b=2.0f0), ("3", 4, 5.0), ((x=19, a=6.0,)), [1]) vals = (1.0, 2.0, 3.0, 4.0) - # Fields is the default q = Query(; select=x -> x isa NamedTuple, @@ -11,36 +9,35 @@ q = Query(; optic = (Accessors.@optic _.a) ∘ Accessors.Properties() # optic = Accessors.Properties() ) - -println("getall") getall(obj, q) + @code_native getall(obj, q) @code_warntype getall(obj, q) @benchmark getall($obj, $q) @test getall(obj, q) == (17.0, 6.0) +# using ProfileView, Cthulhu +# @descend getall(obj, q) +# f(obj, q) = for i in 1:10000000 getall(obj, q) end +# @profview f(obj, q) + missings_obj = (a=missing, b=1, c=(d=missing, e=(f=missing, g=2))) @test getall(missings_obj, Query(ismissing)) === (missing, missing, missing) @benchmark getall($missings_obj, Query(ismissing)) -println("setall") # Need a wrapper so we don't have to pass in the starting iterator setall(obj, q, vals) @benchmark setall($obj, $q, $vals) +# using ProfileView +# @profview for i in 1:1000000 setall(obj, q, vals) end @code_native setall(obj, q, vals) @code_warntype setall(obj, q, vals) # @btime Accessors.set($obj, $slowlens, $vals) @test setall(obj, q, vals) == - (7, (a=1.0, b=2.0f0), ("3", 4, 5.0), ((x=19, a=2.0,), [1])) - -using Cthulhu -@descend getall(obj, q) -# using ProfileView -# @profview for i in 1:1000000 Accessors.set(obj, lens, vals) end + (7, (a=1.0, b=2.0f0), ("3", 4, 5.0), ((x=19, a=2.0,)), [1]) -println("unstable set") unstable_q = Accessors.Query(select=x -> x isa Float64 && x > 2, descend=x -> x isa NamedTuple) @btime setall($obj, $unstable_q, $vals) # slow_unstable_lens = Accessors.Query(; select=x -> x isa Number && x > 4, optic=Properties()) @@ -50,10 +47,15 @@ unstable_q = Accessors.Query(select=x -> x isa Float64 && x > 2, descend=x -> x @btime modify(x -> 10x, $obj, $q) # Context -obj = (b=2, c=2) -@test context((o, fn) -> fn, obj, q) == (:b, :c) -@test context((o, fn) -> typeof(o), obj, q) == (typeof(obj), typeof(obj)) -@btime context((o, fn) -> fn, $obj, $q) +q = Query(; + select=x -> x isa Int, + descend=x -> x isa NamedTuple, + optic = Accessors.Properties() +) +obj2 = (1.0, :a, (b=2, c=2)) +@test context((o, fn) -> fn, obj2, q) == (:b, :c) +@test context((o, fn) -> typeof(o), obj2, q) == (typeof(obj2[3]), typeof(obj2[3])) +@btime context((o, fn) -> fn, $obj2, $q) # Macros @test (@getall (x for x in missings_obj if x isa Number)) == (1, 2)