diff --git a/patchback/_utils.py b/patchback/_utils.py new file mode 100644 index 0000000..63077a8 --- /dev/null +++ b/patchback/_utils.py @@ -0,0 +1,20 @@ +""" +Misc. utility functions +""" + +from __future__ import annotations + +from typing import TypeVar + +_T1 = TypeVar("_T1") +_T2 = TypeVar("_T2") + + +def strip_nones(mapping: dict[_T1, _T2 | None]) -> dict[_T1, _T2]: + """ + Remove keys set to None from a dictionary + + Returns: + A new dictionary instance + """ + return {key: value for key, value in mapping.items() if value is not None} diff --git a/patchback/config.py b/patchback/config.py index a550f87..0606dde 100644 --- a/patchback/config.py +++ b/patchback/config.py @@ -11,6 +11,7 @@ DEFAULT_BACKPORT_BRANCH_PREFIX = 'patchback/backports/' DEFAULT_BACKPORT_LABEL_PREFIX = 'backport-' DEFAULT_TARGET_BRANCH_PREFIX = '' +DEFAULT_FAILED_BACKPORT_LABEL_PREFIX = 'failed-backport-' @attr.dataclass @@ -30,6 +31,24 @@ class PatchbackConfig: ) """Prefix that the older/stable version branch has.""" + failed_label_prefix: str | None = attr.ib(default=DEFAULT_FAILED_BACKPORT_LABEL_PREFIX) + """ + Add {failed_label_prefix}-{target_branch} label when backport fails. + {target_branch_prefix} is stripped from {target_branch}. + Set to None to disable adding a label on failure. + """ + + @failed_label_prefix.validator + def _v_failed_label_prefix(self, _, value: str | None) -> None: + """ + Ensure backport_label_prefix and failed_label_prefix are different + to avoid infinite loops + """ + if value == self.backport_label_prefix: + raise ValueError( + 'failed_label_prefix and backport_label_prefix must be unique values' + ) + async def get_patchback_config( *, diff --git a/patchback/event_handlers.py b/patchback/event_handlers.py index f67ce7f..9898b6c 100644 --- a/patchback/event_handlers.py +++ b/patchback/event_handlers.py @@ -1,10 +1,14 @@ """Webhook event handlers.""" +from __future__ import annotations + import http +import functools import logging import pathlib import tempfile from subprocess import CalledProcessError, check_output, check_call +from typing import Any from anyio import run_in_thread from gidgethub import BadRequest, ValidationError @@ -18,6 +22,7 @@ from .locking_api import LockingAPI from .config import get_patchback_config from .github_reporter import PullRequestReporter +from .labels_api import IssueLabelsAPI, RepoLabelsAPI logger = logging.getLogger(__name__) @@ -282,6 +287,9 @@ async def on_merge_of_labeled_pr( repository['pulls_url'], repository['full_name'], repository['clone_url'], + repo_config.backport_label_prefix, + repo_config.target_branch_prefix, + repo_config.failed_label_prefix, ) @@ -332,6 +340,9 @@ async def on_label_added_to_merged_pr( repository['pulls_url'], repository['full_name'], repository['clone_url'], + repo_config.backport_label_prefix, + repo_config.target_branch_prefix, + repo_config.failed_label_prefix, ) @@ -348,6 +359,9 @@ async def process_pr_backport_labels( backport_branch_prefix, pr_api_url, repo_slug, git_url, + backport_label_prefix: str, + target_branch_prefix: str, + failed_label_prefix: str | None, ) -> None: gh_api = RUNTIME_CONTEXT.app_installation_client checks_api = ChecksAPI( @@ -366,6 +380,17 @@ async def process_pr_backport_labels( locking_api=locking_api, branch_name=target_branch, ) + labels_api = RepoLabelsAPI(api=gh_api, repo_slug=repo_slug) + issue_labels_api = IssueLabelsAPI(api=gh_api, repo_slug=repo_slug, number=pr_number) + failed_label_cb = functools.partial( + add_failure_label, + labels_api=labels_api, + issue_labels_api=issue_labels_api, + backport_label_prefix=backport_label_prefix, + failed_label_prefix=failed_label_prefix, + target_branch_prefix=target_branch_prefix, + target_branch=target_branch, + ) await pr_reporter.start_reporting(pr_head_sha, pr_number, pr_merge_commit) @@ -396,6 +421,7 @@ async def process_pr_backport_labels( subtitle='💔 cherry-picking failed — target branch does not exist', summary=f'❌ {lu_err!s}', ) + await failed_label_cb() return except ValueError as val_err: logger.info( @@ -409,6 +435,7 @@ async def process_pr_backport_labels( text=manual_backport_guide, summary=f'❌ {val_err!s}', ) + await failed_label_cb() return except PermissionError as perm_err: logger.info( @@ -423,6 +450,7 @@ async def process_pr_backport_labels( text=manual_backport_guide, summary=f'❌ {perm_err!s}', ) + await failed_label_cb() return else: logger.info('Backport PR branch: `%s`', backport_pr_branch) @@ -461,6 +489,7 @@ async def process_pr_backport_labels( text=manual_backport_guide, summary=f'❌ {backport_pr_branch_msg}\n\n{val_err!s}', ) + await failed_label_cb() return except BadRequest as bad_req_err: if ( @@ -480,6 +509,7 @@ async def process_pr_backport_labels( text=manual_backport_guide, summary=f'❌ {backport_pr_branch_msg}\n\n{bad_req_err!s}', ) + await failed_label_cb() return else: logger.info('Created a PR @ %s', pr_resp['html_url']) @@ -490,3 +520,31 @@ async def process_pr_backport_labels( text=f'Backported as {pr_resp["html_url"]}', summary=f'✅ {backport_pr_branch_msg!s}', ) + +async def add_failure_label( + *, + labels_api: RepoLabelsAPI, + issue_labels_api: IssueLabelsAPI, + backport_label_prefix: str, + failed_label_prefix: str | None, + target_branch_prefix: str, + target_branch: str, +) -> dict[str, Any] | None: + if failed_label_prefix is None: + return None + stripped_branch = target_branch[len(target_branch_prefix):] + label: str = failed_label_prefix + stripped_branch + backport_label = backport_label_prefix + stripped_branch + # Create label if it doesn't exist + try: + await labels_api.get_label(label) + except BadRequest as exc: + if exc.status_code != 404: + raise + await labels_api.create_label( + label, f'Failed to backport PR to {target_branch}.' + ) + # Add failed label + await issue_labels_api.add_labels(label) + # Delete backport label + await issue_labels_api.remove_label(backport_label) diff --git a/patchback/labels_api.py b/patchback/labels_api.py new file mode 100644 index 0000000..ba87ccc --- /dev/null +++ b/patchback/labels_api.py @@ -0,0 +1,52 @@ +""" +Wrappers around repository and issue label APIs +""" + +from __future__ import annotations + +from collections.abc import AsyncIterator +from typing import Any + +from gidgethub.abc import GitHubAPI + +from patchback._utils import strip_nones + + +class RepoLabelsAPI: + def __init__(self, *, api: GitHubAPI, repo_slug: str) -> None: + self._api: GitHubAPI = api + self._labels_api = f"/repos/{repo_slug}/labels" + + async def create_label( + self, name: str, description: str | None = None, color: str | None = None + ) -> dict[str, Any]: + return await self._api.post( + self._labels_api, + data=strip_nones( + {"name": name, "description": description, "color": color} + ), + ) + + def list_labels(self) -> AsyncIterator[dict[str, Any]]: + return self._api.getiter(self._labels_api) + + async def get_label(self, name: str) -> dict[str, Any]: + return await self._api.getitem(f"{self._labels_api}/{name}") + + +class IssueLabelsAPI: + def __init__(self, *, api: GitHubAPI, repo_slug: str, number: int) -> None: + self._api: GitHubAPI = api + self._issue_labels_api = f"/repos/{repo_slug}/issues/{number}/labels" + + async def add_labels(self, *labels: str) -> list[dict[str, Any]]: + return await self._api.post( + self._issue_labels_api, + data={"labels": labels}, + ) + + def list_labels(self) -> AsyncIterator[dict[str, Any]]: + return self._api.getiter(self._issue_labels_api) + + async def remove_label(self, label: str) -> None: + return await self._api.delete(f"{self._issue_labels_api}/{label}")