Skip to content

Commit

Permalink
Merge pull request #752 from FlorentinD/arrow-client-set-useragent
Browse files Browse the repository at this point in the history
Set user agent for Arrow client
  • Loading branch information
FlorentinD authored Sep 18, 2024
2 parents 4117305 + 955c04b commit 0455f75
Showing 1 changed file with 28 additions and 4 deletions.
32 changes: 28 additions & 4 deletions graphdatascience/query_runner/gds_arrow_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
from neo4j.exceptions import ClientError
from pandas import DataFrame
from pyarrow import ChunkedArray, Schema, Table, chunked_array, flight
from pyarrow import __version__ as arrow_version
from pyarrow._flight import FlightStreamReader, FlightStreamWriter
from pyarrow.flight import ClientMiddleware, ClientMiddlewareFactory
from pyarrow.types import is_dictionary

from ..server_version.server_version import ServerVersion
from ..version import __version__
from .arrow_endpoint_version import ArrowEndpointVersion
from .arrow_info import ArrowInfo
from .query_runner import QueryRunner
Expand Down Expand Up @@ -75,7 +77,8 @@ def __init__(
client_options: Dict[str, Any] = {"disable_server_verification": disable_server_verification}
if auth:
self._auth_middleware = AuthMiddleware(auth)
client_options["middleware"] = [AuthFactory(self._auth_middleware)]
user_agent = f"neo4j-graphdatascience-v{__version__} pyarrow-v{arrow_version}"
client_options["middleware"] = [AuthFactory(self._auth_middleware), UserAgentFactory(useragent=user_agent)]
if tls_root_certs:
client_options["tls_root_certs"] = tls_root_certs

Expand Down Expand Up @@ -219,12 +222,33 @@ def handle_flight_error(e: Exception):
raise e


class AuthFactory(ClientMiddlewareFactory): # type: ignore
def __init__(self, middleware: "AuthMiddleware", *args: Any, **kwargs: Any) -> None:
class UserAgentFactory(ClientMiddlewareFactory):
def __init__(self, useragent: str, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._middleware = UserAgentMiddleware(useragent)

def start_call(self, info: Any) -> ClientMiddleware:
return self._middleware


class UserAgentMiddleware(ClientMiddleware):
def __init__(self, useragent: str, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._useragent = useragent

def sending_headers(self) -> Dict[str, str]:
return {"x-gds-user-agent": self._useragent}

def received_headers(self, headers: Dict[str, Any]) -> None:
pass


class AuthFactory(ClientMiddlewareFactory):
def __init__(self, middleware: AuthMiddleware, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._middleware = middleware

def start_call(self, info: Any) -> "AuthMiddleware":
def start_call(self, info: Any) -> AuthMiddleware:
return self._middleware


Expand Down

0 comments on commit 0455f75

Please sign in to comment.