Skip to content

Commit

Permalink
Merge pull request #751 from FlorentinD/aura-api/retry-auth-requests
Browse files Browse the repository at this point in the history
Retry aura api auth requests
  • Loading branch information
Mats-SX authored Sep 16, 2024
2 parents 1dfbae5 + eaf5777 commit 4117305
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 3 deletions.
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@
## Improvements

* The database connection is now validated before a session is created.
* Retry authentication requests.

## Other changes
27 changes: 24 additions & 3 deletions graphdatascience/session/aura_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def __init__(
def _init_request_session(self, credentials: Tuple[str, str]) -> requests.Session:
request_session = requests.Session()
request_session.headers = {"User-agent": f"neo4j-graphdatascience-v{__version__}"}
request_session.auth = AuraApi.Auth(oauth_url=f"{self._base_uri}/oauth/token", credentials=credentials)
request_session.auth = AuraApi.Auth(
oauth_url=f"{self._base_uri}/oauth/token", credentials=credentials, headers=request_session.headers
)
# dont retry on POST as its not idempotent
request_session.mount(
"https://",
Expand Down Expand Up @@ -329,11 +331,28 @@ def __init__(self, json: Dict[str, Any]) -> None:
def should_refresh(self) -> bool:
return self.refresh_at <= int(time.time())

def __init__(self, oauth_url: str, credentials: Tuple[str, str]) -> None:
def __init__(self, oauth_url: str, credentials: Tuple[str, str], headers: Dict[str, Any]) -> None:
self._token: Optional[AuraApi.Auth.Token] = None
self._logger = logging.getLogger()
self._oauth_url = oauth_url
self._credentials = credentials
self._request_session = self._init_request_session(headers)

def _init_request_session(self, headers: Dict[str, Any]) -> requests.Session:
request_session = requests.Session()
request_session.mount(
"https://",
HTTPAdapter(
max_retries=Retry(
allowed_methods=["POST"], # auth POST request is okay to retry
total=5,
status_forcelist=[429, 500, 502, 503, 504],
backoff_factor=0.1,
)
),
)
request_session.headers = headers
return request_session

def __call__(self, r: requests.PreparedRequest) -> requests.PreparedRequest:
r.headers["Authorization"] = f"Bearer {self._auth_token()}"
Expand All @@ -351,7 +370,9 @@ def _update_token(self) -> AuraApi.Auth.Token:

self._logger.debug("Updating oauth token")

resp = requests.post(self._oauth_url, data=data, auth=(self._credentials[0], self._credentials[1]))
resp = self._request_session.post(
self._oauth_url, data=data, auth=(self._credentials[0], self._credentials[1])
)

if resp.status_code >= 400:
raise AuraApiError(
Expand Down

0 comments on commit 4117305

Please sign in to comment.