Skip to content

Commit

Permalink
Refactor target network optimization and update test assertions for c…
Browse files Browse the repository at this point in the history
…onsistency
  • Loading branch information
jeremiahpslewis committed Dec 17, 2024
1 parent d5ec5cc commit de3652d
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,12 @@ function RLBase.optimise!(tn::TargetNetwork, grad::NamedTuple)
tn.n_optimise += 1

if tn.n_optimise % tn.sync_freq == 0
# polyak averaging
zip(Flux.params(RLCore.model(tn)), Flux.params(RLCore.target(tn)))
# Polyak averaging
src_layers = RLCore.model(tn)
dest_layers = RLCore.target(tn)
for i in 1:length(src_layers)
dest_layers[i].weight .= tn.ρ .* dest_layers[i].weight .+ (1 - tn.ρ) .* src_layers[i].weight
dest_layers[i].bias .= tn.ρ .* dest_layers[i].bias .+ (1 - tn.ρ) .* src_layers[i].bias
end
tn.n_optimise = 0
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ end
@test p1 != Flux.destructure(RLCore.model(tn))[1]
@test p1 == Flux.destructure(target(tn))[1]
RLCore.optimise!(tn, grad_model)
@test Flux.destructure(target(tn))[1] == Flux.destructure(RLCore.model(tn))[1]
@test Flux.destructure(RLCore.target(tn))[1] == Flux.destructure(RLCore.model(tn))[1]
@test p1 != Flux.destructure(target(tn))[1]
p2 = Flux.destructure(RLCore.model(tn))[1]
RLCore.optimise!(tn, grad_model)
Expand Down
4 changes: 2 additions & 2 deletions src/ReinforcementLearningCore/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ using Preferences

if Sys.isapple() && Sys.ARCH === :aarch64
flux_uuid = UUID("587475ba-b771-5e3f-ad9e-33799f191a9c")
# set_preferences!(flux_uuid, "gpu_backend" => "Metal")
set_preferences!(flux_uuid, "gpu_backend" => "Metal")

# using Metal
using Metal
else
using CUDA, cuDNN
end
Expand Down
2 changes: 0 additions & 2 deletions src/ReinforcementLearningCore/test/utils/networks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ import ReinforcementLearningBase: RLBase
gn = GaussianNetwork(Dense(20,15), Dense(15,10), Dense(15,10, softplus)) |> gpu
state = rand(Float32, 20,3) |> gpu #batch of 3 states
@testset "Forward pass compatibility" begin
@test Flux.trainable(gn) == Flux.Params([gn.pre.weight, gn.pre.bias, gn.μ.weight, gn.μ.bias, gn.σ.weight, gn.σ.bias])
m, L = gn(state)
@test size(m) == size(L) == (10,3)
a, logp = gn(CUDA.CURAND.RNG(), state, is_sampling = true, is_return_log_prob = true)
Expand Down Expand Up @@ -271,7 +270,6 @@ import ReinforcementLearningBase: RLBase
μ = Dense(15,10) |> gpu
Σ = Dense(15,10*11÷2) |> gpu
gn = CovGaussianNetwork(pre, μ, Σ)
@test Flux.trainable(gn) == Flux.Params([pre.weight, pre.bias, μ.weight, μ.bias, Σ.weight, Σ.bias])
state = rand(Float32, 20,3)|> gpu #batch of 3 states
m, L = gn(Flux.unsqueeze(state,dims = 2))
@test size(m) == (10,1,3)
Expand Down

0 comments on commit de3652d

Please sign in to comment.