From 0f7016e861624733e24be8e53fd4560bcd92d73a Mon Sep 17 00:00:00 2001 From: Michael Fraenkel Date: Wed, 3 May 2023 09:57:59 -0600 Subject: [PATCH] Add the correct path when it's a package --- integration/_support/package/tasks/module.py | 2 ++ integration/_support/package/tasks/pytest.py | 0 invoke/loader.py | 2 ++ tests/loader.py | 6 ++++++ 4 files changed, 10 insertions(+) create mode 100644 integration/_support/package/tasks/pytest.py diff --git a/integration/_support/package/tasks/module.py b/integration/_support/package/tasks/module.py index 05f37ee50..5bca5e809 100644 --- a/integration/_support/package/tasks/module.py +++ b/integration/_support/package/tasks/module.py @@ -1,4 +1,6 @@ from invoke import task +from . import pytest as pt +from pytest import Testdir @task diff --git a/integration/_support/package/tasks/pytest.py b/integration/_support/package/tasks/pytest.py new file mode 100644 index 000000000..e69de29bb diff --git a/invoke/loader.py b/invoke/loader.py index ba3003dd2..214da8f8b 100644 --- a/invoke/loader.py +++ b/invoke/loader.py @@ -73,6 +73,8 @@ def load(self, name: Optional[str] = None) -> Tuple[ModuleType, str]: # being imported is trying to load local-to-it names. if os.path.isfile(spec.origin): path = os.path.dirname(spec.origin) + if spec.origin.endswith("__init__.py"): + path = os.path.dirname(path) if path not in sys.path: sys.path.insert(0, path) # Actual import diff --git a/tests/loader.py b/tests/loader.py index aef31709d..fb04afa85 100644 --- a/tests/loader.py +++ b/tests/loader.py @@ -51,6 +51,12 @@ def adds_module_parent_dir_to_sys_path(self): # Crummy doesn't-explode test. _BasicLoader().load("namespacing") + def adds_package_dir_to_sys_path(self): + config = Config({"tasks": {"collection_name": "module"}}) + _BasicLoader(config).load("package") + package = Path(support) / "package" + assert str(package) not in sys.path + def doesnt_duplicate_parent_dir_addition(self): _BasicLoader().load("namespacing") _BasicLoader().load("namespacing")