Skip to content

Commit

Permalink
fix: add retry to client for ratelimit (#73)
Browse files Browse the repository at this point in the history
  • Loading branch information
lorr1 authored Apr 8, 2023
1 parent ee9f166 commit c7906be
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 32 deletions.
63 changes: 37 additions & 26 deletions manifest/clients/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,22 @@

import aiohttp
import requests
from tenacity import RetryCallState, retry, stop_after_attempt, wait_random_exponential

from manifest.request import DEFAULT_REQUEST_KEYS, NOT_CACHE_KEYS, Request
from manifest.response import RESPONSE_CONSTRUCTORS, Response

logger = logging.getLogger(__name__)


def retry_if_ratelimit(retry_base: RetryCallState) -> bool:
"""Return whether to retry if ratelimited."""
if isinstance(retry_base.outcome.exception(), requests.exceptions.HTTPError):
if retry_base.outcome.exception().response.status_code == 429: # type: ignore
return True
return False


class Client(ABC):
"""Client class."""

Expand Down Expand Up @@ -194,6 +203,12 @@ def split_requests(
request_params_list.append(params)
return request_params_list

@retry(
reraise=True,
retry=retry_if_ratelimit,
wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(10),
)
def _run_completion(
self, request_params: Dict[str, Any], retry_timeout: int
) -> Dict:
Expand All @@ -207,25 +222,25 @@ def _run_completion(
response as dict.
"""
post_str = self.get_generation_url()
res = requests.post(
post_str,
headers=self.get_generation_header(),
json=request_params,
timeout=retry_timeout,
)
try:
res = requests.post(
post_str,
headers=self.get_generation_header(),
json=request_params,
timeout=retry_timeout,
)
res.raise_for_status()
except requests.Timeout as e:
logger.error(
f"{self.__class__.__name__} request timed out."
" Increase client_timeout."
)
raise e
except requests.exceptions.HTTPError:
logger.error(res.json())
raise requests.exceptions.HTTPError(res.json())
return self.format_response(res.json(), request_params)

@retry(
reraise=True,
retry=retry_if_ratelimit,
wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(10),
)
async def _arun_completion(
self, request_params: Dict[str, Any], retry_timeout: int, batch_size: int
) -> Dict:
Expand All @@ -240,20 +255,16 @@ async def _arun_completion(
response as dict.
"""
post_str = self.get_generation_url()
try:
async with aiohttp.ClientSession(timeout=retry_timeout) as session:
async with session.post(
post_str,
headers=self.get_generation_header(),
json=request_params,
timeout=retry_timeout,
) as res:
res.raise_for_status()
res_json = await res.json(content_type=None)
return self.format_response(res_json, request_params)
except aiohttp.ClientError as e:
logger.error(f"{self.__class__.__name__} request error {e}")
raise e
async with aiohttp.ClientSession(timeout=retry_timeout) as session:
async with session.post(
post_str,
headers=self.get_generation_header(),
json=request_params,
timeout=retry_timeout,
) as res:
res.raise_for_status()
res_json = await res.json(content_type=None)
return self.format_response(res_json, request_params)

def run_request(self, request: Request) -> Response:
"""
Expand Down
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
NAME = "manifest-ml"
DESCRIPTION = "Manifest for Prompting Foundation Models."
URL = "https://github.com/HazyResearch/manifest"
EMAIL = "[email protected]"
EMAIL = "[email protected]"
AUTHOR = "Laurel Orr"
REQUIRES_PYTHON = ">=3.8.0"
VERSION = main_ns["__version__"]
Expand All @@ -34,8 +34,9 @@
"requests>=2.27.1",
"aiohttp>=3.8.0",
"sqlitedict>=2.0.0",
"xxhash>=3.0.0",
"tenacity>=8.2.0",
"tiktoken>=0.3.0",
"xxhash>=3.0.0",
]

# What packages are optional?
Expand Down
2 changes: 1 addition & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Test client.
We just test the dummy client as we don't want to load a model or use OpenAI tokens.
We just test the dummy client.
"""
from manifest.clients.dummy import DummyClient

Expand Down
90 changes: 87 additions & 3 deletions tests/test_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import asyncio
import os
from typing import cast
from unittest.mock import MagicMock, Mock, patch

import pytest
import requests
from requests import HTTPError

from manifest import Manifest, Response
from manifest.caches.noop import NoopCache
Expand Down Expand Up @@ -643,7 +645,7 @@ def test_openaichat(sqlite_cache: str) -> None:
assert isinstance(response.get_response(), str) and len(response.get_response()) > 0
assert response.is_cached() is True
assert "usage" in response.get_json_response()
assert response.get_json_response()["usage"][0]["total_tokens"] == 22
assert response.get_json_response()["usage"][0]["total_tokens"] == 23

response = cast(Response, client.run("Why are there apples?", return_response=True))
assert response.is_cached() is True
Expand Down Expand Up @@ -674,10 +676,92 @@ def test_openaichat(sqlite_cache: str) -> None:
"usage" in response.get_json_response()
and len(response.get_json_response()["usage"]) == 2
)
assert response.get_json_response()["usage"][0]["total_tokens"] == 24
assert response.get_json_response()["usage"][1]["total_tokens"] == 22
assert response.get_json_response()["usage"][0]["total_tokens"] == 25
assert response.get_json_response()["usage"][1]["total_tokens"] == 23

response = cast(
Response, client.run("Why are there oranges?", return_response=True)
)
assert response.is_cached() is True


def test_retry_handling() -> None:
"""Test retry handling."""
# We'll mock the response so we won't need a real connection
client = Manifest(client_name="openai", client_connection="fake")
mock_create = MagicMock(
side_effect=[
# raise a 429 error
HTTPError(
response=Mock(status_code=429, json=Mock(return_value={})),
request=Mock(),
),
# get a valid http response with a 200 status code
Mock(
status_code=200,
json=Mock(
return_value={
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": None,
"text": " WHATTT.",
},
{
"finish_reason": "length",
"index": 1,
"logprobs": None,
"text": " UH OH.",
},
{
"finish_reason": "length",
"index": 2,
"logprobs": None,
"text": " HARG",
},
],
"created": 1679469056,
"id": "cmpl-6wmuWfmyuzi68B6gfeNC0h5ywxXL5",
"model": "text-ada-001",
"object": "text_completion",
"usage": {
"completion_tokens": 30,
"prompt_tokens": 24,
"total_tokens": 54,
},
}
),
),
]
)
prompts = [
"The sky is purple. This is because",
"The sky is magnet. This is because",
"The sky is fuzzy. This is because",
]
with patch("manifest.clients.client.requests.post", mock_create):
# Run manifest
result = client.run(prompts, temperature=0, overwrite_cache=True)
assert result == ["WHATTT.", "UH OH.", "HARG"]

# Assert that OpenAI client was called twice
assert mock_create.call_count == 2

# Now make sure it errors when not a 429
mock_create = MagicMock(
side_effect=[
# raise a 500 error
HTTPError(
response=Mock(status_code=500, json=Mock(return_value={})),
request=Mock(),
),
]
)
with patch("manifest.clients.client.requests.post", mock_create):
# Run manifest
with pytest.raises(HTTPError):
client.run(prompts, temperature=0, overwrite_cache=True)

# Assert that OpenAI client was called once
assert mock_create.call_count == 1

0 comments on commit c7906be

Please sign in to comment.