diff --git a/python/ray/_private/runtime_env/conda.py b/python/ray/_private/runtime_env/conda.py index a28cc02cb046e..12022b14c8584 100644 --- a/python/ray/_private/runtime_env/conda.py +++ b/python/ray/_private/runtime_env/conda.py @@ -18,7 +18,8 @@ create_conda_env_if_needed, delete_conda_env, get_conda_activate_commands, - get_conda_env_list, + get_conda_info_json, + get_conda_envs, ) from ray._private.runtime_env.context import RuntimeEnvContext from ray._private.runtime_env.packaging import Protocol, parse_uri @@ -342,13 +343,15 @@ def _create(): if result in self._validated_named_conda_env: return 0 - conda_env_list = get_conda_env_list() - envs = [Path(env).name for env in conda_env_list] - if result not in envs: + conda_info = get_conda_info_json() + envs = get_conda_envs(conda_info) + + # We accept `result` as a conda name or full path. + if not any(result == env[0] or result == env[1] for env in envs): raise ValueError( f"The given conda environment '{result}' " f"from the runtime env {runtime_env} doesn't " - "exist from the output of `conda env list --json`. " + "exist from the output of `conda info --json`. " "You can only specify an env that already exists. " f"Please make sure to create an env {result} " ) diff --git a/python/ray/_private/runtime_env/conda_utils.py b/python/ray/_private/runtime_env/conda_utils.py index 757ca4a998dd3..a33916c5e46ff 100644 --- a/python/ray/_private/runtime_env/conda_utils.py +++ b/python/ray/_private/runtime_env/conda_utils.py @@ -161,7 +161,7 @@ def delete_conda_env(prefix: str, logger: Optional[logging.Logger] = None) -> bo def get_conda_env_list() -> list: """ - Get conda env list. + Get conda env list in full paths. """ conda_path = get_conda_bin_executable("conda") try: @@ -173,6 +173,33 @@ def get_conda_env_list() -> list: return envs +def get_conda_info_json() -> dict: + """ + Get `conda info --json` output. + """ + conda_path = get_conda_bin_executable("conda") + try: + exec_cmd([conda_path, "--help"], throw_on_error=False) + except EnvironmentError: + raise ValueError(f"Could not find Conda executable at {conda_path}.") + _, stdout, _ = exec_cmd([conda_path, "info", "--json"]) + return json.loads(stdout) + + +def get_conda_envs(conda_info: dict) -> List[Tuple[str, str]]: + """ + Gets the conda environments, as a list of (name, path) tuples. + """ + prefix = conda_info["conda_prefix"] + ret = [] + for env in conda_info["envs"]: + if env == prefix: + ret.append(("base", env)) + else: + ret.append((os.path.basename(env), env)) + return ret + + class ShellCommandException(Exception): pass diff --git a/python/ray/tests/test_runtime_env_complicated.py b/python/ray/tests/test_runtime_env_complicated.py index 0122a06dbc144..a71de6cf4a70c 100644 --- a/python/ray/tests/test_runtime_env_complicated.py +++ b/python/ray/tests/test_runtime_env_complicated.py @@ -19,7 +19,11 @@ _current_py_version, ) -from ray._private.runtime_env.conda_utils import get_conda_env_list +from ray._private.runtime_env.conda_utils import ( + get_conda_env_list, + get_conda_info_json, + get_conda_envs, +) from ray._private.test_utils import ( run_string_as_driver, run_string_as_driver_nonblocking, @@ -213,6 +217,79 @@ def wrapped_version(self): assert ray.get(actor.wrapped_version.remote()) == package_version +@pytest.mark.skipif( + os.environ.get("CONDA_DEFAULT_ENV") is None, + reason="must be run from within a conda environment", +) +def test_base_full_path(conda_envs, shutdown_only): + """ + Test that `base` and its absolute path prefix can both work. + """ + ray.init() + + conda_info = get_conda_info_json() + prefix = conda_info["conda_prefix"] + + test_conda_envs = ["base", prefix] + + @ray.remote + def get_conda_env_name(): + return os.environ.get("CONDA_DEFAULT_ENV") + + # Basic conda runtime env + for conda_env in test_conda_envs: + runtime_env = {"conda": conda_env} + + task = get_conda_env_name.options(runtime_env=runtime_env) + assert ray.get(task.remote()) == "base" + + +@pytest.mark.skipif( + os.environ.get("CONDA_DEFAULT_ENV") is None, + reason="must be run from within a conda environment", +) +def test_task_actor_conda_env_full_path(conda_envs, shutdown_only): + ray.init() + + conda_info = get_conda_info_json() + prefix = conda_info["conda_prefix"] + + test_conda_envs = { + package_version: f"{prefix}/envs/package-{package_version}" + for package_version in EMOJI_VERSIONS + } + + # Basic conda runtime env + for package_version, conda_full_path in test_conda_envs.items(): + runtime_env = {"conda": conda_full_path} + print(f"Testing {package_version}, runtime env: {runtime_env}") + + task = get_emoji_version.options(runtime_env=runtime_env) + assert ray.get(task.remote()) == package_version + + actor = VersionActor.options(runtime_env=runtime_env).remote() + assert ray.get(actor.get_emoji_version.remote()) == package_version + + # Runtime env should inherit to nested task + @ray.remote + def wrapped_version(): + return ray.get(get_emoji_version.remote()) + + @ray.remote + class Wrapper: + def wrapped_version(self): + return ray.get(get_emoji_version.remote()) + + for package_version, conda_full_path in test_conda_envs.items(): + runtime_env = {"conda": conda_full_path} + + task = wrapped_version.options(runtime_env=runtime_env) + assert ray.get(task.remote()) == package_version + + actor = Wrapper.options(runtime_env=runtime_env).remote() + assert ray.get(actor.wrapped_version.remote()) == package_version + + @pytest.mark.skipif( os.environ.get("CONDA_DEFAULT_ENV") is None, reason="must be run from within a conda environment", @@ -329,6 +406,22 @@ def test_get_conda_env_dir(tmp_path): assert env_dir == str(tmp_path / "envs" / "tf2") +@pytest.mark.skipif( + os.environ.get("CONDA_DEFAULT_ENV") is None, + reason="must be run from within a conda environment", +) +def test_get_conda_envs(conda_envs): + """ + Tests that we can at least find 3 conda envs: base, and two envs we created. + """ + conda_info = get_conda_info_json() + envs = get_conda_envs(conda_info) + prefix = conda_info["conda_prefix"] + assert ("base", prefix) in envs + assert ("package-2.1.0", prefix + "/envs/package-2.1.0") in envs + assert ("package-2.2.0", prefix + "/envs/package-2.2.0") in envs + + @pytest.mark.skipif( os.environ.get("CONDA_EXE") is None, reason="Requires properly set-up conda shell",