Skip to content

Commit

Permalink
Handle openai change of api base for just embeddings (zilliztech#495)
Browse files Browse the repository at this point in the history
  • Loading branch information
keeganmccallum authored Jul 25, 2023
1 parent a3328f2 commit 03a0597
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions gptcache/embedding/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,19 @@ class OpenAI(BaseEmbedding):
embed = encoder.to_embeddings(test_sentence)
"""

def __init__(self, model: str = "text-embedding-ada-002", api_key: str = None):
def __init__(self, model: str = "text-embedding-ada-002", api_key: str = None, api_base: str = None):
if not api_key:
if openai.api_key:
api_key = openai.api_key
else:
api_key = os.getenv("OPENAI_API_KEY")
if not api_base:
if openai.api_base:
api_base = openai.api_base
else:
api_base = os.getenv("OPENAI_API_BASE")
openai.api_key = api_key
self.api_base = api_base # don't override all of openai as we may just want to override for say embeddings
self.model = model
if model in self.dim_dict():
self.__dimension = self.dim_dict()[model]
Expand All @@ -48,7 +54,7 @@ def to_embeddings(self, data, **_):
:return: a text embedding in shape of (dim,).
"""
sentence_embeddings = openai.Embedding.create(model=self.model, input=data)
sentence_embeddings = openai.Embedding.create(model=self.model, input=data, api_base=self.api_base)
return np.array(sentence_embeddings["data"][0]["embedding"]).astype("float32")

@property
Expand Down

0 comments on commit 03a0597

Please sign in to comment.