diff --git a/jupyter_server/extension/manager.py b/jupyter_server/extension/manager.py index 308653f656..73ee73199a 100644 --- a/jupyter_server/extension/manager.py +++ b/jupyter_server/extension/manager.py @@ -1,7 +1,17 @@ import importlib from tornado.gen import multi -from traitlets import Any, Bool, Dict, HasTraits, Instance, Unicode, default, observe +from traitlets import ( + Any, + Bool, + Dict, + HasTraits, + Instance, + List, + Unicode, + default, + observe, +) from traitlets import validate as validate_trait from traitlets.config import LoggingConfigurable @@ -158,52 +168,51 @@ class ExtensionPackage(HasTraits): """ name = Unicode(help="Name of the an importable Python package.") - enabled = Bool(False).tag(config=True) + enabled = Bool(False, help="Whether the extension package is enabled.") + + _linked_points = Dict() + extension_points = Dict() + module = Any(allow_none=True, help="The module for this extension package. None if not enabled") + metadata = List(Dict(), help="Extension metadata loaded from the extension package.") + version = Unicode( + help=""" + The version of this extension package, if it can be found. + Otherwise, an empty string. + """, + ) + + @default("version") + def _load_version(self): + if not self.enabled: + return "" + return getattr(self.module, "__version__", "") - def __init__(self, *args, **kwargs): - # Store extension points that have been linked. - self._linked_points = {} - super().__init__(*args, **kwargs) + def __init__(self, **kwargs): + """Initialize an extension package.""" + super().__init__(**kwargs) + if self.enabled: + self._load_metadata() - _linked_points: dict = {} + def _load_metadata(self): + """Import package and load metadata - @validate_trait("name") - def _validate_name(self, proposed): - name = proposed["value"] - self._extension_points = {} + Only used if extension package is enabled + """ + name = self.name try: - self._module, self._metadata = get_metadata(name) + self.module, self.metadata = get_metadata(name, logger=self.log) except ImportError as e: - raise ExtensionModuleNotFound( - "The module '{name}' could not be found ({e}). Are you " - "sure the extension is installed?".format(name=name, e=e) + msg = ( + f"The module '{name}' could not be found ({e}). Are you " + "sure the extension is installed?" ) + raise ExtensionModuleNotFound(msg) from None # Create extension point interfaces for each extension path. - for m in self._metadata: + for m in self.metadata: point = ExtensionPoint(metadata=m) - self._extension_points[point.name] = point + self.extension_points[point.name] = point return name - @property - def module(self): - """Extension metadata loaded from the extension package.""" - return self._module - - @property - def version(self): - """Get the version of this package, if it's given. Otherwise, return an empty string""" - return getattr(self._module, "__version__", "") - - @property - def metadata(self): - """Extension metadata loaded from the extension package.""" - return self._metadata - - @property - def extension_points(self): - """A dictionary of extension points.""" - return self._extension_points - def validate(self): """Validate all extension points in this package.""" for extension in self.extension_points.values(): diff --git a/tests/extension/test_manager.py b/tests/extension/test_manager.py index 2b52fea543..c79f011974 100644 --- a/tests/extension/test_manager.py +++ b/tests/extension/test_manager.py @@ -1,4 +1,5 @@ import os +import sys import unittest.mock as mock import pytest @@ -60,7 +61,7 @@ def test_extension_package_api(): path1 = metadata_list[0] app = path1["app"] - e = ExtensionPackage(name="tests.extension.mockextensions") + e = ExtensionPackage(name="tests.extension.mockextensions", enabled=True) e.extension_points assert hasattr(e, "extension_points") assert len(e.extension_points) == len(metadata_list) @@ -70,7 +71,9 @@ def test_extension_package_api(): def test_extension_package_notfound_error(): with pytest.raises(ExtensionModuleNotFound): - ExtensionPackage(name="nonexistent") + ExtensionPackage(name="nonexistent", enabled=True) + # no raise if not enabled + ExtensionPackage(name="nonexistent", enabled=False) def _normalize_path(path_list): @@ -132,3 +135,23 @@ def test_extension_manager_fail_load(jp_serverapp): jp_serverapp.reraise_server_extension_failures = True with pytest.raises(RuntimeError): manager.load_extension(name) + + +@pytest.mark.parametrize("has_app", [True, False]) +def test_disable_no_import(jp_serverapp, has_app): + # de-import modules so we can detect if they are re-imported + disabled_ext = "tests.extension.mockextensions.mock1" + enabled_ext = "tests.extension.mockextensions.mock2" + sys.modules.pop(disabled_ext, None) + sys.modules.pop(enabled_ext, None) + + manager = ExtensionManager(serverapp=jp_serverapp if has_app else None) + manager.add_extension(disabled_ext, enabled=False) + manager.add_extension(enabled_ext, enabled=True) + assert disabled_ext not in sys.modules + assert enabled_ext in sys.modules + + ext_pkg = manager.extensions[disabled_ext] + assert ext_pkg.extension_points == {} + assert ext_pkg.version == "" + assert ext_pkg.metadata == []