Skip to content

Commit

Permalink
Fix and refactor SAC (#985)
Browse files Browse the repository at this point in the history
* make softgaussian

* add tanh

* Update docstring

* fixing SAC

* enable tests

* Improve correctness of GaussianNetwork

* update CUDA

* use the new TargetNetwork

* fix test

* fix diaglogpdf

* fix tests

* RLCore

* bump versions and compats

* Core import

* reomve DomainSets 0.7 compat

* Update src/ReinforcementLearningDatasets/README.md

Co-authored-by: Jeremiah <[email protected]>

* Update Project.toml

* Update Project.toml

* Update Project.toml

* Update Project.toml

* Bump compat

* Update Project.toml

* Update Project.toml

---------

Co-authored-by: Jeremiah <[email protected]>
  • Loading branch information
HenriDeh and jeremiahpslewis authored Oct 12, 2023
1 parent 3b21982 commit e772d6f
Show file tree
Hide file tree
Showing 93 changed files with 588 additions and 536 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15129,7 +15129,7 @@ <h2 id="Two-Most-Commonly-Used-Algorithms">Two Most Commonly Used Algorithms<a c
<span class="p">)</span> <span class="o">|&gt;</span> <span class="n">cpu</span><span class="p">,</span>
<span class="n">optimizer</span> <span class="o">=</span> <span class="n">Adam</span><span class="p">(),</span>
<span class="p">),</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="mi">32</span><span class="p">,</span>
<span class="n">batchsize</span> <span class="o">=</span> <span class="mi">32</span><span class="p">,</span>
<span class="n">min_replay_history</span> <span class="o">=</span> <span class="mi">100</span><span class="p">,</span>
<span class="n">loss_func</span> <span class="o">=</span> <span class="n">huber_loss</span><span class="p">,</span>
<span class="n">rng</span> <span class="o">=</span> <span class="n">rng</span><span class="p">,</span>
Expand Down
14 changes: 7 additions & 7 deletions docs/homepage/blog/ospp_final_term_report_210370741/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ struct D4RLDataSet{T<:AbstractRNG} <: RLDataSet
dataset::Dict{Symbol, Any}
repo::String
dataset_size::Integer
batch_size::Integer
batchsize::Integer
style::Tuple
rng::T
meta::Dict
Expand All @@ -330,7 +330,7 @@ function dataset(dataset::String;
repo = "d4rl",
rng = StableRNG(123),
is_shuffle = true,
batch_size=256
batchsize=256
)
```
Expand Down Expand Up @@ -383,7 +383,7 @@ Multi threaded batching using a parallel loop where each thread loads the batche
```julia
res = Channel{AtariRLTransition}(n_preallocations; taskref=taskref, spawn=true) do ch
Threads.@threads for i in 1:batch_size
Threads.@threads for i in 1:batchsize
put!(ch, deepcopy(batch(buffer_template, popfirst!(transitions), i)))
end
end
Expand Down Expand Up @@ -472,7 +472,7 @@ end
The datapoints are then put in a `RingBuffer` which is returned.
```julia
res = RingBuffer(buffer;taskref=taskref, sz=n_preallocations) do buff
Threads.@threads for i in 1:batch_size
Threads.@threads for i in 1:batchsize
batch!(buff, take!(transitions), i)
end
end
Expand Down Expand Up @@ -694,7 +694,7 @@ mutable struct FQE{
target_q_network::C_T
n_evals::Int
γ::Float32
batch_size::Int
batchsize::Int
update_freq::Int
update_step::Int
tar_update_freq::Int
Expand All @@ -714,7 +714,7 @@ function RLBase.update!(l::FQE, batch::NamedTuple{SARTS})
D = device(Q)
s, a, r, t, s′ = (send_to_device(D, batch[x]) for x in SARTS)
γ = l.γ
batch_size = l.batch_size
batchsize = l.batchsize

loss_func = Flux.Losses.mse

Expand All @@ -723,7 +723,7 @@ function RLBase.update!(l::FQE, batch::NamedTuple{SARTS})
target = r .+ γ .* (1 .- t) .* q′

gs = gradient(params(Q)) do
q = Q(vcat(s, reshape(a, :, batch_size))) |> vec
q = Q(vcat(s, reshape(a, :, batchsize))) |> vec
loss = loss_func(q, target)
Zygote.ignore() do
l.loss = loss
Expand Down
12 changes: 6 additions & 6 deletions docs/homepage/blog/ospp_report_210370190/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ rl_agent = Agent(
),
γ = 1.0f0,
loss_func = huber_loss,
batch_size = 128,
batchsize = 128,
update_freq = 128,
min_replay_history = 1000,
target_update_freq = 1000,
Expand Down Expand Up @@ -281,7 +281,7 @@ sl_agent = Agent(
optimizer = Descent(0.01),
),
explorer = WeightedSoftmaxExplorer(),
batch_size = 128,
batchsize = 128,
min_reservoir_history = 1000,
rng = rng,
),
Expand Down Expand Up @@ -351,7 +351,7 @@ Given that the [`DDPGPolicy`](https://juliareinforcementlearning.org/docs/rlzoo/
mutable struct MADDPGManager <: AbstractPolicy
agents::Dict{<:Any, <:Agent}
traces
batch_size::Int
batchsize::Int
update_freq::Int
update_step::Int
rng::AbstractRNG
Expand Down Expand Up @@ -454,7 +454,7 @@ agents = MADDPGManager(
trajectory = deepcopy(trajectory),
)) for player in players(env) if player != chance_player(env)),
SARTS, # trace's type
512, # batch_size
512, # batchsize
100, # update_freq
0, # initial update_step
rng
Expand Down Expand Up @@ -508,7 +508,7 @@ create_policy(player) = DDPGPolicy(
na = length(action_space(env, player)),
start_steps = 0,
start_policy = nothing,
update_after = 512 * env.max_steps, # batch_size * env.max_steps
update_after = 512 * env.max_steps, # batchsize * env.max_steps
act_limit = 1.0,
act_noise = 0.,
)
Expand All @@ -530,7 +530,7 @@ agents = MADDPGManager(
) for player in (:Speaker, :Listener)
),
SARTS, # trace's type
512, # batch_size
512, # batchsize
100, # update_freq
0, # initial update_step
rng
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ Base.@kwdef struct OfflinePolicy{L,T} <: AbstractPolicy
learner::L
dataset::T
continuous::Bool
batch_size::Int
batchsize::Int
end
```
This implementation of `OfflinePolicy` refers to [`QBasePolicy`](https://juliareinforcementlearning.org/docs/rlcore/#ReinforcementLearningCore.QBasedPolicy). It provides a parameter `continuous` to support different action space types, including continuous and discrete. `learner` is a specific algorithm for learning and providing policy. `dataset` and `batch_size` are used to sample data for learning.
This implementation of `OfflinePolicy` refers to [`QBasePolicy`](https://juliareinforcementlearning.org/docs/rlcore/#ReinforcementLearningCore.QBasedPolicy). It provides a parameter `continuous` to support different action space types, including continuous and discrete. `learner` is a specific algorithm for learning and providing policy. `dataset` and `batchsize` are used to sample data for learning.

Besides, we implement corresponding functions `π`, `update!` and `sample`. `π` is used to select the action, whose form is determined by the type of action space. `update!` can be used in two stages. In `PreExperiment` stage, we can call this function for pre-training algorithms with `pretrain_step` parameters. In `PreAct` stage, we call this function for training the `learner`. In function `update!`, we need to call function `sample` to sample a batch of data from the dataset. With the development of [ReinforcementLearningDataset.jl](https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/tree/master/src/ReinforcementLearningDatasets), the `sample` function will be deprecated.

Expand All @@ -73,7 +73,7 @@ offline_dqn_policy = OfflinePolicy(
),
dataset = dataset,
continuous = false,
batch_size = 64,
batchsize = 64,
)
```

Expand Down Expand Up @@ -266,7 +266,7 @@ function RLBase.update!(p::OfflinePolicy, traj::AbstractTrajectory, ::AbstractEn
if in(:pretrain_step, fieldnames(typeof(l)))
println("Pretrain...")
for _ in 1:l.pretrain_step
inds, batch = sample(l.rng, p.dataset, p.batch_size)
inds, batch = sample(l.rng, p.dataset, p.batchsize)
update!(l, batch)
end
end
Expand Down
2 changes: 1 addition & 1 deletion docs/src/How_to_use_hooks.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ policy = Agent(
) |> cpu,
optimizer = Adam(),
),
batch_size = 32,
batchsize = 32,
min_replay_history = 100,
loss_func = huber_loss,
),
Expand Down
2 changes: 1 addition & 1 deletion src/ReinforcementLearningBase/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ReinforcementLearningBase"
uuid = "e575027e-6cd6-5018-9292-cdc6200d2b44"
authors = ["Johanni Brea <[email protected]>", "Jun Tian <[email protected]>"]
version = "0.12.1"
version = "0.12.2"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
4 changes: 2 additions & 2 deletions src/ReinforcementLearningCore/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ReinforcementLearningCore"
uuid = "de1b191a-4ae0-4afa-a27b-92d07f46b2d6"
version = "0.14.0"
version = "0.15"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down Expand Up @@ -30,7 +30,7 @@ cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
[compat]
AbstractTrees = "0.3, 0.4"
Adapt = "3"
CUDA = "4"
CUDA = "4, 5"
ChainRulesCore = "1"
CircularArrayBuffers = "0.1"
Crayons = "4"
Expand Down
6 changes: 3 additions & 3 deletions src/ReinforcementLearningCore/src/core/hooks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,13 +217,13 @@ end
Base.getindex(h::BatchStepsPerEpisode) = h.steps

"""
BatchStepsPerEpisode(batch_size::Int; tag = "TRAINING")
BatchStepsPerEpisode(batchsize::Int; tag = "TRAINING")
Similar to [`StepsPerEpisode`](@ref), but is specific to environments
which return a `Vector` of rewards (a typical case with `MultiThreadEnv`).
"""
function BatchStepsPerEpisode(batch_size::Int)
BatchStepsPerEpisode([Int[] for _ = 1:batch_size], zeros(Int, batch_size))
function BatchStepsPerEpisode(batchsize::Int)
BatchStepsPerEpisode([Int[] for _ = 1:batchsize], zeros(Int, batchsize))
end

function Base.push!(hook::BatchStepsPerEpisode,
Expand Down
8 changes: 4 additions & 4 deletions src/ReinforcementLearningCore/src/utils/distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ end
mvnormlogpdf(μ::A, LorU::A, x::A; ϵ = 1f-8) where A <: AbstractArray
Batch version that takes 3D tensors as input where each slice along the 3rd
dimension is a batch sample. `μ` is a (action_size x 1 x batch_size) matrix,
`L` is a (action_size x action_size x batch_size), x is a (action_size x
action_samples x batch_size). Return a 3D matrix of size (1 x action_samples x
batch_size).
dimension is a batch sample. `μ` is a (action_size x 1 x batchsize) matrix,
`L` is a (action_size x action_size x batchsize), x is a (action_size x
action_samples x batchsize). Return a 3D matrix of size (1 x action_samples x
batchsize).
"""
function mvnormlogpdf::A, LorU::A, x::A; ϵ=1.0f-8) where {A<:AbstractArray}
it = zip(eachslice(μ, dims = 3), eachslice(LorU, dims = 3), eachslice(x, dims = 3))
Expand Down
Loading

0 comments on commit e772d6f

Please sign in to comment.