From 153d750c46911629cf2f90f9044985c3cc5e8d89 Mon Sep 17 00:00:00 2001 From: Simon Haastert <57462510+SimonHashtag@users.noreply.github.com> Date: Thu, 29 Feb 2024 15:58:16 +0100 Subject: [PATCH 1/4] Update gym.jl Make gym.jl compatible with current version of Gymnasium - truncation and termination info instead of done --- .../src/environments/3rd_party/gym.jl | 31 +++++++++++++++---- 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/src/ReinforcementLearningEnvironments/src/environments/3rd_party/gym.jl b/src/ReinforcementLearningEnvironments/src/environments/3rd_party/gym.jl index cb8038c06..541d2762f 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/3rd_party/gym.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/3rd_party/gym.jl @@ -69,8 +69,11 @@ RLBase.action_space(env::GymEnv) = env.action_space RLBase.state_space(env::GymEnv) = env.observation_space function RLBase.reward(env::GymEnv{T}) where {T} - if pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type) && length(env.state) == 4 - obs, reward, isdone, info = convert(Tuple{T,Float64,Bool,PyDict}, env.state) + if pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type) && length(env.state) == 5 + _, reward, = convert(Tuple{T,Float64,Bool,Bool,PyDict}, env.state) + reward + elseif pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type) && length(env.state) == 4 + _, reward, = convert(Tuple{T,Float64,Bool,PyDict}, env.state) reward else 0.0 @@ -78,17 +81,33 @@ function RLBase.reward(env::GymEnv{T}) where {T} end function RLBase.is_terminated(env::GymEnv{T}) where {T} - if pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type) && length(env.state) == 4 - obs, reward, isdone, info = convert(Tuple{T,Float64,Bool,PyDict}, env.state) + if pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type) && length(env.state) == 5 + _, _, isterminated, = convert(Tuple{T,Float64,Bool,Bool,PyDict}, env.state) + isterminated + elseif pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type) && length(env.state) == 4 + @warn "Gym version outdated. Update gym to obtain termination and truncation info instead of done signal." + _, _, isdone, = convert(Tuple{T,Float64,Bool,PyDict}, env.state) isdone else false end end +function RLBase.is_truncated(env::GymEnv{T}) where {T} + if pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type) && length(env.state) == 5 + _, _, _, istruncated, = convert(Tuple{T,Float64,Bool,Bool,PyDict}, env.state) + istruncated + else + false + end +end + function RLBase.state(env::GymEnv{T}) where {T} - if pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type) && length(env.state) == 4 - obs, reward, isdone, info = convert(Tuple{T,Float64,Bool,PyDict}, env.state) + if pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type) && length(env.state) == 5 + obs, = convert(Tuple{T,Float64,Bool,Bool,PyDict}, env.state) + obs + elseif pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type) && length(env.state) == 4 + obs, = convert(Tuple{T,Float64,Bool,PyDict}, env.state) obs else convert(T, env.state) From 6654ff6df5ebd609fd5868366d35740806e0412e Mon Sep 17 00:00:00 2001 From: SimonHashtag <57462510+SimonHashtag@users.noreply.github.com> Date: Fri, 1 Mar 2024 12:36:14 +0100 Subject: [PATCH 2/4] Updated to gymnasium --- .../src/environments/3rd_party/gym.jl | 40 +++++++++---------- .../test/runtests.jl | 2 +- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/src/ReinforcementLearningEnvironments/src/environments/3rd_party/gym.jl b/src/ReinforcementLearningEnvironments/src/environments/3rd_party/gym.jl index 541d2762f..3f8e97d1d 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/3rd_party/gym.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/3rd_party/gym.jl @@ -1,12 +1,12 @@ using .PyCall function GymEnv(name::String; seed::Union{Int,Nothing}=nothing) - if !PyCall.pyexists("gym") + if !PyCall.pyexists("gymnasium") error( - "Cannot import module 'gym'.\n\nIf you did not yet install it, try running\n`ReinforcementLearningEnvironments.install_gym()`\n", + "Cannot import module 'gymnasium'.\n\nIf you did not yet install it, try running\n`ReinforcementLearningEnvironments.install_gym()`\n", ) end - gym = pyimport_conda("gym", "gym") + gym = pyimport_conda("gymnasium", "gymnasium") if PyCall.pyexists("d4rl") pyimport("d4rl") end @@ -93,7 +93,7 @@ function RLBase.is_terminated(env::GymEnv{T}) where {T} end end -function RLBase.is_truncated(env::GymEnv{T}) where {T} +function is_truncated(env::GymEnv{T}) where {T} if pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type) && length(env.state) == 5 _, _, _, istruncated, = convert(Tuple{T,Float64,Bool,Bool,PyDict}, env.state) istruncated @@ -142,19 +142,19 @@ end function list_gym_env_names(; modules=[ - "gym.envs.algorithmic", - "gym.envs.box2d", - "gym.envs.classic_control", - "gym.envs.mujoco", - "gym.envs.mujoco.ant_v3", - "gym.envs.mujoco.half_cheetah_v3", - "gym.envs.mujoco.hopper_v3", - "gym.envs.mujoco.humanoid_v3", - "gym.envs.mujoco.swimmer_v3", - "gym.envs.mujoco.walker2d_v3", - "gym.envs.robotics", - "gym.envs.toy_text", - "gym.envs.unittest", + "gymnasium.envs.algorithmic", + "gymnasium.envs.box2d", + "gymnasium.envs.classic_control", + "gymnasium.envs.mujoco", + "gymnasium.envs.mujoco.ant_v3", + "gymnasium.envs.mujoco.half_cheetah_v3", + "gymnasium.envs.mujoco.hopper_v3", + "gymnasium.envs.mujoco.humanoid_v3", + "gymnasium.envs.mujoco.swimmer_v3", + "gymnasium.envs.mujoco.walker2d_v3", + "gymnasium.envs.robotics", + "gymnasium.envs.toy_text", + "gymnasium.envs.unittest", "d4rl.pointmaze", "d4rl.hand_manipulation_suite", "d4rl.gym_mujoco.gym_envs", @@ -166,14 +166,14 @@ function list_gym_env_names(; if PyCall.pyexists("d4rl") pyimport("d4rl") end - gym = pyimport("gym") + gym = pyimport("gymnasium") [x.id for x in values(gym.envs.registry) if split(x.entry_point, ':')[1] in modules] end """ - install_gym(; packages = ["gym", "pybullet"]) + install_gym(; packages = ["gymnasium", "pybullet"]) """ -function install_gym(; packages=["gym", "pybullet"]) +function install_gym(; packages=["gymnasium", "pybullet"]) # Use eventual proxy info proxy_arg = String[] if haskey(ENV, "http_proxy") diff --git a/src/ReinforcementLearningEnvironments/test/runtests.jl b/src/ReinforcementLearningEnvironments/test/runtests.jl index d8dfe872f..59bafe458 100644 --- a/src/ReinforcementLearningEnvironments/test/runtests.jl +++ b/src/ReinforcementLearningEnvironments/test/runtests.jl @@ -13,7 +13,7 @@ using OrdinaryDiffEq using TimerOutputs using Conda -Conda.add("gym") +Conda.add("gymnasium") Conda.add("numpy") @testset "ReinforcementLearningEnvironments" begin From 3d3b01ff242556b43265abe743d0e714fce1ff35 Mon Sep 17 00:00:00 2001 From: Simon Haastert <57462510+SimonHashtag@users.noreply.github.com> Date: Fri, 1 Mar 2024 18:22:24 +0100 Subject: [PATCH 3/4] Delete is_truncated function Postpone on how to handle is_truncated/is_terminated to later PR. If isterminated or istruncated return true for function is_terminated --- .../src/environments/3rd_party/gym.jl | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/src/ReinforcementLearningEnvironments/src/environments/3rd_party/gym.jl b/src/ReinforcementLearningEnvironments/src/environments/3rd_party/gym.jl index 3f8e97d1d..cca749098 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/3rd_party/gym.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/3rd_party/gym.jl @@ -82,8 +82,8 @@ end function RLBase.is_terminated(env::GymEnv{T}) where {T} if pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type) && length(env.state) == 5 - _, _, isterminated, = convert(Tuple{T,Float64,Bool,Bool,PyDict}, env.state) - isterminated + _, _, isterminated, istruncated, = convert(Tuple{T,Float64,Bool,Bool,PyDict}, env.state) + isterminated || istruncated elseif pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type) && length(env.state) == 4 @warn "Gym version outdated. Update gym to obtain termination and truncation info instead of done signal." _, _, isdone, = convert(Tuple{T,Float64,Bool,PyDict}, env.state) @@ -93,15 +93,6 @@ function RLBase.is_terminated(env::GymEnv{T}) where {T} end end -function is_truncated(env::GymEnv{T}) where {T} - if pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type) && length(env.state) == 5 - _, _, _, istruncated, = convert(Tuple{T,Float64,Bool,Bool,PyDict}, env.state) - istruncated - else - false - end -end - function RLBase.state(env::GymEnv{T}) where {T} if pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type) && length(env.state) == 5 obs, = convert(Tuple{T,Float64,Bool,Bool,PyDict}, env.state) From c5cf901c9e36a973afe73ee5c12cc3baf9d3139f Mon Sep 17 00:00:00 2001 From: Jeremiah <4462211+jeremiahpslewis@users.noreply.github.com> Date: Sat, 2 Mar 2024 08:53:42 +0100 Subject: [PATCH 4/4] trigger build --- .../src/environments/3rd_party/gym.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ReinforcementLearningEnvironments/src/environments/3rd_party/gym.jl b/src/ReinforcementLearningEnvironments/src/environments/3rd_party/gym.jl index cca749098..61f01d824 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/3rd_party/gym.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/3rd_party/gym.jl @@ -33,7 +33,7 @@ function GymEnv(name::String; seed::Union{Int,Nothing}=nothing) elseif obs_space isa Space{<:Dict} PyDict else - error("don't know how to get the observation type from observation space of $obs_space") + error("Don't know how to get the observation type from observation space of $obs_space") end env = GymEnv{obs_type,typeof(act_space),typeof(obs_space),typeof(pyenv)}( pyenv,