Skip to content

Commit

Permalink
Merge pull request #1107 from lsst/tickets/DM-47143
Browse files Browse the repository at this point in the history
DM-47143: Allow entry points to be used to discover cli plugins
  • Loading branch information
timj authored Oct 29, 2024
2 parents 6333457 + 3b3ded1 commit ef8b14b
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 37 deletions.
3 changes: 3 additions & 0 deletions doc/changes/DM-47143.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
The ``DAF_BUTLER_PLUGINS`` environment variable should no longer be set if packages use ``pip install`` and have been upgraded to use entry points.
Butler can now read the subcommands from ``pipe_base`` and ``daf_butler_migrate`` automatically.
Setting the environment variable for these packages will result in an error.
125 changes: 93 additions & 32 deletions python/lsst/daf/butler/cli/butler.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,24 @@
"main",
)


import abc
import dataclasses
import functools
import logging
import os
import traceback
import types
from collections import defaultdict
from functools import cache
from importlib.metadata import entry_points
from typing import Any

import click
import yaml
from lsst.resources import ResourcePath
from lsst.utils import doImport
from lsst.utils.introspection import get_full_type_name
from lsst.utils.timer import time_this

from .cliLog import CliLog
from .opt import log_file_option, log_label_option, log_level_option, log_tty_option, long_log_option
Expand Down Expand Up @@ -86,6 +91,16 @@ def _importPlugin(pluginName: str) -> types.ModuleType | type | None | click.Com
return None


@dataclasses.dataclass(frozen=True)
class PluginCommand:
"""A click Command and the plugin it came from."""

command: click.Command
"""The command (`click.Command`)."""
source: str
"""Where the command came from (`str`)."""


class LoaderCLI(click.MultiCommand, abc.ABC):
"""Extends `click.MultiCommand`, which dispatches to subcommands, to load
subcommands at runtime.
Expand Down Expand Up @@ -118,14 +133,14 @@ def localCmdPkg(self) -> str:
"""
raise NotImplementedError()

def getLocalCommands(self) -> defaultdict[str, list[str]]:
def getLocalCommands(self) -> defaultdict[str, list[str | PluginCommand]]:
"""Get the commands offered by the local package. This assumes that the
commands can be found in `localCmdPkg.__all__`, if this is not the case
then this function should be overridden.
Returns
-------
commands : `defaultdict` [`str`, `list` [`str`]]
commands : `defaultdict` [`str`, `list` [`str` | `PluginCommand` ]]
The key is the command name. The value is a list of package(s) that
contains the command.
"""
Expand All @@ -134,8 +149,13 @@ def getLocalCommands(self) -> defaultdict[str, list[str]]:
# _importPlugins logs an error, don't need to do it again here.
return defaultdict(list)
assert hasattr(commandsLocation, "__all__"), f"Must define __all__ in {commandsLocation}"
commands = [getattr(commandsLocation, name) for name in commandsLocation.__all__]
return defaultdict(
list, {self._funcNameToCmdName(f): [self.localCmdPkg] for f in commandsLocation.__all__}
list,
{
command.name: [PluginCommand(command, get_full_type_name(commandsLocation))]
for command in commands
},
)

def list_commands(self, ctx: click.Context) -> list[str]:
Expand Down Expand Up @@ -181,13 +201,18 @@ def get_command(self, ctx: click.Context, name: str) -> click.Command | None:
if name not in commands:
return None
self._raiseIfDuplicateCommands(commands)
module_str = commands[name][0] + "." + self._cmdNameToFuncName(name)
# The click.command decorator returns an instance of a class, which
# is something that doImport is not expecting. We add it in as an
# option here to appease mypy.
plugin = _importPlugin(module_str)
if not plugin:
return None
command = commands[name][0]
if isinstance(command, str):
module_str = command + "." + self._cmdNameToFuncName(name)
# The click.command decorator returns an instance of a class, which
# is something that doImport is not expecting. We add it in as an
# option here to appease mypy.
with time_this(log, msg="Importing command %s (via %s)", args=(name, module_str)):
plugin = _importPlugin(module_str)
if not plugin:
return None
else:
plugin = command.command
if not isinstance(plugin, click.Command):
raise RuntimeError(
f"Command {name!r} loaded from {module_str} is not a click Command, is {type(plugin)}"
Expand Down Expand Up @@ -219,21 +244,22 @@ def _setupLogging(self, ctx: click.Context | None) -> None:
)

@classmethod
def getPluginList(cls) -> list[str]:
def getPluginList(cls) -> list[ResourcePath]:
"""Get the list of importable yaml files that contain cli data for this
command.
Returns
-------
`list` [`str`]
`list` [`lsst.resources.ResourcePath`]
The list of files that contain yaml data about a cli plugin.
"""
if not hasattr(cls, "pluginEnvVar"):
return []
pluginModules = os.environ.get(cls.pluginEnvVar)
if pluginModules:
return [p for p in pluginModules.split(":") if p != ""]
return []
yaml_files = []
if hasattr(cls, "pluginEnvVar"):
pluginModules = os.environ.get(cls.pluginEnvVar)
if pluginModules:
yaml_files.extend([ResourcePath(p) for p in pluginModules.split(":") if p != ""])

