From f1503b8afc1721334f6dc6738f27f403253996bc Mon Sep 17 00:00:00 2001 From: "Rossdan Craig rossdan@lastmileai.dev" <> Date: Fri, 5 Jan 2024 02:02:52 -0500 Subject: [PATCH] Make `get_api_key_from_environment` return nullable Sometimes key is not defined or set, we need to do this to unblock https://github.com/lastmile-ai/aiconfig/pull/769 The functionality is still unchanged because the optional param `required` defaults to true ## Test Plan Pass all automated tests (which it didn't before) --- cookbooks/HuggingFace/hf.py | 4 +-- cookbooks/HuggingFace/python/hf.py | 4 +-- .../src/aiconfig_extension_gemini/Gemini.py | 4 +-- .../text_generation.py | 4 +-- python/src/aiconfig/__init__.py | 2 +- python/src/aiconfig/default_parsers/dalle.py | 4 +-- python/src/aiconfig/default_parsers/hf.py | 4 +-- python/src/aiconfig/default_parsers/openai.py | 4 +-- python/src/aiconfig/util/config_utils.py | 26 ++++++++++++++++--- python/tests/test_util/test_config_util.py | 6 ++--- 10 files changed, 40 insertions(+), 22 deletions(-) diff --git a/cookbooks/HuggingFace/hf.py b/cookbooks/HuggingFace/hf.py index 8257b326e..135899181 100644 --- a/cookbooks/HuggingFace/hf.py +++ b/cookbooks/HuggingFace/hf.py @@ -14,7 +14,7 @@ ParameterizedModelParser, Prompt, PromptMetadata, - get_api_key_from_environment, + maybe_get_api_key_from_environment, resolve_prompt, ) @@ -153,7 +153,7 @@ def __init__(self, model_id: str = None, use_api_token=False): token = None if use_api_token: - token = get_api_key_from_environment("HUGGING_FACE_API_TOKEN") + token = maybe_get_api_key_from_environment("HUGGING_FACE_API_TOKEN") self.client = InferenceClient(model_id, token=token) diff --git a/cookbooks/HuggingFace/python/hf.py b/cookbooks/HuggingFace/python/hf.py index 8257b326e..135899181 100644 --- a/cookbooks/HuggingFace/python/hf.py +++ b/cookbooks/HuggingFace/python/hf.py @@ -14,7 +14,7 @@ ParameterizedModelParser, Prompt, PromptMetadata, - get_api_key_from_environment, + maybe_get_api_key_from_environment, resolve_prompt, ) @@ -153,7 +153,7 @@ def __init__(self, model_id: str = None, use_api_token=False): token = None if use_api_token: - token = get_api_key_from_environment("HUGGING_FACE_API_TOKEN") + token = maybe_get_api_key_from_environment("HUGGING_FACE_API_TOKEN") self.client = InferenceClient(model_id, token=token) diff --git a/extensions/Gemini/python/src/aiconfig_extension_gemini/Gemini.py b/extensions/Gemini/python/src/aiconfig_extension_gemini/Gemini.py index 9766ee4d8..e8a2959f2 100644 --- a/extensions/Gemini/python/src/aiconfig_extension_gemini/Gemini.py +++ b/extensions/Gemini/python/src/aiconfig_extension_gemini/Gemini.py @@ -10,7 +10,7 @@ from aiconfig import ( AIConfigRuntime, CallbackEvent, - get_api_key_from_environment, + maybe_get_api_key_from_environment, ) from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser from aiconfig.model_parser import InferenceOptions @@ -326,7 +326,7 @@ async def run_inference( ) # Auth check. don't need to explicitly set the key as long as this is set as an env var. genai.configure() will pick it up - get_api_key_from_environment("GOOGLE_API_KEY") + maybe_get_api_key_from_environment("GOOGLE_API_KEY") genai.configure() # TODO: check and handle api key here diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/remote_inference_client/text_generation.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/remote_inference_client/text_generation.py index 434648d0d..cc0b448a3 100644 --- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/remote_inference_client/text_generation.py +++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/remote_inference_client/text_generation.py @@ -19,7 +19,7 @@ Prompt, PromptMetadata, ) -from aiconfig.util.config_utils import get_api_key_from_environment +from aiconfig.util.config_utils import maybe_get_api_key_from_environment from aiconfig.util.params import resolve_prompt @@ -154,7 +154,7 @@ def __init__(self, model_id: str = None, use_api_token=False): token = None if use_api_token: - token = get_api_key_from_environment("HUGGING_FACE_API_TOKEN") + token = maybe_get_api_key_from_environment("HUGGING_FACE_API_TOKEN") self.client = InferenceClient(model_id, token=token) diff --git a/python/src/aiconfig/__init__.py b/python/src/aiconfig/__init__.py index 19e20da94..3d77bbc5e 100644 --- a/python/src/aiconfig/__init__.py +++ b/python/src/aiconfig/__init__.py @@ -26,5 +26,5 @@ PromptMetadata, SchemaVersion, ) -from .util.config_utils import get_api_key_from_environment +from .util.config_utils import maybe_get_api_key_from_environment from .util.params import resolve_prompt diff --git a/python/src/aiconfig/default_parsers/dalle.py b/python/src/aiconfig/default_parsers/dalle.py index 47acccac3..644a8ebf8 100644 --- a/python/src/aiconfig/default_parsers/dalle.py +++ b/python/src/aiconfig/default_parsers/dalle.py @@ -3,7 +3,7 @@ import openai from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser -from aiconfig.util.config_utils import get_api_key_from_environment +from aiconfig.util.config_utils import maybe_get_api_key_from_environment from aiconfig.util.params import resolve_prompt from openai import OpenAI @@ -163,7 +163,7 @@ async def run_inference(self, prompt: Prompt, aiconfig, _options, parameters) -> """ # If needed, certify the API key and initialize the OpenAI client if not openai.api_key: - openai.api_key = get_api_key_from_environment("OPENAI_API_KEY") + openai.api_key = maybe_get_api_key_from_environment("OPENAI_API_KEY") if not self.client: self.client = OpenAI(api_key=openai.api_key) diff --git a/python/src/aiconfig/default_parsers/hf.py b/python/src/aiconfig/default_parsers/hf.py index caedf246d..40ddf7ecb 100644 --- a/python/src/aiconfig/default_parsers/hf.py +++ b/python/src/aiconfig/default_parsers/hf.py @@ -4,7 +4,7 @@ from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser from aiconfig.model_parser import InferenceOptions -from aiconfig.util.config_utils import get_api_key_from_environment +from aiconfig.util.config_utils import maybe_get_api_key_from_environment from aiconfig.util.params import resolve_prompt # HuggingFace API imports @@ -145,7 +145,7 @@ def __init__(self, model_id: str = None, use_api_token=False): token = None if use_api_token: - token = get_api_key_from_environment("HUGGING_FACE_API_TOKEN") + token = maybe_get_api_key_from_environment("HUGGING_FACE_API_TOKEN") self.client = InferenceClient(model_id, token=token) diff --git a/python/src/aiconfig/default_parsers/openai.py b/python/src/aiconfig/default_parsers/openai.py index e3e4d95ac..c067abe61 100644 --- a/python/src/aiconfig/default_parsers/openai.py +++ b/python/src/aiconfig/default_parsers/openai.py @@ -6,7 +6,7 @@ from aiconfig.callback import CallbackEvent from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser from aiconfig.model_parser import InferenceOptions -from aiconfig.util.config_utils import get_api_key_from_environment +from aiconfig.util.config_utils import maybe_get_api_key_from_environment from aiconfig.util.params import resolve_prompt, resolve_prompt_string, resolve_system_prompt from openai.types.chat import ChatCompletionMessage @@ -235,7 +235,7 @@ async def run_inference( ) if not openai.api_key: - openai.api_key = get_api_key_from_environment("OPENAI_API_KEY") + openai.api_key = maybe_get_api_key_from_environment("OPENAI_API_KEY") completion_data = await self.deserialize(prompt, aiconfig, parameters) # if stream enabled in runtime options and config, then stream. Otherwise don't stream. diff --git a/python/src/aiconfig/util/config_utils.py b/python/src/aiconfig/util/config_utils.py index 0b2eb6eb1..4728a67a2 100644 --- a/python/src/aiconfig/util/config_utils.py +++ b/python/src/aiconfig/util/config_utils.py @@ -1,6 +1,7 @@ import copy +import dotenv import os -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Union if TYPE_CHECKING: pass @@ -10,10 +11,27 @@ from ..schema import AIConfig -def get_api_key_from_environment(api_key_name: str): - if api_key_name not in os.environ: - raise Exception("Missing API key '{}' in environment".format(api_key_name)) +def maybe_maybe_get_api_key_from_environment( + api_key_name: str, + required: bool = True) -> Union[str, None]: + """Get the API key if it exists, return None or error if it doesn't + + Args: + api_key_name (str): The keyname that we're trying to import from env variable + required (bool, optional): If this is true, we raise an error if the + key is not found + Returns: + Union[str, None]: the value of the key. If `required` is false, this can be None + """ + dotenv.load_dotenv() + if required: + _get_api_key_from_environment(api_key_name) + return os.getenv(api_key_name) + +def _get_api_key_from_environment(api_key_name: str) -> str: + if api_key_name not in os.environ: + raise KeyError(f"Missing API key '{api_key_name}' in environment") return os.environ[api_key_name] diff --git a/python/tests/test_util/test_config_util.py b/python/tests/test_util/test_config_util.py index a831d4fba..b6054261f 100644 --- a/python/tests/test_util/test_config_util.py +++ b/python/tests/test_util/test_config_util.py @@ -1,10 +1,10 @@ -from aiconfig.util.config_utils import get_api_key_from_environment +from aiconfig.util.config_utils import maybe_get_api_key_from_environment -def test_get_api_key_from_environment(): +def test_maybe_get_api_key_from_environment(): key = "TEST_API_KEY" try: - get_api_key_from_environment(key) + maybe_get_api_key_from_environment(key) except Exception: pass # The expected exception was raised, so do nothing else: