diff --git a/openapi_python_client/helpers/__init__.py b/openapi_python_client/helpers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/openapi_python_client/helpers/authenticators.py b/openapi_python_client/helpers/authenticators.py new file mode 100644 index 000000000..54861cc17 --- /dev/null +++ b/openapi_python_client/helpers/authenticators.py @@ -0,0 +1,299 @@ +"""Models for API authenticators.""" + +from __future__ import annotations + +import base64 +import math +import typing as t +from urllib.parse import parse_qs, urlencode, urlsplit, urlunsplit + +import jwt +import requests +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization +from pydantic import BaseModel, root_validator, validator + +from dlt.common import logger, pendulum, timedelta + + +class APIAuthenticator(BaseModel): + """Base class for API authenticators.""" + + auth_headers: t.Dict[str, str] = {} + auth_params: t.Dict[str, str] = {} + + @staticmethod + def add_parameters(initial_url: str, extra_parameters: dict) -> str: + """Add parameters to an URL and return the new URL.""" + scheme, netloc, path, query_string, fragment = urlsplit(initial_url) + query_params = parse_qs(query_string) + query_params.update( + { + parameter_name: [parameter_value] + for parameter_name, parameter_value in extra_parameters.items() + }, + ) + + new_query_string = urlencode(query_params, doseq=True) + + return urlunsplit((scheme, netloc, path, new_query_string, fragment)) + + def authenticate_request(self, request: requests.PreparedRequest) -> requests.PreparedRequest: + """Authenticate a request.""" + if self.auth_headers: + request.headers.update(self.auth_headers) + if request.url and self.auth_params: + request.url = APIAuthenticator.add_parameters(request.url, self.auth_params) + return request + + def __call__(self, r: requests.PreparedRequest) -> requests.PreparedRequest: + """Authenticate a request.""" + return self.authenticate_request(r) + + +class APIKeyAuthenticator(APIAuthenticator): + """API authenticator using an API key.""" + + key: str + value: str + location: t.Literal["headers", "params"] = "headers" + + @root_validator + def post_root(cls, values: dict) -> "APIKeyAuthenticator": + """Add the API key to the authentication parameters.""" + if values["location"] == "headers": + values["auth_headers"][values["key"]] = values["value"] + elif values["location"] == "params": + values["auth_params"][values["key"]] = values["value"] + else: + raise ValueError("Invalid location for API key, must be 'headers' or 'params'") + return values + + +class BearerTokenAuthenticator(APIAuthenticator): + """API authenticator using a bearer token.""" + + token: str + base64_encode: bool = False + + @root_validator + def post_root(cls, values: dict) -> "BearerTokenAuthenticator": + """Add the bearer token to the authentication headers.""" + if values["base64_encode"]: + values["token"] = base64.b64encode(values["token"].encode("utf-8")).decode("utf-8") + values["auth_headers"]["Authorization"] = f"Bearer {values['token']}" + return values + + +class BasicAuthenticator(APIAuthenticator): + """API authenticator using basic authentication.""" + + username: str + password: str + + @root_validator + def post_root(cls, values: dict) -> "BasicAuthenticator": + """Add the basic authentication to the authentication headers.""" + credentials = f"{values['username']}:{values['password']}".encode("utf-8") + token = base64.b64encode(credentials).decode("utf-8") + values["auth_headers"]["Authorization"] = f"Basic {token}" + return values + + +class NoAuthAuthenticator(APIAuthenticator): + """API authenticator using no authentication.""" + + pass + + +class _OAuthAuthenticator(APIAuthenticator): + """API authenticator using OAuth 2.0.""" + + auth_endpoint: str + oauth_scopes: str | t.List[str] + oauth_headers: dict = {} + default_expiration: int | None = None + + # Internal tracking attributes + access_token: str | None = None + refresh_token: str | None = None + last_refreshed: pendulum.DateTime | None = None + expires_in: int | None = None + + @validator("oauth_scopes", pre=True) + def validate_oauth_scopes(cls, value: str | t.List[str]) -> str: + """Validate the OAuth scopes.""" + if isinstance(value, list): + return " ".join(value) + return value + + @property + def auth_headers(self) -> dict: + """Return a dictionary of auth headers to be applied.""" + if not self.is_token_valid(): + self.update_access_token() + result = super().auth_headers + result["Authorization"] = f"Bearer {self.access_token}" + return result + + def is_token_valid(self) -> bool: + """Check if the access token is valid.""" + if self.access_token is None: + return False + if self.expires_in is None: + return True + if self.last_refreshed is None: + return True + if self.default_expiration is None: + return True + if self.last_refreshed + timedelta(seconds=self.expires_in) < pendulum.now(): + return False + return True + + def update_access_token(self) -> None: + """Update the access token.""" + if self.auth_endpoint is None: + raise ValueError("No auth endpoint specified") + if self.oauth_scopes is None: + raise ValueError("No OAuth scopes specified") + if self.oauth_headers is None: + raise ValueError("No OAuth headers specified") + + logger.debug("Updating access token") + response = requests.post( + url=self.auth_endpoint, + headers=self.oauth_headers, + data={ + "grant_type": "client_credentials", + "scope": self.oauth_scopes, + }, + ) + response.raise_for_status() + response_data = response.json() + self.access_token = response_data["access_token"] + self.refresh_token = response_data.get("refresh_token") + self.expires_in = response_data.get("expires_in") + self.last_refreshed = pendulum.now() + + +class OAuthJWTAuthenticator(_OAuthAuthenticator): + """API authenticator using OAuth 2.0 with JWT.""" + + client_id: str + private_key: str + private_key_passphrase: str | None = None + + @property + def oauth_request_body(self) -> dict: + """Return request body for OAuth request.""" + request_time: pendulum.DateTime = pendulum.utcnow() + return { + "iss": self.client_id, + "scope": self.oauth_scopes, + "aud": self.auth_endpoint, + "exp": math.floor((request_time + timedelta(hours=1)).timestamp()), + "iat": math.floor(request_time.timestamp()), + } + + @property + def oauth_request_payload(self) -> dict: + """Return request paytload for OAuth request.""" + private_key: bytes | t.Any = bytes(self.private_key, "UTF-8") + if self.private_key_passphrase: + passphrase = bytes(self.private_key_passphrase, "UTF-8") + private_key = serialization.load_pem_private_key( + private_key, + password=passphrase, + backend=default_backend(), + ) + private_key_string: str | t.Any = private_key.decode("UTF-8") + return { + "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", + "assertion": jwt.encode( + self.oauth_request_body, + private_key_string, + "RS256", + ), + } + + def update_access_token(self) -> None: + """Update `access_token` along with: `last_refreshed` and `expires_in`. + + Raises: + RuntimeError: When OAuth login fails. + """ + request_time: pendulum.DateTime = pendulum.utcnow() + auth_request_payload = self.oauth_request_payload + token_response = requests.post( + self.auth_endpoint, + headers=self.oauth_headers, + data=auth_request_payload, + timeout=60, + ) + try: + token_response.raise_for_status() + except requests.HTTPError as e: + raise RuntimeError( + "Failed OAuth login, response was '%s'. %s", token_response.json(), e + ) from e + + logger.info("OAuth authorization attempt was successful.") + + token_json: dict = token_response.json() + self.access_token = token_json["access_token"] + expiration = token_json.get("expires_in", self.default_expiration) + self.expires_in = int(expiration) if expiration else None + if self.expires_in is None: + logger.debug( + ( + "No expires_in receied in OAuth response and no " + "default_expiration set. Token will be treated as if it never " + "expires." + ), + ) + self.last_refreshed = request_time + + +class APIAuthenticatorChain(APIAuthenticator): + """API authenticator using a chain of authenticators.""" + + authenticators: t.List[APIAuthenticator] = [] + + def authenticate_request(self, request: requests.PreparedRequest) -> requests.PreparedRequest: + """Authenticate a request.""" + for authenticator in self.authenticators: + request = authenticator.authenticate_request(request) + return request + + +class APIAuthenticatorFactory: + """Factory for API authenticators.""" + + @staticmethod + def create_authenticator(authenticator: t.Union[APIAuthenticator, dict]) -> APIAuthenticator: + """Create an API authenticator from a dict or an APIAuthenticator.""" + if isinstance(authenticator, APIAuthenticator): + return authenticator + elif isinstance(authenticator, dict): + authenticator_type = authenticator.pop("type", "NoAuthAuthenticator") + if authenticator_type == "NoAuthAuthenticator": + return NoAuthAuthenticator(**authenticator) + elif authenticator_type == "APIKeyAuthenticator": + return APIKeyAuthenticator(**authenticator) + elif authenticator_type == "BearerTokenAuthenticator": + return BearerTokenAuthenticator(**authenticator) + elif authenticator_type == "BasicAuthenticator": + return BasicAuthenticator(**authenticator) + elif authenticator_type == "OAuthAuthenticator": + return OAuthJWTAuthenticator(**authenticator) + elif authenticator_type == "APIAuthenticatorChain": + return APIAuthenticatorChain( + authenticators=[ + APIAuthenticatorFactory.create_authenticator(authenticator) + for authenticator in authenticator["authenticators"] + ] + ) + else: + raise ValueError(f"Unknown authenticator type: {authenticator_type}") + else: + raise ValueError("Unknown authenticator type") diff --git a/openapi_python_client/helpers/exceptions.py b/openapi_python_client/helpers/exceptions.py new file mode 100644 index 000000000..267f09a0d --- /dev/null +++ b/openapi_python_client/helpers/exceptions.py @@ -0,0 +1,26 @@ +"""Defines a common set of exceptions which developers can raise and/or catch.""" + +from __future__ import annotations + +import typing as t + +if t.TYPE_CHECKING: + import requests + + +class FatalAPIError(Exception): + """Exception raised when a failed request should not be considered retriable.""" + + +class RetriableAPIError(Exception): + """Exception raised when a failed request can be safely retried.""" + + def __init__(self, message: str, response: "requests.Response" | None = None) -> None: + """Extends the default with the failed response as an attribute. + + Args: + message (str): The Error Message + response (requests.Response): The response object. + """ + super().__init__(message) + self.response = response diff --git a/openapi_python_client/helpers/jsonpath.py b/openapi_python_client/helpers/jsonpath.py new file mode 100644 index 000000000..e944e2fe7 --- /dev/null +++ b/openapi_python_client/helpers/jsonpath.py @@ -0,0 +1,44 @@ +"""JSONPath helpers.""" + +from __future__ import annotations + +import typing as t +from functools import lru_cache + +from jsonpath_ng.ext import parse + +if t.TYPE_CHECKING: + import jsonpath_ng + + +def extract_jsonpath( + expression: str, + input: dict | list, # noqa: A002 +) -> t.Generator[t.Any, None, None]: + """Extract records from an input based on a JSONPath expression. + + Args: + expression: JSONPath expression to match against the input. + input: JSON object or array to extract records from. + + Yields: + Records matched with JSONPath expression. + """ + compiled_jsonpath = _compile_jsonpath(expression) + + match: jsonpath_ng.DatumInContext + for match in compiled_jsonpath.find(input): + yield match.value + + +@lru_cache(maxsize=128) +def _compile_jsonpath(expression: str) -> jsonpath_ng.JSONPath: + """Parse a JSONPath expression and cache the result. + + Args: + expression: A string representing a JSONPath expression. + + Returns: + A compiled JSONPath object. + """ + return parse(expression) diff --git a/openapi_python_client/helpers/metrics.py b/openapi_python_client/helpers/metrics.py new file mode 100644 index 000000000..bfded451e --- /dev/null +++ b/openapi_python_client/helpers/metrics.py @@ -0,0 +1,364 @@ +"""Singer metrics logging.""" + +from __future__ import annotations + +import abc +import enum +import json +import logging +import logging.config +import typing as t +from dataclasses import dataclass, field +from time import time + +from dlt.common import logger + +if t.TYPE_CHECKING: + from types import TracebackType + +DEFAULT_LOG_INTERVAL = 60.0 +METRICS_LOGGER_NAME = __name__ + +T = t.TypeVar("T") + + +class Status(str, enum.Enum): + """Constants for commonly used status values.""" + + SUCCEEDED = "succeeded" + FAILED = "failed" + + +class Tag(str, enum.Enum): + """Constants for commonly used tags.""" + + RESOURCE = "resource" + CONTEXT = "context" + ENDPOINT = "endpoint" + JOB_TYPE = "job_type" + HTTP_STATUS_CODE = "http_status_code" + STATUS = "status" + + +class Metric(str, enum.Enum): + """Common metric types.""" + + RECORD_COUNT = "record_count" + BATCH_COUNT = "batch_count" + HTTP_REQUEST_DURATION = "http_request_duration" + HTTP_REQUEST_COUNT = "http_request_count" + JOB_DURATION = "job_duration" + SYNC_DURATION = "sync_duration" + + +@dataclass +class Point(t.Generic[T]): + """An individual metric measurement.""" + + metric_type: str + metric: Metric + value: T + tags: dict[str, t.Any] = field(default_factory=dict) + + def __str__(self) -> str: + """Get string representation of this measurement. + + Returns: + A string representation of this measurement. + """ + return self.to_json() + + def to_json(self) -> str: + """Convert this measure to a JSON object. + + Returns: + A JSON object. + """ + return json.dumps( + { + "type": self.metric_type, + "metric": self.metric.value, + "value": self.value, + "tags": self.tags, + }, + default=str, + ) + + +def log(logger: logging.Logger, point: Point) -> None: + """Log a measurement. + + Args: + logger: An logger instance. + point: A measurement. + """ + logger.info("METRIC: %s", point) + + +class Meter(metaclass=abc.ABCMeta): + """Base class for all meters.""" + + def __init__(self, metric: Metric, tags: dict | None = None) -> None: + """Initialize a meter. + + Args: + metric: The metric type. + tags: Tags to add to the measurement. + """ + self.metric = metric + self.tags = tags or {} + self.logger = logger + + @property + def context(self) -> dict | None: + """Get the context for this meter. + + Returns: + A context dictionary. + """ + return self.tags.get(Tag.CONTEXT) + + @context.setter + def context(self, value: dict | None) -> None: + """Set the context for this meter. + + Args: + value: A context dictionary. + """ + if value is None: + self.tags.pop(Tag.CONTEXT, None) + else: + self.tags[Tag.CONTEXT] = value + + @abc.abstractmethod + def __enter__(self) -> Meter: + """Enter the meter context.""" + ... + + @abc.abstractmethod + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Exit the meter context. + + Args: + exc_type: The exception type. + exc_val: The exception value. + exc_tb: The exception traceback. + """ + ... + + +class Counter(Meter): + """A meter for counting things.""" + + def __init__( + self, + metric: Metric, + tags: dict | None = None, + log_interval: float = DEFAULT_LOG_INTERVAL, + ) -> None: + """Initialize a counter. + + Args: + metric: The metric type. + tags: Tags to add to the measurement. + log_interval: The interval at which to log the count. + """ + super().__init__(metric, tags) + self.value = 0 + self.log_interval = log_interval + self.last_log_time = time() + + def __enter__(self) -> Counter: + """Enter the counter context. + + Returns: + The counter instance. + """ + self.last_log_time = time() + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Exit the counter context. + + Args: + exc_type: The exception type. + exc_val: The exception value. + exc_tb: The exception traceback. + """ + self._pop() + + def _pop(self) -> None: + """Log and reset the counter.""" + log(self.logger, Point("counter", self.metric, self.value, self.tags)) + self.value = 0 + self.last_log_time = time() + + def increment(self, value: int = 1) -> None: + """Increment the counter. + + Args: + value: The value to increment by. + """ + self.value += value + if self._ready_to_log(): + self._pop() + + def _ready_to_log(self) -> bool: + """Check if the counter is ready to log. + + Returns: + True if the counter is ready to log. + """ + return time() - self.last_log_time > self.log_interval + + +class Timer(Meter): + """A meter for timing things.""" + + def __init__(self, metric: Metric, tags: dict | None = None) -> None: + """Initialize a timer. + + Args: + metric: The metric type. + tags: Tags to add to the measurement. + """ + super().__init__(metric, tags) + self.start_time = time() + + def __enter__(self) -> Timer: + """Enter the timer context. + + Returns: + The timer instance. + """ + self.start_time = time() + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Exit the timer context. + + Args: + exc_type: The exception type. + exc_val: The exception value. + exc_tb: The exception traceback. + """ + if Tag.STATUS not in self.tags: + if exc_type is None: + self.tags[Tag.STATUS] = Status.SUCCEEDED + else: + self.tags[Tag.STATUS] = Status.FAILED + log(self.logger, Point("timer", self.metric, self.elapsed(), self.tags)) + + def elapsed(self) -> float: + """Get the elapsed time. + + Returns: + The elapsed time. + """ + return time() - self.start_time + + +def record_counter( + stream: str, + endpoint: str | None = None, + log_interval: float = DEFAULT_LOG_INTERVAL, + **tags: t.Any, +) -> Counter: + """Use for counting records retrieved from the source. + + with record_counter("my_stream", endpoint="/users") as counter: + for record in my_records: + # Do something with the record + counter.increment() + + Args: + stream: The stream name. + endpoint: The endpoint name. + log_interval: The interval at which to log the count. + tags: Tags to add to the measurement. + + Returns: + A counter for counting records. + """ + tags[Tag.RESOURCE] = stream + if endpoint: + tags[Tag.ENDPOINT] = endpoint + return Counter(Metric.RECORD_COUNT, tags, log_interval=log_interval) + + +def batch_counter(stream: str, **tags: t.Any) -> Counter: + """Use for counting batches sent to the target. + + with batch_counter("my_stream") as counter: + for batch in my_batches: + # Do something with the batch + counter.increment() + + Args: + stream: The stream name. + tags: Tags to add to the measurement. + + Returns: + A counter for counting batches. + """ + tags[Tag.RESOURCE] = stream + return Counter(Metric.BATCH_COUNT, tags) + + +def http_request_counter( + stream: str, + endpoint: str, + log_interval: float = DEFAULT_LOG_INTERVAL, + **tags: t.Any, +) -> Counter: + """Use for counting HTTP requests. + + with http_request_counter() as counter: + for record in my_records: + # Do something with the record + counter.increment() + + Args: + stream: The stream name. + endpoint: The endpoint name. + log_interval: The interval at which to log the count. + tags: Tags to add to the measurement. + + Returns: + A counter for counting HTTP requests. + """ + tags.update({Tag.RESOURCE: stream, Tag.ENDPOINT: endpoint}) + return Counter(Metric.HTTP_REQUEST_COUNT, tags, log_interval=log_interval) + + +def sync_timer(resource: str, **tags: t.Any) -> Timer: + """Use for timing the sync of a resource. + + with singer.metrics.sync_timer() as timer: + # Do something + print(f"Sync took {timer.elapsed()} seconds") + + Args: + resource: The resource name. + tags: Tags to add to the measurement. + + Returns: + A timer for timing the sync of a resource. + """ + tags[Tag.RESOURCE] = resource + return Timer(Metric.SYNC_DURATION, tags) diff --git a/openapi_python_client/helpers/paginators.py b/openapi_python_client/helpers/paginators.py new file mode 100644 index 000000000..1a98673e4 --- /dev/null +++ b/openapi_python_client/helpers/paginators.py @@ -0,0 +1,182 @@ +"""Generic paginator classes.""" + +from __future__ import annotations + +import typing as t +from abc import ABCMeta, abstractmethod +from urllib.parse import ParseResult, urlparse + +from .jsonpath import extract_jsonpath + +if t.TYPE_CHECKING: + from requests import Response + + +T = t.TypeVar("T") +TPageToken = t.TypeVar("TPageToken") + + +# TODO: move to common.utils +def first(iterable: t.Iterable[T]) -> T: + """Return the first element of an iterable or raise an exception.""" + return next(iter(iterable)) + + +class BaseAPIPaginator(t.Generic[TPageToken], metaclass=ABCMeta): + """An API paginator object.""" + + def __init__(self, start_value: TPageToken) -> None: + """Create a new paginator.""" + self._value: TPageToken = start_value + self._page_count = 0 + self._finished = False + self._last_seen_record: dict | None = None + + @property + def current_value(self) -> TPageToken: + """Get the current pagination value.""" + return self._value + + @property + def finished(self) -> bool: + """Get a flag that indicates if the last page of data has been reached.""" + return self._finished + + @property + def count(self) -> int: + """Count the number of pages traversed so far.""" + return self._page_count + + def advance(self, response: Response) -> None: + """Get a new page value and advance the current one.""" + self._page_count += 1 + + if not self.has_more(response): + self._finished = True + return + + new_value = self.get_next(response) + + if new_value and new_value == self._value: + raise RuntimeError( + "Loop detected in pagination. Pagination token %s is identical to prior token." + % new_value + ) + + if not new_value: + self._finished = True + else: + self._value = new_value + + def has_more(self, response: Response) -> bool: + """Override this method to check if the endpoint has any pages left.""" + return True + + @abstractmethod + def get_next(self, response: Response) -> TPageToken | None: + """Get the next pagination token or index from the API response.""" + pass + + def __str__(self) -> str: + return f"{self.__class__.__name__}<{self.current_value}>" + + def __repr__(self) -> str: + return str(self) + + +class SinglePagePaginator(BaseAPIPaginator[None]): + """A paginator that works with single-page endpoints.""" + + def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: + """Create a new paginator.""" + super().__init__(None, *args, **kwargs) + + def get_next(self, response: Response) -> None: + """Returns `None` to indicate the end of pagination.""" + return None + + +class BaseHATEOASPaginator( + BaseAPIPaginator[t.Optional[ParseResult]], + metaclass=ABCMeta, +): + """Paginator class for APIs supporting HATEOAS links in their response bodies.""" + + def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: + """Create a new paginator.""" + super().__init__(None, *args, **kwargs) + + @abstractmethod + def get_next_url(self, response: Response) -> str | None: + """Override this method to extract a HATEOAS link from the response.""" + + def get_next(self, response: Response) -> ParseResult | None: + """Get the next pagination token or index from the API response.""" + next_url = self.get_next_url(response) + return urlparse(next_url) if next_url else None + + +class HeaderLinkPaginator(BaseHATEOASPaginator): + """Paginator class for APIs supporting HATEOAS links in their headers.""" + + def get_next_url(self, response: Response) -> str | None: + """Override this method to extract a HATEOAS link from the response.""" + url: str | None = response.links.get("next", {}).get("url") + return url + + +class JSONPathPaginator(BaseAPIPaginator[t.Optional[str]]): + """Paginator class for APIs returning a pagination token in the response body.""" + + def __init__(self, jsonpath: str, *args: t.Any, **kwargs: t.Any) -> None: + """Create a new paginator.""" + super().__init__(None, *args, **kwargs) + self._jsonpath = jsonpath + + def get_next(self, response: Response) -> str | None: + """Get the next page token.""" + matches = extract_jsonpath(self._jsonpath, response.json()) + return next(matches, None) + + +class SimpleHeaderPaginator(BaseAPIPaginator[t.Optional[str]]): + """Paginator class for APIs returning a pagination token in the response headers.""" + + def __init__(self, key: str, *args: t.Any, **kwargs: t.Any) -> None: + """Create a new paginator.""" + super().__init__(None, *args, **kwargs) + self._key = key + + def get_next(self, response: Response) -> str | None: + """Get the next page token.""" + return response.headers.get(self._key, None) + + +class BasePageNumberPaginator(BaseAPIPaginator[int], metaclass=ABCMeta): + """Paginator class for APIs that use page number.""" + + @abstractmethod + def has_more(self, response: Response) -> bool: + """Override this method to check if the endpoint has any pages left.""" + + def get_next(self, response: Response) -> int | None: # noqa: ARG002 + """Get the next page number.""" + return self._value + 1 + + +class BaseOffsetPaginator(BaseAPIPaginator[int], metaclass=ABCMeta): + """Paginator class for APIs that use page offset.""" + + def __init__(self, start_value: int, page_size: int, *args: t.Any, **kwargs: t.Any) -> None: + """Create a new paginator.""" + super().__init__(start_value, *args, **kwargs) + self._page_size = page_size + + @abstractmethod + def has_more(self, response: Response) -> bool: + """Override this method to check if the endpoint has any pages left.""" + ... + + def get_next(self, response: Response) -> int | None: # noqa: ARG002 + """Get the next page offset.""" + return self._value + self._page_size diff --git a/openapi_python_client/helpers/rest.py b/openapi_python_client/helpers/rest.py new file mode 100644 index 000000000..7a8bb3248 --- /dev/null +++ b/openapi_python_client/helpers/rest.py @@ -0,0 +1,484 @@ +"""Abstract base class for API resources.""" + +from __future__ import annotations + +import abc +import copy +import logging +import time +import typing as t +from http import HTTPStatus +from urllib.parse import urlparse + +import backoff +import requests +from dlt.sources.helpers.requests import Client + +from dlt.common import logger +from . import metrics +from .authenticators import APIAuthenticator +from .exceptions import FatalAPIError, RetriableAPIError +from .jsonpath import extract_jsonpath +from .paginators import ( + BaseAPIPaginator, + BaseOffsetPaginator, + SimpleHeaderPaginator, + SinglePagePaginator, +) + +if t.TYPE_CHECKING: + from backoff.types import Details + + +DEFAULT_PAGE_SIZE = 1000 +DEFAULT_REQUEST_TIMEOUT = 300 # 5 minutes + +TPaginatorToken = t.TypeVar("TPaginatorToken") +TEndpoint = t.TypeVar("TEndpoint", bound="RestAPIEndpoint", covariant=True) +TPaginator = t.TypeVar("TPaginator", bound="BaseAPIPaginator", covariant=True) +TAuthenticator = t.TypeVar("TAuthenticator", bound="APIAuthenticator", covariant=True) +TDefEndpoint = t.TypeVar("TDefEndpoint", bound="RestAPIEndpoint") +TDefPaginator = t.TypeVar("TDefPaginator", bound="BaseAPIPaginator") +TDefAuthenticator = t.TypeVar("TDefAuthenticator", bound="APIAuthenticator") + +# add covariant +# overloads for ALL paginators +# overloads for ALL authenticators + + +class BaseAPIEndpointFactory: + pass + + +class BaseAPIEndpoint: + pass + + +class RestAPIEndpoint(BaseAPIEndpoint, t.Generic[TPaginatorToken]): + """Base class for a REST API endpoint.""" + + records_jsonpath: str | None = None + + def __init__( + self, + api: RestAPIEndpointFactory, + path: str, + records_jsonpath: str | None = None, + method: str = "GET", + paginator: BaseAPIPaginator | None = None, + http_headers: dict[str, str] | None = None, + request_body: dict | None = None, + request_params: dict | None = None, + ) -> None: + self.api = api + self.path = path + # If records_jsonpath is provided at init, use it + if records_jsonpath: + self.records_jsonpath = records_jsonpath + # Otherwise, use the class-level records_jsonpath + if not self.records_jsonpath: + # If the class-level records_jsonpath is None, use the default from the API + self.records_jsonpath = api.records_jsonpath + self.method = method + if paginator: + paginator = copy.deepcopy(paginator) + self.paginator = paginator or copy.deepcopy(api.default_paginator) + self.http_headers = api.headers + self.http_headers.update(http_headers or {}) + self.request_body = request_body + self.request_params = request_params + + @property + def resolved_path(self) -> str: + """Get entity URL.""" + return self.api.base_url + self.path + + def _request(self, prepared_request: requests.PreparedRequest) -> requests.Response: + """Send a prepared request and return the response.""" + response = self.api.requests_session.send(prepared_request, timeout=self.timeout) + self._write_request_duration_log( + endpoint=self.path, + response=response, + extra_tags={"url": prepared_request.path_url}, + ) + self.api.validate_response(response) + logger.debug("Response received successfully.") + return response + + def get_request_params( + self, next_page_token: TPaginatorToken | None + ) -> dict[str, t.Any] | str | None: + """Create or update the request params for the REST API request.""" + return self.request_params + + def get_request_body(self, next_page_token: TPaginatorToken | None) -> dict | None: + """Create or update the request body for the REST API request.""" + return self.request_body + + def get_request_url(self, next_page_token: TPaginatorToken | None) -> str | None: + """Create or update the request url for the REST API request.""" + return None + + def get_prepared_request(self, *args: t.Any, **kwargs: t.Any) -> requests.PreparedRequest: + """Build a generic but authenticated request.""" + request = requests.Request(*args, **kwargs) + return self.api.authenticator(self.api.requests_session.prepare_request(request)) + + def prepare_request(self, next_page_token: TPaginatorToken | None) -> requests.PreparedRequest: + """Prepare a request object for this endpoint.""" + http_method = self.method + headers = self.http_headers + + url = self.get_request_url(next_page_token) or self.resolved_path + params = self.get_request_params(next_page_token) or {} + request_data = self.get_request_body(next_page_token) + + return self.get_prepared_request( + method=http_method, + url=url, + params=params, + headers=headers, + json=request_data, + ) + + def request_records(self) -> t.Iterable[dict]: + """Request records from REST endpoint(s), returning response records.""" + decorated_request = self.api.request_decorator(self._request) + + with metrics.http_request_counter("cdf", self.path) as request_counter: + while not self.paginator.finished: + prepared_request = self.prepare_request( + next_page_token=self.paginator.current_value + ) + resp = decorated_request(prepared_request) + request_counter.increment() + self.update_sync_costs(prepared_request, resp) + yield from self.parse_response(resp) + self.paginator.advance(resp) + + def _write_request_duration_log( + self, endpoint: str, response: requests.Response, extra_tags: dict | None + ) -> None: + """Write a log entry for the request duration.""" + extra_tags = extra_tags or {} + point = metrics.Point( + "timer", + metric=metrics.Metric.HTTP_REQUEST_DURATION, + value=response.elapsed.total_seconds(), + tags={ + metrics.Tag.ENDPOINT: endpoint, + metrics.Tag.HTTP_STATUS_CODE: response.status_code, + metrics.Tag.STATUS: ( + metrics.Status.SUCCEEDED + if response.status_code < HTTPStatus.BAD_REQUEST + else metrics.Status.FAILED + ), + **extra_tags, + }, + ) + metrics.log(logger, point=point) + + def update_sync_costs( + self, request: requests.PreparedRequest, response: requests.Response + ) -> dict[str, int]: + """Update internal calculation of sync costs.""" + call_costs = self.calculate_sync_cost(request, response) + self._sync_costs = { + k: self._sync_costs.get(k, 0) + call_costs.get(k, 0) for k in call_costs + } + return self._sync_costs + + def calculate_sync_cost( + self, request: requests.PreparedRequest, response: requests.Response + ) -> t.Dict[str, int]: + """Calculate the cost of the last API call made.""" + return {} + + @property + def timeout(self) -> int: + """Return the request timeout limit in seconds.""" + return DEFAULT_REQUEST_TIMEOUT + + def get_records(self) -> t.Iterable[dict[str, t.Any]]: + """Return a generator of record-type dictionary objects.""" + for record in self.request_records(): + transformed_record = self.post_process(record) + if transformed_record is None: + continue + yield transformed_record + + def parse_response(self, response: requests.Response) -> t.Iterable[dict]: + """Parse the response and return an iterator of result records.""" + yield from extract_jsonpath(self.records_jsonpath, input=response.json()) + + def post_process(self, row: dict) -> dict | None: + """As needed, append or transform raw data to match expected structure.""" + return row + + def __iter__(self) -> t.Iterable[t.Dict[str, t.Any]]: + """Lazily consume the endpoint and return a generator of records.""" + yield from self.get_records() + + def __call__(self) -> t.List[t.Dict[str, t.Any]]: + """Immediately consume the endpoint and return a list of records.""" + return list(self) + + +class RestAPIEndpointFactory( + BaseAPIEndpointFactory, + t.Generic[TDefEndpoint, TDefPaginator, TDefAuthenticator], + metaclass=abc.ABCMeta, +): + """Abstract base class for REST APIs. This can be thought of as a api for endpoints.""" + + extra_retry_statuses: t.Sequence[int] = [HTTPStatus.TOO_MANY_REQUESTS] + """HTTP statuses that should be retried.""" + records_jsonpath: str = "$[*]" + """JSONPath expression to extract records from the response.""" + endpoint_klass: t.Type[TDefEndpoint] = RestAPIEndpoint + """The default endpoint class to use for this API.""" + paginator_klass: t.Type[TDefPaginator] = SinglePagePaginator + """The default paginator class to use for this API.""" + default_paginator = SinglePagePaginator() + """The default paginator instance to use for this API. + + If no paginator is provided, a deepcopy of this one will be used. + """ + authenticator: TDefAuthenticator | None = None + """The authenticator to use for this API. + + This can be overridden by passing an authenticator to the constructor. An authenticator + is a callable that takes a prepared request and returns a prepared request. It is required. + A ValueError will be raised if no authenticator is provided. + """ + + @property + @abc.abstractmethod + def base_url(self) -> str: + """Return the base url, e.g. ``https://api.gong.io/v2/``.""" + + def __init__( + self, + authenticator: APIAuthenticator | None = None, + headers: dict[str, str] | None = None, + ) -> None: + """Initialize the REST endpoint.""" + if authenticator: + # Override the class-level authenticator with the provided one. + self.authenticator = authenticator + if self.authenticator is None: + raise ValueError("No authenticator provided.") + self.headers = headers or {} + self.client = Client(session_attrs={"headers": self.headers}) + + @property + def requests_session(self) -> requests.Session: + """Get requests session.""" + return self.client.session + + def validate_response(self, response: requests.Response) -> None: + """Validate HTTP response.""" + if ( + response.status_code in self.extra_retry_statuses + or HTTPStatus.INTERNAL_SERVER_ERROR <= response.status_code <= max(HTTPStatus) + ): + raise RetriableAPIError(self.response_error_message(response), response) + if HTTPStatus.BAD_REQUEST <= response.status_code < HTTPStatus.INTERNAL_SERVER_ERROR: + raise FatalAPIError(self.response_error_message(response)) + + def response_error_message(self, response: requests.Response) -> str: + """Build error message for invalid http statuses.""" + full_path = urlparse(response.url).path or self.base_url + error_type = ( + "Client" + if HTTPStatus.BAD_REQUEST <= response.status_code < HTTPStatus.INTERNAL_SERVER_ERROR + else "Server" + ) + return f"{response.status_code} {error_type} Error: {response.reason} for path: {full_path}" + + def request_decorator(self, func: t.Callable) -> t.Callable: + """Instantiate a decorator for handling request failures.""" + decorator: t.Callable = backoff.on_exception( + self.backoff_wait_generator, + ( + ConnectionResetError, + RetriableAPIError, + requests.exceptions.ReadTimeout, + requests.exceptions.ConnectionError, + requests.exceptions.ChunkedEncodingError, + requests.exceptions.ContentDecodingError, + ), + max_tries=self.backoff_max_tries, + on_backoff=self.backoff_handler, + jitter=self.backoff_jitter, + )(func) + return decorator + + @t.overload + def endpoint_factory( + self, + path: str, + klass: None = None, + /, + **kwargs, + ) -> TDefEndpoint: + ... + + @t.overload + def endpoint_factory( + self, + path: str, + klass: t.Type[RestAPIEndpoint], + /, + *, + records_jsonpath: str | None = None, + method: str = "GET", + paginator: BaseAPIPaginator | None = None, + http_headers: dict[str, str] | None = None, + request_body: dict | None = None, + request_params: dict | None = None, + ) -> TEndpoint: + ... + + @t.overload + def endpoint_factory( + self, + path: str, + klass: t.Type[TEndpoint], + /, + **kwargs, + ) -> TEndpoint: + ... + + def endpoint_factory( + self, path: str, klass: t.Type[TEndpoint] | None = None, /, **kwargs + ) -> TDefEndpoint | TEndpoint: + """Return a new endpoint object. Kwargs are passed to the endpoint constructor.""" + return (klass or self.endpoint_klass)(api=self, path=path, **kwargs) + + @t.overload + def paginator_factory( + self, + klass: None = None, + /, + **kwargs, + ) -> TDefPaginator: + ... + + @t.overload + def paginator_factory( + self, + klass: t.Type[SinglePagePaginator], + /, + ) -> SinglePagePaginator: + ... + + @t.overload + def paginator_factory( + self, + klass: t.Type[SimpleHeaderPaginator], + /, + key: str, + ) -> SimpleHeaderPaginator: + ... + + @t.overload + def paginator_factory( + self, + klass: t.Type[BaseOffsetPaginator], + /, + start_value: int, + page_size: int, + ) -> BaseOffsetPaginator: + ... + + @t.overload + def paginator_factory( + self, + klass: t.Type[TPaginator], + /, + **kwargs, + ) -> TPaginator: + ... + + def paginator_factory( + self, + klass: t.Type[TPaginator] | None = None, + /, + **kwargs, + ) -> TDefPaginator | TPaginator: + return (klass or self.paginator_klass)(**kwargs) + + def backoff_wait_generator(self) -> t.Generator[float, None, None]: + """The wait generator used by the backoff decorator on request failure.""" + return backoff.expo(factor=2) + + def backoff_max_tries(self) -> int: + """The number of attempts before giving up when retrying requests.""" + return 5 + + def backoff_jitter(self, value: float) -> float: + """Amount of jitter to add.""" + return backoff.random_jitter(value) + + def backoff_handler(self, details: "Details") -> None: + """Adds additional behaviour prior to retry.""" + e = details.get("exception") + if ( + isinstance(e, RetriableAPIError) + and e.response.status_code == HTTPStatus.TOO_MANY_REQUESTS + ): + retry_after = int(e.response.headers.get("Retry-After", 15)) + logging.warning( + "429 Too Many Requests: backing off %0.2f seconds after %d tries. Retry-After: %s", + details.get("wait"), + details.get("tries"), + retry_after, + ) + time.sleep(retry_after) + else: + logging.error( + ( + "Backing off %0.2f seconds after %d tries " + "calling function %s with args %s and kwargs " + "%s" + ), + details.get("wait"), + details.get("tries"), + details.get("target"), + details.get("args"), + details.get("kwargs"), + ) + + def backoff_runtime(self, *, value: t.Callable[[t.Any], int]) -> t.Generator[int, None, None]: + """Optional backoff wait generator that can replace the default `backoff.expo`.""" + exception = yield # type: ignore[misc] + while True: + exception = yield value(exception) + + +if __name__ == "__main__": + + class JsonPlaceholderAPIEndpoint(RestAPIEndpoint[None]): + """JsonPlaceholder API endpoint.""" + + class JsonPlaceholderAPIEndpointSpecialized(JsonPlaceholderAPIEndpoint): + """JsonPlaceholder API endpoint specialized.""" + + class JsonPlaceholderAPIFactory( + RestAPIEndpointFactory[JsonPlaceholderAPIEndpoint, SinglePagePaginator, APIAuthenticator] + ): + """Factory for the JsonPlaceholder API.""" + + base_url = "https://jsonplaceholder.typicode.com/" + endpoint_klass = JsonPlaceholderAPIEndpoint + paginator_klass = SinglePagePaginator + default_paginator = SinglePagePaginator() + authenticator = APIAuthenticator() + records_jsonpath = "$.[*]" + + # Example usage with overloaded type signatures + api = JsonPlaceholderAPIFactory() + posts = api.endpoint_factory("posts") + todos = api.endpoint_factory("todos") + users = api.endpoint_factory("users")