diff --git a/graphdatascience/query_runner/gds_arrow_client.py b/graphdatascience/query_runner/gds_arrow_client.py index f4bbbedde..54a8ad4a2 100644 --- a/graphdatascience/query_runner/gds_arrow_client.py +++ b/graphdatascience/query_runner/gds_arrow_client.py @@ -10,11 +10,13 @@ from neo4j.exceptions import ClientError from pandas import DataFrame from pyarrow import ChunkedArray, Schema, Table, chunked_array, flight +from pyarrow import __version__ as arrow_version from pyarrow._flight import FlightStreamReader, FlightStreamWriter from pyarrow.flight import ClientMiddleware, ClientMiddlewareFactory from pyarrow.types import is_dictionary from ..server_version.server_version import ServerVersion +from ..version import __version__ from .arrow_endpoint_version import ArrowEndpointVersion from .arrow_info import ArrowInfo from .query_runner import QueryRunner @@ -75,7 +77,8 @@ def __init__( client_options: Dict[str, Any] = {"disable_server_verification": disable_server_verification} if auth: self._auth_middleware = AuthMiddleware(auth) - client_options["middleware"] = [AuthFactory(self._auth_middleware)] + user_agent = f"neo4j-graphdatascience-v{__version__} pyarrow-v{arrow_version}" + client_options["middleware"] = [AuthFactory(self._auth_middleware), UserAgentFactory(useragent=user_agent)] if tls_root_certs: client_options["tls_root_certs"] = tls_root_certs @@ -219,12 +222,33 @@ def handle_flight_error(e: Exception): raise e -class AuthFactory(ClientMiddlewareFactory): # type: ignore - def __init__(self, middleware: "AuthMiddleware", *args: Any, **kwargs: Any) -> None: +class UserAgentFactory(ClientMiddlewareFactory): + def __init__(self, useragent: str, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._middleware = UserAgentMiddleware(useragent) + + def start_call(self, info: Any) -> ClientMiddleware: + return self._middleware + + +class UserAgentMiddleware(ClientMiddleware): + def __init__(self, useragent: str, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._useragent = useragent + + def sending_headers(self) -> Dict[str, str]: + return {"x-gds-user-agent": self._useragent} + + def received_headers(self, headers: Dict[str, Any]) -> None: + pass + + +class AuthFactory(ClientMiddlewareFactory): + def __init__(self, middleware: AuthMiddleware, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._middleware = middleware - def start_call(self, info: Any) -> "AuthMiddleware": + def start_call(self, info: Any) -> AuthMiddleware: return self._middleware