Skip to content

Commit

Permalink
tidy, tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Jul 27, 2022
1 parent a68470c commit d686232
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 372 deletions.
122 changes: 6 additions & 116 deletions src/train/Train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using LinearAlgebra
using Optimisers: Optimisers
using Functors: fmap

export train!, update!, adjust!, FluxState, @epochs,
export train!, update!, adjust!, FluxState,
Descent, Adam, Momentum, Nesterov, RMSProp,
AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief #,
# InvDecay, ExpDecay, WeightDecay, stop, skip, Optimiser,
Expand All @@ -15,7 +15,7 @@ export train!, update!, adjust!, FluxState, @epochs,

"""
FluxState(rule, state=missing)
This is an interface between the all-mutable world Flux.jl likes,
and the could-be-immutable world that Optimisers.jl inhabits.
Expand Down Expand Up @@ -56,34 +56,14 @@ end

### Two styles of gradient, and their `train!` functions

using ProgressLogging: @progress, @withprogress, @logprogress
using ProgressLogging: @progress, @withprogress, @logprogress # TODO add progress logging again
using Zygote: Zygote, Params

include("explicit_train.jl.jl") # new!
include("implicit_train.jl.jl") # Params etc, Zygote only
include("explicit_train.jl") # new!
include("implicit_train.jl") # Params etc, Zygote only

explicit_withgradient(f, args...) = Zygote.withgradient(f, args...) # can overload this to use e.g. Yota / Diffractor

# using Requires # Flux doesn't use this right now
# @init @require Diffractor="9f5e2b26-1114-432f-b630-d3fe2085c51c" begin
# @eval function explicit_withgradient(f, args...)
# y, back = Diffractor.∂⃖¹(f, args...)
# _, grads... = back(Zygote.sensitivity(y))
# return (; value = y, gradient = grads)
# end
# end

#=
using Diffractor
function Flux.Train.explicit_withgradient(f, args...)
y, back = Diffractor.∂⃖¹(f, args...)
_, grads... = back(one(y))
return (; value = y, gradient = grads)
end
=#

### Misc. related utilities

"""
Expand All @@ -107,94 +87,4 @@ function adjust!(opt::FluxState, eta::Real)
return opt
end

"""
@epochs N body
Run `body` expression `N` times. Mainly useful for quickly doing
multiple epochs of training in a REPL.
Functionally equivalent to this loop:
```
for _ in 1:N
body
end
```
... but adds progress logging and `@info` messages,
and returns the result of the last iteration.
# Examples
```jldoctest
julia> Flux.@epochs 2 println("hello")
[ Info: Epoch 1
hello
[ Info: Epoch 2
hello
```
"""
macro epochs(n, ex)
@gensym val
body = :(for i in 1:$(esc(n))
@info "Epoch $i"
$(esc(val)) = $(esc(ex))
end)
loop = Expr(:macrocall, Symbol("@progress"), __source__, body)
Expr(:block, :($(esc(val)) = nothing), loop, :($(esc(val))))
# TODO make this actualy return the value? Names aren't right.
#
# $loop
# # @progress for i in 1:$(esc(n))
# # @info "Epoch $i"
# # $(esc(val)) = $(esc(ex))
# # end
# $val # DOESN"T WORK! Expr(:macrocall, ...) ?
# end
end

end


#=
using Flux, Random
data = [(rand(3,2).*[i,1,20/i], [i i]) for i in 1:50] |> shuffle!;
# This exact code works on [email protected]. There, train! returns nothing:
model2 = Chain(Dense(3 => 7, relu), Dense(7 => 1))
opt2 = Flux.Adam()
Flux.train!(Flux.params(model2), data, opt2) do x, y
Flux.mse(model2(x), y)
end
opt2 # contains an IdDict
# This is the new "explicit" method of Train
model1 = Chain(Dense(3 => 7, relu), Dense(7 => 1))
opt1 = Flux.Adam()
Flux.train!(model1, data, opt1) do m, x, y
Flux.mse(m(x), y)
end |> sum
opt1 # contains state tree
# This is new 3-arg train!, one step not an iteration over data:
x1, y1 = data[1]
Flux.train!(model1, opt1) do m
Flux.mse(m(x1), y1)
end
julia> using ProgressLogging
julia> @macroexpand1 @loop N body
begin
x = nothing
@progress for i in 1:N
@info "step $i"
x = body
end
x
end
=#
end # module
28 changes: 15 additions & 13 deletions src/train/explicit_train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,26 +52,28 @@ function train!(loss::Function, model, data, opt::FluxState)
_initialise!(opt, model)
losses = Float32[]
s = opt.state
s isa IdDict && error("can't mix explicit & implicit!")
s isa IdDict && error("""Can't mix explicit & implicit modes!
Once `FluxState` is initialised by `train!` in one mode, it cannot be used in the other.""")
for d in data
l, (g, _...) = Zygote.withgradient(loss, model, train_ok(d)...)
l, (g, _...) = Zygote.withgradient(loss, model, data_splat(d)...)
s, model = Optimisers.update!(s, model, g)
push!(losses, l)
opt.state = s
end
return losses
return losses # Not entirely sure returning losses is a good idea. Flux 0.13 returns `nothing`.
end

