Skip to content

Commit

Permalink
fix: #2867 - Ensure type_encoders are passed to exception response (#…
Browse files Browse the repository at this point in the history
…2941)

* Pass type_encoders to exception handling response

---------

Signed-off-by: Janek Nouvertné <[email protected]>
  • Loading branch information
provinzkraut authored Jan 3, 2024
1 parent 11b4353 commit 2409574
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 4 deletions.
18 changes: 17 additions & 1 deletion litestar/middleware/exceptions/_debug_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from inspect import FrameInfo

from litestar.connection import Request
from litestar.types import TypeEncodersMap

tpl_dir = Path(__file__).parent / "templates"

Expand Down Expand Up @@ -191,4 +192,19 @@ def create_debug_response(request: Request, exc: Exception) -> Response:
content = create_plain_text_response_content(exc)
media_type = MediaType.TEXT

return Response(content=content, media_type=media_type, status_code=HTTP_500_INTERNAL_SERVER_ERROR)
return Response(
content=content,
media_type=media_type,
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
type_encoders=_get_type_encoders_for_request(request) if request is not None else None,
)


def _get_type_encoders_for_request(request: Request) -> TypeEncodersMap | None:
try:
return request.route_handler.resolve_type_encoders()
# we might be in a 404, or before we could resolve the handler, so this
# could potentially error out. In this case we fall back on the application
# type encoders
except (KeyError, AttributeError):
return request.app.type_encoders
7 changes: 4 additions & 3 deletions litestar/middleware/exceptions/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from litestar.enums import MediaType, ScopeType
from litestar.exceptions import WebSocketException
from litestar.middleware.cors import CORSMiddleware
from litestar.middleware.exceptions._debug_response import create_debug_response
from litestar.middleware.exceptions._debug_response import _get_type_encoders_for_request, create_debug_response
from litestar.serialization import encode_json
from litestar.status_codes import HTTP_500_INTERNAL_SERVER_ERROR
from litestar.utils.deprecation import warn_deprecation
Expand Down Expand Up @@ -85,7 +85,7 @@ class ExceptionResponseContent:
extra: dict[str, Any] | list[Any] | None = field(default=None)
"""An extra mapping to attach to the exception."""

def to_response(self) -> Response:
def to_response(self, request: Request | None = None) -> Response:
"""Create a response from the model attributes.
Returns:
Expand All @@ -103,6 +103,7 @@ def to_response(self) -> Response:
headers=self.headers,
status_code=self.status_code,
media_type=self.media_type,
type_encoders=_get_type_encoders_for_request(request) if request is not None else None,
)


Expand Down Expand Up @@ -139,7 +140,7 @@ def create_exception_response(request: Request[Any, Any, Any], exc: Exception) -
extra=getattr(exc, "extra", None),
media_type=media_type,
)
return content.to_response()
return content.to_response(request=request)


class ExceptionHandlerMiddleware:
Expand Down
17 changes: 17 additions & 0 deletions tests/unit/test_middleware/test_exception_handler_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,3 +362,20 @@ def method(self) -> None:
if exc is not None and exc.__traceback__ is not None:
frame = getinnerframes(exc.__traceback__, 2)[-1]
assert get_symbol_name(frame) == "Test.method"


def test_serialize_custom_types() -> None:
# ensure type encoders are passed down to the created response so custom types that
# might end up as part of a ValidationException are handled properly
# https://github.com/litestar-org/litestar/issues/2867
class Foo:
def __init__(self, value: str) -> None:
self.value = value

@get()
def handler() -> None:
raise ValidationException(extra={"foo": Foo("bar")})

with create_test_client([handler], type_encoders={Foo: lambda f: f.value}) as client:
res = client.get("/")
assert res.json()["extra"] == {"foo": "bar"}

0 comments on commit 2409574

Please sign in to comment.