Skip to content

Commit

Permalink
fix: Fix deprecated max_tokens param in openai ChatCompletionRequest
Browse files Browse the repository at this point in the history
Replace it with a newer one: max_completion_tokens
  • Loading branch information
mickqian committed Jan 25, 2025
1 parent 665e5e8 commit eed4b36
Show file tree
Hide file tree
Showing 22 changed files with 105 additions and 81 deletions.
2 changes: 1 addition & 1 deletion benchmark/llava_bench/bench_sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
@sgl.function
def image_qa(s, image_file, question):
s += sgl.user(sgl.image(image_file) + question)
s += sgl.assistant(sgl.gen("answer", max_tokens=args.max_tokens))
s += sgl.assistant(sgl.gen("answer", max_tokens=args.max_completion_tokens))


def main(args):
Expand Down
12 changes: 6 additions & 6 deletions docs/backend/openai_api_completions.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
" {\"role\": \"user\", \"content\": \"List 3 countries and their capitals.\"},\n",
" ],\n",
" temperature=0,\n",
" max_tokens=64,\n",
" max_completion_tokens=64,\n",
")\n",
"\n",
"print_highlight(f\"Response: {response}\")"
Expand Down Expand Up @@ -114,7 +114,7 @@
" {\"role\": \"user\", \"content\": \"What were their major achievements?\"},\n",
" ],\n",
" temperature=0.3, # Lower temperature for more focused responses\n",
" max_tokens=128, # Reasonable length for a concise response\n",
" max_completion_tokens=128, # Reasonable length for a concise response\n",
" top_p=0.95, # Slightly higher for better fluency\n",
" presence_penalty=0.2, # Mild penalty to avoid repetition\n",
" frequency_penalty=0.2, # Mild penalty for more natural language\n",
Expand Down Expand Up @@ -257,7 +257,7 @@
" \"messages\": [\n",
" {\"role\": \"user\", \"content\": \"Tell me a joke about programming\"}\n",
" ],\n",
" \"max_tokens\": 50,\n",
" \"max_completion_tokens\": 50,\n",
" },\n",
" },\n",
" {\n",
Expand All @@ -267,7 +267,7 @@
" \"body\": {\n",
" \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
" \"messages\": [{\"role\": \"user\", \"content\": \"What is Python?\"}],\n",
" \"max_tokens\": 50,\n",
" \"max_completion_tokens\": 50,\n",
" },\n",
" },\n",
"]\n",
Expand Down Expand Up @@ -369,7 +369,7 @@
" \"content\": \"Write a detailed story about topic. Make it very long.\",\n",
" },\n",
" ],\n",
" \"max_tokens\": 500,\n",
" \"max_completion_tokens\": 500,\n",
" },\n",
" }\n",
" )\n",
Expand Down Expand Up @@ -446,7 +446,7 @@
" \"content\": \"Write a detailed story about topic. Make it very long.\",\n",
" },\n",
" ],\n",
" \"max_tokens\": 500,\n",
" \"max_completion_tokens\": 500,\n",
" },\n",
" }\n",
" )\n",
Expand Down
6 changes: 3 additions & 3 deletions docs/backend/openai_api_vision.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
" ]\n",
" }\n",
" ],\n",
" \"max_tokens\": 300\n",
" \"max_completion_tokens\": 300\n",
" }'\n",
"\"\"\"\n",
"\n",
Expand Down Expand Up @@ -130,7 +130,7 @@
" ],\n",
" }\n",
" ],\n",
" \"max_tokens\": 300,\n",
" \"max_completion_tokens\": 300,\n",
"}\n",
"\n",
"response = requests.post(url, json=data)\n",
Expand Down Expand Up @@ -173,7 +173,7 @@
" ],\n",
" }\n",
" ],\n",
" max_tokens=300,\n",
" max_completion_tokens=300,\n",
")\n",
"\n",
"print_highlight(response.choices[0].message.content)"
Expand Down
8 changes: 4 additions & 4 deletions docs/backend/structured_outputs.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@
" },\n",
" ],\n",
" temperature=0,\n",
" max_tokens=128,\n",
" max_completion_tokens=128,\n",
" response_format={\n",
" \"type\": \"json_schema\",\n",
" \"json_schema\": {\n",
Expand Down Expand Up @@ -150,7 +150,7 @@
" },\n",
" ],\n",
" temperature=0,\n",
" max_tokens=128,\n",
" max_completion_tokens=128,\n",
" response_format={\n",
" \"type\": \"json_schema\",\n",
" \"json_schema\": {\"name\": \"foo\", \"schema\": json.loads(json_schema)},\n",
Expand Down Expand Up @@ -191,7 +191,7 @@
" },\n",
" ],\n",
" temperature=0,\n",
" max_tokens=32,\n",
" max_completion_tokens=32,\n",
" extra_body={\"ebnf\": ebnf_grammar},\n",
")\n",
"\n",
Expand All @@ -217,7 +217,7 @@
" {\"role\": \"user\", \"content\": \"What is the capital of France?\"},\n",
" ],\n",
" temperature=0,\n",
" max_tokens=128,\n",
" max_completion_tokens=128,\n",
" extra_body={\"regex\": \"(Paris|London)\"},\n",
")\n",
"\n",
Expand Down
16 changes: 8 additions & 8 deletions docs/frontend/frontend.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ from sglang import function, system, user, assistant, gen, set_default_backend,
def multi_turn_question(s, question_1, question_2):
s += system("You are a helpful assistant.")
s += user(question_1)
s += assistant(gen("answer_1", max_tokens=256))
s += assistant(gen("answer_1", max_completion_tokens=256))
s += user(question_2)
s += assistant(gen("answer_2", max_tokens=256))
s += assistant(gen("answer_2", max_completion_tokens=256))

