From 03a059704443961ae5b6ca243e3edc2dc15aeb2a Mon Sep 17 00:00:00 2001 From: Keegan McCallum Date: Mon, 24 Jul 2023 21:03:53 -0700 Subject: [PATCH] Handle openai change of api base for just embeddings (#495) --- gptcache/embedding/openai.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/gptcache/embedding/openai.py b/gptcache/embedding/openai.py index ac00333a..96767f48 100644 --- a/gptcache/embedding/openai.py +++ b/gptcache/embedding/openai.py @@ -27,13 +27,19 @@ class OpenAI(BaseEmbedding): embed = encoder.to_embeddings(test_sentence) """ - def __init__(self, model: str = "text-embedding-ada-002", api_key: str = None): + def __init__(self, model: str = "text-embedding-ada-002", api_key: str = None, api_base: str = None): if not api_key: if openai.api_key: api_key = openai.api_key else: api_key = os.getenv("OPENAI_API_KEY") + if not api_base: + if openai.api_base: + api_base = openai.api_base + else: + api_base = os.getenv("OPENAI_API_BASE") openai.api_key = api_key + self.api_base = api_base # don't override all of openai as we may just want to override for say embeddings self.model = model if model in self.dim_dict(): self.__dimension = self.dim_dict()[model] @@ -48,7 +54,7 @@ def to_embeddings(self, data, **_): :return: a text embedding in shape of (dim,). """ - sentence_embeddings = openai.Embedding.create(model=self.model, input=data) + sentence_embeddings = openai.Embedding.create(model=self.model, input=data, api_base=self.api_base) return np.array(sentence_embeddings["data"][0]["embedding"]).astype("float32") @property