-
-
Notifications
You must be signed in to change notification settings - Fork 111
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* divide sac into functions * bump version * implement CQL * create OfflineAgent (does not collect online data) * working state * experiments working * typo * Tests pass * add finetuning * write doc * Update src/ReinforcementLearningCore/src/policies/agent/agent_base.jl * Update src/ReinforcementLearningZoo/src/algorithms/offline_rl/CQL_SAC.jl * Apply suggestions from code review * add review suggestions * remove finetuning * fix a ProgressMeter deprecation warning --------- Co-authored-by: Jeremiah <[email protected]>
- Loading branch information
1 parent
8f5ea30
commit e1d9e9e
Showing
17 changed files
with
357 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
include("agent_base.jl") | ||
include("agent_srt_cache.jl") | ||
include("multi_agent.jl") | ||
include("offline_agent.jl") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
76 changes: 76 additions & 0 deletions
76
src/ReinforcementLearningCore/src/policies/agent/offline_agent.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
export OfflineAgent, OfflineBehavior | ||
|
||
""" | ||
OfflineBehavior(; agent:: Union{<:Agent, Nothing}, steps::Int, reset_condition) | ||
Used to provide an OfflineAgent with a "behavior agent" that will generate the training data | ||
at the `PreExperimentStage`. If `agent` is `nothing` (by default), does nothing. The `trajectory` of agent should | ||
be the same as that of the parent `OfflineAgent`. | ||
`steps` is the number of data elements to generate, defautls to the capacity of the trajectory. | ||
`reset_condition` is the episode reset condition for the data generation (defaults to `ResetAtTerminal()`). | ||
The behavior agent will interact with the main environment of the experiment to generate the data. | ||
""" | ||
struct OfflineBehavior{A <: Union{<:Agent, Nothing}, R} | ||
agent::A | ||
steps::Int | ||
reset_condition::R | ||
end | ||
|
||
OfflineBehavior() = OfflineBehavior(nothing, 0, ResetAtTerminal()) | ||
|
||
function OfflineBehavior(agent; steps = ReinforcementLearningTrajectories.capacity(agent.trajectory.container.traces), reset_condition = ResetAtTerminal()) | ||
if steps == Inf | ||
@error "`steps` is infinite, please provide a finite integer." | ||
end | ||
OfflineBehavior(agent, steps, reset_condition) | ||
end | ||
|
||
""" | ||
OfflineAgent(policy::AbstractPolicy, trajectory::Trajectory, offline_behavior::OfflineBehavior = OfflineBehavior()) <: AbstractAgent | ||
`OfflineAgent` is an `AbstractAgent` that, unlike the usual online `Agent`, does not interact with the environment | ||
during training in order to collect data. Just like `Agent`, it contains an `AbstractPolicy` to be trained an a `Trajectory` | ||
that contains the training data. The difference being that the trajectory is filled prior to training and is not updated. | ||
An `OfflineBehavior` can optionaly be provided to provide an second "behavior agent" that will | ||
generate the training data at the `PreExperimentStage`. Does nothing by default. | ||
""" | ||
struct OfflineAgent{P<:AbstractPolicy,T<:Trajectory,B<:OfflineBehavior} <: AbstractAgent | ||
policy::P | ||
trajectory::T | ||
offline_behavior::B | ||
function OfflineAgent(policy::P, trajectory::T, offline_behavior = OfflineBehavior()) where {P<:AbstractPolicy, T<:Trajectory} | ||
agent = new{P,T, typeof(offline_behavior)}(policy, trajectory, offline_behavior) | ||
if TrajectoryStyle(trajectory) === AsyncTrajectoryStyle() | ||
bind(trajectory, @spawn(optimise!(policy, trajectory))) | ||
end | ||
agent | ||
end | ||
end | ||
|
||
OfflineAgent(;policy, trajectory, offline_behavior = OfflineBehavior()) = OfflineAgent(policy, trajectory, offline_behavior) | ||
@functor OfflineAgent (policy,) | ||
|
||
Base.push!(::OfflineAgent{P,T, <: OfflineBehavior{Nothing}}, ::PreExperimentStage, env::AbstractEnv) where {P,T} = nothing | ||
#fills the trajectory with interactions generated with the behavior_agent at the PreExperimentStage. | ||
function Base.push!(agent::OfflineAgent{P,T, <: OfflineBehavior{<:Agent}}, ::PreExperimentStage, env::AbstractEnv) where {P,T} | ||
is_stop = false | ||
policy = agent.offline_behavior.agent | ||
steps = 0 | ||
while !is_stop | ||
reset!(env) | ||
push!(policy, PreEpisodeStage(), env) | ||
while !agent.offline_behavior.reset_condition(policy, env) # one episode | ||
steps += 1 | ||
push!(policy, PreActStage(), env) | ||
action = RLBase.plan!(policy, env) | ||
act!(env, action) | ||
push!(policy, PostActStage(), env, action) | ||
if steps >= agent.offline_behavior.steps | ||
is_stop = true | ||
break | ||
end | ||
end # end of an episode | ||
push!(policy, PostEpisodeStage(), env) | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
98 changes: 98 additions & 0 deletions
98
...cementLearningExperiments/deps/experiments/experiments/Offline/JuliaRL_CQLSAC_Pendulum.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
# --- | ||
# title: JuliaRL\_SAC\_Pendulum | ||
# cover: assets/JuliaRL_SAC_Pendulum.png | ||
# description: SAC applied to Pendulum | ||
# date: 2021-05-22 | ||
# author: "[Roman Bange](https://github.com/rbange)" | ||
# --- | ||
|
||
#+ tangle=true | ||
using ReinforcementLearningCore, ReinforcementLearningBase, ReinforcementLearningZoo, ReinforcementLearningEnvironments | ||
using StableRNGs | ||
using Flux | ||
using Flux.Losses | ||
using IntervalSets | ||
|
||
function RLCore.Experiment( | ||
::Val{:JuliaRL}, | ||
::Val{:CQLSAC}, | ||
::Val{:Pendulum}, | ||
dummy = nothing; | ||
save_dir=nothing, | ||
seed=123 | ||
) | ||
rng = StableRNG(seed) | ||
inner_env = PendulumEnv(T=Float32, rng=rng) | ||
action_dims = inner_env.n_actions | ||
A = action_space(inner_env) | ||
low = A.left | ||
high = A.right | ||
ns = length(state(inner_env)) | ||
na = 1 | ||
|
||
env = ActionTransformedEnv( | ||
inner_env; | ||
action_mapping=x -> low + (x[1] + 1) * 0.5 * (high - low) | ||
) | ||
init = Flux.glorot_uniform(rng) | ||
|
||
create_policy_net() = Approximator( | ||
SoftGaussianNetwork( | ||
Chain( | ||
Dense(ns, 30, relu, init=init), | ||
Dense(30, 30, relu, init=init), | ||
), | ||
Chain(Dense(30, na, init=init)), | ||
Chain(Dense(30, na, softplus, init=init)), | ||
), | ||
Adam(0.003), | ||
) | ||
|
||
create_q_net() = TargetNetwork( | ||
Approximator( | ||
Chain( | ||
Dense(ns + na, 30, relu; init=init), | ||
Dense(30, 30, relu; init=init), | ||
Dense(30, 1; init=init), | ||
), | ||
Adam(0.003),), | ||
ρ = 0.99f0 | ||
) | ||
trajectory= Trajectory( | ||
CircularArraySARTSTraces(capacity = 10000, state = Float32 => (ns,), action = Float32 => (na,)), | ||
BatchSampler{SS′ART}(64), | ||
InsertSampleRatioController(ratio = 1/1, threshold = 0)) # There are no insertions in Offline RL, the controller is not used. | ||
hook = TotalRewardPerEpisode() | ||
|
||
agent = OfflineAgent( | ||
policy = CQLSACPolicy( | ||
sac = SACPolicy( | ||
policy=create_policy_net(), | ||
qnetwork1=create_q_net(), | ||
qnetwork2=create_q_net(), | ||
γ=0.99f0, | ||
α=0.2f0, | ||
start_steps=0, | ||
automatic_entropy_tuning=true, | ||
lr_alpha=0.003f0, | ||
action_dims=action_dims, | ||
rng=rng, | ||
device_rng= rng) | ||
), | ||
trajectory = trajectory, | ||
offline_behavior = OfflineBehavior(Agent(RandomPolicy(-1.0 .. 1.0; rng=rng), trajectory)) | ||
) | ||
|
||
stop_condition = StopAfterStep(5_000, is_show_progress=!haskey(ENV, "CI")) | ||
Experiment(agent, env, stop_condition, hook) | ||
end | ||
|
||
#+ tangle=false | ||
using Plots | ||
pyplot() #hide | ||
ex = E`JuliaRL_SAC_Pendulum` | ||
run(ex) | ||
plot(ex.hook.rewards) | ||
savefig("assets/JuliaRL_SAC_Pendulum.png") #hide | ||
|
||
# ![](assets/JuliaRL_SAC_Pendulum.png) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.