set_default_backend(RuntimeEndpoint("http://localhost:30000"))

Expand Down Expand Up @@ -50,9 +50,9 @@ from sglang import function, system, user, assistant, gen, set_default_backend,
def multi_turn_question(s, question_1, question_2):
s += system("You are a helpful assistant.")
s += user(question_1)
s += assistant(gen("answer_1", max_tokens=256))
s += assistant(gen("answer_1", max_completion_tokens=256))
s += user(question_2)
s += assistant(gen("answer_2", max_tokens=256))
s += assistant(gen("answer_2", max_completion_tokens=256))

set_default_backend(OpenAI("gpt-3.5-turbo"))

Expand Down Expand Up @@ -114,7 +114,7 @@ def tip_suggestion(s):
forks = s.fork(2)
for i, f in enumerate(forks):
f += f"Now, expand tip {i+1} into a paragraph:\n"
f += sgl.gen(f"detailed_tip", max_tokens=256, stop="\n\n")
f += sgl.gen(f"detailed_tip", max_completion_tokens=256, stop="\n\n")

s += "Tip 1:" + forks[0]["detailed_tip"] + "\n"
s += "Tip 2:" + forks[1]["detailed_tip"] + "\n"
Expand All @@ -128,7 +128,7 @@ Use `sgl.image` to pass an image as input.
@sgl.function
def image_qa(s, image_file, question):
s += sgl.user(sgl.image(image_file) + question)
s += sgl.assistant(sgl.gen("answer", max_tokens=256)
s += sgl.assistant(sgl.gen("answer", max_completion_tokens=256)
```

See also [local_example_llava_next.py](https://github.com/sgl-project/sglang/blob/main/examples/frontend_language/quick_start/local_example_llava_next.py).
Expand Down Expand Up @@ -172,7 +172,7 @@ character_regex = (
@sgl.function
def character_gen(s, name):
s += name + " is a character in Harry Potter. Please fill in the following information about this character.\n"
s += sgl.gen("json_output", max_tokens=256, regex=character_regex)
s += sgl.gen("json_output", max_completion_tokens=256, regex=character_regex)
```

See also [json_decode.py](https://github.com/sgl-project/sglang/blob/main/examples/frontend_language/usage/json_decode.py) for an additional example of specifying formats with Pydantic models.
Expand Down Expand Up @@ -229,7 +229,7 @@ def chat_example(s):
s += "Question: What is the capital of France?"

s += sgl.assistant_begin()
s += "Answer: " + sgl.gen(max_tokens=100, stop="\n")
s += "Answer: " + sgl.gen(max_completion_tokens=100, stop="\n")
s += sgl.assistant_end()
```

Expand Down
4 changes: 2 additions & 2 deletions docs/start/send_request.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@
" {\"role\": \"user\", \"content\": \"List 3 countries and their capitals.\"},\n",
" ],\n",
" temperature=0,\n",
" max_tokens=64,\n",
" max_completion_tokens=64,\n",
")\n",
"print_highlight(response)"
]
Expand Down Expand Up @@ -153,7 +153,7 @@
" {\"role\": \"user\", \"content\": \"List 3 countries and their capitals.\"},\n",
" ],\n",
" temperature=0,\n",
" max_tokens=64,\n",
" max_completion_tokens=64,\n",
" stream=True,\n",
")\n",
"\n",
Expand Down
12 changes: 6 additions & 6 deletions python/sglang/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def get_server_info(backend: Optional[BaseBackend] = None):

def gen(
name: Optional[str] = None,
max_tokens: Optional[int] = None,
max_completion_tokens: Optional[int] = None,
min_tokens: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
Expand Down Expand Up @@ -113,7 +113,7 @@ def gen(

return SglGen(
name,
max_tokens,
max_completion_tokens,
min_tokens,
stop,
stop_token_ids,
Expand All @@ -136,7 +136,7 @@ def gen(

def gen_int(
name: Optional[str] = None,
max_tokens: Optional[int] = None,
max_completion_tokens: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: Optional[float] = None,
Expand All @@ -153,7 +153,7 @@ def gen_int(
):
return SglGen(
name,
max_tokens,
max_completion_tokens,
None,
stop,
stop_token_ids,
Expand All @@ -175,7 +175,7 @@ def gen_int(

def gen_string(
name: Optional[str] = None,
max_tokens: Optional[int] = None,
max_completion_tokens: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: Optional[float] = None,
Expand All @@ -192,7 +192,7 @@ def gen_string(
):
return SglGen(
name,
max_tokens,
max_completion_tokens,
None,
stop,
stop_token_ids,
Expand Down
6 changes: 3 additions & 3 deletions python/sglang/bench_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ async def async_request_trt_llm(
"text_input": request_func_input.prompt,
"temperature": 0.000001,
"top_p": 1.0,
"max_tokens": request_func_input.output_len,
"max_completion_tokens": request_func_input.output_len,
"stream": True,
"min_length": request_func_input.output_len,
"end_id": 1048576,
Expand Down Expand Up @@ -160,7 +160,7 @@ async def async_request_openai_completions(
"prompt": prompt,
"temperature": 0.0,
"best_of": 1,
"max_tokens": request_func_input.output_len,
"max_completion_tokens": request_func_input.output_len,
"stream": not args.disable_stream,
"ignore_eos": not args.disable_ignore_eos,
**request_func_input.extra_request_body,
Expand Down Expand Up @@ -239,7 +239,7 @@ async def async_request_truss(
"prompt": prompt,
"temperature": 0.0,
"best_of": 1,
"max_tokens": request_func_input.output_len,
"max_completion_tokens": request_func_input.output_len,
"stream": not args.disable_stream,
"ignore_eos": not args.disable_ignore_eos,
**request_func_input.extra_request_body,
Expand Down
23 changes: 11 additions & 12 deletions python/sglang/lang/backend/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import time
import warnings
from typing import Callable, List, Optional, Union
from typing import List, Optional

import numpy as np

Expand All @@ -18,7 +18,6 @@
except ImportError as e:
openai = tiktoken = e


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -112,18 +111,18 @@ def _prepare_spec_execution(
num_api_spec_tokens: int,
spec_var_name: str,
):
if "max_tokens" not in self.spec_kwargs:
self.spec_kwargs["max_tokens"] = num_api_spec_tokens
if "max_completion_tokens" not in self.spec_kwargs:
self.spec_kwargs["max_completion_tokens"] = num_api_spec_tokens
else:
assert self.spec_kwargs["max_tokens"] == num_api_spec_tokens
assert self.spec_kwargs["max_completion_tokens"] == num_api_spec_tokens

params = sampling_params.to_openai_kwargs()
params = sampling_params.to_openai_kwargs(self.is_chat_model)
for key, value in params.items():
if key in ["stop"]:
continue
if key in ["max_tokens"]:
if key in ["max_completion_tokens"]:
warnings.warn(
"The parameter max_tokens will be overwritten by speculated number of tokens."
"The parameter max_completion_tokens will be overwritten by speculated number of tokens."
)
continue
if key not in self.spec_kwargs:
Expand Down Expand Up @@ -160,7 +159,7 @@ def generate(
else:
prompt = s.text_

kwargs = sampling_params.to_openai_kwargs()
kwargs = sampling_params.to_openai_kwargs(self.is_chat_model)
comp = openai_completion(
client=self.client,
token_usage=self.token_usage,
Expand All @@ -173,7 +172,7 @@ def generate(
assert (
not self.is_chat_model
), "constrained type not supported on chat model"
kwargs = sampling_params.to_openai_kwargs()
kwargs = sampling_params.to_openai_kwargs(self.is_chat_model)
kwargs.pop("stop")
comp = openai_completion(
client=self.client,
Expand All @@ -189,7 +188,7 @@ def generate(
assert (
not self.is_chat_model
), "constrained type not supported on chat model"
kwargs = sampling_params.to_openai_kwargs()
kwargs = sampling_params.to_openai_kwargs(self.is_chat_model)
kwargs.pop("stop")
comp = openai_completion(
client=self.client,
Expand Down Expand Up @@ -279,7 +278,7 @@ def generate_stream(
else:
prompt = s.text_

kwargs = sampling_params.to_openai_kwargs()
kwargs = sampling_params.to_openai_kwargs(self.is_chat_model)
generator = openai_completion_stream(
client=self.client,
token_usage=self.token_usage,
Expand Down
16 changes: 10 additions & 6 deletions python/sglang/lang/choices.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,13 @@ def __call__(
its average logprob for comparison against the longer option."""

num_options = len(choices)
max_tokens = max(len(option) for option in input_token_logprobs)
max_completion_tokens = max(len(option) for option in input_token_logprobs)
logprob_matrix = self._build_logprob_matrix(
input_token_logprobs, max_tokens, num_options
input_token_logprobs, max_completion_tokens, num_options
)
remaining = self._greedy_selection(
logprob_matrix, num_options, max_completion_tokens
)
remaining = self._greedy_selection(logprob_matrix, num_options, max_tokens)

best_choice = choices[remaining[0]]
meta_info = {
Expand All @@ -84,13 +86,15 @@ def __call__(
}
return ChoicesDecision(decision=best_choice, meta_info=meta_info)

def _build_logprob_matrix(self, input_token_logprobs, max_tokens, num_options):
logprob_matrix = np.zeros((num_options, max_tokens))
def _build_logprob_matrix(
self, input_token_logprobs, max_completion_tokens, num_options
):
logprob_matrix = np.zeros((num_options, max_completion_tokens))
for i, option in enumerate(input_token_logprobs):
actual_logprobs = [token[0] for token in option]
avg_logprob = np.mean(actual_logprobs)
logprob_matrix[i, : len(option)] = actual_logprobs
if len(option) < max_tokens:
if len(option) < max_completion_tokens:
logprob_matrix[i, len(option) :] = avg_logprob
return logprob_matrix

Expand Down
Loading

0 comments on commit eed4b36

Please sign in to comment.