From f19a7b18b711e39d77d383c402726ee5f5304c0e Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 19 Nov 2022 15:11:33 -0500 Subject: [PATCH 01/28] re-write training.md --- docs/src/training/training.md | 343 ++++++++++++++++++++-------------- 1 file changed, 206 insertions(+), 137 deletions(-) diff --git a/docs/src/training/training.md b/docs/src/training/training.md index 76aa40f5b8..fe32a977b2 100644 --- a/docs/src/training/training.md +++ b/docs/src/training/training.md @@ -1,210 +1,279 @@ # [Training](@id man-training) -To actually train a model we need four things: +Training refers to the process of slowly adjusting the parameters of a model to make it work better. +Besides the model itself, we will need three things: -* A *objective function*, that evaluates how well a model is doing given some input data. -* The trainable parameters of the model. -* A collection of data points that will be provided to the objective function. -* An [optimiser](optimisers.md) that will update the model parameters appropriately. +* An *objective function* that evaluates how well a model is doing on some input. +* An *optimisation rule* which describes how the model's parameters should be adjusted. +* Some *training data* to use as the input during this process. -Training a model is typically an iterative process, where we go over the data set, -calculate the objective function over the data points, and optimise that. -This can be visualised in the form of a simple loop. +Usually the training data is some collection of examples (or batches of examples) which +are handled one-by-one. One *epoch* of training means that each example is used once, +something like this: ```julia -for d in datapoints +for data in train_set + # Unpack this datapoint into the input and the + # desired result (for "supervised" training): + input, label = data + + # Calculate the gradient of the objective + # with respect to the parameters within the model: + grads = Flux.gradient(model) do m + result = m(input) + loss(result, label) + end - # `d` should produce a collection of arguments - # to the loss function + # Update the parameters so as to reduce the objective, + # according to a particular optimiser: + Flux.update!(opt, model, grads[1]) +end +``` - # Calculate the gradients of the parameters - # with respect to the loss function - grads = Flux.gradient(parameters) do - loss(d...) - end +This isn't pseudo-code, but is precisely how traning is done. +This loop can also be written using the function [`train!`](@ref Train.train!), +but it's helpful to undersand the pieces first: - # Update the parameters based on the chosen - # optimiser (opt) - Flux.Optimise.update!(opt, parameters, grads) +```julia +train!(model, train_set, opt) do m, x, y + loss(m(x), y) end ``` -To make it easy, Flux defines `train!`: +## Model Gradients + +Fist recall from the section on [taking gradients](@ref man-taking-gradients) that +`Flux.gradient(f, a, b)` always calls `f(a, b)`, and returns a tuple `(∂f_∂a, ∂f_∂b)`. +In the code above, the function `f` is an anonymous function with one argument, +created by the `do` block, hence `grads` is a tuple with one element. +Instead of a `do` block, we could have written: -```@docs -Flux.Optimise.train! +```julia +grads = Flux.gradient(m -> loss(m(input), label), model) ``` -There are plenty of examples in the [model zoo](https://github.com/FluxML/model-zoo), and -more information can be found on [Custom Training Loops](@ref man-advanced). +Since the model is some nested set of layers, `grads[1]` is a similarly nested set of +`NamedTuple`s, ultimately containing gradient components. These matching tree-like +structures are what Zygote calls "explicit" gradients. + +It is important that the execution of the model takes place inside the call to `gradient`, +in order for the influence of the model's parameters to be observed by Zygote. + +!!! note + Flux used to use Zygote's "implicit" mode, which looks like this: + ``` + pars = Flux.params(model) + grad = Flux.gradient(() -> loss(model(input), label), pars) + ``` + Here `pars::Params` and `grad::Grads` are two dictionary-like structures. + ## Loss Functions -The objective function must return a number representing how far the model is from its target – the *loss* of the model. The `loss` function that we defined in [basics](@ref man-basics) will work as an objective. -In addition to custom losses, model can be trained in conjuction with -the commonly used losses that are grouped under the `Flux.Losses` module. -We can also define an objective in terms of some model: +The objective function must return a number representing how far the model is from +the desired result. This is termed the *loss* of the model. +This number can be produced by any ordinary Julia code, but this must be executed +within the call to `gradient`. For instance, we could define a function ```julia -m = Chain( - Dense(784 => 32, σ), - Dense(32 => 10), softmax) +loss(y_hat, y) = sum((y_hat .- y).^2) +``` +or write this directly inside the `do` block above. Many commonly used functions, +like `mse` for mean squared error or `crossentropy` for cross-entropy loss, +are available from the [`Flux.Losses`](../models/losses.md) module. -loss(x, y) = Flux.Losses.mse(m(x), y) -ps = Flux.params(m) +!!! note + Flux used to need a loss function which closed over a reference to the model, + instead of being a pure function. Thus in old code you may see something like + ``` + loss(x, y) = sum((model(x) .- y).^2) + ``` + which defines a function making reference to a particular global variable `model`. + This is no longer the preferred style. -# later -Flux.train!(loss, ps, data, opt) -``` +## Optimisation Rules + +The simplest kind of optimisation using the gradient is termed *gradient descent* +(or sometimes *stochastic gradient descent* when it is applied to individual examples +in a loop, not to the entire dataset at once). -The objective will almost always be defined in terms of some *cost function* that measures the distance of the prediction `m(x)` from the target `y`. Flux has several of these built-in, like `mse` for mean squared error or `crossentropy` for cross-entropy loss, but you can calculate it however you want. -For a list of all built-in loss functions, check out the [losses reference](../models/losses.md). +This needs a *learning rate* which is a small number describing how fast to walk downhill, +usually written as the Greek letter "eta", `η`. -At first glance, it may seem strange that the model that we want to train is not part of the input arguments of `Flux.train!` too. However the target of the optimizer is not the model itself, but the objective function that represents the departure between modelled and observed data. In other words, the model is implicitly defined in the objective function, and there is no need to give it explicitly. Passing the objective function instead of the model and a cost function separately provides more flexibility and the possibility of optimizing the calculations. +```julia +η = 0.01 # learning rate -## Model parameters +# For each parameter array, update +# according to the corresponding gradient: +fmap(model, grads[1]) do p, g + p .= p .- η .* g +end +``` -The model to be trained must have a set of tracked parameters that are used to calculate the gradients of the objective function. In the [basics](@ref man-basics) section it is explained how to create models with such parameters. The second argument of the function `Flux.train!` must be an object containing those parameters, which can be obtained from a model `m` as `Flux.params(m)`. +This is wrapped up as a function `update!`, which can be used as follows: -Such an object contains a reference to the model's parameters, not a copy, such that after their training, the model behaves according to their updated values. +```julia +Flux.update!(Descent(0.01), model, grads[1]) +``` + +There are many other optimisation rules, which adjust the step size and direction. +Most require some memory of the gradients from earlier steps. The function `setup` +creates the necessary storage for this, for a particular model. This should be done +once, before training, and looks like this: -Handling all the parameters on a layer by layer basis is explained in the [Layer Helpers](@ref man-basics) section. Also, for freezing model parameters, see the [Advanced Usage Guide](@ref man-advanced). +```julia +# Initialise momentum +opt = Flux.setup(Adam(0.001), model) -```@docs -Flux.params +for data in train_set + ... + + # + Flux.update!(opt, model, grads[1]) +end ``` +Many commonly used optimisation rules, such as `Adam`, are built-in. +These are listed on the [optimisers](@ref man-optimisers) page. + + +!!! note + This `setep` makes another tree-like structure. Old versions of Flux did not do this, + and instead stored a dictionary-like structure within the optimiser `Adam(0.001)`. + This was initialised on first use of the version of `update!` for "implicit" parameters. + + ## Datasets -The `data` argument of `train!` provides a collection of data to train with (usually a set of inputs `x` and target outputs `y`). For example, here's a dummy dataset with only one data point: +The loop above iterates through `train_set`, expecting at each step a tuple `(input, label)`. +The very simplest such object is a vector of tuples, such as this: ```julia -x = rand(784) +x = randn(28, 28) y = rand(10) data = [(x, y)] ``` -`Flux.train!` will call `loss(x, y)`, calculate gradients, update the weights and then move on to the next data point if there is one. We can train the model on the same data three times: +or `data = [(x, y), (x, y), (x, y)]` for the same values three times. -```julia -data = [(x, y), (x, y), (x, y)] -# Or equivalently -using IterTools: ncycle -data = ncycle([(x, y)], 3) -``` - -It's common to load the `x`s and `y`s separately. Here you can use `zip`: +To get data into this format, you might want `zip` to combine a list of different `x`s +with a list of different `y`s: ```julia -xs = [rand(784), rand(784), rand(784)] -ys = [rand( 10), rand( 10), rand( 10)] +xs = [rand(28, 28), rand(28, 28), rand(28, 28)] +ys = [rand(10), rand(10), rand(10)] data = zip(xs, ys) + +first(data) isa Tuple{Matrix, Vector} # true ``` -Training data can be conveniently partitioned for mini-batch training using the [`Flux.Data.DataLoader`](@ref) type: +Very often, the initial data is large arrays which you need to slice into examples: ```julia -X = rand(28, 28, 60000) -Y = rand(0:9, 60000) -data = DataLoader((X, Y), batchsize=128) +X = rand(28, 28, 60_000) +Y = rand(10, 60_000) +data = zip(eachslice(X; dims=3), eachcol(Y)) + +first(data) isa Tuple{Matrix, Vector} # true ``` -Note that, by default, `train!` only loops over the data once (a single "epoch"). -A convenient way to run multiple epochs from the REPL is provided by `@epochs`. +Here each iteration will use one matrix `x` (an image, perhaps) and one vector `y`. +It is very common to instead train on *batches* of such inputs (or *mini-batches*, +the two words mean the same thing) both for efficiency and for better results. +This can be easily done using the [`DataLoader`](@ref Flux.Data.DataLoader): ```julia -julia> using Flux: @epochs - -julia> @epochs 2 println("hello") -[ Info: Epoch 1 -hello -[ Info: Epoch 2 -hello +X = rand(28, 28, 60_000) +Y = rand(0:9, 60_000) +data = Flux.DataLoader((X, Y), batchsize=32) -julia> @epochs 2 Flux.train!(...) -# Train for two epochs +x1, y1 = first(data) +size(x1) == (28, 28, 32) +length(data) == 1875 === 60_000 ÷ 32 ``` -```@docs -Flux.@epochs -``` +Flux's layers are set up to accept such a batch of input data, +and the convolutional layers such as [Conv](@ref Flux.Conv) require it. -## Callbacks -`train!` takes an additional argument, `cb`, that's used for callbacks so that you can observe the training process. For example: +## Training Loops + +Very simple training loops like the one above can be written compactly using +the [`train!`](@ref) function. Including `setup`, this reads: ```julia -train!(objective, ps, data, opt, cb = () -> println("training")) -``` +opt = Flux.setup(Adam(), model) -Callbacks are called for every batch of training data. You can slow this down using `Flux.throttle(f, timeout)` which prevents `f` from being called more than once every `timeout` seconds. +train!(model, train_set, opt) do m, x, y + loss(m(x), y) +end +``` -A more typical callback might look like this: +!!! note + This is the "explicit" method of `train!`, which takes the result of `setup` as its 4th argument. + The 1st argument (from the `do` block) is a function which accepts the model itself. + Old Flux versions provided a method of `train!` for "implicit" parameters, + which works like this: + ``` + train!((x,y) -> loss(model(x), y), Flux.params(model), train_set, Adam()) + ``` + +Real training loops often need more flexibility, and the best way to do this is just +to write the loop. This is ordinary Julia code, without any need to work through some +callback API. Here is an example, in which it may be helpful to note: + +* The function [`withgradient`](@ref Zygote.withgradient) is like `gradient` but also + returns the value of the function, for logging or diagnostic use. +* Logging or printing is best done outside of the `gradient` call, + as there is no need to differentiate these commands. +* Julia's `break` and `continue` keywords let you exit from parts of the loop. ```julia -test_x, test_y = # ... create single batch of test data ... -evalcb() = @show(loss(test_x, test_y)) -throttled_cb = throttle(evalcb, 5) -Flux.@epochs 20 Flux.train!(objective, ps, data, opt, cb = throttled_cb) -``` +opt = Flux.setup(Adam(), model) + +log = [] +for epoch in 1:100 + losses = Float32[] + for (i, data) in enumerate(train_set) + input, label = data + + val, grads = Flux.withgradient(model) do m + # Any code inside here is differentiated. + # Evaluation of the model and loss must be inside! + result = m(input) + my_loss(result, label) + end -Calling `Flux.stop()` in a callback will exit the training loop early. + # Save the loss from the forward pass. (Done outside of gradient.) + push!(losses, val) -```julia -cb = function () - accuracy() > 0.9 && Flux.stop() -end -``` + # Detect loss of Inf or NaN. Print a warning, and then skip update! + if !isfinite(val) + @warn "loss is $val on item $i" epoch + continue + end -## Custom Training loops + Flux.update!(opt, model, grads[1]) + end -The `Flux.train!` function can be very convenient, especially for simple problems. -For some problems, however, it's much cleaner to write your own custom training loop. -An example follows that works similar to the default `Flux.train` but with no callbacks. -You don't need callbacks if you just code the calls to your functions directly into the loop. -E.g. in the places marked with comments. + # Compute some accuracy, and save details to log + acc = my_accuracy(model, train_set) + push!(log, (; acc, losses)) -```julia -function my_custom_train!(loss, ps, data, opt) - # training_loss is declared local so it will be available for logging outside the gradient calculation. - local training_loss - ps = Params(ps) - for d in data - gs = gradient(ps) do - training_loss = loss(d...) - # Code inserted here will be differentiated, unless you need that gradient information - # it is better to do the work outside this block. - return training_loss - end - # Insert whatever code you want here that needs training_loss, e.g. logging. - # logging_callback(training_loss) - # Insert whatever code you want here that needs gradients. - # e.g. logging histograms with TensorBoardLogger.jl to check for exploding gradients. - update!(opt, ps, gs) - # Here you might like to check validation set accuracy, and break out to do early stopping. + # Stop training when some criterion is reached + if acc > 0.95 + println("stopping after $epoch epochs") + break end end ``` -You could simplify this further, for example by hard-coding in the loss function. -Another possibility is to use [`Zygote.pullback`](https://fluxml.ai/Zygote.jl/dev/adjoints/#Pullbacks-1) -to access the training loss and the gradient simultaneously. +## Implicit vs Explicit + +Flux used to handle gradients, training, and optimisation rules quite differently. +The new style described above is called "explicit" by Zygote, and the old style "implicit". +Flux 0.13 is the transitional version which supports both. + +For full details on the implicit style, see [Flux 0.13.6 documentation](https://fluxml.ai/Flux.jl/v0.13.6/training/training/). -```julia -function my_custom_train!(loss, ps, data, opt) - ps = Params(ps) - for d in data - # back is a method that computes the product of the gradient so far with its argument. - train_loss, back = Zygote.pullback(() -> loss(d...), ps) - # Insert whatever code you want here that needs training_loss, e.g. logging. - # logging_callback(training_loss) - # Apply back() to the correct type of 1.0 to get the gradient of loss. - gs = back(one(train_loss)) - # Insert whatever code you want here that needs gradient. - # E.g. logging with TensorBoardLogger.jl as histogram so you can see if it is becoming huge. - update!(opt, ps, gs) - # Here you might like to check validation set accuracy, and break out to do early stopping. - end -end -``` From 97dd55661e63ba76ef032f5f136c13b3a2005ac9 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 19 Nov 2022 15:12:04 -0500 Subject: [PATCH 02/28] add train_api page for docstrings --- docs/make.jl | 3 +- docs/src/training/train_api.md | 71 ++++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 2 deletions(-) create mode 100644 docs/src/training/train_api.md diff --git a/docs/make.jl b/docs/make.jl index ee836b216b..70256de407 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -18,7 +18,6 @@ makedocs( "Fitting a Line" => "models/overview.md", "Gradients and Layers" => "models/basics.md", "Training" => "training/training.md", - "Regularisation" => "models/regularisation.md", # consolidated in #2114 "Recurrence" => "models/recurrence.md", "GPU Support" => "gpu.md", "Saving & Loading" => "saving.md", @@ -31,7 +30,7 @@ makedocs( "Activation Functions" => "models/activation.md", "Weight Initialisation" => "utilities.md", "Loss Functions" => "models/losses.md", - "Optimisation Rules" => "training/optimisers.md", # TODO move optimiser intro up to Training + "Optimisation Rules" => "training/optimisers.md", "Shape Inference" => "outputsize.md", "Flat vs. Nested" => "destructure.md", "Callback Helpers" => "training/callbacks.md", diff --git a/docs/src/training/train_api.md b/docs/src/training/train_api.md new file mode 100644 index 0000000000..5ef6ba12a1 --- /dev/null +++ b/docs/src/training/train_api.md @@ -0,0 +1,71 @@ +# Training API + + +```@docs +Flux.Train.setup +Flux.Train.update! +Flux.Train.train! +``` + +## Implicit style + +Flux used to handle gradients, training, and optimisation rules quite differently. +The new style described above is called "explicit" by Zygote, and the old style "implicit". +Flux 0.13 is the transitional version which supports both. + +For full details on how to use the implicit style, see [Flux 0.13.6 manual](https://fluxml.ai/Flux.jl/v0.13.6/training/training/). + + +```@docs +Flux.params +Flux.Optimise.update! +Flux.Optimise.train! +``` + + +Note that, by default, `train!` only loops over the data once (a single "epoch"). +A convenient way to run multiple epochs from the REPL is provided by `@epochs`. + +```julia +julia> using Flux: @epochs + +julia> @epochs 2 println("hello") +[ Info: Epoch 1 +hello +[ Info: Epoch 2 +hello + +julia> @epochs 2 Flux.train!(...) +# Train for two epochs +``` + +```@docs +Flux.@epochs +``` + +## Callbacks + +`train!` takes an additional argument, `cb`, that's used for callbacks so that you can observe the training process. For example: + +```julia +train!(objective, ps, data, opt, cb = () -> println("training")) +``` + +Callbacks are called for every batch of training data. You can slow this down using `Flux.throttle(f, timeout)` which prevents `f` from being called more than once every `timeout` seconds. + +A more typical callback might look like this: + +```julia +test_x, test_y = # ... create single batch of test data ... +evalcb() = @show(loss(test_x, test_y)) +throttled_cb = throttle(evalcb, 5) +Flux.@epochs 20 Flux.train!(objective, ps, data, opt, cb = throttled_cb) +``` + +Calling `Flux.stop()` in a callback will exit the training loop early. + +```julia +cb = function () + accuracy() > 0.9 && Flux.stop() +end +``` From 3ffe394f188a56782e14f62bc07e5a5af1ce5c56 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 19 Nov 2022 15:12:21 -0500 Subject: [PATCH 03/28] update basic.md to introduce explicit not implicit --- docs/src/models/basics.md | 80 +++++++++++++++++++++++++-------- docs/src/training/optimisers.md | 2 +- 2 files changed, 62 insertions(+), 20 deletions(-) diff --git a/docs/src/models/basics.md b/docs/src/models/basics.md index 140ed7e13d..71f5b5b50f 100644 --- a/docs/src/models/basics.md +++ b/docs/src/models/basics.md @@ -1,6 +1,6 @@ # [How Flux Works: Gradients and Layers](@id man-basics) -## Taking Gradients +## [Taking Gradients](@id man-taking-gradients) Flux's core feature is taking gradients of Julia code. The `gradient` function takes another Julia function `f` and a set of arguments, and returns the gradient with respect to each argument. (It's a good idea to try pasting these examples in the Julia terminal.) @@ -29,35 +29,77 @@ julia> gradient(f, [2, 1], [2, 0]) ([0.0, 2.0], [-0.0, -2.0]) ``` -These gradients are based on `x` and `y`. Flux works by instead taking gradients based on the weights and biases that make up the parameters of a model. +These gradients are based on `x` and `y`. Flux works by instead taking gradients based on the weights and biases that make up the parameters of a model. - -Machine learning often can have *hundreds* of parameters, so Flux lets you work with collections of parameters, via the `params` functions. You can get the gradient of all parameters used in a program without explicitly passing them in. +Machine learning often can have *hundreds* of parameter arrays. +Instead of passing them to `gradient` individually, we can store them together in a structure. +The simplest example is a named tuple, created by the following syntax: ```jldoctest basics -julia> x = [2, 1]; +julia> nt = (a = [2, 1], b = [2, 0], c = abs2); + +julia> g(x::NamedTuple) = sum(abs2, x.a .- x.b); + +julia> g(nt) +1 + +julia> dg_nt = gradient(g, nt)[1] +(a = [0.0, 2.0], b = [-0.0, -2.0], c = nothing) +``` + +Notice that `gradient` has returned a matching structure. The field `dg_nt.a` is the gradient +for `nt.a`, and so on. Some fields have no gradient, indicated by `nothing`. -julia> y = [2, 0]; +Rather than define a function like `g` every time (and think up a name for it), +it is often useful to use anonymous functions: this one is `x -> sum(abs2, x.a .- x.b)`. +Anonymous functions can be defined either with `->` or with `do`, +and such `do` blocks are often useful if you have a few steps to perform: + +```jldoctest basics +julia> gradient((x, y) -> sum(abs2, x.a ./ y .- x.b), nt, [1, 2]) +((a = [0.0, 0.5], b = [-0.0, -1.0], c = nothing), [-0.0, -0.25]) -julia> gs = gradient(Flux.params(x, y)) do - f(x, y) +julia> gradient(nt, [1, 2]) do x, y + z = x.a ./ y + sum(x.c, z .- x.b) end -Grads(...) +((a = [0.0, 0.5], b = [-0.0, -1.0], c = nothing), [-0.0, -0.25]) +``` -julia> gs[x] -2-element Vector{Float64}: - 0.0 - 2.0 +Sometimes you may want to know the value of the function, as well as its gradient. +Rather than calling the function a second time, you can call [`withgradient`](@ref Zygote.withgradient) instead: -julia> gs[y] -2-element Vector{Float64}: - -0.0 - -2.0 ``` +julia> Flux.withgradient(g, nt) +(val = 1, grad = ((a = [0.0, 2.0], b = [-0.0, -2.0], c = nothing),)) +``` + +!!! note + Flux used to handle many parameters in a different way, using the [`params`](@ref Flux.params) function. + This uses a method of `gradient` which takes a zero-argument function, and returns a dictionary + through which the resulting gradients can be looked up: + + ```jldoctest basics + julia> x = [2, 1]; + + julia> y = [2, 0]; + + julia> gs = gradient(Flux.params(x, y)) do + f(x, y) + end + Grads(...) + + julia> gs[x] + 2-element Vector{Float64}: + 0.0 + 2.0 -Here, `gradient` takes a zero-argument function; no arguments are necessary because the `params` tell it what to differentiate. + julia> gs[y] + 2-element Vector{Float64}: + -0.0 + -2.0 + ``` -This will come in really handy when dealing with big, complicated models. For now, though, let's start with something simple. ## Building Simple Models diff --git a/docs/src/training/optimisers.md b/docs/src/training/optimisers.md index 8b3a86d975..87bbd33aa8 100644 --- a/docs/src/training/optimisers.md +++ b/docs/src/training/optimisers.md @@ -2,7 +2,7 @@ CurrentModule = Flux ``` -# Optimisers +# [Optimisers](@id man-optimisers) Consider a [simple linear regression](@ref man-linear-regression). We create some dummy data, calculate a loss, and backpropagate to calculate gradients for the parameters `W` and `b`. From 01b409458b9413779da0e62fc54f0d5d53148a9f Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 19 Nov 2022 15:52:59 -0500 Subject: [PATCH 04/28] more links, comments on notes --- docs/src/models/basics.md | 2 +- docs/src/training/training.md | 18 ++++++++---------- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/docs/src/models/basics.md b/docs/src/models/basics.md index 71f5b5b50f..f142292621 100644 --- a/docs/src/models/basics.md +++ b/docs/src/models/basics.md @@ -74,7 +74,7 @@ julia> Flux.withgradient(g, nt) (val = 1, grad = ((a = [0.0, 2.0], b = [-0.0, -2.0], c = nothing),)) ``` -!!! note +!!! note "Implicit gradients" Flux used to handle many parameters in a different way, using the [`params`](@ref Flux.params) function. This uses a method of `gradient` which takes a zero-argument function, and returns a dictionary through which the resulting gradients can be looked up: diff --git a/docs/src/training/training.md b/docs/src/training/training.md index fe32a977b2..09ee69fa14 100644 --- a/docs/src/training/training.md +++ b/docs/src/training/training.md @@ -59,7 +59,7 @@ structures are what Zygote calls "explicit" gradients. It is important that the execution of the model takes place inside the call to `gradient`, in order for the influence of the model's parameters to be observed by Zygote. -!!! note +!!! note "Explicit vs implicit gradients" Flux used to use Zygote's "implicit" mode, which looks like this: ``` pars = Flux.params(model) @@ -79,10 +79,10 @@ within the call to `gradient`. For instance, we could define a function loss(y_hat, y) = sum((y_hat .- y).^2) ``` or write this directly inside the `do` block above. Many commonly used functions, -like `mse` for mean squared error or `crossentropy` for cross-entropy loss, +like [`mse`](@ref Flux.Losses.mse) for mean-squared error or [`crossentropy`](@ref Flux.Losses.crossentropy) for cross-entropy loss, are available from the [`Flux.Losses`](../models/losses.md) module. -!!! note +!!! note "Implicit-style loss functions" Flux used to need a loss function which closed over a reference to the model, instead of being a pure function. Thus in old code you may see something like ``` @@ -110,14 +110,14 @@ fmap(model, grads[1]) do p, g end ``` -This is wrapped up as a function `update!`, which can be used as follows: +This is wrapped up as a function [`update!`](@ref Flux.Optimise.update!), which can be used as follows: ```julia Flux.update!(Descent(0.01), model, grads[1]) ``` There are many other optimisation rules, which adjust the step size and direction. -Most require some memory of the gradients from earlier steps. The function `setup` +Most require some memory of the gradients from earlier steps. The function [`setup`](@ref Flux.Train.setup) creates the necessary storage for this, for a particular model. This should be done once, before training, and looks like this: @@ -133,11 +133,11 @@ for data in train_set end ``` -Many commonly used optimisation rules, such as `Adam`, are built-in. +Many commonly used optimisation rules, such as [`Adam`](@ref Flux.Optimise.Adam), are built-in. These are listed on the [optimisers](@ref man-optimisers) page. -!!! note +!!! note "Implicit-style optimiser state" This `setep` makes another tree-like structure. Old versions of Flux did not do this, and instead stored a dictionary-like structure within the optimiser `Adam(0.001)`. This was initialised on first use of the version of `update!` for "implicit" parameters. @@ -183,8 +183,6 @@ the two words mean the same thing) both for efficiency and for better results. This can be easily done using the [`DataLoader`](@ref Flux.Data.DataLoader): ```julia -X = rand(28, 28, 60_000) -Y = rand(0:9, 60_000) data = Flux.DataLoader((X, Y), batchsize=32) x1, y1 = first(data) @@ -209,7 +207,7 @@ train!(model, train_set, opt) do m, x, y end ``` -!!! note +!!! note "Implicit-style `train!`" This is the "explicit" method of `train!`, which takes the result of `setup` as its 4th argument. The 1st argument (from the `do` block) is a function which accepts the model itself. Old Flux versions provided a method of `train!` for "implicit" parameters, From 250f0bdccb954b368bdf5b682dee671ab2cc124e Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 20 Nov 2022 18:15:18 -0500 Subject: [PATCH 05/28] updates, rm some Optimisers detail --- docs/src/training/optimisers.md | 83 ++++----------------------------- docs/src/training/train_api.md | 19 +++++--- docs/src/training/training.md | 2 +- src/Flux.jl | 4 +- src/optimise/optimisers.jl | 6 +++ 5 files changed, 30 insertions(+), 84 deletions(-) diff --git a/docs/src/training/optimisers.md b/docs/src/training/optimisers.md index 87bbd33aa8..600cae4d88 100644 --- a/docs/src/training/optimisers.md +++ b/docs/src/training/optimisers.md @@ -4,53 +4,24 @@ CurrentModule = Flux # [Optimisers](@id man-optimisers) -Consider a [simple linear regression](@ref man-linear-regression). We create some dummy data, calculate a loss, and backpropagate to calculate gradients for the parameters `W` and `b`. +Flux builds in many optimisation rules for use with [`train!`](@ref Flux.Optimise.train!) and +other training functions. -```julia -using Flux - -W = rand(2, 5) -b = rand(2) - -predict(x) = (W * x) .+ b -loss(x, y) = sum((predict(x) .- y).^2) +The mechanism by which these work is gradually being replaced as part of the change +from "implicit" dictionary-based to "explicit" tree-like structures. +At present, the same struct (such as `Adam`) can be used with either form, +and will be automatically translated. -x, y = rand(5), rand(2) # Dummy data -l = loss(x, y) # ~ 3 - -θ = Flux.params(W, b) -grads = gradient(() -> loss(x, y), θ) -``` +For full details of how the new "explicit" interface works, see the [Optimisers.jl documentation](https://fluxml.ai/Optimisers.jl/dev/). -We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that: - -```julia -η = 0.1 # Learning Rate -for p in (W, b) - p .-= η * grads[p] -end -``` +For full details on how the "implicit" interface worked, see the [Flux 0.13.6 manual](https://fluxml.ai/Flux.jl/v0.13.6/training/optimisers/#Optimiser-Interface). -Running this will alter the parameters `W` and `b` and our loss should go down. Flux provides a more general way to do optimiser updates like this. - -```julia -using Flux: update! - -opt = Descent(0.1) # Gradient descent with learning rate 0.1 - -for p in (W, b) - update!(opt, p, grads[p]) -end -``` - -An optimiser `update!` accepts a parameter and a gradient, and updates the parameter according to the chosen rule. We can also pass `opt` to our [training loop](training.md), which will update all parameters of the model in a loop. However, we can now easily replace `Descent` with a more advanced optimiser such as `Adam`. ## Optimiser Reference All optimisers return an object that, when passed to `train!`, will update the parameters passed to it. ```@docs -Flux.Optimise.update! Descent Momentum Nesterov @@ -67,44 +38,6 @@ OAdam AdaBelief ``` -## Optimiser Interface - -Flux's optimisers are built around a `struct` that holds all the optimiser parameters along with a definition of how to apply the update rule associated with it. We do this via the `apply!` function which takes the optimiser as the first argument followed by the parameter and its corresponding gradient. - -In this manner Flux also allows one to create custom optimisers to be used seamlessly. Let's work on this with a simple example. - -```julia -mutable struct Momentum - eta - rho - velocity -end - -Momentum(eta::Real, rho::Real) = Momentum(eta, rho, IdDict()) -``` - -The `Momentum` type will act as our optimiser in this case. Notice that we have added all the parameters as fields, along with the velocity which we will use as our state dictionary. Each parameter in our models will get an entry in there. We can now define the rule applied when this optimiser is invoked. - -```julia -function Flux.Optimise.apply!(o::Momentum, x, Δ) - η, ρ = o.eta, o.rho - v = get!(o.velocity, x, zero(x))::typeof(x) - @. v = ρ * v - η * Δ - @. Δ = -v -end -``` - -This is the basic definition of a Momentum update rule given by: - -```math -v = ρ * v - η * Δ -w = w - v -``` - -The `apply!` defines the update rules for an optimiser `opt`, given the parameters and gradients. It returns the updated gradients. Here, every parameter `x` is retrieved from the running state `v` and subsequently updates the state of the optimiser. - -Flux internally calls on this function via the `update!` function. It shares the API with `apply!` but ensures that multiple parameters are handled gracefully. - ## Composing Optimisers Flux defines a special kind of optimiser simply called `Optimiser` which takes in arbitrary optimisers as input. Its behaviour is similar to the usual optimisers, but differs in that it acts by calling the optimisers listed in it sequentially. Each optimiser produces a modified gradient diff --git a/docs/src/training/train_api.md b/docs/src/training/train_api.md index 5ef6ba12a1..81ad86d012 100644 --- a/docs/src/training/train_api.md +++ b/docs/src/training/train_api.md @@ -1,10 +1,16 @@ # Training API - ```@docs Flux.Train.setup -Flux.Train.update! -Flux.Train.train! +Flux.Optimise.train!(loss, model, data, opt; cb) +``` + +The new version of Flux's training code was written as an independent package, called Optimisers.jl. +However, at present all Flux models contain parameter arrays (such as `Array`s and `CuArray`s) +which can be updated in-place. Thus objects returned by `update!` can be ignored. + +```@docs +Optimisers.update! ``` ## Implicit style @@ -15,14 +21,12 @@ Flux 0.13 is the transitional version which supports both. For full details on how to use the implicit style, see [Flux 0.13.6 manual](https://fluxml.ai/Flux.jl/v0.13.6/training/training/). - ```@docs Flux.params -Flux.Optimise.update! -Flux.Optimise.train! +Optimisers.update!(opt::Flux.Optimise.AbstractOptimiser, xs::Flux.Params, gs) +Flux.Optimise.train!(loss, ps::Flux.Params, data, opt::Flux.Optimise.AbstractOptimiser; cb) ``` - Note that, by default, `train!` only loops over the data once (a single "epoch"). A convenient way to run multiple epochs from the REPL is provided by `@epochs`. @@ -69,3 +73,4 @@ cb = function () accuracy() > 0.9 && Flux.stop() end ``` + diff --git a/docs/src/training/training.md b/docs/src/training/training.md index 09ee69fa14..48d4126ed8 100644 --- a/docs/src/training/training.md +++ b/docs/src/training/training.md @@ -1,4 +1,4 @@ -# [Training](@id man-training) +# [Training a Flux Model](@id man-training) Training refers to the process of slowly adjusting the parameters of a model to make it work better. Besides the model itself, we will need three things: diff --git a/src/Flux.jl b/src/Flux.jl index 3853712f7b..53749b85fa 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -34,9 +34,11 @@ export Descent, Adam, Momentum, Nesterov, RMSProp, AdamW, RAdam, AdaBelief, InvDecay, ExpDecay, WeightDecay, ClipValue, ClipNorm +export ClipGrad, OptimiserChain # these are const defined in deprecations, for ClipValue, Optimiser + include("train.jl") using .Train -# using .Train: setup, @train_autodiff +using .Train: setup using CUDA const use_cuda = Ref{Union{Nothing,Bool}}(nothing) diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index ce72a4b0ce..8a60ba97a5 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -564,6 +564,9 @@ end Combine several optimisers into one; each optimiser produces a modified gradient that will be fed into the next, and this is finally applied to the parameter as usual. + +!!! note + This will be replaced by `Optimisers.OptimiserChain` in Flux 0.14. """ mutable struct Optimiser <: AbstractOptimiser os::Vector{Any} @@ -699,6 +702,9 @@ end ClipValue(thresh) Clip gradients when their absolute value exceeds `thresh`. + +!!! note + This will be replaced by `Optimisers.ClipGrad` in Flux 0.14. """ mutable struct ClipValue{T} <: AbstractOptimiser thresh::T From 8e023fd1fb124ec6b34262d9c5bd4c37a78b83ce Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 20 Nov 2022 22:28:48 -0500 Subject: [PATCH 06/28] mention TerminalLoggers --- docs/src/training/train_api.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/src/training/train_api.md b/docs/src/training/train_api.md index 81ad86d012..79148023af 100644 --- a/docs/src/training/train_api.md +++ b/docs/src/training/train_api.md @@ -5,6 +5,10 @@ Flux.Train.setup Flux.Optimise.train!(loss, model, data, opt; cb) ``` +`train!` uses [`@progress`](https://github.com/JuliaLogging/ProgressLogging.jl) which should show a progress bar in VSCode. +To see one in a terminal, you will need to install [TerminalLoggers.jl](https://github.com/JuliaLogging/TerminalLoggers.jl) +and follow its setup instructions. + The new version of Flux's training code was written as an independent package, called Optimisers.jl. However, at present all Flux models contain parameter arrays (such as `Array`s and `CuArray`s) which can be updated in-place. Thus objects returned by `update!` can be ignored. From 1348b7036afb732e1f719ef71818d1f0dff82ce3 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 21 Nov 2022 17:49:31 -0500 Subject: [PATCH 07/28] tweaks --- docs/src/training/callbacks.md | 2 +- docs/src/training/train_api.md | 10 ++++++---- docs/src/training/training.md | 28 ++++++++++++++-------------- 3 files changed, 21 insertions(+), 19 deletions(-) diff --git a/docs/src/training/callbacks.md b/docs/src/training/callbacks.md index 99c80986f1..6e9840ad1d 100644 --- a/docs/src/training/callbacks.md +++ b/docs/src/training/callbacks.md @@ -1,4 +1,4 @@ -# Callback Helpers +# [Callback Helpers](@id man-callback-helpers) ```@docs Flux.throttle diff --git a/docs/src/training/train_api.md b/docs/src/training/train_api.md index 79148023af..bd2f0c83c3 100644 --- a/docs/src/training/train_api.md +++ b/docs/src/training/train_api.md @@ -21,9 +21,9 @@ Optimisers.update! Flux used to handle gradients, training, and optimisation rules quite differently. The new style described above is called "explicit" by Zygote, and the old style "implicit". -Flux 0.13 is the transitional version which supports both. +Flux 0.13 is the transitional version which supports both; Flux 0.14 will remove the old. -For full details on how to use the implicit style, see [Flux 0.13.6 manual](https://fluxml.ai/Flux.jl/v0.13.6/training/training/). +For full details on the interface for implicit-style optimisers, see the [Flux 0.13.6 manual](https://fluxml.ai/Flux.jl/v0.13.6/training/training/). ```@docs Flux.params @@ -51,9 +51,9 @@ julia> @epochs 2 Flux.train!(...) Flux.@epochs ``` -## Callbacks +### Callbacks -`train!` takes an additional argument, `cb`, that's used for callbacks so that you can observe the training process. For example: +Implicit `train!` takes an additional argument, `cb`, that's used for callbacks so that you can observe the training process. For example: ```julia train!(objective, ps, data, opt, cb = () -> println("training")) @@ -78,3 +78,5 @@ cb = function () end ``` +See the page about [callback helpers](@ref man-callback-helpers) for more. + diff --git a/docs/src/training/training.md b/docs/src/training/training.md index 48d4126ed8..57c9638332 100644 --- a/docs/src/training/training.md +++ b/docs/src/training/training.md @@ -97,8 +97,8 @@ The simplest kind of optimisation using the gradient is termed *gradient descent (or sometimes *stochastic gradient descent* when it is applied to individual examples in a loop, not to the entire dataset at once). -This needs a *learning rate* which is a small number describing how fast to walk downhill, -usually written as the Greek letter "eta", `η`. +Gradient descent needs a *learning rate* which is a small number describing how fast to walk downhill, +usually written as the Greek letter "eta", `η`. This is what it does: ```julia η = 0.01 # learning rate @@ -110,16 +110,14 @@ fmap(model, grads[1]) do p, g end ``` -This is wrapped up as a function [`update!`](@ref Flux.Optimise.update!), which can be used as follows: - -```julia -Flux.update!(Descent(0.01), model, grads[1]) -``` +This update of all parameters is wrapepd up as a function [`update!`](@ref Flux.Optimise.update!)`(opt, model, grads[1])`. There are many other optimisation rules, which adjust the step size and direction. -Most require some memory of the gradients from earlier steps. The function [`setup`](@ref Flux.Train.setup) -creates the necessary storage for this, for a particular model. This should be done -once, before training, and looks like this: +Most require some memory of the gradients from earlier steps, rather than always +walking straight downhill. The function [`setup`](@ref Flux.Train.setup) creates the +necessary storage for this, for a particular model. +It should be called once, before training, and returns a tree-like object which is the +first argument of `update!`. Like this: ```julia # Initialise momentum @@ -128,7 +126,7 @@ opt = Flux.setup(Adam(0.001), model) for data in train_set ... - # + # Update both model parameters and optimiser state: Flux.update!(opt, model, grads[1]) end ``` @@ -138,7 +136,7 @@ These are listed on the [optimisers](@ref man-optimisers) page. !!! note "Implicit-style optimiser state" - This `setep` makes another tree-like structure. Old versions of Flux did not do this, + This `setup` makes another tree-like structure. Old versions of Flux did not do this, and instead stored a dictionary-like structure within the optimiser `Adam(0.001)`. This was initialised on first use of the version of `update!` for "implicit" parameters. @@ -266,12 +264,14 @@ for epoch in 1:100 end ``` - ## Implicit vs Explicit Flux used to handle gradients, training, and optimisation rules quite differently. The new style described above is called "explicit" by Zygote, and the old style "implicit". Flux 0.13 is the transitional version which supports both. -For full details on the implicit style, see [Flux 0.13.6 documentation](https://fluxml.ai/Flux.jl/v0.13.6/training/training/). +The blue boxes above describe the changes. +For more details on training in the implicit style, see [Flux 0.13.6 documentation](https://fluxml.ai/Flux.jl/v0.13.6/training/training/). + +For details about the two gradient modes, see [Zygote's documentation](https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1). From 0ec52977bbf6f5d154133ed1d45bc41964aeef04 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 21 Nov 2022 17:50:45 -0500 Subject: [PATCH 08/28] perhaps we should build regularisation into the same page --- docs/src/training/training.md | 63 +++++++++++++++++++++++++++++++++++ test/train.jl | 41 +++++++++++++++++++++++ 2 files changed, 104 insertions(+) diff --git a/docs/src/training/training.md b/docs/src/training/training.md index 57c9638332..7c34e7b9a8 100644 --- a/docs/src/training/training.md +++ b/docs/src/training/training.md @@ -275,3 +275,66 @@ For more details on training in the implicit style, see [Flux 0.13.6 documentati For details about the two gradient modes, see [Zygote's documentation](https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1). +## Regularisation + +The term *regularisation* covers a wide variety of techniques aiming to improve the +result of training. This is often done to avoid overfitting. + +Some of these are can be implemented by simply modifying the loss function. +L2 or ... umm ... adds to the loss a penalty proportional to `θ^2` for every scalar parameter, +and for a simple model could be implemented as follows: + +```julia +Flux.gradient(model) do m + result = m(input) + penalty = sum(abs2, m.weight)/2 + sum(abs2, m.bias)/2 + my_loss(result, label) + 0.42 * penalty +end +``` + +Accessing each individual parameter array by hand won't work well for large models. +Instead, we can use [`Flux.params`](@ref) to collect all of them, +and then apply a function to each one, and sum the result: + +```julia +pen_l2(x::AbstractArray) = sum(abs2, x)/2 + +Flux.gradient(model) do m + result = m(input) + penalty = sum(pen_l2, Flux.params(m)) + my_loss(result, label) + 0.42 * penalty +end +``` + +However, the gradient of this penalty term is very simple: It is proportional to the original weights. +So there is a simpler way to implement exactly the same thing, by modifying the optimiser +instead of the loss function. This is done by replacing this: + +```julia +opt = Flux.setup(Adam(0.1), model) +``` + +with this: + +```julia +decay_opt = Flux.setup(OptimiserChain(WeightDecay(0.42), Adam(0.1)), model) +``` + +Flux's optimisers are really modifications applied to the gradient before using it to update +the parameters, and `OptimiserChain` applies two such modifications. +The first, [`WeightDecay`](@ref) adds `0.42` times original parameter to the gradient, +matching the gradient of the penalty above (with the same, unrealistically large, constant). +After that, in either case, [`Adam`](@ref) computes the final update. + +The same mechanism can be used for other purposes, such as gradient clipping with [`ClipGrad`](@ref ). + +Besides L2 / weight decay, another common and quite different kind of regularisation is +provided by the [`Dropout`](@ref Flux.Dropout) layer. This turns off some ... ?? + +?? do we discuss test/train mode here too? + +## Freezing, Schedules + +?? maybe these also fit in here. + + diff --git a/test/train.jl b/test/train.jl index f8d66a1e4b..8c3c726184 100644 --- a/test/train.jl +++ b/test/train.jl @@ -98,3 +98,44 @@ end @test y5 < y4 end +@testset "L2 regularisation" begin + # New docs claim an exact equivalent. It's a bit long to put the example in there, + # but perhaps the tests should contain it. + + model = Dense(3 => 2, tanh); + init_weight = copy(model.weight); + data = [(randn(Float32, 3,5), randn(Float32, 2,5)) for _ in 1:10]; + + # Take 1: explicitly add a penalty in the loss function + opt = Flux.setup(Adam(0.1), model) + Flux.train!(model, data, opt) do m, x, y + err = Flux.mse(m(x), y) + l2 = sum(abs2, m.weight)/2 + sum(abs2, m.bias)/2 + err + 0.33 * l2 + end + diff1 = model.weight .- init_weight + + # Take 2: the same, but with Flux.params. Was broken for a bit, no tests! + model.weight .= init_weight + model.bias .= 0 + pen2(x::AbstractArray) = sum(abs2, x)/2 + opt = Flux.setup(Adam(0.1), model) + Flux.train!(model, data, opt) do m, x, y + err = Flux.mse(m(x), y) + l2 = sum(pen2, Flux.params(m)) + err + 0.33 * l2 + end + diff2 = model.weight .- init_weight + @test_broken diff1 ≈ diff2 + + # Take 3: using WeightDecay instead. Need the /2 above, to match exactly. + model.weight .= init_weight + model.bias .= 0 + decay_opt = Flux.setup(OptimiserChain(WeightDecay(0.33), Adam(0.1)), model); + Flux.train!(model, data, decay_opt) do m, x, y + Flux.mse(m(x), y) + end + diff3 = model.weight .- init_weight + @test diff1 ≈ diff3 +end + From 0980dcd128e01baefdc0f4ca25cd9d4491018a4f Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 25 Nov 2022 08:23:08 -0500 Subject: [PATCH 09/28] tweaks --- Project.toml | 2 +- docs/src/training/train_api.md | 18 ++++++++++++++++-- docs/src/training/training.md | 22 +++++++++++++++++----- src/Flux.jl | 1 + 4 files changed, 35 insertions(+), 8 deletions(-) diff --git a/Project.toml b/Project.toml index 0b02d1b583..6c43609973 100644 --- a/Project.toml +++ b/Project.toml @@ -33,7 +33,7 @@ MacroTools = "0.5" NNlib = "0.8.9" NNlibCUDA = "0.2.4" OneHotArrays = "0.1, 0.2" -Optimisers = "0.2.10" +Optimisers = "0.2.11" ProgressLogging = "0.1" Reexport = "0.2, 1.0" SpecialFunctions = "1.8.2, 2.1.2" diff --git a/docs/src/training/train_api.md b/docs/src/training/train_api.md index bd2f0c83c3..b8c05b240e 100644 --- a/docs/src/training/train_api.md +++ b/docs/src/training/train_api.md @@ -9,14 +9,28 @@ Flux.Optimise.train!(loss, model, data, opt; cb) To see one in a terminal, you will need to install [TerminalLoggers.jl](https://github.com/JuliaLogging/TerminalLoggers.jl) and follow its setup instructions. -The new version of Flux's training code was written as an independent package, called Optimisers.jl. -However, at present all Flux models contain parameter arrays (such as `Array`s and `CuArray`s) +The new version of Flux's training code was written as an independent package, [Optimisers.jl](https://github.com/FluxML/Optimisers.jl). +This is designed to allow for immutable objects. +But at present all Flux models contain parameter arrays (such as `Array`s and `CuArray`s) which can be updated in-place. Thus objects returned by `update!` can be ignored. ```@docs Optimisers.update! ``` +### Modifiers + +The state returned by `setup` can be modified to temporarily prevent training of +some parts of the model, or to change the learning rate uses. +The functions for doing so may be accessed as `Flux.freeze!`, `Flux.thaw!`, and `Flux.adjust`: + +```@docs +Optimisers.adjust +Optimisers.freeze! +Optimisers.thaw! +``` + + ## Implicit style Flux used to handle gradients, training, and optimisation rules quite differently. diff --git a/docs/src/training/training.md b/docs/src/training/training.md index 7c34e7b9a8..791345ce7e 100644 --- a/docs/src/training/training.md +++ b/docs/src/training/training.md @@ -326,15 +326,27 @@ The first, [`WeightDecay`](@ref) adds `0.42` times original parameter to the gra matching the gradient of the penalty above (with the same, unrealistically large, constant). After that, in either case, [`Adam`](@ref) computes the final update. -The same mechanism can be used for other purposes, such as gradient clipping with [`ClipGrad`](@ref ). +The same `OptimiserChain` mechanism can be used for other purposes, such as gradient clipping with [`ClipGrad`](@ref ). Besides L2 / weight decay, another common and quite different kind of regularisation is -provided by the [`Dropout`](@ref Flux.Dropout) layer. This turns off some ... ?? - -?? do we discuss test/train mode here too? +provided by the [`Dropout`](@ref Flux.Dropout) layer. This turns off some outputs of the +previous layer during training. +It should switch automatically, but see [trainmode!](@ref Flux.trainmode!) / [testmode!](@ref Flux.testmode!) to manually enable or disable this layer. ## Freezing, Schedules -?? maybe these also fit in here. +Finer control of training + +```julia +model = Chain(enc = encoder, dec = decoder) + +opt = Flux.setup(Adam(), model) + +Flux.freeze!(opt.layers.enc) # corresponds to model.layers.end +``` +!!! note + This `freeze!` goes with the "explicit" style. + The earlier "implicit" equivalent was to pass to `gradient` an object referencing only + part of the model, such as `Flux.params(model.layers.enc)`. diff --git a/src/Flux.jl b/src/Flux.jl index 53749b85fa..4933b95aa2 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -8,6 +8,7 @@ using MacroTools: @forward @reexport using NNlib using MLUtils import Optimisers: Optimisers, trainable, destructure # before v0.13, Flux owned these functions +using Optimisers: freeze!, thaw!, adjust using Zygote, ChainRulesCore using Zygote: Params, @adjoint, gradient, pullback, @nograd From c6c253afa98306c92b52d2a8a280f70dee5d35d9 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 27 Nov 2022 00:48:15 -0500 Subject: [PATCH 10/28] update quickstart + readme too --- README.md | 5 ++--- docs/src/models/quickstart.md | 27 +++++++++++++++++---------- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index b3dda36a0f..5d2faf255d 100644 --- a/README.md +++ b/README.md @@ -25,10 +25,9 @@ data = [([x], 2x-x^3) for x in -2:0.1f0:2] model = Chain(Dense(1 => 23, tanh), Dense(23 => 1, bias=false), only) -mloss(x,y) = (model(x) - y)^2 -optim = Flux.Adam() +optim = Flux.setup(Adam(), model) for epoch in 1:1000 - Flux.train!(mloss, Flux.params(model), data, optim) + Flux.train!((m,x,y) -> (m(x) - y)^2, model, data, optim) end plot(x -> 2x-x^3, -2, 2, legend=false) diff --git a/docs/src/models/quickstart.md b/docs/src/models/quickstart.md index 33603196ec..82d09bf6a1 100644 --- a/docs/src/models/quickstart.md +++ b/docs/src/models/quickstart.md @@ -27,25 +27,23 @@ target = Flux.onehotbatch(truth, [true, false]) # 2×1000 OneH loader = Flux.DataLoader((noisy, target) |> gpu, batchsize=64, shuffle=true); # 16-element DataLoader with first element: (2×64 Matrix{Float32}, 2×64 OneHotMatrix) -pars = Flux.params(model) # contains references to arrays in model -opt = Flux.Adam(0.01) # will store optimiser momentum, etc. +optim = Flux.setup(Flux.Adam(0.01), model) # will store optimiser momentum, etc. # Training loop, using the whole data set 1000 times: losses = [] @showprogress for epoch in 1:1_000 for (x, y) in loader - loss, grad = Flux.withgradient(pars) do + loss, grads = Flux.withgradient(model) do m # Evaluate model and loss inside gradient context: - y_hat = model(x) + y_hat = m(x) Flux.crossentropy(y_hat, y) end - Flux.update!(opt, pars, grad) + Flux.update!(optim, model, grads[1]) push!(losses, loss) # logging, outside gradient context end end -pars # parameters, momenta and output have all changed -opt +optim # parameters, momenta and output have all changed out2 = model(noisy |> gpu) |> cpu # first row is prob. of true, second row p(false) mean((out2[1,:] .> 0.5) .== truth) # accuracy 94% so far! @@ -89,7 +87,7 @@ Some things to notice in this example are: * The `model` can be called like a function, `y = model(x)`. Each layer like [`Dense`](@ref Flux.Dense) is an ordinary `struct`, which encapsulates some arrays of parameters (and possibly other state, as for [`BatchNorm`](@ref Flux.BatchNorm)). -* But the model does not contain the loss function, nor the optimisation rule. The [`Adam`](@ref Flux.Adam) object stores between iterations the momenta it needs. And [`Flux.crossentropy`](@ref Flux.Losses.crossentropy) is an ordinary function. +* But the model does not contain the loss function, nor the optimisation rule. The momenta needed by [`Adam`](@ref Flux.Adam) are stored in the object returned by [setup](@ref Flux.Train.setup). And [`Flux.crossentropy`](@ref Flux.Losses.crossentropy) is an ordinary function. * The `do` block creates an anonymous function, as the first argument of `gradient`. Anything executed within this is differentiated. @@ -97,9 +95,18 @@ Instead of calling [`gradient`](@ref Zygote.gradient) and [`update!`](@ref Flux. ```julia for epoch in 1:1_000 - Flux.train!(pars, loader, opt) do x, y - y_hat = model(x) + Flux.train!(model, loader, optim) do m, x, y + y_hat = m(x) Flux.crossentropy(y_hat, y) end end ``` + +!!! note "Implicit-style training" + Until recently Flux's training looked a bit different. + Any code which looks like `gradient(() -> loss(model, x, y), Flux.params(model))` + (gradient of a zero-argument function) or + `train!((x,y) -> loss(model, x, y), Flux.params(model), loader, optim)` + is in the old "implicit" style. + This still works on Flux 0.13, but will be removed from Flux 0.14. + See the [training section](@ref man-training) for more details. From df20a6c235d2eb665ca202d984a76260dbdf7d01 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 27 Nov 2022 11:40:22 -0500 Subject: [PATCH 11/28] finish freezing etc, update everything --- Project.toml | 2 +- docs/Project.toml | 1 + docs/src/models/quickstart.md | 13 ++-- docs/src/training/train_api.md | 20 +++--- docs/src/training/training.md | 127 ++++++++++++++++++++------------- docs/src/training/zygote.md | 28 +++++--- src/Flux.jl | 2 +- src/deprecations.jl | 6 +- src/optimise/train.jl | 4 +- src/train.jl | 9 ++- 10 files changed, 132 insertions(+), 80 deletions(-) diff --git a/Project.toml b/Project.toml index 6c43609973..76a6e25c34 100644 --- a/Project.toml +++ b/Project.toml @@ -33,7 +33,7 @@ MacroTools = "0.5" NNlib = "0.8.9" NNlibCUDA = "0.2.4" OneHotArrays = "0.1, 0.2" -Optimisers = "0.2.11" +Optimisers = "0.2.12" ProgressLogging = "0.1" Reexport = "0.2, 1.0" SpecialFunctions = "1.8.2, 2.1.2" diff --git a/docs/Project.toml b/docs/Project.toml index c1812ee385..4af31f2254 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -3,6 +3,7 @@ BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" diff --git a/docs/src/models/quickstart.md b/docs/src/models/quickstart.md index 82d09bf6a1..8cad3fe839 100644 --- a/docs/src/models/quickstart.md +++ b/docs/src/models/quickstart.md @@ -103,10 +103,15 @@ end ``` !!! note "Implicit-style training" - Until recently Flux's training looked a bit different. - Any code which looks like `gradient(() -> loss(model, x, y), Flux.params(model))` + Until recently Flux's training worked a bit differently. + Any code which looks like + ``` + gradient(() -> loss(model, x, y), Flux.params(model)) + ``` (gradient of a zero-argument function) or - `train!((x,y) -> loss(model, x, y), Flux.params(model), loader, optim)` - is in the old "implicit" style. + ``` + train!((x,y) -> loss(model, x, y), Flux.params(model), loader, optim) + ``` + (with `Flux.params`) is in the old "implicit" style. This still works on Flux 0.13, but will be removed from Flux 0.14. See the [training section](@ref man-training) for more details. diff --git a/docs/src/training/train_api.md b/docs/src/training/train_api.md index b8c05b240e..3228cc9965 100644 --- a/docs/src/training/train_api.md +++ b/docs/src/training/train_api.md @@ -2,10 +2,10 @@ ```@docs Flux.Train.setup -Flux.Optimise.train!(loss, model, data, opt; cb) +Flux.Train.train!(loss, model, data, opt; cb) ``` -`train!` uses [`@progress`](https://github.com/JuliaLogging/ProgressLogging.jl) which should show a progress bar in VSCode. +`train!` uses [`@progress`](https://github.com/JuliaLogging/ProgressLogging.jl) which should show a progress bar in VSCode automatically. To see one in a terminal, you will need to install [TerminalLoggers.jl](https://github.com/JuliaLogging/TerminalLoggers.jl) and follow its setup instructions. @@ -21,27 +21,31 @@ Optimisers.update! ### Modifiers The state returned by `setup` can be modified to temporarily prevent training of -some parts of the model, or to change the learning rate uses. -The functions for doing so may be accessed as `Flux.freeze!`, `Flux.thaw!`, and `Flux.adjust`: +some parts of the model, or to change the learning rate or other hyperparameter. +The functions for doing so may be accessed as `Flux.freeze!`, `Flux.thaw!`, and `Flux.adjust!`. +All mutate the state (or part of it) and return `nothing`. ```@docs -Optimisers.adjust +Optimisers.adjust! Optimisers.freeze! Optimisers.thaw! ``` - -## Implicit style +## Implicit style (Flux ≤ 0.13) Flux used to handle gradients, training, and optimisation rules quite differently. The new style described above is called "explicit" by Zygote, and the old style "implicit". Flux 0.13 is the transitional version which supports both; Flux 0.14 will remove the old. +!!! compat "How to upgrade" + The blue-green boxes in the [training section](@ref man-training) describe + the changes needed to upgrade old code. + For full details on the interface for implicit-style optimisers, see the [Flux 0.13.6 manual](https://fluxml.ai/Flux.jl/v0.13.6/training/training/). ```@docs Flux.params -Optimisers.update!(opt::Flux.Optimise.AbstractOptimiser, xs::Flux.Params, gs) +Flux.Optimise.update!(opt::Flux.Optimise.AbstractOptimiser, xs::Flux.Params, gs) Flux.Optimise.train!(loss, ps::Flux.Params, data, opt::Flux.Optimise.AbstractOptimiser; cb) ``` diff --git a/docs/src/training/training.md b/docs/src/training/training.md index 791345ce7e..cf71176f11 100644 --- a/docs/src/training/training.md +++ b/docs/src/training/training.md @@ -13,8 +13,8 @@ something like this: ```julia for data in train_set - # Unpack this datapoint into the input and the - # desired result (for "supervised" training): + # Unpack this element into the input and the + # desired result (for supervised training): input, label = data # Calculate the gradient of the objective @@ -30,8 +30,8 @@ for data in train_set end ``` -This isn't pseudo-code, but is precisely how traning is done. -This loop can also be written using the function [`train!`](@ref Train.train!), +It is important that every `update!` step receives a newly gradient computed gradient. +This loop can also be written using the function [`train!`](@ref Flux.Train.train!), but it's helpful to undersand the pieces first: ```julia @@ -59,14 +59,16 @@ structures are what Zygote calls "explicit" gradients. It is important that the execution of the model takes place inside the call to `gradient`, in order for the influence of the model's parameters to be observed by Zygote. -!!! note "Explicit vs implicit gradients" - Flux used to use Zygote's "implicit" mode, which looks like this: +!!! compat "Explicit vs implicit gradients" + Flux ≤ 0.13 used Zygote's "implicit" mode, in which `gradient` takes a zero-argument function. + It looks like this: ``` pars = Flux.params(model) - grad = Flux.gradient(() -> loss(model(input), label), pars) + grad = gradient(() -> loss(model(input), label), pars) ``` Here `pars::Params` and `grad::Grads` are two dictionary-like structures. - + Support for this will be removed from Flux 0.14, and these blue (teal?) boxes + explain what needs to change. ## Loss Functions @@ -82,14 +84,13 @@ or write this directly inside the `do` block above. Many commonly used functions like [`mse`](@ref Flux.Losses.mse) for mean-squared error or [`crossentropy`](@ref Flux.Losses.crossentropy) for cross-entropy loss, are available from the [`Flux.Losses`](../models/losses.md) module. -!!! note "Implicit-style loss functions" - Flux used to need a loss function which closed over a reference to the model, +!!! compat "Implicit-style loss functions" + Flux ≤ 0.13 needed a loss function which closed over a reference to the model, instead of being a pure function. Thus in old code you may see something like ``` loss(x, y) = sum((model(x) .- y).^2) ``` which defines a function making reference to a particular global variable `model`. - This is no longer the preferred style. ## Optimisation Rules @@ -135,13 +136,13 @@ Many commonly used optimisation rules, such as [`Adam`](@ref Flux.Optimise.Adam) These are listed on the [optimisers](@ref man-optimisers) page. -!!! note "Implicit-style optimiser state" +!!! compat "Implicit-style optimiser state" This `setup` makes another tree-like structure. Old versions of Flux did not do this, and instead stored a dictionary-like structure within the optimiser `Adam(0.001)`. This was initialised on first use of the version of `update!` for "implicit" parameters. -## Datasets +## Datasets & Batches The loop above iterates through `train_set`, expecting at each step a tuple `(input, label)`. The very simplest such object is a vector of tuples, such as this: @@ -153,7 +154,6 @@ data = [(x, y)] ``` or `data = [(x, y), (x, y), (x, y)]` for the same values three times. - To get data into this format, you might want `zip` to combine a list of different `x`s with a list of different `y`s: @@ -189,13 +189,13 @@ length(data) == 1875 === 60_000 ÷ 32 ``` Flux's layers are set up to accept such a batch of input data, -and the convolutional layers such as [Conv](@ref Flux.Conv) require it. - +and the convolutional layers such as [`Conv`](@ref Flux.Conv) require it. +The batch index is always the last dimension. ## Training Loops Very simple training loops like the one above can be written compactly using -the [`train!`](@ref) function. Including `setup`, this reads: +the [`train!`](@ref Flux.Train.train!) function. Including `setup`, this reads: ```julia opt = Flux.setup(Adam(), model) @@ -205,10 +205,10 @@ train!(model, train_set, opt) do m, x, y end ``` -!!! note "Implicit-style `train!`" - This is the "explicit" method of `train!`, which takes the result of `setup` as its 4th argument. +!!! compat "Implicit-style `train!`" + This is the new "explicit" method of `train!`, which takes the result of `setup` as its 4th argument. The 1st argument (from the `do` block) is a function which accepts the model itself. - Old Flux versions provided a method of `train!` for "implicit" parameters, + Flux versions ≤ 0.13 provided a method of `train!` for "implicit" parameters, which works like this: ``` train!((x,y) -> loss(model(x), y), Flux.params(model), train_set, Adam()) @@ -264,28 +264,18 @@ for epoch in 1:100 end ``` -## Implicit vs Explicit - -Flux used to handle gradients, training, and optimisation rules quite differently. -The new style described above is called "explicit" by Zygote, and the old style "implicit". -Flux 0.13 is the transitional version which supports both. - -The blue boxes above describe the changes. -For more details on training in the implicit style, see [Flux 0.13.6 documentation](https://fluxml.ai/Flux.jl/v0.13.6/training/training/). - -For details about the two gradient modes, see [Zygote's documentation](https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1). - ## Regularisation The term *regularisation* covers a wide variety of techniques aiming to improve the result of training. This is often done to avoid overfitting. Some of these are can be implemented by simply modifying the loss function. -L2 or ... umm ... adds to the loss a penalty proportional to `θ^2` for every scalar parameter, -and for a simple model could be implemented as follows: +*L₂ regularisation* (sometimes called ridge regression) adds to the loss a penalty +proportional to `θ^2` for every scalar parameter. +For a very simple model could be implemented as follows: ```julia -Flux.gradient(model) do m +grads = Flux.gradient(densemodel) do m result = m(input) penalty = sum(abs2, m.weight)/2 + sum(abs2, m.bias)/2 my_loss(result, label) + 0.42 * penalty @@ -299,7 +289,7 @@ and then apply a function to each one, and sum the result: ```julia pen_l2(x::AbstractArray) = sum(abs2, x)/2 -Flux.gradient(model) do m +grads = Flux.gradient(model) do m result = m(input) penalty = sum(pen_l2, Flux.params(m)) my_loss(result, label) + 0.42 * penalty @@ -322,31 +312,72 @@ decay_opt = Flux.setup(OptimiserChain(WeightDecay(0.42), Adam(0.1)), model) Flux's optimisers are really modifications applied to the gradient before using it to update the parameters, and `OptimiserChain` applies two such modifications. -The first, [`WeightDecay`](@ref) adds `0.42` times original parameter to the gradient, +The first, [`WeightDecay`](@ref Flux.WeightDecay) adds `0.42` times original parameter to the gradient, matching the gradient of the penalty above (with the same, unrealistically large, constant). -After that, in either case, [`Adam`](@ref) computes the final update. +After that, in either case, [`Adam`](@ref Flux.Adam) computes the final update. -The same `OptimiserChain` mechanism can be used for other purposes, such as gradient clipping with [`ClipGrad`](@ref ). +The same `OptimiserChain` mechanism can be used for other purposes, such as gradient clipping with [`ClipGrad`](@ref Flux.Optimise.ClipValue) or [`ClipNorm`](@ref Flux.Optimise.ClipNorm). -Besides L2 / weight decay, another common and quite different kind of regularisation is +Besides L₂ / weight decay, another common and quite different kind of regularisation is provided by the [`Dropout`](@ref Flux.Dropout) layer. This turns off some outputs of the previous layer during training. -It should switch automatically, but see [trainmode!](@ref Flux.trainmode!) / [testmode!](@ref Flux.testmode!) to manually enable or disable this layer. +It should switch automatically, but see [`trainmode!`](@ref Flux.trainmode!) / [`testmode!`](@ref Flux.testmode!) to manually enable or disable this layer. -## Freezing, Schedules +## Freezing & Schedules -Finer control of training +Finer control of training, you may wish to alter the learning rate mid-way through training. +This can be done with [`adjust!`](@ref Flux.adjust!), like this: ```julia -model = Chain(enc = encoder, dec = decoder) +opt = Flux.setup(Adam(0.1), model) # initialise once -opt = Flux.setup(Adam(), model) +for epoch in 1:1000 + train!([...], opt) # train with η = 0.1 for first 100, + if epoch == 100 # then change to use η = 0.01 for the rest. + Flux.adjust!(opt, 0.01) + end +end +``` + +!!! compat "Flux ≤ 0.13" + With the old "implicit" optimiser, `opt = Adam(0.1)`, the equivalent was to + directly mutate the `Adam` struct, `opt.eta = 0.001`. + +Other hyper-parameters can also be adjusted, such as `Flux.adjust!(opt, beta = (0.8, 0.99))`. +And such modifications can be applied to just one part of the model. +For instance, this sets a different learning rate for the encoder and the decoder: -Flux.freeze!(opt.layers.enc) # corresponds to model.layers.end +```julia +bimodel = Chain(enc = [...], dec = [...]) # some model with two parts + +opt = Flux.setup(Adam(0.02), bimodel) + +Flux.adjust!(opt.layers.enc, 0.03) # corresponds to bimodel.layers.enc +``` + +To completely disable training of some part of the model, use [`freeze!`](@ref Flux.freeze!). +This is a temporary modification, reversed by `thaw!`: + +```julia +Flux.freeze!(opt.layers.enc) + +train!(loss, bimodel, data, opt) # won't touch bimodel.layers.enc + +Flux.thaw!(opt) # this applies to the entire model ``` -!!! note - This `freeze!` goes with the "explicit" style. +!!! compat "Flux ≤ 0.13" The earlier "implicit" equivalent was to pass to `gradient` an object referencing only - part of the model, such as `Flux.params(model.layers.enc)`. + part of the model, such as `Flux.params(bimodel.layers.enc)`. + +## Implicit or Explicit? + +Flux used to handle gradients, training, and optimisation rules quite differently. +The new style described above is called "explicit" by Zygote, and the old style "implicit". +Flux 0.13 is the transitional version which supports both. + +The blue-green boxes above describe the changes. +For more details on training in the implicit style, see [Flux 0.13.6 documentation](https://fluxml.ai/Flux.jl/v0.13.6/training/training/). + +For details about the two gradient modes, see [Zygote's documentation](https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1). diff --git a/docs/src/training/zygote.md b/docs/src/training/zygote.md index f38f3d467e..6c2f078843 100644 --- a/docs/src/training/zygote.md +++ b/docs/src/training/zygote.md @@ -2,20 +2,11 @@ Flux re-exports the `gradient` from [Zygote](https://github.com/FluxML/Zygote.jl), and uses this function within [`train!`](@ref Flux.train!) to differentiate the model. Zygote has its own [documentation](https://fluxml.ai/Zygote.jl/dev/), in particular listing some [important limitations](https://fluxml.ai/Zygote.jl/dev/limitations/). -## Implicit style - -Flux uses primarily what Zygote calls "implicit" gradients, [described here](https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1) in its documentation. - -```@docs -Zygote.gradient -Zygote.Params -Zygote.Grads -Zygote.jacobian(loss, ::Params) -``` ## Explicit style -The other way of using Zygote, and using most other AD packages, is to explicitly provide a function and its arguments. +The preferred way of using Zygote, and the only way of using most other AD packages, +is to explicitly provide a function and its arguments. ```@docs Zygote.gradient(f, args...) @@ -24,6 +15,21 @@ Zygote.jacobian(f, args...) Zygote.withgradient ``` +## Implicit style (Flux ≤ 0.13) + +Flux used to use what Zygote calls "implicit" gradients, [described here](https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1) in its documentation. +However, support for this will be removed from Flux 0.14. + +!!! compat "Training" + The blue-green boxes in the [training section](@ref man-training) describe + the changes needed to upgrade old code from implicit to explicit style. + +```@docs +Zygote.gradient +Zygote.Params +Zygote.Grads +Zygote.jacobian(loss, ::Params) +``` ## ChainRules diff --git a/src/Flux.jl b/src/Flux.jl index 4933b95aa2..d10f6ea010 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -8,7 +8,7 @@ using MacroTools: @forward @reexport using NNlib using MLUtils import Optimisers: Optimisers, trainable, destructure # before v0.13, Flux owned these functions -using Optimisers: freeze!, thaw!, adjust +using Optimisers: freeze!, thaw!, adjust! using Zygote, ChainRulesCore using Zygote: Params, @adjoint, gradient, pullback, @nograd diff --git a/src/deprecations.jl b/src/deprecations.jl index 0d9985dcd1..8a445266a4 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -98,19 +98,19 @@ Base.@deprecate_binding Data Flux false "Sub-module Flux.Data has been removed. =# import .Optimise: train! -train!(loss, ps::Params, data, opt) = error( +train!(loss, ps::Params, data, opt; cb=nothing) = error( """can't mix implict Params with explict state! To use `Flux.params(m)` in `train!`, the 4th argument must be from the old `Flux.Optimise` sub-module. But better to use the new explicit style, in which `m` itself is the 2nd argument. """) -train!(loss, ps::Params, data, opt::Optimisers.AbstractRule) = error( +train!(loss, ps::Params, data, opt::Optimisers.AbstractRule; cb=nothing) = error( """can't mix implict Params with explict rule from Optimisers.jl To use `Flux.params(m)` in `train!`, the 4th argument must be from the old `Flux.Optimise` sub-module. But better to use the new explicit style, in which `m` itself is the 2nd argument. """) -train!(loss, model, data, opt::Optimise.AbstractOptimiser) = train!(loss, model, data, _old_to_new(opt)) +train!(loss, model, data, opt::Optimise.AbstractOptimiser; cb=nothing) = train!(loss, model, data, _old_to_new(opt); cb) # Next, to use the new `setup` with the still-exported old-style `Adam` etc: import .Train: setup diff --git a/src/optimise/train.jl b/src/optimise/train.jl index d0de78e01a..b5586bc974 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -15,7 +15,7 @@ according to optimizer `opt::AbstractOptimiser` and the gradients `gs` (the gra As a result, the parameters are mutated and the optimizer's internal state may change. The gradient could be mutated as well. -!!! note +!!! compat "Deprecated" This method for implicit `Params` (and `AbstractOptimiser`) will be removed from Flux 0.14. The explicit method `update!(opt, model, grad)` from Optimisers.jl will remain. """ @@ -95,7 +95,7 @@ batchmemaybe(x::Tuple) = x Uses a `loss` function and training `data` to improve the model's parameters according to a particular optimisation rule `opt`. -!!! note +!!! compat "Deprecated" This method with implicit `Params` will be removed from Flux 0.14. It should be replaced with the explicit method `train!(loss, model, data, opt)`. diff --git a/src/train.jl b/src/train.jl index d548e0ac02..c3f9b91be1 100644 --- a/src/train.jl +++ b/src/train.jl @@ -21,6 +21,10 @@ It differs from `Optimisers.setup` in that it: * has methods which accept Flux's old optimisers, and convert them. (The old `Flux.Optimise.Adam` and new `Optimisers.Adam` are distinct types.) +!!! compat "New" + This function was added in Flux 0.13.9. It was not used by the old "implicit" + interface, using `Flux.Optimise` module and [`Flux.params`](@ref). + # Example ```jldoctest julia> model = Dense(2=>1, leakyrelu; init=Flux.ones32); @@ -83,8 +87,9 @@ It adds only a few features to the loop above: * Show a progress bar using [`@withprogress`](https://github.com/JuliaLogging/ProgressLogging.jl). -!!! note - This method has significant changes from the one in Flux ≤ 0.13: +!!! compat "New" + This method was added in Flux 0.13.9. + It has significant changes from the one used by Flux ≤ 0.13: * It now takes the `model` itself, not the result of [`Flux.params`](@ref). (This is to move away from Zygote's "implicit" parameter handling, with `Grads`.) * Instead of `loss` being a function which accepts only the data, From 76051cf7c264d7ae4d883043556d1e0b1edaf825 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 27 Nov 2022 12:35:26 -0500 Subject: [PATCH 12/28] fix a test, etc --- docs/src/training/training.md | 14 +++++++------- test/train.jl | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/src/training/training.md b/docs/src/training/training.md index cf71176f11..6721cf77fd 100644 --- a/docs/src/training/training.md +++ b/docs/src/training/training.md @@ -95,8 +95,7 @@ are available from the [`Flux.Losses`](../models/losses.md) module. ## Optimisation Rules The simplest kind of optimisation using the gradient is termed *gradient descent* -(or sometimes *stochastic gradient descent* when it is applied to individual examples -in a loop, not to the entire dataset at once). +(or sometimes *stochastic gradient descent* when, as here, it is not applied to the entire dataset at once). Gradient descent needs a *learning rate* which is a small number describing how fast to walk downhill, usually written as the Greek letter "eta", `η`. This is what it does: @@ -111,9 +110,11 @@ fmap(model, grads[1]) do p, g end ``` -This update of all parameters is wrapepd up as a function [`update!`](@ref Flux.Optimise.update!)`(opt, model, grads[1])`. +A slightly more refined version of this loop to update all the parameters is wrapepd up as a function [`update!`](@ref Flux.Optimise.update!)`(opt, model, grads[1])`. +And the learning rate is the only thing stored in the [`Descent`](@ref Flux.Optimise.Descent) struct. -There are many other optimisation rules, which adjust the step size and direction. +However, there are many other optimisation rules, which adjust the step size and +direction in various clever ways. Most require some memory of the gradients from earlier steps, rather than always walking straight downhill. The function [`setup`](@ref Flux.Train.setup) creates the necessary storage for this, for a particular model. @@ -125,17 +126,16 @@ first argument of `update!`. Like this: opt = Flux.setup(Adam(0.001), model) for data in train_set - ... + grads = [...] # Update both model parameters and optimiser state: Flux.update!(opt, model, grads[1]) end ``` -Many commonly used optimisation rules, such as [`Adam`](@ref Flux.Optimise.Adam), are built-in. +Many commonly-used optimisation rules, such as [`Adam`](@ref Flux.Optimise.Adam), are built-in. These are listed on the [optimisers](@ref man-optimisers) page. - !!! compat "Implicit-style optimiser state" This `setup` makes another tree-like structure. Old versions of Flux did not do this, and instead stored a dictionary-like structure within the optimiser `Adam(0.001)`. diff --git a/test/train.jl b/test/train.jl index 8c3c726184..310102331e 100644 --- a/test/train.jl +++ b/test/train.jl @@ -126,7 +126,7 @@ end err + 0.33 * l2 end diff2 = model.weight .- init_weight - @test_broken diff1 ≈ diff2 + @test diff1 ≈ diff2 # Take 3: using WeightDecay instead. Need the /2 above, to match exactly. model.weight .= init_weight From 6c35ea16bf96035c0e101604e52fec5ddf3cfec3 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 27 Nov 2022 12:46:20 -0500 Subject: [PATCH 13/28] add note to "advanced" page --- docs/src/models/advanced.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/src/models/advanced.md b/docs/src/models/advanced.md index 2c8ce33f7a..28a31ecaff 100644 --- a/docs/src/models/advanced.md +++ b/docs/src/models/advanced.md @@ -69,6 +69,10 @@ However, doing this requires the `struct` to have a corresponding constructor th When it is desired to not include all the model parameters (for e.g. transfer learning), we can simply not pass in those layers into our call to `params`. +!!! compat "Flux ≤ 0.13" + The mechanism described here is for Flux's old "implicit" training style. + When upgrading for Flux 0.14, it should be replaced by [`freeze!`](@ref Flux.freeze!) and `thaw!`. + Consider a simple multi-layer perceptron model where we want to avoid optimising the first two `Dense` layers. We can obtain this using the slicing features `Chain` provides: @@ -155,6 +159,10 @@ model(xs) # returns a single float vector with one value ``` +!!! note + This `Join` layer is available from the [Fluxperimental.jl](https://github.com/FluxML/Fluxperimental.jl) package. + + #### Using `Parallel` Flux already provides [`Parallel`](@ref) that can offer the same functionality. In this case, `Join` is going to just be syntactic sugar for `Parallel`. From c148b860c97567c7349257412b52e1c2c353f8de Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 28 Nov 2022 09:36:02 -0500 Subject: [PATCH 14/28] tweaks --- docs/src/models/advanced.md | 4 ++ docs/src/models/regularisation.md | 80 ------------------------ docs/src/training/train_api.md | 100 ------------------------------ docs/src/training/training.md | 35 +++++------ 4 files changed, 19 insertions(+), 200 deletions(-) delete mode 100644 docs/src/models/regularisation.md delete mode 100644 docs/src/training/train_api.md diff --git a/docs/src/models/advanced.md b/docs/src/models/advanced.md index 28a31ecaff..0fe05414b4 100644 --- a/docs/src/models/advanced.md +++ b/docs/src/models/advanced.md @@ -231,3 +231,7 @@ function loss(x, ys, model) return sqrt(mean(Flux.mse(y, ŷ) for (y, ŷ) in zip(ys, ŷs))) end ``` + +!!! note + This `Split` layer is available from the [Fluxperimental.jl](https://github.com/FluxML/Fluxperimental.jl) package. + diff --git a/docs/src/models/regularisation.md b/docs/src/models/regularisation.md deleted file mode 100644 index 9ca5551171..0000000000 --- a/docs/src/models/regularisation.md +++ /dev/null @@ -1,80 +0,0 @@ -# Regularisation - -Applying regularisation to model parameters is straightforward. We just need to -apply an appropriate regulariser to each model parameter and -add the result to the overall loss. - -For example, say we have a simple regression. - -```jldoctest regularisation -julia> using Flux - -julia> using Flux.Losses: logitcrossentropy - -julia> m = Dense(10 => 5) -Dense(10 => 5) # 55 parameters - -julia> loss(x, y) = logitcrossentropy(m(x), y); -``` - -We can apply L2 regularisation by taking the squared norm of the parameters , `m.weight` and `m.bias`. - -```jldoctest regularisation -julia> penalty() = sum(abs2, m.weight) + sum(abs2, m.bias); - -julia> loss(x, y) = logitcrossentropy(m(x), y) + penalty(); -``` - -When working with layers, Flux provides the `params` function to grab all -parameters at once. We can easily penalise everything with `sum`: - -```jldoctest regularisation; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?" -julia> Flux.params(m) -Params([Float32[0.34704182 -0.48532376 … -0.06914271 -0.38398427; 0.5201164 -0.033709668 … -0.36169025 -0.5552353; … ; 0.46534058 0.17114447 … -0.4809643 0.04993277; -0.47049698 -0.6206029 … -0.3092334 -0.47857067], Float32[0.0, 0.0, 0.0, 0.0, 0.0]]) - -julia> sqnorm(x) = sum(abs2, x); - -julia> sum(sqnorm, Flux.params(m)) -8.34994f0 -``` - -Here's a larger example with a multi-layer perceptron. - -```jldoctest regularisation; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?" -julia> m = Chain(Dense(28^2 => 128, relu), Dense(128 => 32, relu), Dense(32 => 10)) -Chain( - Dense(784 => 128, relu), # 100_480 parameters - Dense(128 => 32, relu), # 4_128 parameters - Dense(32 => 10), # 330 parameters -) # Total: 6 arrays, 104_938 parameters, 410.289 KiB. - -julia> sqnorm(x) = sum(abs2, x); - -julia> loss(x, y) = logitcrossentropy(m(x), y) + sum(sqnorm, Flux.params(m)); - -julia> loss(rand(28^2), rand(10)) -300.76693683244997 -``` - -One can also easily add per-layer regularisation via the `activations` function: - -```jldoctest regularisation; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?" -julia> using Flux: activations - -julia> c = Chain(Dense(10 => 5, σ), Dense(5 => 2), softmax) -Chain( - Dense(10 => 5, σ), # 55 parameters - Dense(5 => 2), # 12 parameters - NNlib.softmax, -) # Total: 4 arrays, 67 parameters, 524 bytes. - -julia> activations(c, rand(10)) -([0.3274892431795043, 0.5360197770386552, 0.3447464835514667, 0.5273025865532305, 0.7513168089280781], [-0.3533774181890544, -0.010937055274926138], [0.4152168057978045, 0.5847831942021956]) - -julia> sum(sqnorm, ans) -1.9953131077618562 -``` - -```@docs -Flux.activations -``` diff --git a/docs/src/training/train_api.md b/docs/src/training/train_api.md deleted file mode 100644 index 3228cc9965..0000000000 --- a/docs/src/training/train_api.md +++ /dev/null @@ -1,100 +0,0 @@ -# Training API - -```@docs -Flux.Train.setup -Flux.Train.train!(loss, model, data, opt; cb) -``` - -`train!` uses [`@progress`](https://github.com/JuliaLogging/ProgressLogging.jl) which should show a progress bar in VSCode automatically. -To see one in a terminal, you will need to install [TerminalLoggers.jl](https://github.com/JuliaLogging/TerminalLoggers.jl) -and follow its setup instructions. - -The new version of Flux's training code was written as an independent package, [Optimisers.jl](https://github.com/FluxML/Optimisers.jl). -This is designed to allow for immutable objects. -But at present all Flux models contain parameter arrays (such as `Array`s and `CuArray`s) -which can be updated in-place. Thus objects returned by `update!` can be ignored. - -```@docs -Optimisers.update! -``` - -### Modifiers - -The state returned by `setup` can be modified to temporarily prevent training of -some parts of the model, or to change the learning rate or other hyperparameter. -The functions for doing so may be accessed as `Flux.freeze!`, `Flux.thaw!`, and `Flux.adjust!`. -All mutate the state (or part of it) and return `nothing`. - -```@docs -Optimisers.adjust! -Optimisers.freeze! -Optimisers.thaw! -``` - -## Implicit style (Flux ≤ 0.13) - -Flux used to handle gradients, training, and optimisation rules quite differently. -The new style described above is called "explicit" by Zygote, and the old style "implicit". -Flux 0.13 is the transitional version which supports both; Flux 0.14 will remove the old. - -!!! compat "How to upgrade" - The blue-green boxes in the [training section](@ref man-training) describe - the changes needed to upgrade old code. - -For full details on the interface for implicit-style optimisers, see the [Flux 0.13.6 manual](https://fluxml.ai/Flux.jl/v0.13.6/training/training/). - -```@docs -Flux.params -Flux.Optimise.update!(opt::Flux.Optimise.AbstractOptimiser, xs::Flux.Params, gs) -Flux.Optimise.train!(loss, ps::Flux.Params, data, opt::Flux.Optimise.AbstractOptimiser; cb) -``` - -Note that, by default, `train!` only loops over the data once (a single "epoch"). -A convenient way to run multiple epochs from the REPL is provided by `@epochs`. - -```julia -julia> using Flux: @epochs - -julia> @epochs 2 println("hello") -[ Info: Epoch 1 -hello -[ Info: Epoch 2 -hello - -julia> @epochs 2 Flux.train!(...) -# Train for two epochs -``` - -```@docs -Flux.@epochs -``` - -### Callbacks - -Implicit `train!` takes an additional argument, `cb`, that's used for callbacks so that you can observe the training process. For example: - -```julia -train!(objective, ps, data, opt, cb = () -> println("training")) -``` - -Callbacks are called for every batch of training data. You can slow this down using `Flux.throttle(f, timeout)` which prevents `f` from being called more than once every `timeout` seconds. - -A more typical callback might look like this: - -```julia -test_x, test_y = # ... create single batch of test data ... -evalcb() = @show(loss(test_x, test_y)) -throttled_cb = throttle(evalcb, 5) -Flux.@epochs 20 Flux.train!(objective, ps, data, opt, cb = throttled_cb) -``` - -Calling `Flux.stop()` in a callback will exit the training loop early. - -```julia -cb = function () - accuracy() > 0.9 && Flux.stop() -end -``` - -See the page about [callback helpers](@ref man-callback-helpers) for more. - diff --git a/docs/src/training/training.md b/docs/src/training/training.md index 6721cf77fd..b01e95e54c 100644 --- a/docs/src/training/training.md +++ b/docs/src/training/training.md @@ -13,8 +13,7 @@ something like this: ```julia for data in train_set - # Unpack this element into the input and the - # desired result (for supervised training): + # Unpack this element (for supervised training): input, label = data # Calculate the gradient of the objective @@ -154,25 +153,16 @@ data = [(x, y)] ``` or `data = [(x, y), (x, y), (x, y)]` for the same values three times. -To get data into this format, you might want `zip` to combine a list of different `x`s -with a list of different `y`s: -```julia -xs = [rand(28, 28), rand(28, 28), rand(28, 28)] -ys = [rand(10), rand(10), rand(10)] -data = zip(xs, ys) - -first(data) isa Tuple{Matrix, Vector} # true -``` - -Very often, the initial data is large arrays which you need to slice into examples: +Very often, the initial data is large arrays which you need to slice into examples. +To produce one iterator of pairs `(x, y)`, you might want `zip`: ```julia -X = rand(28, 28, 60_000) +X = rand(28, 28, 60_000); # many images, each 28 × 28 Y = rand(10, 60_000) data = zip(eachslice(X; dims=3), eachcol(Y)) -first(data) isa Tuple{Matrix, Vector} # true +first(data) isa Tuple{AbstractMatrix, AbstractVector} # true ``` Here each iteration will use one matrix `x` (an image, perhaps) and one vector `y`. @@ -194,20 +184,25 @@ The batch index is always the last dimension. ## Training Loops -Very simple training loops like the one above can be written compactly using +Simple training loops like the one above can be written compactly using the [`train!`](@ref Flux.Train.train!) function. Including `setup`, this reads: ```julia opt = Flux.setup(Adam(), model) -train!(model, train_set, opt) do m, x, y - loss(m(x), y) +for epoch in 1:100 + Flux.train!(model, train_set, opt) do m, x, y + loss(m(x), y) + end end ``` +Or explicitly writing the anonymous function which this `do` block creates, +`train!((m,x,y) -> loss(m(x),y), model, train_set, opt)` is exactly equivalent. + !!! compat "Implicit-style `train!`" This is the new "explicit" method of `train!`, which takes the result of `setup` as its 4th argument. - The 1st argument (from the `do` block) is a function which accepts the model itself. + The 1st argument is a function which accepts the model itself. Flux versions ≤ 0.13 provided a method of `train!` for "implicit" parameters, which works like this: ``` @@ -318,7 +313,7 @@ After that, in either case, [`Adam`](@ref Flux.Adam) computes the final update. The same `OptimiserChain` mechanism can be used for other purposes, such as gradient clipping with [`ClipGrad`](@ref Flux.Optimise.ClipValue) or [`ClipNorm`](@ref Flux.Optimise.ClipNorm). -Besides L₂ / weight decay, another common and quite different kind of regularisation is +Besides L2 / weight decay, another common and quite different kind of regularisation is provided by the [`Dropout`](@ref Flux.Dropout) layer. This turns off some outputs of the previous layer during training. It should switch automatically, but see [`trainmode!`](@ref Flux.trainmode!) / [`testmode!`](@ref Flux.testmode!) to manually enable or disable this layer. From 6042e8effc5edf6a4df209bc5bee3ab670eb8d36 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 28 Nov 2022 10:08:51 -0500 Subject: [PATCH 15/28] comments --- docs/src/training/training.md | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/docs/src/training/training.md b/docs/src/training/training.md index b01e95e54c..17d16a1c9e 100644 --- a/docs/src/training/training.md +++ b/docs/src/training/training.md @@ -327,7 +327,7 @@ This can be done with [`adjust!`](@ref Flux.adjust!), like this: opt = Flux.setup(Adam(0.1), model) # initialise once for epoch in 1:1000 - train!([...], opt) # train with η = 0.1 for first 100, + train!([...], opt) # Train with η = 0.1 for first 100, if epoch == 100 # then change to use η = 0.01 for the rest. Flux.adjust!(opt, 0.01) end @@ -343,11 +343,14 @@ And such modifications can be applied to just one part of the model. For instance, this sets a different learning rate for the encoder and the decoder: ```julia -bimodel = Chain(enc = [...], dec = [...]) # some model with two parts +# Consider some model with two parts: +bimodel = Chain(enc = [...], dec = [...]) +# This returns a tree whose structure matches the model: opt = Flux.setup(Adam(0.02), bimodel) -Flux.adjust!(opt.layers.enc, 0.03) # corresponds to bimodel.layers.enc +# Adjust the learning rate to be used for bimodel.layers.enc +Flux.adjust!(opt.layers.enc, 0.03) ``` To completely disable training of some part of the model, use [`freeze!`](@ref Flux.freeze!). @@ -356,9 +359,11 @@ This is a temporary modification, reversed by `thaw!`: ```julia Flux.freeze!(opt.layers.enc) -train!(loss, bimodel, data, opt) # won't touch bimodel.layers.enc +# Now training won't update parameters in bimodel.layers.enc +train!(loss, bimodel, data, opt) -Flux.thaw!(opt) # this applies to the entire model +# Un-freeze the entire model: +Flux.thaw!(opt) ``` !!! compat "Flux ≤ 0.13" From b72454a74e222a820eda735546892b042da56dc8 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 28 Nov 2022 10:42:49 -0500 Subject: [PATCH 16/28] tweaks, bugs, missing files, etc --- docs/src/training/reference.md | 106 +++++++++++++++++++++++++++++++++ docs/src/training/training.md | 18 +++--- 2 files changed, 116 insertions(+), 8 deletions(-) create mode 100644 docs/src/training/reference.md diff --git a/docs/src/training/reference.md b/docs/src/training/reference.md new file mode 100644 index 0000000000..a6054515ea --- /dev/null +++ b/docs/src/training/reference.md @@ -0,0 +1,106 @@ +# Training API Reference + +The new version of Flux's training code was written as an independent package, [Optimisers.jl](https://github.com/FluxML/Optimisers.jl). +Only the function `train!` belongs to Flux itself. + +The Optimisers package is designed to allow for immutable objects. But at present all Flux models contain parameter arrays (such as `Array`s and `CuArray`s) which can be updated in-place. +Because of this: + +* The objects returned by `Optimisers.update!` can be ignored. +* Flux defines its own version of `setup` which checks this assumption. + (Using instead `Optimisers.setup` will also work, they return the same thing.) + +The new implementation of rules such as Adam in the Optimisers is quite different from the old one in `Flux.Optimise`. In Flux 0.13, `Flux.Adam()` returns the old one, with supertype `Flux.Optimise.AbstractOptimiser`, but `setup` will silently translate it to its new counterpart. +The available rules are listed the [optimisation rules](@ref man-optimisers) page here; +see the [Optimisers documentation](https://fluxml.ai/Optimisers.jl/dev/) for details on how the new rules work. + +```@docs +Flux.Train.setup +Flux.Train.train!(loss, model, data, opt; cb) +Optimisers.update! +``` + +`train!` uses [`@progress`](https://github.com/JuliaLogging/ProgressLogging.jl) which should show a progress bar in VSCode automatically. +To see one in a terminal, you will need to install [TerminalLoggers.jl](https://github.com/JuliaLogging/TerminalLoggers.jl) +and follow its setup instructions. + +## Optimisation Modifiers + +The state returned by `setup` can be modified to temporarily prevent training of +some parts of the model, or to change the learning rate or other hyperparameter. +The functions for doing so may be accessed as `Flux.freeze!`, `Flux.thaw!`, and `Flux.adjust!`. +All mutate the state (or part of it) and return `nothing`. + +```@docs +Optimisers.adjust! +Optimisers.freeze! +Optimisers.thaw! +``` + +## Implicit style (Flux ≤ 0.13) + +Flux used to handle gradients, training, and optimisation rules quite differently. +The new style described above is called "explicit" by Zygote, and the old style "implicit". +Flux 0.13 is the transitional version which supports both; Flux 0.14 will remove the old. + +!!! compat "How to upgrade" + The blue-green boxes in the [training section](@ref man-training) describe + the changes needed to upgrade old code. + +For full details on the interface for implicit-style optimisers, see the [Flux 0.13.6 manual](https://fluxml.ai/Flux.jl/v0.13.6/training/training/). + +```@docs +Flux.params +Flux.Optimise.update!(opt::Flux.Optimise.AbstractOptimiser, xs::Flux.Params, gs) +Flux.Optimise.train!(loss, ps::Flux.Params, data, opt::Flux.Optimise.AbstractOptimiser; cb) +``` + +Note that, by default, `train!` only loops over the data once (a single "epoch"). +A convenient way to run multiple epochs from the REPL is provided by `@epochs`. + +```julia +julia> using Flux: @epochs + +julia> @epochs 2 println("hello") +[ Info: Epoch 1 +hello +[ Info: Epoch 2 +hello + +julia> @epochs 2 Flux.train!(...) +# Train for two epochs +``` + +```@docs +Flux.@epochs +``` + +## Callbacks + +Implicit `train!` takes an additional argument, `cb`, that's used for callbacks so that you can observe the training process. For example: + +```julia +train!(objective, ps, data, opt, cb = () -> println("training")) +``` + +Callbacks are called for every batch of training data. You can slow this down using `Flux.throttle(f, timeout)` which prevents `f` from being called more than once every `timeout` seconds. + +A more typical callback might look like this: + +```julia +test_x, test_y = # ... create single batch of test data ... +evalcb() = @show(loss(test_x, test_y)) +throttled_cb = throttle(evalcb, 5) +Flux.@epochs 20 Flux.train!(objective, ps, data, opt, cb = throttled_cb) +``` + +Calling `Flux.stop()` in a callback will exit the training loop early. + +```julia +cb = function () + accuracy() > 0.9 && Flux.stop() +end +``` + +See the page about [callback helpers](@ref man-callback-helpers) for more. + diff --git a/docs/src/training/training.md b/docs/src/training/training.md index 17d16a1c9e..d278ff87b2 100644 --- a/docs/src/training/training.md +++ b/docs/src/training/training.md @@ -97,7 +97,9 @@ The simplest kind of optimisation using the gradient is termed *gradient descent (or sometimes *stochastic gradient descent* when, as here, it is not applied to the entire dataset at once). Gradient descent needs a *learning rate* which is a small number describing how fast to walk downhill, -usually written as the Greek letter "eta", `η`. This is what it does: +usually written as the Greek letter "eta", `η`. This is often described as a *hyperparameter*, +to distinguish it from the parameters which are being updated `θ = θ - η * ∂loss_∂θ`. +We want to update all the parameters in the model, like this: ```julia η = 0.01 # learning rate @@ -115,14 +117,14 @@ And the learning rate is the only thing stored in the [`Descent`](@ref Flux.Opti However, there are many other optimisation rules, which adjust the step size and direction in various clever ways. Most require some memory of the gradients from earlier steps, rather than always -walking straight downhill. The function [`setup`](@ref Flux.Train.setup) creates the -necessary storage for this, for a particular model. +walking straight downhill -- [`Momentum`](@ref Flux.Optimise.Momentum) is the simplest. +The function [`setup`](@ref Flux.Train.setup) creates the necessary storage for this, for a particular model. It should be called once, before training, and returns a tree-like object which is the -first argument of `update!`. Like this: +first argument of `update!`. Like this: ```julia # Initialise momentum -opt = Flux.setup(Adam(0.001), model) +opt = Flux.setup(Momentum(0.01, 0.9), model) for data in train_set grads = [...] @@ -222,7 +224,7 @@ callback API. Here is an example, in which it may be helpful to note: ```julia opt = Flux.setup(Adam(), model) -log = [] +my_log = [] for epoch in 1:100 losses = Float32[] for (i, data) in enumerate(train_set) @@ -247,9 +249,9 @@ for epoch in 1:100 Flux.update!(opt, model, grads[1]) end - # Compute some accuracy, and save details to log + # Compute some accuracy, and save details as a NamedTuple acc = my_accuracy(model, train_set) - push!(log, (; acc, losses)) + push!(my_log, (; acc, losses)) # Stop training when some criterion is reached if acc > 0.95 From 28091df22d23b6628e258d89d1271b4761f08a49 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 29 Nov 2022 12:00:22 -0500 Subject: [PATCH 17/28] move a sentence --- docs/src/training/training.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/src/training/training.md b/docs/src/training/training.md index d278ff87b2..ba3bba071d 100644 --- a/docs/src/training/training.md +++ b/docs/src/training/training.md @@ -29,7 +29,6 @@ for data in train_set end ``` -It is important that every `update!` step receives a newly gradient computed gradient. This loop can also be written using the function [`train!`](@ref Flux.Train.train!), but it's helpful to undersand the pieces first: @@ -43,8 +42,8 @@ end Fist recall from the section on [taking gradients](@ref man-taking-gradients) that `Flux.gradient(f, a, b)` always calls `f(a, b)`, and returns a tuple `(∂f_∂a, ∂f_∂b)`. -In the code above, the function `f` is an anonymous function with one argument, -created by the `do` block, hence `grads` is a tuple with one element. +In the code above, the function `f` passed to `gradient` is an anonymous function with +one argument, created by the `do` block, hence `grads` is a tuple with one element. Instead of a `do` block, we could have written: ```julia @@ -58,6 +57,9 @@ structures are what Zygote calls "explicit" gradients. It is important that the execution of the model takes place inside the call to `gradient`, in order for the influence of the model's parameters to be observed by Zygote. +It is also important that every `update!` step receives a newly gradient computed gradient, +as this will be change whenever the model's parameters are changed, and for each new data point. + !!! compat "Explicit vs implicit gradients" Flux ≤ 0.13 used Zygote's "implicit" mode, in which `gradient` takes a zero-argument function. It looks like this: From 11e4825cbc1894de67da8055f1bffa2fbf3988c2 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 7 Dec 2022 18:41:47 -0500 Subject: [PATCH 18/28] change opt to state --- README.md | 4 +-- docs/src/models/quickstart.md | 10 ++++---- docs/src/training/reference.md | 2 +- docs/src/training/training.md | 47 ++++++++++++++++++---------------- 4 files changed, 33 insertions(+), 30 deletions(-) diff --git a/README.md b/README.md index 5d2faf255d..9f7efe7ee7 100644 --- a/README.md +++ b/README.md @@ -25,9 +25,9 @@ data = [([x], 2x-x^3) for x in -2:0.1f0:2] model = Chain(Dense(1 => 23, tanh), Dense(23 => 1, bias=false), only) -optim = Flux.setup(Adam(), model) +state = Flux.setup(Adam(), model) for epoch in 1:1000 - Flux.train!((m,x,y) -> (m(x) - y)^2, model, data, optim) + Flux.train!((m,x,y) -> (m(x) - y)^2, model, data, state) end plot(x -> 2x-x^3, -2, 2, legend=false) diff --git a/docs/src/models/quickstart.md b/docs/src/models/quickstart.md index 8cad3fe839..cef6d72d33 100644 --- a/docs/src/models/quickstart.md +++ b/docs/src/models/quickstart.md @@ -27,7 +27,7 @@ target = Flux.onehotbatch(truth, [true, false]) # 2×1000 OneH loader = Flux.DataLoader((noisy, target) |> gpu, batchsize=64, shuffle=true); # 16-element DataLoader with first element: (2×64 Matrix{Float32}, 2×64 OneHotMatrix) -optim = Flux.setup(Flux.Adam(0.01), model) # will store optimiser momentum, etc. +state = Flux.setup(Flux.Adam(0.01), model) # will store optimiser momentum, etc. # Training loop, using the whole data set 1000 times: losses = [] @@ -38,12 +38,12 @@ losses = [] y_hat = m(x) Flux.crossentropy(y_hat, y) end - Flux.update!(optim, model, grads[1]) + Flux.update!(state, model, grads[1]) push!(losses, loss) # logging, outside gradient context end end -optim # parameters, momenta and output have all changed +state # parameters, momenta and output have all changed out2 = model(noisy |> gpu) |> cpu # first row is prob. of true, second row p(false) mean((out2[1,:] .> 0.5) .== truth) # accuracy 94% so far! @@ -95,7 +95,7 @@ Instead of calling [`gradient`](@ref Zygote.gradient) and [`update!`](@ref Flux. ```julia for epoch in 1:1_000 - Flux.train!(model, loader, optim) do m, x, y + Flux.train!(model, loader, state) do m, x, y y_hat = m(x) Flux.crossentropy(y_hat, y) end @@ -110,7 +110,7 @@ end ``` (gradient of a zero-argument function) or ``` - train!((x,y) -> loss(model, x, y), Flux.params(model), loader, optim) + train!((x,y) -> loss(model, x, y), Flux.params(model), loader, opt) ``` (with `Flux.params`) is in the old "implicit" style. This still works on Flux 0.13, but will be removed from Flux 0.14. diff --git a/docs/src/training/reference.md b/docs/src/training/reference.md index a6054515ea..a440a9a641 100644 --- a/docs/src/training/reference.md +++ b/docs/src/training/reference.md @@ -16,7 +16,7 @@ see the [Optimisers documentation](https://fluxml.ai/Optimisers.jl/dev/) for det ```@docs Flux.Train.setup -Flux.Train.train!(loss, model, data, opt; cb) +Flux.Train.train!(loss, model, data, state; cb) Optimisers.update! ``` diff --git a/docs/src/training/training.md b/docs/src/training/training.md index ba3bba071d..b4ce944249 100644 --- a/docs/src/training/training.md +++ b/docs/src/training/training.md @@ -12,6 +12,9 @@ are handled one-by-one. One *epoch* of training means that each example is used something like this: ```julia +# Initialise the optimiser for this model: +state = Flux.setup(rule, model) + for data in train_set # Unpack this element (for supervised training): input, label = data @@ -24,8 +27,8 @@ for data in train_set end # Update the parameters so as to reduce the objective, - # according to a particular optimiser: - Flux.update!(opt, model, grads[1]) + # according the chosen optimisation rule: + Flux.update!(state, model, grads[1]) end ``` @@ -33,7 +36,7 @@ This loop can also be written using the function [`train!`](@ref Flux.Train.trai but it's helpful to undersand the pieces first: ```julia -train!(model, train_set, opt) do m, x, y +train!(model, train_set, state) do m, x, y loss(m(x), y) end ``` @@ -113,7 +116,7 @@ fmap(model, grads[1]) do p, g end ``` -A slightly more refined version of this loop to update all the parameters is wrapepd up as a function [`update!`](@ref Flux.Optimise.update!)`(opt, model, grads[1])`. +A slightly more refined version of this loop to update all the parameters is wrapepd up as a function [`update!`](@ref Flux.Optimise.update!)`(state, model, grads[1])`. And the learning rate is the only thing stored in the [`Descent`](@ref Flux.Optimise.Descent) struct. However, there are many other optimisation rules, which adjust the step size and @@ -126,13 +129,13 @@ first argument of `update!`. Like this: ```julia # Initialise momentum -opt = Flux.setup(Momentum(0.01, 0.9), model) +state = Flux.setup(Momentum(0.01, 0.9), model) for data in train_set grads = [...] # Update both model parameters and optimiser state: - Flux.update!(opt, model, grads[1]) + Flux.update!(state, model, grads[1]) end ``` @@ -192,17 +195,17 @@ Simple training loops like the one above can be written compactly using the [`train!`](@ref Flux.Train.train!) function. Including `setup`, this reads: ```julia -opt = Flux.setup(Adam(), model) +state = Flux.setup(Adam(), model) for epoch in 1:100 - Flux.train!(model, train_set, opt) do m, x, y + Flux.train!(model, train_set, state) do m, x, y loss(m(x), y) end end ``` Or explicitly writing the anonymous function which this `do` block creates, -`train!((m,x,y) -> loss(m(x),y), model, train_set, opt)` is exactly equivalent. +`train!((m,x,y) -> loss(m(x),y), model, train_set, state)` is exactly equivalent. !!! compat "Implicit-style `train!`" This is the new "explicit" method of `train!`, which takes the result of `setup` as its 4th argument. @@ -224,7 +227,7 @@ callback API. Here is an example, in which it may be helpful to note: * Julia's `break` and `continue` keywords let you exit from parts of the loop. ```julia -opt = Flux.setup(Adam(), model) +state = Flux.setup(Adam(), model) my_log = [] for epoch in 1:100 @@ -248,7 +251,7 @@ for epoch in 1:100 continue end - Flux.update!(opt, model, grads[1]) + Flux.update!(state, model, grads[1]) end # Compute some accuracy, and save details as a NamedTuple @@ -300,13 +303,13 @@ So there is a simpler way to implement exactly the same thing, by modifying the instead of the loss function. This is done by replacing this: ```julia -opt = Flux.setup(Adam(0.1), model) +state = Flux.setup(Adam(0.1), model) ``` with this: ```julia -decay_opt = Flux.setup(OptimiserChain(WeightDecay(0.42), Adam(0.1)), model) +decay_state = Flux.setup(OptimiserChain(WeightDecay(0.42), Adam(0.1)), model) ``` Flux's optimisers are really modifications applied to the gradient before using it to update @@ -328,12 +331,12 @@ Finer control of training, you may wish to alter the learning rate mid-way throu This can be done with [`adjust!`](@ref Flux.adjust!), like this: ```julia -opt = Flux.setup(Adam(0.1), model) # initialise once +state = Flux.setup(Adam(0.1), model) # initialise once for epoch in 1:1000 - train!([...], opt) # Train with η = 0.1 for first 100, + train!([...], state) # Train with η = 0.1 for first 100, if epoch == 100 # then change to use η = 0.01 for the rest. - Flux.adjust!(opt, 0.01) + Flux.adjust!(state, 0.01) end end ``` @@ -342,7 +345,7 @@ end With the old "implicit" optimiser, `opt = Adam(0.1)`, the equivalent was to directly mutate the `Adam` struct, `opt.eta = 0.001`. -Other hyper-parameters can also be adjusted, such as `Flux.adjust!(opt, beta = (0.8, 0.99))`. +Other hyper-parameters can also be adjusted, such as `Flux.adjust!(state, beta = (0.8, 0.99))`. And such modifications can be applied to just one part of the model. For instance, this sets a different learning rate for the encoder and the decoder: @@ -351,23 +354,23 @@ For instance, this sets a different learning rate for the encoder and the decode bimodel = Chain(enc = [...], dec = [...]) # This returns a tree whose structure matches the model: -opt = Flux.setup(Adam(0.02), bimodel) +state = Flux.setup(Adam(0.02), bimodel) # Adjust the learning rate to be used for bimodel.layers.enc -Flux.adjust!(opt.layers.enc, 0.03) +Flux.adjust!(state.layers.enc, 0.03) ``` To completely disable training of some part of the model, use [`freeze!`](@ref Flux.freeze!). This is a temporary modification, reversed by `thaw!`: ```julia -Flux.freeze!(opt.layers.enc) +Flux.freeze!(state.layers.enc) # Now training won't update parameters in bimodel.layers.enc -train!(loss, bimodel, data, opt) +train!(loss, bimodel, data, state) # Un-freeze the entire model: -Flux.thaw!(opt) +Flux.thaw!(state) ``` !!! compat "Flux ≤ 0.13" From aef852780eceefff65ae5e3491e221c64b6138c1 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 7 Dec 2022 23:42:56 -0500 Subject: [PATCH 19/28] new page lost in rebase --- docs/make.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/make.jl b/docs/make.jl index 70256de407..c4aa49570f 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -30,6 +30,7 @@ makedocs( "Activation Functions" => "models/activation.md", "Weight Initialisation" => "utilities.md", "Loss Functions" => "models/losses.md", + "Training API" => "training/reference.md", "Optimisation Rules" => "training/optimisers.md", "Shape Inference" => "outputsize.md", "Flat vs. Nested" => "destructure.md", From cc4c71b25f5d124d5eb211db428f3cd1b7e297da Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 8 Dec 2022 09:09:37 -0500 Subject: [PATCH 20/28] don't say "explicit" so often --- docs/src/models/quickstart.md | 2 +- docs/src/training/training.md | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/docs/src/models/quickstart.md b/docs/src/models/quickstart.md index cef6d72d33..3295701a70 100644 --- a/docs/src/models/quickstart.md +++ b/docs/src/models/quickstart.md @@ -102,7 +102,7 @@ for epoch in 1:1_000 end ``` -!!! note "Implicit-style training" +!!! compat "Implicit-style training, Flux ≤ 0.13" Until recently Flux's training worked a bit differently. Any code which looks like ``` diff --git a/docs/src/training/training.md b/docs/src/training/training.md index b4ce944249..f66f0e6728 100644 --- a/docs/src/training/training.md +++ b/docs/src/training/training.md @@ -54,8 +54,9 @@ grads = Flux.gradient(m -> loss(m(input), label), model) ``` Since the model is some nested set of layers, `grads[1]` is a similarly nested set of -`NamedTuple`s, ultimately containing gradient components. These matching tree-like -structures are what Zygote calls "explicit" gradients. +`NamedTuple`s, ultimately containing gradient components. If (for example) +`θ = model.layers[1].weight[2,3]` is one scalar parameter, an entry in a matrix of weights, +then the derivative of the loss with respect to it is `∂f_∂θ = grads[1].layers[1].weight[2,3]`. It is important that the execution of the model takes place inside the call to `gradient`, in order for the influence of the model's parameters to be observed by Zygote. @@ -63,7 +64,7 @@ in order for the influence of the model's parameters to be observed by Zygote. It is also important that every `update!` step receives a newly gradient computed gradient, as this will be change whenever the model's parameters are changed, and for each new data point. -!!! compat "Explicit vs implicit gradients" +!!! compat "Implicit gradients" Flux ≤ 0.13 used Zygote's "implicit" mode, in which `gradient` takes a zero-argument function. It looks like this: ``` @@ -208,7 +209,7 @@ Or explicitly writing the anonymous function which this `do` block creates, `train!((m,x,y) -> loss(m(x),y), model, train_set, state)` is exactly equivalent. !!! compat "Implicit-style `train!`" - This is the new "explicit" method of `train!`, which takes the result of `setup` as its 4th argument. + This is a new method of `train!`, which takes the result of `setup` as its 4th argument. The 1st argument is a function which accepts the model itself. Flux versions ≤ 0.13 provided a method of `train!` for "implicit" parameters, which works like this: From 278b3a00a4a3abc617c017a6b2c848fe3e15c03e Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 8 Dec 2022 09:35:47 -0500 Subject: [PATCH 21/28] opt to state in a few more places --- docs/src/training/training.md | 2 +- src/train.jl | 27 +++++++++++++-------------- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/docs/src/training/training.md b/docs/src/training/training.md index f66f0e6728..a188e7f3da 100644 --- a/docs/src/training/training.md +++ b/docs/src/training/training.md @@ -336,7 +336,7 @@ state = Flux.setup(Adam(0.1), model) # initialise once for epoch in 1:1000 train!([...], state) # Train with η = 0.1 for first 100, - if epoch == 100 # then change to use η = 0.01 for the rest. + if epoch == 100 # then change to use η = 0.01 for the rest. Flux.adjust!(state, 0.01) end end diff --git a/src/train.jl b/src/train.jl index c3f9b91be1..dc0e2a60ce 100644 --- a/src/train.jl +++ b/src/train.jl @@ -12,7 +12,7 @@ using ProgressLogging: @progress, @withprogress, @logprogress using Zygote: Zygote, Params """ - opt = setup(rule, model) + state = setup(rule, model) This is a version of `Optimisers.setup`, and is the first step before using [`train!`](@ref Flux.train!). It differs from `Optimisers.setup` in that it: @@ -29,12 +29,12 @@ It differs from `Optimisers.setup` in that it: ```jldoctest julia> model = Dense(2=>1, leakyrelu; init=Flux.ones32); -julia> opt = Flux.setup(Momentum(0.1), model) # this encodes the optimiser and its state +julia> state = Flux.setup(Momentum(0.1), model) # this encodes the optimiser and its state (weight = Leaf(Momentum{Float64}(0.1, 0.9), Float32[0.0 0.0]), bias = Leaf(Momentum{Float64}(0.1, 0.9), Float32[0.0]), σ = ()) julia> x1, y1 = [0.2, -0.3], [0.4]; # use the same data for two steps: -julia> Flux.train!(model, [(x1, y1), (x1, y1)], opt) do m, x, y +julia> Flux.train!(model, [(x1, y1), (x1, y1)], state) do m, x, y sum(abs.(m(x) .- y)) * 100 end @@ -42,7 +42,7 @@ julia> model.bias # was zero, mutated by Flux.train! 1-element Vector{Float32}: 10.190001 -julia> opt # mutated by Flux.train! +julia> state # mutated by Flux.train! (weight = Leaf(Momentum{Float64}(0.1, 0.9), Float32[-2.018 3.027]), bias = Leaf(Momentum{Float64}(0.1, 0.9), Float32[-10.09]), σ = ()) ``` """ @@ -56,12 +56,12 @@ function setup(rule::Optimisers.AbstractRule, model) end """ - train!(loss, model, data, opt) + train!(loss, model, data, state) Uses a `loss` function and training `data` to improve the `model`'s parameters -according to a particular optimisation rule `opt`. Iterates through `data` once, -evaluating for each `d in data` either `loss(model, d...)` if `d isa Tuple`, -or else `loss(model, d)` for other `d`. +according to a particular optimisation rule encoded in `state`. +Iterates through `data` once, evaluating for each `d in data` either +`loss(model, d...)` if `d isa Tuple`, or else `loss(model, d)` for other `d`. For example, with these definitions... ``` @@ -69,14 +69,13 @@ data = [(x1, y1), (x2, y2), (x3, y3)] loss3(m, x, y) = norm(m(x) .- y) # the model is the first argument -opt = Flux.setup(Adam(), model) # explicit setup of optimiser momenta +state = Flux.setup(Adam(), model) # explicit setup of optimiser momenta ``` -...calling `Flux.train!(loss3, model, data, opt)` runs a loop much like this, -using Zygote's "explicit" mode for the gradient: +...calling `Flux.train!(loss3, model, data, state)` runs a loop much like this: ``` for d in data ∂L∂m = gradient(loss3, model, d...)[1] - update!(opt, model, ∂L∂m) # method for "explicit" gradient + update!(state, model, ∂L∂m) # method for "explicit" gradient end ``` You can also write this loop yourself, if you need more flexibility. @@ -94,10 +93,10 @@ It adds only a few features to the loop above: (This is to move away from Zygote's "implicit" parameter handling, with `Grads`.) * Instead of `loss` being a function which accepts only the data, now it must also accept the `model` itself, as the first argument. - * `opt` should be the result of [`Flux.setup`](@ref). Using an optimiser + * `state` should be the result of [`Flux.setup`](@ref). Using an optimiser such as `Adam()` without this step should give you a warning. * Callback functions are not supported. - But any code can be included in the above `for` loop. + (But any code can be included in the above `for` loop.) """ function train!(loss, model, data, opt; cb = nothing) isnothing(cb) || error("""train! does not support callback functions. From 629cecd517c492236e3dd51734e9611ef32b6b3c Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 8 Dec 2022 16:29:02 -0500 Subject: [PATCH 22/28] add three compat boxes about common errors / problems re old versions --- docs/src/destructure.md | 6 ++++++ docs/src/models/layers.md | 7 ++++++- docs/src/training/reference.md | 7 +++++++ 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/docs/src/destructure.md b/docs/src/destructure.md index eeb3036ae9..6e9eac191e 100644 --- a/docs/src/destructure.md +++ b/docs/src/destructure.md @@ -49,6 +49,12 @@ julia> Flux.destructure(grad) # acts on non-models, too (Float32[10.339018, 11.379145, 22.845667, -29.565302, -37.644184], Restructure(Tuple, ..., 5)) ``` +!!! compat "Flux ≤ 0.12" + Old versions of Flux had an entirely different implementation of `destructure`, which + had many bugs (and almost no tests). Many comments online still refer to that now-deleted + function, or to memories of it. + + ### All Parameters The function `destructure` now lives in [`Optimisers.jl`](https://github.com/FluxML/Optimisers.jl). diff --git a/docs/src/models/layers.md b/docs/src/models/layers.md index 3714f434e4..b8345c3f78 100644 --- a/docs/src/models/layers.md +++ b/docs/src/models/layers.md @@ -15,7 +15,7 @@ The `Dense` exemplifies several features: * It is annotated with [`@functor`](@ref Functors.@functor), which means that [`params`](@ref Flux.params) will see the contents, and [`gpu`](@ref Flux.gpu) will move their arrays to the GPU. By contrast, `Chain` itself contains no parameters, but connects other layers together. -The section on [dataflow layers](@ref man-dataflow-layers) introduces others like this, +The section on [dataflow layers](@ref man-dataflow-layers) introduces others like this. ## Fully Connected @@ -27,6 +27,11 @@ Flux.Scale Perhaps `Scale` isn't quite fully connected, but it may be thought of as `Dense(Diagonal(s.weights), s.bias)`, and LinearAlgebra's `Diagonal` is a matrix which just happens to contain many zeros. +!!! compat "Flux ≤ 0.12" + Old versions of Flux accepted only `Dense(in, out, act)` and not `Dense(in => out, act)`. + This notation makes a `Pair` object. If you get an error like `MethodError: no method matching Dense(::Pair{Int64,Int64})`, this means that you should upgrade to Flux 0.13. + + ## Convolution Models These layers are used to build convolutional neural networks (CNNs). diff --git a/docs/src/training/reference.md b/docs/src/training/reference.md index a440a9a641..2a5f27c696 100644 --- a/docs/src/training/reference.md +++ b/docs/src/training/reference.md @@ -49,6 +49,13 @@ Flux 0.13 is the transitional version which supports both; Flux 0.14 will remove For full details on the interface for implicit-style optimisers, see the [Flux 0.13.6 manual](https://fluxml.ai/Flux.jl/v0.13.6/training/training/). +!!! compat "Flux ≤ 0.12" + Earlier versions of Flux exported `params`, thus allowing unqualified `params(model)` + after `using Flux`. This conflicted with too many other packages, and was removed in Flux 0.13. + If you get an error ``UndefVarError: `params` not defined``, this probably means that you are + following code for Flux 0.12 or earlier on a more recent version. + + ```@docs Flux.params Flux.Optimise.update!(opt::Flux.Optimise.AbstractOptimiser, xs::Flux.Params, gs) From ee9a53ffc42d1a84a25d2e84b67fd2a0e1520413 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 10 Dec 2022 19:46:17 -0500 Subject: [PATCH 23/28] change to opt_state --- README.md | 4 ++-- docs/src/models/quickstart.md | 8 ++++---- docs/src/training/training.md | 38 +++++++++++++++++------------------ src/train.jl | 20 +++++++++--------- 4 files changed, 35 insertions(+), 35 deletions(-) diff --git a/README.md b/README.md index 9f7efe7ee7..5d2faf255d 100644 --- a/README.md +++ b/README.md @@ -25,9 +25,9 @@ data = [([x], 2x-x^3) for x in -2:0.1f0:2] model = Chain(Dense(1 => 23, tanh), Dense(23 => 1, bias=false), only) -state = Flux.setup(Adam(), model) +optim = Flux.setup(Adam(), model) for epoch in 1:1000 - Flux.train!((m,x,y) -> (m(x) - y)^2, model, data, state) + Flux.train!((m,x,y) -> (m(x) - y)^2, model, data, optim) end plot(x -> 2x-x^3, -2, 2, legend=false) diff --git a/docs/src/models/quickstart.md b/docs/src/models/quickstart.md index 3295701a70..3e4939bc2a 100644 --- a/docs/src/models/quickstart.md +++ b/docs/src/models/quickstart.md @@ -27,7 +27,7 @@ target = Flux.onehotbatch(truth, [true, false]) # 2×1000 OneH loader = Flux.DataLoader((noisy, target) |> gpu, batchsize=64, shuffle=true); # 16-element DataLoader with first element: (2×64 Matrix{Float32}, 2×64 OneHotMatrix) -state = Flux.setup(Flux.Adam(0.01), model) # will store optimiser momentum, etc. +optim = Flux.setup(Flux.Adam(0.01), model) # will store optimiser momentum, etc. # Training loop, using the whole data set 1000 times: losses = [] @@ -38,12 +38,12 @@ losses = [] y_hat = m(x) Flux.crossentropy(y_hat, y) end - Flux.update!(state, model, grads[1]) + Flux.update!(optim, model, grads[1]) push!(losses, loss) # logging, outside gradient context end end -state # parameters, momenta and output have all changed +optim # parameters, momenta and output have all changed out2 = model(noisy |> gpu) |> cpu # first row is prob. of true, second row p(false) mean((out2[1,:] .> 0.5) .== truth) # accuracy 94% so far! @@ -95,7 +95,7 @@ Instead of calling [`gradient`](@ref Zygote.gradient) and [`update!`](@ref Flux. ```julia for epoch in 1:1_000 - Flux.train!(model, loader, state) do m, x, y + Flux.train!(model, loader, optim) do m, x, y y_hat = m(x) Flux.crossentropy(y_hat, y) end diff --git a/docs/src/training/training.md b/docs/src/training/training.md index a188e7f3da..1e2e2d3c15 100644 --- a/docs/src/training/training.md +++ b/docs/src/training/training.md @@ -13,7 +13,7 @@ something like this: ```julia # Initialise the optimiser for this model: -state = Flux.setup(rule, model) +opt_state = Flux.setup(rule, model) for data in train_set # Unpack this element (for supervised training): @@ -28,7 +28,7 @@ for data in train_set # Update the parameters so as to reduce the objective, # according the chosen optimisation rule: - Flux.update!(state, model, grads[1]) + Flux.update!(opt_state, model, grads[1]) end ``` @@ -36,7 +36,7 @@ This loop can also be written using the function [`train!`](@ref Flux.Train.trai but it's helpful to undersand the pieces first: ```julia -train!(model, train_set, state) do m, x, y +train!(model, train_set, opt_state) do m, x, y loss(m(x), y) end ``` @@ -117,7 +117,7 @@ fmap(model, grads[1]) do p, g end ``` -A slightly more refined version of this loop to update all the parameters is wrapepd up as a function [`update!`](@ref Flux.Optimise.update!)`(state, model, grads[1])`. +A slightly more refined version of this loop to update all the parameters is wrapepd up as a function [`update!`](@ref Flux.Optimise.update!)`(opt_state, model, grads[1])`. And the learning rate is the only thing stored in the [`Descent`](@ref Flux.Optimise.Descent) struct. However, there are many other optimisation rules, which adjust the step size and @@ -130,13 +130,13 @@ first argument of `update!`. Like this: ```julia # Initialise momentum -state = Flux.setup(Momentum(0.01, 0.9), model) +opt_state = Flux.setup(Momentum(0.01, 0.9), model) for data in train_set grads = [...] # Update both model parameters and optimiser state: - Flux.update!(state, model, grads[1]) + Flux.update!(opt_state, model, grads[1]) end ``` @@ -196,17 +196,17 @@ Simple training loops like the one above can be written compactly using the [`train!`](@ref Flux.Train.train!) function. Including `setup`, this reads: ```julia -state = Flux.setup(Adam(), model) +opt_state = Flux.setup(Adam(), model) for epoch in 1:100 - Flux.train!(model, train_set, state) do m, x, y + Flux.train!(model, train_set, opt_state) do m, x, y loss(m(x), y) end end ``` Or explicitly writing the anonymous function which this `do` block creates, -`train!((m,x,y) -> loss(m(x),y), model, train_set, state)` is exactly equivalent. +`train!((m,x,y) -> loss(m(x),y), model, train_set, opt_state)` is exactly equivalent. !!! compat "Implicit-style `train!`" This is a new method of `train!`, which takes the result of `setup` as its 4th argument. @@ -228,7 +228,7 @@ callback API. Here is an example, in which it may be helpful to note: * Julia's `break` and `continue` keywords let you exit from parts of the loop. ```julia -state = Flux.setup(Adam(), model) +opt_state = Flux.setup(Adam(), model) my_log = [] for epoch in 1:100 @@ -252,7 +252,7 @@ for epoch in 1:100 continue end - Flux.update!(state, model, grads[1]) + Flux.update!(opt_state, model, grads[1]) end # Compute some accuracy, and save details as a NamedTuple @@ -304,13 +304,13 @@ So there is a simpler way to implement exactly the same thing, by modifying the instead of the loss function. This is done by replacing this: ```julia -state = Flux.setup(Adam(0.1), model) +opt_state = Flux.setup(Adam(0.1), model) ``` with this: ```julia -decay_state = Flux.setup(OptimiserChain(WeightDecay(0.42), Adam(0.1)), model) +decay_opt_state = Flux.setup(OptimiserChain(WeightDecay(0.42), Adam(0.1)), model) ``` Flux's optimisers are really modifications applied to the gradient before using it to update @@ -332,12 +332,12 @@ Finer control of training, you may wish to alter the learning rate mid-way throu This can be done with [`adjust!`](@ref Flux.adjust!), like this: ```julia -state = Flux.setup(Adam(0.1), model) # initialise once +opt_state = Flux.setup(Adam(0.1), model) # initialise once for epoch in 1:1000 train!([...], state) # Train with η = 0.1 for first 100, if epoch == 100 # then change to use η = 0.01 for the rest. - Flux.adjust!(state, 0.01) + Flux.adjust!(opt_state, 0.01) end end ``` @@ -346,7 +346,7 @@ end With the old "implicit" optimiser, `opt = Adam(0.1)`, the equivalent was to directly mutate the `Adam` struct, `opt.eta = 0.001`. -Other hyper-parameters can also be adjusted, such as `Flux.adjust!(state, beta = (0.8, 0.99))`. +Other hyper-parameters can also be adjusted, such as `Flux.adjust!(opt_state, beta = (0.8, 0.99))`. And such modifications can be applied to just one part of the model. For instance, this sets a different learning rate for the encoder and the decoder: @@ -355,17 +355,17 @@ For instance, this sets a different learning rate for the encoder and the decode bimodel = Chain(enc = [...], dec = [...]) # This returns a tree whose structure matches the model: -state = Flux.setup(Adam(0.02), bimodel) +opt_state = Flux.setup(Adam(0.02), bimodel) # Adjust the learning rate to be used for bimodel.layers.enc -Flux.adjust!(state.layers.enc, 0.03) +Flux.adjust!(opt_state.layers.enc, 0.03) ``` To completely disable training of some part of the model, use [`freeze!`](@ref Flux.freeze!). This is a temporary modification, reversed by `thaw!`: ```julia -Flux.freeze!(state.layers.enc) +Flux.freeze!(opt_state.layers.enc) # Now training won't update parameters in bimodel.layers.enc train!(loss, bimodel, data, state) diff --git a/src/train.jl b/src/train.jl index dc0e2a60ce..63d95258b9 100644 --- a/src/train.jl +++ b/src/train.jl @@ -12,7 +12,7 @@ using ProgressLogging: @progress, @withprogress, @logprogress using Zygote: Zygote, Params """ - state = setup(rule, model) + opt_state = setup(rule, model) This is a version of `Optimisers.setup`, and is the first step before using [`train!`](@ref Flux.train!). It differs from `Optimisers.setup` in that it: @@ -29,12 +29,12 @@ It differs from `Optimisers.setup` in that it: ```jldoctest julia> model = Dense(2=>1, leakyrelu; init=Flux.ones32); -julia> state = Flux.setup(Momentum(0.1), model) # this encodes the optimiser and its state +julia> opt_state = Flux.setup(Momentum(0.1), model) # this encodes the optimiser and its state (weight = Leaf(Momentum{Float64}(0.1, 0.9), Float32[0.0 0.0]), bias = Leaf(Momentum{Float64}(0.1, 0.9), Float32[0.0]), σ = ()) julia> x1, y1 = [0.2, -0.3], [0.4]; # use the same data for two steps: -julia> Flux.train!(model, [(x1, y1), (x1, y1)], state) do m, x, y +julia> Flux.train!(model, [(x1, y1), (x1, y1)], opt_state) do m, x, y sum(abs.(m(x) .- y)) * 100 end @@ -42,7 +42,7 @@ julia> model.bias # was zero, mutated by Flux.train! 1-element Vector{Float32}: 10.190001 -julia> state # mutated by Flux.train! +julia> opt_state # mutated by Flux.train! (weight = Leaf(Momentum{Float64}(0.1, 0.9), Float32[-2.018 3.027]), bias = Leaf(Momentum{Float64}(0.1, 0.9), Float32[-10.09]), σ = ()) ``` """ @@ -56,10 +56,10 @@ function setup(rule::Optimisers.AbstractRule, model) end """ - train!(loss, model, data, state) + train!(loss, model, data, opt_state) Uses a `loss` function and training `data` to improve the `model`'s parameters -according to a particular optimisation rule encoded in `state`. +according to a particular optimisation rule encoded in `opt_state`. Iterates through `data` once, evaluating for each `d in data` either `loss(model, d...)` if `d isa Tuple`, or else `loss(model, d)` for other `d`. @@ -69,13 +69,13 @@ data = [(x1, y1), (x2, y2), (x3, y3)] loss3(m, x, y) = norm(m(x) .- y) # the model is the first argument -state = Flux.setup(Adam(), model) # explicit setup of optimiser momenta +opt_state = Flux.setup(Adam(), model) # explicit setup of optimiser momenta ``` -...calling `Flux.train!(loss3, model, data, state)` runs a loop much like this: +...calling `Flux.train!(loss3, model, data, opt_state)` runs a loop much like this: ``` for d in data ∂L∂m = gradient(loss3, model, d...)[1] - update!(state, model, ∂L∂m) # method for "explicit" gradient + update!(opt_state, model, ∂L∂m) end ``` You can also write this loop yourself, if you need more flexibility. @@ -93,7 +93,7 @@ It adds only a few features to the loop above: (This is to move away from Zygote's "implicit" parameter handling, with `Grads`.) * Instead of `loss` being a function which accepts only the data, now it must also accept the `model` itself, as the first argument. - * `state` should be the result of [`Flux.setup`](@ref). Using an optimiser + * `opt_state` should be the result of [`Flux.setup`](@ref). Using an optimiser such as `Adam()` without this step should give you a warning. * Callback functions are not supported. (But any code can be included in the above `for` loop.) From 3615b96f576a215c73af72f61e3c716f615af4b8 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 11 Dec 2022 20:43:57 -0500 Subject: [PATCH 24/28] fixes --- docs/src/models/basics.md | 2 +- docs/src/training/reference.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/src/models/basics.md b/docs/src/models/basics.md index f142292621..ed5beac373 100644 --- a/docs/src/models/basics.md +++ b/docs/src/models/basics.md @@ -36,7 +36,7 @@ Instead of passing them to `gradient` individually, we can store them together i The simplest example is a named tuple, created by the following syntax: ```jldoctest basics -julia> nt = (a = [2, 1], b = [2, 0], c = abs2); +julia> nt = (a = [2, 1], b = [2, 0], c = tanh); julia> g(x::NamedTuple) = sum(abs2, x.a .- x.b); diff --git a/docs/src/training/reference.md b/docs/src/training/reference.md index 2a5f27c696..b30bbef2e2 100644 --- a/docs/src/training/reference.md +++ b/docs/src/training/reference.md @@ -52,7 +52,7 @@ For full details on the interface for implicit-style optimisers, see the [Flux 0 !!! compat "Flux ≤ 0.12" Earlier versions of Flux exported `params`, thus allowing unqualified `params(model)` after `using Flux`. This conflicted with too many other packages, and was removed in Flux 0.13. - If you get an error ``UndefVarError: `params` not defined``, this probably means that you are + If you get an error `UndefVarError: \`params\` not defined`, this probably means that you are following code for Flux 0.12 or earlier on a more recent version. From 5c628909be741e575536447854bc123897a90802 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 11 Dec 2022 20:45:48 -0500 Subject: [PATCH 25/28] fixup --- docs/src/training/optimisers.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/src/training/optimisers.md b/docs/src/training/optimisers.md index 600cae4d88..a17bbd5b00 100644 --- a/docs/src/training/optimisers.md +++ b/docs/src/training/optimisers.md @@ -2,7 +2,7 @@ CurrentModule = Flux ``` -# [Optimisers](@id man-optimisers) +# [Optimisation Rules](@id man-optimisers) Flux builds in many optimisation rules for use with [`train!`](@ref Flux.Optimise.train!) and other training functions. @@ -12,9 +12,9 @@ from "implicit" dictionary-based to "explicit" tree-like structures. At present, the same struct (such as `Adam`) can be used with either form, and will be automatically translated. -For full details of how the new "explicit" interface works, see the [Optimisers.jl documentation](https://fluxml.ai/Optimisers.jl/dev/). +For full details of how the new interface works, see the [Optimisers.jl documentation](https://fluxml.ai/Optimisers.jl/dev/). -For full details on how the "implicit" interface worked, see the [Flux 0.13.6 manual](https://fluxml.ai/Flux.jl/v0.13.6/training/optimisers/#Optimiser-Interface). +For full details on how the old "implicit" interface worked, see the [Flux 0.13.6 manual](https://fluxml.ai/Flux.jl/v0.13.6/training/optimisers/#Optimiser-Interface). ## Optimiser Reference From ede33b0f2e32efc3702be7253c5bdec9371a02c2 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 13 Dec 2022 13:13:05 -0500 Subject: [PATCH 26/28] fixup --- docs/src/training/reference.md | 2 +- docs/src/training/training.md | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/src/training/reference.md b/docs/src/training/reference.md index b30bbef2e2..8a4ef97689 100644 --- a/docs/src/training/reference.md +++ b/docs/src/training/reference.md @@ -52,7 +52,7 @@ For full details on the interface for implicit-style optimisers, see the [Flux 0 !!! compat "Flux ≤ 0.12" Earlier versions of Flux exported `params`, thus allowing unqualified `params(model)` after `using Flux`. This conflicted with too many other packages, and was removed in Flux 0.13. - If you get an error `UndefVarError: \`params\` not defined`, this probably means that you are + If you get an error `UndefVarError: params not defined`, this probably means that you are following code for Flux 0.12 or earlier on a more recent version. diff --git a/docs/src/training/training.md b/docs/src/training/training.md index 1e2e2d3c15..e54bc9720f 100644 --- a/docs/src/training/training.md +++ b/docs/src/training/training.md @@ -368,10 +368,10 @@ This is a temporary modification, reversed by `thaw!`: Flux.freeze!(opt_state.layers.enc) # Now training won't update parameters in bimodel.layers.enc -train!(loss, bimodel, data, state) +train!(loss, bimodel, data, opt_state) # Un-freeze the entire model: -Flux.thaw!(state) +Flux.thaw!(opt_state) ``` !!! compat "Flux ≤ 0.13" From cb5a742eab770bd2490a7852a84368a9a7e30638 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 13 Dec 2022 15:33:37 -0500 Subject: [PATCH 27/28] fixup --- docs/src/models/basics.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/models/basics.md b/docs/src/models/basics.md index ed5beac373..ca95dc747d 100644 --- a/docs/src/models/basics.md +++ b/docs/src/models/basics.md @@ -61,7 +61,7 @@ julia> gradient((x, y) -> sum(abs2, x.a ./ y .- x.b), nt, [1, 2]) julia> gradient(nt, [1, 2]) do x, y z = x.a ./ y - sum(x.c, z .- x.b) + sum(abs2, z .- x.b) end ((a = [0.0, 0.5], b = [-0.0, -1.0], c = nothing), [-0.0, -0.25]) ``` From c5446b05aec409ee71b8605f64bf9871bee63437 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 15 Dec 2022 10:53:38 -0500 Subject: [PATCH 28/28] spelling & indentation --- docs/src/training/training.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/src/training/training.md b/docs/src/training/training.md index e54bc9720f..362b6ae8f8 100644 --- a/docs/src/training/training.md +++ b/docs/src/training/training.md @@ -117,7 +117,7 @@ fmap(model, grads[1]) do p, g end ``` -A slightly more refined version of this loop to update all the parameters is wrapepd up as a function [`update!`](@ref Flux.Optimise.update!)`(opt_state, model, grads[1])`. +A slightly more refined version of this loop to update all the parameters is wrapped up as a function [`update!`](@ref Flux.Optimise.update!)`(opt_state, model, grads[1])`. And the learning rate is the only thing stored in the [`Descent`](@ref Flux.Optimise.Descent) struct. However, there are many other optimisation rules, which adjust the step size and @@ -335,10 +335,10 @@ This can be done with [`adjust!`](@ref Flux.adjust!), like this: opt_state = Flux.setup(Adam(0.1), model) # initialise once for epoch in 1:1000 - train!([...], state) # Train with η = 0.1 for first 100, - if epoch == 100 # then change to use η = 0.01 for the rest. - Flux.adjust!(opt_state, 0.01) - end + train!([...], state) # Train with η = 0.1 for first 100, + if epoch == 100 # then change to use η = 0.01 for the rest. + Flux.adjust!(opt_state, 0.01) + end end ```