Skip to content

Commit

Permalink
Allow plugins to be found via the butler.cli entry point
Browse files Browse the repository at this point in the history
This is in addition to the environment variable.
The plugins return the actual commands rather than any
links to YAML files.
  • Loading branch information
timj committed Oct 28, 2024
1 parent 00b770d commit 80b168e
Showing 1 changed file with 72 additions and 33 deletions.
105 changes: 72 additions & 33 deletions python/lsst/daf/butler/cli/butler.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,22 @@
)

import abc
import dataclasses
import functools
import logging
import os
import traceback
import types
from collections import defaultdict
from functools import cache
from importlib.metadata import entry_points
from typing import Any

import click
import yaml
from lsst.resources import ResourcePath
from lsst.utils import doImport
from lsst.utils.introspection import get_full_type_name
from lsst.utils.timer import time_this

from .cliLog import CliLog
Expand Down Expand Up @@ -87,6 +91,12 @@ def _importPlugin(pluginName: str) -> types.ModuleType | type | None | click.Com
return None


@dataclasses.dataclass(frozen=True)
class PluginCommand:
command: click.Command
source: str


class LoaderCLI(click.MultiCommand, abc.ABC):
"""Extends `click.MultiCommand`, which dispatches to subcommands, to load
subcommands at runtime.
Expand Down Expand Up @@ -119,14 +129,14 @@ def localCmdPkg(self) -> str:
"""
raise NotImplementedError()

def getLocalCommands(self) -> defaultdict[str, list[str]]:
def getLocalCommands(self) -> defaultdict[str, list[str | PluginCommand]]:
"""Get the commands offered by the local package. This assumes that the
commands can be found in `localCmdPkg.__all__`, if this is not the case
then this function should be overridden.
Returns
-------
commands : `defaultdict` [`str`, `list` [`str`]]
commands : `defaultdict` [`str`, `list` [`str` | `PluginCommand` ]]
The key is the command name. The value is a list of package(s) that
contains the command.
"""
Expand All @@ -135,8 +145,13 @@ def getLocalCommands(self) -> defaultdict[str, list[str]]:
# _importPlugins logs an error, don't need to do it again here.
return defaultdict(list)
assert hasattr(commandsLocation, "__all__"), f"Must define __all__ in {commandsLocation}"
commands = [getattr(commandsLocation, name) for name in commandsLocation.__all__]
return defaultdict(
list, {self._funcNameToCmdName(f): [self.localCmdPkg] for f in commandsLocation.__all__}
list,
{
command.name: [PluginCommand(command, get_full_type_name(commandsLocation))]
for command in commands
},
)

def list_commands(self, ctx: click.Context) -> list[str]:
Expand Down Expand Up @@ -182,14 +197,18 @@ def get_command(self, ctx: click.Context, name: str) -> click.Command | None:
if name not in commands:
return None
self._raiseIfDuplicateCommands(commands)
module_str = commands[name][0] + "." + self._cmdNameToFuncName(name)
# The click.command decorator returns an instance of a class, which
# is something that doImport is not expecting. We add it in as an
# option here to appease mypy.
with time_this(log, msg="Importing command %s (via %s)", args=(name, module_str)):
plugin = _importPlugin(module_str)
if not plugin:
return None
command = commands[name][0]
if isinstance(command, str):
module_str = command + "." + self._cmdNameToFuncName(name)
# The click.command decorator returns an instance of a class, which
# is something that doImport is not expecting. We add it in as an
# option here to appease mypy.
with time_this(log, msg="Importing command %s (via %s)", args=(name, module_str)):
plugin = _importPlugin(module_str)
if not plugin:
return None
else:
plugin = command.command
if not isinstance(plugin, click.Command):
raise RuntimeError(
f"Command {name!r} loaded from {module_str} is not a click Command, is {type(plugin)}"
Expand Down Expand Up @@ -221,21 +240,22 @@ def _setupLogging(self, ctx: click.Context | None) -> None:
)

@classmethod
def getPluginList(cls) -> list[str]:
def getPluginList(cls) -> list[ResourcePath]:
"""Get the list of importable yaml files that contain cli data for this
command.
Returns
-------
`list` [`str`]
`list` [`lsst.resources.ResourcePath`]
The list of files that contain yaml data about a cli plugin.
"""
if not hasattr(cls, "pluginEnvVar"):
return []
pluginModules = os.environ.get(cls.pluginEnvVar)
if pluginModules:
return [p for p in pluginModules.split(":") if p != ""]
return []
yaml_files = []
if hasattr(cls, "pluginEnvVar"):
pluginModules = os.environ.get(cls.pluginEnvVar)
if pluginModules:
yaml_files.extend([ResourcePath(p) for p in pluginModules.split(":") if p != ""])

return yaml_files

@classmethod
def _funcNameToCmdName(cls, functionName: str) -> str:
Expand All @@ -257,14 +277,14 @@ def _cmdNameToFuncName(cls, commandName: str) -> str:

