Skip to content

Commit

Permalink
allow full path in conda rt env
Browse files Browse the repository at this point in the history
Signed-off-by: Ruiyang Wang <[email protected]>
  • Loading branch information
rynewang committed May 25, 2024
1 parent f9ac050 commit 9dd3367
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 7 deletions.
13 changes: 8 additions & 5 deletions python/ray/_private/runtime_env/conda.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
create_conda_env_if_needed,
delete_conda_env,
get_conda_activate_commands,
get_conda_env_list,
get_conda_info_json,
get_conda_envs,
)
from ray._private.runtime_env.context import RuntimeEnvContext
from ray._private.runtime_env.packaging import Protocol, parse_uri
Expand Down Expand Up @@ -342,13 +343,15 @@ def _create():
if result in self._validated_named_conda_env:
return 0

conda_env_list = get_conda_env_list()
envs = [Path(env).name for env in conda_env_list]
if result not in envs:
conda_info = get_conda_info_json()
envs = get_conda_envs(conda_info)

# We accept `result` as a conda name or full path.
if not any(result == env[0] or result == env[1] for env in envs):
raise ValueError(
f"The given conda environment '{result}' "
f"from the runtime env {runtime_env} doesn't "
"exist from the output of `conda env list --json`. "
"exist from the output of `conda info --json`. "
"You can only specify an env that already exists. "
f"Please make sure to create an env {result} "
)
Expand Down
29 changes: 28 additions & 1 deletion python/ray/_private/runtime_env/conda_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def delete_conda_env(prefix: str, logger: Optional[logging.Logger] = None) -> bo

def get_conda_env_list() -> list:
"""
Get conda env list.
Get conda env list in full paths.
"""
conda_path = get_conda_bin_executable("conda")
try:
Expand All @@ -173,6 +173,33 @@ def get_conda_env_list() -> list:
return envs


def get_conda_info_json() -> dict:
"""
Get `conda info --json` output.
"""
conda_path = get_conda_bin_executable("conda")
try:
exec_cmd([conda_path, "--help"], throw_on_error=False)
except EnvironmentError:
raise ValueError(f"Could not find Conda executable at {conda_path}.")
_, stdout, _ = exec_cmd([conda_path, "info", "--json"])
return json.loads(stdout)


def get_conda_envs(conda_info: dict) -> List[Tuple[str, str]]:
"""
Gets the conda environments, as a list of (name, path) tuples.
"""
prefix = conda_info["conda_prefix"]
ret = []
for env in conda_info["envs"]:
if env == prefix:
ret.append(("base", env))
else:
ret.append((os.path.basename(env), env))
return ret


class ShellCommandException(Exception):
pass

Expand Down
95 changes: 94 additions & 1 deletion python/ray/tests/test_runtime_env_complicated.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
_current_py_version,
)

from ray._private.runtime_env.conda_utils import get_conda_env_list
from ray._private.runtime_env.conda_utils import (
get_conda_env_list,
get_conda_info_json,
get_conda_envs,
)
from ray._private.test_utils import (
run_string_as_driver,
run_string_as_driver_nonblocking,
Expand Down Expand Up @@ -213,6 +217,79 @@ def wrapped_version(self):
assert ray.get(actor.wrapped_version.remote()) == package_version


@pytest.mark.skipif(
os.environ.get("CONDA_DEFAULT_ENV") is None,
reason="must be run from within a conda environment",
)
def test_base_full_path(conda_envs, shutdown_only):
"""
Test that `base` and its absolute path prefix can both work.
"""
ray.init()

conda_info = get_conda_info_json()
prefix = conda_info["conda_prefix"]

test_conda_envs = ["base", prefix]

@ray.remote
def get_conda_env_name():
return os.environ.get("CONDA_DEFAULT_ENV")

# Basic conda runtime env
for conda_env in test_conda_envs:
runtime_env = {"conda": conda_env}

task = get_conda_env_name.options(runtime_env=runtime_env)
assert ray.get(task.remote()) == "base"


@pytest.mark.skipif(
os.environ.get("CONDA_DEFAULT_ENV") is None,
reason="must be run from within a conda environment",
)
def test_task_actor_conda_env_full_path(conda_envs, shutdown_only):
ray.init()

conda_info = get_conda_info_json()
prefix = conda_info["conda_prefix"]

test_conda_envs = {
package_version: f"{prefix}/envs/package-{package_version}"
for package_version in EMOJI_VERSIONS
}

# Basic conda runtime env
for package_version, conda_full_path in test_conda_envs.items():
runtime_env = {"conda": conda_full_path}
print(f"Testing {package_version}, runtime env: {runtime_env}")

task = get_emoji_version.options(runtime_env=runtime_env)
assert ray.get(task.remote()) == package_version

actor = VersionActor.options(runtime_env=runtime_env).remote()
assert ray.get(actor.get_emoji_version.remote()) == package_version

# Runtime env should inherit to nested task
@ray.remote
def wrapped_version():
return ray.get(get_emoji_version.remote())

@ray.remote
class Wrapper:
def wrapped_version(self):
return ray.get(get_emoji_version.remote())

for package_version, conda_full_path in test_conda_envs.items():
runtime_env = {"conda": conda_full_path}

task = wrapped_version.options(runtime_env=runtime_env)
assert ray.get(task.remote()) == package_version

actor = Wrapper.options(runtime_env=runtime_env).remote()
assert ray.get(actor.wrapped_version.remote()) == package_version


@pytest.mark.skipif(
os.environ.get("CONDA_DEFAULT_ENV") is None,
reason="must be run from within a conda environment",
Expand Down Expand Up @@ -329,6 +406,22 @@ def test_get_conda_env_dir(tmp_path):
assert env_dir == str(tmp_path / "envs" / "tf2")


@pytest.mark.skipif(
os.environ.get("CONDA_DEFAULT_ENV") is None,
reason="must be run from within a conda environment",
)
def test_get_conda_envs(conda_envs):
"""
Tests that we can at least find 3 conda envs: base, and two envs we created.
"""
conda_info = get_conda_info_json()
envs = get_conda_envs(conda_info)
prefix = conda_info["conda_prefix"]
assert ("base", prefix) in envs
assert ("package-2.1.0", prefix + "/envs/package-2.1.0") in envs
assert ("package-2.2.0", prefix + "/envs/package-2.2.0") in envs


@pytest.mark.skipif(
os.environ.get("CONDA_EXE") is None,
reason="Requires properly set-up conda shell",
Expand Down

0 comments on commit 9dd3367

Please sign in to comment.