diff --git a/src/ReinforcementLearningEnvironments/src/environments/3rd_party/gym.jl b/src/ReinforcementLearningEnvironments/src/environments/3rd_party/gym.jl index cb8038c06..61f01d824 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/3rd_party/gym.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/3rd_party/gym.jl @@ -1,12 +1,12 @@ using .PyCall function GymEnv(name::String; seed::Union{Int,Nothing}=nothing) - if !PyCall.pyexists("gym") + if !PyCall.pyexists("gymnasium") error( - "Cannot import module 'gym'.\n\nIf you did not yet install it, try running\n`ReinforcementLearningEnvironments.install_gym()`\n", + "Cannot import module 'gymnasium'.\n\nIf you did not yet install it, try running\n`ReinforcementLearningEnvironments.install_gym()`\n", ) end - gym = pyimport_conda("gym", "gym") + gym = pyimport_conda("gymnasium", "gymnasium") if PyCall.pyexists("d4rl") pyimport("d4rl") end @@ -33,7 +33,7 @@ function GymEnv(name::String; seed::Union{Int,Nothing}=nothing) elseif obs_space isa Space{<:Dict} PyDict else - error("don't know how to get the observation type from observation space of $obs_space") + error("Don't know how to get the observation type from observation space of $obs_space") end env = GymEnv{obs_type,typeof(act_space),typeof(obs_space),typeof(pyenv)}( pyenv, @@ -69,8 +69,11 @@ RLBase.action_space(env::GymEnv) = env.action_space RLBase.state_space(env::GymEnv) = env.observation_space function RLBase.reward(env::GymEnv{T}) where {T} - if pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type) && length(env.state) == 4 - obs, reward, isdone, info = convert(Tuple{T,Float64,Bool,PyDict}, env.state) + if pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type) && length(env.state) == 5 + _, reward, = convert(Tuple{T,Float64,Bool,Bool,PyDict}, env.state) + reward + elseif pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type) && length(env.state) == 4 + _, reward, = convert(Tuple{T,Float64,Bool,PyDict}, env.state) reward else 0.0 @@ -78,8 +81,12 @@ function RLBase.reward(env::GymEnv{T}) where {T} end function RLBase.is_terminated(env::GymEnv{T}) where {T} - if pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type) && length(env.state) == 4 - obs, reward, isdone, info = convert(Tuple{T,Float64,Bool,PyDict}, env.state) + if pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type) && length(env.state) == 5 + _, _, isterminated, istruncated, = convert(Tuple{T,Float64,Bool,Bool,PyDict}, env.state) + isterminated || istruncated + elseif pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type) && length(env.state) == 4 + @warn "Gym version outdated. Update gym to obtain termination and truncation info instead of done signal." + _, _, isdone, = convert(Tuple{T,Float64,Bool,PyDict}, env.state) isdone else false @@ -87,8 +94,11 @@ function RLBase.is_terminated(env::GymEnv{T}) where {T} end function RLBase.state(env::GymEnv{T}) where {T} - if pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type) && length(env.state) == 4 - obs, reward, isdone, info = convert(Tuple{T,Float64,Bool,PyDict}, env.state) + if pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type) && length(env.state) == 5 + obs, = convert(Tuple{T,Float64,Bool,Bool,PyDict}, env.state) + obs + elseif pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type) && length(env.state) == 4 + obs, = convert(Tuple{T,Float64,Bool,PyDict}, env.state) obs else convert(T, env.state) @@ -123,19 +133,19 @@ end function list_gym_env_names(; modules=[ - "gym.envs.algorithmic", - "gym.envs.box2d", - "gym.envs.classic_control", - "gym.envs.mujoco", - "gym.envs.mujoco.ant_v3", - "gym.envs.mujoco.half_cheetah_v3", - "gym.envs.mujoco.hopper_v3", - "gym.envs.mujoco.humanoid_v3", - "gym.envs.mujoco.swimmer_v3", - "gym.envs.mujoco.walker2d_v3", - "gym.envs.robotics", - "gym.envs.toy_text", - "gym.envs.unittest", + "gymnasium.envs.algorithmic", + "gymnasium.envs.box2d", + "gymnasium.envs.classic_control", + "gymnasium.envs.mujoco", + "gymnasium.envs.mujoco.ant_v3", + "gymnasium.envs.mujoco.half_cheetah_v3", + "gymnasium.envs.mujoco.hopper_v3", + "gymnasium.envs.mujoco.humanoid_v3", + "gymnasium.envs.mujoco.swimmer_v3", + "gymnasium.envs.mujoco.walker2d_v3", + "gymnasium.envs.robotics", + "gymnasium.envs.toy_text", + "gymnasium.envs.unittest", "d4rl.pointmaze", "d4rl.hand_manipulation_suite", "d4rl.gym_mujoco.gym_envs", @@ -147,14 +157,14 @@ function list_gym_env_names(; if PyCall.pyexists("d4rl") pyimport("d4rl") end - gym = pyimport("gym") + gym = pyimport("gymnasium") [x.id for x in values(gym.envs.registry) if split(x.entry_point, ':')[1] in modules] end """ - install_gym(; packages = ["gym", "pybullet"]) + install_gym(; packages = ["gymnasium", "pybullet"]) """ -function install_gym(; packages=["gym", "pybullet"]) +function install_gym(; packages=["gymnasium", "pybullet"]) # Use eventual proxy info proxy_arg = String[] if haskey(ENV, "http_proxy") diff --git a/src/ReinforcementLearningEnvironments/test/runtests.jl b/src/ReinforcementLearningEnvironments/test/runtests.jl index d8dfe872f..59bafe458 100644 --- a/src/ReinforcementLearningEnvironments/test/runtests.jl +++ b/src/ReinforcementLearningEnvironments/test/runtests.jl @@ -13,7 +13,7 @@ using OrdinaryDiffEq using TimerOutputs using Conda -Conda.add("gym") +Conda.add("gymnasium") Conda.add("numpy") @testset "ReinforcementLearningEnvironments" begin