Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a way to handle non-episodic environments #613

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .cspell/cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@
"logps",
"trilcol",
"mvnormlogpdf",
"mvnormals"
"mvnormals",
"Optimise"
],
"ignoreWords": [],
"minWordLength": 5,
Expand Down
1 change: 1 addition & 0 deletions src/ReinforcementLearningCore/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
CircularArrayBuffers = "9de3a189-e0c0-4e15-ba3b-b14b9fb0aec1"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
ElasticArrays = "fdbdab4c-e67f-52f5-8c3f-e7b388dad3d4"
Expand Down
27 changes: 10 additions & 17 deletions src/ReinforcementLearningCore/src/policies/agents/agent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,21 +96,6 @@ function RLBase.update!(
::AbstractStage,
) end

function RLBase.update!(
trajectory::AbstractTrajectory,
::AbstractPolicy,
::AbstractEnv,
::PreEpisodeStage,
)
if length(trajectory) > 0
pop!(trajectory[:state])
pop!(trajectory[:action])
if haskey(trajectory, :legal_actions_mask)
pop!(trajectory[:legal_actions_mask])
end
end
end

function RLBase.update!(
trajectory::AbstractTrajectory,
policy::AbstractPolicy,
Expand All @@ -119,6 +104,11 @@ function RLBase.update!(
action,
)
s = policy isa NamedPolicy ? state(env, nameof(policy)) : state(env)
#remove this state from the last_state_idx list since this one is not.
idx = current_idx(trajectory[:state])
if idx in trajectory.last_states_idxs
pop!(trajectory.last_states_idxs, idx)
end
push!(trajectory[:state], s)
push!(trajectory[:action], action)
if haskey(trajectory, :legal_actions_mask)
Expand Down Expand Up @@ -166,13 +156,16 @@ function RLBase.update!(

A = policy isa NamedPolicy ? action_space(env, nameof(policy)) : action_space(env)
a = get_dummy_action(A)

push!(trajectory.last_states_idxs, current_idx(trajectory[:terminal])) #note that this state is terminal and should not be sampled
push!(trajectory[:state], s)
#dummies to keep buffers of the same length
push!(trajectory[:action], a)
push!(trajectory[:reward], zero(eltype(trajectory[:reward])))
push!(trajectory[:terminal], true)
if haskey(trajectory, :legal_actions_mask)
lasm =
policy isa NamedPolicy ? legal_action_space_mask(env, nameof(policy)) :
legal_action_space_mask(env)
push!(trajectory[:legal_actions_mask], lasm)
end
end
end
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ function Base.empty!(t::AbstractTrajectory)
for k in keys(t)
empty!(t[k])
end
empty!(t.last_states_idxs)
end

function Base.push!(t::AbstractTrajectory; kwargs...)
Expand All @@ -40,6 +41,10 @@ function Base.push!(t::AbstractTrajectory; kwargs...)
end

function Base.pop!(t::AbstractTrajectory)
idx = current_idx(t[first(keys(t))])
if idx in t.last_states_idxs
pop!(t.last_states_idxs, idx)
end
for k in keys(t)
pop!(t[k])
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,31 @@ export Trajectory,
using MacroTools: @forward
using ElasticArrays
using CircularArrayBuffers: CircularArrayBuffer, CircularVectorBuffer
import DataStructures.Set

#####
# Trajectory
#####

"""
Trajectory(;[trace_name=trace_container]...)
Trajectory(;[trace_name=trace_container]..., is_episodic = true)

A simple wrapper of `NamedTuple`.
Define our own type here to avoid type piracy with `NamedTuple`
Mainly a simple wrapper of `NamedTuple`.
Set `is_episodic = false` when working with non-episodic environments (i.e. infinite horizon) that stop after a given number of steps to avoid multiplying the value if the last state by 0 when bootstrapping TD targets
"""
struct Trajectory{T} <: AbstractTrajectory
traces::T
is_episodic::Bool
last_states_idxs::Set{Int}
end

Trajectory(; kwargs...) = Trajectory(values(kwargs))
Trajectory(; is_episodic = true, kwargs...) = Trajectory(values(kwargs), is_episodic, Set{Int64}())

@forward Trajectory.traces Base.getindex, Base.keys

Base.merge(a::Trajectory, b::Trajectory) = Trajectory(merge(a.traces, b.traces))
Base.merge(a::Trajectory, b::NamedTuple) = Trajectory(merge(a.traces, b))
Base.merge(a::NamedTuple, b::Trajectory) = Trajectory(merge(a, b.traces))
Base.merge(a::Trajectory, b::Trajectory) = Trajectory(merge(a.traces, b.traces), a.is_episodic, Set{Int64}())
Base.merge(a::Trajectory, b::NamedTuple) = Trajectory(merge(a.traces, b), a.is_episodic, Set{Int64}())
Base.merge(a::NamedTuple, b::Trajectory) = Trajectory(merge(a, b.traces), b.is_episodic, Set{Int64}())

#####

Expand All @@ -53,10 +56,11 @@ underlying buffer.
See also [`CircularArraySARTTrajectory`](@ref),
[`CircularArraySLARTTrajectory`](@ref), [`CircularArrayPSARTTrajectory`](@ref).
"""
function CircularArrayTrajectory(; capacity, kwargs...)
function CircularArrayTrajectory(; capacity, is_episodic = true, kwargs...)
Trajectory(map(values(kwargs)) do x
CircularArrayBuffer{eltype(first(x))}(last(x)..., capacity)
end)
end,
is_episodic, Set{Int64}())
end

"""
Expand All @@ -72,10 +76,11 @@ Similar to [`CircularArrayTrajectory`](@ref), except that the underlying storage

See also [`CircularVectorSARTTrajectory`](@ref), [`CircularVectorSARTSATrajectory`](@ref).
"""
function CircularVectorTrajectory(; capacity, kwargs...)
function CircularVectorTrajectory(; capacity, is_episodic = true, kwargs...)
Trajectory(map(values(kwargs)) do x
CircularVectorBuffer{x}(capacity)
end)
end,
is_episodic, Set{Int64}())
end

#####
Expand Down Expand Up @@ -165,9 +170,9 @@ CircularArraySARTTrajectory(;
action = Int => (),
reward = Float32 => (),
terminal = Bool => (),
) = merge(
CircularArrayTrajectory(; capacity = capacity + 1, state = state, action = action),
CircularArrayTrajectory(; capacity = capacity, reward = reward, terminal = terminal),
is_episodic = true) = merge(
CircularArrayTrajectory(; capacity = capacity , state = state, action = action, is_episodic = is_episodic),
CircularArrayTrajectory(; capacity = capacity, reward = reward, terminal = terminal, is_episodic = is_episodic),
)

const CircularArraySLARTTrajectory = Trajectory{
Expand All @@ -191,14 +196,16 @@ CircularArraySLARTTrajectory(;
action = Int => (),
reward = Float32 => (),
terminal = Bool => (),
is_episodic = true
) = merge(
CircularArrayTrajectory(;
capacity = capacity + 1,
capacity = capacity ,
state = state,
legal_actions_mask = legal_actions_mask,
action = action,
is_episodic = is_episodic
),
CircularArrayTrajectory(; capacity = capacity, reward = reward, terminal = terminal),
CircularArrayTrajectory(; capacity = capacity, reward = reward, terminal = terminal, is_episodic = is_episodic),
)

#####
Expand Down Expand Up @@ -288,9 +295,10 @@ CircularVectorSARTTrajectory(;
action = Int,
reward = Float32,
terminal = Bool,
is_episodic = true
) = merge(
CircularVectorTrajectory(; capacity = capacity + 1, state = state, action = action),
CircularVectorTrajectory(; capacity = capacity, reward = reward, terminal = terminal),
CircularVectorTrajectory(; capacity = capacity , state = state, action = action, is_episodic = is_episodic),
CircularVectorTrajectory(; capacity = capacity, reward = reward, terminal = terminal, is_episodic = is_episodic),
)

#####
Expand Down Expand Up @@ -318,6 +326,7 @@ CircularVectorSARTSATrajectory(;
terminal = Bool,
next_state = state,
next_action = action,
is_episodic = true
) = CircularVectorTrajectory(;
capacity = capacity,
state = state,
Expand All @@ -326,6 +335,7 @@ CircularVectorSARTSATrajectory(;
terminal = terminal,
next_state = next_state,
next_action = next_action,
is_episodic = is_episodic
)

#####
Expand All @@ -336,10 +346,11 @@ CircularVectorSARTSATrajectory(;
A specialized [`Trajectory`](@ref) which uses [`ElasticArray`](https://github.com/JuliaArrays/ElasticArrays.jl) as the underlying
storage. See also [`ElasticSARTTrajectory`](@ref).
"""
function ElasticArrayTrajectory(; kwargs...)
function ElasticArrayTrajectory(; is_episodic = true, kwargs...)
Trajectory(map(values(kwargs)) do x
ElasticArray{eltype(first(x))}(undef, last(x)..., 0)
end)
end,
is_episodic, Set{Int64}())
end

const ElasticSARTTrajectory = Trajectory{
Expand Down Expand Up @@ -433,12 +444,14 @@ function ElasticSARTTrajectory(;
action = Int => (),
reward = Float32 => (),
terminal = Bool => (),
is_episodic = true
)
ElasticArrayTrajectory(;
state = state,
action = action,
reward = reward,
terminal = terminal,
is_episodic = is_episodic
)
end

Expand All @@ -451,10 +464,11 @@ end

A [`Trajectory`](@ref) with each trace using a `Vector` as the storage.
"""
function VectorTrajectory(; kwargs...)
function VectorTrajectory(; is_episodic = true, kwargs...)
Trajectory(map(values(kwargs)) do x
Vector{x}()
end)
end,
is_episodic, Set{Int64}())
end

const VectorSARTTrajectory =
Expand All @@ -477,8 +491,9 @@ function VectorSARTTrajectory(;
action = Int,
reward = Float32,
terminal = Bool,
is_episodic = true
)
VectorTrajectory(; state = state, action = action, reward = reward, terminal = terminal)
VectorTrajectory(; state = state, action = action, reward = reward, terminal = terminal, is_episodic = is_episodic)
end

const VectorSATrajectory =
Expand All @@ -494,8 +509,8 @@ A specialized [`VectorTrajectory`] with traces of `(:state, :action)`.
- `state::DataType = Int`
- `action::DataType = Int`
"""
function VectorSATrajectory(; state = Int, action = Int)
VectorTrajectory(; state = state, action = action)
function VectorSATrajectory(; state = Int, action = Int, is_episodic = true)
VectorTrajectory(; state = state, action = action, is_episodic = is_episodic)
end
#####

Expand All @@ -518,8 +533,8 @@ Base.getindex(t::PrioritizedTrajectory, s::Symbol) =
const CircularArrayPSARTTrajectory =
PrioritizedTrajectory{<:SumTree,<:CircularArraySARTTrajectory}

CircularArrayPSARTTrajectory(; capacity, kwargs...) = PrioritizedTrajectory(
CircularArraySARTTrajectory(; capacity = capacity, kwargs...),
CircularArrayPSARTTrajectory(; capacity, is_episodic = true, kwargs...) = PrioritizedTrajectory(
CircularArraySARTTrajectory(; capacity = capacity, is_episodic = is_episodic, kwargs...),
SumTree(capacity),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,15 @@ BatchSampler{T}(batch_size::Int; cache = nothing, rng = Random.GLOBAL_RNG) where

# TODO: deprecate
function StatsBase.sample(rng::AbstractRNG, t::AbstractTrajectory, s::BatchSampler)
inds = rand(rng, 1:length(t), s.batch_size)
inds = zeros(Int, s.batch_size)
for i in eachindex(inds)
inds[i] = rand(rng, 1:length(t))
#reject samples that are terminal states. This could be inneficient if an environment has very short episodes.
#An alternative is to create a list of acceptable inds equal to {1:length(t) \ last_state_idx}. That however necessitates to store a potentially long list of valid indices.
while inds[i] in t.last_state_idx
inds[i] = rand(rng, 1:length(t))
end
end
fetch!(s, t, inds)
inds, s.cache
end
Expand Down Expand Up @@ -101,6 +109,10 @@ function fetch!(s::BatchSampler{traces}, t::Union{CircularArraySARTTrajectory, C
@error "unsupported traces $traces"
end

if !t.is_episodic
batch.terminal .= false
end

if isnothing(s.cache)
s.cache = map(batch) do x
convert(Array, x)
Expand All @@ -127,9 +139,17 @@ end

# TODO:deprecate
function StatsBase.sample(rng::AbstractRNG, t::AbstractTrajectory, s::NStepBatchSampler)
valid_range =
isnothing(s.stack_size) ? (1:(length(t)-s.n+1)) : (s.stack_size:(length(t)-s.n+1))
inds = rand(rng, valid_range, s.batch_size)
ss : isnothing(s.stack_size) ? 1 : s.stack_size
valid_range = ss:(length(t)-s.n+1)
inds = zeros(Int, s.batch_size)
for i in eachindex(inds)
inds[i] = rand(rng, valid_range)
#reject samples that are terminal states. This could be inneficient if an environment has very short episodes.
#An alternative is to create a list of acceptable inds equal to {1:length(t) \ last_state_idx}. That however necessitates to store a potentially long list of valid indices.
while inds[i] in t.last_state_idx
inds[i] = rand(rng, valid_range)
end
end
inds, fetch!(s, t, inds)
end

Expand Down Expand Up @@ -189,6 +209,10 @@ function fetch!(
@error "unsupported traces $traces"
end

if !t.is_episodic
batch.terminal .= false
end

if isnothing(sampler.cache)
sampler.cache = map(batch) do x
convert(Array, x)
Expand All @@ -199,3 +223,14 @@ function fetch!(
end
end
end

#This could be added to CircularArrayBuffer.jl instead.
function current_idx(cb::CircularArrayBuffer)
if cb.nframes == length(cb.buffer)
cb.first
else
cb.nframes + 1
end
end

current_idx(a::Array) = size(a, ndims(a))
4 changes: 2 additions & 2 deletions src/ReinforcementLearningCore/test/components/trajectories.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@

push!(t; reward = 3.0f0, terminal = false, state = 4 * ones(Int, 4), action = 4)
@test length(t) == 3
@test t[:state] == [j for i in 1:4, j in 1:4]
@test t[:state] == [j for i in 1:4, j in 2:4]
@test t[:reward] == [1, 2, 3]

# test circle works as expected
push!(t; reward = 4.0f0, terminal = true, state = 5 * ones(Int, 4), action = 5)
@test length(t) == 3
@test t[:state] == [j for i in 1:4, j in 2:5]
@test t[:state] == [j for i in 1:4, j in 3:5]
@test t[:reward] == [2, 3, 4]
end

Expand Down