diff --git a/doc/changes/DM-47143.feature.rst b/doc/changes/DM-47143.feature.rst new file mode 100644 index 0000000000..e4dbe287cb --- /dev/null +++ b/doc/changes/DM-47143.feature.rst @@ -0,0 +1,3 @@ +The ``DAF_BUTLER_PLUGINS`` environment variable should no longer be set if packages use ``pip install`` and have been upgraded to use entry points. +Butler can now read the subcommands from ``pipe_base`` and ``daf_butler_migrate`` automatically. +Setting the environment variable for these packages will result in an error. diff --git a/python/lsst/daf/butler/cli/butler.py b/python/lsst/daf/butler/cli/butler.py index 7905a91f5e..8a0efdf8c1 100755 --- a/python/lsst/daf/butler/cli/butler.py +++ b/python/lsst/daf/butler/cli/butler.py @@ -33,19 +33,24 @@ "main", ) - import abc +import dataclasses import functools import logging import os import traceback import types from collections import defaultdict +from functools import cache +from importlib.metadata import entry_points from typing import Any import click import yaml +from lsst.resources import ResourcePath from lsst.utils import doImport +from lsst.utils.introspection import get_full_type_name +from lsst.utils.timer import time_this from .cliLog import CliLog from .opt import log_file_option, log_label_option, log_level_option, log_tty_option, long_log_option @@ -86,6 +91,16 @@ def _importPlugin(pluginName: str) -> types.ModuleType | type | None | click.Com return None +@dataclasses.dataclass(frozen=True) +class PluginCommand: + """A click Command and the plugin it came from.""" + + command: click.Command + """The command (`click.Command`).""" + source: str + """Where the command came from (`str`).""" + + class LoaderCLI(click.MultiCommand, abc.ABC): """Extends `click.MultiCommand`, which dispatches to subcommands, to load subcommands at runtime. @@ -118,14 +133,14 @@ def localCmdPkg(self) -> str: """ raise NotImplementedError() - def getLocalCommands(self) -> defaultdict[str, list[str]]: + def getLocalCommands(self) -> defaultdict[str, list[str | PluginCommand]]: """Get the commands offered by the local package. This assumes that the commands can be found in `localCmdPkg.__all__`, if this is not the case then this function should be overridden. Returns ------- - commands : `defaultdict` [`str`, `list` [`str`]] + commands : `defaultdict` [`str`, `list` [`str` | `PluginCommand` ]] The key is the command name. The value is a list of package(s) that contains the command. """ @@ -134,8 +149,13 @@ def getLocalCommands(self) -> defaultdict[str, list[str]]: # _importPlugins logs an error, don't need to do it again here. return defaultdict(list) assert hasattr(commandsLocation, "__all__"), f"Must define __all__ in {commandsLocation}" + commands = [getattr(commandsLocation, name) for name in commandsLocation.__all__] return defaultdict( - list, {self._funcNameToCmdName(f): [self.localCmdPkg] for f in commandsLocation.__all__} + list, + { + command.name: [PluginCommand(command, get_full_type_name(commandsLocation))] + for command in commands + }, ) def list_commands(self, ctx: click.Context) -> list[str]: @@ -181,13 +201,18 @@ def get_command(self, ctx: click.Context, name: str) -> click.Command | None: if name not in commands: return None self._raiseIfDuplicateCommands(commands) - module_str = commands[name][0] + "." + self._cmdNameToFuncName(name) - # The click.command decorator returns an instance of a class, which - # is something that doImport is not expecting. We add it in as an - # option here to appease mypy. - plugin = _importPlugin(module_str) - if not plugin: - return None + command = commands[name][0] + if isinstance(command, str): + module_str = command + "." + self._cmdNameToFuncName(name) + # The click.command decorator returns an instance of a class, which + # is something that doImport is not expecting. We add it in as an + # option here to appease mypy. + with time_this(log, msg="Importing command %s (via %s)", args=(name, module_str)): + plugin = _importPlugin(module_str) + if not plugin: + return None + else: + plugin = command.command if not isinstance(plugin, click.Command): raise RuntimeError( f"Command {name!r} loaded from {module_str} is not a click Command, is {type(plugin)}" @@ -219,21 +244,22 @@ def _setupLogging(self, ctx: click.Context | None) -> None: ) @classmethod - def getPluginList(cls) -> list[str]: + def getPluginList(cls) -> list[ResourcePath]: """Get the list of importable yaml files that contain cli data for this command. Returns ------- - `list` [`str`] + `list` [`lsst.resources.ResourcePath`] The list of files that contain yaml data about a cli plugin. """ - if not hasattr(cls, "pluginEnvVar"): - return [] - pluginModules = os.environ.get(cls.pluginEnvVar) - if pluginModules: - return [p for p in pluginModules.split(":") if p != ""] - return [] + yaml_files = [] + if hasattr(cls, "pluginEnvVar"): + pluginModules = os.environ.get(cls.pluginEnvVar) + if pluginModules: + yaml_files.extend([ResourcePath(p) for p in pluginModules.split(":") if p != ""]) + + return yaml_files @classmethod def _funcNameToCmdName(cls, functionName: str) -> str: @@ -255,21 +281,21 @@ def _cmdNameToFuncName(cls, commandName: str) -> str: @staticmethod def _mergeCommandLists( - a: defaultdict[str, list[str]], b: defaultdict[str, list[str]] - ) -> defaultdict[str, list[str]]: + a: defaultdict[str, list[str | PluginCommand]], b: defaultdict[str, list[str | PluginCommand]] + ) -> defaultdict[str, list[str | PluginCommand]]: """Combine two dicts whose keys are strings (command name) and values are list of string (the package(s) that provide the named command). Parameters ---------- - a : `defaultdict` [`str`, `list` [`str`]] + a : `defaultdict` [`str`, `list` [`str` | `PluginCommand` ]] The key is the command name. The value is a list of package(s) that contains the command. b : (same as a) Returns ------- - commands : `defaultdict` [`str`: [`str`]] + commands : `defaultdict` [`str`: [`str` | `PluginCommand` ]] For convenience, returns a extended with b. ('a' is modified in place.) """ @@ -278,7 +304,7 @@ def _mergeCommandLists( return a @classmethod - def _getPluginCommands(cls) -> defaultdict[str, list[str]]: + def _getPluginCommands(cls) -> defaultdict[str, list[str | PluginCommand]]: """Get the commands offered by plugin packages. Returns @@ -286,12 +312,16 @@ def _getPluginCommands(cls) -> defaultdict[str, list[str]]: commands : `defaultdict` [`str`, `list` [`str`]] The key is the command name. The value is a list of package(s) that contains the command. + + Notes + ----- + Assumes that if entry points are defined, the plugin environment + variable will not be defined for that same package. """ - commands: defaultdict[str, list[str]] = defaultdict(list) + commands: defaultdict[str, list[str | PluginCommand]] = defaultdict(list) for pluginName in cls.getPluginList(): try: - with open(pluginName) as resourceFile: - resources = defaultdict(list, yaml.safe_load(resourceFile)) + resources = defaultdict(list, yaml.safe_load(pluginName.read())) except Exception as err: log.warning("Error loading commands from %s, skipping. %s", pluginName, err) continue @@ -300,27 +330,37 @@ def _getPluginCommands(cls) -> defaultdict[str, list[str]]: continue pluginCommands = {cmd: [resources["cmd"]["import"]] for cmd in resources["cmd"]["commands"]} cls._mergeCommandLists(commands, defaultdict(list, pluginCommands)) + + if hasattr(cls, "entryPoint"): + plugins = entry_points(group=cls.entryPoint) + for p in plugins: + func = p.load() + func_name = get_full_type_name(func) + pluginCommands = {cmd.name: [PluginCommand(cmd, func_name)] for cmd in func()} + cls._mergeCommandLists(commands, defaultdict(list, pluginCommands)) + return commands - def _getCommands(self) -> defaultdict[str, list[str]]: + @cache + def _getCommands(self) -> defaultdict[str, list[str | PluginCommand]]: """Get the commands offered by daf_butler and plugin packages. Returns ------- - commands : `defaultdict` [`str`, `list` [`str`]] + commands : `defaultdict` [`str`, `list` [`str` | `PluginCommand` ]] The key is the command name. The value is a list of package(s) that contains the command. """ return self._mergeCommandLists(self.getLocalCommands(), self._getPluginCommands()) @staticmethod - def _raiseIfDuplicateCommands(commands: defaultdict[str, list[str]]) -> None: + def _raiseIfDuplicateCommands(commands: defaultdict[str, list[str | PluginCommand]]) -> None: """If any provided command is offered by more than one package raise an exception. Parameters ---------- - commands : `defaultdict` [`str`, `list` [`str`]] + commands : `defaultdict` [`str`, `list` [`str` | `PLuginCommand` ]] The key is the command name. The value is a list of package(s) that contains the command. @@ -333,7 +373,12 @@ def _raiseIfDuplicateCommands(commands: defaultdict[str, list[str]]) -> None: msg = "" for command, packages in commands.items(): if len(packages) > 1: - msg += f"Command '{command}' exists in packages {', '.join(packages)}. " + pkg_names: list[str] = [] + for p in packages: + if not isinstance(p, str): + p = p.source + pkg_names.append(p) + msg += f"Command '{command}' exists in packages {', '.join(pkg_names)}. " if msg: raise click.ClickException(message=msg + "Duplicate commands are not supported, aborting.") @@ -344,6 +389,7 @@ class ButlerCLI(LoaderCLI): localCmdPkg = "lsst.daf.butler.cli.cmd" pluginEnvVar = "DAF_BUTLER_PLUGINS" + entryPoint = "butler.cli" @classmethod def _funcNameToCmdName(cls, functionName: str) -> str: @@ -371,6 +417,21 @@ def _cmdNameToFuncName(cls, commandName: str) -> str: return super()._cmdNameToFuncName(commandName) +class UncachedButlerCLI(ButlerCLI): + """ButlerCLI that can be used where caching of the commands is disabled.""" + + def _getCommands(self) -> defaultdict[str, list[str | PluginCommand]]: # type: ignore[override] + """Get the commands offered by daf_butler and plugin packages. + + Returns + ------- + commands : `defaultdict` [`str`, `list` [`str` | `PluginCommand` ]] + The key is the command name. The value is a list of package(s) that + contains the command. + """ + return self._mergeCommandLists(self.getLocalCommands(), self._getPluginCommands()) + + @click.command(cls=ButlerCLI, context_settings=dict(help_option_names=["-h", "--help"])) @log_level_option() @long_log_option() diff --git a/python/lsst/daf/butler/tests/cliLogTestBase.py b/python/lsst/daf/butler/tests/cliLogTestBase.py index 456678b326..3137f68c9c 100644 --- a/python/lsst/daf/butler/tests/cliLogTestBase.py +++ b/python/lsst/daf/butler/tests/cliLogTestBase.py @@ -49,8 +49,15 @@ from typing import TYPE_CHECKING, Any import click -from lsst.daf.butler.cli.butler import cli as butlerCli +from lsst.daf.butler.cli.butler import UncachedButlerCLI from lsst.daf.butler.cli.cliLog import CliLog +from lsst.daf.butler.cli.opt import ( + log_file_option, + log_label_option, + log_level_option, + log_tty_option, + long_log_option, +) from lsst.daf.butler.cli.utils import LogCliRunner, clickResultMsg, command_test_env from lsst.daf.butler.logging import ButlerLogRecords from lsst.utils.logging import TRACE @@ -68,6 +75,34 @@ lsstLog_WARN = 0 +@click.command(cls=UncachedButlerCLI) +@log_level_option() +@long_log_option() +@log_file_option() +@log_tty_option() +@log_label_option() +def butlerCli(log_level: str, long_log: bool, log_file: str, log_tty: bool, log_label: str) -> None: + """Uncached ButlerCLI. + + Parameters + ---------- + log_level : `str` + The log level to use by default. ``log_level`` is handled by + ``get_command`` and ``list_commands``, and is called in + one of those functions before this is called. + long_log : `bool` + Enable extended log output. ``long_log`` is handled by + ``setup_logging``. + log_file : `str` + The log file name. + log_tty : `bool` + Whether to send logs to standard output. + log_label : `str` + Log labels. + """ + pass + + @click.command() @click.option("--expected-pyroot-level", type=int) @click.option("--expected-pylsst-level", type=int) diff --git a/tests/test_cliPluginLoader.py b/tests/test_cliPluginLoader.py index c48847047d..9da535f176 100644 --- a/tests/test_cliPluginLoader.py +++ b/tests/test_cliPluginLoader.py @@ -37,6 +37,7 @@ import click import yaml from lsst.daf.butler.cli import butler, cmd +from lsst.daf.butler.cli.butler import UncachedButlerCLI from lsst.daf.butler.cli.utils import LogCliRunner, command_test_env @@ -64,6 +65,12 @@ def duplicate_command_test_env(runner): yield +@click.command(cls=UncachedButlerCLI) +def uncached_cli(): + """ButlerCLI that does not cache the commands.""" + pass + + class FailedLoadTest(unittest.TestCase): """Test failed plugin loading.""" @@ -73,7 +80,7 @@ def setUp(self): def test_unimportablePlugin(self): with command_test_env(self.runner, "test_cliPluginLoader", "non-existant-command-function"): with self.assertLogs() as cm: - result = self.runner.invoke(butler.cli, "--help") + result = self.runner.invoke(uncached_cli, "--help") self.assertEqual(result.exit_code, 0, f"output: {result.output!r} exception: {result.exception}") expectedErrMsg = ( "Could not import plugin from test_cliPluginLoader.non_existant_command_function, skipping." @@ -104,7 +111,7 @@ def setUp(self): def test_loadAndExecutePluginCommand(self): """Test that a plugin command can be loaded and executed.""" with command_test_env(self.runner, "test_cliPluginLoader", "command-test"): - result = self.runner.invoke(butler.cli, "command-test") + result = self.runner.invoke(uncached_cli, "command-test") self.assertEqual(result.exit_code, 0, f"output: {result.output} exception: {result.exception}") self.assertEqual(result.stdout, "test command\n") @@ -118,7 +125,7 @@ def test_loadAndExecuteLocalCommand(self): def test_loadTopHelp(self): """Test that an expected command is produced by 'butler --help'""" with command_test_env(self.runner, "test_cliPluginLoader", "command-test"): - result = self.runner.invoke(butler.cli, "--help") + result = self.runner.invoke(uncached_cli, "--help") self.assertEqual(result.exit_code, 0, f"output: {result.output} exception: {result.exception}") self.assertIn("command-test", result.stdout) @@ -146,7 +153,7 @@ def test_listCommands_duplicate(self): """ self.maxDiff = None with duplicate_command_test_env(self.runner): - result = self.runner.invoke(butler.cli, ["create", "test_repo"]) + result = self.runner.invoke(uncached_cli, ["create", "test_repo"]) self.assertEqual(result.exit_code, 1, f"output: {result.output} exception: {result.exception}") self.assertEqual( result.output,