From 19afa7a181d55d803974969920890351d2b723e7 Mon Sep 17 00:00:00 2001 From: Karoline Pauls Date: Wed, 15 Jan 2025 00:26:58 +0000 Subject: [PATCH] transport: cache default ssl context --- httpx/_transports/default.py | 53 ++++++++++++++++++++++++++++----- tests/client/test_transports.py | 42 ++++++++++++++++++++++++++ 2 files changed, 87 insertions(+), 8 deletions(-) create mode 100644 tests/client/test_transports.py diff --git a/httpx/_transports/default.py b/httpx/_transports/default.py index d5aa05ff23..9395679ddc 100644 --- a/httpx/_transports/default.py +++ b/httpx/_transports/default.py @@ -28,6 +28,7 @@ import contextlib import typing +from functools import lru_cache from types import TracebackType if typing.TYPE_CHECKING: @@ -70,6 +71,34 @@ HTTPCORE_EXC_MAP: dict[type[Exception], type[httpx.HTTPError]] = {} +_DEFAULT_VERIFY = True +_DEFAULT_CERT = None +_DEFAULT_TRUST_ENV = True + + +@lru_cache(maxsize=1) +def _create_default_ssl_context() -> ssl.SSLContext: + return create_ssl_context( + verify=_DEFAULT_VERIFY, + cert=_DEFAULT_CERT, + trust_env=_DEFAULT_TRUST_ENV, + ) + + +def _create_or_reuse_ssl_context( + verify: ssl.SSLContext | str | bool = _DEFAULT_VERIFY, + cert: CertTypes | None = _DEFAULT_CERT, + trust_env: bool = _DEFAULT_TRUST_ENV, +) -> ssl.SSLContext: + if (verify, cert, trust_env) == ( + _DEFAULT_VERIFY, + _DEFAULT_CERT, + _DEFAULT_TRUST_ENV, + ): + return _create_default_ssl_context() + else: + return create_ssl_context(verify=verify, cert=cert, trust_env=trust_env) + def _load_httpcore_exceptions() -> dict[type[Exception], type[httpx.HTTPError]]: import httpcore @@ -135,9 +164,9 @@ def close(self) -> None: class HTTPTransport(BaseTransport): def __init__( self, - verify: ssl.SSLContext | str | bool = True, - cert: CertTypes | None = None, - trust_env: bool = True, + verify: ssl.SSLContext | str | bool = _DEFAULT_VERIFY, + cert: CertTypes | None = _DEFAULT_CERT, + trust_env: bool = _DEFAULT_TRUST_ENV, http1: bool = True, http2: bool = False, limits: Limits = DEFAULT_LIMITS, @@ -150,7 +179,11 @@ def __init__( import httpcore proxy = Proxy(url=proxy) if isinstance(proxy, (str, URL)) else proxy - ssl_context = create_ssl_context(verify=verify, cert=cert, trust_env=trust_env) + ssl_context = _create_or_reuse_ssl_context( + verify=verify, + cert=cert, + trust_env=trust_env, + ) if proxy is None: self._pool = httpcore.ConnectionPool( @@ -279,9 +312,9 @@ async def aclose(self) -> None: class AsyncHTTPTransport(AsyncBaseTransport): def __init__( self, - verify: ssl.SSLContext | str | bool = True, - cert: CertTypes | None = None, - trust_env: bool = True, + verify: ssl.SSLContext | str | bool = _DEFAULT_VERIFY, + cert: CertTypes | None = _DEFAULT_CERT, + trust_env: bool = _DEFAULT_TRUST_ENV, http1: bool = True, http2: bool = False, limits: Limits = DEFAULT_LIMITS, @@ -294,7 +327,11 @@ def __init__( import httpcore proxy = Proxy(url=proxy) if isinstance(proxy, (str, URL)) else proxy - ssl_context = create_ssl_context(verify=verify, cert=cert, trust_env=trust_env) + ssl_context = _create_or_reuse_ssl_context( + verify=verify, + cert=cert, + trust_env=trust_env, + ) if proxy is None: self._pool = httpcore.AsyncConnectionPool( diff --git a/tests/client/test_transports.py b/tests/client/test_transports.py new file mode 100644 index 0000000000..e4f1f9a834 --- /dev/null +++ b/tests/client/test_transports.py @@ -0,0 +1,42 @@ +import pytest + +from httpx import AsyncHTTPTransport, HTTPTransport +from httpx._transports.default import _DEFAULT_TRUST_ENV, _DEFAULT_VERIFY + +DIFFERENT_VERIFY = {"verify": not _DEFAULT_VERIFY} +DIFFERENT_CERT_ENV = {"cert": ()} +DIFFERENT_TRUST_ENV = {"trust_env": not _DEFAULT_TRUST_ENV} + + +@pytest.mark.parametrize("transport", [HTTPTransport, AsyncHTTPTransport]) +def test_default_ssl_config_cached(transport): + transport_1 = transport() + transport_2 = transport() + assert transport_1._pool._ssl_context is not None + assert transport_2._pool._ssl_context is not None + + assert transport_1._pool._ssl_context is transport_2._pool._ssl_context + + +@pytest.mark.parametrize("transport", [HTTPTransport, AsyncHTTPTransport]) +@pytest.mark.parametrize( + ("kwargs_1", "kwargs_2"), + [ + ({}, DIFFERENT_VERIFY), + (DIFFERENT_VERIFY, {}), + (DIFFERENT_VERIFY, DIFFERENT_VERIFY), + ({}, DIFFERENT_CERT_ENV), + (DIFFERENT_CERT_ENV, {}), + (DIFFERENT_CERT_ENV, DIFFERENT_CERT_ENV), + ({}, DIFFERENT_TRUST_ENV), + (DIFFERENT_TRUST_ENV, {}), + (DIFFERENT_TRUST_ENV, DIFFERENT_TRUST_ENV), + ], +) +def test_non_default_ssl_config_not_cached(transport, kwargs_1, kwargs_2): + transport_1 = transport(**kwargs_1) + transport_2 = transport(**kwargs_2) + assert transport_1._pool._ssl_context is not None + assert transport_2._pool._ssl_context is not None + + assert transport_1._pool._ssl_context is not transport_2._pool._ssl_context