Skip to content

Commit

Permalink
Improve typing in prefect.server.events (#16692)
Browse files Browse the repository at this point in the history
  • Loading branch information
desertaxle authored Jan 11, 2025
1 parent d069600 commit eb44a8d
Show file tree
Hide file tree
Showing 8 changed files with 160 additions and 86 deletions.
9 changes: 6 additions & 3 deletions src/prefect/_internal/retries.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def retry_async_fn(
retry_on_exceptions: tuple[type[Exception], ...] = (Exception,),
operation_name: Optional[str] = None,
) -> Callable[
[Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, Optional[R]]]
[Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]]
]:
"""A decorator for retrying an async function.
Expand All @@ -48,9 +48,9 @@ def retry_async_fn(

def decorator(
func: Callable[P, Coroutine[Any, Any, R]],
) -> Callable[P, Coroutine[Any, Any, Optional[R]]]:
) -> Callable[P, Coroutine[Any, Any, R]]:
@wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> Optional[R]:
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
name = operation_name or func.__name__
for attempt in range(max_attempts):
try:
Expand All @@ -67,6 +67,9 @@ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> Optional[R]:
f"Retrying in {delay:.2f} seconds..."
)
await asyncio.sleep(delay)
# Technically unreachable, but this raise helps pyright know that this function
# won't return None.
raise Exception(f"Function {name!r} failed after {max_attempts} attempts")

return wrapper

Expand Down
22 changes: 12 additions & 10 deletions src/prefect/server/events/schemas/automations.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ class EventTrigger(ResourceTrigger):
@model_validator(mode="before")
@classmethod
def enforce_minimum_within_for_proactive_triggers(
cls, data: Dict[str, Any]
cls, data: Dict[str, Any] | Any
) -> Dict[str, Any]:
if not isinstance(data, dict):
return data
Expand All @@ -342,7 +342,7 @@ def enforce_minimum_within_for_proactive_triggers(

return data

def covers(self, event: ReceivedEvent):
def covers(self, event: ReceivedEvent) -> bool:
if not self.covers_resources(event.resource, event.related):
return False

Expand All @@ -356,10 +356,10 @@ def immediate(self) -> bool:
"""Does this reactive trigger fire immediately for all events?"""
return self.posture == Posture.Reactive and self.within == timedelta(0)

_event_pattern: Optional[re.Pattern] = PrivateAttr(None)
_event_pattern: Optional[re.Pattern[str]] = PrivateAttr(None)

@property
def event_pattern(self) -> re.Pattern:
def event_pattern(self) -> re.Pattern[str]:
"""A regular expression which may be evaluated against any event string to
determine if this trigger would be interested in the event"""
if self._event_pattern:
Expand Down Expand Up @@ -625,13 +625,15 @@ class Firing(PrefectBaseModel):

id: UUID = Field(default_factory=uuid4)

trigger: ServerTriggerTypes = Field(..., description="The trigger that is firing")
trigger: Union[ServerTriggerTypes, CompositeTrigger] = Field(
default=..., description="The trigger that is firing"
)
trigger_states: Set[TriggerState] = Field(
...,
default=...,
description="The state changes represented by this Firing",
)
triggered: DateTime = Field(
...,
default=...,
description=(
"The time at which this trigger fired, which may differ from the "
"occurred time of the associated event (as events processing may always "
Expand All @@ -654,16 +656,16 @@ class Firing(PrefectBaseModel):
),
)
triggering_event: Optional[ReceivedEvent] = Field(
None,
default=None,
description=(
"The most recent event associated with this Firing. This may be the "
"event that caused the trigger to fire (for Reactive triggers), or the "
"last event to match the trigger (for Proactive triggers), or the state "
"change event (for a Metric trigger)."
),
)
triggering_value: Any = Field(
None,
triggering_value: Optional[Any] = Field(
default=None,
description=(
"A value associated with this firing of a trigger. Maybe used to "
"convey additional information at the point of firing, like the value of "
Expand Down
19 changes: 12 additions & 7 deletions src/prefect/server/events/services/actions.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
from __future__ import annotations

import asyncio
from typing import Optional
from typing import TYPE_CHECKING, NoReturn

from prefect.logging import get_logger
from prefect.server.events import actions
from prefect.server.utilities.messaging import create_consumer
from prefect.server.utilities.messaging import Consumer, create_consumer

if TYPE_CHECKING:
import logging

logger = get_logger(__name__)
logger: "logging.Logger" = get_logger(__name__)


class Actions:
"""Runs actions triggered by Automatinos"""

name: str = "Actions"

consumer_task: Optional[asyncio.Task] = None
consumer_task: asyncio.Task[None] | None = None

async def start(self):
async def start(self) -> NoReturn:
assert self.consumer_task is None, "Actions already started"
self.consumer = create_consumer("actions")
self.consumer: Consumer = create_consumer("actions")

async with actions.consumer() as handler:
self.consumer_task = asyncio.create_task(self.consumer.run(handler))
Expand All @@ -28,7 +33,7 @@ async def start(self):
except asyncio.CancelledError:
pass

async def stop(self):
async def stop(self) -> None:
assert self.consumer_task is not None, "Actions not started"
self.consumer_task.cancel()
try:
Expand Down
19 changes: 12 additions & 7 deletions src/prefect/server/events/services/event_logger.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,31 @@
from __future__ import annotations

import asyncio
from typing import Optional
from typing import TYPE_CHECKING, NoReturn

import pendulum
import rich

from prefect.logging import get_logger
from prefect.server.events.schemas.events import ReceivedEvent
from prefect.server.utilities.messaging import Message, create_consumer
from prefect.server.utilities.messaging import Consumer, Message, create_consumer

if TYPE_CHECKING:
import logging

logger = get_logger(__name__)
logger: "logging.Logger" = get_logger(__name__)


class EventLogger:
"""A debugging service that logs events to the console as they arrive."""

name: str = "EventLogger"

consumer_task: Optional[asyncio.Task] = None
consumer_task: asyncio.Task[None] | None = None

async def start(self):
async def start(self) -> NoReturn:
assert self.consumer_task is None, "Logger already started"
self.consumer = create_consumer("events")
self.consumer: Consumer = create_consumer("events")

console = rich.console.Console()

Expand All @@ -46,7 +51,7 @@ async def handler(message: Message):
except asyncio.CancelledError:
pass

async def stop(self):
async def stop(self) -> None:
assert self.consumer_task is not None, "Logger not started"
self.consumer_task.cancel()
try:
Expand Down
26 changes: 18 additions & 8 deletions src/prefect/server/events/services/event_persister.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
storage as fast as it can. Never gets tired.
"""

