diff --git a/evals/cli/oaieval.py b/evals/cli/oaieval.py index 5b80a61712..1bf725e2e8 100644 --- a/evals/cli/oaieval.py +++ b/evals/cli/oaieval.py @@ -7,6 +7,8 @@ import sys from typing import Any, Mapping, Optional, Union, cast +import openai + import evals import evals.api import evals.base @@ -266,6 +268,9 @@ def main() -> None: filename=args.log_to_file if args.log_to_file else None, ) logging.getLogger("openai").setLevel(logging.WARN) + + if hasattr(openai.error, "set_display_cause"): # type: ignore + openai.error.set_display_cause() # type: ignore run(args) diff --git a/evals/elsuite/make_me_say/autoeval.py b/evals/elsuite/make_me_say/autoeval.py index 9df3fd393a..a25bba7a7e 100644 --- a/evals/elsuite/make_me_say/autoeval.py +++ b/evals/elsuite/make_me_say/autoeval.py @@ -2,7 +2,7 @@ from typing import Callable, Optional, Union import backoff -from openai import BadRequestError +from openai.error import InvalidRequestError from evals.api import CompletionFn, CompletionResult from evals.elsuite.make_me_say.core import Game, Message, Player @@ -17,7 +17,7 @@ ) -@backoff.on_exception(backoff.constant, BadRequestError, max_tries=3) +@backoff.on_exception(backoff.constant, InvalidRequestError, max_tries=3) def run( codeword: str, manipulator_completion_fn: CompletionFn, diff --git a/evals/elsuite/make_me_say/utils.py b/evals/elsuite/make_me_say/utils.py index 11eb38e72b..bb982b9b5c 100644 --- a/evals/elsuite/make_me_say/utils.py +++ b/evals/elsuite/make_me_say/utils.py @@ -3,6 +3,7 @@ import backoff import openai +import openai.error import urllib3.exceptions from evals.api import CompletionResult @@ -11,12 +12,9 @@ @backoff.on_exception( backoff.expo, ( - openai.APIError, - openai.APIStatusError, - openai.RateLimitError, - openai.APITimeoutError, - openai.APIConnectionError, - openai.InternalServerError, + openai.error.RateLimitError, + openai.error.ServiceUnavailableError, + openai.error.TryAgain, urllib3.exceptions.TimeoutError, ), ) diff --git a/evals/registry.py b/evals/registry.py index e36da7e9e8..48a62116f7 100644 --- a/evals/registry.py +++ b/evals/registry.py @@ -99,7 +99,7 @@ def add_registry_paths(self, paths: Sequence[Union[str, Path]]) -> None: def api_model_ids(self) -> list[str]: try: return [m["id"] for m in openai.Model.list()["data"]] - except openai.OpenAIError as err: # type: ignore + except openai.error.OpenAIError as err: # type: ignore # Errors can happen when running eval with completion function that uses custom # API endpoints and authentication mechanisms. logger.warning(f"Could not fetch API model IDs from OpenAI API: {err}") diff --git a/evals/utils/api_utils.py b/evals/utils/api_utils.py index cc0d6bb891..7f178a70b7 100644 --- a/evals/utils/api_utils.py +++ b/evals/utils/api_utils.py @@ -14,12 +14,11 @@ @backoff.on_exception( wait_gen=backoff.expo, exception=( - openai.APIError, - openai.APIStatusError, - openai.RateLimitError, - openai.APITimeoutError, - openai.APIConnectionError, - openai.InternalServerError, + openai.error.ServiceUnavailableError, + openai.error.APIError, + openai.error.RateLimitError, + openai.error.APIConnectionError, + openai.error.Timeout, ), max_value=60, factor=1.5, @@ -32,7 +31,7 @@ def openai_completion_create_retrying(*args, **kwargs): result = openai.Completion.create(*args, **kwargs) if "error" in result: logging.warning(result) - raise openai.APIError(result["error"]) + raise openai.error.APIError(result["error"]) return result @@ -53,12 +52,11 @@ def request_with_timeout(func, *args, timeout=EVALS_THREAD_TIMEOUT, **kwargs): @backoff.on_exception( wait_gen=backoff.expo, exception=( - openai.APIError, - openai.APIStatusError, - openai.RateLimitError, - openai.APITimeoutError, - openai.APIConnectionError, - openai.InternalServerError, + openai.error.ServiceUnavailableError, + openai.error.APIError, + openai.error.RateLimitError, + openai.error.APIConnectionError, + openai.error.Timeout, ), max_value=60, factor=1.5, @@ -71,5 +69,5 @@ def openai_chat_completion_create_retrying(*args, **kwargs): result = request_with_timeout(openai.ChatCompletion.create, *args, **kwargs) if "error" in result: logging.warning(result) - raise openai.APIError(result["error"]) + raise openai.error.APIError(result["error"]) return result