Skip to content

Commit

Permalink
Iterate refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
tomchristie committed Nov 11, 2024
1 parent 228e791 commit ee9cfe1
Show file tree
Hide file tree
Showing 16 changed files with 532 additions and 821 deletions.
25 changes: 11 additions & 14 deletions httpcore/_async/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
from .._backends.auto import AutoBackend
from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream
from .._exceptions import ConnectError, ConnectTimeout
from .._models import Origin, Request, Response
from .._models import Origin, Request
from .._ssl import default_ssl_context
from .._synchronization import AsyncLock
from .._synchronization import AsyncSemaphore
from .._trace import Trace
from .http11 import AsyncHTTP11Connection
from .interfaces import AsyncConnectionInterface
from .interfaces import AsyncConnectionInterface, StartResponse

RETRIES_BACKOFF_FACTOR = 0.5 # 0s, 0.5s, 1s, 2s, 4s, etc.

Expand Down Expand Up @@ -63,10 +63,10 @@ def __init__(
)
self._connection: AsyncConnectionInterface | None = None
self._connect_failed: bool = False
self._request_lock = AsyncLock()
self._request_lock = AsyncSemaphore(bound=1)
self._socket_options = socket_options

async def handle_async_request(self, request: Request) -> Response:
async def iterate_response(self, request: Request) -> typing.AsyncIterator[StartResponse | bytes]:
if not self.can_handle_request(request.url.origin):
raise RuntimeError(
f"Attempted to send request to {request.url.origin} on connection to {self._origin}"
Expand Down Expand Up @@ -100,7 +100,11 @@ async def handle_async_request(self, request: Request) -> Response:
self._connect_failed = True
raise exc

return await self._connection.handle_async_request(request)
iterator = self._connection.iterate_response(request)
start_response = await anext(iterator)
yield start_response
async for body in iterator:
yield body

async def _connect(self, request: Request) -> AsyncNetworkStream:
timeouts = request.extensions.get("timeout", {})
Expand Down Expand Up @@ -174,14 +178,7 @@ async def aclose(self) -> None:

def is_available(self) -> bool:
if self._connection is None:
# If HTTP/2 support is enabled, and the resulting connection could
# end up as HTTP/2 then we should indicate the connection as being
# available to service multiple requests.
return (
self._http2
and (self._origin.scheme == b"https" or not self._http1)
and not self._connect_failed
)
return False
return self._connection.is_available()

def has_expired(self) -> bool:
Expand Down
272 changes: 54 additions & 218 deletions httpcore/_async/connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,15 @@

import ssl
import sys
import types
import typing

from .._backends.auto import AutoBackend
from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend
from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol
from .._models import Origin, Proxy, Request, Response
from .._synchronization import AsyncEvent, AsyncShieldCancellation, AsyncThreadLock
from .._exceptions import UnsupportedProtocol
from .._models import Origin, Proxy, Request
from .._synchronization import AsyncSemaphore
from .connection import AsyncHTTPConnection
from .interfaces import AsyncConnectionInterface, AsyncRequestInterface


class AsyncPoolRequest:
def __init__(self, request: Request) -> None:
self.request = request
self.connection: AsyncConnectionInterface | None = None
self._connection_acquired = AsyncEvent()

def assign_to_connection(self, connection: AsyncConnectionInterface | None) -> None:
self.connection = connection
self._connection_acquired.set()

def clear_connection(self) -> None:
self.connection = None
self._connection_acquired = AsyncEvent()

async def wait_for_connection(
self, timeout: float | None = None
) -> AsyncConnectionInterface:
if self.connection is None:
await self._connection_acquired.wait(timeout=timeout)
assert self.connection is not None
return self.connection

def is_queued(self) -> bool:
return self.connection is None
from .interfaces import AsyncConnectionInterface, AsyncRequestInterface, StartResponse


class AsyncConnectionPool(AsyncRequestInterface):
Expand All @@ -49,6 +22,7 @@ def __init__(
self,
ssl_context: ssl.SSLContext | None = None,
proxy: Proxy | None = None,
concurrency_limit: int = 100,
max_connections: int | None = 10,
max_keepalive_connections: int | None = None,
keepalive_expiry: float | None = None,
Expand Down Expand Up @@ -102,6 +76,7 @@ def __init__(
self._max_keepalive_connections = min(
self._max_connections, self._max_keepalive_connections
)
self._limits = AsyncSemaphore(bound=concurrency_limit)

self._keepalive_expiry = keepalive_expiry
self._http1 = http1
Expand All @@ -123,7 +98,7 @@ def __init__(
# We only mutate the state of the connection pool within an 'optional_thread_lock'
# context. This holds a threading lock unless we're running in async mode,
# in which case it is a no-op.
self._optional_thread_lock = AsyncThreadLock()
# self._optional_thread_lock = AsyncThreadLock()

def create_connection(self, origin: Origin) -> AsyncConnectionInterface:
if self._proxy is not None:
Expand Down Expand Up @@ -196,7 +171,7 @@ def connections(self) -> list[AsyncConnectionInterface]:
"""
return list(self._connections)

async def handle_async_request(self, request: Request) -> Response:
async def iterate_response(self, request: Request) -> typing.AsyncIterator[StartResponse | bytes]:
"""
Send an HTTP request, and return an HTTP response.
Expand All @@ -212,145 +187,50 @@ async def handle_async_request(self, request: Request) -> Response:
f"Request URL has an unsupported protocol '{scheme}://'."
)

timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("pool", None)

with self._optional_thread_lock:
# Add the incoming request to our request queue.
pool_request = AsyncPoolRequest(request)
self._requests.append(pool_request)

try:
while True:
with self._optional_thread_lock:
# Assign incoming requests to available connections,
# closing or creating new connections as required.
closing = self._assign_requests_to_connections()
await self._close_connections(closing)

# Wait until this request has an assigned connection.
connection = await pool_request.wait_for_connection(timeout=timeout)

try:
# Send the request on the assigned connection.
response = await connection.handle_async_request(
pool_request.request
)
except ConnectionNotAvailable:
# In some cases a connection may initially be available to
# handle a request, but then become unavailable.
#
# In this case we clear the connection and try again.
pool_request.clear_connection()
else:
break # pragma: nocover

except BaseException as exc:
with self._optional_thread_lock:
# For any exception or cancellation we remove the request from
# the queue, and then re-assign requests to connections.
self._requests.remove(pool_request)
closing = self._assign_requests_to_connections()

await self._close_connections(closing)
raise exc from None

# Return the response. Note that in this case we still have to manage
# the point at which the response is closed.
assert isinstance(response.stream, typing.AsyncIterable)
return Response(
status=response.status,
headers=response.headers,
content=PoolByteStream(
stream=response.stream, pool_request=pool_request, pool=self
),
extensions=response.extensions,
)

def _assign_requests_to_connections(self) -> list[AsyncConnectionInterface]:
"""
Manage the state of the connection pool, assigning incoming
requests to connections as available.
Called whenever a new request is added or removed from the pool.
Any closing connections are returned, allowing the I/O for closing
those connections to be handled seperately.
"""
closing_connections = []

# First we handle cleaning up any connections that are closed,
# have expired their keep-alive, or surplus idle connections.
for connection in list(self._connections):
if connection.is_closed():
# log: "removing closed connection"
self._connections.remove(connection)
elif connection.has_expired():
# log: "closing expired connection"
self._connections.remove(connection)
closing_connections.append(connection)
elif (
connection.is_idle()
and len([connection.is_idle() for connection in self._connections])
> self._max_keepalive_connections
):
# log: "closing idle connection"
self._connections.remove(connection)
closing_connections.append(connection)

# Assign queued requests to connections.
queued_requests = [request for request in self._requests if request.is_queued()]
for pool_request in queued_requests:
origin = pool_request.request.url.origin
available_connections = [
connection
for connection in self._connections
if connection.can_handle_request(origin) and connection.is_available()
]
idle_connections = [
connection for connection in self._connections if connection.is_idle()
]

# There are three cases for how we may be able to handle the request:
#
# 1. There is an existing connection that can handle the request.
# 2. We can create a new connection to handle the request.
# 3. We can close an idle connection and then create a new connection
# to handle the request.
if available_connections:
# log: "reusing existing connection"
connection = available_connections[0]
pool_request.assign_to_connection(connection)
elif len(self._connections) < self._max_connections:
# log: "creating new connection"
connection = self.create_connection(origin)
self._connections.append(connection)
pool_request.assign_to_connection(connection)
elif idle_connections:
# log: "closing idle connection"
connection = idle_connections[0]
self._connections.remove(connection)
closing_connections.append(connection)
# log: "creating new connection"
connection = self.create_connection(origin)
self._connections.append(connection)
pool_request.assign_to_connection(connection)

return closing_connections

async def _close_connections(self, closing: list[AsyncConnectionInterface]) -> None:
# Close connections which have been removed from the pool.
with AsyncShieldCancellation():
for connection in closing:
await connection.aclose()
# timeouts = request.extensions.get("timeout", {})
# timeout = timeouts.get("pool", None)

async with self._limits:
connection = self._get_connection(request)
iterator = connection.iterate_response(request)
try:
response_start = await anext(iterator)
# Return the response status and headers.
yield response_start
# Return the response.
async for event in iterator:
yield event
finally:
await iterator.aclose()
closing = self._close_connections()
for conn in closing:
await conn.aclose()

def _get_connection(self, request):
origin = request.url.origin
for connection in self._connections:
if connection.can_handle_request(origin) and connection.is_available():
return connection

connection = self.create_connection(origin)
self._connections.append(connection)
return connection

def _close_connections(self):
closing = [conn for conn in self._connections if conn.has_expired()]
self._connections = [
conn for conn in self._connections
if not (conn.has_expired() or conn.is_closed())
]
return closing

async def aclose(self) -> None:
# Explicitly close the connection pool.
# Clears all existing requests and connections.
with self._optional_thread_lock:
closing_connections = list(self._connections)
self._connections = []
await self._close_connections(closing_connections)
closing = list(self._connections)
self._connections = []
for conn in closing:
await conn.aclose()

async def __aenter__(self) -> AsyncConnectionPool:
return self
Expand All @@ -365,56 +245,12 @@ async def __aexit__(

def __repr__(self) -> str:
class_name = self.__class__.__name__
with self._optional_thread_lock:
request_is_queued = [request.is_queued() for request in self._requests]
connection_is_idle = [
connection.is_idle() for connection in self._connections
]

num_active_requests = request_is_queued.count(False)
num_queued_requests = request_is_queued.count(True)
num_active_connections = connection_is_idle.count(False)
num_idle_connections = connection_is_idle.count(True)

requests_info = (
f"Requests: {num_active_requests} active, {num_queued_requests} queued"
)
connection_is_idle = [
connection.is_idle() for connection in self._connections
]
num_active_connections = connection_is_idle.count(False)
num_idle_connections = connection_is_idle.count(True)
connection_info = (
f"Connections: {num_active_connections} active, {num_idle_connections} idle"
)

return f"<{class_name} [{requests_info} | {connection_info}]>"


class PoolByteStream:
def __init__(
self,
stream: typing.AsyncIterable[bytes],
pool_request: AsyncPoolRequest,
pool: AsyncConnectionPool,
) -> None:
self._stream = stream
self._pool_request = pool_request
self._pool = pool
self._closed = False

async def __aiter__(self) -> typing.AsyncIterator[bytes]:
try:
async for part in self._stream:
yield part
except BaseException as exc:
await self.aclose()
raise exc from None

async def aclose(self) -> None:
if not self._closed:
self._closed = True
with AsyncShieldCancellation():
if hasattr(self._stream, "aclose"):
await self._stream.aclose()

with self._pool._optional_thread_lock:
self._pool._requests.remove(self._pool_request)
closing = self._pool._assign_requests_to_connections()

await self._pool._close_connections(closing)
return f"<{class_name} [{connection_info}]>"
Loading

0 comments on commit ee9cfe1

Please sign in to comment.