From ecadf3b8b1d8f0d9e5f4f8bb7d5a8e40bad54b1f Mon Sep 17 00:00:00 2001 From: Henri Dehaybe <47037088+HenriDeh@users.noreply.github.com> Date: Mon, 27 Nov 2023 16:02:29 +0100 Subject: [PATCH] Update CQL_SAC.jl (#1003) --- .../src/algorithms/offline_rl/CQL_SAC.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/CQL_SAC.jl b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/CQL_SAC.jl index c137da7f1..7f6f68601 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/CQL_SAC.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/CQL_SAC.jl @@ -69,7 +69,7 @@ function update_critic!(p::CQLSACPolicy, batch::NamedTuple{SS′ART}) y = soft_q_learning_target(p.sac, r, t, s′) states = MLUtils.unsqueeze(s, dims = 2) #(state_size x 1 x batchsize) - a_policy, logp_policy = RLCore.forward(p.sac.policy, states, p.action_sample_size) #(action_size x action_sample_size x batchsize), (1 x action_sample_size x batchsize) + a_policy, logp_policy = RLCore.forward(p.sac.policy, p.sac.device_rng, states, p.action_sample_size) #(action_size x action_sample_size x batchsize), (1 x action_sample_size x batchsize) a_unif = (rand(p.sac.rng, Float32, size(a_policy)...) .- 0.5f0) .* 2 # Uniform sampling between -1 and 1: (action_size x action_sample_size x batchsize) logp_unif = fill!(similar(a_unif, 1, size(a_unif)[2:end]...), 0.5^size(a_unif)[1]) #(1 x action_sample_size x batchsize)