Skip to content

Commit

Permalink
Fix hook issues (#887)
Browse files Browse the repository at this point in the history
* Fix hook issues

* Drop is_terminated check from run function

* Version bump

* Update NEWS.md

---------

Co-authored-by: Henri Dehaybe <[email protected]>
  • Loading branch information
jeremiahpslewis and HenriDeh authored May 23, 2023
1 parent 6797f89 commit 0503496
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 12 deletions.
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,10 @@

### ReinforcementLearningCore.jl

#### v0.10.1

- Fix hook issue with 'extra' call; always run `push!` at end of episode, regardless of whether stopped or terminated

#### v0.10.0

- Transition to `RLCore.forward`, `RLBase.act!`, `RLBase.plan!` and `Base.push!` syntax instead of functional objects for hooks, policies and environments
Expand Down
2 changes: 1 addition & 1 deletion src/ReinforcementLearningCore/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ReinforcementLearningCore"
uuid = "de1b191a-4ae0-4afa-a27b-92d07f46b2d6"
version = "0.10.0"
version = "0.10.1"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
4 changes: 2 additions & 2 deletions src/ReinforcementLearningCore/src/core/hooks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,9 @@ Base.getindex(h::StepsPerEpisode) = h.steps

Base.push!(hook::StepsPerEpisode, ::PostActStage, args...) = hook.count += 1

Base.push!(hook::StepsPerEpisode, stage::Union{PostEpisodeStage,PostExperimentStage}, agent, env, ::Symbol) = Base.push!(hook, stage, agent, env)
Base.push!(hook::StepsPerEpisode, stage::PostEpisodeStage, agent, env, ::Symbol) = Base.push!(hook, stage, agent, env)

function Base.push!(hook::StepsPerEpisode, ::Union{PostEpisodeStage,PostExperimentStage}, agent, env)
function Base.push!(hook::StepsPerEpisode, ::PostEpisodeStage, agent, env)
Base.push!(hook.steps, hook.count)
hook.count = 0
end
Expand Down
6 changes: 2 additions & 4 deletions src/ReinforcementLearningCore/src/core/run.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,8 @@ function _run(policy::AbstractPolicy,
end
end # end of an episode

if is_terminated(env)
push!(policy, PostEpisodeStage(), env) # let the policy see the last observation
push!(hook, PostEpisodeStage(), policy, env)
end
push!(policy, PostEpisodeStage(), env) # let the policy see the last observation
push!(hook, PostEpisodeStage(), policy, env)
end
push!(policy, PostExperimentStage(), env)
push!(hook, PostExperimentStage(), policy, env)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,8 @@ function Base.run(
end
end # end of an episode

if is_terminated(env)
push!(multiagent_policy, PostEpisodeStage(), env) # let the policy see the last observation
push!(multiagent_hook, PostEpisodeStage(), multiagent_policy, env)
end
push!(multiagent_policy, PostEpisodeStage(), env) # let the policy see the last observation
push!(multiagent_hook, PostEpisodeStage(), multiagent_policy, env)
end
push!(multiagent_policy, PostExperimentStage(), env)
push!(multiagent_hook, PostExperimentStage(), multiagent_policy, env)
Expand Down
17 changes: 16 additions & 1 deletion src/ReinforcementLearningCore/test/core/hooks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ function test_noop!(hook::AbstractHook; stages=[PreActStage(), PostActStage(), P
end
end

function test_run!(hook::AbstractHook)
hook_ = deepcopy(hook)
run(RandomPolicy(), RandomWalk1D(), StopAfterEpisode(10), hook_)
return hook_
end

@testset "TotalRewardPerEpisode" begin
h_1 = TotalRewardPerEpisode(; is_display_on_exit=true)
h_2 = TotalRewardPerEpisode(; is_display_on_exit=false)
Expand All @@ -42,6 +48,9 @@ end
policy = RandomPolicy(legal_action_space(env))

for h in (h_1, h_2, h_3, h_4, h_5)
h_ = test_run!(h)
@test length(h_.rewards) == 10

push!(h, PostActStage(), policy, env)
@test h.reward == 1
push!(h, PostEpisodeStage(), policy, env)
Expand All @@ -65,6 +74,9 @@ end
h_5 = TotalBatchRewardPerEpisode(10)

for h in (h_1, h_2, h_3, h_4, h_5)
h_ = test_run!(h)
@test length(h_.rewards) == 10

push!(h, PostActStage(), policy, env)
@test h.reward == fill(1, 10)
push!(h, PostEpisodeStage(), policy, env)
Expand Down Expand Up @@ -119,7 +131,7 @@ end
@test h.steps == [100]

push!(h, PostExperimentStage(), agent, env)
@test h.steps == [100, 0]
@test h.steps == [100]

test_noop!(h, stages=[PreActStage(), PreEpisodeStage(), PreExperimentStage()])
end
Expand All @@ -133,6 +145,9 @@ end
h_3 = RewardsPerEpisode{Float16}()

for h in (h_1, h_2, h_3)
h_ = test_run!(h)
@test length(h_.rewards) == 10

push!(h, PreEpisodeStage(), agent, env)
@test h.rewards == [[]]

Expand Down

0 comments on commit 0503496

Please sign in to comment.