Skip to content

Commit

Permalink
feat: Refactor base HTTP client to enable request limits to work
Browse files Browse the repository at this point in the history
  • Loading branch information
akalex committed Jan 17, 2024
1 parent 5df05bd commit 46978fb
Show file tree
Hide file tree
Showing 11 changed files with 202 additions and 104 deletions.
12 changes: 0 additions & 12 deletions .isort.cfg

This file was deleted.

9 changes: 5 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,27 @@ repos:
- id: check-yaml

- repo: https://github.com/psf/black
rev: 22.10.0
rev: 23.1.0
hooks:
- id: black
language_version: python3.8
args: [--line-length=120, --skip-string-normalization]

- repo: https://github.com/pycqa/flake8
rev: 5.0.4
rev: 6.0.0
hooks:
- id: flake8

- repo: https://github.com/pycqa/isort
rev: 5.10.1
rev: 5.11.5
hooks:
- id: isort
stages: [commit]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.991
rev: v1.1.1
hooks:
- id: mypy
args: [--no-error-summary, --hide-error-codes, --follow-imports=skip]
files: ^async_firebase/
additional_dependencies: [types-setuptools]
5 changes: 4 additions & 1 deletion CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# Changelog

## 3.4.0
* Refactored ``async_firebase.base.AsyncClientBase`` to take advantage of connection pool. So the HTTP client will be created once during class ``async_firebase.client.AsyncFirebaseClient`` instantiation.

## 3.3.0
* async_firebase now works with python 3.12
* `async_firebase` now works with python 3.12

## 3.2.0
* ``AsyncFirebaseClient`` empower with advanced features to configure request behaviour such as timeout, or connection pooling.
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ help:
install:
@echo "$(BOLD)Installing package$(RESET)"
@poetry config virtualenvs.create false
@poetry install --only main --no-root
@poetry install --only main
@echo "$(BOLD)Done!$(RESET)"

.PHONY: update
Expand Down
1 change: 1 addition & 0 deletions async_firebase/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from .client import AsyncFirebaseClient # noqa


root_logger = logging.getLogger("async_firebase")
if root_logger.level == logging.NOTSET:
root_logger.setLevel(logging.WARN)
90 changes: 46 additions & 44 deletions async_firebase/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,16 @@
import uuid
from datetime import datetime, timedelta
from email.mime.nonmultipart import MIMENonMultipart
from importlib.metadata import version
from pathlib import PurePath
from urllib.parse import urlencode, urljoin
from urllib.parse import urlencode

import httpx
from google.oauth2 import service_account # type: ignore

import pkg_resources # type: ignore
from async_firebase._config import (
DEFAULT_REQUEST_LIMITS,
DEFAULT_REQUEST_TIMEOUT,
RequestLimits,
RequestTimeout,
)
from async_firebase._config import DEFAULT_REQUEST_LIMITS, DEFAULT_REQUEST_TIMEOUT, RequestLimits, RequestTimeout
from async_firebase.messages import FCMBatchResponse, FCMResponse
from async_firebase.utils import (
FCMBatchResponseHandler,
FCMResponseHandler,
serialize_mime_message,
)
from async_firebase.utils import FCMBatchResponseHandler, FCMResponseHandler, join_url, serialize_mime_message


