Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add setall() #68

Merged
merged 17 commits into from
Nov 2, 2022
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Accessors"
uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
authors = ["Takafumi Arakaki <[email protected]>", "Jan Weidner <[email protected]> and contributors"]
version = "0.1.21"
version = "0.1.23"
aplavin marked this conversation as resolved.
Show resolved Hide resolved

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
118 changes: 94 additions & 24 deletions src/getsetall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ julia> getall(obj, @optic _ |> Elements() |> last)
"""
function getall end

# implementations for individual noncomposite optics

getall(obj::Union{Tuple, AbstractVector}, ::Elements) = obj
getall(obj::Union{NamedTuple}, ::Elements) = values(obj)
getall(obj::AbstractArray, ::Elements) = vec(obj)
Expand All @@ -29,17 +31,63 @@ getall(obj, ::Properties) = getproperties(obj) |> values
getall(obj, o::If) = o.modify_condition(obj) ? (obj,) : ()
getall(obj, f) = (f(obj),)

function setall(obj, ::Properties, vs)
names = propertynames(obj)
setproperties(obj, NamedTuple{names}(NTuple{length(names)}(vs)))
end
setall(obj::NamedTuple{NS}, ::Elements, vs) where {NS} = NamedTuple{NS}(NTuple{length(NS)}(vs))
setall(obj::NTuple{N, Any}, ::Elements, vs) where {N} = (@assert length(vs) == N; NTuple{N}(vs))
setall(obj::AbstractArray, ::Elements, vs::AbstractArray) = (@assert length(obj) == length(vs); reshape(vs, size(obj)))
setall(obj::AbstractArray, ::Elements, vs) = setall(obj, Elements(), collect(vs))
setall(obj, o::If, vs) = error("Not supported")
setall(obj, o, vs) = set(obj, o, only(vs))


# implementations for composite optics

# A straightforward recursive approach doesn't actually infer,
# see https://github.com/JuliaObjects/Accessors.jl/pull/64 and https://github.com/JuliaObjects/Accessors.jl/pull/68.
# Instead, we need to generate separate functions for each recursion level.

# A recursive implementation of getall doesn't actually infer,
# see https://github.com/JuliaObjects/Accessors.jl/pull/64.
# Instead, we need to generate unrolled code explicitly.
function getall(obj, optic::ComposedFunction)
N = length(decompose(optic))
_GetAll{N}()(obj, optic)
_getall(obj, optic, Val(N))
end

function setall(obj, optic::ComposedFunction, vs)
N = length(decompose(optic))
vss = to_nested_shape(vs, Val(getall_lengths(obj, optic, Val(N))), Val(N))
_setall(obj, optic, vss, Val(N))
end


# _getall: the actual workhorse for getall
_getall(_, _, ::Val{N}) where {N} = error("Too many chained optics: $N is not supported for now. See also https://github.com/JuliaObjects/Accessors.jl/pull/64.")
_getall(obj, optic, ::Val{1}) = getall(obj, optic)
for i in 2:10
@eval function _getall(obj, optic, ::Val{$i})
_reduce_concat(
map(getall(obj, optic.inner)) do obj
_getall(obj, optic.outer, Val($(i-1)))
end
)
end
end

struct _GetAll{N} end
(::_GetAll{N})(_) where {N} = error("Too many chained optics: $N is not supported for now. See also https://github.com/JuliaObjects/Accessors.jl/pull/64.")
# _setall: the actual workhorse for setall
# takes values as a nested tuple with proper leaf lengths, prepared in setall above
_setall(_, _, _, ::Val{N}) where {N} = error("Too many chained optics: $N is not supported for now. See also https://github.com/JuliaObjects/Accessors.jl/pull/68.")
_setall(obj, optic, vs, ::Val{1}) = setall(obj, optic, vs)
for i in 2:10
@eval function _setall(obj, optic, vs, ::Val{$i})
setall(obj, optic.inner, map(getall(obj, optic.inner), vs) do obj, vss
_setall(obj, optic.outer, vss, Val($(i - 1)))
end)
end
end


# helper functions

_concat(a::Tuple, b::Tuple) = (a..., b...)
_concat(a::Tuple, b::AbstractVector) = vcat(collect(a), b)
Expand All @@ -51,26 +99,48 @@ _reduce_concat(xs::AbstractVector) = reduce(append!, xs; init=eltype(eltype(xs))
_reduce_concat(xs::Tuple{AbstractVector, Vararg{AbstractVector}}) = reduce(vcat, xs)
_reduce_concat(xs::AbstractVector{<:AbstractVector}) = reduce(vcat, xs)

function _generate_getall(N::Int)
syms = [Symbol(:f_, i) for i in 1:N]

expr = :( getall(obj, $(syms[end])) )
for s in syms[1:end - 1] |> reverse
expr = :(
_reduce_concat(
map(getall(obj, $(s))) do obj
$expr
end
)
)
end
_staticlength(::NTuple{N, <:Any}) where {N} = Val(N)
_staticlength(x::AbstractVector) = length(x)

:(function (::_GetAll{$N})(obj, optic)
($(syms...),) = deopcompose(optic)
$expr
end)
getall_lengths(obj, optic, ::Val{1}) = _staticlength(getall(obj, optic))
for i in 2:10
@eval function getall_lengths(obj, optic, ::Val{$i})
# convert to Tuple: vectors cannot be put into Val
map(getall(obj, optic.inner) |> Tuple) do o
getall_lengths(o, optic.outer, Val($(i - 1)))
end
end
end

_val(N::Int) = N
_val(::Val{N}) where {N} = N

nestedsum(ls::Int) = ls
nestedsum(ls::Val) = ls
nestedsum(ls::Tuple) = sum(_val ∘ nestedsum, ls)

# to_nested_shape() definition uses both @eval and @generated
#
# @eval is needed because the code for different recursion depths should be different for inference,
# not the same method with different parameters.
#
# @generated is used to unpack target lengths from the second argument at compile time to make to_nested_shape() as cheap as possible.
#
# Note: to_nested_shape() only operates on plain Julia types and won't be affected by user lens definition, unlike setall for example.
# That's why it's safe to make it @generated.
to_nested_shape(vs, ::Val{LS}, ::Val{1}) where {LS} = (@assert length(vs) == _val(LS); vs)
for i in 2:10
eval(_generate_getall(i))
@eval @generated function to_nested_shape(vs, ls::Val{LS}, ::Val{$i}) where {LS}
aplavin marked this conversation as resolved.
Show resolved Hide resolved
vi = 1
subs = map(LS) do lss
n = nestedsum(lss)
elems = map(vi:vi+_val(n)-1) do j
:( vs[$j] )
end
res = :( to_nested_shape(($(elems...),), $(Val(lss)), $(Val($(i - 1)))) )
vi += _val(n)
res
end
:( ($(subs...),) )
end
end
2 changes: 1 addition & 1 deletion src/optics.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
export @optic
export PropertyLens, IndexLens
export set, modify, delete, insert, getall
export set, modify, delete, insert, getall, setall
export ∘, opcompose, var"⨟"
export Elements, Recursive, If, Properties
export setproperties
Expand Down
16 changes: 16 additions & 0 deletions src/testing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,19 @@ function test_modify_law(f, lens, obj)
obj_setfget = set(obj, lens, val)
@test obj_modify == obj_setfget
end

function test_getsetall_laws(optic, obj, vals1, vals2; cmp=(==))

# setall ⨟ getall
vals = getall(obj, optic)
@test cmp(setall(obj, optic, vals), obj)

# getall ⨟ setall
obj1 = setall(obj, optic, vals1)
@test cmp(collect(getall(obj1, optic)), collect(vals1))

# setall idempotent
obj12 = setall(obj1, optic, vals2)
obj2 = setall(obj12, optic, vals2)
@test obj12 == obj2
end
43 changes: 43 additions & 0 deletions test/test_getsetall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,49 @@ if VERSION >= v"1.6" # for ComposedFunction
obj = ([1, 2], [:a, :b])
@test [1, 2, :a, :b] == @inferred getall(obj, @optic _ |> Elements() |> Elements())
end

@testset "setall" begin
for o in [Elements(), Properties()]
@test (a=2, b=3) === @inferred setall((a=1, b="2"), o, (2, 3))
@test (a=2, b="3") === @inferred setall((a=1, b="2"), o, (2, "3"))
@test (a=2, b=3) === @inferred setall((a=1, b="2"), o, [2, 3])
end
@test (2, 3) === @inferred setall((1, "2"), Elements(), (2, 3))
@test (2, "3") === @inferred setall((1, "2"), Elements(), (2, "3"))
@test (2, 3) === @inferred setall((1, "2"), Elements(), [2, 3])
@test [2, 3] == @inferred setall([1, "2"], Elements(), (2, 3))
@test [2, "3"] == @inferred setall([1, "2"], Elements(), (2, "3"))
@test [2, 3] == @inferred setall([1, "2"], Elements(), [2, 3])

obj = (a=1, b=2.0, c='3')
@test (a="aa", b=2.0, c='3') === @inferred setall(obj, @optic(_.a), ("aa",))
@test (a=9, b=19.0, c='4') === @inferred setall(obj, @optic(_ |> Elements() |> _ + 1), (10, 20.0, '5'))

obj = (a=1, b=((c=3, d=4), (c=5, d=6)))
@test (a=1, b=(:x, :y)) === @inferred setall(obj, @optic(_.b |> Elements()), (:x, :y))
@test (a=1, b=((c=:x, d=4), (c=:y, d=6))) === @inferred setall(obj, @optic(_.b |> Elements() |> _.c), (:x, :y))
@test (a=1, b=((c=:x, d="y"), (c=:z, d=10))) === @inferred setall(obj, @optic(_.b |> Elements() |> Properties()), (:x, "y", :z, 10))
@test (a=1, b=((c=-3., d=-4.), (c=-5., d=-6.))) === @inferred setall(obj, @optic(_.b |> Elements() |> Properties() |> _ * 3), (-9, -12, -15, -18))
@test (a=1, b=((c=-3., d=-4.), (c=-5., d=-6.))) === @inferred setall(obj, @optic(_.b |> Elements() |> Properties() |> _ * 3), [-9, -12, -15, -18])

obj = ([1, 2], 3:5, (6,))
@test obj == setall(obj, @optic(_ |> Elements() |> Elements()), 1:6)
@test ([2, 3], 4:6, (7,)) == setall(obj, @optic(_ |> Elements() |> Elements() |> _ - 1), 1:6)
# can this infer?..
@test_broken obj == @inferred setall(obj, @optic(_ |> Elements() |> Elements()), 1:6)
@test_broken ([2, 3], 4:6, (7,)) == @inferred setall(obj, @optic(_ |> Elements() |> Elements() |> _ - 1), 1:6)
end

@testset "getall/setall consistency" begin
for (optic, obj, vals1, vals2) in [
(Elements(), (1, "2"), (2, 3), (4, 5)),
(Properties(), (a=1, b="2"), (2, 3), (4, 5)),
(@optic(_.b |> Elements() |> Properties() |> _ * 3), (a=1, b=((c=3, d=4), (c=5, d=6))), 1:4, (-9, -12, -15, -18)),
]
Accessors.test_getsetall_laws(optic, obj, vals1, vals2)
end
end

end

end