Skip to content

Commit

Permalink
Put a short timeout and retry 3 times to make OpenAI calls more robus…
Browse files Browse the repository at this point in the history
…t, since their API is not incredibly reliable, but if it takes too long it's usually a sign you should cut short and try again, specially since we are using streaming mode
  • Loading branch information
rogeriochaves committed Jul 25, 2023
1 parent 47f6848 commit 738b707
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
11 changes: 11 additions & 0 deletions litechain/contrib/llms/open_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import openai
from colorama import Fore
from retry import retry

from litechain.core.chain import Chain, ChainOutput

Expand Down Expand Up @@ -64,17 +65,22 @@ def __init__(
model: str,
temperature: Optional[float] = 0,
max_tokens: Optional[int] = None,
timeout: int = 5,
retries: int = 3,
) -> None:
async def completion(prompt: str) -> AsyncGenerator[U, None]:
loop = asyncio.get_event_loop()

@retry(tries=retries)
def get_completions():
return openai.Completion.create(
model=model,
prompt=prompt,
temperature=temperature,
stream=True,
max_tokens=max_tokens,
timeout=timeout,
request_timeout=timeout,
)

completions = await loop.run_in_executor(None, get_completions)
Expand Down Expand Up @@ -272,12 +278,15 @@ def __init__(
function_call: Optional[Union[Literal["none", "auto"], str]] = None,
temperature: Optional[float] = 0,
max_tokens: Optional[int] = None,
timeout: int = 5,
retries: int = 3,
) -> None:
async def chat_completion(
messages: List[OpenAIChatMessage],
) -> AsyncGenerator[ChainOutput[OpenAIChatDelta], None]:
loop = asyncio.get_event_loop()

@retry(tries=retries)
def get_completions():
function_kwargs = {}
if functions is not None:
Expand All @@ -286,6 +295,8 @@ def get_completions():
function_kwargs["function_call"] = function_call

return openai.ChatCompletion.create(
timeout=timeout,
request_timeout=timeout,
model=model,
messages=[m.to_dict() for m in messages],
temperature=temperature,
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ colorama
# Contrib
openai
gpt4all
retry

# Function parsing
docstring_parser
Expand Down

0 comments on commit 738b707

Please sign in to comment.