) for a message.
+ """
+ return jinja_env.from_string(_message_template).render(
+ role=message["role"],
+ content=message["content"],
+ variant=message.get("variant", None),
+ )
+
+
+jinja_env.globals["message_to_html"] = message_to_html
+
+
+_report_template = """
+
+
+
+
+
+ {% if metrics %}
+
Metrics
+
+
+ Metric |
+ Value |
+
+
+ Score |
+ {{ score | float | round(3) }} |
+
+ {% for name, value in metrics.items() %}
+
+ {{ name }} |
+ {{ value }} |
+
+ {% endfor %}
+
+ {% endif %}
+
Examples
+ {% for html in htmls %}
+ {{ html | safe }}
+
+ {% endfor %}
+
+
+"""
+
+
+def make_report(eval_result: EvalResult) -> str:
+ """
+ Create a standalone HTML report from an EvalResult.
+ """
+ return jinja_env.from_string(_report_template).render(
+ score=eval_result.score,
+ metrics=eval_result.metrics,
+ htmls=eval_result.htmls,
+ )
+
+
+def make_report_from_example_htmls(htmls: List[str]):
+ """
+ Create a standalone HTML report from a list of example htmls
+ """
+ return jinja_env.from_string(_report_template).render(
+ score=None, metrics={}, htmls=htmls
+ )
+
+
+def download_dataset(path, url):
+ print(f"Downloading dataset {path} from {url}")
+ try:
+ response = requests.get(url, stream=True)
+ response.raise_for_status()
+
+ total_size = int(response.headers.get("content-length", 0))
+ block_size = 8192
+
+ with open(path, "wb") as f, tqdm(
+ desc="Downloading",
+ total=total_size,
+ unit="iB",
+ unit_scale=True,
+ unit_divisor=1024,
+ ) as progress_bar:
+ for data in response.iter_content(block_size):
+ size = f.write(data)
+ progress_bar.update(size)
+
+ print(f"Dataset downloaded and saved to {path}")
+ except requests.RequestException as e:
+ raise Exception(f"Failed to download dataset: {e}")
+
+
+def set_ulimit(target_soft_limit=65535):
+ resource_type = resource.RLIMIT_NOFILE
+ current_soft, current_hard = resource.getrlimit(resource_type)
+
+ if current_soft < target_soft_limit:
+ try:
+ resource.setrlimit(resource_type, (target_soft_limit, current_hard))
+ except ValueError as e:
+ print(f"Fail to set RLIMIT_NOFILE: {e}")
diff --git a/python/sglang/test/simple_eval_gpqa.py b/python/sglang/test/simple_eval_gpqa.py
new file mode 100644
index 00000000000..46055caa5f1
--- /dev/null
+++ b/python/sglang/test/simple_eval_gpqa.py
@@ -0,0 +1,92 @@
+# Adapted from https://github.com/openai/simple-evals/
+
+"""
+GPQA: A Graduate-Level Google-Proof Q&A Benchmark
+David Rein, Betty Li Hou, Asa Cooper Stickland, Jackson Petty, Richard Yuanzhe Pang, Julien Dirani, Julian Michael, Samuel R. Bowman
+https://arxiv.org/abs/2311.12022
+"""
+
+import random
+import re
+
+import pandas
+
+from sglang.test import simple_eval_common as common
+from sglang.test.simple_eval_common import (
+ ANSWER_PATTERN_MULTICHOICE,
+ HTML_JINJA,
+ Eval,
+ EvalResult,
+ MessageList,
+ SamplerBase,
+ SingleEvalResult,
+ format_multichoice_question,
+)
+
+
+class GPQAEval(Eval):
+ def __init__(
+ self,
+ filename: str,
+ num_examples: int | None,
+ num_threads: int,
+ n_repeats: int = 1,
+ ):
+ df = pandas.read_csv(filename)
+ examples = [row.to_dict() for _, row in df.iterrows()]
+ rng = random.Random(0)
+ if num_examples:
+ assert n_repeats == 1, "n_repeats only supported for num_examples"
+ examples = rng.sample(examples, num_examples)
+ examples = examples * n_repeats
+ examples = [
+ example | {"permutation": rng.sample(range(4), 4)} for example in examples
+ ]
+ self.examples = examples
+ self.n_repeats = n_repeats
+ self.num_threads = num_threads
+
+ def __call__(self, sampler: SamplerBase) -> EvalResult:
+ def fn(row: dict):
+ choices = [
+ row["Correct Answer"],
+ row["Incorrect Answer 1"],
+ row["Incorrect Answer 2"],
+ row["Incorrect Answer 3"],
+ ]
+ choices = [choices[i] for i in row["permutation"]]
+ correct_index = choices.index(row["Correct Answer"])
+ correct_answer = "ABCD"[correct_index]
+ choices_dict = dict(
+ A=choices[0],
+ B=choices[1],
+ C=choices[2],
+ D=choices[3],
+ Question=row["Question"],
+ )
+ prompt_messages = [
+ sampler._pack_message(
+ content=format_multichoice_question(choices_dict), role="user"
+ )
+ ]
+ response_text = sampler(prompt_messages)
+ match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text)
+ extracted_answer = match.group(1) if match else None
+ score = 1.0 if extracted_answer == correct_answer else 0.0
+ html = common.jinja_env.from_string(HTML_JINJA).render(
+ prompt_messages=prompt_messages,
+ next_message=dict(content=response_text, role="assistant"),
+ score=score,
+ correct_answer=correct_answer,
+ extracted_answer=extracted_answer,
+ )
+ convo = prompt_messages + [dict(content=response_text, role="assistant")]
+ return SingleEvalResult(
+ html=html,
+ score=score,
+ convo=convo,
+ metrics={"chars": len(response_text)},
+ )
+
+ results = common.map_with_progress(fn, self.examples, self.num_threads)
+ return common.aggregate_results(results)
diff --git a/python/sglang/test/simple_eval_humaneval.py b/python/sglang/test/simple_eval_humaneval.py
new file mode 100644
index 00000000000..7a0f90c4673
--- /dev/null
+++ b/python/sglang/test/simple_eval_humaneval.py
@@ -0,0 +1,139 @@
+# Adapted from https://github.com/openai/simple-evals/
+
+"""
+HumanEval: Evaluating Large Language Models Trained on Code
+Mark Chen and Jerry Tworek and Heewoo Jun and Qiming Yuan and Henrique Ponde de Oliveira Pinto and Jared Kaplan and Harri Edwards and Yuri Burda and Nicholas Joseph and Greg Brockman and Alex Ray and Raul Puri and Gretchen Krueger and Michael Petrov and Heidy Khlaaf and Girish Sastry and Pamela Mishkin and Brooke Chan and Scott Gray and Nick Ryder and Mikhail Pavlov and Alethea Power and Lukasz Kaiser and Mohammad Bavarian and Clemens Winter and Philippe Tillet and Felipe Petroski Such and Dave Cummings and Matthias Plappert and Fotios Chantzis and Elizabeth Barnes and Ariel Herbert-Voss and William Hebgen Guss and Alex Nichol and Alex Paino and Nikolas Tezak and Jie Tang and Igor Babuschkin and Suchir Balaji and Shantanu Jain and William Saunders and Christopher Hesse and Andrew N. Carr and Jan Leike and Josh Achiam and Vedant Misra and Evan Morikawa and Alec Radford and Matthew Knight and Miles Brundage and Mira Murati and Katie Mayer and Peter Welinder and Bob McGrew and Dario Amodei and Sam McCandlish and Ilya Sutskever and Wojciech Zaremba
+https://arxiv.org/abs/2107.03374 https://github.com/openai/human-eval/
+"""
+
+import json
+import logging
+import multiprocessing
+import random
+import re
+from collections import Counter, defaultdict
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from io import BytesIO
+from typing import Any, Dict, List, Tuple
+
+import blobfile as bf
+import tqdm
+
+try:
+ from human_eval.data import HUMAN_EVAL, read_problems
+ from human_eval.evaluation import estimate_pass_at_k
+ from human_eval.execution import check_correctness # , unsafe_execute
+except (ImportError, ModuleNotFoundError):
+ print("\nPlease install human-eval at https://github.com/openai/human-eval.\n")
+ raise
+
+from sglang.test import simple_eval_common as common
+from sglang.test.simple_eval_common import (
+ HTML_JINJA,
+ Eval,
+ EvalResult,
+ SamplerBase,
+ SingleEvalResult,
+)
+
+
+def evaluate_functional_correctness(
+ sample: Dict[str, str],
+ completions: List[str],
+ n_workers: int = 4,
+ timeout: float = 3.0,
+):
+ """
+ Evaluates the functional correctness of generated samples, and writes
+ results to f"{sample_file}_results.jsonl.gz"
+ """
+ import copy
+
+ # Check the generated samples against test suites.
+ with ThreadPoolExecutor(max_workers=n_workers) as executor:
+ futures = []
+ for i, completion in enumerate(completions):
+ args = (sample, completion, timeout, i)
+ future = executor.submit(check_correctness, *args)
+ futures.append(future)
+ results = []
+ for future in as_completed(futures):
+ result = future.result()
+ results.append(result)
+ passed = [int(r["passed"]) for r in results]
+ return passed
+
+
+class HumanEval(Eval):
+ def __init__(
+ self,
+ num_examples: int | None,
+ num_threads: int,
+ num_samples_per_task: int = 5,
+ ks_passes: List[int] = [1, 2, 5],
+ timeout: int = 120,
+ ):
+ self.seed = 0
+ self.examples = read_problems()
+ self.examples = list(self.examples.values())
+
+ self._num_examples = num_examples
+ if self._num_examples:
+ self.examples = random.Random(self.seed).sample(self.examples, num_examples)
+ self._num_samples_per_task = num_samples_per_task
+ self._ks_passes = ks_passes
+ self._timeout = timeout
+ self._num_threads = num_threads
+
+ def __call__(self, sampler: SamplerBase) -> EvalResult:
+ instruction = "Read the following function signature and docstring, and fully implement the function described. Your response should only contain the code for this function.\n"
+
+ def find_code(completion):
+ pattern = re.compile(r"```python\n(.*?)```", re.DOTALL)
+ matches = pattern.findall(completion)
+ extracted_answer = matches[0] if len(matches) >= 1 else completion
+ extracted_answer = extracted_answer[
+ extracted_answer.find(":\n ") + 2 :
+ ] # remove signature
+ return extracted_answer
+
+ def fn(sample: Dict[str, str]):
+ prompt_messages = [
+ sampler._pack_message(
+ role="user", content=instruction + sample["prompt"]
+ )
+ ]
+ completions = [
+ find_code(sampler(prompt_messages))
+ for _ in range(self._num_samples_per_task)
+ ]
+ results = evaluate_functional_correctness(sample, completions)
+ total = len(results)
+ correct = sum(results)
+ score = sum(results) / len(results)
+ html = common.jinja_env.from_string(HTML_JINJA).render(
+ prompt_messages=prompt_messages,
+ next_message=dict(content=completions[0], role="assistant"),
+ score=score,
+ correct_answer=[1] * len(results),
+ extracted_answer=results,
+ )
+ convo = prompt_messages + [
+ dict(content=completion, role="assistant") for completion in completions
+ ]
+ return SingleEvalResult(
+ html=html,
+ score=score,
+ convo=convo,
+ metrics={
+ f"pass@{k}": estimate_pass_at_k([total], [correct], k)
+ # this will be aggrated so no need of .mean()
+ for k in self._ks_passes
+ if total >= k
+ },
+ )
+
+ results = common.map_with_progress(
+ fn, self.examples, num_threads=self._num_threads
+ )
+ return common.aggregate_results(results)
diff --git a/python/sglang/test/simple_eval_math.py b/python/sglang/test/simple_eval_math.py
new file mode 100644
index 00000000000..4ddb650d965
--- /dev/null
+++ b/python/sglang/test/simple_eval_math.py
@@ -0,0 +1,72 @@
+# Adapted from https://github.com/openai/simple-evals/
+
+"""
+Measuring Mathematical Problem Solving With the MATH Dataset
+Dan Hendrycks, Collin Burns, Saurav Kadavath, Akul Arora, Steven Basart, Eric Tang, Dawn Song, Jacob Steinhardt
+https://arxiv.org/abs/2103.03874
+"""
+
+import random
+import re
+
+import pandas
+
+from sglang.test import simple_eval_common as common
+from sglang.test.simple_eval_common import (
+ ANSWER_PATTERN,
+ HTML_JINJA,
+ Eval,
+ EvalResult,
+ SamplerBase,
+ SingleEvalResult,
+ check_equality,
+)
+
+QUERY_TEMPLATE = """
+Solve the following math problem step by step. The last line of your response should be of the form Answer: $ANSWER (without quotes) where $ANSWER is the answer to the problem.
+
+{Question}
+
+Remember to put your answer on its own line after "Answer:", and you do not need to use a \\boxed command.
+""".strip()
+
+
+class MathEval(Eval):
+ def __init__(
+ self,
+ filename: str,
+ equality_checker: SamplerBase,
+ num_examples: int | None,
+ num_threads: int,
+ ):
+ df = pandas.read_csv(filename)
+ examples = [row.to_dict() for _, row in df.iterrows()]
+ if num_examples:
+ examples = random.Random(0).sample(examples, num_examples)
+ self.examples = examples
+ self.equality_checker = equality_checker
+ self.num_threads = num_threads
+
+ def __call__(self, sampler: SamplerBase) -> EvalResult:
+ def fn(row: dict):
+ prompt_messages = [
+ sampler._pack_message(content=QUERY_TEMPLATE.format(**row), role="user")
+ ]
+ response_text = sampler(prompt_messages)
+ match = re.search(ANSWER_PATTERN, response_text)
+ extracted_answer = match.group(1) if match else None
+ score = float(
+ check_equality(self.equality_checker, row["Answer"], extracted_answer)
+ )
+ html = common.jinja_env.from_string(HTML_JINJA).render(
+ prompt_messages=prompt_messages,
+ next_message=dict(content=response_text, role="assistant"),
+ score=score,
+ correct_answer=row["Answer"],
+ extracted_answer=extracted_answer,
+ )
+ convo = prompt_messages + [dict(content=response_text, role="assistant")]
+ return SingleEvalResult(html=html, score=score, convo=convo)
+
+ results = common.map_with_progress(fn, self.examples, self.num_threads)
+ return common.aggregate_results(results)
diff --git a/python/sglang/test/simple_eval_mmlu.py b/python/sglang/test/simple_eval_mmlu.py
new file mode 100644
index 00000000000..3c0287510cb
--- /dev/null
+++ b/python/sglang/test/simple_eval_mmlu.py
@@ -0,0 +1,120 @@
+# Adapted from https://github.com/openai/simple-evals/
+
+"""
+Measuring Massive Multitask Language Understanding
+Dan Hendrycks, Collin Burns, Steven Basart, Andy Zou, Mantas Mazeika, Dawn Song, Jacob Steinhardt
+https://arxiv.org/abs/2009.03300
+"""
+
+import random
+import re
+
+import pandas
+
+from sglang.test import simple_eval_common as common
+from sglang.test.simple_eval_common import (
+ ANSWER_PATTERN_MULTICHOICE,
+ HTML_JINJA,
+ Eval,
+ EvalResult,
+ SamplerBase,
+ SingleEvalResult,
+ format_multichoice_question,
+)
+
+subject2category = {
+ "abstract_algebra": "stem",
+ "anatomy": "other",
+ "astronomy": "stem",
+ "business_ethics": "other",
+ "clinical_knowledge": "other",
+ "college_biology": "stem",
+ "college_chemistry": "stem",
+ "college_computer_science": "stem",
+ "college_mathematics": "stem",
+ "college_medicine": "other",
+ "college_physics": "stem",
+ "computer_security": "stem",
+ "conceptual_physics": "stem",
+ "econometrics": "social_sciences",
+ "electrical_engineering": "stem",
+ "elementary_mathematics": "stem",
+ "formal_logic": "humanities",
+ "global_facts": "other",
+ "high_school_biology": "stem",
+ "high_school_chemistry": "stem",
+ "high_school_computer_science": "stem",
+ "high_school_european_history": "humanities",
+ "high_school_geography": "social_sciences",
+ "high_school_government_and_politics": "social_sciences",
+ "high_school_macroeconomics": "social_sciences",
+ "high_school_mathematics": "stem",
+ "high_school_microeconomics": "social_sciences",
+ "high_school_physics": "stem",
+ "high_school_psychology": "social_sciences",
+ "high_school_statistics": "stem",
+ "high_school_us_history": "humanities",
+ "high_school_world_history": "humanities",
+ "human_aging": "other",
+ "human_sexuality": "social_sciences",
+ "international_law": "humanities",
+ "jurisprudence": "humanities",
+ "logical_fallacies": "humanities",
+ "machine_learning": "stem",
+ "management": "other",
+ "marketing": "other",
+ "medical_genetics": "other",
+ "miscellaneous": "other",
+ "moral_disputes": "humanities",
+ "moral_scenarios": "humanities",
+ "nutrition": "other",
+ "philosophy": "humanities",
+ "prehistory": "humanities",
+ "professional_accounting": "other",
+ "professional_law": "humanities",
+ "professional_medicine": "other",
+ "professional_psychology": "social_sciences",
+ "public_relations": "social_sciences",
+ "security_studies": "social_sciences",
+ "sociology": "social_sciences",
+ "us_foreign_policy": "social_sciences",
+ "virology": "other",
+ "world_religions": "humanities",
+}
+
+
+class MMLUEval(Eval):
+ def __init__(self, filename: str, num_examples: int | None, num_threads: int):
+ df = pandas.read_csv(filename)
+ examples = [row.to_dict() for _, row in df.iterrows()]
+ if num_examples:
+ examples = random.Random(0).sample(examples, num_examples)
+ self.examples = examples
+ self.num_threads = num_threads
+
+ def __call__(self, sampler: SamplerBase) -> EvalResult:
+ def fn(row: dict):
+ prompt_messages = [
+ sampler._pack_message(
+ content=format_multichoice_question(row), role="user"
+ )
+ ]
+ response_text = sampler(prompt_messages)
+ match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text)
+ extracted_answer = match.group(1) if match else None
+ score = 1.0 if extracted_answer == row["Answer"] else 0.0
+ html = common.jinja_env.from_string(HTML_JINJA).render(
+ prompt_messages=prompt_messages,
+ next_message=dict(content=response_text, role="assistant"),
+ score=score,
+ correct_answer=row["Answer"],
+ extracted_answer=extracted_answer,
+ )
+ convo = prompt_messages + [dict(content=response_text, role="assistant")]
+ category = subject2category.get(row["Subject"], "other")
+ return SingleEvalResult(
+ html=html, score=score, metrics={category: score}, convo=convo
+ )
+
+ results = common.map_with_progress(fn, self.examples, self.num_threads)
+ return common.aggregate_results(results)
diff --git a/python/sglang/test/srt/sampling/penaltylib/utils.py b/python/sglang/test/srt/sampling/penaltylib/utils.py
new file mode 100644
index 00000000000..b41eac32ba9
--- /dev/null
+++ b/python/sglang/test/srt/sampling/penaltylib/utils.py
@@ -0,0 +1,337 @@
+import dataclasses
+import enum
+import typing
+import unittest
+
+import torch
+
+from sglang.srt.sampling.penaltylib.orchestrator import (
+ BatchedPenalizerOrchestrator,
+ _BatchedPenalizer,
+ _BatchLike,
+)
+
+
+@dataclasses.dataclass
+class MockSamplingParams:
+ frequency_penalty: float = 0.0
+ min_new_tokens: int = 0
+ stop_token_ids: typing.List[int] = None
+ presence_penalty: float = 0.0
+ repetition_penalty: float = 1.0
+
+
+@dataclasses.dataclass
+class MockTokenizer:
+ eos_token_id: int
+
+
+@dataclasses.dataclass
+class MockReq:
+ origin_input_ids: typing.List[int]
+ sampling_params: MockSamplingParams
+ tokenizer: MockTokenizer
+
+
+class StepType(enum.Enum):
+ INPUT = "input"
+ OUTPUT = "output"
+
+
+@dataclasses.dataclass
+class Step:
+ type: StepType
+ token_ids: typing.List[int]
+ expected_tensors: typing.Dict[str, torch.Tensor]
+ # assume initial logits are all 1
+ expected_logits: torch.Tensor
+
+
+@dataclasses.dataclass
+class Subject:
+ sampling_params: MockSamplingParams
+ # first step must be input, which will be converted to Req
+ steps: typing.List[Step]
+ eos_token_id: int = -1
+
+ def __post_init__(self):
+ if self.steps[0].type != StepType.INPUT:
+ raise ValueError("First step must be input")
+
+ # each steps should have the same expected_tensors.keys()
+ for i in range(1, len(self.steps)):
+ if self.tensor_keys(i) != self.tensor_keys():
+ raise ValueError(
+ f"Expected tensors keys must be the same for all steps. Got {self.steps[i].expected_tensors.keys()} for key={i} and {self.steps[0].expected_tensors.keys()}"
+ )
+
+ def tensor_keys(self, i: int = 0) -> typing.Set[str]:
+ return set(self.steps[i].expected_tensors.keys())
+
+ def to_req(self) -> MockReq:
+ return MockReq(
+ origin_input_ids=self.steps[0].token_ids,
+ sampling_params=self.sampling_params,
+ tokenizer=MockTokenizer(eos_token_id=self.eos_token_id),
+ )
+
+
+@dataclasses.dataclass
+class Case:
+ enabled: bool
+ test_subjects: typing.List[Subject]
+
+ def __post_init__(self):
+ # each test_subjects.steps should have the same expected_tensors.keys()
+ for i in range(1, len(self.test_subjects)):
+ if self.tensor_keys(i) != self.tensor_keys():
+ raise ValueError(
+ f"Expected tensors keys must be the same for all test_subjects. Got {self.test_subjects[i].tensor_keys()} for key={i} and {self.test_subjects[0].tensor_keys()}"
+ )
+
+ def tensor_keys(self, i: int = 0) -> typing.List[str]:
+ return set(self.test_subjects[i].tensor_keys())
+
+
+class BaseBatchedPenalizerTest(unittest.TestCase):
+ Penalizer: typing.Type[_BatchedPenalizer]
+ device = "cuda"
+ vocab_size = 5
+
+ enabled: Subject = None
+ disabled: Subject = None
+
+ def setUp(self):
+ if self.__class__ == BaseBatchedPenalizerTest:
+ self.skipTest("Base class for penalizer tests")
+
+ self.create_test_subjects()
+ self.create_test_cases()
+
+ def tensor(self, data, **kwargs) -> torch.Tensor:
+ """
+ Shortcut to create a tensor with device=self.device.
+ """
+ return torch.tensor(data, **kwargs, device=self.device)
+
+ def create_test_subjects(self) -> typing.List[Subject]:
+ raise NotImplementedError()
+
+ def create_test_cases(self):
+ self.test_cases = [
+ Case(enabled=True, test_subjects=[self.enabled]),
+ Case(enabled=False, test_subjects=[self.disabled]),
+ Case(enabled=True, test_subjects=[self.enabled, self.disabled]),
+ ]
+
+ def _create_penalizer(
+ self, case: Case
+ ) -> typing.Tuple[BatchedPenalizerOrchestrator, _BatchedPenalizer]:
+ orchestrator = BatchedPenalizerOrchestrator(
+ vocab_size=self.vocab_size,
+ batch=_BatchLike(reqs=[subject.to_req() for subject in case.test_subjects]),
+ device=self.device,
+ Penalizers={self.Penalizer},
+ )
+
+ return orchestrator, orchestrator.penalizers[self.Penalizer]
+
+ def test_is_required(self):
+ for case in self.test_cases:
+ with self.subTest(case=case):
+ _, penalizer = self._create_penalizer(case)
+ self.assertEqual(case.enabled, penalizer.is_required())
+
+ def test_prepare(self):
+ for case in self.test_cases:
+ with self.subTest(case=case):
+ orchestrator, penalizer = self._create_penalizer(case)
+ self.assertEqual(case.enabled, penalizer.is_prepared())
+
+ if case.enabled:
+ for key, tensor in {
+ key: torch.cat(
+ tensors=[
+ subject.steps[0].expected_tensors[key]
+ for subject in case.test_subjects
+ ],
+ )
+ for key in case.tensor_keys()
+ }.items():
+ torch.testing.assert_close(
+ actual=getattr(penalizer, key),
+ expected=tensor,
+ msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
+ )
+
+ actual = orchestrator.apply(
+ torch.ones(
+ size=(len(case.test_subjects), self.vocab_size),
+ dtype=torch.float32,
+ device=self.device,
+ )
+ )
+ expected = torch.cat(
+ tensors=[
+ subject.steps[0].expected_logits
+ for subject in case.test_subjects
+ ],
+ )
+ torch.testing.assert_close(
+ actual=actual,
+ expected=expected,
+ msg=f"logits\nactual={actual}\nexpected={expected}",
+ )
+
+ def test_teardown(self):
+ for case in self.test_cases:
+ with self.subTest(case=case):
+ _, penalizer = self._create_penalizer(case)
+ penalizer.teardown()
+
+ for key in case.test_subjects[0].steps[0].expected_tensors.keys():
+ self.assertIsNone(getattr(penalizer, key, None))
+
+ def test_filter(self):
+ for case in self.test_cases:
+ with self.subTest(case=case):
+ orchestrator, penalizer = self._create_penalizer(case)
+
+ indices_to_keep = [0]
+ orchestrator.filter(indices_to_keep=indices_to_keep)
+
+ filtered_subjects = [case.test_subjects[i] for i in indices_to_keep]
+
+ if penalizer.is_required():
+ self.assertTrue(penalizer.is_prepared())
+ for key, tensor in {
+ key: torch.cat(
+ tensors=[
+ subject.steps[0].expected_tensors[key]
+ for subject in filtered_subjects
+ ],
+ )
+ for key in case.tensor_keys()
+ }.items():
+ torch.testing.assert_close(
+ actual=getattr(penalizer, key),
+ expected=tensor,
+ msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
+ )
+
+ actual_logits = orchestrator.apply(
+ torch.ones(
+ size=(len(filtered_subjects), self.vocab_size),
+ dtype=torch.float32,
+ device=self.device,
+ )
+ )
+ filtered_expected_logits = torch.cat(
+ tensors=[
+ subject.steps[0].expected_logits
+ for subject in filtered_subjects
+ ],
+ )
+ torch.testing.assert_close(
+ actual=actual_logits,
+ expected=filtered_expected_logits,
+ msg=f"logits\nactual={actual_logits}\nexpected={filtered_expected_logits}",
+ )
+
+ def test_merge_enabled_with_disabled(self):
+ enabled_test_case = self.test_cases[0]
+ disabled_test_case = self.test_cases[1]
+
+ orchestrator, penalizer = self._create_penalizer(enabled_test_case)
+ theirs, _ = self._create_penalizer(disabled_test_case)
+
+ orchestrator.merge(theirs)
+
+ for key, tensor in {
+ key: torch.cat(
+ tensors=[
+ enabled_test_case.test_subjects[0].steps[0].expected_tensors[key],
+ disabled_test_case.test_subjects[0].steps[0].expected_tensors[key],
+ ],
+ )
+ for key in enabled_test_case.tensor_keys()
+ }.items():
+ torch.testing.assert_close(
+ actual=getattr(penalizer, key),
+ expected=tensor,
+ msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
+ )
+
+ def test_cumulate_apply_repeat(self):
+ for case in self.test_cases:
+ with self.subTest(case=case):
+ orchestrator, penalizer = self._create_penalizer(case)
+
+ max_step = max(len(subject.steps) for subject in case.test_subjects)
+ for i in range(1, max_step):
+ orchestrator.filter(
+ indices_to_keep=[
+ j
+ for j, subject in enumerate(case.test_subjects)
+ if i < len(subject.steps)
+ ]
+ )
+
+ filtered_subjects = [
+ subject
+ for subject in case.test_subjects
+ if i < len(subject.steps)
+ ]
+
+ inputs: typing.List[typing.List[int]] = []
+ outputs: typing.List[typing.List[int]] = []
+ for subject in filtered_subjects:
+ step = subject.steps[i]
+ if step.type == StepType.INPUT:
+ inputs.append(step.token_ids)
+ outputs.append([])
+ else:
+ inputs.append([])
+ outputs.append(step.token_ids)
+
+ if any(inputs):
+ orchestrator.cumulate_input_tokens(inputs)
+
+ if any(outputs):
+ orchestrator.cumulate_output_tokens(outputs)
+
+ if penalizer.is_required():
+ self.assertTrue(penalizer.is_prepared())
+ for key, tensor in {
+ key: torch.cat(
+ tensors=[
+ subject.steps[i].expected_tensors[key]
+ for subject in filtered_subjects
+ ],
+ )
+ for key in case.tensor_keys()
+ }.items():
+ torch.testing.assert_close(
+ actual=getattr(penalizer, key),
+ expected=tensor,
+ msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
+ )
+
+ actual_logits = orchestrator.apply(
+ torch.ones(
+ size=(len(filtered_subjects), self.vocab_size),
+ dtype=torch.float32,
+ device=self.device,
+ )
+ )
+ filtered_expected_logits = torch.cat(
+ tensors=[
+ subject.steps[i].expected_logits
+ for subject in filtered_subjects
+ ],
+ )
+ torch.testing.assert_close(
+ actual=actual_logits,
+ expected=filtered_expected_logits,
+ msg=f"logits\nactual={actual_logits}\nexpected={filtered_expected_logits}",
+ )
diff --git a/python/sglang/test/test_conversation.py b/python/sglang/test/test_conversation.py
deleted file mode 100644
index 11e837ddbde..00000000000
--- a/python/sglang/test/test_conversation.py
+++ /dev/null
@@ -1,46 +0,0 @@
-from sglang.srt.conversation import generate_chat_conv
-from sglang.srt.managers.openai_protocol import (
- ChatCompletionMessageContentImagePart,
- ChatCompletionMessageContentImageURL,
- ChatCompletionMessageContentTextPart,
- ChatCompletionMessageGenericParam,
- ChatCompletionMessageUserParam,
- ChatCompletionRequest,
-)
-
-
-def test_chat_completion_to_conv_image():
- """Test that we can convert a chat image request to a convo"""
- request = ChatCompletionRequest(
- model="default",
- messages=[
- ChatCompletionMessageGenericParam(
- role="system", content="You are a helpful AI assistant"
- ),
- ChatCompletionMessageUserParam(
- role="user",
- content=[
- ChatCompletionMessageContentTextPart(
- type="text", text="Describe this image"
- ),
- ChatCompletionMessageContentImagePart(
- type="image_url",
- image_url=ChatCompletionMessageContentImageURL(
- url="https://someurl.com"
- ),
- ),
- ],
- ),
- ],
- )
- conv = generate_chat_conv(request, "vicuna_v1.1")
- assert conv.messages == [
- ["USER", "Describe this image
"],
- ["ASSISTANT", None],
- ]
- assert conv.system_message == "You are a helpful AI assistant"
- assert conv.image_data == ["https://someurl.com"]
-
-
-if __name__ == "__main__":
- test_chat_completion_to_conv_image()
diff --git a/python/sglang/test/test_openai_protocol.py b/python/sglang/test/test_openai_protocol.py
deleted file mode 100644
index 99e7a8089cf..00000000000
--- a/python/sglang/test/test_openai_protocol.py
+++ /dev/null
@@ -1,51 +0,0 @@
-from sglang.srt.managers.openai_protocol import (
- ChatCompletionMessageContentImagePart,
- ChatCompletionMessageContentImageURL,
- ChatCompletionMessageContentTextPart,
- ChatCompletionMessageGenericParam,
- ChatCompletionMessageUserParam,
- ChatCompletionRequest,
-)
-
-
-def test_chat_completion_request_image():
- """Test that Chat Completion Requests with images can be converted."""
-
- image_request = {
- "model": "default",
- "messages": [
- {"role": "system", "content": "You are a helpful AI assistant"},
- {
- "role": "user",
- "content": [
- {"type": "text", "text": "Describe this image"},
- {"type": "image_url", "image_url": {"url": "https://someurl.com"}},
- ],
- },
- ],
- "temperature": 0,
- "max_tokens": 64,
- }
- request = ChatCompletionRequest(**image_request)
- assert len(request.messages) == 2
- assert request.messages[0] == ChatCompletionMessageGenericParam(
- role="system", content="You are a helpful AI assistant"
- )
- assert request.messages[1] == ChatCompletionMessageUserParam(
- role="user",
- content=[
- ChatCompletionMessageContentTextPart(
- type="text", text="Describe this image"
- ),
- ChatCompletionMessageContentImagePart(
- type="image_url",
- image_url=ChatCompletionMessageContentImageURL(
- url="https://someurl.com"
- ),
- ),
- ],
- )
-
-
-if __name__ == "__main__":
- test_chat_completion_request_image()
diff --git a/python/sglang/test/test_programs.py b/python/sglang/test/test_programs.py
index 6fa8f821433..710871ba5db 100644
--- a/python/sglang/test/test_programs.py
+++ b/python/sglang/test/test_programs.py
@@ -105,20 +105,22 @@ def test_decode_json_regex():
def decode_json(s):
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STRING
- s += "Generate a JSON object to describe the basic information of a city.\n"
+ s += "Generate a JSON object to describe the basic city information of Paris.\n"
with s.var_scope("json_output"):
s += "{\n"
s += ' "name": ' + sgl.gen(regex=REGEX_STRING + ",") + "\n"
s += ' "population": ' + sgl.gen(regex=REGEX_INT + ",") + "\n"
s += ' "area": ' + sgl.gen(regex=REGEX_INT + ",") + "\n"
- s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT + ",") + "\n"
- s += ' "country": ' + sgl.gen(regex=REGEX_STRING + ",") + "\n"
- s += ' "timezone": ' + sgl.gen(regex=REGEX_STRING) + "\n"
+ s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT) + "\n"
s += "}"
- ret = decode_json.run()
- js_obj = json.loads(ret["json_output"])
+ ret = decode_json.run(temperature=0.0)
+ try:
+ js_obj = json.loads(ret["json_output"])
+ except json.decoder.JSONDecodeError:
+ print("JSONDecodeError", ret["json_output"])
+ raise
assert isinstance(js_obj["name"], str)
assert isinstance(js_obj["population"], int)
@@ -126,7 +128,7 @@ def decode_json(s):
def test_decode_json():
@sgl.function
def decode_json(s):
- s += "Generate a JSON object to describe the basic information of a city.\n"
+ s += "Generate a JSON object to describe the basic city information of Paris.\n"
with s.var_scope("json_output"):
s += "{\n"
@@ -137,8 +139,12 @@ def decode_json(s):
s += ' "timezone": ' + sgl.gen(dtype=str) + "\n"
s += "}"
- ret = decode_json.run()
- js_obj = json.loads(ret["json_output"])
+ ret = decode_json.run(max_new_tokens=64)
+ try:
+ js_obj = json.loads(ret["json_output"])
+ except json.decoder.JSONDecodeError:
+ print("JSONDecodeError", ret["json_output"])
+ raise
assert isinstance(js_obj["name"], str)
assert isinstance(js_obj["population"], int)
@@ -257,6 +263,7 @@ def parallel_decoding(s, topic):
s += "\nIn summary," + sgl.gen("summary", max_tokens=512)
ret = parallel_decoding.run(topic="writing a good blog post", temperature=0.3)
+ assert isinstance(ret["summary"], str)
def test_parallel_encoding(check_answer=True):
@@ -306,7 +313,7 @@ def image_qa(s, question):
assert (
"taxi" in state.messages()[-1]["content"]
or "car" in state.messages()[-1]["content"]
- )
+ ), f"{state.messages()[-1]['content']}"
def test_stream():
diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py
index 693bade6f2d..613645b572e 100644
--- a/python/sglang/test/test_utils.py
+++ b/python/sglang/test/test_utils.py
@@ -1,16 +1,27 @@
"""Common utilities for testing and benchmarking"""
+import argparse
import asyncio
+import multiprocessing
+import subprocess
+import threading
+import time
+import unittest
from functools import partial
+from typing import Callable, List, Optional
import numpy as np
import requests
+import torch
+import torch.nn.functional as F
-from sglang.backend.openai import OpenAI
-from sglang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.global_config import global_config
+from sglang.lang.backend.openai import OpenAI
+from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.utils import get_exception_traceback
+DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Meta-Llama-3.1-8B-Instruct"
+
def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None):
assert url is not None
@@ -243,7 +254,7 @@ async def program(ctx, choices):
return choices.index(answer)
-def add_common_other_args_and_parse(parser):
+def add_common_other_args_and_parse(parser: argparse.ArgumentParser):
parser.add_argument("--parallel", type=int, default=64)
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=None)
@@ -282,7 +293,7 @@ def add_common_other_args_and_parse(parser):
return args
-def add_common_sglang_args_and_parse(parser):
+def add_common_sglang_args_and_parse(parser: argparse.ArgumentParser):
parser.add_argument("--parallel", type=int, default=64)
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
@@ -292,7 +303,7 @@ def add_common_sglang_args_and_parse(parser):
return args
-def select_sglang_backend(args):
+def select_sglang_backend(args: argparse.Namespace):
if args.backend.startswith("srt"):
if args.backend == "srt-no-parallel":
global_config.enable_parallel_decoding = False
@@ -305,7 +316,7 @@ def select_sglang_backend(args):
return backend
-def _get_call_generate(args):
+def _get_call_generate(args: argparse.Namespace):
if args.backend == "lightllm":
return partial(call_generate_lightllm, url=f"{args.host}:{args.port}/generate")
elif args.backend == "vllm":
@@ -332,7 +343,7 @@ def _get_call_generate(args):
raise ValueError(f"Invalid backend: {args.backend}")
-def _get_call_select(args):
+def _get_call_select(args: argparse.Namespace):
if args.backend == "lightllm":
return partial(call_select_lightllm, url=f"{args.host}:{args.port}/generate")
elif args.backend == "vllm":
@@ -355,7 +366,7 @@ def _get_call_select(args):
raise ValueError(f"Invalid backend: {args.backend}")
-def get_call_generate(args):
+def get_call_generate(args: argparse.Namespace):
call_generate = _get_call_generate(args)
def func(*args, **kwargs):
@@ -368,7 +379,7 @@ def func(*args, **kwargs):
return func
-def get_call_select(args):
+def get_call_select(args: argparse.Namespace):
call_select = _get_call_select(args)
def func(*args, **kwargs):
@@ -379,3 +390,111 @@ def func(*args, **kwargs):
raise
return func
+
+
+def popen_launch_server(
+ model: str,
+ base_url: str,
+ timeout: float,
+ api_key: Optional[str] = None,
+ other_args: tuple = (),
+):
+ _, host, port = base_url.split(":")
+ host = host[2:]
+
+ command = [
+ "python3",
+ "-m",
+ "sglang.launch_server",
+ "--model-path",
+ model,
+ "--host",
+ host,
+ "--port",
+ port,
+ *other_args,
+ ]
+ if api_key:
+ command += ["--api-key", api_key]
+
+ process = subprocess.Popen(command, stdout=None, stderr=None)
+
+ start_time = time.time()
+ while time.time() - start_time < timeout:
+ try:
+ headers = {
+ "Content-Type": "application/json; charset=utf-8",
+ "Authorization": f"Bearer {api_key}",
+ }
+ response = requests.get(f"{base_url}/v1/models", headers=headers)
+ if response.status_code == 200:
+ return process
+ except requests.RequestException:
+ pass
+ time.sleep(10)
+ raise TimeoutError("Server failed to start within the timeout period.")
+
+
+def run_with_timeout(
+ func: Callable,
+ args: tuple = (),
+ kwargs: Optional[dict] = None,
+ timeout: float = None,
+):
+ """Run a function with timeout."""
+ ret_value = []
+
+ def _target_func():
+ ret_value.append(func(*args, **(kwargs or {})))
+
+ t = threading.Thread(target=_target_func)
+ t.start()
+ t.join(timeout=timeout)
+ if t.is_alive():
+ raise TimeoutError()
+
+ if not ret_value:
+ raise RuntimeError()
+
+ return ret_value[0]
+
+
+def run_unittest_files(files: List[str], timeout_per_file: float):
+ tic = time.time()
+ success = True
+
+ for filename in files:
+
+ def func():
+ print(f"\n\nRun {filename}\n\n")
+ ret = unittest.main(module=None, argv=["", "-vb"] + [filename])
+
+ p = multiprocessing.Process(target=func)
+
+ def run_one_file():
+ p.start()
+ p.join()
+
+ try:
+ run_with_timeout(run_one_file, timeout=timeout_per_file)
+ if p.exitcode != 0:
+ success = False
+ break
+ except TimeoutError:
+ p.terminate()
+ time.sleep(5)
+ print(
+ f"\nTimeout after {timeout_per_file} seconds when running {filename}\n"
+ )
+ return False
+
+ if success:
+ print(f"Success. Time elapsed: {time.time() - tic:.2f}s")
+ else:
+ print(f"Fail. Time elapsed: {time.time() - tic:.2f}s")
+
+ return 0 if success else -1
+
+
+def get_similarities(vec1, vec2):
+ return F.cosine_similarity(torch.tensor(vec1), torch.tensor(vec2), dim=0)
diff --git a/python/sglang/utils.py b/python/sglang/utils.py
index 0f5fd439082..c880d259d53 100644
--- a/python/sglang/utils.py
+++ b/python/sglang/utils.py
@@ -1,16 +1,17 @@
"""Common utilities."""
import base64
+import importlib
import json
import logging
import signal
import sys
-import threading
import traceback
import urllib.request
from concurrent.futures import ThreadPoolExecutor
from io import BytesIO
from json import dumps
+from typing import Union
import numpy as np
import requests
@@ -24,7 +25,7 @@ def get_exception_traceback():
return err_str
-def is_same_type(values):
+def is_same_type(values: list):
"""Return whether the elements in values are of the same type."""
if len(values) <= 1:
return True
@@ -44,7 +45,7 @@ def read_jsonl(filename: str):
return rets
-def dump_state_text(filename, states, mode="w"):
+def dump_state_text(filename: str, states: list, mode: str = "w"):
"""Dump program state in a text file."""
from sglang.lang.interpreter import ProgramState
@@ -74,19 +75,13 @@ def status_code(self):
return self.resp.status
-def http_request(
- url, json=None, stream=False, auth_token=None, api_key=None, verify=None
-):
+def http_request(url, json=None, stream=False, api_key=None, verify=None):
"""A faster version of requests.post with low-level urllib API."""
headers = {"Content-Type": "application/json; charset=utf-8"}
- # add the Authorization header if an auth token is provided
- if auth_token is not None:
- headers["Authorization"] = f"Bearer {auth_token}"
-
- # add the API Key header if an API key is provided
+ # add the Authorization header if an api key is provided
if api_key is not None:
- headers["X-API-Key"] = api_key
+ headers["Authorization"] = f"Bearer {api_key}"
if stream:
return requests.post(url, json=json, stream=True, headers=headers)
@@ -104,7 +99,7 @@ def http_request(
return HttpResponse(e)
-def encode_image_base64(image_path):
+def encode_image_base64(image_path: Union[str, bytes]):
"""Encode an image in base64."""
if isinstance(image_path, str):
with open(image_path, "rb") as image_file:
@@ -143,7 +138,7 @@ def encode_frame(frame):
return frame_bytes
-def encode_video_base64(video_path, num_frames=16):
+def encode_video_base64(video_path: str, num_frames: int = 16):
import cv2 # pip install opencv-python-headless
cap = cv2.VideoCapture(video_path)
@@ -189,7 +184,7 @@ def encode_video_base64(video_path, num_frames=16):
return video_base64
-def _is_chinese_char(cp):
+def _is_chinese_char(cp: int):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
@@ -214,7 +209,7 @@ def _is_chinese_char(cp):
return False
-def find_printable_text(text):
+def find_printable_text(text: str):
"""Returns the longest printable substring of text that contains only entire words."""
# Borrowed from https://github.com/huggingface/transformers/blob/061580c82c2db1de9139528243e105953793f7a2/src/transformers/generation/streamers.py#L99
@@ -233,26 +228,7 @@ def find_printable_text(text):
return text[: text.rfind(" ") + 1]
-def run_with_timeout(func, args=(), kwargs=None, timeout=None):
- """Run a function with timeout."""
- ret_value = []
-
- def _target_func():
- ret_value.append(func(*args, **(kwargs or {})))
-
- t = threading.Thread(target=_target_func)
- t.start()
- t.join(timeout=timeout)
- if t.is_alive():
- raise TimeoutError()
-
- if not ret_value:
- raise RuntimeError()
-
- return ret_value[0]
-
-
-def graceful_registry(sub_module_name):
+def graceful_registry(sub_module_name: str):
def graceful_shutdown(signum, frame):
logger.info(
f"{sub_module_name} Received signal to shutdown. Performing graceful shutdown..."
@@ -261,3 +237,26 @@ def graceful_shutdown(signum, frame):
logger.info(f"{sub_module_name} recive sigterm")
signal.signal(signal.SIGTERM, graceful_shutdown)
+
+
+class LazyImport:
+ """Lazy import to make `import sglang` run faster."""
+
+ def __init__(self, module_name: str, class_name: str):
+ self.module_name = module_name
+ self.class_name = class_name
+ self._module = None
+
+ def _load(self):
+ if self._module is None:
+ module = importlib.import_module(self.module_name)
+ self._module = getattr(module, self.class_name)
+ return self._module
+
+ def __getattr__(self, name: str):
+ module = self._load()
+ return getattr(module, name)
+
+ def __call__(self, *args, **kwargs):
+ module = self._load()
+ return module(*args, **kwargs)
diff --git a/python/sglang/version.py b/python/sglang/version.py
new file mode 100644
index 00000000000..5635676f6b4
--- /dev/null
+++ b/python/sglang/version.py
@@ -0,0 +1 @@
+__version__ = "0.2.11"
diff --git a/scripts/convert_yi_vl.py b/scripts/convert_yi_vl.py
index a45f83a3002..bdf37ff92bb 100644
--- a/scripts/convert_yi_vl.py
+++ b/scripts/convert_yi_vl.py
@@ -10,16 +10,15 @@
from transformers import AutoConfig, AutoTokenizer
+
def add_image_token(model_path: str):
tokenizer = AutoTokenizer.from_pretrained(model_path)
- tokenizer.add_tokens(
- [""],
- special_tokens=True
- )
+ tokenizer.add_tokens([""], special_tokens=True)
print(tokenizer)
tokenizer.save_pretrained(model_path)
+
def edit_model_config(model_path):
config = AutoConfig.from_pretrained(model_path)
@@ -29,10 +28,11 @@ def edit_model_config(model_path):
print(config)
config.save_pretrained(model_path)
+
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str)
args = parser.parse_args()
add_image_token(args.model_path)
- edit_model_config(args.model_path)
\ No newline at end of file
+ edit_model_config(args.model_path)
diff --git a/test/srt/test_curl.sh b/scripts/deprecated/test_curl.sh
similarity index 86%
rename from test/srt/test_curl.sh
rename to scripts/deprecated/test_curl.sh
index 4362eaa9355..1c83208a759 100644
--- a/test/srt/test_curl.sh
+++ b/scripts/deprecated/test_curl.sh
@@ -3,7 +3,7 @@ curl http://localhost:30000/generate \
-d '{
"text": "Once upon a time,",
"sampling_params": {
- "max_new_tokens": 16,
+ "max_new_tokens": 64,
"temperature": 0
}
}'
diff --git a/test/srt/test_flashinfer.py b/scripts/deprecated/test_flashinfer.py
similarity index 100%
rename from test/srt/test_flashinfer.py
rename to scripts/deprecated/test_flashinfer.py
diff --git a/test/srt/test_httpserver_classify.py b/scripts/deprecated/test_httpserver_classify.py
similarity index 100%
rename from test/srt/test_httpserver_classify.py
rename to scripts/deprecated/test_httpserver_classify.py
diff --git a/test/srt/test_httpserver_concurrent.py b/scripts/deprecated/test_httpserver_concurrent.py
similarity index 100%
rename from test/srt/test_httpserver_concurrent.py
rename to scripts/deprecated/test_httpserver_concurrent.py
diff --git a/test/srt/test_httpserver_decode.py b/scripts/deprecated/test_httpserver_decode.py
similarity index 69%
rename from test/srt/test_httpserver_decode.py
rename to scripts/deprecated/test_httpserver_decode.py
index 7e169f3e423..57517a15b00 100644
--- a/test/srt/test_httpserver_decode.py
+++ b/scripts/deprecated/test_httpserver_decode.py
@@ -13,14 +13,15 @@
import requests
-def test_decode(url, return_logprob, top_logprobs_num, return_text):
+def test_decode(url, return_logprob=False, top_logprobs_num=0, return_text=False, n=1):
response = requests.post(
url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {
- "temperature": 0,
+ "temperature": 0 if n == 1 else 0.5,
"max_new_tokens": 32,
+ "n": n,
},
"stream": False,
"return_logprob": return_logprob,
@@ -41,8 +42,14 @@ def test_decode(url, return_logprob, top_logprobs_num, return_text):
url = f"{args.host}:{args.port}"
- test_decode(url, False, 0, False)
- test_decode(url, True, 0, False)
- test_decode(url, True, 0, True)
- test_decode(url, True, 3, False)
- test_decode(url, True, 3, True)
+ test_decode(url)
+ test_decode(url, n=3)
+
+ for top_logprobs_num in [0, 3]:
+ for return_text in [True, False]:
+ test_decode(
+ url,
+ return_logprob=True,
+ top_logprobs_num=top_logprobs_num,
+ return_text=return_text,
+ )
diff --git a/test/srt/test_httpserver_decode_stream.py b/scripts/deprecated/test_httpserver_decode_stream.py
similarity index 89%
rename from test/srt/test_httpserver_decode_stream.py
rename to scripts/deprecated/test_httpserver_decode_stream.py
index 38f090b7d1b..955c368d154 100644
--- a/test/srt/test_httpserver_decode_stream.py
+++ b/scripts/deprecated/test_httpserver_decode_stream.py
@@ -40,14 +40,14 @@ def test_decode_stream(url, return_logprob, top_logprobs_num):
data = json.loads(chunk[5:].strip("\n"))
if return_logprob:
- assert data["meta_info"]["prefill_token_logprobs"] is not None
- assert data["meta_info"]["decode_token_logprobs"] is not None
+ assert data["meta_info"]["input_token_logprobs"] is not None
+ assert data["meta_info"]["output_token_logprobs"] is not None
assert data["meta_info"]["normalized_prompt_logprob"] is not None
for logprob, token_id, token_text in data["meta_info"][
- "decode_token_logprobs"
+ "output_token_logprobs"
][prev:]:
print(f"{token_text:12s}\t{logprob}\t{token_id}", flush=True)
- prev = len(data["meta_info"]["decode_token_logprobs"])
+ prev = len(data["meta_info"]["output_token_logprobs"])
else:
output = data["text"].strip()
print(output[prev:], end="", flush=True)
diff --git a/test/srt/test_httpserver_llava.py b/scripts/deprecated/test_httpserver_llava.py
similarity index 97%
rename from test/srt/test_httpserver_llava.py
rename to scripts/deprecated/test_httpserver_llava.py
index e3cf1b79931..791fc6deb1f 100644
--- a/test/srt/test_httpserver_llava.py
+++ b/scripts/deprecated/test_httpserver_llava.py
@@ -10,7 +10,6 @@
import argparse
import asyncio
import json
-import time
import aiohttp
import requests
@@ -37,7 +36,7 @@ async def test_concurrent(args):
"image_data": "example_image.png",
"sampling_params": {
"temperature": 0,
- "max_new_tokens": 16,
+ "max_new_tokens": 64,
},
},
)
diff --git a/test/srt/test_httpserver_reuse.py b/scripts/deprecated/test_httpserver_reuse.py
similarity index 100%
rename from test/srt/test_httpserver_reuse.py
rename to scripts/deprecated/test_httpserver_reuse.py
diff --git a/test/srt/test_jump_forward.py b/scripts/deprecated/test_jump_forward.py
similarity index 100%
rename from test/srt/test_jump_forward.py
rename to scripts/deprecated/test_jump_forward.py
diff --git a/test/srt/test_robust.py b/scripts/deprecated/test_robust.py
similarity index 100%
rename from test/srt/test_robust.py
rename to scripts/deprecated/test_robust.py
diff --git a/scripts/format.sh b/scripts/format.sh
deleted file mode 100644
index a49aed74549..00000000000
--- a/scripts/format.sh
+++ /dev/null
@@ -1,8 +0,0 @@
-isort python
-black python
-
-isort test
-black test
-
-isort benchmark
-black benchmark
diff --git a/scripts/launch_tgi.sh b/scripts/launch_tgi.sh
deleted file mode 100644
index eeb4054754f..00000000000
--- a/scripts/launch_tgi.sh
+++ /dev/null
@@ -1,6 +0,0 @@
-docker run --name tgi --rm -ti --gpus all --network host \
- -v /home/ubuntu/model_weights/Llama-2-7b-chat-hf:/Llama-2-7b-chat-hf \
- ghcr.io/huggingface/text-generation-inference:1.3.0 \
- --model-id /Llama-2-7b-chat-hf --num-shard 1 --trust-remote-code \
- --max-input-length 2048 --max-total-tokens 4096 \
- --port 24000
diff --git a/playground/launch_tgi.sh b/scripts/playground/launch_tgi.sh
similarity index 100%
rename from playground/launch_tgi.sh
rename to scripts/playground/launch_tgi.sh
diff --git a/playground/load_tokenizer.py b/scripts/playground/load_tokenizer.py
similarity index 61%
rename from playground/load_tokenizer.py
rename to scripts/playground/load_tokenizer.py
index 39fa1842481..94cf34bc71f 100644
--- a/playground/load_tokenizer.py
+++ b/scripts/playground/load_tokenizer.py
@@ -3,11 +3,12 @@
from sglang.srt.hf_transformers_utils import get_tokenizer
-
if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument("--name", type=str, default="meta-llama/Meta-Llama-3-8B-Instruct")
+ parser.add_argument(
+ "--name", type=str, default="meta-llama/Meta-Llama-3-8B-Instruct"
+ )
args = parser.parse_args()
t = get_tokenizer(args.name)
- code.interact(local=locals())
\ No newline at end of file
+ code.interact(local=locals())
diff --git a/playground/reference_hf.py b/scripts/playground/reference_hf.py
similarity index 93%
rename from playground/reference_hf.py
rename to scripts/playground/reference_hf.py
index ca82871c9de..ac91b3bed40 100644
--- a/playground/reference_hf.py
+++ b/scripts/playground/reference_hf.py
@@ -30,9 +30,12 @@
@torch.inference_mode()
def normal_text(args):
- t = AutoTokenizer.from_pretrained(args.model_path)
+ t = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
m = AutoModelForCausalLM.from_pretrained(
- args.model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
+ args.model_path,
+ torch_dtype=torch.float16,
+ low_cpu_mem_usage=True,
+ trust_remote_code=True,
)
m.cuda()
diff --git a/test/README.md b/test/README.md
new file mode 100644
index 00000000000..cdfbbaee81a
--- /dev/null
+++ b/test/README.md
@@ -0,0 +1,26 @@
+# Run Unit Tests
+
+## Test Frontend Language
+```
+cd sglang/test/lang
+export OPENAI_API_KEY=sk-*****
+
+# Run a single file
+python3 test_openai_backend.py
+
+# Run a suite
+python3 run_suite.py --suite minimal
+```
+
+## Test Backend Runtime
+```
+cd sglang/test/srt
+
+# Run a single file
+python3 test_eval_accuracy.py
+
+# Run a suite
+python3 run_suite.py --suite minimal
+```
+
+
diff --git a/test/__init__.py b/test/__init__.py
deleted file mode 100644
index e69de29bb2d..00000000000
diff --git a/test/lang/run_all.py b/test/lang/run_all.py
deleted file mode 100644
index cb5da15850b..00000000000
--- a/test/lang/run_all.py
+++ /dev/null
@@ -1,60 +0,0 @@
-import argparse
-import glob
-import multiprocessing
-import os
-import time
-import unittest
-
-from sglang.utils import run_with_timeout
-
-
-def run_unittest_files(files, args):
- for filename in files:
-
- def func():
- print(filename)
- ret = unittest.main(module=None, argv=["", "-vb"] + [filename])
-
- p = multiprocessing.Process(target=func)
-
- def run_one_file():
- p.start()
- p.join()
-
- try:
- run_with_timeout(run_one_file, timeout=args.time_limit_per_file)
- if p.exitcode != 0:
- return False
- except TimeoutError:
- p.terminate()
- time.sleep(5)
- print(
- f"\nTimeout after {args.time_limit_per_file} seconds "
- f"when running {filename}"
- )
- return False
-
- return True
-
-
-if __name__ == "__main__":
- arg_parser = argparse.ArgumentParser()
- arg_parser.add_argument(
- "--time-limit-per-file",
- type=int,
- default=1000,
- help="The time limit for running one file in seconds.",
- )
- args = arg_parser.parse_args()
-
- files = glob.glob("**/test_*.py", recursive=True)
-
- tic = time.time()
- success = run_unittest_files(files, args)
-
- if success:
- print(f"Success. Time elapsed: {time.time() - tic:.2f}s")
- else:
- print(f"Fail. Time elapsed: {time.time() - tic:.2f}s")
-
- exit(0 if success else -1)
diff --git a/test/lang/run_suite.py b/test/lang/run_suite.py
new file mode 100644
index 00000000000..379427afac9
--- /dev/null
+++ b/test/lang/run_suite.py
@@ -0,0 +1,34 @@
+import argparse
+import glob
+
+from sglang.test.test_utils import run_unittest_files
+
+suites = {
+ "minimal": ["test_srt_backend.py", "test_openai_backend.py"],
+}
+
+
+if __name__ == "__main__":
+ arg_parser = argparse.ArgumentParser()
+ arg_parser.add_argument(
+ "--timeout-per-file",
+ type=int,
+ default=1000,
+ help="The time limit for running one file in seconds.",
+ )
+ arg_parser.add_argument(
+ "--suite",
+ type=str,
+ default=list(suites.keys())[0],
+ choices=list(suites.keys()) + ["all"],
+ help="The suite to run",
+ )
+ args = arg_parser.parse_args()
+
+ if args.suite == "all":
+ files = glob.glob("**/test_*.py", recursive=True)
+ else:
+ files = suites[args.suite]
+
+ exit_code = run_unittest_files(files, args.timeout_per_file)
+ exit(exit_code)
diff --git a/test/lang/test_anthropic_backend.py b/test/lang/test_anthropic_backend.py
index 3eb4051d739..87b27a765a3 100644
--- a/test/lang/test_anthropic_backend.py
+++ b/test/lang/test_anthropic_backend.py
@@ -7,14 +7,11 @@
class TestAnthropicBackend(unittest.TestCase):
backend = None
- chat_backend = None
- def setUp(self):
- cls = type(self)
-
- if cls.backend is None:
- cls.backend = Anthropic("claude-3-haiku-20240307")
- set_default_backend(cls.backend)
+ @classmethod
+ def setUpClass(cls):
+ cls.backend = Anthropic("claude-3-haiku-20240307")
+ set_default_backend(cls.backend)
def test_mt_bench(self):
test_mt_bench()
@@ -30,5 +27,5 @@ def test_stream(self):
# global_config.verbosity = 2
# t = TestAnthropicBackend()
- # t.setUp()
+ # t.setUpClass()
# t.test_mt_bench()
diff --git a/test/lang/test_bind_cache.py b/test/lang/test_bind_cache.py
index 9cba14ce437..14a7e509863 100644
--- a/test/lang/test_bind_cache.py
+++ b/test/lang/test_bind_cache.py
@@ -1,17 +1,20 @@
import unittest
import sglang as sgl
-from sglang.backend.runtime_endpoint import RuntimeEndpoint
+from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST
class TestBind(unittest.TestCase):
backend = None
- def setUp(self):
- cls = type(self)
+ @classmethod
+ def setUpClass(cls):
+ cls.backend = sgl.Runtime(model_path=DEFAULT_MODEL_NAME_FOR_TEST)
+ sgl.set_default_backend(cls.backend)
- if cls.backend is None:
- cls.backend = RuntimeEndpoint(base_url="http://localhost:30000")
+ @classmethod
+ def tearDownClass(cls):
+ cls.backend.shutdown()
def test_bind(self):
@sgl.function
@@ -48,5 +51,5 @@ def few_shot_qa(s, prompt, question):
unittest.main(warnings="ignore")
# t = TestBind()
- # t.setUp()
+ # t.setUpClass()
# t.test_cache()
diff --git a/test/lang/test_choices.py b/test/lang/test_choices.py
new file mode 100644
index 00000000000..da25e9e496f
--- /dev/null
+++ b/test/lang/test_choices.py
@@ -0,0 +1,95 @@
+import unittest
+
+import numpy as np
+
+from sglang.lang.choices import (
+ greedy_token_selection,
+ token_length_normalized,
+ unconditional_likelihood_normalized,
+)
+
+MOCK_CHOICES_INPUT_DATA = {
+ "choices": [
+ "organ", # ["organ"]
+ "organism", # ["organ", "ism"]
+ "antidisestablishmentarianism", # ["ant", "id", "is", "est", "ablish", "ment", "arian", "ism"]
+ ],
+ "normalized_prompt_logprobs": [-0.1, -0.2, -0.05],
+ "input_token_logprobs": [
+ [[-0.1, 1, None]],
+ [[-0.1, 1, None], [-0.3, 2, None]],
+ [
+ [-0.4, 3, None],
+ [-0.25, 4, None],
+ [-0.1, 5, None],
+ [-0.01, 6, None],
+ [-0.01, 7, None],
+ [-0.01, 8, None],
+ [-0.01, 9, None],
+ [-0.01, 2, None],
+ ],
+ ],
+ "output_token_logprobs": [
+ [[-0.1, 10, None]],
+ [[-0.1, 10, None]],
+ [[-0.1, 10, None]],
+ ],
+ "unconditional_token_logprobs": [
+ [[None, 1, None]],
+ [[None, 1, None], [-1.4, 2, None]],
+ [
+ [None, 3, None],
+ [-0.25, 4, None],
+ [-0.1, 5, None],
+ [-0.01, 6, None],
+ [-0.01, 7, None],
+ [-0.01, 8, None],
+ [-0.01, 9, None],
+ [-0.01, 2, None],
+ ],
+ ],
+}
+
+
+class TestChoices(unittest.TestCase):
+
+ def test_token_length_normalized(self):
+ """Confirm 'antidisestablishmentarianism' is selected due to high confidences for
+ its later tokens resulting in highest token length normalized prompt logprob."""
+ decision = token_length_normalized(**MOCK_CHOICES_INPUT_DATA)
+ assert decision.decision == "antidisestablishmentarianism"
+
+ def test_greedy_token_selection(self):
+ """Confirm 'organ' is selected due it having the joint highest initial token
+ logprob, and a higher average logprob than organism's second token."""
+ decision = greedy_token_selection(**MOCK_CHOICES_INPUT_DATA)
+ assert decision.decision == "organ"
+ assert np.allclose(
+ decision.meta_info["greedy_logprob_matrix"],
+ [
+ [-0.1, -0.1, -0.1, -0.1, -0.1, -0.1, -0.1, -0.1],
+ [-0.1, -0.3, -0.2, -0.2, -0.2, -0.2, -0.2, -0.2],
+ [-0.4, -0.25, -0.1, -0.01, -0.01, -0.01, -0.01, -0.01],
+ ],
+ atol=0.01,
+ )
+
+ def test_unconditional_likelihood_normalized(self):
+ """Confirm 'organism' is selected due to it having the highest average token logprob
+ once normalized by the unconditional token logprobs."""
+ decision = unconditional_likelihood_normalized(**MOCK_CHOICES_INPUT_DATA)
+ assert decision.decision == "organism"
+ assert np.allclose(
+ decision.meta_info["normalized_unconditional_prompt_logprobs"],
+ [-0.1, 0.5, -0.05],
+ atol=0.01,
+ )
+
+
+if __name__ == "__main__":
+ unittest.main(warnings="ignore")
+
+ # t = TestChoices()
+ # t.test_token_length_normalized()
+ # t.test_greedy_token_selection()
+ # t.test_unconditional_likelihood_normalized()
diff --git a/test/lang/test_litellm_backend.py b/test/lang/test_litellm_backend.py
index 15d83bd517a..3c7f5db2182 100644
--- a/test/lang/test_litellm_backend.py
+++ b/test/lang/test_litellm_backend.py
@@ -6,15 +6,12 @@
class TestAnthropicBackend(unittest.TestCase):
- backend = None
chat_backend = None
- def setUp(self):
- cls = type(self)
-
- if cls.backend is None:
- cls.backend = LiteLLM("gpt-3.5-turbo")
- set_default_backend(cls.backend)
+ @classmethod
+ def setUpClass(cls):
+ cls.chat_backend = LiteLLM("gpt-3.5-turbo")
+ set_default_backend(cls.chat_backend)
def test_mt_bench(self):
test_mt_bench()
diff --git a/test/lang/test_openai_backend.py b/test/lang/test_openai_backend.py
index d35495e4d75..b1bb47b82f6 100644
--- a/test/lang/test_openai_backend.py
+++ b/test/lang/test_openai_backend.py
@@ -20,20 +20,18 @@
class TestOpenAIBackend(unittest.TestCase):
- backend = None
+ instruct_backend = None
chat_backend = None
chat_vision_backend = None
- def setUp(self):
- cls = type(self)
-
- if cls.backend is None:
- cls.backend = OpenAI("gpt-3.5-turbo-instruct")
- cls.chat_backend = OpenAI("gpt-3.5-turbo")
- cls.chat_vision_backend = OpenAI("gpt-4-turbo")
+ @classmethod
+ def setUpClass(cls):
+ cls.instruct_backend = OpenAI("gpt-3.5-turbo-instruct")
+ cls.chat_backend = OpenAI("gpt-3.5-turbo")
+ cls.chat_vision_backend = OpenAI("gpt-4-turbo")
def test_few_shot_qa(self):
- set_default_backend(self.backend)
+ set_default_backend(self.instruct_backend)
test_few_shot_qa()
def test_mt_bench(self):
@@ -41,35 +39,35 @@ def test_mt_bench(self):
test_mt_bench()
def test_select(self):
- set_default_backend(self.backend)
+ set_default_backend(self.instruct_backend)
test_select(check_answer=True)
def test_decode_int(self):
- set_default_backend(self.backend)
+ set_default_backend(self.instruct_backend)
test_decode_int()
def test_decode_json(self):
- set_default_backend(self.backend)
+ set_default_backend(self.instruct_backend)
test_decode_json()
def test_expert_answer(self):
- set_default_backend(self.backend)
+ set_default_backend(self.instruct_backend)
test_expert_answer()
def test_tool_use(self):
- set_default_backend(self.backend)
+ set_default_backend(self.instruct_backend)
test_tool_use()
def test_react(self):
- set_default_backend(self.backend)
+ set_default_backend(self.instruct_backend)
test_react()
def test_parallel_decoding(self):
- set_default_backend(self.backend)
+ set_default_backend(self.instruct_backend)
test_parallel_decoding()
def test_parallel_encoding(self):
- set_default_backend(self.backend)
+ set_default_backend(self.instruct_backend)
test_parallel_encoding()
def test_image_qa(self):
@@ -77,11 +75,11 @@ def test_image_qa(self):
test_image_qa()
def test_stream(self):
- set_default_backend(self.backend)
+ set_default_backend(self.instruct_backend)
test_stream()
def test_completion_speculative(self):
- set_default_backend(self.backend)
+ set_default_backend(self.instruct_backend)
test_completion_speculative()
def test_chat_completion_speculative(self):
@@ -96,5 +94,5 @@ def test_chat_completion_speculative(self):
# global_config.verbosity = 2
# t = TestOpenAIBackend()
- # t.setUp()
- # t.test_chat_completion_speculative()
+ # t.setUpClass()
+ # t.test_stream()
diff --git a/test/lang/test_srt_backend.py b/test/lang/test_srt_backend.py
index c92568c0bf1..778cde8be4e 100644
--- a/test/lang/test_srt_backend.py
+++ b/test/lang/test_srt_backend.py
@@ -1,7 +1,3 @@
-"""
-python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
-"""
-
import json
import unittest
@@ -13,24 +9,25 @@
test_few_shot_qa,
test_mt_bench,
test_parallel_decoding,
- test_parallel_encoding,
- test_react,
test_regex,
test_select,
test_stream,
test_tool_use,
)
+from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST
class TestSRTBackend(unittest.TestCase):
backend = None
- def setUp(self):
- cls = type(self)
+ @classmethod
+ def setUpClass(cls):
+ cls.backend = sgl.Runtime(model_path=DEFAULT_MODEL_NAME_FOR_TEST)
+ sgl.set_default_backend(cls.backend)
- if cls.backend is None:
- cls.backend = sgl.RuntimeEndpoint(base_url="http://localhost:30000")
- sgl.set_default_backend(cls.backend)
+ @classmethod
+ def tearDownClass(cls):
+ cls.backend.shutdown()
def test_few_shot_qa(self):
test_few_shot_qa()
@@ -62,9 +59,6 @@ def test_stream(self):
def test_regex(self):
test_regex()
- # def test_parallel_encoding(self):
- # test_parallel_encoding(check_answer=False)
-
if __name__ == "__main__":
unittest.main(warnings="ignore")
@@ -73,5 +67,6 @@ def test_regex(self):
# global_config.verbosity = 2
# t = TestSRTBackend()
- # t.setUp()
- # t.test_regex()
+ # t.setUpClass()
+ # t.test_few_shot_qa()
+ # t.tearDownClass()
diff --git a/test/lang/test_tracing.py b/test/lang/test_tracing.py
index 266ce65fe38..5f2bc1d04fe 100644
--- a/test/lang/test_tracing.py
+++ b/test/lang/test_tracing.py
@@ -1,7 +1,7 @@
import unittest
import sglang as sgl
-from sglang.backend.base_backend import BaseBackend
+from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template
@@ -16,7 +16,7 @@ def few_shot_qa(s, question):
s += "A:" + sgl.gen("answer", stop="\n")
tracer = few_shot_qa.trace()
- print(tracer.last_node.print_graph_dfs() + "\n")
+ # print(tracer.last_node.print_graph_dfs() + "\n")
def test_select(self):
@sgl.function
@@ -26,7 +26,7 @@ def capital(s):
s += "It is a city" + sgl.gen("description", stop=".")
tracer = capital.trace()
- print(tracer.last_node.print_graph_dfs() + "\n")
+ # print(tracer.last_node.print_graph_dfs() + "\n")
def test_raise_warning(self):
@sgl.function
@@ -66,11 +66,11 @@ def tip_suggestion(s, topic):
s += "In summary" + sgl.gen("summary")
compiled = tip_suggestion.compile()
- compiled.print_graph()
+ # compiled.print_graph()
sgl.set_default_backend(sgl.OpenAI("gpt-3.5-turbo-instruct"))
state = compiled.run(topic="staying healthy")
- print(state.text() + "\n")
+ # print(state.text() + "\n")
states = compiled.run_batch(
[
@@ -80,8 +80,8 @@ def tip_suggestion(s, topic):
],
temperature=0,
)
- for s in states:
- print(s.text() + "\n")
+ # for s in states:
+ # print(s.text() + "\n")
def test_role(self):
@sgl.function
@@ -95,7 +95,7 @@ def multi_turn_chat(s):
backend.chat_template = get_chat_template("llama-2-chat")
compiled = multi_turn_chat.compile(backend=backend)
- compiled.print_graph()
+ # compiled.print_graph()
def test_fork(self):
@sgl.function
@@ -118,10 +118,10 @@ def tip_suggestion(s):
s += "In summary" + sgl.gen("summary")
tracer = tip_suggestion.trace()
- print(tracer.last_node.print_graph_dfs())
+ # print(tracer.last_node.print_graph_dfs())
a = tip_suggestion.run(backend=sgl.OpenAI("gpt-3.5-turbo-instruct"))
- print(a.text())
+ # print(a.text())
if __name__ == "__main__":
diff --git a/test/lang/test_vertexai_backend.py b/test/lang/test_vertexai_backend.py
index aae840101ac..b29efaa75ad 100644
--- a/test/lang/test_vertexai_backend.py
+++ b/test/lang/test_vertexai_backend.py
@@ -17,13 +17,11 @@ class TestVertexAIBackend(unittest.TestCase):
chat_backend = None
chat_vision_backend = None
- def setUp(self):
- cls = type(self)
-
- if cls.backend is None:
- cls.backend = VertexAI("gemini-pro")
- cls.chat_backend = VertexAI("gemini-pro")
- cls.chat_vision_backend = VertexAI("gemini-pro-vision")
+ @classmethod
+ def setUpClass(cls):
+ cls.backend = VertexAI("gemini-pro")
+ cls.chat_backend = VertexAI("gemini-pro")
+ cls.chat_vision_backend = VertexAI("gemini-pro-vision")
def test_few_shot_qa(self):
set_default_backend(self.backend)
@@ -61,5 +59,5 @@ def test_stream(self):
# global_config.verbosity = 2
# t = TestVertexAIBackend()
- # t.setUp()
+ # t.setUpClass()
# t.test_stream()
diff --git a/test/srt/example_image.png b/test/srt/example_image.png
deleted file mode 120000
index c8a970edd0c..00000000000
--- a/test/srt/example_image.png
+++ /dev/null
@@ -1 +0,0 @@
-../lang/example_image.png
\ No newline at end of file
diff --git a/test/srt/models/test_embedding_models.py b/test/srt/models/test_embedding_models.py
new file mode 100644
index 00000000000..c29c33188c0
--- /dev/null
+++ b/test/srt/models/test_embedding_models.py
@@ -0,0 +1,69 @@
+"""
+Copyright 2023-2024 SGLang Team
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+import unittest
+
+import torch
+
+from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
+from sglang.test.test_utils import get_similarities
+
+MODELS = [("intfloat/e5-mistral-7b-instruct", 1)]
+TORCH_DTYPES = [torch.float16]
+
+
+class TestEmbeddingModels(unittest.TestCase):
+
+ def assert_close_prefill_logits(
+ self,
+ prompts,
+ model_path,
+ tp_size,
+ torch_dtype,
+ ) -> None:
+ with HFRunner(
+ model_path, torch_dtype=torch_dtype, is_generation_model=False
+ ) as hf_runner:
+ hf_outputs = hf_runner.forward(prompts)
+
+ with SRTRunner(
+ model_path,
+ tp_size=tp_size,
+ torch_dtype=torch_dtype,
+ is_generation_model=False,
+ ) as srt_runner:
+ srt_outputs = srt_runner.forward(prompts)
+
+ for i in range(len(prompts)):
+ hf_logits = torch.Tensor(hf_outputs.embed_logits[i])
+ srt_logits = torch.Tensor(srt_outputs.embed_logits[i])
+
+ similarities = torch.tensor(get_similarities(hf_logits, srt_logits))
+
+ tolerance = 1e-2
+ assert torch.all(
+ abs(similarities - 1) < tolerance
+ ), f"embeddings not all close"
+
+ def test_prefill_logits(self):
+ for model, tp_size in MODELS:
+ for torch_dtype in TORCH_DTYPES:
+ self.assert_close_prefill_logits(
+ DEFAULT_PROMPTS, model, tp_size, torch_dtype
+ )
+
+
+if __name__ == "__main__":
+ unittest.main(warnings="ignore")
diff --git a/test/srt/models/test_generation_models.py b/test/srt/models/test_generation_models.py
new file mode 100644
index 00000000000..f057648020f
--- /dev/null
+++ b/test/srt/models/test_generation_models.py
@@ -0,0 +1,68 @@
+"""
+Copyright 2023-2024 SGLang Team
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+import unittest
+
+import torch
+
+from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
+
+MODELS = [
+ ("meta-llama/Meta-Llama-3.1-8B-Instruct", 1),
+]
+TORCH_DTYPES = [torch.float16]
+
+
+class TestCausalModels(unittest.TestCase):
+
+ def assert_close_prefill_logits(
+ self,
+ prompts,
+ model_path,
+ tp_size,
+ torch_dtype,
+ ) -> None:
+ with HFRunner(
+ model_path, torch_dtype=torch_dtype, is_generation_model=True
+ ) as hf_runner:
+ hf_outputs = hf_runner.forward(prompts)
+
+ with SRTRunner(
+ model_path,
+ tp_size=tp_size,
+ torch_dtype=torch_dtype,
+ is_generation_model=True,
+ ) as srt_runner:
+ srt_outputs = srt_runner.forward(prompts)
+
+ for i in range(len(prompts)):
+ hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i])
+ srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
+
+ tolerance = 3e-2
+ assert torch.all(
+ abs(hf_logprobs - srt_logprobs) < tolerance
+ ), f"prefill logprobs not all close"
+
+ def test_prefill_logits(self):
+ for model, tp_size in MODELS:
+ for torch_dtype in TORCH_DTYPES:
+ self.assert_close_prefill_logits(
+ DEFAULT_PROMPTS, model, tp_size, torch_dtype
+ )
+
+
+if __name__ == "__main__":
+ unittest.main(warnings="ignore")
diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py
new file mode 100644
index 00000000000..d5051ffc1e2
--- /dev/null
+++ b/test/srt/run_suite.py
@@ -0,0 +1,54 @@
+import argparse
+import glob
+
+from sglang.test.test_utils import run_unittest_files
+
+suites = {
+ "minimal": [
+ "test_eval_accuracy.py",
+ "test_openai_server.py",
+ "test_vision_openai_server.py",
+ "test_chunked_prefill.py",
+ "test_torch_compile.py",
+ "test_models_from_modelscope.py",
+ "models/test_generation_models.py",
+ "models/test_embedding_models.py",
+ "sampling/penaltylib",
+ ],
+ "sampling/penaltylib": glob.glob(
+ "sampling/penaltylib/**/test_*.py", recursive=True
+ ),
+}
+
+for target_suite_name, target_tests in suites.items():
+ for suite_name, tests in suites.items():
+ if suite_name == target_suite_name:
+ continue
+ if target_suite_name in tests:
+ tests.remove(target_suite_name)
+ tests.extend(target_tests)
+
+if __name__ == "__main__":
+ arg_parser = argparse.ArgumentParser()
+ arg_parser.add_argument(
+ "--timeout-per-file",
+ type=int,
+ default=2000,
+ help="The time limit for running one file in seconds.",
+ )
+ arg_parser.add_argument(
+ "--suite",
+ type=str,
+ default=list(suites.keys())[0],
+ choices=list(suites.keys()) + ["all"],
+ help="The suite to run",
+ )
+ args = arg_parser.parse_args()
+
+ if args.suite == "all":
+ files = glob.glob("**/test_*.py", recursive=True)
+ else:
+ files = suites[args.suite]
+
+ exit_code = run_unittest_files(files, args.timeout_per_file)
+ exit(exit_code)
diff --git a/test/srt/sampling/penaltylib/penalizers/test_frequency_penalty.py b/test/srt/sampling/penaltylib/penalizers/test_frequency_penalty.py
new file mode 100644
index 00000000000..59db353abfa
--- /dev/null
+++ b/test/srt/sampling/penaltylib/penalizers/test_frequency_penalty.py
@@ -0,0 +1,93 @@
+import typing
+import unittest
+
+import torch
+
+from sglang.srt.sampling.penaltylib.penalizers.frequency_penalty import (
+ BatchedFrequencyPenalizer,
+)
+from sglang.test.srt.sampling.penaltylib.utils import (
+ BaseBatchedPenalizerTest,
+ MockSamplingParams,
+ Step,
+ StepType,
+ Subject,
+)
+
+
+class BaseBatchedFrequencyPenalizerTest(BaseBatchedPenalizerTest):
+ Penalizer = BatchedFrequencyPenalizer
+ frequency_penalty: float
+
+ def setUp(self):
+ if self.__class__ == BaseBatchedFrequencyPenalizerTest:
+ self.skipTest("Base class for frequency_penalty tests")
+
+ super().setUp()
+
+ def _create_subject(self, frequency_penalty: float) -> Subject:
+ return Subject(
+ sampling_params=MockSamplingParams(
+ frequency_penalty=frequency_penalty,
+ ),
+ steps=[
+ Step(
+ type=StepType.INPUT,
+ token_ids=[0, 1, 2],
+ expected_tensors={
+ "frequency_penalties": self.tensor(
+ [[frequency_penalty] * self.vocab_size], dtype=torch.float32
+ ),
+ "cumulated_frequency_penalties": self.tensor(
+ [[0.0] * self.vocab_size], dtype=torch.float32
+ ),
+ },
+ expected_logits=self.tensor(
+ [[1] * self.vocab_size], dtype=torch.float32
+ ),
+ ),
+ Step(
+ type=StepType.OUTPUT,
+ token_ids=[1, 2, 2],
+ expected_tensors={
+ "frequency_penalties": self.tensor(
+ [[frequency_penalty] * self.vocab_size], dtype=torch.float32
+ ),
+ "cumulated_frequency_penalties": self.tensor(
+ [
+ [
+ frequency_penalty * i if i in {1, 2} else 0.0
+ for i in range(self.vocab_size)
+ ],
+ ],
+ dtype=torch.float32,
+ ),
+ },
+ expected_logits=self.tensor(
+ [
+ [
+ 1.0 - frequency_penalty * i if i in {1, 2} else 1.0
+ for i in range(self.vocab_size)
+ ],
+ ],
+ dtype=torch.float32,
+ ),
+ ),
+ ],
+ )
+
+ def create_test_subjects(self) -> typing.List[Subject]:
+ self.enabled = self._create_subject(frequency_penalty=self.frequency_penalty)
+ self.disabled = self._create_subject(frequency_penalty=0.0)
+
+
+class TestBatchedFrequencyPenalizerPositiveValue(BaseBatchedFrequencyPenalizerTest):
+ frequency_penalty = 0.12
+
+
+class TestBatchedFrequencyPenalizerNegativeValue(BaseBatchedFrequencyPenalizerTest):
+ frequency_penalty = -0.12
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/test/srt/sampling/penaltylib/penalizers/test_min_new_tokens.py b/test/srt/sampling/penaltylib/penalizers/test_min_new_tokens.py
new file mode 100644
index 00000000000..1984aafe5ea
--- /dev/null
+++ b/test/srt/sampling/penaltylib/penalizers/test_min_new_tokens.py
@@ -0,0 +1,152 @@
+import typing
+import unittest
+
+import torch
+
+from sglang.srt.sampling.penaltylib.penalizers.min_new_tokens import (
+ BatchedMinNewTokensPenalizer,
+)
+from sglang.test.srt.sampling.penaltylib.utils import (
+ BaseBatchedPenalizerTest,
+ MockSamplingParams,
+ Step,
+ StepType,
+ Subject,
+)
+
+MIN_NEW_TOKENS = 2
+EOS_TOKEN_ID = 4
+STOP_TOKEN_ID = 3
+
+ALL_STOP_TOKEN_IDS = {STOP_TOKEN_ID, EOS_TOKEN_ID}
+
+
+class TestBatchedMinNewTokensPenalizer(BaseBatchedPenalizerTest):
+ Penalizer = BatchedMinNewTokensPenalizer
+
+ def _create_subject(self, min_new_tokens: int) -> Subject:
+ return Subject(
+ eos_token_id=EOS_TOKEN_ID,
+ sampling_params=MockSamplingParams(
+ min_new_tokens=min_new_tokens,
+ stop_token_ids={STOP_TOKEN_ID},
+ ),
+ steps=[
+ Step(
+ type=StepType.INPUT,
+ token_ids=[0, 1, 2],
+ expected_tensors={
+ "min_new_tokens": self.tensor(
+ [[min_new_tokens]], dtype=torch.int32
+ ),
+ "stop_token_penalties": self.tensor(
+ [
+ [
+ float("-inf") if i in ALL_STOP_TOKEN_IDS else 0
+ for i in range(self.vocab_size)
+ ]
+ ],
+ dtype=torch.float32,
+ ),
+ "len_output_tokens": self.tensor([[0]], dtype=torch.int32),
+ },
+ expected_logits=(
+ self.tensor(
+ [
+ [
+ float("-inf") if i in ALL_STOP_TOKEN_IDS else 1
+ for i in range(self.vocab_size)
+ ]
+ ],
+ dtype=torch.float32,
+ )
+ if min_new_tokens > 0
+ else torch.ones(
+ (1, self.vocab_size),
+ dtype=torch.float32,
+ device=self.device,
+ )
+ ),
+ ),
+ Step(
+ type=StepType.OUTPUT,
+ token_ids=[0],
+ expected_tensors={
+ "min_new_tokens": self.tensor(
+ [[min_new_tokens]], dtype=torch.int32
+ ),
+ "stop_token_penalties": self.tensor(
+ [
+ [
+ float("-inf") if i in ALL_STOP_TOKEN_IDS else 0
+ for i in range(self.vocab_size)
+ ]
+ ],
+ dtype=torch.float32,
+ ),
+ "len_output_tokens": self.tensor([[1]], dtype=torch.int32),
+ },
+ expected_logits=(
+ self.tensor(
+ [
+ [
+ float("-inf") if i in ALL_STOP_TOKEN_IDS else 1
+ for i in range(self.vocab_size)
+ ]
+ ],
+ dtype=torch.float32,
+ )
+ if min_new_tokens > 1
+ else torch.ones(
+ (1, self.vocab_size),
+ dtype=torch.float32,
+ device=self.device,
+ )
+ ),
+ ),
+ Step(
+ type=StepType.OUTPUT,
+ token_ids=[0],
+ expected_tensors={
+ "min_new_tokens": self.tensor(
+ [[min_new_tokens]], dtype=torch.int32
+ ),
+ "stop_token_penalties": self.tensor(
+ [
+ [
+ float("-inf") if i in ALL_STOP_TOKEN_IDS else 0
+ for i in range(self.vocab_size)
+ ]
+ ],
+ dtype=torch.float32,
+ ),
+ "len_output_tokens": self.tensor([[2]], dtype=torch.int32),
+ },
+ expected_logits=(
+ self.tensor(
+ [
+ [
+ float("-inf") if i in ALL_STOP_TOKEN_IDS else 1
+ for i in range(self.vocab_size)
+ ]
+ ],
+ dtype=torch.float32,
+ )
+ if min_new_tokens > 2
+ else torch.ones(
+ (1, self.vocab_size),
+ dtype=torch.float32,
+ device=self.device,
+ )
+ ),
+ ),
+ ],
+ )
+
+ def create_test_subjects(self) -> typing.List[Subject]:
+ self.enabled = self._create_subject(min_new_tokens=MIN_NEW_TOKENS)
+ self.disabled = self._create_subject(min_new_tokens=0.0)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/test/srt/sampling/penaltylib/penalizers/test_presence_penalty.py b/test/srt/sampling/penaltylib/penalizers/test_presence_penalty.py
new file mode 100644
index 00000000000..96cbf1082e5
--- /dev/null
+++ b/test/srt/sampling/penaltylib/penalizers/test_presence_penalty.py
@@ -0,0 +1,93 @@
+import typing
+import unittest
+
+import torch
+
+from sglang.srt.sampling.penaltylib.penalizers.presence_penalty import (
+ BatchedPresencePenalizer,
+)
+from sglang.test.srt.sampling.penaltylib.utils import (
+ BaseBatchedPenalizerTest,
+ MockSamplingParams,
+ Step,
+ StepType,
+ Subject,
+)
+
+
+class BaseBatchedPresencePenalizerTest(BaseBatchedPenalizerTest):
+ Penalizer = BatchedPresencePenalizer
+ presence_penalty: float
+
+ def setUp(self):
+ if self.__class__ == BaseBatchedPresencePenalizerTest:
+ self.skipTest("Base class for presence_penalty tests")
+
+ super().setUp()
+
+ def _create_subject(self, presence_penalty: float) -> Subject:
+ return Subject(
+ sampling_params=MockSamplingParams(
+ presence_penalty=presence_penalty,
+ ),
+ steps=[
+ Step(
+ type=StepType.INPUT,
+ token_ids=[0, 1, 2],
+ expected_tensors={
+ "presence_penalties": self.tensor(
+ [[presence_penalty] * self.vocab_size], dtype=torch.float32
+ ),
+ "cumulated_presence_penalties": self.tensor(
+ [[0.0] * self.vocab_size], dtype=torch.float32
+ ),
+ },
+ expected_logits=self.tensor(
+ [[1] * self.vocab_size], dtype=torch.float32
+ ),
+ ),
+ Step(
+ type=StepType.OUTPUT,
+ token_ids=[1, 2, 2],
+ expected_tensors={
+ "presence_penalties": self.tensor(
+ [[presence_penalty] * self.vocab_size], dtype=torch.float32
+ ),
+ "cumulated_presence_penalties": self.tensor(
+ [
+ [
+ presence_penalty if i in {1, 2} else 0.0
+ for i in range(self.vocab_size)
+ ],
+ ],
+ dtype=torch.float32,
+ ),
+ },
+ expected_logits=self.tensor(
+ [
+ [
+ 1.0 - presence_penalty if i in {1, 2} else 1.0
+ for i in range(self.vocab_size)
+ ],
+ ],
+ dtype=torch.float32,
+ ),
+ ),
+ ],
+ )
+
+ def create_test_subjects(self) -> typing.List[Subject]:
+ self.enabled = self._create_subject(presence_penalty=self.presence_penalty)
+ self.disabled = self._create_subject(presence_penalty=0.0)
+
+
+class TestBatchedPresencePenalizerPositiveValue(BaseBatchedPresencePenalizerTest):
+ presence_penalty = 0.12
+
+
+class TestBatchedPresencePenalizerNegativeValue(BaseBatchedPresencePenalizerTest):
+ presence_penalty = -0.12
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/test/srt/sampling/penaltylib/penalizers/test_repetition_penalty.py b/test/srt/sampling/penaltylib/penalizers/test_repetition_penalty.py
new file mode 100644
index 00000000000..e3751c14a30
--- /dev/null
+++ b/test/srt/sampling/penaltylib/penalizers/test_repetition_penalty.py
@@ -0,0 +1,87 @@
+import typing
+import unittest
+
+import torch
+
+from sglang.srt.sampling.penaltylib.penalizers.repetition_penalty import (
+ BatchedRepetitionPenalizer,
+)
+from sglang.test.srt.sampling.penaltylib.utils import (
+ BaseBatchedPenalizerTest,
+ MockSamplingParams,
+ Step,
+ StepType,
+ Subject,
+)
+
+REPETITION_PENALTY = 2.0
+
+
+class TestBatchedRepetitionPenalizer(BaseBatchedPenalizerTest):
+ Penalizer = BatchedRepetitionPenalizer
+
+ def _create_subject(self, repetition_penalty: float) -> Subject:
+ l = 1.0 / repetition_penalty
+ return Subject(
+ sampling_params=MockSamplingParams(
+ repetition_penalty=repetition_penalty,
+ ),
+ steps=[
+ Step(
+ type=StepType.INPUT,
+ token_ids=[0, 1, 2],
+ expected_tensors={
+ "repetition_penalties": self.tensor(
+ [[repetition_penalty] * self.vocab_size],
+ dtype=torch.float32,
+ ),
+ "cumulated_repetition_penalties": (
+ self.tensor(
+ [[2.0, 2.0, 2.0, 1.0, 1.0]], dtype=torch.float32
+ )
+ if repetition_penalty != 1.0
+ else self.tensor(
+ [[1.0] * self.vocab_size], dtype=torch.float32
+ )
+ ),
+ },
+ expected_logits=(
+ self.tensor([[l, l, l, 1.0, 1.0]], dtype=torch.float32)
+ if repetition_penalty != 1.0
+ else self.tensor([[1.0] * self.vocab_size], dtype=torch.float32)
+ ),
+ ),
+ Step(
+ type=StepType.OUTPUT,
+ token_ids=[0, 1, 3],
+ expected_tensors={
+ "repetition_penalties": self.tensor(
+ [[repetition_penalty] * self.vocab_size],
+ dtype=torch.float32,
+ ),
+ "cumulated_repetition_penalties": (
+ self.tensor(
+ [[2.0, 2.0, 2.0, 2.0, 1.0]], dtype=torch.float32
+ )
+ if repetition_penalty != 1.0
+ else self.tensor(
+ [[1.0] * self.vocab_size], dtype=torch.float32
+ )
+ ),
+ },
+ expected_logits=(
+ self.tensor([[l, l, l, l, 1.0]], dtype=torch.float32)
+ if repetition_penalty != 1.0
+ else self.tensor([[1.0] * self.vocab_size], dtype=torch.float32)
+ ),
+ ),
+ ],
+ )
+
+ def create_test_subjects(self) -> typing.List[Subject]:
+ self.enabled = self._create_subject(repetition_penalty=REPETITION_PENALTY)
+ self.disabled = self._create_subject(repetition_penalty=1.0)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py b/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py
new file mode 100644
index 00000000000..e72dc30f956
--- /dev/null
+++ b/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py
@@ -0,0 +1,110 @@
+import json
+import unittest
+from multiprocessing import Process
+
+import requests
+
+from sglang.srt.utils import kill_child_process
+from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST, popen_launch_server
+
+
+class TestBatchPenalizerE2E(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ cls.model = DEFAULT_MODEL_NAME_FOR_TEST
+ cls.base_url = f"http://127.0.0.1:{8157}"
+ cls.process = popen_launch_server(
+ cls.model,
+ cls.base_url,
+ timeout=300,
+ other_args=(
+ "--random-seed",
+ "0",
+ ),
+ )
+
+ @classmethod
+ def tearDownClass(cls):
+ kill_child_process(cls.process.pid)
+
+ def run_decode(
+ self,
+ return_logprob=True,
+ top_logprobs_num=5,
+ return_text=True,
+ n=1,
+ **sampling_params,
+ ):
+ response = requests.post(
+ self.base_url + "/generate",
+ json={
+ # prompt that is supposed to generate < 32 tokens
+ "text": "<|start_header_id|>user<|end_header_id|>\n\nWhat is the answer for 1 + 1 = ?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
+ "sampling_params": {
+ "max_new_tokens": 32,
+ "n": n,
+ **sampling_params,
+ },
+ "stream": False,
+ "return_logprob": return_logprob,
+ "top_logprobs_num": top_logprobs_num,
+ "return_text_in_logprobs": return_text,
+ "logprob_start_len": 0,
+ },
+ )
+ print(json.dumps(response.json()))
+ print("=" * 100)
+
+ def test_default_values(self):
+ self.run_decode()
+
+ def test_mixed(self):
+ """
+ Sends two requests with one with penalizers disabled, and the other with penalizers enabled.
+ This will cause two different {ScheduleBatch} to be initialized and eventually gets merged.
+
+ Merging batch with penalizers enabled with enabled, or disabled is trivial. However disabled + enabled is not.
+ This is because the penalizer will not be prepared if it is not required, then it will be prepared during the merge.
+
+ This test triggers the merge of disabled + enabled.
+ """
+
+ processes = []
+
+ p = Process(
+ target=self.run_decode,
+ )
+ processes.append(p)
+ p.start()
+
+ p = Process(
+ target=self.run_decode,
+ kwargs={
+ "frequency_penalty": 2,
+ "min_new_tokens": 16,
+ "presence_penalty": 2,
+ "repetition_penalty": 2,
+ },
+ )
+ processes.append(p)
+ p.start()
+
+ for p in processes:
+ p.join()
+
+ def test_frequency_penalty(self):
+ self.run_decode(frequency_penalty=2)
+
+ def test_min_new_tokens(self):
+ self.run_decode(min_new_tokens=16)
+
+ def test_presence_penalty(self):
+ self.run_decode(presence_penalty=2)
+
+ def test_repetition_penalty(self):
+ self.run_decode(repetition_penalty=2)
+
+
+if __name__ == "__main__":
+ unittest.main(warnings="ignore")
diff --git a/test/srt/test_chunked_prefill.py b/test/srt/test_chunked_prefill.py
new file mode 100644
index 00000000000..7f274926a62
--- /dev/null
+++ b/test/srt/test_chunked_prefill.py
@@ -0,0 +1,45 @@
+import unittest
+from types import SimpleNamespace
+
+from sglang.srt.utils import kill_child_process
+from sglang.test.run_eval import run_eval
+from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST, popen_launch_server
+
+
+class TestAccuracy(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ cls.model = DEFAULT_MODEL_NAME_FOR_TEST
+ cls.base_url = "http://127.0.0.1:8157"
+ cls.process = popen_launch_server(
+ cls.model,
+ cls.base_url,
+ timeout=300,
+ other_args=["--chunked-prefill-size", "32"],
+ )
+
+ @classmethod
+ def tearDownClass(cls):
+ kill_child_process(cls.process.pid)
+
+ def test_mmlu(self):
+ args = SimpleNamespace(
+ base_url=self.base_url,
+ model=self.model,
+ eval_name="mmlu",
+ num_examples=20,
+ num_threads=20,
+ )
+
+ metrics = run_eval(args)
+ assert metrics["score"] >= 0.5
+
+
+if __name__ == "__main__":
+ unittest.main(warnings="ignore")
+
+ # t = TestAccuracy()
+ # t.setUpClass()
+ # t.test_mmlu()
+ # t.tearDownClass()
diff --git a/test/srt/test_eval_accuracy.py b/test/srt/test_eval_accuracy.py
new file mode 100644
index 00000000000..b6359362670
--- /dev/null
+++ b/test/srt/test_eval_accuracy.py
@@ -0,0 +1,40 @@
+import unittest
+from types import SimpleNamespace
+
+from sglang.srt.utils import kill_child_process
+from sglang.test.run_eval import run_eval
+from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST, popen_launch_server
+
+
+class TestAccuracy(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ cls.model = DEFAULT_MODEL_NAME_FOR_TEST
+ cls.base_url = "http://127.0.0.1:8157"
+ cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300)
+
+ @classmethod
+ def tearDownClass(cls):
+ kill_child_process(cls.process.pid)
+
+ def test_mmlu(self):
+ args = SimpleNamespace(
+ base_url=self.base_url,
+ model=self.model,
+ eval_name="mmlu",
+ num_examples=20,
+ num_threads=20,
+ )
+
+ metrics = run_eval(args)
+ assert metrics["score"] >= 0.5
+
+
+if __name__ == "__main__":
+ unittest.main(warnings="ignore")
+
+ # t = TestAccuracy()
+ # t.setUpClass()
+ # t.test_mmlu()
+ # t.tearDownClass()
diff --git a/test/srt/test_models_from_modelscope.py b/test/srt/test_models_from_modelscope.py
new file mode 100644
index 00000000000..2313053b909
--- /dev/null
+++ b/test/srt/test_models_from_modelscope.py
@@ -0,0 +1,47 @@
+import os
+import shutil
+import subprocess
+import unittest
+from unittest import mock
+
+from sglang.srt.utils import prepare_model, prepare_tokenizer
+
+
+class TestDownloadFromModelScope(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ cls.model = "iic/nlp_lstmcrf_word-segmentation_chinese-news"
+ stat, output = subprocess.getstatusoutput("pip install modelscope")
+
+ cls.with_modelscope_environ = {k: v for k, v in os.environ.items()}
+ cls.with_modelscope_environ["SGLANG_USE_MODELSCOPE"] = "True"
+
+ @classmethod
+ def tearDownClass(cls):
+ pass
+
+ def test_prepare_model(self):
+ from modelscope.utils.file_utils import get_model_cache_root
+
+ model_cache_root = get_model_cache_root()
+ if os.path.exists(model_cache_root):
+ shutil.rmtree(model_cache_root)
+ with mock.patch.dict(os.environ, self.with_modelscope_environ, clear=True):
+ model_path = prepare_model(self.model)
+ assert os.path.exists(os.path.join(model_path, "pytorch_model.bin"))
+
+ def test_prepare_tokenizer(self):
+ from modelscope.utils.file_utils import get_model_cache_root
+
+ model_cache_root = get_model_cache_root()
+ if os.path.exists(model_cache_root):
+ shutil.rmtree(model_cache_root)
+ with mock.patch.dict(os.environ, self.with_modelscope_environ, clear=True):
+ tokenizer_path = prepare_tokenizer(self.model)
+ assert not os.path.exists(os.path.join(tokenizer_path, "pytorch_model.bin"))
+ assert os.path.exists(os.path.join(tokenizer_path, "config.json"))
+
+
+if __name__ == "__main__":
+ unittest.main(warnings="ignore")
diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py
index a77319b1baa..f8f6ca63210 100644
--- a/test/srt/test_openai_server.py
+++ b/test/srt/test_openai_server.py
@@ -1,209 +1,407 @@
-"""
-First run the following command to launch the server.
-Note that TinyLlama adopts different chat templates in different versions.
-For v0.4, the chat template is chatml.
-
-python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 \
---port 30000 --chat-template chatml
-
-Output example:
-The capital of France is Paris.
-The capital of the United States is Washington, D.C.
-The capital of Canada is Ottawa.
-The capital of Japan is Tokyo
-"""
-
-import argparse
import json
+import time
+import unittest
import openai
+from sglang.srt.hf_transformers_utils import get_tokenizer
+from sglang.srt.utils import kill_child_process
+from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST, popen_launch_server
-def test_completion(args, echo, logprobs):
- client = openai.Client(api_key="EMPTY", base_url=args.base_url)
- response = client.completions.create(
- model="default",
- prompt="The capital of France is",
- temperature=0,
- max_tokens=32,
- echo=echo,
- logprobs=logprobs,
- )
- text = response.choices[0].text
- print(response.choices[0].text)
- if echo:
- assert text.startswith("The capital of France is")
- if logprobs:
- print(response.choices[0].logprobs.top_logprobs)
- assert response.choices[0].logprobs
- if echo:
- assert response.choices[0].logprobs.token_logprobs[0] == None
+
+class TestOpenAIServer(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ cls.model = DEFAULT_MODEL_NAME_FOR_TEST
+ cls.base_url = "http://127.0.0.1:8157"
+ cls.api_key = "sk-123456"
+ cls.process = popen_launch_server(
+ cls.model, cls.base_url, timeout=300, api_key=cls.api_key
+ )
+ cls.base_url += "/v1"
+ cls.tokenizer = get_tokenizer(DEFAULT_MODEL_NAME_FOR_TEST)
+
+ @classmethod
+ def tearDownClass(cls):
+ kill_child_process(cls.process.pid)
+
+ def run_completion(
+ self, echo, logprobs, use_list_input, parallel_sample_num, token_input
+ ):
+ client = openai.Client(api_key=self.api_key, base_url=self.base_url)
+ prompt = "The capital of France is"
+ if token_input:
+ prompt_input = self.tokenizer.encode(prompt)
+ num_prompt_tokens = len(prompt_input)
+ else:
+ prompt_input = prompt
+ num_prompt_tokens = len(self.tokenizer.encode(prompt))
+
+ if use_list_input:
+ prompt_arg = [prompt_input, prompt_input]
+ num_choices = len(prompt_arg)
+ num_prompt_tokens *= 2
else:
- assert response.choices[0].logprobs.token_logprobs[0] != None
- assert response.id
- assert response.created
- assert response.usage.prompt_tokens > 0
- assert response.usage.completion_tokens > 0
- assert response.usage.total_tokens > 0
- print("=" * 100)
-
-
-def test_completion_stream(args, echo, logprobs):
- client = openai.Client(api_key="EMPTY", base_url=args.base_url)
- response = client.completions.create(
- model="default",
- prompt="The capital of France is",
- temperature=0,
- max_tokens=32,
- stream=True,
- echo=echo,
- logprobs=logprobs,
- )
- first = True
- for r in response:
- if first:
+ prompt_arg = prompt_input
+ num_choices = 1
+
+ response = client.completions.create(
+ model=self.model,
+ prompt=prompt_arg,
+ temperature=0,
+ max_tokens=32,
+ echo=echo,
+ logprobs=logprobs,
+ n=parallel_sample_num,
+ )
+
+ assert len(response.choices) == num_choices * parallel_sample_num
+
+ if echo:
+ text = response.choices[0].text
+ assert text.startswith(prompt)
+
+ if logprobs:
+ assert response.choices[0].logprobs
+ assert isinstance(response.choices[0].logprobs.tokens[0], str)
+ assert isinstance(response.choices[0].logprobs.top_logprobs[1], dict)
+ ret_num_top_logprobs = len(response.choices[0].logprobs.top_logprobs[1])
+ # FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some out_put id maps to the same output token and duplicate in the map
+ # assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
+ assert ret_num_top_logprobs > 0
if echo:
- assert r.choices[0].text.startswith("The capital of France is")
- first = False
+ assert response.choices[0].logprobs.token_logprobs[0] == None
+ else:
+ assert response.choices[0].logprobs.token_logprobs[0] != None
+
+ assert response.id
+ assert response.created
+ assert (
+ response.usage.prompt_tokens == num_prompt_tokens
+ ), f"{response.usage.prompt_tokens} vs {num_prompt_tokens}"
+ assert response.usage.completion_tokens > 0
+ assert response.usage.total_tokens > 0
+
+ def run_completion_stream(self, echo, logprobs, token_input):
+ client = openai.Client(api_key=self.api_key, base_url=self.base_url)
+ prompt = "The capital of France is"
+ if token_input:
+ prompt_arg = self.tokenizer.encode(prompt)
+ else:
+ prompt_arg = prompt
+ generator = client.completions.create(
+ model=self.model,
+ prompt=prompt_arg,
+ temperature=0,
+ max_tokens=32,
+ echo=echo,
+ logprobs=logprobs,
+ stream=True,
+ stream_options={"include_usage": True},
+ )
+
+ first = True
+ for response in generator:
+ usage = response.usage
+ if usage is not None:
+ assert usage.prompt_tokens > 0
+ assert usage.completion_tokens > 0
+ assert usage.total_tokens > 0
+ continue
+ if logprobs:
+ assert response.choices[0].logprobs
+ assert isinstance(response.choices[0].logprobs.tokens[0], str)
+ if not (first and echo):
+ assert isinstance(
+ response.choices[0].logprobs.top_logprobs[0], dict
+ )
+ ret_num_top_logprobs = len(
+ response.choices[0].logprobs.top_logprobs[0]
+ )
+ # FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some out_put id maps to the same output token and duplicate in the map
+ # assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
+ assert ret_num_top_logprobs > 0
+
+ if first:
+ if echo:
+ assert response.choices[0].text.startswith(
+ prompt
+ ), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {first}"
+ first = False
+ assert response.id
+ assert response.created
+
+ def run_chat_completion(self, logprobs, parallel_sample_num):
+ client = openai.Client(api_key=self.api_key, base_url=self.base_url)
+ response = client.chat.completions.create(
+ model=self.model,
+ messages=[
+ {"role": "system", "content": "You are a helpful AI assistant"},
+ {
+ "role": "user",
+ "content": "What is the capital of France? Answer in a few words.",
+ },
+ ],
+ temperature=0,
+ logprobs=logprobs is not None and logprobs > 0,
+ top_logprobs=logprobs,
+ n=parallel_sample_num,
+ )
+
if logprobs:
- print(
- f"{r.choices[0].text:12s}\t" f"{r.choices[0].logprobs.token_logprobs}",
- flush=True,
+ assert isinstance(
+ response.choices[0].logprobs.content[0].top_logprobs[0].token, str
+ )
+
+ ret_num_top_logprobs = len(
+ response.choices[0].logprobs.content[0].top_logprobs
)
- print(r.choices[0].logprobs.top_logprobs)
+ assert (
+ ret_num_top_logprobs == logprobs
+ ), f"{ret_num_top_logprobs} vs {logprobs}"
+
+ assert len(response.choices) == parallel_sample_num
+ assert response.choices[0].message.role == "assistant"
+ assert isinstance(response.choices[0].message.content, str)
+ assert response.id
+ assert response.created
+ assert response.usage.prompt_tokens > 0
+ assert response.usage.completion_tokens > 0
+ assert response.usage.total_tokens > 0
+
+ def run_chat_completion_stream(self, logprobs):
+ client = openai.Client(api_key=self.api_key, base_url=self.base_url)
+ generator = client.chat.completions.create(
+ model=self.model,
+ messages=[
+ {"role": "system", "content": "You are a helpful AI assistant"},
+ {"role": "user", "content": "What is the capital of France?"},
+ ],
+ temperature=0,
+ logprobs=logprobs is not None and logprobs > 0,
+ top_logprobs=logprobs,
+ stream=True,
+ stream_options={"include_usage": True},
+ )
+
+ is_first = True
+ for response in generator:
+ usage = response.usage
+ if usage is not None:
+ assert usage.prompt_tokens > 0
+ assert usage.completion_tokens > 0
+ assert usage.total_tokens > 0
+ continue
+
+ data = response.choices[0].delta
+
+ if is_first:
+ data.role == "assistant"
+ is_first = False
+ continue
+
+ if logprobs:
+ assert response.choices[0].logprobs
+ assert isinstance(
+ response.choices[0].logprobs.content[0].top_logprobs[0].token, str
+ )
+ assert isinstance(
+ response.choices[0].logprobs.content[0].top_logprobs, list
+ )
+ ret_num_top_logprobs = len(
+ response.choices[0].logprobs.content[0].top_logprobs
+ )
+ assert (
+ ret_num_top_logprobs == logprobs
+ ), f"{ret_num_top_logprobs} vs {logprobs}"
+
+ assert isinstance(data.content, str)
+ assert response.id
+ assert response.created
+
+ def run_batch(self, mode):
+ client = openai.Client(api_key=self.api_key, base_url=self.base_url)
+ if mode == "completion":
+ input_file_path = "complete_input.jsonl"
+ # write content to input file
+ content = [
+ {
+ "custom_id": "request-1",
+ "method": "POST",
+ "url": "/v1/completions",
+ "body": {
+ "model": "gpt-3.5-turbo-instruct",
+ "prompt": "List 3 names of famous soccer player: ",
+ "max_tokens": 20,
+ },
+ },
+ {
+ "custom_id": "request-2",
+ "method": "POST",
+ "url": "/v1/completions",
+ "body": {
+ "model": "gpt-3.5-turbo-instruct",
+ "prompt": "List 6 names of famous basketball player: ",
+ "max_tokens": 40,
+ },
+ },
+ {
+ "custom_id": "request-3",
+ "method": "POST",
+ "url": "/v1/completions",
+ "body": {
+ "model": "gpt-3.5-turbo-instruct",
+ "prompt": "List 6 names of famous tenniss player: ",
+ "max_tokens": 40,
+ },
+ },
+ ]
+
else:
- print(r.choices[0].text, end="", flush=True)
- assert r.id
- assert r.usage.prompt_tokens > 0
- assert r.usage.completion_tokens > 0
- assert r.usage.total_tokens > 0
- print("=" * 100)
-
-
-def test_chat_completion(args):
- client = openai.Client(api_key="EMPTY", base_url=args.base_url)
- response = client.chat.completions.create(
- model="default",
- messages=[
- {"role": "system", "content": "You are a helpful AI assistant"},
- {"role": "user", "content": "What is the capital of France?"},
- ],
- temperature=0,
- max_tokens=32,
- )
- print(response.choices[0].message.content)
- assert response.id
- assert response.created
- assert response.usage.prompt_tokens > 0
- assert response.usage.completion_tokens > 0
- assert response.usage.total_tokens > 0
- print("=" * 100)
-
-
-def test_chat_completion_image(args):
- client = openai.Client(api_key="EMPTY", base_url=args.base_url)
- response = client.chat.completions.create(
- model="default",
- messages=[
- {"role": "system", "content": "You are a helpful AI assistant"},
- {
- "role": "user",
- "content": [
- {"type": "text", "text": "Describe this image"},
- {
- "type": "image_url",
- "image_url": {
- "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/mixtral_8x7b.jpg"
- },
+ input_file_path = "chat_input.jsonl"
+ content = [
+ {
+ "custom_id": "request-1",
+ "method": "POST",
+ "url": "/v1/chat/completions",
+ "body": {
+ "model": "gpt-3.5-turbo-0125",
+ "messages": [
+ {
+ "role": "system",
+ "content": "You are a helpful assistant.",
+ },
+ {
+ "role": "user",
+ "content": "Hello! List 3 NBA players and tell a story",
+ },
+ ],
+ "max_tokens": 30,
+ },
+ },
+ {
+ "custom_id": "request-2",
+ "method": "POST",
+ "url": "/v1/chat/completions",
+ "body": {
+ "model": "gpt-3.5-turbo-0125",
+ "messages": [
+ {"role": "system", "content": "You are an assistant. "},
+ {
+ "role": "user",
+ "content": "Hello! List three capital and tell a story",
+ },
+ ],
+ "max_tokens": 50,
},
- ],
- },
- ],
- temperature=0,
- max_tokens=32,
- )
- print(response.choices[0].message.content)
- assert response.id
- assert response.created
- assert response.usage.prompt_tokens > 0
- assert response.usage.completion_tokens > 0
- assert response.usage.total_tokens > 0
- print("=" * 100)
-
-
-def test_chat_completion_stream(args):
- client = openai.Client(api_key="EMPTY", base_url=args.base_url)
- response = client.chat.completions.create(
- model="default",
- messages=[
- {"role": "system", "content": "You are a helpful AI assistant"},
- {"role": "user", "content": "List 3 countries and their capitals."},
- ],
- temperature=0,
- max_tokens=64,
- stream=True,
- )
- is_first = True
- for chunk in response:
- if is_first:
- is_first = False
- assert chunk.choices[0].delta.role == "assistant"
- continue
-
- data = chunk.choices[0].delta
- if not data.content:
- continue
- print(data.content, end="", flush=True)
- print("=" * 100)
-
-
-def test_regex(args):
- client = openai.Client(api_key="EMPTY", base_url=args.base_url)
-
- regex = (
- r"""\{\n"""
- + r""" "name": "[\w]+",\n"""
- + r""" "population": [\d]+\n"""
- + r"""\}"""
- )
-
- response = client.chat.completions.create(
- model="default",
- messages=[
- {"role": "system", "content": "You are a helpful AI assistant"},
- {"role": "user", "content": "Introduce the capital of France."},
- ],
- temperature=0,
- max_tokens=128,
- extra_body={"regex": regex},
- )
- text = response.choices[0].message.content
- print(json.loads(text))
- print("=" * 100)
+ },
+ ]
+ with open(input_file_path, "w") as file:
+ for line in content:
+ file.write(json.dumps(line) + "\n")
+ with open(input_file_path, "rb") as file:
+ uploaded_file = client.files.create(file=file, purpose="batch")
+ if mode == "completion":
+ endpoint = "/v1/completions"
+ elif mode == "chat":
+ endpoint = "/v1/chat/completions"
+ completion_window = "24h"
+ batch_job = client.batches.create(
+ input_file_id=uploaded_file.id,
+ endpoint=endpoint,
+ completion_window=completion_window,
+ )
+ while batch_job.status not in ["completed", "failed", "cancelled"]:
+ time.sleep(3)
+ print(
+ f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
+ )
+ batch_job = client.batches.retrieve(batch_job.id)
+ assert batch_job.status == "completed"
+ assert batch_job.request_counts.completed == len(content)
+ assert batch_job.request_counts.failed == 0
+ assert batch_job.request_counts.total == len(content)
+
+ result_file_id = batch_job.output_file_id
+ file_response = client.files.content(result_file_id)
+ result_content = file_response.read().decode("utf-8") # Decode bytes to string
+ results = [
+ json.loads(line)
+ for line in result_content.split("\n")
+ if line.strip() != ""
+ ]
+ assert len(results) == len(content)
+
+ def test_completion(self):
+ for echo in [False, True]:
+ for logprobs in [None, 5]:
+ for use_list_input in [True, False]:
+ for parallel_sample_num in [1, 2]:
+ for token_input in [False, True]:
+ self.run_completion(
+ echo,
+ logprobs,
+ use_list_input,
+ parallel_sample_num,
+ token_input,
+ )
+
+ def test_completion_stream(self):
+ # parallel sampling adn list input are not supported in streaming mode
+ for echo in [False, True]:
+ for logprobs in [None, 5]:
+ for token_input in [False, True]:
+ self.run_completion_stream(echo, logprobs, token_input)
+
+ def test_chat_completion(self):
+ for logprobs in [None, 5]:
+ for parallel_sample_num in [1, 2]:
+ self.run_chat_completion(logprobs, parallel_sample_num)
+
+ def test_chat_completion_stream(self):
+ for logprobs in [None, 5]:
+ self.run_chat_completion_stream(logprobs)
+
+ def test_batch(self):
+ for mode in ["completion", "chat"]:
+ self.run_batch(mode)
+
+ def test_regex(self):
+ client = openai.Client(api_key=self.api_key, base_url=self.base_url)
+
+ regex = (
+ r"""\{\n"""
+ + r""" "name": "[\w]+",\n"""
+ + r""" "population": [\d]+\n"""
+ + r"""\}"""
+ )
+
+ response = client.chat.completions.create(
+ model=self.model,
+ messages=[
+ {"role": "system", "content": "You are a helpful AI assistant"},
+ {"role": "user", "content": "Introduce the capital of France."},
+ ],
+ temperature=0,
+ max_tokens=128,
+ extra_body={"regex": regex},
+ )
+ text = response.choices[0].message.content
+
+ try:
+ js_obj = json.loads(text)
+ except (TypeError, json.decoder.JSONDecodeError):
+ print("JSONDecodeError", text)
+ raise
+ assert isinstance(js_obj["name"], str)
+ assert isinstance(js_obj["population"], int)
if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument("--base-url", type=str, default="http://127.0.0.1:30000/v1")
- parser.add_argument(
- "--test-image", action="store_true", help="Enables testing image inputs"
- )
- args = parser.parse_args()
-
- test_completion(args, echo=False, logprobs=False)
- test_completion(args, echo=True, logprobs=False)
- test_completion(args, echo=False, logprobs=True)
- test_completion(args, echo=True, logprobs=True)
- test_completion(args, echo=False, logprobs=3)
- test_completion(args, echo=True, logprobs=3)
- test_completion_stream(args, echo=False, logprobs=False)
- test_completion_stream(args, echo=True, logprobs=False)
- test_completion_stream(args, echo=False, logprobs=True)
- test_completion_stream(args, echo=True, logprobs=True)
- test_completion_stream(args, echo=False, logprobs=3)
- test_completion_stream(args, echo=True, logprobs=3)
- test_chat_completion(args)
- test_chat_completion_stream(args)
- test_regex(args)
- if args.test_image:
- test_chat_completion_image(args)
+ unittest.main(warnings="ignore")
+
+ # t = TestOpenAIServer()
+ # t.setUpClass()
+ # t.test_completion()
+ # t.tearDownClass()
diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py
new file mode 100644
index 00000000000..b208dfa1329
--- /dev/null
+++ b/test/srt/test_srt_endpoint.py
@@ -0,0 +1,62 @@
+import json
+import unittest
+
+import requests
+
+from sglang.srt.utils import kill_child_process
+from sglang.test.run_eval import run_eval
+from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST, popen_launch_server
+
+
+class TestSRTEndpoint(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ cls.model = DEFAULT_MODEL_NAME_FOR_TEST
+ cls.base_url = "http://127.0.0.1:8157"
+ cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300)
+
+ @classmethod
+ def tearDownClass(cls):
+ kill_child_process(cls.process.pid)
+
+ def run_decode(
+ self, return_logprob=False, top_logprobs_num=0, return_text=False, n=1
+ ):
+ response = requests.post(
+ self.base_url + "/generate",
+ json={
+ "text": "The capital of France is",
+ "sampling_params": {
+ "temperature": 0 if n == 1 else 0.5,
+ "max_new_tokens": 32,
+ "n": n,
+ },
+ "stream": False,
+ "return_logprob": return_logprob,
+ "top_logprobs_num": top_logprobs_num,
+ "return_text_in_logprobs": return_text,
+ "logprob_start_len": 0,
+ },
+ )
+ print(json.dumps(response.json()))
+ print("=" * 100)
+
+ def test_simple_decode(self):
+ self.run_decode()
+
+ def test_parallel_sample(self):
+ self.run_decode(n=3)
+
+ def test_logprob(self):
+ for top_logprobs_num in [0, 3]:
+ for return_text in [True, False]:
+ self.run_decode(
+ return_logprob=True,
+ top_logprobs_num=top_logprobs_num,
+ return_text=return_text,
+ )
+
+
+if __name__ == "__main__":
+ unittest.main(warnings="ignore")
diff --git a/test/srt/test_torch_compile.py b/test/srt/test_torch_compile.py
new file mode 100644
index 00000000000..fd2c6ebb778
--- /dev/null
+++ b/test/srt/test_torch_compile.py
@@ -0,0 +1,42 @@
+import unittest
+from types import SimpleNamespace
+
+from sglang.srt.utils import kill_child_process
+from sglang.test.run_eval import run_eval
+from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST, popen_launch_server
+
+
+class TestAccuracy(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ cls.model = DEFAULT_MODEL_NAME_FOR_TEST
+ cls.base_url = "http://127.0.0.1:8157"
+ cls.process = popen_launch_server(
+ cls.model, cls.base_url, timeout=300, other_args=["--enable-torch-compile"]
+ )
+
+ @classmethod
+ def tearDownClass(cls):
+ kill_child_process(cls.process.pid)
+
+ def test_mmlu(self):
+ args = SimpleNamespace(
+ base_url=self.base_url,
+ model=self.model,
+ eval_name="mmlu",
+ num_examples=20,
+ num_threads=20,
+ )
+
+ metrics = run_eval(args)
+ assert metrics["score"] >= 0.5
+
+
+if __name__ == "__main__":
+ unittest.main(warnings="ignore")
+
+ # t = TestAccuracy()
+ # t.setUpClass()
+ # t.test_mmlu()
+ # t.tearDownClass()
diff --git a/test/srt/test_vision_openai_server.py b/test/srt/test_vision_openai_server.py
new file mode 100644
index 00000000000..982c026dbbf
--- /dev/null
+++ b/test/srt/test_vision_openai_server.py
@@ -0,0 +1,121 @@
+import json
+import unittest
+
+import openai
+
+from sglang.srt.hf_transformers_utils import get_tokenizer
+from sglang.srt.utils import kill_child_process
+from sglang.test.test_utils import popen_launch_server
+
+
+class TestOpenAIVisionServer(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ cls.model = "liuhaotian/llava-v1.6-vicuna-7b"
+ cls.base_url = "http://127.0.0.1:8157"
+ cls.api_key = "sk-123456"
+ cls.process = popen_launch_server(
+ cls.model,
+ cls.base_url,
+ timeout=300,
+ api_key=cls.api_key,
+ other_args=[
+ "--chat-template",
+ "vicuna_v1.1",
+ "--tokenizer-path",
+ "llava-hf/llava-1.5-7b-hf",
+ "--log-requests",
+ ],
+ )
+ cls.base_url += "/v1"
+
+ @classmethod
+ def tearDownClass(cls):
+ kill_child_process(cls.process.pid)
+
+ def test_chat_completion(self):
+ client = openai.Client(api_key=self.api_key, base_url=self.base_url)
+
+ response = client.chat.completions.create(
+ model="default",
+ messages=[
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image_url",
+ "image_url": {
+ "url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
+ },
+ },
+ {
+ "type": "text",
+ "text": "Describe this image in a very short sentence.",
+ },
+ ],
+ },
+ ],
+ temperature=0,
+ )
+
+ assert response.choices[0].message.role == "assistant"
+ text = response.choices[0].message.content
+ assert isinstance(text, str)
+ assert "car" in text or "taxi" in text, text
+ assert response.id
+ assert response.created
+ assert response.usage.prompt_tokens > 0
+ assert response.usage.completion_tokens > 0
+ assert response.usage.total_tokens > 0
+
+ def test_regex(self):
+ client = openai.Client(api_key=self.api_key, base_url=self.base_url)
+
+ regex = (
+ r"""\{\n"""
+ + r""" "color": "[\w]+",\n"""
+ + r""" "number_of_cars": [\d]+\n"""
+ + r"""\}"""
+ )
+
+ response = client.chat.completions.create(
+ model="default",
+ messages=[
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image_url",
+ "image_url": {
+ "url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
+ },
+ },
+ {
+ "type": "text",
+ "text": "Describe this image in the JSON format.",
+ },
+ ],
+ },
+ ],
+ temperature=0,
+ extra_body={"regex": regex},
+ )
+ text = response.choices[0].message.content
+
+ try:
+ js_obj = json.loads(text)
+ except (TypeError, json.decoder.JSONDecodeError):
+ print("JSONDecodeError", text)
+ raise
+ assert isinstance(js_obj["color"], str)
+ assert isinstance(js_obj["number_of_cars"], int)
+
+
+if __name__ == "__main__":
+ unittest.main(warnings="ignore")
+
+ # t = TestOpenAIVisionServer()
+ # t.setUpClass()
+ # t.test_chat_completion()
+ # t.tearDownClass()