Skip to content

Commit

Permalink
Update CQL_SAC.jl (#1003)
Browse files Browse the repository at this point in the history
  • Loading branch information
HenriDeh authored Nov 27, 2023
1 parent e1d9e9e commit ecadf3b
Showing 1 changed file with 1 addition and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ function update_critic!(p::CQLSACPolicy, batch::NamedTuple{SS′ART})
y = soft_q_learning_target(p.sac, r, t, s′)

states = MLUtils.unsqueeze(s, dims = 2) #(state_size x 1 x batchsize)
a_policy, logp_policy = RLCore.forward(p.sac.policy, states, p.action_sample_size) #(action_size x action_sample_size x batchsize), (1 x action_sample_size x batchsize)
a_policy, logp_policy = RLCore.forward(p.sac.policy, p.sac.device_rng, states, p.action_sample_size) #(action_size x action_sample_size x batchsize), (1 x action_sample_size x batchsize)

a_unif = (rand(p.sac.rng, Float32, size(a_policy)...) .- 0.5f0) .* 2 # Uniform sampling between -1 and 1: (action_size x action_sample_size x batchsize)
logp_unif = fill!(similar(a_unif, 1, size(a_unif)[2:end]...), 0.5^size(a_unif)[1]) #(1 x action_sample_size x batchsize)
Expand Down

0 comments on commit ecadf3b

Please sign in to comment.