return yaml_files

@classmethod
def _funcNameToCmdName(cls, functionName: str) -> str:
Expand All @@ -255,21 +281,21 @@ def _cmdNameToFuncName(cls, commandName: str) -> str:

@staticmethod
def _mergeCommandLists(
a: defaultdict[str, list[str]], b: defaultdict[str, list[str]]
) -> defaultdict[str, list[str]]:
a: defaultdict[str, list[str | PluginCommand]], b: defaultdict[str, list[str | PluginCommand]]
) -> defaultdict[str, list[str | PluginCommand]]:
"""Combine two dicts whose keys are strings (command name) and values
are list of string (the package(s) that provide the named command).
Parameters
----------
a : `defaultdict` [`str`, `list` [`str`]]
a : `defaultdict` [`str`, `list` [`str` | `PluginCommand` ]]
The key is the command name. The value is a list of package(s) that
contains the command.
b : (same as a)
Returns
-------
commands : `defaultdict` [`str`: [`str`]]
commands : `defaultdict` [`str`: [`str` | `PluginCommand` ]]
For convenience, returns a extended with b. ('a' is modified in
place.)
"""
Expand All @@ -278,20 +304,24 @@ def _mergeCommandLists(
return a

@classmethod
def _getPluginCommands(cls) -> defaultdict[str, list[str]]:
def _getPluginCommands(cls) -> defaultdict[str, list[str | PluginCommand]]:
"""Get the commands offered by plugin packages.
Returns
-------
commands : `defaultdict` [`str`, `list` [`str`]]
The key is the command name. The value is a list of package(s) that
contains the command.
Notes
-----
Assumes that if entry points are defined, the plugin environment
variable will not be defined for that same package.
"""
commands: defaultdict[str, list[str]] = defaultdict(list)
commands: defaultdict[str, list[str | PluginCommand]] = defaultdict(list)
for pluginName in cls.getPluginList():
try:
with open(pluginName) as resourceFile:
resources = defaultdict(list, yaml.safe_load(resourceFile))
resources = defaultdict(list, yaml.safe_load(pluginName.read()))
except Exception as err:
log.warning("Error loading commands from %s, skipping. %s", pluginName, err)
continue
Expand All @@ -300,27 +330,37 @@ def _getPluginCommands(cls) -> defaultdict[str, list[str]]:
continue
pluginCommands = {cmd: [resources["cmd"]["import"]] for cmd in resources["cmd"]["commands"]}
cls._mergeCommandLists(commands, defaultdict(list, pluginCommands))

if hasattr(cls, "entryPoint"):
plugins = entry_points(group=cls.entryPoint)
for p in plugins:
func = p.load()
func_name = get_full_type_name(func)
pluginCommands = {cmd.name: [PluginCommand(cmd, func_name)] for cmd in func()}
cls._mergeCommandLists(commands, defaultdict(list, pluginCommands))

return commands

def _getCommands(self) -> defaultdict[str, list[str]]:
@cache
def _getCommands(self) -> defaultdict[str, list[str | PluginCommand]]:
"""Get the commands offered by daf_butler and plugin packages.
Returns
-------
commands : `defaultdict` [`str`, `list` [`str`]]
commands : `defaultdict` [`str`, `list` [`str` | `PluginCommand` ]]
The key is the command name. The value is a list of package(s) that
contains the command.
"""
return self._mergeCommandLists(self.getLocalCommands(), self._getPluginCommands())

@staticmethod
def _raiseIfDuplicateCommands(commands: defaultdict[str, list[str]]) -> None:
def _raiseIfDuplicateCommands(commands: defaultdict[str, list[str | PluginCommand]]) -> None:
"""If any provided command is offered by more than one package raise an
exception.
Parameters
----------
commands : `defaultdict` [`str`, `list` [`str`]]
commands : `defaultdict` [`str`, `list` [`str` | `PLuginCommand` ]]
The key is the command name. The value is a list of package(s) that
contains the command.
Expand All @@ -333,7 +373,12 @@ def _raiseIfDuplicateCommands(commands: defaultdict[str, list[str]]) -> None:
msg = ""
for command, packages in commands.items():
if len(packages) > 1:
msg += f"Command '{command}' exists in packages {', '.join(packages)}. "
pkg_names: list[str] = []
for p in packages:
if not isinstance(p, str):
p = p.source
pkg_names.append(p)
msg += f"Command '{command}' exists in packages {', '.join(pkg_names)}. "
if msg:
raise click.ClickException(message=msg + "Duplicate commands are not supported, aborting.")

Expand All @@ -344,6 +389,7 @@ class ButlerCLI(LoaderCLI):
localCmdPkg = "lsst.daf.butler.cli.cmd"

pluginEnvVar = "DAF_BUTLER_PLUGINS"
entryPoint = "butler.cli"

@classmethod
def _funcNameToCmdName(cls, functionName: str) -> str:
Expand Down Expand Up @@ -371,6 +417,21 @@ def _cmdNameToFuncName(cls, commandName: str) -> str:
return super()._cmdNameToFuncName(commandName)


class UncachedButlerCLI(ButlerCLI):
"""ButlerCLI that can be used where caching of the commands is disabled."""

def _getCommands(self) -> defaultdict[str, list[str | PluginCommand]]: # type: ignore[override]
"""Get the commands offered by daf_butler and plugin packages.
Returns
-------
commands : `defaultdict` [`str`, `list` [`str` | `PluginCommand` ]]
The key is the command name. The value is a list of package(s) that
contains the command.
"""
return self._mergeCommandLists(self.getLocalCommands(), self._getPluginCommands())


@click.command(cls=ButlerCLI, context_settings=dict(help_option_names=["-h", "--help"]))
@log_level_option()
@long_log_option()
Expand Down
37 changes: 36 additions & 1 deletion python/lsst/daf/butler/tests/cliLogTestBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,15 @@
from typing import TYPE_CHECKING, Any

import click
from lsst.daf.butler.cli.butler import cli as butlerCli
from lsst.daf.butler.cli.butler import UncachedButlerCLI
from lsst.daf.butler.cli.cliLog import CliLog
from lsst.daf.butler.cli.opt import (
log_file_option,
log_label_option,
log_level_option,
log_tty_option,
long_log_option,
)
from lsst.daf.butler.cli.utils import LogCliRunner, clickResultMsg, command_test_env
from lsst.daf.butler.logging import ButlerLogRecords
from lsst.utils.logging import TRACE
Expand All @@ -68,6 +75,34 @@
lsstLog_WARN = 0


@click.command(cls=UncachedButlerCLI)
@log_level_option()
@long_log_option()
@log_file_option()
@log_tty_option()
@log_label_option()
def butlerCli(log_level: str, long_log: bool, log_file: str, log_tty: bool, log_label: str) -> None:
"""Uncached ButlerCLI.
Parameters
----------
log_level : `str`
The log level to use by default. ``log_level`` is handled by
``get_command`` and ``list_commands``, and is called in
one of those functions before this is called.
long_log : `bool`
Enable extended log output. ``long_log`` is handled by
``setup_logging``.
log_file : `str`
The log file name.
log_tty : `bool`
Whether to send logs to standard output.
log_label : `str`
Log labels.
"""
pass


@click.command()
@click.option("--expected-pyroot-level", type=int)
@click.option("--expected-pylsst-level", type=int)
Expand Down
15 changes: 11 additions & 4 deletions tests/test_cliPluginLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import click
import yaml
from lsst.daf.butler.cli import butler, cmd
from lsst.daf.butler.cli.butler import UncachedButlerCLI
from lsst.daf.butler.cli.utils import LogCliRunner, command_test_env


Expand Down Expand Up @@ -64,6 +65,12 @@ def duplicate_command_test_env(runner):
yield


@click.command(cls=UncachedButlerCLI)
def uncached_cli():
"""ButlerCLI that does not cache the commands."""
pass


class FailedLoadTest(unittest.TestCase):
"""Test failed plugin loading."""

Expand All @@ -73,7 +80,7 @@ def setUp(self):
def test_unimportablePlugin(self):
with command_test_env(self.runner, "test_cliPluginLoader", "non-existant-command-function"):
with self.assertLogs() as cm:
result = self.runner.invoke(butler.cli, "--help")
result = self.runner.invoke(uncached_cli, "--help")
self.assertEqual(result.exit_code, 0, f"output: {result.output!r} exception: {result.exception}")
expectedErrMsg = (
"Could not import plugin from test_cliPluginLoader.non_existant_command_function, skipping."
Expand Down Expand Up @@ -104,7 +111,7 @@ def setUp(self):
def test_loadAndExecutePluginCommand(self):
"""Test that a plugin command can be loaded and executed."""
with command_test_env(self.runner, "test_cliPluginLoader", "command-test"):
result = self.runner.invoke(butler.cli, "command-test")
result = self.runner.invoke(uncached_cli, "command-test")
self.assertEqual(result.exit_code, 0, f"output: {result.output} exception: {result.exception}")
self.assertEqual(result.stdout, "test command\n")

Expand All @@ -118,7 +125,7 @@ def test_loadAndExecuteLocalCommand(self):
def test_loadTopHelp(self):
"""Test that an expected command is produced by 'butler --help'"""
with command_test_env(self.runner, "test_cliPluginLoader", "command-test"):
result = self.runner.invoke(butler.cli, "--help")
result = self.runner.invoke(uncached_cli, "--help")
self.assertEqual(result.exit_code, 0, f"output: {result.output} exception: {result.exception}")
self.assertIn("command-test", result.stdout)

Expand Down Expand Up @@ -146,7 +153,7 @@ def test_listCommands_duplicate(self):
"""
self.maxDiff = None
with duplicate_command_test_env(self.runner):
result = self.runner.invoke(butler.cli, ["create", "test_repo"])
result = self.runner.invoke(uncached_cli, ["create", "test_repo"])
self.assertEqual(result.exit_code, 1, f"output: {result.output} exception: {result.exception}")
self.assertEqual(
result.output,
Expand Down

0 comments on commit ef8b14b

Please sign in to comment.