From 989286fc2b86ccc1b40713d5554a8e2c366e4834 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 26 Jul 2022 22:24:14 -0400 Subject: [PATCH] tidy, tests --- src/deprecations.jl | 5 - src/train/Train.jl | 122 +----------------- src/train/explicit_train.jl | 28 +++-- src/train/implicit_train.jl | 10 +- test/layers/conv.jl | 8 +- test/optimise.jl | 239 ------------------------------------ test/runtests.jl | 4 +- test/train.jl | 92 ++++++++++++++ 8 files changed, 127 insertions(+), 381 deletions(-) delete mode 100644 test/optimise.jl create mode 100644 test/train.jl diff --git a/src/deprecations.jl b/src/deprecations.jl index 1769a94170..6cb73d2cf2 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -34,11 +34,6 @@ struct Zeros end Zeros(args...) = Zeros() # was used both Dense(10, 2, initb = Zeros) and Dense(rand(2,10), Zeros()) -# function Optimise.update!(x::AbstractArray, x̄) -# Base.depwarn("`Flux.Optimise.update!(x, x̄)` was not used internally and has been removed. Please write `x .-= x̄` instead.", :update!) -# x .-= x̄ -# end - function Diagonal(size::Integer...; kw...) Base.depwarn("Flux.Diagonal is now Flux.Scale, and also allows an activation function.", :Diagonal) Scale(size...; kw...) diff --git a/src/train/Train.jl b/src/train/Train.jl index 32049b9285..bbbe762fa2 100644 --- a/src/train/Train.jl +++ b/src/train/Train.jl @@ -4,7 +4,7 @@ using LinearAlgebra using Optimisers: Optimisers using Functors: fmap -export train!, update!, adjust!, FluxState, @epochs, +export train!, update!, adjust!, FluxState, Descent, Adam, Momentum, Nesterov, RMSProp, AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief #, # InvDecay, ExpDecay, WeightDecay, stop, skip, Optimiser, @@ -15,7 +15,7 @@ export train!, update!, adjust!, FluxState, @epochs, """ FluxState(rule, state=missing) - + This is an interface between the all-mutable world Flux.jl likes, and the could-be-immutable world that Optimisers.jl inhabits. @@ -56,34 +56,14 @@ end ### Two styles of gradient, and their `train!` functions -using ProgressLogging: @progress, @withprogress, @logprogress +using ProgressLogging: @progress, @withprogress, @logprogress # TODO add progress logging again using Zygote: Zygote, Params -include("explicit_train.jl.jl") # new! -include("implicit_train.jl.jl") # Params etc, Zygote only +include("explicit_train.jl") # new! +include("implicit_train.jl") # Params etc, Zygote only explicit_withgradient(f, args...) = Zygote.withgradient(f, args...) # can overload this to use e.g. Yota / Diffractor -# using Requires # Flux doesn't use this right now -# @init @require Diffractor="9f5e2b26-1114-432f-b630-d3fe2085c51c" begin -# @eval function explicit_withgradient(f, args...) -# y, back = Diffractor.∂⃖¹(f, args...) -# _, grads... = back(Zygote.sensitivity(y)) -# return (; value = y, gradient = grads) -# end -# end - -#= - -using Diffractor -function Flux.Train.explicit_withgradient(f, args...) - y, back = Diffractor.∂⃖¹(f, args...) - _, grads... = back(one(y)) - return (; value = y, gradient = grads) -end - -=# - ### Misc. related utilities """ @@ -107,94 +87,4 @@ function adjust!(opt::FluxState, eta::Real) return opt end -""" - @epochs N body - -Run `body` expression `N` times. Mainly useful for quickly doing -multiple epochs of training in a REPL. - -Functionally equivalent to this loop: -``` -for _ in 1:N - body -end -``` -... but adds progress logging and `@info` messages, -and returns the result of the last iteration. - -# Examples -```jldoctest -julia> Flux.@epochs 2 println("hello") -[ Info: Epoch 1 -hello -[ Info: Epoch 2 -hello -``` -""" -macro epochs(n, ex) - @gensym val - body = :(for i in 1:$(esc(n)) - @info "Epoch $i" - $(esc(val)) = $(esc(ex)) - end) - loop = Expr(:macrocall, Symbol("@progress"), __source__, body) - Expr(:block, :($(esc(val)) = nothing), loop, :($(esc(val)))) - # TODO make this actualy return the value? Names aren't right. -# -# $loop -# # @progress for i in 1:$(esc(n)) -# # @info "Epoch $i" -# # $(esc(val)) = $(esc(ex)) -# # end -# $val # DOESN"T WORK! Expr(:macrocall, ...) ? -# end -end - -end - - -#= - -using Flux, Random -data = [(rand(3,2).*[i,1,20/i], [i i]) for i in 1:50] |> shuffle!; - -# This exact code works on Flux@0.13. There, train! returns nothing: -model2 = Chain(Dense(3 => 7, relu), Dense(7 => 1)) -opt2 = Flux.Adam() -Flux.train!(Flux.params(model2), data, opt2) do x, y - Flux.mse(model2(x), y) -end -opt2 # contains an IdDict - -# This is the new "explicit" method of Train -model1 = Chain(Dense(3 => 7, relu), Dense(7 => 1)) -opt1 = Flux.Adam() -Flux.train!(model1, data, opt1) do m, x, y - Flux.mse(m(x), y) -end |> sum -opt1 # contains state tree - -# This is new 3-arg train!, one step not an iteration over data: -x1, y1 = data[1] -Flux.train!(model1, opt1) do m - Flux.mse(m(x1), y1) -end - - - - - -julia> using ProgressLogging -julia> @macroexpand1 @loop N body -begin - x = nothing - @progress for i in 1:N - @info "step $i" - x = body - end - x -end - - - -=# \ No newline at end of file +end # module diff --git a/src/train/explicit_train.jl b/src/train/explicit_train.jl index edd31b281e..673ba6c141 100644 --- a/src/train/explicit_train.jl +++ b/src/train/explicit_train.jl @@ -52,26 +52,28 @@ function train!(loss::Function, model, data, opt::FluxState) _initialise!(opt, model) losses = Float32[] s = opt.state - s isa IdDict && error("can't mix explicit & implicit!") + s isa IdDict && error("""Can't mix explicit & implicit modes! + Once `FluxState` is initialised by `train!` in one mode, it cannot be used in the other.""") for d in data - l, (g, _...) = Zygote.withgradient(loss, model, train_ok(d)...) + l, (g, _...) = explicit_withgradient(loss, model, data_splat(d)...) s, model = Optimisers.update!(s, model, g) push!(losses, l) opt.state = s end - return losses + return losses # Not entirely sure returning losses is a good idea. Flux 0.13 returns `nothing`. end -train_ok(x::T) where T = error("""train! expects every d in data be a Tuple or a NamedTuple, got $T - To allow this type, define `Flux.Optimise.train_ok(x::$T) = (x,)`""") -train_ok(x::Tuple) = x -train_ok(x::NamedTuple) = x +data_splat(x::T) where T = error("""train! expects every d in data be a Tuple or a NamedTuple, got $T + To allow this type, define `Flux.Train.data_splat(x::$T) = (x,)`""") +data_splat(x::Tuple) = x +data_splat(x::NamedTuple) = x function _initialise!(opt::FluxState, model) if opt.state isa Missing opt.state = Optimisers.setup(opt.rule, model) fmap(model, exclude = Optimisers.isnumeric) do x - Optimisers.maywrite(x) || error("model must be fully mutable for train! to work, got $(typeof(x))") + Optimisers.maywrite(x) || error("""model must be fully mutable for train! to work, got x::$(typeof(x)) + If `x .+= dx` is in fact ok, define `Optimisers.maywrite(::$(typeof(x))) = true`""") end end opt @@ -107,12 +109,12 @@ function train!(loss::Function, model, opt::FluxState) l end +# This method lets you use Optimisers.Descent() instead of Flux.Descent(), when there is no state function train!(loss::Function, model, data, opt::Optimisers.AbstractRule) _initialise!(opt, model) - # fmap(opt.state) do x - # x isa Union{Number, AbstractArray{<:Number}} && @warn "optimiser state will be lost!" - # x - # end # won't work as you need to look inside Leaf for non-nothings. - @warn "optimiser state will be lost!" + fmap(opt.state, exclude = x -> x isa Optimsers.Leaf) do leaf + leaf.state isa Nothing || @warn "Optimiser state will be lost! Please wrap optimisation rule in `FluxState`, e.g. by using `Flux.Adam()`" leaf + leaf + end train!(loss, model, data, FluxState(opt)) end diff --git a/src/train/implicit_train.jl b/src/train/implicit_train.jl index 43c3b75766..eb2068eaa0 100644 --- a/src/train/implicit_train.jl +++ b/src/train/implicit_train.jl @@ -29,7 +29,7 @@ function train!(loss::Function, pars::Params, data, opt::FluxState) losses = Float32[] for d in data l, grads = Zygote.withgradient(() -> loss(batchmemaybe(d)...), pars) - update!(opt, pars, grads) + _update!(opt, pars, grads) push!(losses, l) end return losses @@ -49,7 +49,7 @@ function train!(loss::Function, pars::Params, opt::FluxState) Explicit parameters are now preferred, see `train!(loss, model, data, opt)`""", :train!, force=true) _initialise!(opt, pars) l, grads = Zygote.withgradient(() -> loss(), pars) - update!(opt, pars, grads) + _update!(opt, pars, grads) return l end @@ -68,6 +68,12 @@ Legacy method, mimicking the behaviour of Flux <= 0.13. """ function update!(opt::FluxState, xs::Params, gs) Base.depwarn("Flux.update! is a legacy function", :update!) + _initialise!(opt, xs) + _update!(opt, xs, gs) +end +# This _update! exists only so that train! above gives one depwarn, not two! +# ... and also to call _initialise! +function _update!(opt::FluxState, xs::Params, gs) for x in xs isnothing(gs[x]) && continue update!(opt, x, gs[x]) diff --git a/test/layers/conv.jl b/test/layers/conv.jl index 019f3fd603..32e40f4186 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -55,13 +55,13 @@ end bias = Conv((2, 2), 1=>3, bias = false); ip = zeros(Float32, 28,28,1,1) op = zeros(Float32, 27,27,3,1) .+ 2.f0 - opt = Descent() + opt = Flux.Descent() for _ = 1:10^3 gs = gradient(Flux.params(bias)) do Flux.Losses.mse(bias(ip), op) end - Flux.Optimise.update!(opt, params(bias), gs) + Flux.Optimise.update!(opt, Flux.params(bias), gs) end @test Flux.Losses.mse(bias(ip), op) ≈ 4.f0 @@ -168,7 +168,7 @@ end x = zeros(Float32, 5, 5, 2, 4) m = ConvTranspose((3,3), 2=>3) - @test gradient(()->sum(m(x)), params(m)) isa Flux.Zygote.Grads + @test gradient(()->sum(m(x)), Flux.params(m)) isa Flux.Zygote.Grads # test ConvTranspose supports groups argument x = randn(Float32, 10, 10, 2, 3) @@ -178,7 +178,7 @@ end m2 = ConvTranspose((3,3), 2=>4, groups=2, pad=SamePad()) @test size(m2.weight) == (3,3,2,2) @test size(m1(x)) == size(m2(x)) - @test gradient(()->sum(m2(x)), params(m2)) isa Flux.Zygote.Grads + @test gradient(()->sum(m2(x)), Flux.params(m2)) isa Flux.Zygote.Grads x = randn(Float32, 10, 2,1) m = ConvTranspose((3,), 2=>4, pad=SamePad(), groups=2) diff --git a/test/optimise.jl b/test/optimise.jl deleted file mode 100644 index e922d3c0b8..0000000000 --- a/test/optimise.jl +++ /dev/null @@ -1,239 +0,0 @@ -using Flux.Optimise -using Flux.Optimise: runall -using Flux: Params, gradient -import FillArrays, ComponentArrays -using Test -using Random - -@testset "Optimise" begin - # Ensure rng has different state inside and outside the inner @testset - # so that w and w' are different - Random.seed!(84) - w = randn(10, 10) - @testset for opt in [AdamW(), AdaGrad(0.1), AdaMax(), AdaDelta(0.9), AMSGrad(), - NAdam(), RAdam(), Descent(0.1), Adam(), OAdam(), AdaBelief(), - Nesterov(), RMSProp(), Momentum()] - Random.seed!(42) - w′ = randn(10, 10) - b = false - loss(x) = Flux.Losses.mse(w*x, w′*x .+ b) - for t = 1: 10^5 - θ = params([w′, b]) - x = rand(10) - θ̄ = gradient(() -> loss(x), θ) - Optimise.update!(opt, θ, θ̄) - end - @test loss(rand(10, 10)) < 0.01 - end -end - -@testset "Optimiser" begin - Random.seed!(84) - w = randn(10, 10) - @testset for Opt in [InvDecay, WeightDecay, ExpDecay] - Random.seed!(42) - w′ = randn(10, 10) - loss(x) = Flux.Losses.mse(w*x, w′*x) - opt = Optimiser(Opt(), Adam(0.001)) - for t = 1:10^5 - θ = Params([w′]) - x = rand(10) - θ̄ = gradient(() -> loss(x), θ) - Optimise.update!(opt, θ, θ̄) - end - @test loss(rand(10, 10)) < 0.01 - end -end - -@testset "Training Loop" begin - i = 0 - l = 1 - Flux.train!( - () -> (sleep(0.1); Flux.skip(); i+=1), - Params([]), - Iterators.repeated((), 10), - Descent() - ) - - @test i==0 #all skipped - - Flux.train!( - () -> (sleep(0.1); i==8 && Flux.skip(); i+=1), - Params([]), - Iterators.repeated((), 10), - Descent() - ) - - @test i==8 #skip after i hit 8 - - i = 0 - Flux.train!(() -> (sleep(0.1); i += 1; l), - Params([]), - Iterators.repeated((), 100), - Descent(), - cb = Flux.throttle(() -> (i > 3 && Flux.stop()), 1)) - - @test 3 < i < 50 - - # Test multiple callbacks - x = 0 - fs = [() -> (), () -> x = 1] - cbs = runall(fs) - cbs() - @test x == 1 - - r = rand(3, 3) - loss(x) = sum(x .* x) - Flux.train!(loss, Flux.params(r), (r,), Descent()) -end - -@testset "ExpDecay" begin - - @testset "Sanity Check" begin - o = ExpDecay(0.2, 0.5, 1, 1e-3) - p = [0.0] - steps = 1:8 - eta_expected = @. max(o.eta * 0.5 ^ steps, o.clip) - eta_actual = [Optimise.apply!(o, p, [1.0])[1] for _ in steps] - @test eta_actual == eta_expected - end - - @testset "starting step" begin - start = 4 - o = ExpDecay(0.2, 0.5, 1, 1e-3, start) - p = [0.0] - steps = 1:8 - eta_expected = @. max(o.eta * 0.5 ^ max(steps - start, 0), o.clip) - eta_actual = [Optimise.apply!(o, p, [1.0])[1] for _ in steps] - @test eta_actual == eta_expected - end - - w = randn(10, 10) - o = ExpDecay(0.1, 0.1, 1000, 1e-4) - w1 = randn(10,10) - loss(x) = Flux.Losses.mse(w*x, w1*x) - flag = 1 - decay_steps = [] - for t = 1:10^5 - prev_eta = o.eta - θ = Params([w1]) - x = rand(10) - θ̄ = gradient(() -> loss(x), θ) - prev_grad = collect(θ̄[w1]) - delta = Optimise.apply!(o, w1, θ̄[w1]) - w1 .-= delta - new_eta = o.eta - if new_eta != prev_eta - push!(decay_steps, t) - end - array = fill(o.eta, size(prev_grad)) - if array .* prev_grad != delta - flag = 0 - end - end - @test flag == 1 - # Test to check if decay happens at decay steps. Eta reaches clip value (1e-4) after 4000 steps (decay by 0.1 every 1000 steps starting at 0.1). - ground_truth = [] - for i in 1:4 - push!(ground_truth, 1000*i) # Expected decay steps for this example. - end - @test decay_steps == ground_truth - @test o.eta == o.clip -end - -@testset "Clipping" begin - w = randn(10, 10) - loss(x) = sum(w * x) - θ = Params([w]) - x = 1000 * randn(10) - w̄ = gradient(() -> loss(x), θ)[w] - w̄_value = Optimise.apply!(ClipValue(1.0), w, copy(w̄)) - @test all(w̄_value .<= 1) - w̄_norm = Optimise.apply!(ClipNorm(1.0), w, copy(w̄)) - @test norm(w̄_norm) <= 1 -end - -@testset "update!: handle Fills from Zygote" begin - w = randn(10,10) - wold = copy(w) - g = FillArrays.Ones(size(w)) - opt = Descent(0.1) - Flux.update!(opt, w, g) - @test w ≈ wold .- 0.1 - - w = randn(3) - wold = copy(w) - θ = Flux.params([w]) - gs = gradient(() -> w[1], θ) - opt = Descent(0.1) - Flux.update!(opt, θ, gs) - @test w[1] ≈ wold[1] .- 0.1 - @test w[2:3] ≈ wold[2:3] - - ## Issue #1510 - w = randn(10,10) - wold = copy(w) - θ = Flux.params([w]) - gs = gradient(() -> sum(w), θ) - opt = Descent(0.1) - Flux.update!(opt, θ, gs) - @test w ≈ wold .- 0.1 -end - -@testset "update!: handle ComponentArrays" begin - w = ComponentArrays.ComponentArray(a=1.0, b=[2, 1, 4], c=(a=2, b=[1, 2])) - wold = deepcopy(w) - θ = Flux.params([w]) - gs = gradient(() -> sum(w.a) + sum(w.c.b), θ) - opt = Descent(0.1) - Flux.update!(opt, θ, gs) - @test w.a ≈ wold.a .- 0.1 - @test w.b ≈ wold.b - @test w.c.b ≈ wold.c.b .- 0.1 - @test w.c.a ≈ wold.c.a - - w = ComponentArrays.ComponentArray(a=1.0, b=[2, 1, 4], c=(a=2, b=[1, 2])) - wold = deepcopy(w) - θ = Flux.params([w]) - gs = gradient(() -> sum(w), θ) - opt = Descent(0.1) - Flux.update!(opt, θ, gs) - @test w ≈ wold .- 0.1 -end - -# Flux PR #1776 -# We need to test that optimisers like Adam that maintain an internal momentum -# estimate properly calculate the second-order statistics on the gradients as -# the flow backward through the model. Previously, we would calculate second- -# order statistics via `Δ^2` rather than the complex-aware `Δ * conj(Δ)`, which -# wreaks all sorts of havoc on our training loops. This test ensures that -# a simple optimization is montonically decreasing (up to learning step effects) -@testset "Momentum Optimisers and complex values" begin - # Test every optimizer that has momentum internally - for opt_ctor in [Adam, RMSProp, RAdam, OAdam, AdaGrad, AdaDelta, NAdam, AdaBelief] - # Our "model" is just a complex number - w = zeros(ComplexF32, 1) - - # Our model attempts to learn `f(x) = conj(x)` where `f(x) = w*x` - function loss() - # Deterministic training data is the best training data - x = ones(1, 1) + 1im*ones(1, 1) - - # Manually implement `mse()` to allow demonstration of brokenness - # on older Flux builds that don't have a fixed `mse()` - return sum(abs2.(w * x .- conj(x))) - end - - params = Flux.Params([w]) - opt = opt_ctor(1e-2) - - # Train for 10 iterations, enforcing that loss is monotonically decreasing - last_loss = Inf - for idx in 1:10 - grads = Flux.gradient(loss, params) - @test loss() < last_loss - last_loss = loss() - Flux.update!(opt, params, grads) - end - end -end diff --git a/test/runtests.jl b/test/runtests.jl index 706f126451..d9a5011879 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -20,8 +20,8 @@ Random.seed!(0) include("onehot.jl") end - @testset "Optimise" begin - include("optimise.jl") + @testset "Train" begin + include("train.jl") end @testset "Data" begin diff --git a/test/train.jl b/test/train.jl new file mode 100644 index 0000000000..c7af65f509 --- /dev/null +++ b/test/train.jl @@ -0,0 +1,92 @@ +using Flux.Train +using Zygote: Params, gradient + +import FillArrays, ComponentArrays + +using Test +using Random + +@testset "Implicit train!" begin # These tests pass on Flux v0.13 + Random.seed!(84) + w = randn(10, 10) + w2 = randn(10, 10) # NB outside the inner @testset, else it will be exactly == w, as the RNG seed is reset. + @testset for opt in [Descent(0.1), Adam()] + # [AdamW(), AdaGrad(0.1), AdaMax(), AdaDelta(0.9), AMSGrad(), + # NAdam(), RAdam(), Descent(0.1), Adam(), OAdam(), AdaBelief(), + # Nesterov(), RMSProp(), Momentum()] + w′ = copy(w2) + b = zeros(10) + loss(x) = Flux.Losses.mse(w*x, w′*x .+ b) + @test loss(rand(10, 10)) > 1 + Flux.train!(loss, Flux.params([w′, b]), (rand(10) for _ in 1: 10^5), opt) + @test loss(rand(10, 10)) < 0.01 + end +end + +@testset "Explicit train!" begin + Random.seed!(84) + w = randn(10, 10) + w2 = randn(10, 10) # NB outside the inner @testset, else it will be exactly == w, as the RNG seed is reset. + @testset for opt in [Descent(0.1), Adam()] + @test opt isa FluxState + w′ = copy(w2) + b = zeros(10) + loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias) + model = (weight=w′, bias=b, ignore=nothing) + @test loss(model, rand(10, 10)) > 1 + train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt) + @test loss(model, rand(10, 10)) < 0.01 + end +end + +#= + +@testset "update!: handle Fills from Zygote" begin + w = randn(10,10) + wold = copy(w) + g = FillArrays.Ones(size(w)) + opt = Descent(0.1) + Flux.update!(opt, w, g) + @test w ≈ wold .- 0.1 + + w = randn(3) + wold = copy(w) + θ = Flux.params([w]) + gs = gradient(() -> w[1], θ) + opt = Descent(0.1) + Flux.update!(opt, θ, gs) + @test w[1] ≈ wold[1] .- 0.1 + @test w[2:3] ≈ wold[2:3] + + ## Issue #1510 + w = randn(10,10) + wold = copy(w) + θ = Flux.params([w]) + gs = gradient(() -> sum(w), θ) + opt = Descent(0.1) + Flux.update!(opt, θ, gs) + @test w ≈ wold .- 0.1 +end + +@testset "update!: handle ComponentArrays" begin + w = ComponentArrays.ComponentArray(a=1.0, b=[2, 1, 4], c=(a=2, b=[1, 2])) + wold = deepcopy(w) + θ = Flux.params([w]) + gs = gradient(() -> sum(w.a) + sum(w.c.b), θ) + opt = Descent(0.1) + Flux.update!(opt, θ, gs) + @test w.a ≈ wold.a .- 0.1 + @test w.b ≈ wold.b + @test w.c.b ≈ wold.c.b .- 0.1 + @test w.c.a ≈ wold.c.a + + w = ComponentArrays.ComponentArray(a=1.0, b=[2, 1, 4], c=(a=2, b=[1, 2])) + wold = deepcopy(w) + θ = Flux.params([w]) + gs = gradient(() -> sum(w), θ) + opt = Descent(0.1) + Flux.update!(opt, θ, gs) + @test w ≈ wold .- 0.1 +end + +=# \ No newline at end of file