diff --git a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/CQL_SAC.jl b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/CQL_SAC.jl index c137da7f1..7f6f68601 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/CQL_SAC.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/CQL_SAC.jl @@ -69,7 +69,7 @@ function update_critic!(p::CQLSACPolicy, batch::NamedTuple{SS′ART}) y = soft_q_learning_target(p.sac, r, t, s′) states = MLUtils.unsqueeze(s, dims = 2) #(state_size x 1 x batchsize) - a_policy, logp_policy = RLCore.forward(p.sac.policy, states, p.action_sample_size) #(action_size x action_sample_size x batchsize), (1 x action_sample_size x batchsize) + a_policy, logp_policy = RLCore.forward(p.sac.policy, p.sac.device_rng, states, p.action_sample_size) #(action_size x action_sample_size x batchsize), (1 x action_sample_size x batchsize) a_unif = (rand(p.sac.rng, Float32, size(a_policy)...) .- 0.5f0) .* 2 # Uniform sampling between -1 and 1: (action_size x action_sample_size x batchsize) logp_unif = fill!(similar(a_unif, 1, size(a_unif)[2:end]...), 0.5^size(a_unif)[1]) #(1 x action_sample_size x batchsize)