Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fastapi: fix wrapping of middlewares #3012

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
([#3037](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3037))
- `opentelemetry-instrumentation-sqlalchemy`: Fix a remaining memory leak in EngineTracer
([#3053](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3053))
- `opentelemetry-instrumentation-fastapi`: instrument unhandled exceptions
([#3012](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3012))

### Breaking changes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,14 @@ def client_response_hook(span: Span, scope: dict[str, Any], message: dict[str, A
from __future__ import annotations

import logging
import types
from typing import Collection, Literal

import fastapi
from starlette.applications import Starlette
from starlette.middleware.errors import ServerErrorMiddleware
from starlette.routing import Match
from starlette.types import ASGIApp

from opentelemetry.instrumentation._semconv import (
_get_schema_url,
Expand All @@ -199,9 +203,9 @@ def client_response_hook(span: Span, scope: dict[str, Any], message: dict[str, A
from opentelemetry.instrumentation.fastapi.package import _instruments
from opentelemetry.instrumentation.fastapi.version import __version__
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.metrics import get_meter
from opentelemetry.metrics import MeterProvider, get_meter
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace import get_tracer
from opentelemetry.trace import TracerProvider, get_tracer
from opentelemetry.util.http import (
get_excluded_urls,
parse_excluded_urls,
Expand All @@ -226,9 +230,9 @@ def instrument_app(
server_request_hook: ServerRequestHook = None,
client_request_hook: ClientRequestHook = None,
client_response_hook: ClientResponseHook = None,
tracer_provider=None,
meter_provider=None,
excluded_urls=None,
tracer_provider: TracerProvider | None = None,
lzchen marked this conversation as resolved.
Show resolved Hide resolved
meter_provider: MeterProvider | None = None,
excluded_urls: str | None = None,
http_capture_headers_server_request: list[str] | None = None,
http_capture_headers_server_response: list[str] | None = None,
http_capture_headers_sanitize_fields: list[str] | None = None,
Expand Down Expand Up @@ -280,21 +284,40 @@ def instrument_app(
schema_url=_get_schema_url(sem_conv_opt_in_mode),
)

app.add_middleware(
OpenTelemetryMiddleware,
excluded_urls=excluded_urls,
default_span_details=_get_default_span_details,
server_request_hook=server_request_hook,
client_request_hook=client_request_hook,
client_response_hook=client_response_hook,
# Pass in tracer/meter to get __name__and __version__ of fastapi instrumentation
tracer=tracer,
meter=meter,
http_capture_headers_server_request=http_capture_headers_server_request,
http_capture_headers_server_response=http_capture_headers_server_response,
http_capture_headers_sanitize_fields=http_capture_headers_sanitize_fields,
exclude_spans=exclude_spans,
# Instead of using `app.add_middleware` we monkey patch `build_middleware_stack` to insert our middleware
# as the outermost middleware.
# Otherwise `OpenTelemetryMiddleware` would have unhandled exceptions tearing through it and would not be able
# to faithfully record what is returned to the client since it technically cannot know what `ServerErrorMiddleware` is going to do.

def build_middleware_stack(self: Starlette) -> ASGIApp:
app = type(self).build_middleware_stack(self)
app = OpenTelemetryMiddleware(
app,
excluded_urls=excluded_urls,
default_span_details=_get_default_span_details,
server_request_hook=server_request_hook,
client_request_hook=client_request_hook,
client_response_hook=client_response_hook,
# Pass in tracer/meter to get __name__and __version__ of fastapi instrumentation
tracer=tracer,
meter=meter,
http_capture_headers_server_request=http_capture_headers_server_request,
http_capture_headers_server_response=http_capture_headers_server_response,
http_capture_headers_sanitize_fields=http_capture_headers_sanitize_fields,
exclude_spans=exclude_spans,
)
# Wrap in an outer layer of ServerErrorMiddleware so that any exceptions raised in OpenTelemetryMiddleware
# are handled.
# This should not happen unless there is a bug in OpenTelemetryMiddleware, but if there is we don't want that
# to impact the user's application just because we wrapped the middlewares in this order.
app = ServerErrorMiddleware(app)
return app

app._original_build_middleware_stack = app.build_middleware_stack
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: do you think it would be possible to use wrapt for monkeypatching instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather not. This is simpler and works just fine.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xrmx genuine question: what would be the benefit of that in this case?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xrmx genuine question: what would be the benefit of that in this case?

wrapt monkeypatching tend to generally work better than shuffling classes under the hood, had to move to wrapt in httpx instrumentation because otherwise the instrumentation did not patch stuff loaded before the load of the instrumentation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless there's a specific reason I don't think introducing more complexity and code to be executed is helpful here. There are plenty of tests checking that the monkey patching is working correctly. This is also not a method that users would ever call, in fact as someone who's responsible for significant changes to this very method I would have made it a private method if it had not already been public for years.

app.build_middleware_stack = types.MethodType(
build_middleware_stack, app
)

app._is_instrumented_by_opentelemetry = True
if app not in _InstrumentedFastAPI._instrumented_fastapi_apps:
_InstrumentedFastAPI._instrumented_fastapi_apps.add(app)
Expand All @@ -305,11 +328,12 @@ def instrument_app(

@staticmethod
def uninstrument_app(app: fastapi.FastAPI):
app.user_middleware = [
x
for x in app.user_middleware
if x.cls is not OpenTelemetryMiddleware
]
original_build_middleware_stack = getattr(
app, "_original_build_middleware_stack", None
)
if original_build_middleware_stack:
app.build_middleware_stack = original_build_middleware_stack
del app._original_build_middleware_stack
app.middleware_stack = app.build_middleware_stack()
app._is_instrumented_by_opentelemetry = False

Expand Down Expand Up @@ -337,12 +361,7 @@ def _instrument(self, **kwargs):
_InstrumentedFastAPI._http_capture_headers_sanitize_fields = (
kwargs.get("http_capture_headers_sanitize_fields")
)
_excluded_urls = kwargs.get("excluded_urls")
_InstrumentedFastAPI._excluded_urls = (
_excluded_urls_from_env
if _excluded_urls is None
else parse_excluded_urls(_excluded_urls)
)
_InstrumentedFastAPI._excluded_urls = kwargs.get("excluded_urls")
_InstrumentedFastAPI._meter_provider = kwargs.get("meter_provider")
_InstrumentedFastAPI._exclude_spans = kwargs.get("exclude_spans")
fastapi.FastAPI = _InstrumentedFastAPI
Expand All @@ -361,43 +380,29 @@ class _InstrumentedFastAPI(fastapi.FastAPI):
_server_request_hook: ServerRequestHook = None
_client_request_hook: ClientRequestHook = None
_client_response_hook: ClientResponseHook = None
_http_capture_headers_server_request: list[str] | None = None
_http_capture_headers_server_response: list[str] | None = None
_http_capture_headers_sanitize_fields: list[str] | None = None
_exclude_spans: list[Literal["receive", "send"]] | None = None

_instrumented_fastapi_apps = set()
_sem_conv_opt_in_mode = _HTTPStabilityMode.DEFAULT

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
tracer = get_tracer(
__name__,
__version__,
_InstrumentedFastAPI._tracer_provider,
schema_url=_get_schema_url(
_InstrumentedFastAPI._sem_conv_opt_in_mode
),
)
meter = get_meter(
__name__,
__version__,
_InstrumentedFastAPI._meter_provider,
schema_url=_get_schema_url(
_InstrumentedFastAPI._sem_conv_opt_in_mode
),
)
self.add_middleware(
OpenTelemetryMiddleware,
excluded_urls=_InstrumentedFastAPI._excluded_urls,
default_span_details=_get_default_span_details,
server_request_hook=_InstrumentedFastAPI._server_request_hook,
client_request_hook=_InstrumentedFastAPI._client_request_hook,
client_response_hook=_InstrumentedFastAPI._client_response_hook,
# Pass in tracer/meter to get __name__and __version__ of fastapi instrumentation
tracer=tracer,
meter=meter,
http_capture_headers_server_request=_InstrumentedFastAPI._http_capture_headers_server_request,
http_capture_headers_server_response=_InstrumentedFastAPI._http_capture_headers_server_response,
http_capture_headers_sanitize_fields=_InstrumentedFastAPI._http_capture_headers_sanitize_fields,
exclude_spans=_InstrumentedFastAPI._exclude_spans,
FastAPIInstrumentor.instrument_app(
self,
server_request_hook=self._server_request_hook,
client_request_hook=self._client_request_hook,
client_response_hook=self._client_response_hook,
tracer_provider=self._tracer_provider,
meter_provider=self._meter_provider,
excluded_urls=self._excluded_urls,
http_capture_headers_server_request=self._http_capture_headers_server_request,
http_capture_headers_server_response=self._http_capture_headers_server_response,
http_capture_headers_sanitize_fields=self._http_capture_headers_sanitize_fields,
exclude_spans=self._exclude_spans,
)
self._is_instrumented_by_opentelemetry = True
_InstrumentedFastAPI._instrumented_fastapi_apps.add(self)

def __del__(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# pylint: disable=too-many-lines

import unittest
from contextlib import ExitStack
from timeit import default_timer
from unittest.mock import Mock, patch

Expand Down Expand Up @@ -170,9 +171,14 @@ def setUp(self):
self._instrumentor = otel_fastapi.FastAPIInstrumentor()
self._app = self._create_app()
self._app.add_middleware(HTTPSRedirectMiddleware)
self._client = TestClient(self._app)
self._client = TestClient(self._app, base_url="https://testserver:443")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the root cause of tests failures. It turns out every request was being made twice because it was getting redirected to https. Before this PR that wasn't being instrumented correctly, so this was not being caught! I think that's another major bug this PR is fixing.

# run the lifespan, initialize the middleware stack
# this is more in-line with what happens in a real application when the server starts up
self._exit_stack = ExitStack()
self._exit_stack.enter_context(self._client)
Comment on lines +175 to +178
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The build_middleware_stack this PR patches is called on the startup event - which the TestClient only runs if called within a context manager, that's why @adriangb is using ExitStack here. FYI


def tearDown(self):
self._exit_stack.close()
super().tearDown()
self.env_patch.stop()
self.exclude_patch.stop()
Expand Down Expand Up @@ -205,11 +211,19 @@ async def _(param: str):
async def _():
return {"message": "ok"}

@app.get("/error")
async def _():
raise UnhandledException("This is an unhandled exception")

app.mount("/sub", app=sub_app)

return app


class UnhandledException(Exception):
pass


class TestBaseManualFastAPI(TestBaseFastAPI):
@classmethod
def setUpClass(cls):
Expand All @@ -220,6 +234,27 @@ def setUpClass(cls):

super(TestBaseManualFastAPI, cls).setUpClass()

def test_fastapi_unhandled_exception(self):
"""If the application has an unhandled error the instrumentation should capture that a 500 response is returned."""
try:
resp = self._client.get("/error")
assert (
resp.status_code == 500
), resp.content # pragma: no cover, for debugging this test if an exception is _not_ raised
except UnhandledException:
pass
else:
self.fail("Expected UnhandledException")

spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 3)
span = spans[0]
assert span.name == "GET /error http send"
assert span.attributes[SpanAttributes.HTTP_STATUS_CODE] == 500
span = spans[2]
assert span.name == "GET /error"
assert span.attributes[SpanAttributes.HTTP_TARGET] == "/error"

def test_sub_app_fastapi_call(self):
"""
This test is to ensure that a span in case of a sub app targeted contains the correct server url
Expand Down Expand Up @@ -976,6 +1011,10 @@ async def _(param: str):
async def _():
return {"message": "ok"}

@app.get("/error")
async def _():
raise UnhandledException("This is an unhandled exception")

app.mount("/sub", app=sub_app)

return app
Expand Down Expand Up @@ -1124,9 +1163,11 @@ def test_request(self):
def test_mulitple_way_instrumentation(self):
self._instrumentor.instrument_app(self._app)
count = 0
for middleware in self._app.user_middleware:
if middleware.cls is OpenTelemetryMiddleware:
app = self._app.middleware_stack
while app is not None:
if isinstance(app, OpenTelemetryMiddleware):
count += 1
app = getattr(app, "app", None)
self.assertEqual(count, 1)

def test_uninstrument_after_instrument(self):
Expand Down
Loading