diff --git a/src/ReinforcementLearningCore/Project.toml b/src/ReinforcementLearningCore/Project.toml index af7544a37..b50960272 100644 --- a/src/ReinforcementLearningCore/Project.toml +++ b/src/ReinforcementLearningCore/Project.toml @@ -1,6 +1,6 @@ name = "ReinforcementLearningCore" uuid = "de1b191a-4ae0-4afa-a27b-92d07f46b2d6" -version = "0.15" +version = "0.15.1" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" @@ -39,7 +39,6 @@ FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12, 0.13, 1" Flux = "0.13, 0.14" Functors = "0.1, 0.2, 0.3, 0.4" GPUArrays = "8, 9" -Parsers = "2" ProgressMeter = "1" Reexport = "1" ReinforcementLearningBase = "0.12" @@ -57,4 +56,4 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["CommonRLInterface","DomainSets", "Test", "Random"] +test = ["CommonRLInterface", "DomainSets", "Test", "Random"] diff --git a/src/ReinforcementLearningCore/src/core/stop_conditions.jl b/src/ReinforcementLearningCore/src/core/stop_conditions.jl index 410aa881c..87fe711e9 100644 --- a/src/ReinforcementLearningCore/src/core/stop_conditions.jl +++ b/src/ReinforcementLearningCore/src/core/stop_conditions.jl @@ -44,7 +44,7 @@ end function StopAfterStep(step; cur = 1, is_show_progress = true) if is_show_progress - progress = ProgressMeter.Progress(step, 1) + progress = ProgressMeter.Progress(step, dt = 1) ProgressMeter.update!(progress, cur) else progress = nothing @@ -83,7 +83,7 @@ end function StopAfterEpisode(episode; cur = 0, is_show_progress = true) if is_show_progress - progress = ProgressMeter.Progress(episode, 1) + progress = ProgressMeter.Progress(episode, dt = 1) ProgressMeter.update!(progress, cur) else progress = nothing diff --git a/src/ReinforcementLearningCore/src/policies/agent/agent.jl b/src/ReinforcementLearningCore/src/policies/agent/agent.jl index 55f11198d..375f12886 100644 --- a/src/ReinforcementLearningCore/src/policies/agent/agent.jl +++ b/src/ReinforcementLearningCore/src/policies/agent/agent.jl @@ -1,3 +1,4 @@ include("agent_base.jl") include("agent_srt_cache.jl") include("multi_agent.jl") +include("offline_agent.jl") \ No newline at end of file diff --git a/src/ReinforcementLearningCore/src/policies/agent/agent_base.jl b/src/ReinforcementLearningCore/src/policies/agent/agent_base.jl index 79eeead9a..c57afcdde 100644 --- a/src/ReinforcementLearningCore/src/policies/agent/agent_base.jl +++ b/src/ReinforcementLearningCore/src/policies/agent/agent_base.jl @@ -4,6 +4,9 @@ using Base.Threads: @spawn using Functors: @functor import Base.push! + +abstract type AbstractAgent <: AbstractPolicy end + """ Agent(;policy, trajectory) <: AbstractPolicy @@ -13,7 +16,7 @@ is a Callable and its call method accepts varargs and keyword arguments to be passed to the policy. """ -mutable struct Agent{P,T} <: AbstractPolicy +mutable struct Agent{P,T} <: AbstractAgent policy::P trajectory::T @@ -29,11 +32,11 @@ end Agent(;policy, trajectory) = Agent(policy, trajectory) -RLBase.optimise!(agent::Agent, stage::S) where {S<:AbstractStage} = RLBase.optimise!(TrajectoryStyle(agent.trajectory), agent, stage) -RLBase.optimise!(::SyncTrajectoryStyle, agent::Agent, stage::S) where {S<:AbstractStage} = RLBase.optimise!(agent.policy, stage, agent.trajectory) +RLBase.optimise!(agent::AbstractAgent, stage::S) where {S<:AbstractStage} = RLBase.optimise!(TrajectoryStyle(agent.trajectory), agent, stage) +RLBase.optimise!(::SyncTrajectoryStyle, agent::AbstractAgent, stage::S) where {S<:AbstractStage} = RLBase.optimise!(agent.policy, stage, agent.trajectory) # already spawn a task to optimise inner policy when initializing the agent -RLBase.optimise!(::AsyncTrajectoryStyle, agent::Agent, stage::S) where {S<:AbstractStage} = nothing +RLBase.optimise!(::AsyncTrajectoryStyle, agent::AbstractAgent, stage::S) where {S<:AbstractStage} = nothing #by default, optimise does nothing at all stage function RLBase.optimise!(policy::AbstractPolicy, stage::AbstractStage, trajectory::Trajectory) end @@ -47,7 +50,7 @@ end # !!! TODO: In async scenarios, parameters of the policy may still be updating # (partially), which will result to incorrect action. This should be addressed # in Oolong.jl with a wrapper -function RLBase.plan!(agent::Agent, env::AbstractEnv) +function RLBase.plan!(agent::AbstractAgent, env::AbstractEnv) RLBase.plan!(agent.policy, env) end diff --git a/src/ReinforcementLearningCore/src/policies/agent/offline_agent.jl b/src/ReinforcementLearningCore/src/policies/agent/offline_agent.jl new file mode 100644 index 000000000..65e23d2d0 --- /dev/null +++ b/src/ReinforcementLearningCore/src/policies/agent/offline_agent.jl @@ -0,0 +1,76 @@ +export OfflineAgent, OfflineBehavior + +""" + OfflineBehavior(; agent:: Union{<:Agent, Nothing}, steps::Int, reset_condition) + +Used to provide an OfflineAgent with a "behavior agent" that will generate the training data +at the `PreExperimentStage`. If `agent` is `nothing` (by default), does nothing. The `trajectory` of agent should +be the same as that of the parent `OfflineAgent`. +`steps` is the number of data elements to generate, defautls to the capacity of the trajectory. +`reset_condition` is the episode reset condition for the data generation (defaults to `ResetAtTerminal()`). + +The behavior agent will interact with the main environment of the experiment to generate the data. +""" +struct OfflineBehavior{A <: Union{<:Agent, Nothing}, R} + agent::A + steps::Int + reset_condition::R +end + +OfflineBehavior() = OfflineBehavior(nothing, 0, ResetAtTerminal()) + +function OfflineBehavior(agent; steps = ReinforcementLearningTrajectories.capacity(agent.trajectory.container.traces), reset_condition = ResetAtTerminal()) + if steps == Inf + @error "`steps` is infinite, please provide a finite integer." + end + OfflineBehavior(agent, steps, reset_condition) +end + +""" + OfflineAgent(policy::AbstractPolicy, trajectory::Trajectory, offline_behavior::OfflineBehavior = OfflineBehavior()) <: AbstractAgent + +`OfflineAgent` is an `AbstractAgent` that, unlike the usual online `Agent`, does not interact with the environment +during training in order to collect data. Just like `Agent`, it contains an `AbstractPolicy` to be trained an a `Trajectory` +that contains the training data. The difference being that the trajectory is filled prior to training and is not updated. +An `OfflineBehavior` can optionaly be provided to provide an second "behavior agent" that will +generate the training data at the `PreExperimentStage`. Does nothing by default. +""" +struct OfflineAgent{P<:AbstractPolicy,T<:Trajectory,B<:OfflineBehavior} <: AbstractAgent + policy::P + trajectory::T + offline_behavior::B + function OfflineAgent(policy::P, trajectory::T, offline_behavior = OfflineBehavior()) where {P<:AbstractPolicy, T<:Trajectory} + agent = new{P,T, typeof(offline_behavior)}(policy, trajectory, offline_behavior) + if TrajectoryStyle(trajectory) === AsyncTrajectoryStyle() + bind(trajectory, @spawn(optimise!(policy, trajectory))) + end + agent + end +end + +OfflineAgent(;policy, trajectory, offline_behavior = OfflineBehavior()) = OfflineAgent(policy, trajectory, offline_behavior) +@functor OfflineAgent (policy,) + +Base.push!(::OfflineAgent{P,T, <: OfflineBehavior{Nothing}}, ::PreExperimentStage, env::AbstractEnv) where {P,T} = nothing +#fills the trajectory with interactions generated with the behavior_agent at the PreExperimentStage. +function Base.push!(agent::OfflineAgent{P,T, <: OfflineBehavior{<:Agent}}, ::PreExperimentStage, env::AbstractEnv) where {P,T} + is_stop = false + policy = agent.offline_behavior.agent + steps = 0 + while !is_stop + reset!(env) + push!(policy, PreEpisodeStage(), env) + while !agent.offline_behavior.reset_condition(policy, env) # one episode + steps += 1 + push!(policy, PreActStage(), env) + action = RLBase.plan!(policy, env) + act!(env, action) + push!(policy, PostActStage(), env, action) + if steps >= agent.offline_behavior.steps + is_stop = true + break + end + end # end of an episode + push!(policy, PostEpisodeStage(), env) + end +end \ No newline at end of file diff --git a/src/ReinforcementLearningCore/test/policies/agent.jl b/src/ReinforcementLearningCore/test/policies/agent.jl index cacff4f3e..37a826827 100644 --- a/src/ReinforcementLearningCore/test/policies/agent.jl +++ b/src/ReinforcementLearningCore/test/policies/agent.jl @@ -39,4 +39,42 @@ import ReinforcementLearningCore.SRT end end end + @testset "OfflineAgent" begin + env = RandomWalk1D() + a_1 = OfflineAgent( + policy = RandomPolicy(), + trajectory = Trajectory( + CircularArraySARTSTraces(; capacity = 1_000), + DummySampler(), + ), + ) + @test a_1.offline_behavior.agent === nothing + push!(a_1, PreExperimentStage(), env) + @test isempty(a_1.trajectory.container) + + trajectory = Trajectory( + CircularArraySARTSTraces(; capacity = 1_000), + DummySampler(), + ) + + a_2 = OfflineAgent( + policy = RandomPolicy(), + trajectory = trajectory, + offline_behavior = OfflineBehavior( + Agent(RandomPolicy(), trajectory), + steps = 5, + ) + ) + push!(a_2, PreExperimentStage(), env) + @test length(a_2.trajectory.container) == 5 + + for agent in [a_1, a_2] + action = RLBase.plan!(agent, env) + @test action in (1,2) + for stage in [PreEpisodeStage(), PreActStage(), PostActStage(), PostEpisodeStage()] + push!(agent, stage, env) + @test length(agent.trajectory.container) in (0,5) + end + end + end end diff --git a/src/ReinforcementLearningExperiments/Project.toml b/src/ReinforcementLearningExperiments/Project.toml index a19d3ab64..8c4124746 100644 --- a/src/ReinforcementLearningExperiments/Project.toml +++ b/src/ReinforcementLearningExperiments/Project.toml @@ -27,7 +27,7 @@ Reexport = "1" ReinforcementLearningBase = "0.12" ReinforcementLearningCore = "0.15" ReinforcementLearningEnvironments = "0.8" -ReinforcementLearningZoo = "0.10" +ReinforcementLearningZoo = "0.10.1" StableRNGs = "1" Weave = "0.10" cuDNN = "1" diff --git a/src/ReinforcementLearningExperiments/deps/experiments/experiments/Offline/JuliaRL_CQLSAC_Pendulum.jl b/src/ReinforcementLearningExperiments/deps/experiments/experiments/Offline/JuliaRL_CQLSAC_Pendulum.jl new file mode 100644 index 000000000..5e30b816d --- /dev/null +++ b/src/ReinforcementLearningExperiments/deps/experiments/experiments/Offline/JuliaRL_CQLSAC_Pendulum.jl @@ -0,0 +1,98 @@ +# --- +# title: JuliaRL\_SAC\_Pendulum +# cover: assets/JuliaRL_SAC_Pendulum.png +# description: SAC applied to Pendulum +# date: 2021-05-22 +# author: "[Roman Bange](https://github.com/rbange)" +# --- + +#+ tangle=true +using ReinforcementLearningCore, ReinforcementLearningBase, ReinforcementLearningZoo, ReinforcementLearningEnvironments +using StableRNGs +using Flux +using Flux.Losses +using IntervalSets + +function RLCore.Experiment( + ::Val{:JuliaRL}, + ::Val{:CQLSAC}, + ::Val{:Pendulum}, + dummy = nothing; + save_dir=nothing, + seed=123 +) + rng = StableRNG(seed) + inner_env = PendulumEnv(T=Float32, rng=rng) + action_dims = inner_env.n_actions + A = action_space(inner_env) + low = A.left + high = A.right + ns = length(state(inner_env)) + na = 1 + + env = ActionTransformedEnv( + inner_env; + action_mapping=x -> low + (x[1] + 1) * 0.5 * (high - low) + ) + init = Flux.glorot_uniform(rng) + + create_policy_net() = Approximator( + SoftGaussianNetwork( + Chain( + Dense(ns, 30, relu, init=init), + Dense(30, 30, relu, init=init), + ), + Chain(Dense(30, na, init=init)), + Chain(Dense(30, na, softplus, init=init)), + ), + Adam(0.003), + ) + + create_q_net() = TargetNetwork( + Approximator( + Chain( + Dense(ns + na, 30, relu; init=init), + Dense(30, 30, relu; init=init), + Dense(30, 1; init=init), + ), + Adam(0.003),), + ρ = 0.99f0 + ) + trajectory= Trajectory( + CircularArraySARTSTraces(capacity = 10000, state = Float32 => (ns,), action = Float32 => (na,)), + BatchSampler{SS′ART}(64), + InsertSampleRatioController(ratio = 1/1, threshold = 0)) # There are no insertions in Offline RL, the controller is not used. + hook = TotalRewardPerEpisode() + + agent = OfflineAgent( + policy = CQLSACPolicy( + sac = SACPolicy( + policy=create_policy_net(), + qnetwork1=create_q_net(), + qnetwork2=create_q_net(), + γ=0.99f0, + α=0.2f0, + start_steps=0, + automatic_entropy_tuning=true, + lr_alpha=0.003f0, + action_dims=action_dims, + rng=rng, + device_rng= rng) + ), + trajectory = trajectory, + offline_behavior = OfflineBehavior(Agent(RandomPolicy(-1.0 .. 1.0; rng=rng), trajectory)) + ) + + stop_condition = StopAfterStep(5_000, is_show_progress=!haskey(ENV, "CI")) + Experiment(agent, env, stop_condition, hook) +end + +#+ tangle=false +using Plots +pyplot() #hide +ex = E`JuliaRL_SAC_Pendulum` +run(ex) +plot(ex.hook.rewards) +savefig("assets/JuliaRL_SAC_Pendulum.png") #hide + +# ![](assets/JuliaRL_SAC_Pendulum.png) diff --git a/src/ReinforcementLearningExperiments/deps/experiments/experiments/Policy Gradient/JuliaRL_SAC_Pendulum.jl b/src/ReinforcementLearningExperiments/deps/experiments/experiments/Policy Gradient/JuliaRL_SAC_Pendulum.jl index 2b3e0524b..760adf242 100644 --- a/src/ReinforcementLearningExperiments/deps/experiments/experiments/Policy Gradient/JuliaRL_SAC_Pendulum.jl +++ b/src/ReinforcementLearningExperiments/deps/experiments/experiments/Policy Gradient/JuliaRL_SAC_Pendulum.jl @@ -76,7 +76,7 @@ function RLCore.Experiment( ), trajectory= Trajectory( CircularArraySARTSTraces(capacity = 10000, state = Float32 => (ns,), action = Float32 => (na,)), - BatchSampler{SS′ART}(128), + BatchSampler{SS′ART}(64), InsertSampleRatioController(ratio = 1/1, threshold = 1000)) ) diff --git a/src/ReinforcementLearningExperiments/src/ReinforcementLearningExperiments.jl b/src/ReinforcementLearningExperiments/src/ReinforcementLearningExperiments.jl index 4c903691b..ccaead84c 100644 --- a/src/ReinforcementLearningExperiments/src/ReinforcementLearningExperiments.jl +++ b/src/ReinforcementLearningExperiments/src/ReinforcementLearningExperiments.jl @@ -23,6 +23,7 @@ include(joinpath(EXPERIMENTS_DIR, "JuliaRL_TRPO_CartPole.jl")) include(joinpath(EXPERIMENTS_DIR, "JuliaRL_MPO_CartPole.jl")) include(joinpath(EXPERIMENTS_DIR, "DQN_CartPoleGPU.jl")) include(joinpath(EXPERIMENTS_DIR, "JuliaRL_SAC_Pendulum.jl")) +include(joinpath(EXPERIMENTS_DIR, "JuliaRL_CQLSAC_Pendulum.jl")) # dynamic loading environments function __init__() end diff --git a/src/ReinforcementLearningExperiments/test/runtests.jl b/src/ReinforcementLearningExperiments/test/runtests.jl index 69471a995..e1440763f 100644 --- a/src/ReinforcementLearningExperiments/test/runtests.jl +++ b/src/ReinforcementLearningExperiments/test/runtests.jl @@ -15,6 +15,7 @@ run(E`JuliaRL_Rainbow_CartPole`) #run(E`JuliaRL_VPG_CartPole`) #run(E`JuliaRL_TRPO_CartPole`) run(E`JuliaRL_SAC_Pendulum`) +run(E`JuliaRL_CQLSAC_Pendulum`) run(E`JuliaRL_MPODiscrete_CartPole`) run(E`JuliaRL_MPOContinuous_CartPole`) run(E`JuliaRL_MPOCovariance_CartPole`) diff --git a/src/ReinforcementLearningZoo/Project.toml b/src/ReinforcementLearningZoo/Project.toml index 631de0b20..68a0b033b 100644 --- a/src/ReinforcementLearningZoo/Project.toml +++ b/src/ReinforcementLearningZoo/Project.toml @@ -1,6 +1,6 @@ name = "ReinforcementLearningZoo" uuid = "d607f57d-ee1e-4ba7-bcf2-7734c1e31854" -version = "0.10" +version = "0.10.1" [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" @@ -10,6 +10,7 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Optim = "429524aa-4258-5aef-a3af-852621145aeb" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -29,7 +30,7 @@ LogExpFunctions = "0.3" NNlib = "0.8, 0.9" Optim = "1" ReinforcementLearningBase = "0.12" -ReinforcementLearningCore = "0.15" +ReinforcementLearningCore = "0.15.1" StatsBase = "0.33, 0.34" Zygote = "0.6" cuDNN = "1" diff --git a/src/ReinforcementLearningZoo/src/ReinforcementLearningZoo.jl b/src/ReinforcementLearningZoo/src/ReinforcementLearningZoo.jl index 9a4546afc..f86255640 100644 --- a/src/ReinforcementLearningZoo/src/ReinforcementLearningZoo.jl +++ b/src/ReinforcementLearningZoo/src/ReinforcementLearningZoo.jl @@ -5,6 +5,7 @@ using ReinforcementLearningCore import ReinforcementLearningCore.forward const RLZoo = ReinforcementLearningZoo export RLZoo +import MLUtils include("algorithms/algorithms.jl") # include("hooks/hooks.jl") # TotalBatchRewardPerEpisode is broken, need to ensure vector copy works! diff --git a/src/ReinforcementLearningZoo/src/algorithms/algorithms.jl b/src/ReinforcementLearningZoo/src/algorithms/algorithms.jl index 0ad177db7..611d08a39 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/algorithms.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/algorithms.jl @@ -3,6 +3,6 @@ include("dqns/dqns.jl") include("policy_gradient/policy_gradient.jl") # include("searching/searching.jl") # include("cfr/cfr.jl") -# include("offline_rl/offline_rl.jl") +include("offline_rl/offline_rl.jl") # include("nfsp/abstract_nfsp.jl") # include("exploitability_descent/exploitability_descent.jl") diff --git a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/CQL_SAC.jl b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/CQL_SAC.jl new file mode 100644 index 000000000..c137da7f1 --- /dev/null +++ b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/CQL_SAC.jl @@ -0,0 +1,93 @@ +import LogExpFunctions +export CQLSACPolicy + +""" + CQLSACPolicy( + sac::SACPolicy, + action_sample_size::Int = 10, + α_cql::Float32 = 0f0, + α_lr::Float32 = 1f-3, + τ_cql::Float32 = 5f0, + α_cql_autotune::Bool = true, + cons_weight::Float32 = 1f0, #multiplies the Q difference before substracting tau, used to scale with the rewards. + ) + + Implements the Conservative Q-Learning algorithm [1] in its continuous variant on top of the SAC algorithm [2]. `CQLSACPolicy` wraps a classic `SACPolicy` whose networks will be trained normally, except for the additional conservative loss. + `CQLSACPolicy` contains the additional hyperparameters that are specific to this method. α_cql is the lagrange penalty for the conservative_loss, it will be automatically tuned if ` α_cql_autotune = true`. `cons_weight` is a scaling parameter + which may be necessary to decrease if the scale of the Q-values is large. `τ_cql` is the threshold of the lagrange conservative penalty. + See SACPolicy for all the other hyperparameters related to SAC. + + As this is an offline algorithm, it must be wrapped in an `OfflineAgent` which will not update the trajectory as the training progresses. However, it _will_ interact with the supplied environment, which may be useful to record the progress. + This can be avoided by supplying a dummy environment. + + [1] Kumar, A., Zhou, A., Tucker, G., & Levine, S. (2020). Conservative q-learning for offline reinforcement learning. Advances in Neural Information Processing Systems, 33, 1179-1191. + [2] Haarnoja, T. et al. (2018). Soft actor-critic algorithms and applications. arXiv preprint arXiv:1812.05905. +""" +Base.@kwdef mutable struct CQLSACPolicy{P<:SACPolicy} <: AbstractPolicy + sac::P + action_sample_size::Int = 10 + α_cql::Float32 = 0f0 + α_lr::Float32 = 1f-3 + τ_cql::Float32 = 5f0 + α_cql_autotune::Bool = true + cons_weight::Float32 = 1f0 #multiplies the Q difference before substracting tau, used to scale with the rewards. +end + +function RLBase.plan!(p::CQLSACPolicy, env) + RLBase.plan!(p.sac, env) +end + +function RLBase.optimise!(p::CQLSACPolicy, ::PostActStage, traj::Trajectory) + batch = ReinforcementLearningTrajectories.StatsBase.sample(traj) + update_critic!(p, batch) + update_actor!(p.sac, batch) #uses the implemented SACPolicy actor update, as it is identical +end + +function conservative_loss(p::CQLSACPolicy, t_qnetwork, q_policy_inputs, logps, s, a, y) + qnetwork = model(t_qnetwork) + q_policy = vec(LogExpFunctions.logsumexp(qnetwork(q_policy_inputs) .- logps, dims = 2)) #(1 x 1 x batchsize) -> (batchsize,) Note: some python public implementations use a temperature. + + q_beta = vec(qnetwork(vcat(s, a))) #(batchsize,) + + diff = mean(q_policy .- q_beta)*p.cons_weight - p.τ_cql + + if p.α_cql_autotune + p.α_cql += p.α_lr*diff + p.α_cql = clamp(p.α_cql, 0f0,1f6) + end + + conservative_loss = p.α_cql*diff + + q_learning_loss = mse(q_beta, y) + + return conservative_loss + q_learning_loss +end + +function update_critic!(p::CQLSACPolicy, batch::NamedTuple{SS′ART}) + s, s′, a, r, t = send_to_device(device(p.sac.qnetwork1), batch) + + y = soft_q_learning_target(p.sac, r, t, s′) + + states = MLUtils.unsqueeze(s, dims = 2) #(state_size x 1 x batchsize) + a_policy, logp_policy = RLCore.forward(p.sac.policy, states, p.action_sample_size) #(action_size x action_sample_size x batchsize), (1 x action_sample_size x batchsize) + + a_unif = (rand(p.sac.rng, Float32, size(a_policy)...) .- 0.5f0) .* 2 # Uniform sampling between -1 and 1: (action_size x action_sample_size x batchsize) + logp_unif = fill!(similar(a_unif, 1, size(a_unif)[2:end]...), 0.5^size(a_unif)[1]) #(1 x action_sample_size x batchsize) + + repeated_states = reduce(hcat, Iterators.repeated(states, p.action_sample_size*2)) #(state_size x action_sample_size*2 x batchsize) + actions = hcat(a_policy, a_unif)#, a_policy′) #(action_size x action_sample_size*2 x batchsize) + + q_policy_inputs = vcat(repeated_states, actions) + logps = hcat(logp_policy, logp_unif)#, logp_policy′) #(1 x action_sample_size*2 x batchsize) + + # Train Q Networks + q_grad_1 = gradient(Flux.params(model(p.sac.qnetwork1))) do + conservative_loss(p, p.sac.qnetwork1, q_policy_inputs, logps, s, a, y) + end + RLBase.optimise!(p.sac.qnetwork1, q_grad_1) + + q_grad_2 = gradient(Flux.params(model(p.sac.qnetwork2))) do + conservative_loss(p, p.sac.qnetwork2, q_policy_inputs, logps, s, a,y ) + end + RLBase.optimise!(p.sac.qnetwork2, q_grad_2) +end diff --git a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/offline_rl.jl b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/offline_rl.jl index 570a6bbe8..d3ba98a59 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/offline_rl.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/offline_rl.jl @@ -1,4 +1,4 @@ -include("BCQ.jl") +#=include("BCQ.jl") include("BEAR.jl") include("behavior_cloning.jl") include("CRR.jl") @@ -7,3 +7,5 @@ include("FisherBRC.jl") include("PLAS.jl") include("ope/ope.jl") include("common.jl") +=# +include("CQL_SAC.jl") diff --git a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/sac.jl b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/sac.jl index 0a29da38b..968c8360c 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/sac.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/sac.jl @@ -120,34 +120,45 @@ function RLBase.optimise!( traj::Trajectory ) for batch in traj - update!(p, batch) + update_critic!(p, batch) + update_actor!(p, batch) end end -function update!(p::SACPolicy, batch::NamedTuple{SS′ART}) - s, s′, a, r, t = send_to_device(device(p.qnetwork1), batch) - - γ, α = p.γ, p.α - +function soft_q_learning_target(p::SACPolicy, r, t, s′) a′, log_π = RLCore.forward(p.policy,p.device_rng, s′; is_sampling=true, is_return_log_prob=true) q′_input = vcat(s′, a′) q′ = min.(target(p.qnetwork1)(q′_input), target(p.qnetwork2)(q′_input)) - y = r .+ γ .* (1 .- t) .* dropdims(q′ .- α .* log_π, dims=1) + r .+ p.γ .* (1 .- t) .* dropdims(q′ .- p.α .* log_π, dims=1) +end - # Train Q Networks +function q_learning_loss(qnetwork, s, a, y) q_input = vcat(s, a) + q = dropdims(model(qnetwork)(q_input), dims=1) + mse(q, y) +end + +function update_critic!(p::SACPolicy, batch::NamedTuple{SS′ART}) + s, s′, a, r, t = send_to_device(device(p.qnetwork1), batch) + y = soft_q_learning_target(p, r, t, s′) + + # Train Q Networks q_grad_1 = gradient(Flux.params(model(p.qnetwork1))) do - q1 = dropdims(model(p.qnetwork1)(q_input), dims=1) - mse(q1, y) + q_learning_loss(p.qnetwork1, s, a, y) end RLBase.optimise!(p.qnetwork1, q_grad_1) + q_grad_2 = gradient(Flux.params(model(p.qnetwork2))) do - q2 = dropdims(model(p.qnetwork2)(q_input), dims=1) - mse(q2, y) + q_learning_loss(p.qnetwork2, s, a, y) end RLBase.optimise!(p.qnetwork2, q_grad_2) +end + +function update_actor!(p::SACPolicy, batch::NamedTuple{SS′ART}) + s = send_to_device(device(p.qnetwork1), batch[:state]) + a = send_to_device(device(p.qnetwork1), batch[:action]) # Train Policy p_grad = gradient(Flux.params(p.policy)) do @@ -159,13 +170,11 @@ function update!(p::SACPolicy, batch::NamedTuple{SS′ART}) ignore_derivatives() do p.reward_term = reward p.entropy_term = entropy + if p.automatic_entropy_tuning # Tune entropy automatically + p.α -= p.lr_alpha * mean(-log_π .- p.target_entropy) + end end - α * entropy - reward + p.α * entropy - reward end RLBase.optimise!(p.policy, p_grad) - - # Tune entropy automatically - if p.automatic_entropy_tuning - p.α -= p.lr_alpha * mean(-log_π .- p.target_entropy) - end end