Skip to content

Commit

Permalink
Conservative Q-Learning (#995)
Browse files Browse the repository at this point in the history
* divide sac into functions

* bump version

* implement CQL

* create OfflineAgent (does not collect online data)

* working state

* experiments working

* typo

* Tests pass

* add finetuning

* write doc

* Update src/ReinforcementLearningCore/src/policies/agent/agent_base.jl

* Update src/ReinforcementLearningZoo/src/algorithms/offline_rl/CQL_SAC.jl

* Apply suggestions from code review

* add review suggestions

* remove finetuning

* fix a ProgressMeter deprecation warning

---------

Co-authored-by: Jeremiah <[email protected]>
  • Loading branch information
HenriDeh and jeremiahpslewis authored Oct 26, 2023
1 parent 8f5ea30 commit e1d9e9e
Show file tree
Hide file tree
Showing 17 changed files with 357 additions and 34 deletions.
5 changes: 2 additions & 3 deletions src/ReinforcementLearningCore/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ReinforcementLearningCore"
uuid = "de1b191a-4ae0-4afa-a27b-92d07f46b2d6"
version = "0.15"
version = "0.15.1"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down Expand Up @@ -39,7 +39,6 @@ FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12, 0.13, 1"
Flux = "0.13, 0.14"
Functors = "0.1, 0.2, 0.3, 0.4"
GPUArrays = "8, 9"
Parsers = "2"
ProgressMeter = "1"
Reexport = "1"
ReinforcementLearningBase = "0.12"
Expand All @@ -57,4 +56,4 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["CommonRLInterface","DomainSets", "Test", "Random"]
test = ["CommonRLInterface", "DomainSets", "Test", "Random"]
4 changes: 2 additions & 2 deletions src/ReinforcementLearningCore/src/core/stop_conditions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ end

function StopAfterStep(step; cur = 1, is_show_progress = true)
if is_show_progress
progress = ProgressMeter.Progress(step, 1)
progress = ProgressMeter.Progress(step, dt = 1)
ProgressMeter.update!(progress, cur)
else
progress = nothing
Expand Down Expand Up @@ -83,7 +83,7 @@ end

function StopAfterEpisode(episode; cur = 0, is_show_progress = true)
if is_show_progress
progress = ProgressMeter.Progress(episode, 1)
progress = ProgressMeter.Progress(episode, dt = 1)
ProgressMeter.update!(progress, cur)
else
progress = nothing
Expand Down
1 change: 1 addition & 0 deletions src/ReinforcementLearningCore/src/policies/agent/agent.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
include("agent_base.jl")
include("agent_srt_cache.jl")
include("multi_agent.jl")
include("offline_agent.jl")
13 changes: 8 additions & 5 deletions src/ReinforcementLearningCore/src/policies/agent/agent_base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ using Base.Threads: @spawn

using Functors: @functor
import Base.push!

abstract type AbstractAgent <: AbstractPolicy end

"""
Agent(;policy, trajectory) <: AbstractPolicy
Expand All @@ -13,7 +16,7 @@ is a Callable and its call method accepts varargs and keyword arguments to be
passed to the policy.
"""
mutable struct Agent{P,T} <: AbstractPolicy
mutable struct Agent{P,T} <: AbstractAgent
policy::P
trajectory::T

Expand All @@ -29,11 +32,11 @@ end

Agent(;policy, trajectory) = Agent(policy, trajectory)

RLBase.optimise!(agent::Agent, stage::S) where {S<:AbstractStage} = RLBase.optimise!(TrajectoryStyle(agent.trajectory), agent, stage)
RLBase.optimise!(::SyncTrajectoryStyle, agent::Agent, stage::S) where {S<:AbstractStage} = RLBase.optimise!(agent.policy, stage, agent.trajectory)
RLBase.optimise!(agent::AbstractAgent, stage::S) where {S<:AbstractStage} = RLBase.optimise!(TrajectoryStyle(agent.trajectory), agent, stage)
RLBase.optimise!(::SyncTrajectoryStyle, agent::AbstractAgent, stage::S) where {S<:AbstractStage} = RLBase.optimise!(agent.policy, stage, agent.trajectory)

# already spawn a task to optimise inner policy when initializing the agent
RLBase.optimise!(::AsyncTrajectoryStyle, agent::Agent, stage::S) where {S<:AbstractStage} = nothing
RLBase.optimise!(::AsyncTrajectoryStyle, agent::AbstractAgent, stage::S) where {S<:AbstractStage} = nothing

#by default, optimise does nothing at all stage
function RLBase.optimise!(policy::AbstractPolicy, stage::AbstractStage, trajectory::Trajectory) end
Expand All @@ -47,7 +50,7 @@ end
# !!! TODO: In async scenarios, parameters of the policy may still be updating
# (partially), which will result to incorrect action. This should be addressed
# in Oolong.jl with a wrapper
function RLBase.plan!(agent::Agent, env::AbstractEnv)
function RLBase.plan!(agent::AbstractAgent, env::AbstractEnv)
RLBase.plan!(agent.policy, env)
end

Expand Down
76 changes: 76 additions & 0 deletions src/ReinforcementLearningCore/src/policies/agent/offline_agent.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
export OfflineAgent, OfflineBehavior

"""
OfflineBehavior(; agent:: Union{<:Agent, Nothing}, steps::Int, reset_condition)
Used to provide an OfflineAgent with a "behavior agent" that will generate the training data
at the `PreExperimentStage`. If `agent` is `nothing` (by default), does nothing. The `trajectory` of agent should
be the same as that of the parent `OfflineAgent`.
`steps` is the number of data elements to generate, defautls to the capacity of the trajectory.
`reset_condition` is the episode reset condition for the data generation (defaults to `ResetAtTerminal()`).
The behavior agent will interact with the main environment of the experiment to generate the data.
"""
struct OfflineBehavior{A <: Union{<:Agent, Nothing}, R}
agent::A
steps::Int
reset_condition::R
end

OfflineBehavior() = OfflineBehavior(nothing, 0, ResetAtTerminal())

function OfflineBehavior(agent; steps = ReinforcementLearningTrajectories.capacity(agent.trajectory.container.traces), reset_condition = ResetAtTerminal())
if steps == Inf
@error "`steps` is infinite, please provide a finite integer."
end
OfflineBehavior(agent, steps, reset_condition)
end

"""
OfflineAgent(policy::AbstractPolicy, trajectory::Trajectory, offline_behavior::OfflineBehavior = OfflineBehavior()) <: AbstractAgent
`OfflineAgent` is an `AbstractAgent` that, unlike the usual online `Agent`, does not interact with the environment
during training in order to collect data. Just like `Agent`, it contains an `AbstractPolicy` to be trained an a `Trajectory`
that contains the training data. The difference being that the trajectory is filled prior to training and is not updated.
An `OfflineBehavior` can optionaly be provided to provide an second "behavior agent" that will
generate the training data at the `PreExperimentStage`. Does nothing by default.
"""
struct OfflineAgent{P<:AbstractPolicy,T<:Trajectory,B<:OfflineBehavior} <: AbstractAgent
policy::P
trajectory::T
offline_behavior::B
function OfflineAgent(policy::P, trajectory::T, offline_behavior = OfflineBehavior()) where {P<:AbstractPolicy, T<:Trajectory}
agent = new{P,T, typeof(offline_behavior)}(policy, trajectory, offline_behavior)
if TrajectoryStyle(trajectory) === AsyncTrajectoryStyle()
bind(trajectory, @spawn(optimise!(policy, trajectory)))
end
agent
end
end

OfflineAgent(;policy, trajectory, offline_behavior = OfflineBehavior()) = OfflineAgent(policy, trajectory, offline_behavior)
@functor OfflineAgent (policy,)

Base.push!(::OfflineAgent{P,T, <: OfflineBehavior{Nothing}}, ::PreExperimentStage, env::AbstractEnv) where {P,T} = nothing
#fills the trajectory with interactions generated with the behavior_agent at the PreExperimentStage.
function Base.push!(agent::OfflineAgent{P,T, <: OfflineBehavior{<:Agent}}, ::PreExperimentStage, env::AbstractEnv) where {P,T}
is_stop = false
policy = agent.offline_behavior.agent
steps = 0
while !is_stop
reset!(env)
push!(policy, PreEpisodeStage(), env)
while !agent.offline_behavior.reset_condition(policy, env) # one episode
steps += 1
push!(policy, PreActStage(), env)
action = RLBase.plan!(policy, env)
act!(env, action)
push!(policy, PostActStage(), env, action)
if steps >= agent.offline_behavior.steps
is_stop = true
break
end
end # end of an episode
push!(policy, PostEpisodeStage(), env)
end
end
38 changes: 38 additions & 0 deletions src/ReinforcementLearningCore/test/policies/agent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,42 @@ import ReinforcementLearningCore.SRT
end
end
end
@testset "OfflineAgent" begin
env = RandomWalk1D()
a_1 = OfflineAgent(
policy = RandomPolicy(),
trajectory = Trajectory(
CircularArraySARTSTraces(; capacity = 1_000),
DummySampler(),
),
)
@test a_1.offline_behavior.agent === nothing
push!(a_1, PreExperimentStage(), env)
@test isempty(a_1.trajectory.container)

trajectory = Trajectory(
CircularArraySARTSTraces(; capacity = 1_000),
DummySampler(),
)

a_2 = OfflineAgent(
policy = RandomPolicy(),
trajectory = trajectory,
offline_behavior = OfflineBehavior(
Agent(RandomPolicy(), trajectory),
steps = 5,
)
)
push!(a_2, PreExperimentStage(), env)
@test length(a_2.trajectory.container) == 5

for agent in [a_1, a_2]
action = RLBase.plan!(agent, env)
@test action in (1,2)
for stage in [PreEpisodeStage(), PreActStage(), PostActStage(), PostEpisodeStage()]
push!(agent, stage, env)
@test length(agent.trajectory.container) in (0,5)
end
end
end
end
2 changes: 1 addition & 1 deletion src/ReinforcementLearningExperiments/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Reexport = "1"
ReinforcementLearningBase = "0.12"
ReinforcementLearningCore = "0.15"
ReinforcementLearningEnvironments = "0.8"
ReinforcementLearningZoo = "0.10"
ReinforcementLearningZoo = "0.10.1"
StableRNGs = "1"
Weave = "0.10"
cuDNN = "1"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# ---
# title: JuliaRL\_SAC\_Pendulum
# cover: assets/JuliaRL_SAC_Pendulum.png
# description: SAC applied to Pendulum
# date: 2021-05-22
# author: "[Roman Bange](https://github.com/rbange)"
# ---

#+ tangle=true
using ReinforcementLearningCore, ReinforcementLearningBase, ReinforcementLearningZoo, ReinforcementLearningEnvironments
using StableRNGs
using Flux
using Flux.Losses
using IntervalSets

function RLCore.Experiment(
::Val{:JuliaRL},
::Val{:CQLSAC},
::Val{:Pendulum},
dummy = nothing;
save_dir=nothing,
seed=123
)
rng = StableRNG(seed)
inner_env = PendulumEnv(T=Float32, rng=rng)
action_dims = inner_env.n_actions
A = action_space(inner_env)
low = A.left
high = A.right
ns = length(state(inner_env))
na = 1

env = ActionTransformedEnv(
inner_env;
action_mapping=x -> low + (x[1] + 1) * 0.5 * (high - low)
)
init = Flux.glorot_uniform(rng)

create_policy_net() = Approximator(
SoftGaussianNetwork(
Chain(
Dense(ns, 30, relu, init=init),
Dense(30, 30, relu, init=init),
),
Chain(Dense(30, na, init=init)),
Chain(Dense(30, na, softplus, init=init)),
),
Adam(0.003),
)

create_q_net() = TargetNetwork(
Approximator(
Chain(
Dense(ns + na, 30, relu; init=init),
Dense(30, 30, relu; init=init),
Dense(30, 1; init=init),
),
Adam(0.003),),
ρ = 0.99f0
)
trajectory= Trajectory(
CircularArraySARTSTraces(capacity = 10000, state = Float32 => (ns,), action = Float32 => (na,)),
BatchSampler{SS′ART}(64),
InsertSampleRatioController(ratio = 1/1, threshold = 0)) # There are no insertions in Offline RL, the controller is not used.
hook = TotalRewardPerEpisode()

agent = OfflineAgent(
policy = CQLSACPolicy(
sac = SACPolicy(
policy=create_policy_net(),
qnetwork1=create_q_net(),
qnetwork2=create_q_net(),
γ=0.99f0,
α=0.2f0,
start_steps=0,
automatic_entropy_tuning=true,
lr_alpha=0.003f0,
action_dims=action_dims,
rng=rng,
device_rng= rng)
),
trajectory = trajectory,
offline_behavior = OfflineBehavior(Agent(RandomPolicy(-1.0 .. 1.0; rng=rng), trajectory))
)

stop_condition = StopAfterStep(5_000, is_show_progress=!haskey(ENV, "CI"))
Experiment(agent, env, stop_condition, hook)
end

#+ tangle=false
using Plots
pyplot() #hide
ex = E`JuliaRL_SAC_Pendulum`
run(ex)
plot(ex.hook.rewards)
savefig("assets/JuliaRL_SAC_Pendulum.png") #hide

# ![](assets/JuliaRL_SAC_Pendulum.png)
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ function RLCore.Experiment(
),
trajectory= Trajectory(
CircularArraySARTSTraces(capacity = 10000, state = Float32 => (ns,), action = Float32 => (na,)),
BatchSampler{SS′ART}(128),
BatchSampler{SS′ART}(64),
InsertSampleRatioController(ratio = 1/1, threshold = 1000))
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ include(joinpath(EXPERIMENTS_DIR, "JuliaRL_TRPO_CartPole.jl"))
include(joinpath(EXPERIMENTS_DIR, "JuliaRL_MPO_CartPole.jl"))
include(joinpath(EXPERIMENTS_DIR, "DQN_CartPoleGPU.jl"))
include(joinpath(EXPERIMENTS_DIR, "JuliaRL_SAC_Pendulum.jl"))
include(joinpath(EXPERIMENTS_DIR, "JuliaRL_CQLSAC_Pendulum.jl"))

# dynamic loading environments
function __init__() end
Expand Down
1 change: 1 addition & 0 deletions src/ReinforcementLearningExperiments/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ run(E`JuliaRL_Rainbow_CartPole`)
#run(E`JuliaRL_VPG_CartPole`)
#run(E`JuliaRL_TRPO_CartPole`)
run(E`JuliaRL_SAC_Pendulum`)
run(E`JuliaRL_CQLSAC_Pendulum`)
run(E`JuliaRL_MPODiscrete_CartPole`)
run(E`JuliaRL_MPOContinuous_CartPole`)
run(E`JuliaRL_MPOCovariance_CartPole`)
Expand Down
5 changes: 3 additions & 2 deletions src/ReinforcementLearningZoo/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ReinforcementLearningZoo"
uuid = "d607f57d-ee1e-4ba7-bcf2-7734c1e31854"
version = "0.10"
version = "0.10.1"

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Expand All @@ -10,6 +10,7 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -29,7 +30,7 @@ LogExpFunctions = "0.3"
NNlib = "0.8, 0.9"
Optim = "1"
ReinforcementLearningBase = "0.12"
ReinforcementLearningCore = "0.15"
ReinforcementLearningCore = "0.15.1"
StatsBase = "0.33, 0.34"
Zygote = "0.6"
cuDNN = "1"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using ReinforcementLearningCore
import ReinforcementLearningCore.forward
const RLZoo = ReinforcementLearningZoo
export RLZoo
import MLUtils

include("algorithms/algorithms.jl")
# include("hooks/hooks.jl") # TotalBatchRewardPerEpisode is broken, need to ensure vector copy works!
Expand Down
2 changes: 1 addition & 1 deletion src/ReinforcementLearningZoo/src/algorithms/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@ include("dqns/dqns.jl")
include("policy_gradient/policy_gradient.jl")
# include("searching/searching.jl")
# include("cfr/cfr.jl")
# include("offline_rl/offline_rl.jl")
include("offline_rl/offline_rl.jl")
# include("nfsp/abstract_nfsp.jl")
# include("exploitability_descent/exploitability_descent.jl")
Loading

0 comments on commit e1d9e9e

Please sign in to comment.