@staticmethod
def _mergeCommandLists(
a: defaultdict[str, list[str]], b: defaultdict[str, list[str]]
) -> defaultdict[str, list[str]]:
a: defaultdict[str, list[str | PluginCommand]], b: defaultdict[str, list[str | PluginCommand]]
) -> defaultdict[str, list[str | PluginCommand]]:
"""Combine two dicts whose keys are strings (command name) and values
are list of string (the package(s) that provide the named command).
Parameters
----------
a : `defaultdict` [`str`, `list` [`str`]]
a : `defaultdict` [`str`, `list` [`str` | `PluginCommand` ]]
The key is the command name. The value is a list of package(s) that
contains the command.
b : (same as a)
Expand All @@ -280,20 +300,24 @@ def _mergeCommandLists(
return a

@classmethod
def _getPluginCommands(cls) -> defaultdict[str, list[str]]:
def _getPluginCommands(cls) -> defaultdict[str, list[str | PluginCommand]]:
"""Get the commands offered by plugin packages.
Returns
-------
commands : `defaultdict` [`str`, `list` [`str`]]
The key is the command name. The value is a list of package(s) that
contains the command.
Notes
-----
Assumes that if entry points are defined, the plugin environment
variable will not be defined for that same package.
"""
commands: defaultdict[str, list[str]] = defaultdict(list)
commands: defaultdict[str, list[str | PluginCommand]] = defaultdict(list)
for pluginName in cls.getPluginList():
try:
with open(pluginName) as resourceFile:
resources = defaultdict(list, yaml.safe_load(resourceFile))
resources = defaultdict(list, yaml.safe_load(pluginName.read()))
except Exception as err:
log.warning("Error loading commands from %s, skipping. %s", pluginName, err)
continue
Expand All @@ -302,28 +326,37 @@ def _getPluginCommands(cls) -> defaultdict[str, list[str]]:
continue
pluginCommands = {cmd: [resources["cmd"]["import"]] for cmd in resources["cmd"]["commands"]}
cls._mergeCommandLists(commands, defaultdict(list, pluginCommands))

if hasattr(cls, "entryPoint"):
plugins = entry_points(group=cls.entryPoint)
for p in plugins:
func = p.load()
func_name = get_full_type_name(func)
pluginCommands = {cmd.name: [PluginCommand(cmd, func_name)] for cmd in func()}
cls._mergeCommandLists(commands, defaultdict(list, pluginCommands))

return commands

@cache
def _getCommands(self) -> defaultdict[str, list[str]]:
def _getCommands(self) -> defaultdict[str, list[str | PluginCommand]]:
"""Get the commands offered by daf_butler and plugin packages.
Returns
-------
commands : `defaultdict` [`str`, `list` [`str`]]
commands : `defaultdict` [`str`, `list` [`str` | `PluginCommand` ]]
The key is the command name. The value is a list of package(s) that
contains the command.
"""
return self._mergeCommandLists(self.getLocalCommands(), self._getPluginCommands())

@staticmethod
def _raiseIfDuplicateCommands(commands: defaultdict[str, list[str]]) -> None:
def _raiseIfDuplicateCommands(commands: defaultdict[str, list[str | PluginCommand]]) -> None:
"""If any provided command is offered by more than one package raise an
exception.
Parameters
----------
commands : `defaultdict` [`str`, `list` [`str`]]
commands : `defaultdict` [`str`, `list` [`str` | `PLuginCommand` ]]
The key is the command name. The value is a list of package(s) that
contains the command.
Expand All @@ -336,7 +369,12 @@ def _raiseIfDuplicateCommands(commands: defaultdict[str, list[str]]) -> None:
msg = ""
for command, packages in commands.items():
if len(packages) > 1:
msg += f"Command '{command}' exists in packages {', '.join(packages)}. "
pkg_names: list[str] = []
for p in packages:
if not isinstance(p, str):
p = p.source
pkg_names.append(p)
msg += f"Command '{command}' exists in packages {', '.join(pkg_names)}. "
if msg:
raise click.ClickException(message=msg + "Duplicate commands are not supported, aborting.")

Expand All @@ -347,6 +385,7 @@ class ButlerCLI(LoaderCLI):
localCmdPkg = "lsst.daf.butler.cli.cmd"

pluginEnvVar = "DAF_BUTLER_PLUGINS"
entryPoint = "butler.cli"

@classmethod
def _funcNameToCmdName(cls, functionName: str) -> str:
Expand Down Expand Up @@ -377,12 +416,12 @@ def _cmdNameToFuncName(cls, commandName: str) -> str:
class UncachedButlerCLI(ButlerCLI):
"""ButlerCLI that can be used where caching of the commands is disabled."""

def _getCommands(self) -> defaultdict[str, list[str]]: # type: ignore[override]
def _getCommands(self) -> defaultdict[str, list[str | PluginCommand]]: # type: ignore[override]
"""Get the commands offered by daf_butler and plugin packages.
Returns
-------
commands : `defaultdict` [`str`, `list` [`str`]]
commands : `defaultdict` [`str`, `list` [`str` | `PluginCommand` ]]
The key is the command name. The value is a list of package(s) that
contains the command.
"""
Expand Down

0 comments on commit 80b168e

Please sign in to comment.