Skip to content

Commit

Permalink
[DEV-19025] middleware refactoring (#4)
Browse files Browse the repository at this point in the history
* DEV-19025: All counters in base middleware class

* DEV-19025: codestyle

* DEV-19025: codestyle

* DEV-19025: fix

* DEV-19025: fix
  • Loading branch information
alexreznikoff authored Feb 12, 2025
1 parent a495ead commit d272feb
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 135 deletions.
90 changes: 81 additions & 9 deletions src/huntflow_base_metrics/web_frameworks/_middleware.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
import abc
import time
from dataclasses import dataclass
from typing import Generic, Optional, Set, TypeVar

from huntflow_base_metrics import apply_labels
from huntflow_base_metrics._context import METRIC_CONTEXT
from huntflow_base_metrics.web_frameworks._request_metrics import (
EXCEPTIONS,
REQUESTS,
REQUESTS_IN_PROGRESS,
REQUESTS_PROCESSING_TIME,
RESPONSES,
)


@dataclass(frozen=True)
Expand All @@ -11,30 +20,93 @@ class PathTemplate:
is_handled: bool


@dataclass
class RequestContext:
method: str
path_template: PathTemplate
start_time: float
end_time: float = 0
status_code: int = 200

@property
def duration(self) -> float:
return self.end_time - self.start_time


RequestType = TypeVar("RequestType")


class PrometheusMiddleware(abc.ABC, Generic[RequestType]):
include_routes: Optional[Set[str]] = None
exclude_routes: Optional[Set[str]] = None

@staticmethod
@abc.abstractmethod
def get_method(request: RequestType) -> str:
pass

@staticmethod
@abc.abstractmethod
def get_path_template(request: RequestType) -> PathTemplate:
pass

@classmethod
def is_excluded(cls, path_template: PathTemplate) -> bool:
if cls.include_routes:
return path_template.value not in cls.include_routes
if cls.exclude_routes:
return path_template.value in cls.exclude_routes
return False
def get_request_context(cls, request: RequestType) -> RequestContext:
return RequestContext(
method=cls.get_method(request),
path_template=cls.get_path_template(request),
start_time=time.perf_counter(),
)

@classmethod
def need_process(cls, path_template: PathTemplate) -> bool:
def need_process(cls, ctx: RequestContext) -> bool:
return (
METRIC_CONTEXT.enable_metrics
and path_template.is_handled
and not cls.is_excluded(path_template)
and ctx.path_template.is_handled
and not cls._is_excluded(ctx.path_template)
)

@classmethod
def count_request_before(cls, ctx: RequestContext) -> None:
apply_labels(
REQUESTS_IN_PROGRESS,
method=ctx.method,
path_template=ctx.path_template.value,
).inc()
apply_labels(REQUESTS, method=ctx.method, path_template=ctx.path_template.value).inc()

@classmethod
def count_request_after(cls, ctx: RequestContext) -> None:
apply_labels(
REQUESTS_PROCESSING_TIME,
method=ctx.method,
path_template=ctx.path_template.value,
).observe(ctx.duration)
apply_labels(
RESPONSES,
method=ctx.method,
path_template=ctx.path_template.value,
status_code=str(ctx.status_code),
).inc()
apply_labels(
REQUESTS_IN_PROGRESS,
method=ctx.method,
path_template=ctx.path_template.value,
).dec()

@classmethod
def count_request_exceptions(cls, ctx: RequestContext, exception_type: str) -> None:
apply_labels(
EXCEPTIONS,
method=ctx.method,
path_template=ctx.path_template.value,
exception_type=exception_type,
).inc()

@classmethod
def _is_excluded(cls, path_template: PathTemplate) -> bool:
if cls.include_routes:
return path_template.value not in cls.include_routes
if cls.exclude_routes:
return path_template.value in cls.exclude_routes
return False
50 changes: 12 additions & 38 deletions src/huntflow_base_metrics/web_frameworks/aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,8 @@

from aiohttp.web import Application, Request, Response, middleware

from huntflow_base_metrics.base import apply_labels
from huntflow_base_metrics.export import export_to_http_response
from huntflow_base_metrics.web_frameworks._middleware import PathTemplate, PrometheusMiddleware
from huntflow_base_metrics.web_frameworks._request_metrics import (
EXCEPTIONS,
REQUESTS,
REQUESTS_IN_PROGRESS,
REQUESTS_PROCESSING_TIME,
RESPONSES,
)

__all__ = ["add_middleware", "get_http_response_metrics"]

Expand All @@ -22,48 +14,30 @@ class _PrometheusMiddleware(PrometheusMiddleware[Request]):
@classmethod
@middleware
async def dispatch(cls, request: Request, handler: Callable) -> Response:
method = request.method
path_template = cls.get_path_template(request)

if not cls.need_process(path_template):
ctx = cls.get_request_context(request)
if not cls.need_process(ctx):
return await handler(request)

apply_labels(REQUESTS_IN_PROGRESS, method=method, path_template=path_template.value).inc()
apply_labels(REQUESTS, method=method, path_template=path_template.value).inc()
cls.count_request_before(ctx)

before_time = time.perf_counter()
status_code = HTTPStatus.INTERNAL_SERVER_ERROR
try:
response = await handler(request)
except BaseException as e:
apply_labels(
EXCEPTIONS,
method=method,
path_template=path_template.value,
exception_type=type(e).__name__,
).inc()
ctx.status_code = HTTPStatus.INTERNAL_SERVER_ERROR
cls.count_request_exceptions(ctx, type(e).__name__)
raise
else:
status_code = response.status
after_time = time.perf_counter()
apply_labels(
REQUESTS_PROCESSING_TIME, method=method, path_template=path_template.value
).observe(after_time - before_time)
ctx.status_code = response.status
finally:
apply_labels(
RESPONSES,
method=method,
path_template=path_template.value,
status_code=str(status_code),
).inc()
apply_labels(
REQUESTS_IN_PROGRESS,
method=method,
path_template=path_template.value,
).dec()
ctx.end_time = time.perf_counter()
cls.count_request_after(ctx)

return response

@staticmethod
def get_method(request: Request) -> str:
return request.method

@staticmethod
def get_path_template(request: Request) -> PathTemplate:
match_info = request.match_info
Expand Down
50 changes: 12 additions & 38 deletions src/huntflow_base_metrics/web_frameworks/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,64 +8,38 @@
from starlette.routing import Match
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR

from huntflow_base_metrics.base import apply_labels
from huntflow_base_metrics.export import export_to_http_response
from huntflow_base_metrics.web_frameworks._middleware import PathTemplate, PrometheusMiddleware
from huntflow_base_metrics.web_frameworks._request_metrics import (
EXCEPTIONS,
REQUESTS,
REQUESTS_IN_PROGRESS,
REQUESTS_PROCESSING_TIME,
RESPONSES,
)

__all__ = ["add_middleware", "get_http_response_metrics"]


class _PrometheusMiddleware(PrometheusMiddleware[Request], BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
method = request.method
path_template = self.get_path_template(request)

if not self.need_process(path_template):
ctx = self.get_request_context(request)
if not self.need_process(ctx):
return await call_next(request)

apply_labels(REQUESTS_IN_PROGRESS, method=method, path_template=path_template.value).inc()
apply_labels(REQUESTS, method=method, path_template=path_template.value).inc()
self.count_request_before(ctx)

before_time = time.perf_counter()
status_code = HTTP_500_INTERNAL_SERVER_ERROR
try:
response = await call_next(request)
except BaseException as e:
apply_labels(
EXCEPTIONS,
method=method,
path_template=path_template.value,
exception_type=type(e).__name__,
).inc()
ctx.status_code = HTTP_500_INTERNAL_SERVER_ERROR
self.count_request_exceptions(ctx, type(e).__name__)
raise
else:
status_code = response.status_code
after_time = time.perf_counter()
apply_labels(
REQUESTS_PROCESSING_TIME, method=method, path_template=path_template.value
).observe(after_time - before_time)
ctx.status_code = response.status_code
finally:
apply_labels(
RESPONSES,
method=method,
path_template=path_template.value,
status_code=str(status_code),
).inc()
apply_labels(
REQUESTS_IN_PROGRESS,
method=method,
path_template=path_template.value,
).dec()
ctx.end_time = time.perf_counter()
self.count_request_after(ctx)

return response

@staticmethod
def get_method(request: Request) -> str:
return request.method

@staticmethod
def get_path_template(request: Request) -> PathTemplate:
for route in request.app.routes:
Expand Down
67 changes: 17 additions & 50 deletions src/huntflow_base_metrics/web_frameworks/litestar.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import time
from contextvars import ContextVar
from dataclasses import dataclass
from functools import wraps
from typing import Any, Callable, Iterable, Optional, Type

Expand All @@ -9,15 +8,11 @@
from litestar.middleware import AbstractMiddleware
from litestar.types import ASGIApp, Message, Receive, Scope, Send

from huntflow_base_metrics.base import apply_labels
from huntflow_base_metrics.export import export_to_http_response
from huntflow_base_metrics.web_frameworks._middleware import PathTemplate, PrometheusMiddleware
from huntflow_base_metrics.web_frameworks._request_metrics import (
EXCEPTIONS,
REQUESTS,
REQUESTS_IN_PROGRESS,
REQUESTS_PROCESSING_TIME,
RESPONSES,
from huntflow_base_metrics.web_frameworks._middleware import (
PathTemplate,
PrometheusMiddleware,
RequestContext,
)

__all__ = ["exception_context", "get_http_response_metrics", "get_middleware"]
Expand All @@ -36,14 +31,6 @@ def set(self, value: str) -> None:
exception_context = _ExceptionContext()


@dataclass
class _RequestSpan:
start_time: float
end_time: float = 0
duration: float = 0
status_code: int = 200


class _PrometheusMiddleware(PrometheusMiddleware[Request], AbstractMiddleware):
scopes = {ScopeType.HTTP}

Expand All @@ -53,62 +40,42 @@ def __init__(self, app: ASGIApp, *args: Any, **kwargs: Any) -> None:

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
request = Request[Any, Any, Any](scope, receive)
method = request.method
path_template = self.get_path_template(request)
ctx = self.get_request_context(request)

if not self.need_process(path_template):
if not self.need_process(ctx):
await self.app(scope, receive, send)
return

apply_labels(REQUESTS_IN_PROGRESS, method=method, path_template=path_template.value).inc()
apply_labels(REQUESTS, method=method, path_template=path_template.value).inc()
self.count_request_before(ctx)

span = _RequestSpan(start_time=time.perf_counter())
send_wrapper = self._get_send_wrapper(send, span)
send_wrapper = self._get_send_wrapper(send, ctx)

try:
await self.app(scope, receive, send_wrapper)
finally:
status_code = span.status_code
apply_labels(
REQUESTS_PROCESSING_TIME, method=method, path_template=path_template.value
).observe(span.duration)
apply_labels(
RESPONSES,
method=method,
path_template=path_template.value,
status_code=str(status_code),
).inc()
apply_labels(
REQUESTS_IN_PROGRESS,
method=method,
path_template=path_template.value,
).dec()
self.count_request_after(ctx)
exception_type = exception_context.get()
if exception_type:
apply_labels(
EXCEPTIONS,
method=method,
path_template=path_template.value,
exception_type=exception_type,
).inc()
self.count_request_exceptions(ctx, exception_type)

@staticmethod
def _get_send_wrapper(send: Send, span: _RequestSpan) -> Callable:
def _get_send_wrapper(send: Send, ctx: RequestContext) -> Callable:
@wraps(send)
async def wrapped_send(message: Message) -> None:
if message["type"] == "http.response.start":
span.status_code = message["status"]
ctx.status_code = message["status"]

if message["type"] == "http.response.body":
end = time.perf_counter()
span.duration = end - span.start_time
span.end_time = end
ctx.end_time = time.perf_counter()

await send(message)

return wrapped_send

@staticmethod
def get_method(request: Request) -> str:
return request.method

@staticmethod
def get_path_template(request: Request) -> PathTemplate:
return PathTemplate(value=request.scope["path_template"], is_handled=True)
Expand Down

0 comments on commit d272feb

Please sign in to comment.