From e7c8efe3fefd3ede816ef95ca3225b1a24602fa4 Mon Sep 17 00:00:00 2001 From: Riccardo Magliocchetti Date: Tue, 22 Oct 2024 15:44:35 +0200 Subject: [PATCH] green tests \o/ --- .../instrumentation/httpx/__init__.py | 49 +++++++++++++++++-- .../tests/test_httpx_integration.py | 37 +++++--------- 2 files changed, 56 insertions(+), 30 deletions(-) diff --git a/instrumentation/opentelemetry-instrumentation-httpx/src/opentelemetry/instrumentation/httpx/__init__.py b/instrumentation/opentelemetry-instrumentation-httpx/src/opentelemetry/instrumentation/httpx/__init__.py index 53f43a4819..92ac58ffec 100644 --- a/instrumentation/opentelemetry-instrumentation-httpx/src/opentelemetry/instrumentation/httpx/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-httpx/src/opentelemetry/instrumentation/httpx/__init__.py @@ -733,8 +733,18 @@ def _instrument(self, **kwargs): tracer_provider = kwargs.get("tracer_provider") self._request_hook = kwargs.get("request_hook") self._response_hook = kwargs.get("response_hook") - self._async_request_hook = kwargs.get("async_request_hook") - self._async_response_hook = kwargs.get("async_response_hook") + _async_request_hook = kwargs.get("async_request_hook") + self._async_request_hook = ( + _async_request_hook + if iscoroutinefunction(_async_request_hook) + else None + ) + _async_response_hook = kwargs.get("async_response_hook") + self._async_response_hook = ( + _async_response_hook + if iscoroutinefunction(_async_response_hook) + else None + ) _OpenTelemetrySemanticConventionStability._initialize() self._sem_conv_opt_in_mode = _OpenTelemetrySemanticConventionStability._get_opentelemetry_stability_opt_in_mode( @@ -826,7 +836,7 @@ def _handle_request_wrapper(self, wrapped, instance, args, kwargs): span.set_attribute( ERROR_TYPE, type(exception).__qualname__ ) - raise exception.with_traceback(exception.__traceback__) + raise exception return response @@ -895,7 +905,7 @@ async def _handle_async_request_wrapper( span.set_attribute( ERROR_TYPE, type(exception).__qualname__ ) - raise exception.with_traceback(exception.__traceback__) + raise exception return response @@ -927,6 +937,19 @@ def instrument_client( ) return + # FIXME: sharing state in the instrumentor instance maybe it's not that great, need to pass tracer and semconv to each + # instance separately + _OpenTelemetrySemanticConventionStability._initialize() + self._sem_conv_opt_in_mode = _OpenTelemetrySemanticConventionStability._get_opentelemetry_stability_opt_in_mode( + _OpenTelemetryStabilitySignalType.HTTP, + ) + self._tracer = get_tracer( + __name__, + instrumenting_library_version=__version__, + tracer_provider=tracer_provider, + schema_url=_get_schema_url(self._sem_conv_opt_in_mode), + ) + if iscoroutinefunction(request_hook): self._async_request_hook = request_hook self._request_hook = None @@ -947,6 +970,13 @@ def instrument_client( "handle_request", self._handle_request_wrapper, ) + for transport in client._mounts.values(): + # FIXME: check it's not wrapped already? + wrap_function_wrapper( + transport, + "handle_request", + self._handle_request_wrapper, + ) client._is_instrumented_by_opentelemetry = True if hasattr(client._transport, "handle_async_request"): wrap_function_wrapper( @@ -954,6 +984,13 @@ def instrument_client( "handle_async_request", self._handle_async_request_wrapper, ) + for transport in client._mounts.values(): + # FIXME: check it's not wrapped already? + wrap_function_wrapper( + transport, + "handle_async_request", + self._handle_async_request_wrapper, + ) client._is_instrumented_by_opentelemetry = True @staticmethod @@ -967,7 +1004,11 @@ def uninstrument_client( """ if hasattr(client._transport, "handle_request"): unwrap(client._transport, "handle_request") + for transport in client._mounts.values(): + unwrap(transport, "handle_request") client._is_instrumented_by_opentelemetry = False elif hasattr(client._transport, "handle_async_request"): unwrap(client._transport, "handle_async_request") + for transport in client._mounts.values(): + unwrap(transport, "handle_async_request") client._is_instrumented_by_opentelemetry = False diff --git a/instrumentation/opentelemetry-instrumentation-httpx/tests/test_httpx_integration.py b/instrumentation/opentelemetry-instrumentation-httpx/tests/test_httpx_integration.py index 8149f4db2f..9d88ae8cd4 100644 --- a/instrumentation/opentelemetry-instrumentation-httpx/tests/test_httpx_integration.py +++ b/instrumentation/opentelemetry-instrumentation-httpx/tests/test_httpx_integration.py @@ -167,8 +167,6 @@ def setUp(self): ) ) - HTTPXClientInstrumentor().instrument() - def print_spans(self, spans): for span in spans: print(span.name, span.attributes) @@ -751,8 +749,9 @@ def create_proxy_transport(self, url: str): def setUp(self): super().setUp() - HTTPXClientInstrumentor().instrument() self.client = self.create_client() + # FIXME: calling instrument() instead fixes 13*2 tests :( + HTTPXClientInstrumentor().instrument_client(self.client) def tearDown(self): HTTPXClientInstrumentor().uninstrument() @@ -792,7 +791,6 @@ def test_custom_tracer_provider(self): result = self.create_tracer_provider(resource=resource) tracer_provider, exporter = result - HTTPXClientInstrumentor().uninstrument() HTTPXClientInstrumentor().instrument( tracer_provider=tracer_provider ) @@ -802,7 +800,6 @@ def test_custom_tracer_provider(self): self.assertEqual(result.text, "Hello!") span = self.assert_span(exporter=exporter) self.assertIs(span.resource, resource) - HTTPXClientInstrumentor().uninstrument() def test_response_hook(self): response_hook_key = ( @@ -811,7 +808,6 @@ def test_response_hook(self): else "response_hook" ) response_hook_kwargs = {response_hook_key: self.response_hook} - HTTPXClientInstrumentor().uninstrument() HTTPXClientInstrumentor().instrument( tracer_provider=self.tracer_provider, **response_hook_kwargs, @@ -830,10 +826,8 @@ def test_response_hook(self): HTTP_RESPONSE_BODY: "Hello!", }, ) - HTTPXClientInstrumentor().uninstrument() def test_response_hook_sync_async_kwargs(self): - HTTPXClientInstrumentor().uninstrument() HTTPXClientInstrumentor().instrument( tracer_provider=self.tracer_provider, response_hook=_response_hook, @@ -845,7 +839,7 @@ def test_response_hook_sync_async_kwargs(self): self.assertEqual(result.text, "Hello!") span = self.assert_span() self.assertEqual( - dict(span.attributes), + span.attributes, { SpanAttributes.HTTP_METHOD: "GET", SpanAttributes.HTTP_URL: self.URL, @@ -853,7 +847,6 @@ def test_response_hook_sync_async_kwargs(self): HTTP_RESPONSE_BODY: "Hello!", }, ) - HTTPXClientInstrumentor().uninstrument() def test_request_hook(self): request_hook_key = ( @@ -862,7 +855,6 @@ def test_request_hook(self): else "request_hook" ) request_hook_kwargs = {request_hook_key: self.request_hook} - HTTPXClientInstrumentor().uninstrument() HTTPXClientInstrumentor().instrument( tracer_provider=self.tracer_provider, **request_hook_kwargs, @@ -873,10 +865,8 @@ def test_request_hook(self): self.assertEqual(result.text, "Hello!") span = self.assert_span() self.assertEqual(span.name, "GET" + self.URL) - HTTPXClientInstrumentor().uninstrument() def test_request_hook_sync_async_kwargs(self): - HTTPXClientInstrumentor().uninstrument() HTTPXClientInstrumentor().instrument( tracer_provider=self.tracer_provider, request_hook=_request_hook, @@ -888,10 +878,8 @@ def test_request_hook_sync_async_kwargs(self): self.assertEqual(result.text, "Hello!") span = self.assert_span() self.assertEqual(span.name, "GET" + self.URL) - HTTPXClientInstrumentor().uninstrument() def test_request_hook_no_span_update(self): - HTTPXClientInstrumentor().uninstrument() HTTPXClientInstrumentor().instrument( tracer_provider=self.tracer_provider, request_hook=self.no_update_request_hook, @@ -902,10 +890,8 @@ def test_request_hook_no_span_update(self): self.assertEqual(result.text, "Hello!") span = self.assert_span() self.assertEqual(span.name, "GET") - HTTPXClientInstrumentor().uninstrument() def test_not_recording(self): - HTTPXClientInstrumentor().uninstrument() with mock.patch("opentelemetry.trace.INVALID_SPAN") as mock_span: HTTPXClientInstrumentor().instrument( tracer_provider=trace.NoOpTracerProvider() @@ -921,10 +907,8 @@ def test_not_recording(self): self.assertTrue(mock_span.is_recording.called) self.assertFalse(mock_span.set_attribute.called) self.assertFalse(mock_span.set_status.called) - HTTPXClientInstrumentor().uninstrument() def test_suppress_instrumentation_new_client(self): - HTTPXClientInstrumentor().uninstrument() HTTPXClientInstrumentor().instrument() with suppress_http_instrumentation(): client = self.create_client() @@ -932,10 +916,8 @@ def test_suppress_instrumentation_new_client(self): self.assertEqual(result.text, "Hello!") self.assert_span(num_spans=0) - HTTPXClientInstrumentor().uninstrument() def test_instrument_client(self): - HTTPXClientInstrumentor().uninstrument() client = self.create_client() HTTPXClientInstrumentor().instrument_client(client) result = self.perform_request(self.URL, client=client) @@ -943,6 +925,8 @@ def test_instrument_client(self): self.assert_span(num_spans=1) def test_instrumentation_without_client(self): + + HTTPXClientInstrumentor().instrument() results = [ httpx.get(self.URL), httpx.request("GET", self.URL), @@ -961,6 +945,7 @@ def test_instrumentation_without_client(self): ) def test_uninstrument(self): + HTTPXClientInstrumentor().instrument() HTTPXClientInstrumentor().uninstrument() client = self.create_client() result = self.perform_request(self.URL, client=client) @@ -970,7 +955,6 @@ def test_uninstrument(self): self.assert_span(num_spans=0) def test_uninstrument_client(self): - HTTPXClientInstrumentor().uninstrument() HTTPXClientInstrumentor().uninstrument_client(self.client) result = self.perform_request(self.URL) @@ -979,6 +963,7 @@ def test_uninstrument_client(self): self.assert_span(num_spans=0) def test_uninstrument_new_client(self): + HTTPXClientInstrumentor().instrument() client1 = self.create_client() HTTPXClientInstrumentor().uninstrument_client(client1) @@ -1001,6 +986,7 @@ def test_uninstrument_new_client(self): def test_instrument_proxy(self): proxy_mounts = self.create_proxy_mounts() + HTTPXClientInstrumentor().instrument() client = self.create_client(mounts=proxy_mounts) self.perform_request(self.URL, client=client) self.assert_span(num_spans=1) @@ -1027,7 +1013,6 @@ def print_handler(self, client): return handler def test_instrument_client_with_proxy(self): - HTTPXClientInstrumentor().uninstrument() proxy_mounts = self.create_proxy_mounts() client = self.create_client(mounts=proxy_mounts) self.assert_proxy_mounts( @@ -1047,6 +1032,7 @@ def test_instrument_client_with_proxy(self): def test_uninstrument_client_with_proxy(self): proxy_mounts = self.create_proxy_mounts() + HTTPXClientInstrumentor().instrument() client = self.create_client(mounts=proxy_mounts) self.assert_proxy_mounts( client._mounts.values(), @@ -1109,7 +1095,7 @@ def create_client( transport: typing.Optional[SyncOpenTelemetryTransport] = None, **kwargs, ): - return httpx.Client(**kwargs) + return httpx.Client(transport=transport, **kwargs) def perform_request( self, @@ -1230,6 +1216,7 @@ class TestAsyncInstrumentationIntegration(BaseTestCases.BaseInstrumentorTest): def setUp(self): super().setUp() self.client2 = self.create_client() + HTTPXClientInstrumentor().instrument_client(self.client2) def create_client( self, @@ -1283,7 +1270,6 @@ def test_async_response_hook_does_nothing_if_not_coroutine(self): SpanAttributes.HTTP_STATUS_CODE: 200, }, ) - HTTPXClientInstrumentor().uninstrument() def test_async_request_hook_does_nothing_if_not_coroutine(self): HTTPXClientInstrumentor().instrument( @@ -1296,4 +1282,3 @@ def test_async_request_hook_does_nothing_if_not_coroutine(self): self.assertEqual(result.text, "Hello!") span = self.assert_span() self.assertEqual(span.name, "GET") - HTTPXClientInstrumentor().uninstrument()