diff --git a/benchmark/llava_bench/bench_sglang.py b/benchmark/llava_bench/bench_sglang.py index f84c8a90fb6b..2e8460ef5dc2 100644 --- a/benchmark/llava_bench/bench_sglang.py +++ b/benchmark/llava_bench/bench_sglang.py @@ -16,7 +16,7 @@ @sgl.function def image_qa(s, image_file, question): s += sgl.user(sgl.image(image_file) + question) - s += sgl.assistant(sgl.gen("answer", max_tokens=args.max_tokens)) + s += sgl.assistant(sgl.gen("answer", max_tokens=args.max_completion_tokens)) def main(args): diff --git a/docs/backend/openai_api_completions.ipynb b/docs/backend/openai_api_completions.ipynb index 58b524108db1..c1ea2d3e5712 100644 --- a/docs/backend/openai_api_completions.ipynb +++ b/docs/backend/openai_api_completions.ipynb @@ -76,7 +76,7 @@ " {\"role\": \"user\", \"content\": \"List 3 countries and their capitals.\"},\n", " ],\n", " temperature=0,\n", - " max_tokens=64,\n", + " max_completion_tokens=64,\n", ")\n", "\n", "print_highlight(f\"Response: {response}\")" @@ -114,7 +114,7 @@ " {\"role\": \"user\", \"content\": \"What were their major achievements?\"},\n", " ],\n", " temperature=0.3, # Lower temperature for more focused responses\n", - " max_tokens=128, # Reasonable length for a concise response\n", + " max_completion_tokens=128, # Reasonable length for a concise response\n", " top_p=0.95, # Slightly higher for better fluency\n", " presence_penalty=0.2, # Mild penalty to avoid repetition\n", " frequency_penalty=0.2, # Mild penalty for more natural language\n", @@ -257,7 +257,7 @@ " \"messages\": [\n", " {\"role\": \"user\", \"content\": \"Tell me a joke about programming\"}\n", " ],\n", - " \"max_tokens\": 50,\n", + " \"max_completion_tokens\": 50,\n", " },\n", " },\n", " {\n", @@ -267,7 +267,7 @@ " \"body\": {\n", " \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n", " \"messages\": [{\"role\": \"user\", \"content\": \"What is Python?\"}],\n", - " \"max_tokens\": 50,\n", + " \"max_completion_tokens\": 50,\n", " },\n", " },\n", "]\n", @@ -369,7 +369,7 @@ " \"content\": \"Write a detailed story about topic. Make it very long.\",\n", " },\n", " ],\n", - " \"max_tokens\": 500,\n", + " \"max_completion_tokens\": 500,\n", " },\n", " }\n", " )\n", @@ -446,7 +446,7 @@ " \"content\": \"Write a detailed story about topic. Make it very long.\",\n", " },\n", " ],\n", - " \"max_tokens\": 500,\n", + " \"max_completion_tokens\": 500,\n", " },\n", " }\n", " )\n", diff --git a/docs/backend/openai_api_vision.ipynb b/docs/backend/openai_api_vision.ipynb index da8864c24c93..5c278422efe6 100644 --- a/docs/backend/openai_api_vision.ipynb +++ b/docs/backend/openai_api_vision.ipynb @@ -89,7 +89,7 @@ " ]\n", " }\n", " ],\n", - " \"max_tokens\": 300\n", + " \"max_completion_tokens\": 300\n", " }'\n", "\"\"\"\n", "\n", @@ -130,7 +130,7 @@ " ],\n", " }\n", " ],\n", - " \"max_tokens\": 300,\n", + " \"max_completion_tokens\": 300,\n", "}\n", "\n", "response = requests.post(url, json=data)\n", @@ -173,7 +173,7 @@ " ],\n", " }\n", " ],\n", - " max_tokens=300,\n", + " max_completion_tokens=300,\n", ")\n", "\n", "print_highlight(response.choices[0].message.content)" diff --git a/docs/backend/structured_outputs.ipynb b/docs/backend/structured_outputs.ipynb index e413743ccfde..77d20aae1312 100644 --- a/docs/backend/structured_outputs.ipynb +++ b/docs/backend/structured_outputs.ipynb @@ -98,7 +98,7 @@ " },\n", " ],\n", " temperature=0,\n", - " max_tokens=128,\n", + " max_completion_tokens=128,\n", " response_format={\n", " \"type\": \"json_schema\",\n", " \"json_schema\": {\n", @@ -150,7 +150,7 @@ " },\n", " ],\n", " temperature=0,\n", - " max_tokens=128,\n", + " max_completion_tokens=128,\n", " response_format={\n", " \"type\": \"json_schema\",\n", " \"json_schema\": {\"name\": \"foo\", \"schema\": json.loads(json_schema)},\n", @@ -191,7 +191,7 @@ " },\n", " ],\n", " temperature=0,\n", - " max_tokens=32,\n", + " max_completion_tokens=32,\n", " extra_body={\"ebnf\": ebnf_grammar},\n", ")\n", "\n", @@ -217,7 +217,7 @@ " {\"role\": \"user\", \"content\": \"What is the capital of France?\"},\n", " ],\n", " temperature=0,\n", - " max_tokens=128,\n", + " max_completion_tokens=128,\n", " extra_body={\"regex\": \"(Paris|London)\"},\n", ")\n", "\n", diff --git a/docs/frontend/frontend.md b/docs/frontend/frontend.md index 8b56fa487682..555197c31595 100644 --- a/docs/frontend/frontend.md +++ b/docs/frontend/frontend.md @@ -19,9 +19,9 @@ from sglang import function, system, user, assistant, gen, set_default_backend, def multi_turn_question(s, question_1, question_2): s += system("You are a helpful assistant.") s += user(question_1) - s += assistant(gen("answer_1", max_tokens=256)) + s += assistant(gen("answer_1", max_completion_tokens=256)) s += user(question_2) - s += assistant(gen("answer_2", max_tokens=256)) + s += assistant(gen("answer_2", max_completion_tokens=256)) set_default_backend(RuntimeEndpoint("http://localhost:30000")) @@ -50,9 +50,9 @@ from sglang import function, system, user, assistant, gen, set_default_backend, def multi_turn_question(s, question_1, question_2): s += system("You are a helpful assistant.") s += user(question_1) - s += assistant(gen("answer_1", max_tokens=256)) + s += assistant(gen("answer_1", max_completion_tokens=256)) s += user(question_2) - s += assistant(gen("answer_2", max_tokens=256)) + s += assistant(gen("answer_2", max_completion_tokens=256)) set_default_backend(OpenAI("gpt-3.5-turbo")) @@ -114,7 +114,7 @@ def tip_suggestion(s): forks = s.fork(2) for i, f in enumerate(forks): f += f"Now, expand tip {i+1} into a paragraph:\n" - f += sgl.gen(f"detailed_tip", max_tokens=256, stop="\n\n") + f += sgl.gen(f"detailed_tip", max_completion_tokens=256, stop="\n\n") s += "Tip 1:" + forks[0]["detailed_tip"] + "\n" s += "Tip 2:" + forks[1]["detailed_tip"] + "\n" @@ -128,7 +128,7 @@ Use `sgl.image` to pass an image as input. @sgl.function def image_qa(s, image_file, question): s += sgl.user(sgl.image(image_file) + question) - s += sgl.assistant(sgl.gen("answer", max_tokens=256) + s += sgl.assistant(sgl.gen("answer", max_completion_tokens=256) ``` See also [local_example_llava_next.py](https://github.com/sgl-project/sglang/blob/main/examples/frontend_language/quick_start/local_example_llava_next.py). @@ -172,7 +172,7 @@ character_regex = ( @sgl.function def character_gen(s, name): s += name + " is a character in Harry Potter. Please fill in the following information about this character.\n" - s += sgl.gen("json_output", max_tokens=256, regex=character_regex) + s += sgl.gen("json_output", max_completion_tokens=256, regex=character_regex) ``` See also [json_decode.py](https://github.com/sgl-project/sglang/blob/main/examples/frontend_language/usage/json_decode.py) for an additional example of specifying formats with Pydantic models. @@ -229,7 +229,7 @@ def chat_example(s): s += "Question: What is the capital of France?" s += sgl.assistant_begin() - s += "Answer: " + sgl.gen(max_tokens=100, stop="\n") + s += "Answer: " + sgl.gen(max_completion_tokens=100, stop="\n") s += sgl.assistant_end() ``` diff --git a/docs/start/send_request.ipynb b/docs/start/send_request.ipynb index 4cb46f1edc98..44bcb8c38665 100644 --- a/docs/start/send_request.ipynb +++ b/docs/start/send_request.ipynb @@ -124,7 +124,7 @@ " {\"role\": \"user\", \"content\": \"List 3 countries and their capitals.\"},\n", " ],\n", " temperature=0,\n", - " max_tokens=64,\n", + " max_completion_tokens=64,\n", ")\n", "print_highlight(response)" ] @@ -153,7 +153,7 @@ " {\"role\": \"user\", \"content\": \"List 3 countries and their capitals.\"},\n", " ],\n", " temperature=0,\n", - " max_tokens=64,\n", + " max_completion_tokens=64,\n", " stream=True,\n", ")\n", "\n", diff --git a/python/sglang/api.py b/python/sglang/api.py index 7ef306380a91..8cc4ad12b78a 100644 --- a/python/sglang/api.py +++ b/python/sglang/api.py @@ -73,7 +73,7 @@ def get_server_info(backend: Optional[BaseBackend] = None): def gen( name: Optional[str] = None, - max_tokens: Optional[int] = None, + max_completion_tokens: Optional[int] = None, min_tokens: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stop_token_ids: Optional[List[int]] = None, @@ -113,7 +113,7 @@ def gen( return SglGen( name, - max_tokens, + max_completion_tokens, min_tokens, stop, stop_token_ids, @@ -136,7 +136,7 @@ def gen( def gen_int( name: Optional[str] = None, - max_tokens: Optional[int] = None, + max_completion_tokens: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stop_token_ids: Optional[List[int]] = None, temperature: Optional[float] = None, @@ -153,7 +153,7 @@ def gen_int( ): return SglGen( name, - max_tokens, + max_completion_tokens, None, stop, stop_token_ids, @@ -175,7 +175,7 @@ def gen_int( def gen_string( name: Optional[str] = None, - max_tokens: Optional[int] = None, + max_completion_tokens: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stop_token_ids: Optional[List[int]] = None, temperature: Optional[float] = None, @@ -192,7 +192,7 @@ def gen_string( ): return SglGen( name, - max_tokens, + max_completion_tokens, None, stop, stop_token_ids, diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 10ce965be742..0419c06657b1 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -86,7 +86,7 @@ async def async_request_trt_llm( "text_input": request_func_input.prompt, "temperature": 0.000001, "top_p": 1.0, - "max_tokens": request_func_input.output_len, + "max_completion_tokens": request_func_input.output_len, "stream": True, "min_length": request_func_input.output_len, "end_id": 1048576, @@ -160,7 +160,7 @@ async def async_request_openai_completions( "prompt": prompt, "temperature": 0.0, "best_of": 1, - "max_tokens": request_func_input.output_len, + "max_completion_tokens": request_func_input.output_len, "stream": not args.disable_stream, "ignore_eos": not args.disable_ignore_eos, **request_func_input.extra_request_body, @@ -239,7 +239,7 @@ async def async_request_truss( "prompt": prompt, "temperature": 0.0, "best_of": 1, - "max_tokens": request_func_input.output_len, + "max_completion_tokens": request_func_input.output_len, "stream": not args.disable_stream, "ignore_eos": not args.disable_ignore_eos, **request_func_input.extra_request_body, diff --git a/python/sglang/lang/backend/openai.py b/python/sglang/lang/backend/openai.py index 4f37da79b7e8..52bf89b9a9e0 100644 --- a/python/sglang/lang/backend/openai.py +++ b/python/sglang/lang/backend/openai.py @@ -2,7 +2,7 @@ import logging import time import warnings -from typing import Callable, List, Optional, Union +from typing import List, Optional import numpy as np @@ -18,7 +18,6 @@ except ImportError as e: openai = tiktoken = e - logger = logging.getLogger(__name__) @@ -112,18 +111,18 @@ def _prepare_spec_execution( num_api_spec_tokens: int, spec_var_name: str, ): - if "max_tokens" not in self.spec_kwargs: - self.spec_kwargs["max_tokens"] = num_api_spec_tokens + if "max_completion_tokens" not in self.spec_kwargs: + self.spec_kwargs["max_completion_tokens"] = num_api_spec_tokens else: - assert self.spec_kwargs["max_tokens"] == num_api_spec_tokens + assert self.spec_kwargs["max_completion_tokens"] == num_api_spec_tokens - params = sampling_params.to_openai_kwargs() + params = sampling_params.to_openai_kwargs(self.is_chat_model) for key, value in params.items(): if key in ["stop"]: continue - if key in ["max_tokens"]: + if key in ["max_completion_tokens"]: warnings.warn( - "The parameter max_tokens will be overwritten by speculated number of tokens." + "The parameter max_completion_tokens will be overwritten by speculated number of tokens." ) continue if key not in self.spec_kwargs: @@ -160,7 +159,7 @@ def generate( else: prompt = s.text_ - kwargs = sampling_params.to_openai_kwargs() + kwargs = sampling_params.to_openai_kwargs(self.is_chat_model) comp = openai_completion( client=self.client, token_usage=self.token_usage, @@ -173,7 +172,7 @@ def generate( assert ( not self.is_chat_model ), "constrained type not supported on chat model" - kwargs = sampling_params.to_openai_kwargs() + kwargs = sampling_params.to_openai_kwargs(self.is_chat_model) kwargs.pop("stop") comp = openai_completion( client=self.client, @@ -189,7 +188,7 @@ def generate( assert ( not self.is_chat_model ), "constrained type not supported on chat model" - kwargs = sampling_params.to_openai_kwargs() + kwargs = sampling_params.to_openai_kwargs(self.is_chat_model) kwargs.pop("stop") comp = openai_completion( client=self.client, @@ -279,7 +278,7 @@ def generate_stream( else: prompt = s.text_ - kwargs = sampling_params.to_openai_kwargs() + kwargs = sampling_params.to_openai_kwargs(self.is_chat_model) generator = openai_completion_stream( client=self.client, token_usage=self.token_usage, diff --git a/python/sglang/lang/choices.py b/python/sglang/lang/choices.py index e52c6b362179..61b68142aee3 100644 --- a/python/sglang/lang/choices.py +++ b/python/sglang/lang/choices.py @@ -69,11 +69,13 @@ def __call__( its average logprob for comparison against the longer option.""" num_options = len(choices) - max_tokens = max(len(option) for option in input_token_logprobs) + max_completion_tokens = max(len(option) for option in input_token_logprobs) logprob_matrix = self._build_logprob_matrix( - input_token_logprobs, max_tokens, num_options + input_token_logprobs, max_completion_tokens, num_options + ) + remaining = self._greedy_selection( + logprob_matrix, num_options, max_completion_tokens ) - remaining = self._greedy_selection(logprob_matrix, num_options, max_tokens) best_choice = choices[remaining[0]] meta_info = { @@ -84,13 +86,15 @@ def __call__( } return ChoicesDecision(decision=best_choice, meta_info=meta_info) - def _build_logprob_matrix(self, input_token_logprobs, max_tokens, num_options): - logprob_matrix = np.zeros((num_options, max_tokens)) + def _build_logprob_matrix( + self, input_token_logprobs, max_completion_tokens, num_options + ): + logprob_matrix = np.zeros((num_options, max_completion_tokens)) for i, option in enumerate(input_token_logprobs): actual_logprobs = [token[0] for token in option] avg_logprob = np.mean(actual_logprobs) logprob_matrix[i, : len(option)] = actual_logprobs - if len(option) < max_tokens: + if len(option) < max_completion_tokens: logprob_matrix[i, len(option) :] = avg_logprob return logprob_matrix diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index 4c294781c20e..6e4b818c69fa 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -744,7 +744,7 @@ def _resolve_sampling_params(self, sampling_params): # deepcopy is required because the dict has lists inside clone = copy.deepcopy(self.default_sampling_para) - for item in [ + for field in [ "max_new_tokens", "min_new_tokens", "stop", @@ -764,9 +764,9 @@ def _resolve_sampling_params(self, sampling_params): "regex", "json_schema", ]: - value = getattr(sampling_params, item, None) + value = getattr(sampling_params, field, None) if value is not None: - setattr(clone, item, value) + setattr(clone, field, value) if self.chat_template.stop_str: if clone.stop == (): diff --git a/python/sglang/lang/ir.py b/python/sglang/lang/ir.py index 1ae5ac1063a1..7e3c0d233e8c 100644 --- a/python/sglang/lang/ir.py +++ b/python/sglang/lang/ir.py @@ -57,12 +57,14 @@ def clone(self): self.json_schema, ) - def to_openai_kwargs(self): + def to_openai_kwargs(self, is_chat_model): # OpenAI does not support top_k, so we drop it here if self.regex is not None: warnings.warn("Regular expression is not supported in the OpenAI backend.") return { - "max_tokens": self.max_new_tokens, + ( + "max_completion_tokens" if is_chat_model else "max_tokens" + ): self.max_new_tokens, "stop": self.stop or None, "temperature": self.temperature, "top_p": self.top_p, @@ -91,7 +93,7 @@ def to_anthropic_kwargs(self): "Regular expression is not supported in the Anthropic backend." ) return { - "max_tokens": self.max_new_tokens, + "max_completion_tokens": self.max_new_tokens, "stop_sequences": ( self.stop if isinstance(self.stop, (list, tuple)) else [self.stop] ), @@ -104,7 +106,7 @@ def to_litellm_kwargs(self): if self.regex is not None: warnings.warn("Regular expression is not supported in the LiteLLM backend.") return { - "max_tokens": self.max_new_tokens, + "max_completion_tokens": self.max_new_tokens, "stop": self.stop or None, "temperature": self.temperature, "top_p": self.top_p, diff --git a/python/sglang/llama3_eval.py b/python/sglang/llama3_eval.py index 35bd4a7e4d4c..957e02e8e8e2 100644 --- a/python/sglang/llama3_eval.py +++ b/python/sglang/llama3_eval.py @@ -38,7 +38,14 @@ async def fetch_responses( - client, prompt, semaphore, index, provider, model_size, output_dir, max_tokens + client, + prompt, + semaphore, + index, + provider, + model_size, + output_dir, + max_completion_tokens, ): output_file = os.path.join(output_dir, f"response_{index}.pkl") if os.path.exists(output_file): @@ -50,7 +57,7 @@ async def fetch_responses( model=provider_to_models[provider][model_size], prompt=prompt, temperature=0.0, - max_tokens=max_tokens, + max_tokens=max_completion_tokens, ) if isinstance(response, openai.BadRequestError): with open(output_file, "wb") as f: @@ -130,7 +137,7 @@ async def benchmark(args): args.provider, args.model_size, args.output_dir, - max_tokens=max_tokens, + max_completion_tokens=max_tokens, ) ) ) diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 5056ba22ef99..e64b763441c4 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -945,7 +945,7 @@ def v1_chat_generate_request( sampling_params = { "temperature": request.temperature, - "max_new_tokens": request.max_tokens, + "max_new_tokens": request.get_max_output_tokens(), "min_new_tokens": request.min_tokens, "stop": stop, "stop_token_ids": request.stop_token_ids, diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index 2ed9006c0ea2..6574892e173b 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -296,6 +296,7 @@ class ChatCompletionRequest(BaseModel): logprobs: bool = False top_logprobs: Optional[int] = None max_tokens: Optional[int] = None + max_completion_tokens: Optional[int] = None n: int = 1 presence_penalty: float = 0.0 response_format: Optional[ResponseFormat] = None @@ -325,6 +326,14 @@ class ChatCompletionRequest(BaseModel): lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None session_params: Optional[Dict] = None + def get_max_output_tokens(self) -> int: + if self.max_completion_tokens: + return self.max_completion_tokens + elif self.max_tokens: + return self.max_tokens + else: + return None + class FunctionResponse(BaseModel): """Function response.""" diff --git a/python/sglang/test/few_shot_gsm8k.py b/python/sglang/test/few_shot_gsm8k.py index 9657e730084c..b9e793ad12c0 100644 --- a/python/sglang/test/few_shot_gsm8k.py +++ b/python/sglang/test/few_shot_gsm8k.py @@ -81,7 +81,7 @@ def few_shot_gsm8k(s, question): s += few_shot_examples + question s += sgl.gen( "answer", - max_tokens=args.max_new_tokens, + max_completion_tokens=args.max_new_tokens, stop=["Question", "Assistant:", "<|separator|>"], ) diff --git a/python/sglang/test/run_eval.py b/python/sglang/test/run_eval.py index fe88171ce274..8d2a9bdfdbfa 100644 --- a/python/sglang/test/run_eval.py +++ b/python/sglang/test/run_eval.py @@ -65,7 +65,7 @@ def run_eval(args): sampler = ChatCompletionSampler( model=args.model, - max_tokens=2048, + max_completion_tokens=2048, base_url=base_url, temperature=getattr(args, "temperature", 0.0), ) diff --git a/python/sglang/test/simple_eval_common.py b/python/sglang/test/simple_eval_common.py index 518e6245c00b..fa9933277a26 100644 --- a/python/sglang/test/simple_eval_common.py +++ b/python/sglang/test/simple_eval_common.py @@ -91,7 +91,7 @@ def __init__( model: Optional[str] = None, system_message: Optional[str] = None, temperature: float = 0.0, - max_tokens: int = 2048, + max_completion_tokens: int = 2048, ): self.client = OpenAI(base_url=base_url, http_client=LargerHttpxClient()) @@ -101,7 +101,7 @@ def __init__( self.model = model self.system_message = system_message self.temperature = temperature - self.max_tokens = max_tokens + self.max_completion_tokens = max_completion_tokens self.image_format = "url" def _handle_image( @@ -137,7 +137,7 @@ def __call__(self, message_list: MessageList) -> str: model=self.model, messages=message_list, temperature=self.temperature, - max_tokens=self.max_tokens, + max_completion_tokens=self.max_completion_tokens, ) return response.choices[0].message.content # NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are reruning MMMU diff --git a/python/sglang/test/test_programs.py b/python/sglang/test/test_programs.py index 088cb0d0af91..6388206c2ec7 100644 --- a/python/sglang/test/test_programs.py +++ b/python/sglang/test/test_programs.py @@ -253,23 +253,23 @@ def parallel_decoding(s, topic): # Generate skeleton for i in range(1, 1 + fork_size): - s += f"{i}." + sgl.gen(max_tokens=16, stop=[".", "\n"]) + ".\n" + s += f"{i}." + sgl.gen(max_completion_tokens=16, stop=[".", "\n"]) + ".\n" # Generate detailed tips forks = s.fork(fork_size) for i in range(fork_size): forks[ i - ] += f"Now, I expand tip {i+1} into a detailed paragraph:\nTip {i+1}:" + ] += f"Now, I expand tip {i + 1} into a detailed paragraph:\nTip {i + 1}:" forks[i] += sgl.gen("detailed_tip", max_tokens, stop=["\n\n"]) forks.join() # Concatenate tips and summarize s += "Here are these tips with detailed explanation:\n" for i in range(fork_size): - s += f"Tip {i+1}:" + forks[i]["detailed_tip"] + "\n" + s += f"Tip {i + 1}:" + forks[i]["detailed_tip"] + "\n" - s += "\nIn summary," + sgl.gen("summary", max_tokens=512) + s += "\nIn summary," + sgl.gen("summary", max_completion_tokens=512) ret = parallel_decoding.run(topic="writing a good blog post", temperature=0.3) assert isinstance(ret["summary"], str) @@ -292,7 +292,7 @@ def parallel_encoding(s, question, context_0, context_1, context_2): s += "Now, please answer the following question. " "Do not list options." s += "\nQuestion: " + question + "\n" - s += "ASSISTANT:" + sgl.gen("answer", max_tokens=max_tokens) + s += "ASSISTANT:" + sgl.gen("answer", max_completion_tokens=max_tokens) ret = parallel_encoding.run( question="Who is the father of Julian?", @@ -560,7 +560,7 @@ def test_gen_min_new_tokens(): def convo_1(s): s += sgl.user("What is the capital of the United States?") s += sgl.assistant( - sgl.gen("answer", min_tokens=MIN_TOKENS, max_tokens=MAX_TOKENS) + sgl.gen("answer", min_tokens=MIN_TOKENS, max_completion_tokens=MAX_TOKENS) ) def assert_min_tokens(tokenizer, text): diff --git a/test/srt/test_cache_report.py b/test/srt/test_cache_report.py index f128aa147dc3..9487888aeb93 100644 --- a/test/srt/test_cache_report.py +++ b/test/srt/test_cache_report.py @@ -74,7 +74,7 @@ def run_openai(self, message): {"role": "user", "content": message}, ], temperature=0, - max_tokens=100, + max_completion_tokens=100, ) return response @@ -85,7 +85,7 @@ async def run_openai_async(self, message): {"role": "user", "content": message}, ], temperature=0, - max_tokens=100, + max_completion_tokens=100, ) return response diff --git a/test/srt/test_json_constrained.py b/test/srt/test_json_constrained.py index adb5c18fbe22..1b58e2914c1e 100644 --- a/test/srt/test_json_constrained.py +++ b/test/srt/test_json_constrained.py @@ -108,7 +108,7 @@ def test_json_openai(self): {"role": "user", "content": "Introduce the capital of France."}, ], temperature=0, - max_tokens=128, + max_completion_tokens=128, response_format={ "type": "json_schema", "json_schema": {"name": "foo", "schema": json.loads(self.json_schema)}, diff --git a/test/srt/test_matched_stop.py b/test/srt/test_matched_stop.py index 7b09a6d35f13..9500dad1c51e 100644 --- a/test/srt/test_matched_stop.py +++ b/test/srt/test_matched_stop.py @@ -68,7 +68,7 @@ def run_completions_generation( def run_chat_completions_generation( self, prompt=MANY_NEW_TOKENS_PROMPT, - max_tokens=1, + max_completion_tokens=1, stop=None, finish_reason=None, matched_stop=None, @@ -81,7 +81,7 @@ def run_chat_completions_generation( ], "temperature": 0, "top_p": 1, - "max_tokens": max_tokens, + "max_completion_tokens": max_completion_tokens, } if stop is not None: @@ -102,7 +102,10 @@ def test_finish_stop_str(self): max_tokens=1000, stop="\n", finish_reason="stop", matched_stop="\n" ) self.run_chat_completions_generation( - max_tokens=1000, stop="\n", finish_reason="stop", matched_stop="\n" + max_completion_tokens=1000, + stop="\n", + finish_reason="stop", + matched_stop="\n", ) def test_finish_stop_eos(self): @@ -121,7 +124,7 @@ def test_finish_stop_eos(self): ) self.run_chat_completions_generation( prompt="What is 2 + 2?", - max_tokens=1000, + max_completion_tokens=1000, finish_reason="stop", matched_stop=eos_token_id, ) @@ -131,7 +134,7 @@ def test_finish_length(self): max_tokens=5, finish_reason="length", matched_stop=None ) self.run_chat_completions_generation( - max_tokens=5, finish_reason="length", matched_stop=None + max_completion_tokens=5, finish_reason="length", matched_stop=None )