Skip to content

Commit

Permalink
Use run_in_executor for getaddrinfo
Browse files Browse the repository at this point in the history
  • Loading branch information
NoahStapp committed Jan 15, 2025
1 parent 39ade36 commit 5e3bc65
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 8 deletions.
20 changes: 16 additions & 4 deletions pymongo/asynchronous/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import asyncio
import builtins
import functools
import socket
import sys
from typing import (
Expand All @@ -26,6 +27,7 @@
cast,
)

from pymongo._asyncio_executor import _PYMONGO_EXECUTOR
from pymongo.errors import (
OperationFailure,
)
Expand Down Expand Up @@ -70,14 +72,24 @@ async def inner(*args: Any, **kwargs: Any) -> Any:
return cast(F, inner)


async def getaddrinfo(host, port, **kwargs):
async def getaddrinfo(
host: Any, port: Any, **kwargs: Any
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if not _IS_SYNC:
loop = asyncio.get_running_loop()
return await loop.getaddrinfo( # type: ignore[assignment]
host, port, **kwargs
return await loop.run_in_executor( # type: ignore[return-value]
_PYMONGO_EXECUTOR, functools.partial(socket.getaddrinfo, host, port, **kwargs)
)
else:
return socket.getaddrinfo(host, port, **kwargs) # type: ignore[assignment]
return socket.getaddrinfo(host, port, **kwargs)


if sys.version_info >= (3, 10):
Expand Down
20 changes: 16 additions & 4 deletions pymongo/synchronous/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import asyncio
import builtins
import functools
import socket
import sys
from typing import (
Expand All @@ -26,6 +27,7 @@
cast,
)

from pymongo._asyncio_executor import _PYMONGO_EXECUTOR
from pymongo.errors import (
OperationFailure,
)
Expand Down Expand Up @@ -70,14 +72,24 @@ def inner(*args: Any, **kwargs: Any) -> Any:
return cast(F, inner)


def getaddrinfo(host, port, **kwargs):
def getaddrinfo(
host: Any, port: Any, **kwargs: Any
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if not _IS_SYNC:
loop = asyncio.get_running_loop()
return loop.getaddrinfo( # type: ignore[assignment]
host, port, **kwargs
return loop.run_in_executor( # type: ignore[return-value]
_PYMONGO_EXECUTOR, functools.partial(socket.getaddrinfo, host, port, **kwargs)
)
else:
return socket.getaddrinfo(host, port, **kwargs) # type: ignore[assignment]
return socket.getaddrinfo(host, port, **kwargs)


if sys.version_info >= (3, 10):
Expand Down

0 comments on commit 5e3bc65

Please sign in to comment.