Skip to content

Commit

Permalink
Openai API migrate (#2765)
Browse files Browse the repository at this point in the history
  • Loading branch information
andy-yang-1 authored Dec 24, 2023
1 parent c70bb3d commit 82ef3a3
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 38 deletions.
22 changes: 14 additions & 8 deletions docs/openai_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,24 +39,30 @@ pip install --upgrade openai

Then, interact with model vicuna:
```python
import openai
from openai import OpenAI
# to get proper authentication, make sure to use a valid key that's listed in
# the --api-keys flag. if no flag value is provided, the `api_key` will be ignored.
openai.api_key = "EMPTY"
openai.api_base = "http://localhost:8000/v1"
client = OpenAI(api_key="EMPTY", base_url="http://localhost:8000/v1", default_headers={"x-foo": "true"})

model = "vicuna-7b-v1.5"
prompt = "Once upon a time"

# create a completion
completion = openai.Completion.create(model=model, prompt=prompt, max_tokens=64)
# create a completion (legacy)
completion = client.completions.create(
model=model,
prompt=prompt
)
# print the completion
print(prompt + completion.choices[0].text)

# create a chat completion
completion = openai.ChatCompletion.create(
model=model,
messages=[{"role": "user", "content": "Hello! What is your name?"}]
completion = client.chat.completions.create(
model="vicuna-7b-v1.5",
response_format={ "type": "json_object" },
messages=[
{"role": "system", "content": "You are a helpful assistant designed to output JSON."},
{"role": "user", "content": "Who won the world series in 2020?"}
]
)
# print the completion
print(completion.choices[0].message.content)
Expand Down
1 change: 1 addition & 0 deletions fastchat/serve/openai_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,7 @@ async def show_available_models():
return ModelList(data=model_cards)


@app.post("/v1chat/completions", dependencies=[Depends(check_api_key)])
@app.post("/v1/chat/completions", dependencies=[Depends(check_api_key)])
async def create_chat_completion(request: ChatCompletionRequest):
"""Creates a completion for the chat message"""
Expand Down
2 changes: 1 addition & 1 deletion playground/test_embedding/test_sentence_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from scipy.spatial.distance import cosine


def get_embedding_from_api(word, model="vicuna-7b-v1.1"):
def get_embedding_from_api(word, model="vicuna-7b-v1.5"):
if "ada" in model:
resp = openai.Embedding.create(
model=model,
Expand Down
115 changes: 86 additions & 29 deletions tests/test_openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,53 @@
Launch:
python3 launch_openai_api_test_server.py
"""
from distutils.version import LooseVersion
import warnings

import openai

try:
from openai import OpenAI, AsyncOpenAI
except ImportError:
warnings.warn("openai<1.0 is deprecated")

from fastchat.utils import run_cmd

openai.api_key = "EMPTY" # Not support yet
openai.api_base = "http://localhost:8000/v1"


def test_list_models():
model_list = openai.Model.list()
names = [x["id"] for x in model_list["data"]]
if LooseVersion(openai.__version__) < LooseVersion("1.0"):
model_list = openai.Model.list()
else:
client = OpenAI(api_key=openai.api_key, base_url=openai.api_base)
model_list = client.models.list()
names = [x.id for x in model_list.data]
return names


def test_completion(model, logprob):
prompt = "Once upon a time"
completion = openai.Completion.create(
model=model,
prompt=prompt,
logprobs=logprob,
max_tokens=64,
temperature=0,
)
if LooseVersion(openai.__version__) < LooseVersion("1.0"):
completion = openai.Completion.create(
model=model,
prompt=prompt,
logprobs=logprob,
max_tokens=64,
temperature=0,
)
else:
client = OpenAI(api_key=openai.api_key, base_url=openai.api_base)
# legacy
completion = client.completions.create(
model=model,
prompt=prompt,
logprobs=logprob,
max_tokens=64,
temperature=0,
)

print(f"full text: {prompt + completion.choices[0].text}", flush=True)
if completion.choices[0].logprobs is not None:
print(
Expand All @@ -38,42 +61,76 @@ def test_completion(model, logprob):

def test_completion_stream(model):
prompt = "Once upon a time"
res = openai.Completion.create(
model=model,
prompt=prompt,
max_tokens=64,
stream=True,
temperature=0,
)
if LooseVersion(openai.__version__) < LooseVersion("1.0"):
res = openai.Completion.create(
model=model,
prompt=prompt,
max_tokens=64,
stream=True,
temperature=0,
)
else:
client = OpenAI(api_key=openai.api_key, base_url=openai.api_base)
# legacy
res = client.completions.create(
model=model,
prompt=prompt,
max_tokens=64,
stream=True,
temperature=0,
)
print(prompt, end="")
for chunk in res:
content = chunk["choices"][0]["text"]
content = chunk.choices[0].text
print(content, end="", flush=True)
print()


def test_embedding(model):
embedding = openai.Embedding.create(model=model, input="Hello world!")
print(f"embedding len: {len(embedding['data'][0]['embedding'])}")
print(f"embedding value[:5]: {embedding['data'][0]['embedding'][:5]}")
if LooseVersion(openai.__version__) < LooseVersion("1.0"):
embedding = openai.Embedding.create(model=model, input="Hello world!")
else:
client = OpenAI(api_key=openai.api_key, base_url=openai.api_base)
embedding = client.embeddings.create(model=model, input="Hello world!")
print(f"embedding len: {len(embedding.data[0].embedding)}")
print(f"embedding value[:5]: {embedding.data[0].embedding[:5]}")


def test_chat_completion(model):
completion = openai.ChatCompletion.create(
model=model,
messages=[{"role": "user", "content": "Hello! What is your name?"}],
temperature=0,
)
if LooseVersion(openai.__version__) < LooseVersion("1.0"):
completion = openai.ChatCompletion.create(
model=model,
messages=[{"role": "user", "content": "Hello! What is your name?"}],
temperature=0,
)
else:
client = OpenAI(api_key=openai.api_key, base_url=openai.api_base)
completion = client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": "Hello! What is your name?"}],
temperature=0,
)
print(completion.choices[0].message.content)


def test_chat_completion_stream(model):
messages = [{"role": "user", "content": "Hello! What is your name?"}]
res = openai.ChatCompletion.create(
model=model, messages=messages, stream=True, temperature=0
)
if LooseVersion(openai.__version__) < LooseVersion("1.0"):
res = openai.ChatCompletion.create(
model=model, messages=messages, stream=True, temperature=0
)
else:
client = OpenAI(api_key=openai.api_key, base_url=openai.api_base)
res = client.chat.completions.create(
model=model, messages=messages, stream=True, temperature=0
)
for chunk in res:
content = chunk["choices"][0]["delta"].get("content", "")
try:
content = chunk.choices[0].delta.content
if content is None:
content = ""
except Exception as e:
content = chunk.choices[0].delta.get("content", "")
print(content, end="", flush=True)
print()

Expand Down

0 comments on commit 82ef3a3

Please sign in to comment.