class AsyncClientBase:
Expand Down Expand Up @@ -71,6 +62,22 @@ def __init__(

self._request_timeout = request_timeout
self._request_limits = request_limits
self._http_client: t.Optional[httpx.AsyncClient] = None

@property
def _client(self) -> httpx.AsyncClient:
def _create_http_client() -> httpx.AsyncClient:
return httpx.AsyncClient(
timeout=httpx.Timeout(**self._request_timeout.__dict__),
limits=httpx.Limits(**self._request_limits.__dict__),
)

if self._http_client is None:
self._http_client = _create_http_client()
elif self._client.is_closed:
self._http_client = _create_http_client()

return self._http_client

def creds_from_service_account_info(self, service_account_info: t.Dict[str, str]) -> None:
"""
Expand Down Expand Up @@ -109,9 +116,8 @@ async def _get_access_token(self) -> str:
}
).encode("utf-8")

async with httpx.AsyncClient() as client:
response: httpx.Response = await client.post(self.TOKEN_URL, data=data, headers=headers)
response_data = response.json()
response: httpx.Response = await self._client.post(self.TOKEN_URL, content=data, headers=headers)
response_data = response.json()

self._credentials.expiry = datetime.utcnow() + timedelta(seconds=response_data["expires_in"])
self._credentials.token = response_data["access_token"]
Expand Down Expand Up @@ -159,7 +165,7 @@ async def prepare_headers(self) -> t.Dict[str, str]:
"Content-Type": "application/json; UTF-8",
"X-Request-Id": self.get_request_id(),
"X-GOOG-API-FORMAT-VERSION": "2",
"X-FIREBASE-CLIENT": "async-firebase/{0}".format(pkg_resources.get_distribution("async-firebase").version),
"X-FIREBASE-CLIENT": "async-firebase/{0}".format(version("async-firebase")),
}

async def send_request(
Expand All @@ -180,34 +186,30 @@ async def send_request(
:param content: request content
:return: HTTP response
"""
async with httpx.AsyncClient(
base_url=self.BASE_URL,
timeout=httpx.Timeout(**self._request_timeout.__dict__),
limits=httpx.Limits(**self._request_limits.__dict__),
) as client:
url = join_url(self.BASE_URL, self.FCM_ENDPOINT.format(project_id=self._credentials.project_id), uri)
logging.debug(
"Requesting POST %s, payload: %s, content: %s, headers: %s",
url,
json_payload,
content,
headers,
)
try:
raw_fcm_response: httpx.Response = await self._client.post(
url,
json=json_payload,
headers=headers or await self.prepare_headers(),
content=content,
)
raw_fcm_response.raise_for_status()
except httpx.HTTPError as exc:
response = response_handler.handle_error(exc)
else:
logging.debug(
"Requesting POST %s, payload: %s, content: %s, headers: %s",
urljoin(self.BASE_URL, self.FCM_ENDPOINT.format(project_id=self._credentials.project_id)),
json_payload,
content,
headers,
"Response Code: %s, Time spent to make a request: %s",
raw_fcm_response.status_code,
raw_fcm_response.elapsed,
)
try:
raw_fcm_response: httpx.Response = await client.post(
uri,
json=json_payload,
headers=headers or await self.prepare_headers(),
content=content,
)
raw_fcm_response.raise_for_status()
except httpx.HTTPError as exc:
response = response_handler.handle_error(exc)
else:
logging.debug(
"Response Code: %s, Time spent to make a request: %s",
raw_fcm_response.status_code,
raw_fcm_response.elapsed,
)
response = response_handler.handle_response(raw_fcm_response)
response = response_handler.handle_response(raw_fcm_response)

return response
9 changes: 3 additions & 6 deletions async_firebase/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,7 @@

import httpx

from async_firebase.base import ( # noqa: F401
AsyncClientBase,
RequestLimits,
RequestTimeout,
)
from async_firebase.base import AsyncClientBase, RequestLimits, RequestTimeout # noqa: F401
from async_firebase.encoders import aps_encoder
from async_firebase.messages import (
AndroidConfig,
Expand All @@ -45,6 +41,7 @@
serialize_mime_message,
)


DEFAULT_TTL = 604800
BATCH_MAX_MESSAGES = MULTICAST_MESSAGE_MAX_DEVICE_TOKENS = 500

Expand Down Expand Up @@ -74,7 +71,7 @@ def assemble_push_notification(
has_apns_config = True if apns_config and apns_config.payload else False
if has_apns_config:
# avoid mutation of active message
message.apns = replace(message.apns)
message.apns = replace(message.apns) # type: ignore
message.apns.payload = aps_encoder(apns_config.payload.aps) # type: ignore

push_notification: t.Dict[str, t.Any] = cleanup_firebase_message(
Expand Down
36 changes: 35 additions & 1 deletion async_firebase/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from email.mime.multipart import MIMEMultipart
from email.mime.nonmultipart import MIMENonMultipart
from email.parser import FeedParser
from urllib.parse import quote, urlencode, urljoin

import httpx

Expand All @@ -26,6 +27,40 @@
from async_firebase.messages import FCMBatchResponse, FCMResponse


def join_url(
base: str,
*parts: t.Union[str, int],
params: t.Optional[dict] = None,
leading_slash: bool = False,
trailing_slash: bool = False,
) -> str:
"""Construct a full ("absolute") URL by combining a "base URL" (base) with another URL (url) parts.
:param base: base URL part
:param parts: another url parts that should be joined
:param params: dict with query params
:param leading_slash: flag to force leading slash
:param trailing_slash: flag to force trailing slash
:return: full URL
"""
url = base
if parts:
url = '/'.join([base.strip('/'), quote('/'.join(map(lambda x: str(x).strip('/'), parts)))])

# trailing slash can be important
if trailing_slash:
url = f'{url}/'
# as well as a leading slash
if leading_slash:
url = f'/{url}'

if params:
url = urljoin(url, '?{}'.format(urlencode(params)))

return url


def remove_null_values(dict_value: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
"""Remove Falsy values from the dictionary."""
return {k: v for k, v in dict_value.items() if v not in [None, [], {}]}
Expand Down Expand Up @@ -126,7 +161,6 @@ def serialize_mime_message(


class FCMResponseHandlerBase(ABC, t.Generic[FCMResponseType]):

ERROR_CODE_TO_EXCEPTION_TYPE: t.Dict[str, t.Type[AsyncFirebaseError]] = {
FcmErrorCode.INVALID_ARGUMENT.value: errors.InvalidArgumentError,
FcmErrorCode.FAILED_PRECONDITION.value: errors.FailedPreconditionError,
Expand Down
Loading

0 comments on commit 46978fb

Please sign in to comment.