train_ok(x::T) where T = error("""train! expects every d in data be a Tuple or a NamedTuple, got $T
To allow this type, define `Flux.Optimise.train_ok(x::$T) = (x,)`""")
train_ok(x::Tuple) = x
train_ok(x::NamedTuple) = x
data_splat(x::T) where T = error("""train! expects every d in data be a Tuple or a NamedTuple, got $T
To allow this type, define `Flux.Train.data_splat(x::$T) = (x,)`""")
data_splat(x::Tuple) = x
data_splat(x::NamedTuple) = x

function _initialise!(opt::FluxState, model)
if opt.state isa Missing
opt.state = Optimisers.setup(opt.rule, model)
fmap(model, exclude = Optimisers.isnumeric) do x
Optimisers.maywrite(x) || error("model must be fully mutable for train! to work, got $(typeof(x))")
Optimisers.maywrite(x) || error("""model must be fully mutable for train! to work, got x::$(typeof(x))
If `x .+= dx` is in fact ok, define `Optimisers.maywrite(::$(typeof(x))) = true`""")
end
end
opt
Expand Down Expand Up @@ -107,12 +109,12 @@ function train!(loss::Function, model, opt::FluxState)
l
end

# This method lets you use Optimisers.Descent() instead of Flux.Descent(), when there is no state
function train!(loss::Function, model, data, opt::Optimisers.AbstractRule)
_initialise!(opt, model)
# fmap(opt.state) do x
# x isa Union{Number, AbstractArray{<:Number}} && @warn "optimiser state will be lost!"
# x
# end # won't work as you need to look inside Leaf for non-nothings.
@warn "optimiser state will be lost!"
fmap(opt.state, exclude = x -> x isa Optimsers.Leaf) do leaf
leaf.state isa Nothing || @warn "Optimiser state will be lost! Please wrap optimisation rule in `FluxState`, e.g. by using `Flux.Adam()`" leaf
leaf
end
train!(loss, model, data, FluxState(opt))
end
8 changes: 6 additions & 2 deletions src/train/implicit_train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ function train!(loss::Function, pars::Params, data, opt::FluxState)
losses = Float32[]
for d in data
l, grads = Zygote.withgradient(() -> loss(batchmemaybe(d)...), pars)
update!(opt, pars, grads)
_update!(opt, pars, grads)
push!(losses, l)
end
return losses
Expand All @@ -49,7 +49,7 @@ function train!(loss::Function, pars::Params, opt::FluxState)
Explicit parameters are now preferred, see `train!(loss, model, data, opt)`""", :train!, force=true)
_initialise!(opt, pars)
l, grads = Zygote.withgradient(() -> loss(), pars)
update!(opt, pars, grads)
_update!(opt, pars, grads)
return l
end

Expand All @@ -68,6 +68,10 @@ Legacy method, mimicking the behaviour of Flux <= 0.13.
"""
function update!(opt::FluxState, xs::Params, gs)
Base.depwarn("Flux.update! is a legacy function", :update!)
_update!(opt, xs, gs)
end
# This _update! exists only so that train! above gives one depwarn, not two!
function _update!(opt::FluxState, xs::Params, gs)
for x in xs
isnothing(gs[x]) && continue
update!(opt, x, gs[x])
Expand Down
Loading

0 comments on commit d686232

Please sign in to comment.