From 24095740e7c12220a8ca5cc14766cf0ec95764ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Janek=20Nouvertn=C3=A9?= Date: Wed, 3 Jan 2024 20:16:31 +0100 Subject: [PATCH] fix: #2867 - Ensure `type_encoders` are passed to exception response (#2941) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Pass type_encoders to exception handling response --------- Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> --- .../middleware/exceptions/_debug_response.py | 18 +++++++++++++++++- litestar/middleware/exceptions/middleware.py | 7 ++++--- .../test_exception_handler_middleware.py | 17 +++++++++++++++++ 3 files changed, 38 insertions(+), 4 deletions(-) diff --git a/litestar/middleware/exceptions/_debug_response.py b/litestar/middleware/exceptions/_debug_response.py index 3cdd86dc6f..d80ea4bde9 100644 --- a/litestar/middleware/exceptions/_debug_response.py +++ b/litestar/middleware/exceptions/_debug_response.py @@ -26,6 +26,7 @@ from inspect import FrameInfo from litestar.connection import Request + from litestar.types import TypeEncodersMap tpl_dir = Path(__file__).parent / "templates" @@ -191,4 +192,19 @@ def create_debug_response(request: Request, exc: Exception) -> Response: content = create_plain_text_response_content(exc) media_type = MediaType.TEXT - return Response(content=content, media_type=media_type, status_code=HTTP_500_INTERNAL_SERVER_ERROR) + return Response( + content=content, + media_type=media_type, + status_code=HTTP_500_INTERNAL_SERVER_ERROR, + type_encoders=_get_type_encoders_for_request(request) if request is not None else None, + ) + + +def _get_type_encoders_for_request(request: Request) -> TypeEncodersMap | None: + try: + return request.route_handler.resolve_type_encoders() + # we might be in a 404, or before we could resolve the handler, so this + # could potentially error out. In this case we fall back on the application + # type encoders + except (KeyError, AttributeError): + return request.app.type_encoders diff --git a/litestar/middleware/exceptions/middleware.py b/litestar/middleware/exceptions/middleware.py index 97dacbeb99..4828cfd8d5 100644 --- a/litestar/middleware/exceptions/middleware.py +++ b/litestar/middleware/exceptions/middleware.py @@ -11,7 +11,7 @@ from litestar.enums import MediaType, ScopeType from litestar.exceptions import WebSocketException from litestar.middleware.cors import CORSMiddleware -from litestar.middleware.exceptions._debug_response import create_debug_response +from litestar.middleware.exceptions._debug_response import _get_type_encoders_for_request, create_debug_response from litestar.serialization import encode_json from litestar.status_codes import HTTP_500_INTERNAL_SERVER_ERROR from litestar.utils.deprecation import warn_deprecation @@ -85,7 +85,7 @@ class ExceptionResponseContent: extra: dict[str, Any] | list[Any] | None = field(default=None) """An extra mapping to attach to the exception.""" - def to_response(self) -> Response: + def to_response(self, request: Request | None = None) -> Response: """Create a response from the model attributes. Returns: @@ -103,6 +103,7 @@ def to_response(self) -> Response: headers=self.headers, status_code=self.status_code, media_type=self.media_type, + type_encoders=_get_type_encoders_for_request(request) if request is not None else None, ) @@ -139,7 +140,7 @@ def create_exception_response(request: Request[Any, Any, Any], exc: Exception) - extra=getattr(exc, "extra", None), media_type=media_type, ) - return content.to_response() + return content.to_response(request=request) class ExceptionHandlerMiddleware: diff --git a/tests/unit/test_middleware/test_exception_handler_middleware.py b/tests/unit/test_middleware/test_exception_handler_middleware.py index 7f2cae85d1..cb53adc88a 100644 --- a/tests/unit/test_middleware/test_exception_handler_middleware.py +++ b/tests/unit/test_middleware/test_exception_handler_middleware.py @@ -362,3 +362,20 @@ def method(self) -> None: if exc is not None and exc.__traceback__ is not None: frame = getinnerframes(exc.__traceback__, 2)[-1] assert get_symbol_name(frame) == "Test.method" + + +def test_serialize_custom_types() -> None: + # ensure type encoders are passed down to the created response so custom types that + # might end up as part of a ValidationException are handled properly + # https://github.com/litestar-org/litestar/issues/2867 + class Foo: + def __init__(self, value: str) -> None: + self.value = value + + @get() + def handler() -> None: + raise ValidationException(extra={"foo": Foo("bar")}) + + with create_test_client([handler], type_encoders={Foo: lambda f: f.value}) as client: + res = client.get("/") + assert res.json()["extra"] == {"foo": "bar"}