Skip to content

Commit

Permalink
fixed DQNLearner Gpu isse (#933)
Browse files Browse the repository at this point in the history
* fixed DQNLearner Gpu isse

* reanme variables for cspell / conventions
  • Loading branch information
Mytolo authored Jul 25, 2023
1 parent 1f7f347 commit bd78e83
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# ---
# title: JuliaRL\_DQNCartPole\_GPU
# cover:
# description: DQN applied to CartPole on GPU
# date: 2023-07-24
# author: "[Panajiotis Keßler](mailto:[email protected])"
# ---

function RLCore.Experiment(
::Val{:JuliaRL},
::Val{:DQNCartPole},
::Val{:GPU},
seed=123,
cap = 100,
n=12,
γ=0.99f0
)
rng = StableRNG(seed)
env = CartPoleEnv(; T = Float32, rng = rng)
ns, na = length(state(env)), length(action_space(env))

policy = Agent(
QBasedPolicy(
learner=DQNLearner(
approximator=Approximator(
model=TwinNetwork(
Chain(
Dense(ns, 128, relu; init = glorot_uniform(rng)),
Dense(128, 128, relu; init = glorot_uniform(rng)),
Dense(128, na; init = glorot_uniform(rng)),
);
sync_freq=100
),
optimiser=Adam(),
) |> gpu,
n=n,
γ=γ,
is_enable_double_DQN=true,
loss_func=huber_loss,
rng=rng,
),
explorer=EpsilonGreedyExplorer(
kind=:exp,
ϵ_stable=0.01,
decay_steps=500,
rng=rng,
),
),
Trajectory(
container=CircularArraySARTTraces(
capacity=cap,
state=Float32 => (ns),
),
sampler=NStepBatchSampler{SS′ART}(
n=n,
γ=0.99f0,
batch_size=32,
rng=rng
),
controller=InsertSampleRatioController(
threshold=ceil(1.1*n),
n_inserted=0
)),
)
stop_condition = StopAfterEpisode(5, is_show_progress=!haskey(ENV, "CI"))
hook = TotalRewardPerEpisode()
Experiment(policy, env, stop_condition, hook)
end
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ include(joinpath(EXPERIMENTS_DIR, "JuliaRL_Rainbow_CartPole.jl"))
include(joinpath(EXPERIMENTS_DIR, "JuliaRL_VPG_CartPole.jl"))
include(joinpath(EXPERIMENTS_DIR, "JuliaRL_TRPO_CartPole.jl"))
include(joinpath(EXPERIMENTS_DIR, "JuliaRL_MPO_CartPole.jl"))
include(joinpath(EXPERIMENTS_DIR, "DQN_CartPoleGPU.jl"))

# dynamic loading environments
function __init__() end
Expand Down
1 change: 1 addition & 0 deletions src/ReinforcementLearningExperiments/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ run(E`JuliaRL_VPG_CartPole`)
run(E`JuliaRL_MPODiscrete_CartPole`)
run(E`JuliaRL_MPOContinuous_CartPole`)
run(E`JuliaRL_MPOCovariance_CartPole`)
run(E`JuliaRL_DQNCartPole_GPU`)
# run(E`JuliaRL_BC_CartPole`)
# run(E`JuliaRL_VMPO_CartPole`)
# run(E`JuliaRL_BasicDQN_MountainCar`)
Expand Down
15 changes: 9 additions & 6 deletions src/ReinforcementLearningZoo/src/algorithms/dqns/dqn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,26 +28,29 @@ function RLBase.optimise!(learner::DQNLearner, batch::NamedTuple)
A = learner.approximator
Q = A.model.source
Qₜ = A.model.target
@assert device(Q) == device(Qₜ) || @warn "Q and target Q function have to be on the same device"

γ = learner.γ
loss_func = learner.loss_func
n = learner.n

s, s′, a, r, t = map(x -> batch[x], SS′ART)
s, s_next, a, r, t = map(x -> batch[x], SS′ART)
a = CartesianIndex.(a, 1:length(a))
s, s_next, a, r, t = send_to_device(device(Q), (s, s_next, a, r, t))

q′ = learner.is_enable_double_DQN ? Q(s′) : Qₜ(s′)
q_next = learner.is_enable_double_DQN ? Q(s_next) : Qₜ(s_next)

if haskey(batch, :next_legal_actions_mask)
q′ .+= ifelse.(batch[:next_legal_actions_mask], 0.0f0, typemin(Float32))
q_next .+= ifelse.(batch[:next_legal_actions_mask], 0.0f0, typemin(Float32))
end

q′ₐ = learner.is_enable_double_DQN ? Qₜ(s′)[dropdims(argmax(q′, dims=1), dims=1)] : dropdims(maximum(q′; dims=1), dims=1)
q_next_action = learner.is_enable_double_DQN ? Qₜ(s_next)[dropdims(argmax(q_next, dims=1), dims=1)] : dropdims(maximum(q_next; dims=1), dims=1)

G = r .+ γ^n .* (1 .- t) .* q′ₐ
R = r .+ γ^n .* (1 .- t) .* q_next_action

gs = gradient(params(A)) do
qₐ = Q(s)[a]
loss = loss_func(G, qₐ)
loss = loss_func(R, qₐ)
ignore_derivatives() do
learner.loss = loss
end
Expand Down

0 comments on commit bd78e83

Please sign in to comment.