-
-
Notifications
You must be signed in to change notification settings - Fork 111
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* NFQ before refactor * NFQ after refactor * Move to dqns * Refactor * Add NFQ to RLZoo * Set up experiment * Update algorithm for refactor * rng and loss type * remove duplicate * dispatch on trajectory * optimise is dummy by default * optimise! is dispatched on traj and loops it * Fix precompilation warnings * Avoid running post episode optimise! multiple times * Tune experiment * Remove commented code * Drop gpu call Co-authored-by: Henri Dehaybe <[email protected]> * Use `sample` to get batch from trajectory * optimise! for AbstractLearner * NFQ optimise! calls at the correct time * Remove superfluous function due to main merge * Anonymous loop variable * Update NFQ docs * Update julia_words.txt --------- Co-authored-by: Henri Dehaybe <[email protected]>
- Loading branch information
Showing
7 changed files
with
162 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5294,4 +5294,5 @@ sqmahal | |
logdpf | ||
devmode | ||
logpdfs | ||
kldivs | ||
kldivs | ||
Riedmiller |
95 changes: 95 additions & 0 deletions
95
...ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_NFQ_CartPole.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
# --- | ||
# title: JuliaRL\_NFQ\_CartPole | ||
# cover: assets/JuliaRL_BasicDQN_CartPole.png | ||
# description: NFQ applied to the cartpole environment | ||
# date: 2023-06 | ||
# author: "[Lucas Bex](https://github.com/CasBex)" | ||
# --- | ||
|
||
#+ tangle=true | ||
using ReinforcementLearningCore, ReinforcementLearningBase, ReinforcementLearningZoo | ||
using ReinforcementLearningEnvironments | ||
using Flux | ||
using Flux: glorot_uniform | ||
|
||
using StableRNGs: StableRNG | ||
using Flux.Losses: mse | ||
|
||
function RLCore.Experiment( | ||
::Val{:JuliaRL}, | ||
::Val{:NFQ}, | ||
::Val{:CartPole}, | ||
seed = 123, | ||
) | ||
rng = StableRNG(seed) | ||
env = CartPoleEnv(; T=Float32, rng=rng) | ||
ns, na = length(state(env)), length(first(action_space(env))) | ||
|
||
agent = Agent( | ||
policy=QBasedPolicy( | ||
learner=NFQ( | ||
action_space=action_space(env), | ||
approximator=Approximator( | ||
model=Chain( | ||
Dense(ns+na, 5, σ; init=glorot_uniform(rng)), | ||
Dense(5, 5, σ; init=glorot_uniform(rng)), | ||
Dense(5, 1; init=glorot_uniform(rng)), | ||
), | ||
optimiser=RMSProp() | ||
), | ||
loss_function=mse, | ||
epochs=100, | ||
num_iterations=10, | ||
γ = 0.95f0 | ||
), | ||
explorer=EpsilonGreedyExplorer( | ||
kind=:exp, | ||
ϵ_stable=0.001, | ||
warmup_steps=500, | ||
rng=rng, | ||
), | ||
), | ||
trajectory=Trajectory( | ||
container=CircularArraySARTTraces( | ||
capacity=10_000, | ||
state=Float32 => (ns,), | ||
action=Float32 => (na,), | ||
), | ||
sampler=BatchSampler{SS′ART}( | ||
batch_size=10_000, | ||
rng=rng | ||
), | ||
controller=InsertSampleRatioController( | ||
threshold=100, | ||
n_inserted=-1 | ||
) | ||
) | ||
) | ||
|
||
stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI")) | ||
hook = TotalRewardPerEpisode() | ||
Experiment(agent, env, stop_condition, hook) | ||
end | ||
|
||
#+ tangle=false | ||
using Plots | ||
# pyplot() # hide | ||
ex = E`JuliaRL_NFQ_CartPole` | ||
run(ex) | ||
plot(ex.hook.rewards) | ||
savefig("assets/JuliaRL_NFQ_CartPole.png") #hide | ||
|
||
#= | ||
## Watch a demo episode with the trained agent | ||
```julia | ||
demo = Experiment(ex.policy, | ||
CartPoleEnv(), | ||
StopWhenDone(), | ||
RolloutHook(plot, closeall), | ||
"DQN <-> Demo") | ||
run(demo) | ||
``` | ||
=# | ||
|
||
#  |
1 change: 1 addition & 0 deletions
1
src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/config.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
export NFQ | ||
|
||
using Flux | ||
using Functors: @functor | ||
|
||
""" | ||
NFQ{A<:AbstractApproximator, F, R} <: AbstractLearner | ||
NFQ(action_space::AbstractVector, approximator::A, num_iterations::Integer epochs::Integer, loss_function::F, rng::R, γ::Float32) where {A, F, R} | ||
Neural Fitted Q-iteration as implemented in [1] | ||
# Keyword arguments | ||
- `action_space::AbstractVector` Action space of the environment (necessary in the optimise! step) | ||
- `approximator::A` Q-function approximator (typically a neural network) | ||
- `num_iterations::Integer` number of value iteration iterations in FQI loop (i.e. the outer loop) | ||
- `epochs::Integer` number of epochs to train neural network per iteration | ||
- `loss_function::F` loss function of the NN | ||
- `rng::R` random number generator | ||
- `γ::Float32` discount rate | ||
# References | ||
[1] Riedmiller, M. (2005). Neural Fitted Q Iteration – First Experiences with a Data Efficient Neural Reinforcement Learning Method. In: Gama, J., Camacho, R., Brazdil, P.B., Jorge, A.M., Torgo, L. (eds) Machine Learning: ECML 2005. ECML 2005. Lecture Notes in Computer Science(), vol 3720. Springer, Berlin, Heidelberg. https://doi.org/10.1007/11564096_32 | ||
""" | ||
Base.@kwdef struct NFQ{A, R, F} <: AbstractLearner | ||
action_space::AbstractVector | ||
approximator::A | ||
num_iterations::Integer = 20 | ||
epochs::Integer = 100 | ||
loss_function::F = mse | ||
rng::R = Random.default_rng() | ||
γ::Float32 = 0.9f0 | ||
end | ||
|
||
@functor NFQ (approximator,) | ||
|
||
RLCore.forward(L::NFQ, s::AbstractArray) = RLCore.forward(L.approximator, s) | ||
|
||
function RLCore.forward(learner::NFQ, env::AbstractEnv) | ||
as = action_space(env) | ||
return vcat(repeat(state(env), inner=(1, length(as))), transpose(as)) |> x -> send_to_device(device(learner.approximator), x) |> x->RLCore.forward(learner, x) |> send_to_host |> vec | ||
end | ||
|
||
function RLBase.optimise!(learner::NFQ, ::PostEpisodeStage, trajectory::Trajectory) | ||
Q = learner.approximator | ||
γ = learner.γ | ||
loss_func = learner.loss_function | ||
as = learner.action_space | ||
las = length(as) | ||
batch = ReinforcementLearningTrajectories.sample(trajectory) | ||
|
||
(s, a, r, ss) = batch[[:state, :action, :reward, :next_state]] | ||
a = Float32.(a) | ||
s, a, r, ss = map(x->send_to_device(device(Q), x), (s, a, r, ss)) | ||
for i = 1:learner.num_iterations | ||
# Make an input x samples x |action space| array -- Q --> samples x |action space| -- max --> samples | ||
G = r .+ γ .* (cat(repeat(ss, inner=(1, 1, las)), reshape(repeat(as, outer=(1, size(ss, 2))), (1, size(ss, 2), las)), dims=1) |> x -> maximum(RLCore.forward(Q, x), dims=3) |> vec) | ||
for _ = 1:learner.epochs | ||
Flux.train!((x, y) -> loss_func(RLCore.forward(Q, x), y), params(Q.model), [(vcat(s, a), transpose(G))], Q.optimiser) | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,9 @@ | ||
include("basic_dqn.jl") | ||
include("NFQ.jl") | ||
include("dqn.jl") | ||
include("prioritized_dqn.jl") | ||
include("qr_dqn.jl") | ||
include("rem_dqn.jl") | ||
include("iqn.jl") | ||
include("rainbow.jl") | ||
# include("common.jl") | ||
# include("common.jl") |