Skip to content

Commit

Permalink
NFQ (#897)
Browse files Browse the repository at this point in the history
* NFQ before refactor

* NFQ after refactor

* Move to dqns

* Refactor

* Add NFQ to RLZoo

* Set up experiment

* Update algorithm for refactor

* rng and loss type

* remove duplicate

* dispatch on trajectory

* optimise is dummy by default

* optimise! is dispatched on traj and loops it

* Fix precompilation warnings

* Avoid running post episode optimise! multiple times

* Tune experiment

* Remove commented code

* Drop gpu call

Co-authored-by: Henri Dehaybe <[email protected]>

* Use `sample` to get batch from trajectory

* optimise! for AbstractLearner

* NFQ optimise! calls at the correct time

* Remove superfluous function due to main merge

* Anonymous loop variable

* Update NFQ docs

* Update julia_words.txt

---------

Co-authored-by: Henri Dehaybe <[email protected]>
  • Loading branch information
CasBex and HenriDeh authored Jun 26, 2023
1 parent 6de371f commit 72d6766
Show file tree
Hide file tree
Showing 7 changed files with 162 additions and 2 deletions.
3 changes: 2 additions & 1 deletion .cspell/julia_words.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5294,4 +5294,5 @@ sqmahal
logdpf
devmode
logpdfs
kldivs
kldivs
Riedmiller
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# ---
# title: JuliaRL\_NFQ\_CartPole
# cover: assets/JuliaRL_BasicDQN_CartPole.png
# description: NFQ applied to the cartpole environment
# date: 2023-06
# author: "[Lucas Bex](https://github.com/CasBex)"
# ---

#+ tangle=true
using ReinforcementLearningCore, ReinforcementLearningBase, ReinforcementLearningZoo
using ReinforcementLearningEnvironments
using Flux
using Flux: glorot_uniform

using StableRNGs: StableRNG
using Flux.Losses: mse

function RLCore.Experiment(
::Val{:JuliaRL},
::Val{:NFQ},
::Val{:CartPole},
seed = 123,
)
rng = StableRNG(seed)
env = CartPoleEnv(; T=Float32, rng=rng)
ns, na = length(state(env)), length(first(action_space(env)))

agent = Agent(
policy=QBasedPolicy(
learner=NFQ(
action_space=action_space(env),
approximator=Approximator(
model=Chain(
Dense(ns+na, 5, σ; init=glorot_uniform(rng)),
Dense(5, 5, σ; init=glorot_uniform(rng)),
Dense(5, 1; init=glorot_uniform(rng)),
),
optimiser=RMSProp()
),
loss_function=mse,
epochs=100,
num_iterations=10,
γ = 0.95f0
),
explorer=EpsilonGreedyExplorer(
kind=:exp,
ϵ_stable=0.001,
warmup_steps=500,
rng=rng,
),
),
trajectory=Trajectory(
container=CircularArraySARTTraces(
capacity=10_000,
state=Float32 => (ns,),
action=Float32 => (na,),
),
sampler=BatchSampler{SS′ART}(
batch_size=10_000,
rng=rng
),
controller=InsertSampleRatioController(
threshold=100,
n_inserted=-1
)
)
)

stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI"))
hook = TotalRewardPerEpisode()
Experiment(agent, env, stop_condition, hook)
end

#+ tangle=false
using Plots
# pyplot() # hide
ex = E`JuliaRL_NFQ_CartPole`
run(ex)
plot(ex.hook.rewards)
savefig("assets/JuliaRL_NFQ_CartPole.png") #hide

#=
## Watch a demo episode with the trained agent
```julia
demo = Experiment(ex.policy,
CartPoleEnv(),
StopWhenDone(),
RolloutHook(plot, closeall),
"DQN <-> Demo")
run(demo)
```
=#

# ![](assets/JuliaRL_NFQ_CartPole.png)
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{
"description": "DQN related experiments.",
"order": [
"JuliaRL_NFQ_CartPole.jl",
"JuliaRL_BasicDQN_CartPole.jl",
"JuliaRL_BasicDQN_MountainCar.jl",
"JuliaRL_BasicDQN_PendulumDiscrete.jl",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ const EXPERIMENTS_DIR = joinpath(@__DIR__, "experiments")
# for f in readdir(EXPERIMENTS_DIR)
# include(joinpath(EXPERIMENTS_DIR, f))
# end
include(joinpath(EXPERIMENTS_DIR, "JuliaRL_NFQ_CartPole.jl"))
include(joinpath(EXPERIMENTS_DIR, "JuliaRL_BasicDQN_CartPole.jl"))
include(joinpath(EXPERIMENTS_DIR, "JuliaRL_DQN_CartPole.jl"))
include(joinpath(EXPERIMENTS_DIR, "JuliaRL_PrioritizedDQN_CartPole.jl"))
Expand Down
1 change: 1 addition & 0 deletions src/ReinforcementLearningExperiments/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using CUDA

CUDA.allowscalar(false)

run(E`JuliaRL_NFQ_CartPole`)
run(E`JuliaRL_BasicDQN_CartPole`)
run(E`JuliaRL_DQN_CartPole`)
run(E`JuliaRL_PrioritizedDQN_CartPole`)
Expand Down
60 changes: 60 additions & 0 deletions src/ReinforcementLearningZoo/src/algorithms/dqns/NFQ.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
export NFQ

using Flux
using Functors: @functor

"""
NFQ{A<:AbstractApproximator, F, R} <: AbstractLearner
NFQ(action_space::AbstractVector, approximator::A, num_iterations::Integer epochs::Integer, loss_function::F, rng::R, γ::Float32) where {A, F, R}
Neural Fitted Q-iteration as implemented in [1]
# Keyword arguments
- `action_space::AbstractVector` Action space of the environment (necessary in the optimise! step)
- `approximator::A` Q-function approximator (typically a neural network)
- `num_iterations::Integer` number of value iteration iterations in FQI loop (i.e. the outer loop)
- `epochs::Integer` number of epochs to train neural network per iteration
- `loss_function::F` loss function of the NN
- `rng::R` random number generator
- `γ::Float32` discount rate
# References
[1] Riedmiller, M. (2005). Neural Fitted Q Iteration – First Experiences with a Data Efficient Neural Reinforcement Learning Method. In: Gama, J., Camacho, R., Brazdil, P.B., Jorge, A.M., Torgo, L. (eds) Machine Learning: ECML 2005. ECML 2005. Lecture Notes in Computer Science(), vol 3720. Springer, Berlin, Heidelberg. https://doi.org/10.1007/11564096_32
"""
Base.@kwdef struct NFQ{A, R, F} <: AbstractLearner
action_space::AbstractVector
approximator::A
num_iterations::Integer = 20
epochs::Integer = 100
loss_function::F = mse
rng::R = Random.default_rng()
γ::Float32 = 0.9f0
end

@functor NFQ (approximator,)

RLCore.forward(L::NFQ, s::AbstractArray) = RLCore.forward(L.approximator, s)

function RLCore.forward(learner::NFQ, env::AbstractEnv)
as = action_space(env)
return vcat(repeat(state(env), inner=(1, length(as))), transpose(as)) |> x -> send_to_device(device(learner.approximator), x) |> x->RLCore.forward(learner, x) |> send_to_host |> vec
end

function RLBase.optimise!(learner::NFQ, ::PostEpisodeStage, trajectory::Trajectory)
Q = learner.approximator
γ = learner.γ
loss_func = learner.loss_function
as = learner.action_space
las = length(as)
batch = ReinforcementLearningTrajectories.sample(trajectory)

(s, a, r, ss) = batch[[:state, :action, :reward, :next_state]]
a = Float32.(a)
s, a, r, ss = map(x->send_to_device(device(Q), x), (s, a, r, ss))
for i = 1:learner.num_iterations
# Make an input x samples x |action space| array -- Q --> samples x |action space| -- max --> samples
G = r .+ γ .* (cat(repeat(ss, inner=(1, 1, las)), reshape(repeat(as, outer=(1, size(ss, 2))), (1, size(ss, 2), las)), dims=1) |> x -> maximum(RLCore.forward(Q, x), dims=3) |> vec)
for _ = 1:learner.epochs
Flux.train!((x, y) -> loss_func(RLCore.forward(Q, x), y), params(Q.model), [(vcat(s, a), transpose(G))], Q.optimiser)
end
end
end
3 changes: 2 additions & 1 deletion src/ReinforcementLearningZoo/src/algorithms/dqns/dqns.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
include("basic_dqn.jl")
include("NFQ.jl")
include("dqn.jl")
include("prioritized_dqn.jl")
include("qr_dqn.jl")
include("rem_dqn.jl")
include("iqn.jl")
include("rainbow.jl")
# include("common.jl")
# include("common.jl")

0 comments on commit 72d6766

Please sign in to comment.