diff --git a/python/lsst/daf/butler/cli/butler.py b/python/lsst/daf/butler/cli/butler.py index 0e00393bae..9dce172845 100755 --- a/python/lsst/daf/butler/cli/butler.py +++ b/python/lsst/daf/butler/cli/butler.py @@ -34,6 +34,7 @@ ) import abc +import dataclasses import functools import logging import os @@ -41,11 +42,14 @@ import types from collections import defaultdict from functools import cache +from importlib.metadata import entry_points from typing import Any import click import yaml +from lsst.resources import ResourcePath from lsst.utils import doImport +from lsst.utils.introspection import get_full_type_name from lsst.utils.timer import time_this from .cliLog import CliLog @@ -87,6 +91,12 @@ def _importPlugin(pluginName: str) -> types.ModuleType | type | None | click.Com return None +@dataclasses.dataclass(frozen=True) +class PluginCommand: + command: click.Command + source: str + + class LoaderCLI(click.MultiCommand, abc.ABC): """Extends `click.MultiCommand`, which dispatches to subcommands, to load subcommands at runtime. @@ -119,14 +129,14 @@ def localCmdPkg(self) -> str: """ raise NotImplementedError() - def getLocalCommands(self) -> defaultdict[str, list[str]]: + def getLocalCommands(self) -> defaultdict[str, list[str | PluginCommand]]: """Get the commands offered by the local package. This assumes that the commands can be found in `localCmdPkg.__all__`, if this is not the case then this function should be overridden. Returns ------- - commands : `defaultdict` [`str`, `list` [`str`]] + commands : `defaultdict` [`str`, `list` [`str` | `PluginCommand` ]] The key is the command name. The value is a list of package(s) that contains the command. """ @@ -135,8 +145,13 @@ def getLocalCommands(self) -> defaultdict[str, list[str]]: # _importPlugins logs an error, don't need to do it again here. return defaultdict(list) assert hasattr(commandsLocation, "__all__"), f"Must define __all__ in {commandsLocation}" + commands = [getattr(commandsLocation, name) for name in commandsLocation.__all__] return defaultdict( - list, {self._funcNameToCmdName(f): [self.localCmdPkg] for f in commandsLocation.__all__} + list, + { + command.name: [PluginCommand(command, get_full_type_name(commandsLocation))] + for command in commands + }, ) def list_commands(self, ctx: click.Context) -> list[str]: @@ -182,14 +197,18 @@ def get_command(self, ctx: click.Context, name: str) -> click.Command | None: if name not in commands: return None self._raiseIfDuplicateCommands(commands) - module_str = commands[name][0] + "." + self._cmdNameToFuncName(name) - # The click.command decorator returns an instance of a class, which - # is something that doImport is not expecting. We add it in as an - # option here to appease mypy. - with time_this(log, msg="Importing command %s (via %s)", args=(name, module_str)): - plugin = _importPlugin(module_str) - if not plugin: - return None + command = commands[name][0] + if isinstance(command, str): + module_str = command + "." + self._cmdNameToFuncName(name) + # The click.command decorator returns an instance of a class, which + # is something that doImport is not expecting. We add it in as an + # option here to appease mypy. + with time_this(log, msg="Importing command %s (via %s)", args=(name, module_str)): + plugin = _importPlugin(module_str) + if not plugin: + return None + else: + plugin = command.command if not isinstance(plugin, click.Command): raise RuntimeError( f"Command {name!r} loaded from {module_str} is not a click Command, is {type(plugin)}" @@ -221,21 +240,22 @@ def _setupLogging(self, ctx: click.Context | None) -> None: ) @classmethod - def getPluginList(cls) -> list[str]: + def getPluginList(cls) -> list[ResourcePath]: """Get the list of importable yaml files that contain cli data for this command. Returns ------- - `list` [`str`] + `list` [`lsst.resources.ResourcePath`] The list of files that contain yaml data about a cli plugin. """ - if not hasattr(cls, "pluginEnvVar"): - return [] - pluginModules = os.environ.get(cls.pluginEnvVar) - if pluginModules: - return [p for p in pluginModules.split(":") if p != ""] - return [] + yaml_files = [] + if hasattr(cls, "pluginEnvVar"): + pluginModules = os.environ.get(cls.pluginEnvVar) + if pluginModules: + yaml_files.extend([ResourcePath(p) for p in pluginModules.split(":") if p != ""]) + + return yaml_files @classmethod def _funcNameToCmdName(cls, functionName: str) -> str: @@ -257,14 +277,14 @@ def _cmdNameToFuncName(cls, commandName: str) -> str: @staticmethod def _mergeCommandLists( - a: defaultdict[str, list[str]], b: defaultdict[str, list[str]] - ) -> defaultdict[str, list[str]]: + a: defaultdict[str, list[str | PluginCommand]], b: defaultdict[str, list[str | PluginCommand]] + ) -> defaultdict[str, list[str | PluginCommand]]: """Combine two dicts whose keys are strings (command name) and values are list of string (the package(s) that provide the named command). Parameters ---------- - a : `defaultdict` [`str`, `list` [`str`]] + a : `defaultdict` [`str`, `list` [`str` | `PluginCommand` ]] The key is the command name. The value is a list of package(s) that contains the command. b : (same as a) @@ -280,7 +300,7 @@ def _mergeCommandLists( return a @classmethod - def _getPluginCommands(cls) -> defaultdict[str, list[str]]: + def _getPluginCommands(cls) -> defaultdict[str, list[str | PluginCommand]]: """Get the commands offered by plugin packages. Returns @@ -288,12 +308,16 @@ def _getPluginCommands(cls) -> defaultdict[str, list[str]]: commands : `defaultdict` [`str`, `list` [`str`]] The key is the command name. The value is a list of package(s) that contains the command. + + Notes + ----- + Assumes that if entry points are defined, the plugin environment + variable will not be defined for that same package. """ - commands: defaultdict[str, list[str]] = defaultdict(list) + commands: defaultdict[str, list[str | PluginCommand]] = defaultdict(list) for pluginName in cls.getPluginList(): try: - with open(pluginName) as resourceFile: - resources = defaultdict(list, yaml.safe_load(resourceFile)) + resources = defaultdict(list, yaml.safe_load(pluginName.read())) except Exception as err: log.warning("Error loading commands from %s, skipping. %s", pluginName, err) continue @@ -302,28 +326,37 @@ def _getPluginCommands(cls) -> defaultdict[str, list[str]]: continue pluginCommands = {cmd: [resources["cmd"]["import"]] for cmd in resources["cmd"]["commands"]} cls._mergeCommandLists(commands, defaultdict(list, pluginCommands)) + + if hasattr(cls, "entryPoint"): + plugins = entry_points(group=cls.entryPoint) + for p in plugins: + func = p.load() + func_name = get_full_type_name(func) + pluginCommands = {cmd.name: [PluginCommand(cmd, func_name)] for cmd in func()} + cls._mergeCommandLists(commands, defaultdict(list, pluginCommands)) + return commands @cache - def _getCommands(self) -> defaultdict[str, list[str]]: + def _getCommands(self) -> defaultdict[str, list[str | PluginCommand]]: """Get the commands offered by daf_butler and plugin packages. Returns ------- - commands : `defaultdict` [`str`, `list` [`str`]] + commands : `defaultdict` [`str`, `list` [`str` | `PluginCommand` ]] The key is the command name. The value is a list of package(s) that contains the command. """ return self._mergeCommandLists(self.getLocalCommands(), self._getPluginCommands()) @staticmethod - def _raiseIfDuplicateCommands(commands: defaultdict[str, list[str]]) -> None: + def _raiseIfDuplicateCommands(commands: defaultdict[str, list[str | PluginCommand]]) -> None: """If any provided command is offered by more than one package raise an exception. Parameters ---------- - commands : `defaultdict` [`str`, `list` [`str`]] + commands : `defaultdict` [`str`, `list` [`str` | `PLuginCommand` ]] The key is the command name. The value is a list of package(s) that contains the command. @@ -336,7 +369,12 @@ def _raiseIfDuplicateCommands(commands: defaultdict[str, list[str]]) -> None: msg = "" for command, packages in commands.items(): if len(packages) > 1: - msg += f"Command '{command}' exists in packages {', '.join(packages)}. " + pkg_names: list[str] = [] + for p in packages: + if not isinstance(p, str): + p = p.source + pkg_names.append(p) + msg += f"Command '{command}' exists in packages {', '.join(pkg_names)}. " if msg: raise click.ClickException(message=msg + "Duplicate commands are not supported, aborting.") @@ -347,6 +385,7 @@ class ButlerCLI(LoaderCLI): localCmdPkg = "lsst.daf.butler.cli.cmd" pluginEnvVar = "DAF_BUTLER_PLUGINS" + entryPoint = "butler.cli" @classmethod def _funcNameToCmdName(cls, functionName: str) -> str: @@ -377,12 +416,12 @@ def _cmdNameToFuncName(cls, commandName: str) -> str: class UncachedButlerCLI(ButlerCLI): """ButlerCLI that can be used where caching of the commands is disabled.""" - def _getCommands(self) -> defaultdict[str, list[str]]: # type: ignore[override] + def _getCommands(self) -> defaultdict[str, list[str | PluginCommand]]: # type: ignore[override] """Get the commands offered by daf_butler and plugin packages. Returns ------- - commands : `defaultdict` [`str`, `list` [`str`]] + commands : `defaultdict` [`str`, `list` [`str` | `PluginCommand` ]] The key is the command name. The value is a list of package(s) that contains the command. """