Skip to content

Commit

Permalink
fix openai api server docs
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Dec 24, 2023
1 parent 82ef3a3 commit 093574d
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 95 deletions.
29 changes: 11 additions & 18 deletions docs/openai_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,37 +32,30 @@ Now, let us test the API server.
### OpenAI Official SDK
The goal of `openai_api_server.py` is to implement a fully OpenAI-compatible API server, so the models can be used directly with [openai-python](https://github.com/openai/openai-python) library.

First, install openai-python:
First, install OpenAI python package >= 1.0:
```bash
pip install --upgrade openai
```

Then, interact with model vicuna:
Then, interact with the Vicuna model:
```python
from openai import OpenAI
# to get proper authentication, make sure to use a valid key that's listed in
# the --api-keys flag. if no flag value is provided, the `api_key` will be ignored.
client = OpenAI(api_key="EMPTY", base_url="http://localhost:8000/v1", default_headers={"x-foo": "true"})
import openai

openai.api_key = "EMPTY"
openai.base_url = "http://localhost:8000/v1/"

model = "vicuna-7b-v1.5"
prompt = "Once upon a time"

# create a completion (legacy)
completion = client.completions.create(
model=model,
prompt=prompt
)
# create a completion
completion = openai.completions.create(model=model, prompt=prompt, max_tokens=64)
# print the completion
print(prompt + completion.choices[0].text)

# create a chat completion
completion = client.chat.completions.create(
model="vicuna-7b-v1.5",
response_format={ "type": "json_object" },
messages=[
{"role": "system", "content": "You are a helpful assistant designed to output JSON."},
{"role": "user", "content": "Who won the world series in 2020?"}
]
completion = openai.chat.completions.create(
model=model,
messages=[{"role": "user", "content": "Hello! What is your name?"}]
)
# print the completion
print(completion.choices[0].message.content)
Expand Down
104 changes: 27 additions & 77 deletions tests/test_openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,52 +4,31 @@
Launch:
python3 launch_openai_api_test_server.py
"""
from distutils.version import LooseVersion
import warnings

import openai

try:
from openai import OpenAI, AsyncOpenAI
except ImportError:
warnings.warn("openai<1.0 is deprecated")

from fastchat.utils import run_cmd


openai.api_key = "EMPTY" # Not support yet
openai.api_base = "http://localhost:8000/v1"
openai.base_url = "http://localhost:8000/v1/"


def test_list_models():
if LooseVersion(openai.__version__) < LooseVersion("1.0"):
model_list = openai.Model.list()
else:
client = OpenAI(api_key=openai.api_key, base_url=openai.api_base)
model_list = client.models.list()
model_list = openai.models.list()
names = [x.id for x in model_list.data]
return names


def test_completion(model, logprob):
prompt = "Once upon a time"
if LooseVersion(openai.__version__) < LooseVersion("1.0"):
completion = openai.Completion.create(
model=model,
prompt=prompt,
logprobs=logprob,
max_tokens=64,
temperature=0,
)
else:
client = OpenAI(api_key=openai.api_key, base_url=openai.api_base)
# legacy
completion = client.completions.create(
model=model,
prompt=prompt,
logprobs=logprob,
max_tokens=64,
temperature=0,
)
completion = openai.completions.create(
model=model,
prompt=prompt,
logprobs=logprob,
max_tokens=64,
temperature=0,
)

print(f"full text: {prompt + completion.choices[0].text}", flush=True)
if completion.choices[0].logprobs is not None:
Expand All @@ -61,24 +40,13 @@ def test_completion(model, logprob):

def test_completion_stream(model):
prompt = "Once upon a time"
if LooseVersion(openai.__version__) < LooseVersion("1.0"):
res = openai.Completion.create(
model=model,
prompt=prompt,
max_tokens=64,
stream=True,
temperature=0,
)
else:
client = OpenAI(api_key=openai.api_key, base_url=openai.api_base)
# legacy
res = client.completions.create(
model=model,
prompt=prompt,
max_tokens=64,
stream=True,
temperature=0,
)
res = openai.completions.create(
model=model,
prompt=prompt,
max_tokens=64,
stream=True,
temperature=0,
)
print(prompt, end="")
for chunk in res:
content = chunk.choices[0].text
Expand All @@ -87,43 +55,25 @@ def test_completion_stream(model):


def test_embedding(model):
if LooseVersion(openai.__version__) < LooseVersion("1.0"):
embedding = openai.Embedding.create(model=model, input="Hello world!")
else:
client = OpenAI(api_key=openai.api_key, base_url=openai.api_base)
embedding = client.embeddings.create(model=model, input="Hello world!")
embedding = openai.embeddings.create(model=model, input="Hello world!")
print(f"embedding len: {len(embedding.data[0].embedding)}")
print(f"embedding value[:5]: {embedding.data[0].embedding[:5]}")


def test_chat_completion(model):
if LooseVersion(openai.__version__) < LooseVersion("1.0"):
completion = openai.ChatCompletion.create(
model=model,
messages=[{"role": "user", "content": "Hello! What is your name?"}],
temperature=0,
)
else:
client = OpenAI(api_key=openai.api_key, base_url=openai.api_base)
completion = client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": "Hello! What is your name?"}],
temperature=0,
)
completion = openai.chat.completions.create(
model=model,
messages=[{"role": "user", "content": "Hello! What is your name?"}],
temperature=0,
)
print(completion.choices[0].message.content)


def test_chat_completion_stream(model):
messages = [{"role": "user", "content": "Hello! What is your name?"}]
if LooseVersion(openai.__version__) < LooseVersion("1.0"):
res = openai.ChatCompletion.create(
model=model, messages=messages, stream=True, temperature=0
)
else:
client = OpenAI(api_key=openai.api_key, base_url=openai.api_base)
res = client.chat.completions.create(
model=model, messages=messages, stream=True, temperature=0
)
res = openai.chat.completions.create(
model=model, messages=messages, stream=True, temperature=0
)
for chunk in res:
try:
content = chunk.choices[0].delta.content
Expand Down Expand Up @@ -192,7 +142,7 @@ def test_openai_curl():
test_chat_completion_stream(model)
try:
test_embedding(model)
except openai.error.APIError as e:
except openai.APIError as e:
print(f"Embedding error: {e}")

print("===== Test curl =====")
Expand Down

0 comments on commit 093574d

Please sign in to comment.