Skip to content

Commit

Permalink
ComposedHooks, MultiHook fixes (#874)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremiahpslewis authored May 3, 2023
1 parent 6653304 commit 18714fc
Show file tree
Hide file tree
Showing 13 changed files with 196 additions and 108 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ subpackages. The relationship between them is depicted below:
## ✋ Getting Help
Are you looking for help with ReinforcementLearning.jl? Here are ways to find help:
1. Read the online documentation! Most likely the answer is already provided in an example or in the API documents. Search using the search bar in the upper left.
2. Chat with us in [Julia Slack](https://julialang.org/slack/) in the #reinforcement-learnin channel.
<!-- cspell:disable-next -->
2. Chat with us in [Julia Slack](https://julialang.org/slack/) in the #reinforcement-learnin channel.
3. Post a question in the [Julia discourse](https://discourse.julialang.org/) forum in the category "Machine Learning" and use "reinforcement-learning" as a tag.
4. For issues with unexpected behavior or defects in ReinforcementLearning.jl, then please open an issue on the ReinforcementLearning [GitHub page](https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl) with a minimal working example and steps to reproduce.

Expand Down
73 changes: 70 additions & 3 deletions src/ReinforcementLearningCore/src/core/hooks.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
export AbstractHook,
EmptyHook,
ComposedHook,
StepsPerEpisode,
RewardsPerEpisode,
TotalRewardPerEpisode,
Expand Down Expand Up @@ -31,16 +32,24 @@ abstract type AbstractHook end

(hook::AbstractHook)(args...) = nothing

struct ComposedHook{H} <: AbstractHook
hooks::H
struct ComposedHook{T<:Tuple} <: AbstractHook
hooks::T
ComposedHook(hooks...) = new{typeof(hooks)}(hooks)
end

Base.:(+)(h1::AbstractHook, h2::AbstractHook) = ComposedHook((h1, h2))
Base.:(+)(h1::ComposedHook, h2::AbstractHook) = ComposedHook((h1.hooks..., h2))
Base.:(+)(h1::AbstractHook, h2::ComposedHook) = ComposedHook((h1, h2.hooks...))
Base.:(+)(h1::ComposedHook, h2::ComposedHook) = ComposedHook((h1.hooks..., h2.hooks...))

(h::ComposedHook)(args...) = map(h -> h(args...), h.hooks)
function (hook::ComposedHook)(stage::AbstractStage, args...; kw...)
for h in hook.hooks
h(stage, args...; kw...)
end
return
end

Base.getindex(hook::ComposedHook, inds...) = getindex(hook.hooks, inds...)

#####
# EmptyHook
Expand Down Expand Up @@ -71,6 +80,8 @@ Base.getindex(h::StepsPerEpisode) = h.steps

(hook::StepsPerEpisode)(::PostActStage, args...) = hook.count += 1

(hook::StepsPerEpisode)(stage::Union{PostEpisodeStage,PostExperimentStage}, agent, env, ::Symbol) = hook(stage, agent, env)

function (hook::StepsPerEpisode)(::Union{PostEpisodeStage,PostExperimentStage}, agent, env)
push!(hook.steps, hook.count)
hook.count = 0
Expand Down Expand Up @@ -101,7 +112,10 @@ end
Base.getindex(h::RewardsPerEpisode) = h.rewards

(h::RewardsPerEpisode)(::PreEpisodeStage, agent, env) = push!(h.rewards, h.empty_vect)
(h::RewardsPerEpisode)(::PreEpisodeStage, agent, env, ::Symbol) = h(PreEpisodeStage(), agent, env)

(h::RewardsPerEpisode)(::PostActStage, agent, env) = push!(h.rewards[end], reward(env))
(h::RewardsPerEpisode)(::PostActStage, agent, env, player::Symbol) = push!(h.rewards[end], reward(env, player))

#####
# TotalRewardPerEpisode
Expand Down Expand Up @@ -130,6 +144,7 @@ end
Base.getindex(h::TotalRewardPerEpisode) = h.rewards

(h::TotalRewardPerEpisode)(::PostActStage, agent, env) = h.reward += reward(env)
(h::TotalRewardPerEpisode)(::PostActStage, agent, env, player::Symbol) = h.reward += reward(env, player)

function (hook::TotalRewardPerEpisode)(
::PostEpisodeStage,
Expand Down Expand Up @@ -161,6 +176,20 @@ function (hook::TotalRewardPerEpisode{true, F})(
display(hook)
end

# Pass through as no need for multiplayer customization
function (hook::TotalRewardPerEpisode)(
stage::Union{PostEpisodeStage, PostExperimentStage},
agent,
env,
player::Symbol
)
hook(
stage,
agent,
env,
)
end

#####
# TotalBatchRewardPerEpisode
#####
Expand Down Expand Up @@ -202,6 +231,16 @@ function (hook::TotalBatchRewardPerEpisode)(
return
end

function (hook::TotalBatchRewardPerEpisode)(
::PostActStage,
agent,
env,
player::Symbol,
)
hook.reward .+= reward(env, player)
return
end

function (hook::TotalBatchRewardPerEpisode)(::PostEpisodeStage, agent, env)
push!.(hook.rewards, hook.reward)
hook.reward .= 0
Expand Down Expand Up @@ -235,6 +274,20 @@ function (hook::TotalBatchRewardPerEpisode{true, F})(
display(hook)
end

# Pass through as no need for multiplayer customization
function (hook::TotalBatchRewardPerEpisode)(
stage::Union{PostEpisodeStage, PostExperimentStage},
agent,
env,
player::Symbol
)
hook(
stage,
agent,
env,
)
end

#####
# BatchStepsPerEpisode
#####
Expand Down Expand Up @@ -270,6 +323,20 @@ function (hook::BatchStepsPerEpisode)(
end
end

# Pass through as no need for multiplayer customization
function (hook::BatchStepsPerEpisode)(
stage::PostActStage,
agent,
env,
player::Symbol
)
hook(
stage,
agent,
env,
)
end

#####
# TimePerStep
#####
Expand Down
5 changes: 4 additions & 1 deletion src/ReinforcementLearningCore/src/core/run.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,10 @@ end

Base.show(io::IO, m::MIME"text/plain", t::Experiment{S}) where {S} = show(io, m, convert(AnnotatedStructTree, t; description=string(S)))

Base.run(ex::Experiment) = run(ex.policy, ex.env, ex.stop_condition, ex.hook)
function Base.run(ex::Experiment)
run(ex.policy, ex.env, ex.stop_condition, ex.hook)
return ex
end

function Base.run(
policy::AbstractPolicy,
Expand Down
3 changes: 2 additions & 1 deletion src/ReinforcementLearningCore/src/core/stages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@ struct PreActStage <: AbstractStage end
struct PostActStage <: AbstractStage end

(p::AbstractPolicy)(::AbstractStage, ::AbstractEnv) = nothing
(p::AbstractPolicy)(::AbstractStage, ::AbstractEnv, ::Symbol) = nothing

RLBase.optimise!(::AbstractPolicy) = nothing
RLBase.optimise!(::AbstractPolicy) = nothing
8 changes: 6 additions & 2 deletions src/ReinforcementLearningCore/src/policies/agent/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,15 @@ function (agent::Agent)(env::AbstractEnv, args...; kwargs...)
action
end

function (agent::Agent)(::PostActStage, env::E) where {E <: AbstractEnv}
function (agent::Agent)(::PostActStage, env::AbstractEnv)
update!(agent.cache, reward(env), is_terminated(env))
end

function (agent::Agent)(::PostExperimentStage, env::E) where {E <: AbstractEnv}
function (agent::Agent)(::PostActStage, p::Symbol, env::AbstractEnv)
update!(agent.cache, reward(env, p), is_terminated(env))
end

function (agent::Agent)(::PostExperimentStage, env::AbstractEnv)
RLBase.reset!(agent.cache)
end

Expand Down
26 changes: 4 additions & 22 deletions src/ReinforcementLearningCore/src/policies/agent/multi_agent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ end

function (multiagent::MultiAgentPolicy)(::PreEpisodeStage, env::E) where {E<:AbstractEnv}
for player in players(env)
multiagent[player](PreEpisodeStage(), env)
multiagent[player](PreEpisodeStage(), env, player)
end
end

Expand All @@ -195,31 +195,13 @@ end

function (multiagent::MultiAgentPolicy)(::PostEpisodeStage, env::E) where {E<:AbstractEnv}
for player in players(env)
multiagent[player](PostEpisodeStage(), env)
multiagent[player](PostEpisodeStage(), env, player)
end
end

function (hook::MultiAgentHook)(::PreEpisodeStage, multiagent::MultiAgentPolicy, env::E) where {E<:AbstractEnv}
function (hook::MultiAgentHook)(stage::S, multiagent::MultiAgentPolicy, env::E) where {E<:AbstractEnv,S<:AbstractStage}
for player in players(env)
hook[player](PreEpisodeStage(), multiagent[player], env)
end
end

function (hook::MultiAgentHook)(::PreActStage, multiagent::MultiAgentPolicy, env::E) where {E<:AbstractEnv}
for player in players(env)
hook[player](PreActStage(), multiagent[player], env)
end
end

function (hook::MultiAgentHook)(::PostActStage, multiagent::MultiAgentPolicy, env::E) where {E<:AbstractEnv}
for player in players(env)
hook[player](PostActStage(), multiagent[player], env)
end
end

function (hook::MultiAgentHook)(::PostEpisodeStage, multiagent::MultiAgentPolicy, env::E) where {E<:AbstractEnv}
for player in players(env)
hook[player](PostEpisodeStage(), multiagent[player], env)
hook[player](stage, multiagent[player], env, player)
end
end

Expand Down
57 changes: 57 additions & 0 deletions src/ReinforcementLearningCore/test/core/base.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
using ReinforcementLearningCore: SRT
using ReinforcementLearningBase

@testset "core" begin
@testset "simple workflow" begin
@testset "StopAfterStep" begin
agent = Agent(
RandomPolicy(),
Trajectory(
CircularArraySARTTraces(; capacity = 1_000),
BatchSampler(1),
InsertSampleRatioController(n_inserted = -1),
),
)
env = RandomWalk1D()
stop_condition = StopAfterStep(123)
hook = StepsPerEpisode()
run(agent, env, stop_condition, hook)

@test sum(hook[]) == length(agent.trajectory.container)
end

@testset "StopAfterEpisode" begin
agent = Agent(
RandomPolicy(),
Trajectory(
CircularArraySARTTraces(; capacity = 1_000),
BatchSampler(1),
InsertSampleRatioController(n_inserted = -1),
),
)
env = RandomWalk1D()
stop_condition = StopAfterEpisode(10)
hook = StepsPerEpisode()
run(agent, env, stop_condition, hook)

@test sum(hook[]) == length(agent.trajectory.container)
end

@testset "StopAfterStep, use type stable Agent" begin
env = RandomWalk1D()
agent = Agent(
RandomPolicy(legal_action_space(env)),
Trajectory(
CircularArraySARTTraces(; capacity = 1_000),
BatchSampler(1),
InsertSampleRatioController(n_inserted = -1),
),
SRT{Any, Any, Any}(),
)
stop_condition = StopAfterStep(123; is_show_progress=false)
hook = StepsPerEpisode()
run(agent, env, stop_condition, hook)
@test sum(hook[]) == length(agent.trajectory.container)
end
end
end
59 changes: 2 additions & 57 deletions src/ReinforcementLearningCore/test/core/core.jl
Original file line number Diff line number Diff line change
@@ -1,57 +1,2 @@
using ReinforcementLearningCore: SRT
using ReinforcementLearningBase

@testset "core" begin
@testset "simple workflow" begin
@testset "StopAfterStep" begin
agent = Agent(
RandomPolicy(),
Trajectory(
CircularArraySARTTraces(; capacity = 1_000),
BatchSampler(1),
InsertSampleRatioController(n_inserted = -1),
),
)
env = RandomWalk1D()
stop_condition = StopAfterStep(123)
hook = StepsPerEpisode()
run(agent, env, stop_condition, hook)

@test sum(hook[]) == length(agent.trajectory.container)
end

@testset "StopAfterEpisode" begin
agent = Agent(
RandomPolicy(),
Trajectory(
CircularArraySARTTraces(; capacity = 1_000),
BatchSampler(1),
InsertSampleRatioController(n_inserted = -1),
),
)
env = RandomWalk1D()
stop_condition = StopAfterEpisode(10)
hook = StepsPerEpisode()
run(agent, env, stop_condition, hook)

@test sum(hook[]) == length(agent.trajectory.container)
end

@testset "StopAfterStep, use type stable Agent" begin
env = RandomWalk1D()
agent = Agent(
RandomPolicy(legal_action_space(env)),
Trajectory(
CircularArraySARTTraces(; capacity = 1_000),
BatchSampler(1),
InsertSampleRatioController(n_inserted = -1),
),
SRT{Any, Any, Any}(),
)
stop_condition = StopAfterStep(123; is_show_progress=false)
hook = StepsPerEpisode()
run(agent, env, stop_condition, hook)
@test sum(hook[]) == length(agent.trajectory.container)
end
end
end
include("base.jl")
include("hooks.jl")
22 changes: 14 additions & 8 deletions src/ReinforcementLearningCore/test/core/hooks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,20 @@ function test_noop!(hook::AbstractHook; stages=[PreActStage(), PostActStage(), P
policy = RandomPolicy(legal_action_space(env))

hook_fieldnames = fieldnames(typeof(hook))
for stage in stages
hook_copy = deepcopy(hook)
hook_copy(stage, policy, env)
for field_ in hook_fieldnames
if getfield(hook, field_) isa Ref
@test getfield(hook, field_)[] == getfield(hook_copy, field_)[]
else
@test getfield(hook, field_) == getfield(hook_copy, field_)
for mode in [:MultiAgent, :SingleAgent]
for stage in stages
hook_copy = deepcopy(hook)
if mode == :SingleAgent
hook_copy(stage, policy, env)
elseif mode == :MultiAgent
hook_copy(stage, policy, env, :player_i)
end
for field_ in hook_fieldnames
if getfield(hook, field_) isa Ref
@test getfield(hook, field_)[] == getfield(hook_copy, field_)[]
else
@test getfield(hook, field_) == getfield(hook_copy, field_)
end
end
end
end
Expand Down
Loading

0 comments on commit 18714fc

Please sign in to comment.