from __future__ import annotations

import asyncio
from contextlib import asynccontextmanager
from datetime import timedelta
from typing import AsyncGenerator, List, Optional
from typing import TYPE_CHECKING, AsyncGenerator, List, NoReturn

import pendulum
import sqlalchemy as sa
Expand All @@ -15,25 +17,33 @@
from prefect.server.database import provide_database_interface
from prefect.server.events.schemas.events import ReceivedEvent
from prefect.server.events.storage.database import write_events
from prefect.server.utilities.messaging import Message, MessageHandler, create_consumer
from prefect.server.utilities.messaging import (
Consumer,
Message,
MessageHandler,
create_consumer,
)
from prefect.settings import (
PREFECT_API_SERVICES_EVENT_PERSISTER_BATCH_SIZE,
PREFECT_API_SERVICES_EVENT_PERSISTER_FLUSH_INTERVAL,
PREFECT_EVENTS_RETENTION_PERIOD,
)

logger = get_logger(__name__)
if TYPE_CHECKING:
import logging

logger: "logging.Logger" = get_logger(__name__)


class EventPersister:
"""A service that persists events to the database as they arrive."""

name: str = "EventLogger"

consumer_task: Optional[asyncio.Task] = None
consumer_task: asyncio.Task[None] | None = None

def __init__(self):
self._started_event: Optional[asyncio.Event] = None
self._started_event: asyncio.Event | None = None

@property
def started_event(self) -> asyncio.Event:
Expand All @@ -45,9 +55,9 @@ def started_event(self) -> asyncio.Event:
def started_event(self, value: asyncio.Event) -> None:
self._started_event = value

async def start(self):
async def start(self) -> NoReturn:
assert self.consumer_task is None, "Event persister already started"
self.consumer = create_consumer("events")
self.consumer: Consumer = create_consumer("events")

async with create_handler(
batch_size=PREFECT_API_SERVICES_EVENT_PERSISTER_BATCH_SIZE.value(),
Expand All @@ -64,7 +74,7 @@ async def start(self):
except asyncio.CancelledError:
pass

async def stop(self):
async def stop(self) -> None:
assert self.consumer_task is not None, "Event persister not started"
self.consumer_task.cancel()
try:
Expand Down
23 changes: 14 additions & 9 deletions src/prefect/server/events/services/triggers.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,30 @@
from __future__ import annotations

import asyncio
from typing import Optional
from typing import TYPE_CHECKING, Any, NoReturn, Optional

from prefect.logging import get_logger
from prefect.server.events import triggers
from prefect.server.services.loop_service import LoopService
from prefect.server.utilities.messaging import create_consumer
from prefect.server.utilities.messaging import Consumer, create_consumer
from prefect.settings import PREFECT_EVENTS_PROACTIVE_GRANULARITY

logger = get_logger(__name__)
if TYPE_CHECKING:
import logging

logger: "logging.Logger" = get_logger(__name__)


class ReactiveTriggers:
"""Runs the reactive triggers consumer"""

name: str = "ReactiveTriggers"

consumer_task: Optional[asyncio.Task] = None
consumer_task: asyncio.Task[None] | None = None

async def start(self):
async def start(self) -> NoReturn:
assert self.consumer_task is None, "Reactive triggers already started"
self.consumer = create_consumer("events")
self.consumer: Consumer = create_consumer("events")

async with triggers.consumer() as handler:
self.consumer_task = asyncio.create_task(self.consumer.run(handler))
Expand All @@ -30,7 +35,7 @@ async def start(self):
except asyncio.CancelledError:
pass

async def stop(self):
async def stop(self) -> None:
assert self.consumer_task is not None, "Reactive triggers not started"
self.consumer_task.cancel()
try:
Expand All @@ -43,7 +48,7 @@ async def stop(self):


class ProactiveTriggers(LoopService):
def __init__(self, loop_seconds: Optional[float] = None, **kwargs):
def __init__(self, loop_seconds: Optional[float] = None, **kwargs: Any):
super().__init__(
loop_seconds=(
loop_seconds
Expand All @@ -52,5 +57,5 @@ def __init__(self, loop_seconds: Optional[float] = None, **kwargs):
**kwargs,
)

async def run_once(self):
async def run_once(self) -> None:
await triggers.evaluate_proactive_triggers()
Loading

0 comments on commit eb44a8d

Please sign in to comment.