diff --git a/.github/ISSUE_TEMPLATE/1-bug-report.yml b/.github/ISSUE_TEMPLATE/1-bug-report.yml new file mode 100644 index 00000000000..f9acb728c1c --- /dev/null +++ b/.github/ISSUE_TEMPLATE/1-bug-report.yml @@ -0,0 +1,38 @@ +name: 🐞 Bug report +description: Create a report to help us reproduce and fix the bug +title: "[Bug] " +labels: ['Bug'] + +body: +- type: checkboxes + attributes: + label: Checklist + options: + - label: 1. I have searched related issues but cannot get the expected help. + - label: 2. The bug has not been fixed in the latest version. + - label: 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback. + - label: 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/sgl-project/sglang/discussions/new/choose Otherwise, it will be closed. +- type: textarea + attributes: + label: Describe the bug + description: A clear and concise description of what the bug is. + validations: + required: true +- type: textarea + attributes: + label: Reproduction + description: | + What command or script did you run? Which **model** are you using? + placeholder: | + A placeholder for the command. + validations: + required: true +- type: textarea + attributes: + label: Environment + description: | + Please provide necessary environment information here with `python3 -m sglang.check_env`. + placeholder: Environment here. + render: Shell + validations: + required: true diff --git a/.github/ISSUE_TEMPLATE/2-feature-request.yml b/.github/ISSUE_TEMPLATE/2-feature-request.yml new file mode 100644 index 00000000000..5ab369f8b09 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/2-feature-request.yml @@ -0,0 +1,17 @@ +name: 🚀 Feature request +description: Suggest an idea for this project +title: "[Feature] " + +body: +- type: textarea + attributes: + label: Motivation + description: | + A clear and concise description of the motivation of the feature. + validations: + required: true +- type: textarea + attributes: + label: Related resources + description: | + If there is an official code release or third-party implementations, please also provide the information here, which would be very helpful. diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 00000000000..20f4a10bc56 --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,15 @@ +Thank you for your contribution, we really appreciate it. The following instructions will help improve your pull request and make it easier to receive feedback. If there are any items you don't understand, don't worry. Just submit the pull request and ask the maintainers for help. + +## Motivation + +Please explain the motivation behind this PR and the goal you aim to achieve with it. + +## Modification + +Briefly describe the changes made in this PR. + +## Checklist + +1. Ensure pre-commit `pre-commit run --all-files` or other linting tools are used to fix potential lint issues. +2. Confirm that modifications are covered by complete unit tests. If not, please add more unit tests for correctness. +3. Modify documentation as needed, such as docstrings or example tutorials. diff --git a/.github/workflows/cache-purge.yml b/.github/workflows/cache-purge.yml new file mode 100644 index 00000000000..c699f49885f --- /dev/null +++ b/.github/workflows/cache-purge.yml @@ -0,0 +1,27 @@ +name: Weekly Cache Purge + +on: + schedule: + - cron: '0 0 * * 0' # Every Sunday at 00:00 + workflow_dispatch: + +jobs: + purge-cache: + if: github.repository == 'sgl-project/sglang' + runs-on: self-hosted + + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Purge pip cache + run: | + source $HOME/venv/bin/activate + echo "$HOME/venv/bin" >> $GITHUB_PATH + pip cache purge + + - name: Update dependencies + run: | + pip install --upgrade pip + pip install -e "python[all]" + pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall diff --git a/.github/workflows/close-inactive-issues.yml b/.github/workflows/close-inactive-issues.yml new file mode 100644 index 00000000000..e81dd424791 --- /dev/null +++ b/.github/workflows/close-inactive-issues.yml @@ -0,0 +1,91 @@ +name: Close Inactive Issues + +on: + schedule: + - cron: '0 0 * * *' + workflow_dispatch: + +permissions: + issues: write + contents: read + +jobs: + close-inactive-issues: + if: github.repository == 'sgl-project/sglang' + runs-on: ubuntu-latest + steps: + - name: Check and close inactive issues + uses: actions/github-script@v6 + with: + github-token: ${{secrets.GITHUB_TOKEN}} + script: | + const sixtyDaysAgo = new Date(Date.now() - 60 * 24 * 60 * 60 * 1000); + + const [owner, repo] = process.env.GITHUB_REPOSITORY.split('/'); + console.log(`Owner: ${owner}, Repo: ${repo}`); + + async function fetchIssues(page = 1) { + console.log(`Fetching issues for ${owner}/${repo}, page ${page}`); + return await github.rest.issues.listForRepo({ + owner, + repo, + state: 'open', + sort: 'updated', + direction: 'asc', + per_page: 100, + page: page + }); + } + + async function processIssues() { + console.log('Starting to process issues'); + console.log(`Repository: ${owner}/${repo}`); + + let page = 1; + let hasMoreIssues = true; + while (hasMoreIssues) { + try { + const issues = await fetchIssues(page); + console.log(`Fetched ${issues.data.length} issues on page ${page}`); + + if (issues.data.length === 0) { + hasMoreIssues = false; + break; + } + + for (const issue of issues.data) { + if (new Date(issue.updated_at) < sixtyDaysAgo) { + try { + await github.rest.issues.update({ + owner, + repo, + issue_number: issue.number, + state: 'closed', + labels: [...issue.labels.map(l => l.name), 'inactive'] + }); + await github.rest.issues.createComment({ + owner, + repo, + issue_number: issue.number, + body: 'This issue has been automatically closed due to inactivity. Please feel free to reopen it if needed.' + }); + console.log(`Closed issue #${issue.number} due to inactivity.`); + } catch (error) { + console.error(`Failed to close issue #${issue.number}: ${error.message}`); + } + } else { + console.log(`Issue #${issue.number} is still active. Stopping processing.`); + hasMoreIssues = false; + break; + } + } + page += 1; + } catch (error) { + console.error(`Error fetching issues on page ${page}: ${error.message}`); + hasMoreIssues = false; + } + } + console.log('Finished processing issues'); + } + + await processIssues(); diff --git a/.github/workflows/e2e-test.yml b/.github/workflows/e2e-test.yml new file mode 100644 index 00000000000..b82bbdc369a --- /dev/null +++ b/.github/workflows/e2e-test.yml @@ -0,0 +1,58 @@ +name: E2E Test + +on: + push: + branches: [ main ] + paths: + - "python/sglang/**" + - "test/**" + pull_request: + branches: [ main ] + paths: + - "python/sglang/**" + - "test/**" + workflow_dispatch: + +concurrency: + group: e2e-test-${{ github.ref }} + cancel-in-progress: true + +jobs: + e2e-test: + if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + runs-on: self-hosted + + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Install dependencies + run: | + source $HOME/venv/bin/activate + echo "$HOME/venv/bin" >> $GITHUB_PATH + + pip install --upgrade pip + pip install -e "python[all]" + pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall + + - name: Benchmark Serving Throughput + run: | + python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --port 8413 --disable-radix-cache & + + echo "Waiting for server to start..." + for i in {1..120}; do + if curl -s http://127.0.0.1:8413/health; then + echo "Server is up!" + break + fi + if [ $i -eq 120 ]; then + echo "Server failed to start within 120 seconds" + exit 1 + fi + sleep 1 + done + + cd $HOME && python3 -m sglang.bench_serving --backend sglang --port 8413 --dataset-name random --num-prompts 3000 --random-input 256 --random-output 512 + + echo "Stopping server..." + kill -9 $(ps aux | grep sglang | grep Meta-Llama-3.1-8B-Instruct | grep -- "--port 8413" | grep -v grep | awk '{print $2}') diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 00000000000..07614050640 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,19 @@ +name: Lint + +on: [push, pull_request] + +jobs: + lint: + runs-on: ubuntu-20.04 + steps: + - uses: actions/checkout@v2 + - name: Set up Python 3.8 + uses: actions/setup-python@v2 + with: + python-version: 3.8 + - name: Install pre-commit hook + run: | + python -m pip install pre-commit + pre-commit install + - name: Linting + run: pre-commit run --all-files diff --git a/.github/workflows/release-docker.yml b/.github/workflows/release-docker.yml new file mode 100644 index 00000000000..abdaa169e21 --- /dev/null +++ b/.github/workflows/release-docker.yml @@ -0,0 +1,52 @@ +name: Release Docker +on: + push: + branches: + - main + paths: + - "python/sglang/version.py" + workflow_dispatch: + +jobs: + publish: + if: github.repository == 'sgl-project/sglang' + runs-on: ubuntu-latest + environment: 'prod' + strategy: + matrix: + cuda_version: ['12.1.1', '12.4.1'] + steps: + - name: Delete huge unnecessary tools folder + run: rm -rf /opt/hostedtoolcache + + - name: Checkout repository + uses: actions/checkout@v3 + + - name: Login to Docker Hub + uses: docker/login-action@v2 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + - name: Build and Push + run: | + version=$(cat python/sglang/version.py | cut -d'"' -f2) + + if [ "${{ matrix.cuda_version }}" = "12.1.1" ]; then + cuda_tag="cu121" + elif [ "${{ matrix.cuda_version }}" = "12.4.1" ]; then + cuda_tag="cu124" + else + echo "Unsupported CUDA version" + exit 1 + fi + + tag=v${version}-${cuda_tag} + + docker build . -f docker/Dockerfile --build-arg CUDA_VERSION=${{ matrix.cuda_version }} -t lmsysorg/sglang:${tag} --no-cache + docker push lmsysorg/sglang:${tag} + + if [ "${{ matrix.cuda_version }}" = "12.1.1" ]; then + docker tag lmsysorg/sglang:${tag} lmsysorg/sglang:latest + docker push lmsysorg/sglang:latest + fi diff --git a/.github/workflows/release-fake-tag.yml b/.github/workflows/release-fake-tag.yml new file mode 100644 index 00000000000..c4b1c338aa4 --- /dev/null +++ b/.github/workflows/release-fake-tag.yml @@ -0,0 +1,35 @@ +name: Release Fake Tag +on: + push: + branches: + - main + paths: + - "python/sglang/version.py" + workflow_dispatch: + +permissions: + contents: write + +jobs: + publish: + if: github.repository == 'sgl-project/sglang' + runs-on: ubuntu-latest + environment: 'prod' + steps: + - name: Checkout repository + uses: actions/checkout@v3 + + - name: Get version + id: get_version + run: | + version=$(cat python/sglang/version.py | cut -d'"' -f2) + echo "TAG=v$version" >> $GITHUB_OUTPUT + + - name: Create and push fake tag + env: + GITHUB_TOKEN: ${{ secrets.REPO_TOKEN }} + run: | + git config user.name zhyncs + git config user.email me@zhyncs.com + git checkout -b ${{ steps.get_version.outputs.TAG }} + git push --set-upstream origin ${{ steps.get_version.outputs.TAG }} diff --git a/.github/workflows/release-github.yml b/.github/workflows/release-github.yml new file mode 100644 index 00000000000..12a2309a6f1 --- /dev/null +++ b/.github/workflows/release-github.yml @@ -0,0 +1,25 @@ +name: Release GitHub +on: + workflow_dispatch: +jobs: + publish: + if: github.repository == 'sgl-project/sglang' + runs-on: ubuntu-latest + environment: 'prod' + steps: + - name: Checkout repository + uses: actions/checkout@v3 + + - name: Get version + id: get_version + run: | + version=$(cat python/sglang/version.py | cut -d'"' -f2) + echo "TAG=v$version" >> $GITHUB_OUTPUT + + - name: Release + uses: softprops/action-gh-release@v1 + env: + GITHUB_TOKEN: ${{ secrets.REPO_TOKEN }} + with: + name: Release ${{ steps.get_version.outputs.TAG }} + tag_name: ${{ steps.get_version.outputs.TAG }} diff --git a/.github/workflows/release-pypi.yml b/.github/workflows/release-pypi.yml new file mode 100644 index 00000000000..c79e46cb715 --- /dev/null +++ b/.github/workflows/release-pypi.yml @@ -0,0 +1,29 @@ +name: Release PyPI +on: + push: + branches: + - main + paths: + - "python/sglang/version.py" + workflow_dispatch: + +jobs: + publish: + if: github.repository == 'sgl-project/sglang' + runs-on: ubuntu-latest + environment: 'prod' + steps: + - name: Set up python3.8 + uses: actions/setup-python@v4 + with: + python-version: '3.8' + - name: Checkout repository + uses: actions/checkout@v3 + - name: Upload to pypi + run: | + cd python + cp ../README.md ../LICENSE . + pip install build + python3 -m build + pip install twine + python3 -m twine upload dist/* -u __token__ -p ${{ secrets.PYPI_TOKEN }} diff --git a/.github/workflows/unit-test.yml b/.github/workflows/unit-test.yml new file mode 100644 index 00000000000..e43caf5f0ac --- /dev/null +++ b/.github/workflows/unit-test.yml @@ -0,0 +1,48 @@ +name: Unit Test + +on: + push: + branches: [ main ] + paths: + - "python/sglang/**" + - "test/**" + pull_request: + branches: [ main ] + paths: + - "python/sglang/**" + - "test/**" + workflow_dispatch: + +concurrency: + group: unit-test-${{ github.ref }} + cancel-in-progress: true + +jobs: + unit-test: + if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + runs-on: self-hosted + + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Install dependencies + run: | + source $HOME/venv/bin/activate + echo "$HOME/venv/bin" >> $GITHUB_PATH + + pip install --upgrade pip + pip install -e "python[all]" + pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall + pip install accelerate + pip install sentence_transformers + + - name: Test Frontend Language + run: | + cd test/lang + python3 run_suite.py --suite minimal + + - name: Test Backend Runtime + run: | + cd test/srt + python3 run_suite.py --suite minimal diff --git a/.gitignore b/.gitignore index 10e602e830d..ca43e1ccba4 100644 --- a/.gitignore +++ b/.gitignore @@ -181,3 +181,5 @@ tmp*.txt # personnal work_dirs/ *.csv + +!logo.png diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000000..2fa1254a66d --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,9 @@ +repos: + - repo: https://github.com/PyCQA/isort + rev: 5.13.2 + hooks: + - id: isort + - repo: https://github.com/psf/black + rev: 24.4.2 + hooks: + - id: black diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 00000000000..94f52e9a0f2 --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,17 @@ +version: 2 + +formats: all + +build: + os: "ubuntu-22.04" + tools: + python: "3.12" + + +sphinx: + configuration: docs/en/conf.py + + +python: + install: + - requirements: docs/requirements.txt diff --git a/README.md b/README.md index 2ac666c6b79..9be13509fb9 100644 --- a/README.md +++ b/README.md @@ -1,28 +1,42 @@
-logo +logo + +[![PyPI](https://img.shields.io/pypi/v/sglang)](https://pypi.org/project/sglang) +![PyPI - Downloads](https://img.shields.io/pypi/dm/sglang) +[![license](https://img.shields.io/github/license/sgl-project/sglang.svg)](https://github.com/sgl-project/sglang/tree/main/LICENSE) +[![issue resolution](https://img.shields.io/github/issues-closed-raw/sgl-project/sglang)](https://github.com/sgl-project/sglang/issues) +[![open issues](https://img.shields.io/github/issues-raw/sgl-project/sglang)](https://github.com/sgl-project/sglang/issues) +
-------------------------------------------------------------------------------- -| [**Blog**](https://lmsys.org/blog/2024-01-17-sglang/) | [**Paper**](https://arxiv.org/abs/2312.07104) | +| [**Blog**](https://lmsys.org/blog/2024-07-25-sglang-llama3/) | [**Paper**](https://arxiv.org/abs/2312.07104) | [**Slack**](https://join.slack.com/t/sgl-fru7574/shared_invite/zt-2ngly9muu-t37XiH87qvD~6rVBTkTEHw) | -SGLang is a structured generation language designed for large language models (LLMs). -It makes your interaction with LLMs faster and more controllable by co-designing the frontend language and the runtime system. +SGLang is a fast serving framework for large language models and vision language models. +It makes your interaction with models faster and more controllable by co-designing the backend runtime and frontend language. The core features include: +- **Fast Backend Runtime**: Efficient serving with RadixAttention for prefix caching, jump-forward constrained decoding, continuous batching, token attention (paged attention), tensor parallelism, flashinfer kernels, and quantization (AWQ/FP8/GPTQ/Marlin). - **Flexible Frontend Language**: Enables easy programming of LLM applications with chained generation calls, advanced prompting, control flow, multiple modalities, parallelism, and external interactions. -- **High-Performance Backend Runtime**: Features RadixAttention for accelerating complex LLM programs by reusing the KV cache across multiple calls. It can also serve as a standalone inference engine with all common techniques implemented (e.g., continuous batching and tensor parallelism). ## News -- [2024/02] 🔥 SGLang enables **3x faster JSON decoding** with compressed finite state machine ([blog](https://lmsys.org/blog/2024-02-05-compressed-fsm/)). -- [2024/01] 🔥 SGLang powers the serving of the official **LLaVA v1.6** release demo ([usage](https://github.com/haotian-liu/LLaVA?tab=readme-ov-file#demo)). +- [2024/07] 🔥 Faster Llama3 Serving with SGLang Runtime (vs. TensorRT-LLM, vLLM) ([blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/)). +- [2024/04] SGLang is used by the official **LLaVA-NeXT (video)** release ([blog](https://llava-vl.github.io/blog/2024-04-30-llava-next-video/)). +- [2024/02] SGLang enables **3x faster JSON decoding** with compressed finite state machine ([blog](https://lmsys.org/blog/2024-02-05-compressed-fsm/)). + +
+More + - [2024/01] SGLang provides up to **5x faster inference** with RadixAttention ([blog](https://lmsys.org/blog/2024-01-17-sglang/)). +- [2024/01] SGLang powers the serving of the official **LLaVA v1.6** release demo ([usage](https://github.com/haotian-liu/LLaVA?tab=readme-ov-file#demo)). + +
## Contents - [Install](#install) -- [Quick Start](#quick-start) -- [Frontend: Structured Generation Language (SGLang)](#frontend-structured-generation-language-sglang) - [Backend: SGLang Runtime (SRT)](#backend-sglang-runtime-srt) +- [Frontend: Structured Generation Language (SGLang)](#frontend-structured-generation-language-sglang) - [Benchmark And Performance](#benchmark-and-performance) - [Roadmap](#roadmap) - [Citation And Acknowledgment](#citation-and-acknowledgment) @@ -31,42 +45,201 @@ The core features include: ### Method 1: With pip ``` +pip install --upgrade pip pip install "sglang[all]" # Install FlashInfer CUDA kernels -pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/ +pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ ``` ### Method 2: From source ``` -git clone https://github.com/sgl-project/sglang.git +# Use the last release branch +git clone -b v0.2.11 https://github.com/sgl-project/sglang.git cd sglang +pip install --upgrade pip pip install -e "python[all]" # Install FlashInfer CUDA kernels -pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/ +pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ ``` ### Method 3: Using docker -The docker images are available on Docker Hub as [lmsysorg/sglang](https://hub.docker.com/r/lmsysorg/sglang/tags). +The docker images are available on Docker Hub as [lmsysorg/sglang](https://hub.docker.com/r/lmsysorg/sglang/tags), built from [Dockerfile](docker). +Replace `` below with your huggingface hub [token](https://huggingface.co/docs/hub/en/security-tokens). -### Common Notes -- If you see errors from the Triton compiler, please install the [Triton Nightly](https://triton-lang.org/main/getting-started/installation.html) by -``` -pip uninstall -y triton triton-nightly -pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly +```bash +docker run --gpus all \ + -p 30000:30000 \ + -v ~/.cache/huggingface:/root/.cache/huggingface \ + --env "HF_TOKEN=" \ + --ipc=host \ + lmsysorg/sglang:latest \ + python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --host 0.0.0.0 --port 30000 ``` + +### Common Notes - If you cannot install FlashInfer, check out its [installation](https://docs.flashinfer.ai/installation.html#) page. If you still cannot install it, you can use the slower Triton kernels by adding `--disable-flashinfer` when launching the server. - If you only need to use the OpenAI backend, you can avoid installing other dependencies by using `pip install "sglang[openai]"`. -## Quick Start +## Backend: SGLang Runtime (SRT) +The SGLang Runtime (SRT) is an efficient serving engine. + +### Quick Start +Launch a server +``` +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 +``` + +Send a request +``` +curl http://localhost:30000/generate \ + -H "Content-Type: application/json" \ + -d '{ + "text": "Once upon a time,", + "sampling_params": { + "max_new_tokens": 16, + "temperature": 0 + } + }' +``` +Learn more about the argument format [here](docs/en/sampling_params.md). + +### OpenAI Compatible API +In addition, the server supports OpenAI-compatible APIs. + +```python +import openai +client = openai.Client( + base_url="http://127.0.0.1:30000/v1", api_key="EMPTY") + +# Text completion +response = client.completions.create( + model="default", + prompt="The capital of France is", + temperature=0, + max_tokens=32, +) +print(response) + +# Chat completion +response = client.chat.completions.create( + model="default", + messages=[ + {"role": "system", "content": "You are a helpful AI assistant"}, + {"role": "user", "content": "List 3 countries and their capitals."}, + ], + temperature=0, + max_tokens=64, +) +print(response) +``` + +It supports streaming, vision, and most features of the Chat/Completions/Models/Batch endpoints specified by the [OpenAI API Reference](https://platform.openai.com/docs/api-reference/). + +### Additional Server Arguments +- Add `--tp 2` to enable tensor parallelism. If it indicates `peer access is not supported between these two devices`, add `--enable-p2p-check` option. +``` +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 --tp 2 +``` +- Add `--dp 2` to enable data parallelism. It can also be used together with tp. Data parallelism is better for throughput if there is enough memory. +``` +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 --dp 2 --tp 2 +``` +- If you see out-of-memory errors during serving, please try to reduce the memory usage of the KV cache pool by setting a smaller value of `--mem-fraction-static`. The default value is `0.9`. +``` +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 --mem-fraction-static 0.7 +``` +- If you see out-of-memory errors during prefill for long prompts on a model that supports long context, consider using chunked prefill. +``` +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --port 30000 --chunked-prefill-size 8192 +``` +- See [hyperparameter_tuning.md](docs/en/hyperparameter_tuning.md) on tuning hyperparameters for better performance. +- Add `--nnodes 2` to run tensor parallelism on multiple nodes. If you have two nodes with two GPUs on each node and want to run TP=4, let `sgl-dev-0` be the hostname of the first node and `50000` be an available port. +``` +# Node 0 +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 4 --nccl-init sgl-dev-0:50000 --nnodes 2 --node-rank 0 + +# Node 1 +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 4 --nccl-init sgl-dev-0:50000 --nnodes 2 --node-rank 1 +``` +- If the model does not have a template in the Hugging Face tokenizer, you can specify a [custom chat template](docs/en/custom_chat_template.md). +- To enable fp8 quantization, you can add `--quantization fp8` on a fp16 checkpoint or directly load a fp8 checkpoint without specifying any arguments. +- To enable experimental torch.compile support, you can add `--enable-torch-compile`. It accelerates small models on small batch sizes. + +### Use Models From ModelScope +To use model from [ModelScope](https://www.modelscope.cn), setting environment variable SGLANG_USE_MODELSCOPE. +``` +export SGLANG_USE_MODELSCOPE=true +``` +Launch [Qwen2-7B-Instruct](https://www.modelscope.cn/models/qwen/qwen2-7b-instruct) Server +``` +SGLANG_USE_MODELSCOPE=true python -m sglang.launch_server --model-path qwen/Qwen2-7B-Instruct --port 30000 +``` + +### Supported Models + +- Llama / Llama 2 / Llama 3 / Llama 3.1 +- Mistral / Mixtral +- Gemma / Gemma 2 +- Qwen / Qwen 2 / Qwen 2 MoE +- DeepSeek / DeepSeek 2 +- LLaVA 1.5 / 1.6 + - `python -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000` + - `python -m sglang.launch_server --model-path liuhaotian/llava-v1.6-vicuna-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000` + - `python -m sglang.launch_server --model-path liuhaotian/llava-v1.6-34b --tokenizer-path liuhaotian/llava-v1.6-34b-tokenizer --port 30000` +- LLaVA-NeXT-Video + - see [examples/usage/llava_video](examples/usage/llava_video) +- Yi-VL + - see [srt_example_yi_vl.py](examples/quick_start/srt_example_yi_vl.py). +- StableLM +- Command-R +- DBRX +- Grok +- ChatGLM +- InternLM 2 +- Mistral NeMo + +Instructions for supporting a new model are [here](https://github.com/sgl-project/sglang/blob/main/docs/en/model_support.md). + +### Run Llama 3.1 405B + +```bash +## Run 405B (fp8) on a single node +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct-FP8 --tp 8 + +## Run 405B (fp16) on two nodes +# replace the `172.16.4.52:20000` with your own first node ip address and port, disable CUDA Graph temporarily + +# on the first node +GLOO_SOCKET_IFNAME=eth0 python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct --tp 16 --nccl-init-addr 172.16.4.52:20000 --nnodes 2 --node-rank 0 --disable-cuda-graph --mem-frac 0.75 + +# on the second +GLOO_SOCKET_IFNAME=eth0 python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct --tp 16 --nccl-init-addr 172.16.4.52:20000 --nnodes 2 --node-rank 1 --disable-cuda-graph --mem-frac 0.75 +``` + +### Benchmark Performance + +- Benchmark a single static batch by running the following command without launching a server. The arguments are the same as for `launch_server.py`. Note that this is not a dynamic batching server, so it may run out of memory for a batch size that a real server can handle. A real server truncates the prefill into several batches, while this unit test does not. For accurate large batch testing, consider using `sglang.bench_serving`. + ``` + python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 32 --input-len 256 --output-len 32 + ``` +- Benchmark online serving. Launch a server first and run the following command. + ``` + python3 -m sglang.bench_serving --backend sglang --num-prompt 10 + ``` + +## Frontend: Structured Generation Language (SGLang) +The frontend language can be used with local models or API models. + +### Quick Start The example below shows how to use sglang to answer a mulit-turn question. -### Using Local Models +#### Using Local Models First, launch a server with ``` -python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 ``` Then, connect to the server and answer a multi-turn question. @@ -95,7 +268,7 @@ for m in state.messages(): print(state["answer_1"]) ``` -### Using OpenAI Models +#### Using OpenAI Models Set the OpenAI API Key ``` export OPENAI_API_KEY=sk-****** @@ -126,13 +299,12 @@ for m in state.messages(): print(state["answer_1"]) ``` -### More Examples +#### More Examples Anthropic and VertexAI (Gemini) models are also supported. You can find more examples at [examples/quick_start](examples/quick_start). -## Frontend: Structured Generation Language (SGLang) - +### Language Feature To begin with, import sglang. ```python import sglang as sgl @@ -145,7 +317,7 @@ The system will manage the state, chat template, parallelism and batching for yo The complete code for the examples below can be found at [readme_examples.py](examples/usage/readme_examples.py) -### Control Flow +#### Control Flow You can use any Python code within the function body, including control flow, nested function calls, and external libraries. ```python @@ -160,7 +332,7 @@ def tool_use(s, question): s += "The key word to search is" + sgl.gen("word") ``` -### Parallelism +#### Parallelism Use `fork` to launch parallel prompts. Because `sgl.gen` is non-blocking, the for loop below issues two generation calls in parallel. @@ -182,7 +354,7 @@ def tip_suggestion(s): s += "In summary" + sgl.gen("summary") ``` -### Multi Modality +#### Multi Modality Use `sgl.image` to pass an image as input. ```python @@ -194,7 +366,7 @@ def image_qa(s, image_file, question): See also [srt_example_llava.py](examples/quick_start/srt_example_llava.py). -### Constrained Decoding +#### Constrained Decoding Use `regex` to specify a regular expression as a decoding constraint. This is only supported for local models. @@ -209,7 +381,7 @@ def regular_expression_gen(s): ) ``` -### JSON Decoding +#### JSON Decoding Use `regex` to specify a JSON schema with a regular expression. ```python @@ -238,8 +410,7 @@ def character_gen(s, name): See also [json_decode.py](examples/usage/json_decode.py) for an additional example on specifying formats with Pydantic models. - -### Batching +#### Batching Use `run_batch` to run a batch of requests with continuous batching. ```python @@ -258,7 +429,7 @@ states = text_qa.run_batch( ) ``` -### Streaming +#### Streaming Add `stream=True` to enable streaming. ```python @@ -277,154 +448,38 @@ for out in state.text_iter(): print(out, end="", flush=True) ``` -### Tips and Implementation Details -- The `choices` argument in `sgl.gen` is implemented by computing the [token-length normalized log probabilities](https://blog.eleuther.ai/multiple-choice-normalization/) of all choices and selecting the one with the highest probability. -- The `regex` argument in `sgl.gen` is implemented through autoregressive decoding with logit bias masking, according to the constraints set by the regex. It is compatible with `temperature=0` and `temperature != 0`. - -## Backend: SGLang Runtime (SRT) -The SGLang Runtime (SRT) is designed to work best with the SGLang frontend. -However, it can also be used as a standalone API server. -In this case, the [RadixAttention](https://arxiv.org/abs/2312.07104) can still greatly accelerate many use cases with automatic KV cache reuse. +#### Roles -### Usage -Launch a server -``` -python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 -``` - -Send a request -``` -curl http://localhost:30000/generate \ - -H "Content-Type: application/json" \ - -d '{ - "text": "Once upon a time,", - "sampling_params": { - "max_new_tokens": 16, - "temperature": 0 - } - }' -``` -Learn more about the argument format [here](docs/sampling_params.md). - -### OpenAI Compatible API -In addition, the server supports an experimental OpenAI-compatible API. +Use `sgl.system`, `sgl.user` and `sgl.assistant` to set roles when using Chat models. You can also define more complex role prompts using begin and end tokens. ```python -import openai -client = openai.Client( - base_url="http://127.0.0.1:30000/v1", api_key="EMPTY") - -# Text completion -response = client.completions.create( - model="default", - prompt="The capital of France is", - temperature=0, - max_tokens=32, -) -print(response) - -# Chat completion -response = client.chat.completions.create( - model="default", - messages=[ - {"role": "system", "content": "You are a helpful AI assistant"}, - {"role": "user", "content": "List 3 countries and their capitals."}, - ], - temperature=0, - max_tokens=64, -) -print(response) -``` - -By default, the server uses the chat template specified in the model tokenizer from Hugging Face. It should just work for most official models such as Llama-2/Llama-3. - -If needed, you can also override the chat template when launching the server: - -``` -python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --chat-template llama-2 -``` - -If the chat template you are looking for is missing, you are welcome to contribute it. -Meanwhile, you can also temporarily register your chat template as follows: - -```json -{ - "name": "my_model", - "system": "<|im_start|>system", - "user": "<|im_start|>user", - "assistant": "<|im_start|>assistant", - "sep_style": "CHATML", - "sep": "<|im_end|>", - "stop_str": ["<|im_end|>", "<|im_start|>"] -} -``` +@sgl.function +def chat_example(s): + s += sgl.system("You are a helpful assistant.") + # Same as: s += s.system("You are a helpful assistant.") -``` -python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --chat-template ./my_model_template.json -``` + with s.user(): + s += "Question: What is the capital of France?" -### Additional Arguments -- Add `--tp 2` to enable tensor parallelism. If it indicates `peer access is not supported between these two devices`, add `--enable-p2p-check` option. + s += sgl.assistant_begin() + s += "Answer: " + sgl.gen(max_tokens=100, stop="\n") + s += sgl.assistant_end() ``` -python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --tp 2 -``` -- Add `--dp 2` to enable data parallelism. It can also be used together with tp. Data parallelism is better for throughput if there is enough memory. -``` -python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --dp 2 --tp 2 -``` -- If you see out-of-memory errors during serving, please try to reduce the memory usage of the KV cache pool by setting a smaller value of `--mem-fraction-static`. The default value is `0.9` -``` -python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --mem-fraction-static 0.7 -``` -- See [hyperparameter_tuning.md](docs/hyperparameter_tuning.md) on tuning hyperparameters for better performance. -### Supported Models -- Llama -- Mistral -- Mixtral -- Qwen / Qwen 2 / Qwen 2 MoE -- Gemma / Gemma 2 - - `python -m sglang.launch_server --model-path google/gemma-7b-it --port 30000 --attention-reduce-in-fp32` -- LLaVA - - `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000` - - `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.6-vicuna-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000` - - `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.6-34b --tokenizer-path liuhaotian/llava-v1.6-34b-tokenizer --port 3000` -- LLaVA-NeXT-Video - - see [srt_example_llava_v.sh](examples/usage/llava_video/srt_example_llava_v.sh) -- Yi-VL - - see [srt_example_yi_vl.py](examples/quick_start/srt_example_yi_vl.py). -- StableLM -- Command-R -- DBRX -- Grok -- ChatGLM -- AWQ/GPTQ/Marlin quantization +#### Tips and Implementation Details +- The `choices` argument in `sgl.gen` is implemented by computing the [token-length normalized log probabilities](https://blog.eleuther.ai/multiple-choice-normalization/) of all choices and selecting the one with the highest probability. +- The `regex` argument in `sgl.gen` is implemented through autoregressive decoding with logit bias masking, according to the constraints set by the regex. It is compatible with `temperature=0` and `temperature != 0`. -Instructions for supporting a new model are [here](https://github.com/sgl-project/sglang/blob/main/docs/model_support.md). ## Benchmark And Performance -- Llama-7B on NVIDIA A10G, FP16, Tensor Parallelism=1 -![llama_7b](assets/llama_7b.jpg) - -- Mixtral-8x7B on NVIDIA A10G, FP16, Tensor Parallelism=8 -![mixtral_8x7b](assets/mixtral_8x7b.jpg) +![8b_throughput](https://lmsys.org/images/blog/sglang_llama3/8b_throughput.svg) +![70b_fp8_throughput](https://lmsys.org/images/blog/sglang_llama3/70b_fp8_throughput.svg) -- Learn more about the above [results](docs/benchmark_results.md). -- Synthetic latency and throughput benchmark [scripts](https://github.com/sgl-project/sglang/tree/main/benchmark/latency_throughput). +Learn more at this [blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/). ## Roadmap -https://github.com/sgl-project/sglang/issues/157 +[Development Roadmap (2024 Q3)](https://github.com/sgl-project/sglang/issues/634) ## Citation And Acknowledgment -``` -@misc{zheng2024sglang, - title={SGLang: Efficient Execution of Structured Language Model Programs}, - author={Lianmin Zheng and Liangsheng Yin and Zhiqiang Xie and Chuyue Sun and Jeff Huang and Cody Hao Yu and Shiyi Cao and Christos Kozyrakis and Ion Stoica and Joseph E. Gonzalez and Clark Barrett and Ying Sheng}, - year={2024}, - eprint={2312.07104}, - archivePrefix={arXiv}, - primaryClass={cs.AI} -} -``` - -We learned from the design and reused some code of the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), [LMQL](https://github.com/eth-sri/lmql). +Please cite our paper, [SGLang: Efficient Execution of Structured Language Model Programs](https://arxiv.org/abs/2312.07104), if you find the project useful. +We also learned from the design and reused code from the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), and [LMQL](https://github.com/eth-sri/lmql). diff --git a/assets/llama_7b.jpg b/assets/llama_7b.jpg deleted file mode 100644 index e68960e0160..00000000000 Binary files a/assets/llama_7b.jpg and /dev/null differ diff --git a/assets/mixtral_8x7b.jpg b/assets/mixtral_8x7b.jpg deleted file mode 100644 index 755e4296ea2..00000000000 Binary files a/assets/mixtral_8x7b.jpg and /dev/null differ diff --git a/benchmark/blog_v0_2/405b_sglang.sh b/benchmark/blog_v0_2/405b_sglang.sh new file mode 100644 index 00000000000..eae5e22060a --- /dev/null +++ b/benchmark/blog_v0_2/405b_sglang.sh @@ -0,0 +1,24 @@ +# Create dummy weights: +# 1. Create a folder `~/llama-3.1-405b-fp8-dummy` and create `config.json` and tokenizer under this folder. +# 2. Get `config.json`` from ./config.md +# 3. Download the tokenizer +# wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer.json +# wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer_config.json + +# Launch sglang +# python -m sglang.launch_server --model ~/llama-3.1-405b-fp8-dummy/ --load-format dummy --tp 8 --quant fp8 --disable-radix --mem-frac 0.87 + +# offline +python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 3000 --random-input 1024 --random-output 1024 > sglang_log11 +python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 4000 --random-input 1024 --random-output 512 > sglang_log12 +python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 800 --random-input 4096 --random-output 2048 > sglang_log13 +python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 1500 --random-input 4096 --random-output 1024 > sglang_log14 +python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 6000 --random-input 256 --random-output 512 > sglang_log15 +python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --num-prompt 2000 > sglang_log21 + +# online +python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 300 --request-rate 1 --random-input 1024 --random-output 1024 > sglang_log31 +python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 600 --request-rate 2 --random-input 1024 --random-output 1024 > sglang_log32 +python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 1200 --request-rate 4 --random-input 1024 --random-output 1024 > sglang_log33 +python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 2400 --request-rate 8 --random-input 1024 --random-output 1024 > sglang_log34 +python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 3200 --request-rate 16 --random-input 1024 --random-output 1024 > sglang_log35 \ No newline at end of file diff --git a/benchmark/blog_v0_2/405b_trt.sh b/benchmark/blog_v0_2/405b_trt.sh new file mode 100644 index 00000000000..1950bc92b52 --- /dev/null +++ b/benchmark/blog_v0_2/405b_trt.sh @@ -0,0 +1,17 @@ +# Launch trtllm +# https://github.com/sgl-project/tensorrt-demo + +# offline +python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 3000 --random-input 1024 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log11 +python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 4000 --random-input 1024 --random-output 512 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log12 +python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 800 --random-input 4096 --random-output 2048 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log13 +python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 1500 --random-input 4096 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log14 +python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 6000 --random-input 256 --random-output 512 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log15 +python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name sharegpt --num-prompt 2000 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log21 + +# online +python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 300 --request-rate 1 --random-input 1024 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log31 +python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 600 --request-rate 2 --random-input 1024 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log32 +python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 1200 --request-rate 4 --random-input 1024 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log33 +python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 2400 --request-rate 8 --random-input 1024 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log34 +python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 3200 --request-rate 16 --random-input 1024 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log35 diff --git a/benchmark/blog_v0_2/405b_vllm.sh b/benchmark/blog_v0_2/405b_vllm.sh new file mode 100644 index 00000000000..17833a2000d --- /dev/null +++ b/benchmark/blog_v0_2/405b_vllm.sh @@ -0,0 +1,24 @@ +# Create dummy weights: +# 1. Create a folder `~/llama-3.1-405b-fp8-dummy` and create `config.json` and tokenizer under this folder. +# 2. Get `config.json`` from ./config.md +# 3. Download the tokenizer +# wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer.json +# wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer_config.json + +# Launch vllm +# python3 -m vllm.entrypoints.openai.api_server --model ~/llama-3.1-405b-fp8-dummy/ --load-format dummy --disable-log-requests --tensor-parallel-size 8 --max-model-len 10000 + +# offline +python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 3000 --random-input 1024 --random-output 1024 > vllm_log11 +python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 4000 --random-input 1024 --random-output 512 > vllm_log12 +python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 800 --random-input 4096 --random-output 2048 > vllm_log13 +python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 1500 --random-input 4096 --random-output 1024 > vllm_log14 +python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 6000 --random-input 256 --random-output 512 > vllm_log15 +python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name sharegpt --num-prompt 2000 > vllm_log21 + +# online +python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 300 --request-rate 1 --random-input 1024 --random-output 1024 > vllm_log31 +python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 600 --request-rate 2 --random-input 1024 --random-output 1024 > vllm_log32 +python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 1200 --request-rate 4 --random-input 1024 --random-output 1024 > vllm_log33 +python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 2400 --request-rate 8 --random-input 1024 --random-output 1024 > vllm_log34 +python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 3200 --request-rate 16 --random-input 1024 --random-output 1024 > vllm_log35 \ No newline at end of file diff --git a/benchmark/blog_v0_2/README.md b/benchmark/blog_v0_2/README.md new file mode 100644 index 00000000000..57443e5fe21 --- /dev/null +++ b/benchmark/blog_v0_2/README.md @@ -0,0 +1,164 @@ +# How to reproduce the benchmark results of SGLang + +## Prerequisite + +### Install the latest SGLang + +```bash +git clone https://github.com/sgl-project/sglang.git +cd sglang +git checkout v0.2.7 + +pip install --upgrade pip +pip install -e "python[all]" + +pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/ +``` + +### Set up ulimit and HF_TOKEN + +```bash +ulimit -n 65535 +# Change the token to a real and usable one, with access permissions for the Llama 3 models. +export HF_TOKEN=hf_token +``` + +### Launch the server + +```bash +# Meta-Llama-3.1-8B-Instruct +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --enable-torch-compile --disable-radix-cache + +# Meta-Llama-3.1-70B-Instruct +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-70B-Instruct --disable-radix-cache --tp 8 + +# Meta-Llama-3-70B-Instruct-FP8 +python -m sglang.launch_server --model-path neuralmagic/Meta-Llama-3-70B-Instruct-FP8 --disable-radix-cache --tp 8 +``` + +## Benchmark + +### Hardware Requirements + +- 8B models: Single NVIDIA A100 80GB GPU +- 70B models: 8 x NVIDIA A100 80GB GPUs with Tensor Parallelism (TP) 8 +- 70B FP8 models: 8 x NVIDIA H100 GPUs with Tensor Parallelism (TP) 8 + +Please ensure you have the appropriate hardware before running the benchmarks. + +#### Offline benchmark + +```bash +python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 4000 --random-input 1024 --random-output 1024 --output-file offline.jsonl +python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 5000 --random-input 1024 --random-output 512 --output-file offline.jsonl +python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 1000 --random-input 4096 --random-output 2048 --output-file offline.jsonl +python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 2000 --random-input 4096 --random-output 1024 --output-file offline.jsonl +python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 6000 --random-input 256 --random-output 512 --output-file offline.jsonl +python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --num-prompts 3000 --output-file offline.jsonl +cat offline.jsonl | cut -d':' -f12 | cut -d',' -f1 +``` + +#### Online benchmark + +```bash +python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 300 --request-rate 1 --output-file online.jsonl +python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 600 --request-rate 2 --output-file online.jsonl +python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 1200 --request-rate 4 --output-file online.jsonl +python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 2400 --request-rate 8 --output-file online.jsonl +python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 3200 --request-rate 16 --output-file online.jsonl +cat online.jsonl | cut -d':' -f9 | cut -d',' -f1 +``` + +## Other + +We tried using vLLM 0.5.3.post1, but it often crashes under high loads, and it seems to have similar or worse performance compared to vLLM 0.5.2 from our partial benchmarking, so we are using the older version, vLLM 0.5.2. + +Preparation for TensorRT LLM can refer to https://github.com/sgl-project/tensorrt-demo. Specifically, we used a batch size of 512, a max input length of 8192, and a max number of tokens of 8192. The instance count for preprocessing and postprocessing in Triton Server is 16. + +```bash +# vLLM +pip install vllm==0.5.2 +pip install jsonschema==4.21.1 + +# Meta-Llama-3-8B-Instruct +python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-8B-Instruct --disable-log-requests + +# meta-llama/Meta-Llama-3-70B-Instruct +python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-70B-Instruct --disable-log-requests --tensor 8 + +# neuralmagic/Meta-Llama-3-70B-Instruct-FP8 +python -m vllm.entrypoints.openai.api_server --model neuralmagic/Meta-Llama-3-70B-Instruct-FP8 --disable-log-requests --tensor 8 +``` + +```bash +wget https://raw.githubusercontent.com/sgl-project/sglang/main/python/sglang/bench_serving.py +``` + +```bash +# vLLM Offline + +python3 bench_serving.py --backend vllm --dataset-name random --num-prompts 4000 --random-input 1024 --random-output 1024 --output-file offline_vllm.jsonl +python3 bench_serving.py --backend vllm --dataset-name random --num-prompts 5000 --random-input 1024 --random-output 512 --output-file offline_vllm.jsonl +python3 bench_serving.py --backend vllm --dataset-name random --num-prompts 1000 --random-input 4096 --random-output 2048 --output-file offline_vllm.jsonl +python3 bench_serving.py --backend vllm --dataset-name random --num-prompts 2000 --random-input 4096 --random-output 1024 --output-file offline_vllm.jsonl +python3 bench_serving.py --backend vllm --dataset-name random --num-prompts 6000 --random-input 256 --random-output 512 --output-file offline_vllm.jsonl +python3 bench_serving.py --backend vllm --dataset-name sharegpt --num-prompts 3000 --output-file offline_vllm.jsonl +cat offline_vllm.jsonl | cut -d':' -f12 | cut -d',' -f1 +``` + +```bash +# vLLM Online + +python3 bench_serving.py --backend vllm --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 300 --request-rate 1 --output-file online_vllm.jsonl +python3 bench_serving.py --backend vllm --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 600 --request-rate 2 --output-file online_vllm.jsonl +python3 bench_serving.py --backend vllm --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 1200 --request-rate 4 --output-file online_vllm.jsonl +python3 bench_serving.py --backend vllm --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 2400 --request-rate 8 --output-file online_vllm.jsonl +python3 bench_serving.py --backend vllm --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 3200 --request-rate 16 --output-file online_vllm.jsonl +cat online_vllm.jsonl | cut -d':' -f9 | cut -d',' -f1 +``` + +```bash +# TensorRT LLM Offline 8B + +python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --num-prompts 4000 --random-input 1024 --random-output 1024 --output-file offline_trt_8b.jsonl +python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --num-prompts 5000 --random-input 1024 --random-output 512 --output-file offline_trt_8b.jsonl +python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --num-prompts 1000 --random-input 4096 --random-output 2048 --output-file offline_trt_8b.jsonl +python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --num-prompts 2000 --random-input 4096 --random-output 1024 --output-file offline_trt_8b.jsonl +python3 bench_serving.py --backend trt --dataset-name random --num-prompts 6000 --random-input 256 --random-output 512 --output-file offline_trt_8b.jsonl --model meta-llama/Meta-Llama-3-8B-Instruct +python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name sharegpt --num-prompts 3000 --output-file offline_trt_8b.jsonl +cat offline_trt_8b.jsonl | cut -d':' -f12 | cut -d',' -f1 +``` + +```bash +# TensorRT LLM Online 8B + +python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 300 --request-rate 1 --output-file online_trt_8b.jsonl +python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 600 --request-rate 2 --output-file online_trt_8b.jsonl +python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 1200 --request-rate 4 --output-file online_trt_8b.jsonl +python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 2400 --request-rate 8 --output-file online_trt_8b.jsonl +python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 3200 --request-rate 16 --output-file online_trt_8b.jsonl +cat online_trt_8b.jsonl | cut -d':' -f9 | cut -d',' -f1 +``` + +```bash +# TensorRT LLM Offline 70B + +python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --num-prompts 4000 --random-input 1024 --random-output 1024 --output-file offline_trt_70b.jsonl +python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --num-prompts 5000 --random-input 1024 --random-output 512 --output-file offline_trt_70b.jsonl +python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --num-prompts 1000 --random-input 4096 --random-output 2048 --output-file offline_trt_70b.jsonl +python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --num-prompts 2000 --random-input 4096 --random-output 1024 --output-file offline_trt_70b.jsonl +python3 bench_serving.py --backend trt --dataset-name random --num-prompts 6000 --random-input 256 --random-output 512 --output-file offline_trt_70b.jsonl --model meta-llama/Meta-Llama-3-70B-Instruct +python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name sharegpt --num-prompts 3000 --output-file offline_trt_70b.jsonl +cat offline_trt_70b.jsonl | cut -d':' -f12 | cut -d',' -f1 +``` + +```bash +# TensorRT LLM Online 70B + +python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 300 --request-rate 1 --output-file online_trt_70b.jsonl +python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 600 --request-rate 2 --output-file online_trt_70b.jsonl +python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 1200 --request-rate 4 --output-file online_trt_70b.jsonl +python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 2400 --request-rate 8 --output-file online_trt_70b.jsonl +python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 3200 --request-rate 16 --output-file online_trt_70b.jsonl +cat online_trt_70b.jsonl | cut -d':' -f9 | cut -d',' -f1 +``` diff --git a/benchmark/blog_v0_2/config.md b/benchmark/blog_v0_2/config.md new file mode 100644 index 00000000000..3faf6009bb7 --- /dev/null +++ b/benchmark/blog_v0_2/config.md @@ -0,0 +1,100 @@ +### used for TensorRT LLM + +``` +{ + "architecture": "LlamaForCausalLM", + "dtype": "float16", + "logits_dtype": "float32", + "vocab_size": 128256, + "max_position_embeddings": 8192, + "hidden_size": 16384, + "num_hidden_layers": 126, + "num_attention_heads": 128, + "num_key_value_heads": 16, + "head_size": 128, + "qk_layernorm": false, + "hidden_act": "silu", + "intermediate_size": 53248, + "norm_epsilon": 1e-05, + "position_embedding_type": "rope_gpt_neox", + "use_parallel_embedding": false, + "embedding_sharding_dim": 0, + "share_embedding_table": false, + "mapping": { + "world_size": 8, + "tp_size": 8, + "pp_size": 1, + "gpus_per_node": 8 + }, + "quantization": { + "quant_algo": "FP8", + "kv_cache_quant_algo": null, + "group_size": 128, + "smoothquant_val": null, + "has_zero_point": false, + "pre_quant_scale": false, + "exclude_modules": [ + "lm_head" + ] + }, + "kv_dtype": "float16", + "rotary_scaling": null, + "residual_mlp": false, + "moe_normalization_mode": null, + "rotary_base": 500000.0, + "moe_num_experts": 0, + "moe_top_k": 0, + "moe_tp_mode": 2, + "attn_bias": false, + "disable_weight_only_quant_plugin": false, + "mlp_bias": false +} +``` + +### used for vLLM and SGLang + +``` +{ + "_name_or_path": "dummy_fp8", + "architectures": [ + "LlamaForCausalLM" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 128000, + "eos_token_id": 128009, + "hidden_act": "silu", + "hidden_size": 16384, + "initializer_range": 0.02, + "intermediate_size": 53248, + "mlp_bias": false, + "model_type": "llama", + "num_attention_heads": 128, + "num_hidden_layers": 126, + "num_key_value_heads": 8, + "pretraining_tp": 1, + "quantization_config": { + "activation_scheme": "static", + "ignored_layers": [ + "lm_head" + ], + "quant_method": "fp8" + }, + "rope_scaling": { + "factor": 8.0, + "low_freq_factor": 1.0, + "high_freq_factor": 4.0, + "original_max_position_embeddings": 8192, + "rope_type": "llama3" + }, + "max_position_embeddings": 131072, + "rms_norm_eps": 1e-05, + "rope_scaling": null, + "rope_theta": 500000.0, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.41.1", + "use_cache": true, + "vocab_size": 128256 +} +``` diff --git a/benchmark/gsm8k/bench_sglang.py b/benchmark/gsm8k/bench_sglang.py index d1ed22cbe14..298ec11d73d 100644 --- a/benchmark/gsm8k/bench_sglang.py +++ b/benchmark/gsm8k/bench_sglang.py @@ -64,7 +64,7 @@ def main(args): @sgl.function def few_shot_gsm8k(s, question): s += few_shot_examples + question - s += sgl.gen("answer", max_tokens=256, stop="Question") + s += sgl.gen("answer", max_tokens=512, stop="Question") ##################################### ########## SGL Program End ########## diff --git a/benchmark/gsm8k/download_data.sh b/benchmark/gsm8k/download_data.sh old mode 100644 new mode 100755 diff --git a/benchmark/latency_throughput/README.md b/benchmark/latency_throughput/README.md index af136e1d6a6..b6c2e679718 100644 --- a/benchmark/latency_throughput/README.md +++ b/benchmark/latency_throughput/README.md @@ -1,9 +1,8 @@ - # Benchmark Latency and Throughput ## SGLang -### Launch server +### Launch a server ``` python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 ``` @@ -30,9 +29,14 @@ wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/r #### Run ShareGPT ``` -python3 bench_throughput.py --backend srt --port 30000 --tokenizer meta-llama/Llama-2-7b-chat-hf --dataset ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 10 --request-rate 10 +python3 bench_serving.py --backend srt --port 30000 --tokenizer meta-llama/Llama-2-7b-chat-hf --dataset ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 10 --request-rate 10 ``` +### Profile with Nsight +1. To profile a single batch, use `nsys profile --cuda-graph-trace=node python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3-8B --batch-size 64 --input-len 512` +2. To profile a server, use `nsys profile --trace-fork-before-exec=true --cuda-graph-trace=node python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3-8B`. + + ## Other baselines ### vLLM @@ -42,14 +46,20 @@ python3 -m vllm.entrypoints.api_server --model meta-llama/Llama-2-7b-chat-hf --t ``` # run synthetic -python3 bench_throughput.py --backend vllm --port 30000 --tokenizer meta-llama/Llama-2-7b-chat-hf --num-prompt 1000 --request-rate 100 --input-len 1024 --output-len 256 +python3 bench_serving.py --backend vllm --port 30000 --tokenizer meta-llama/Llama-2-7b-chat-hf --num-prompt 1000 --request-rate 100 --input-len 1024 --output-len 256 ``` ``` # run ShareGPT -python3 bench_throughput.py --backend vllm --port 21000 --tokenizer meta-llama/Llama-2-7b-chat-hf --dataset ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 10 --request-rate 10 +python3 bench_serving.py --backend vllm --port 21000 --tokenizer meta-llama/Llama-2-7b-chat-hf --dataset ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 10 --request-rate 10 ``` +``` +# run one batch +python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-70B --tensor 8 --disable-log-requests --max-num-seqs 1024 --quantization fp8 + +python3 bench_one.py --input-len 1024 --batch-size 1 1 2 4 8 16 32 64 128 256 512 768 1024 --port 8000 --backend vllm +``` ### LightLLM ``` @@ -57,5 +67,5 @@ python -m lightllm.server.api_server --model_dir ~/model_weights/Llama-2-7b-chat ``` ``` -python3 bench_throughput.py --backend lightllm --port 22000 --tokenizer meta-llama/Llama-2-7b-chat-hf --dataset ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 10 --request-rate 10 -``` \ No newline at end of file +python3 bench_serving.py --backend lightllm --port 22000 --tokenizer meta-llama/Llama-2-7b-chat-hf --dataset ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 10 --request-rate 10 +``` diff --git a/benchmark/latency_throughput/bench_one.py b/benchmark/latency_throughput/bench_one.py index 5b2fa3cbc7c..b390c44a536 100644 --- a/benchmark/latency_throughput/bench_one.py +++ b/benchmark/latency_throughput/bench_one.py @@ -1,52 +1,51 @@ +""" +Usage: +python3 bench_one.py --input-len 2048 --batch-size 1 2 4 8 16 32 64 128 256 512 +""" + import argparse +import json import time +import numpy as np import requests -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--host", type=str, default="http://127.0.0.1") - parser.add_argument("--port", type=int, default=None) - parser.add_argument("--backend", type=str, default="srt") - parser.add_argument("--batch-size", type=int, default=1) - parser.add_argument("--max-tokens", type=int, default=256) - args = parser.parse_args() - - if args.port is None: - if args.backend == "srt": - args.port = 30000 - elif args.backend == "vllm": - args.port = 21000 - elif args.backend == "lightllm": - args.port = 22000 - elif args.backend == "ginfer": - args.port = 9988 - else: - raise ValueError(f"Invalid backend: {args.backend}") +def run_one_batch_size(bs): url = f"{args.host}:{args.port}" - a = 20 max_new_tokens = args.max_tokens - prompt = f"{a, }" + + if args.input_len: + input_ids = [ + [int(x) for x in np.random.randint(0, high=16384, size=(args.input_len,))] + for _ in range(bs) + ] + else: + text = [f"{i, }" for i in range(bs)] tic = time.time() if args.backend == "srt": + if args.input_len: + inputs = {"input_ids": input_ids} + else: + inputs = {"text": text} + response = requests.post( url + "/generate", json={ - "text": [prompt] * args.batch_size, "sampling_params": { "temperature": 0, "max_new_tokens": max_new_tokens, "ignore_eos": True, }, + **inputs, }, ) elif args.backend == "lightllm": response = requests.post( url + "/generate", json={ - "inputs": prompt, + "inputs": text[0], "parameters": { "temperature": 0, "max_new_tokens": max_new_tokens, @@ -55,13 +54,19 @@ }, ) elif args.backend == "vllm": + if args.input_len: + inputs = {"prompt": input_ids} + else: + inputs = {"prompt": text} + response = requests.post( - url + "/generate", + url + "/v1/completions", json={ - "prompt": prompt, + "model": args.vllm_model_name, "temperature": 0, "max_tokens": max_new_tokens, "ignore_eos": True, + **inputs, }, ) elif args.backend == "ginfer": @@ -73,7 +78,7 @@ tic = time.time() sample_request = sampler_pb2.SampleTextRequest( - prompt=prompt, + prompt=text[0], settings=sampler_pb2.SampleSettings( max_len=max_new_tokens, rng_seed=0, @@ -91,5 +96,52 @@ ret = response.json() print(ret) - speed = args.batch_size * max_new_tokens / latency - print(f"latency: {latency:.2f} s, speed: {speed:.2f} token/s") + input_len = args.input_len if args.input_len else 1 + output_len = max_new_tokens + + output_throughput = bs * max_new_tokens / latency + overall_throughput = bs * (input_len + output_len) / latency + print(f"latency: {latency:.2f} s") + print(f"output throughput: {output_throughput:.2f} token/s") + print(f"(input + output) throughput: {overall_throughput:.2f} token/s") + + with open("results.jsonl", "a") as fout: + res = { + "backend": args.backend, + "input_len": args.input_len, + "output_len": args.max_tokens, + "batch_size": bs, + "latency": latency, + "output_throughput": output_throughput, + "overall_throughput": overall_throughput, + } + fout.write(json.dumps(res) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="http://127.0.0.1") + parser.add_argument("--port", type=int, default=None) + parser.add_argument("--backend", type=str, default="srt") + parser.add_argument("--input-len", type=int, default=None) + parser.add_argument("--batch-size", type=int, nargs="*", default=[1]) + parser.add_argument("--max-tokens", type=int, default=256) + parser.add_argument( + "--vllm-model-name", type=str, default="meta-llama/Meta-Llama-3-70B" + ) + args = parser.parse_args() + + if args.port is None: + if args.backend == "srt": + args.port = 30000 + elif args.backend == "vllm": + args.port = 21000 + elif args.backend == "lightllm": + args.port = 22000 + elif args.backend == "ginfer": + args.port = 9988 + else: + raise ValueError(f"Invalid backend: {args.backend}") + + for bs in args.batch_size: + run_one_batch_size(bs) diff --git a/benchmark/latency_throughput/bench_serving.py b/benchmark/latency_throughput/bench_serving.py index 1adb78958cc..74fafc9494a 100644 --- a/benchmark/latency_throughput/bench_serving.py +++ b/benchmark/latency_throughput/bench_serving.py @@ -248,7 +248,7 @@ def main(args: argparse.Namespace): random.seed(args.seed) np.random.seed(args.seed) - api_url = f"http://{args.host}:{args.port}/generate" + api_url = f"{args.host}:{args.port}/generate" if args.tokenizer.endswith(".json") or args.tokenizer.endswith(".model"): from sglang.srt.hf_transformers_utils import get_tokenizer @@ -297,7 +297,8 @@ def main(args: argparse.Namespace): benchmark_time = benchmark_end_time - benchmark_start_time # Compute the statistics. - avg_latency = np.mean([latency for _, _, latency in REQUEST_LATENCY]) + latencies = [latency for _, _, latency in REQUEST_LATENCY] + avg_latency = np.mean(latencies) avg_per_token_latency = np.mean( [ latency / (prompt_len + output_len) @@ -311,6 +312,9 @@ def main(args: argparse.Namespace): np.sum([output_len for _, output_len, _ in REQUEST_LATENCY]) / benchmark_time ) + # latencies = [round(latency, 2) for _, _, latency in REQUEST_LATENCY] + # print(latencies) + print(f"Total time: {benchmark_time:.2f} s") print(f"Request throughput: {args.num_prompts / benchmark_time:.2f} requests/s") print(f"Decoding throughput: {decoding_throughput:.2f} token/s") @@ -329,7 +333,7 @@ def main(args: argparse.Namespace): default="srt", choices=["vllm", "tgi", "srt", "lightllm", "ginfer"], ) - parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--host", type=str, default="http://localhost") parser.add_argument("--port", type=int, default=30000) parser.add_argument("--dataset", type=str, help="Path to the dataset.") parser.add_argument("--input-len", type=int, default=2048) diff --git a/benchmark/line_retrieval/gen_data.py b/benchmark/line_retrieval/gen_data.py index 5763e661590..c88ecba493c 100644 --- a/benchmark/line_retrieval/gen_data.py +++ b/benchmark/line_retrieval/gen_data.py @@ -48,9 +48,9 @@ def generate_lines(random_words, num_lines, redirect_ratio): ) for i in redirect_indices: target_idx = np.random.choice(min(i * 2 + 100, num_lines)) - lines[ - i - ] = f"Line {indices[i]}: The REGISTER_CONTENT is the same as Line {indices[target_idx]}." + lines[i] = ( + f"Line {indices[i]}: The REGISTER_CONTENT is the same as Line {indices[target_idx]}." + ) redirects[i] = target_idx # Build links and find sources diff --git a/docker/Dockerfile b/docker/Dockerfile index 3f2e870082a..9571d71a901 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,6 +1,34 @@ -FROM vllm/vllm-openai +ARG CUDA_VERSION=12.1.1 -RUN pip install --upgrade pip -RUN pip install "sglang[all]" -RUN pip uninstall -y triton triton-nightly && pip install --no-deps --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly -RUN pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/ +FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 + +ARG PYTHON_VERSION=3 + +ENV DEBIAN_FRONTEND=noninteractive + +RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ + && echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \ + && apt-get update -y \ + && apt-get install -y ccache software-properties-common \ + && add-apt-repository ppa:deadsnakes/ppa \ + && apt-get update -y \ + && apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv python${PYTHON_VERSION}-pip \ + && if [ "${PYTHON_VERSION}" != "3" ]; then update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1; fi \ + && python3 --version \ + && python3 -m pip --version \ + && rm -rf /var/lib/apt/lists/* \ + && apt-get clean + +RUN apt-get update -y \ + && apt-get install -y git curl sudo + +WORKDIR /sgl-workspace + +RUN pip3 --no-cache-dir install --upgrade pip \ + && pip3 --no-cache-dir install --upgrade setuptools wheel \ + && git clone --depth=1 https://github.com/sgl-project/sglang.git \ + && cd sglang \ + && pip --no-cache-dir install -e "python[all]" \ + && pip3 --no-cache-dir install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ + +ENV DEBIAN_FRONTEND=interactive diff --git a/docs/benchmark_results.md b/docs/benchmark_results.md deleted file mode 100644 index 519dfec3fcc..00000000000 --- a/docs/benchmark_results.md +++ /dev/null @@ -1,22 +0,0 @@ -## Benchmark Results - -We tested our system on the following common LLM workloads and reported the achieved throughput: -- **[MMLU](https://arxiv.org/abs/2009.03300)**: A 5-shot, multi-choice, multi-task benchmark. -- **[HellaSwag](https://arxiv.org/abs/1905.07830)**: A 20-shot, multi-choice sentence completion benchmark. -- **[ReAct Agent](https://arxiv.org/abs/2210.03629)**: An agent task using prompt traces collected from the original ReAct paper. -- **[Tree-of-Thought](https://arxiv.org/pdf/2305.10601.pdf)**: A custom tree search-based prompt for solving GSM-8K problems. -- **JSON Decode**: Extracting information from a Wikipedia page and outputting it in JSON format. -- **Chat (short)**: A synthetic chat benchmark where each conversation includes 4 turns with short LLM outputs. -- **Chat (long)**: A synthetic chat benchmark where each conversation includes 4 turns with long LLM outputs. -- **[DSPy RAG](https://github.com/stanfordnlp/dspy)**: A retrieval-augmented generation pipeline in the DSPy tutorial. -- **[LLaVA Bench](https://github.com/haotian-liu/LLaVA)**: Running LLaVA v1.5, a vision language model on the LLaVA-in-the-wild benchmark. - -We tested both Llama-7B on one NVIDIA A10G GPU (24GB) and Mixtral-8x7B on 8 NVIDIA A10G GPUs with tensor parallelism, using FP16 precision. We used vllm v0.2.5, guidance v0.1.8, Hugging Face TGI v1.3.0, and SGLang v0.1.5. - -- Llama-7B on NVIDIA A10G, FP16, Tensor Parallelism=1 -![llama_7b](../assets/llama_7b.jpg) - -- Mixtral-8x7B on NVIDIA A10G, FP16, Tensor Parallelism=8 -![mixtral_8x7b](../assets/mixtral_8x7b.jpg) - -The benchmark code is available [here](https://github.com/sgl-project/sglang/tree/main/benchmark). diff --git a/docs/en/Makefile b/docs/en/Makefile new file mode 100644 index 00000000000..9ad4b38e0e3 --- /dev/null +++ b/docs/en/Makefile @@ -0,0 +1,12 @@ +SPHINXOPTS = +SPHINXBUILD = sphinx-build +SOURCEDIR = . +BUILDDIR = _build + +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/en/_static/css/readthedocs.css b/docs/en/_static/css/readthedocs.css new file mode 100644 index 00000000000..aca6649b436 --- /dev/null +++ b/docs/en/_static/css/readthedocs.css @@ -0,0 +1,9 @@ +table.autosummary td { + width: 50% +} + +img.align-center { + display: block; + margin-left: auto; + margin-right: auto; +} diff --git a/docs/en/_static/image/logo.png b/docs/en/_static/image/logo.png new file mode 100644 index 00000000000..2a8bc258f66 Binary files /dev/null and b/docs/en/_static/image/logo.png differ diff --git a/docs/en/choices_methods.md b/docs/en/choices_methods.md new file mode 100644 index 00000000000..e0f3ed313c4 --- /dev/null +++ b/docs/en/choices_methods.md @@ -0,0 +1,77 @@ +# Choices Methods in SGLang +This doc describes the choices methods supported by SGLang. + +The optional `choices_method` arg determines how options supplied to SGLang's `choices` primitive are selected. Only the `RuntimeEndpoint` backend supports the `choices_method` arg. Other backends, such as `OpenAI`, have bespoke selection implementations due to API limitations. + +## Methods + +### Token Length Normalized + +Token length normalized is the default SGLang choices method. It selects the option with the highest average logprob across all of its tokens. + +Usage example (alternatively, simply omit the `choices_method` arg): +```python +@sgl.function +def example(s): + s += sgl.user("What is the capital of France?") + s += sgl.assistant( + sgl.gen( + "answer", + choices=["London", "Paris", "Berlin"], + choices_method=sgl.token_length_normalized, + ) + ) +``` + + +This can perform poorly if an option contains many tokens, where its later tokens are predicted with high confidence based on its earlier tokens. For instance, even strong models will fail the above example if the specified options are `["Paris", "Antidisestablishmentarianism"]`. + +### Greedy Token Selection + +Greedy token selection simply selects the option with the highest logprob for its initial token. For overlapping options where one option is a subset of a longer option, the logprobs of the shorter option are extended using its average logprob for comparison against the longer option. + +Usage example: +```python +@sgl.function +def example(s): + s += sgl.user("What is the capital of France?") + s += sgl.assistant( + sgl.gen( + "answer", + choices=["London", "Paris", "Berlin"], + choices_method=sgl.greedy_token_selection, + ) + ) +``` + +This can perform poorly if an option misleads the model down a bad path based on an attractive initial token. For instance, greedy selection will result in an incorrect response for this example: +```python +@sgl.function +def us_president_example(s): + s += sgl.user("Name a US president.") + s += sgl.assistant( + sgl.gen( + "answer", + choices=["Donald Duck", "Millard Fillmore"], + choices_method=sgl.greedy_token_selection, + ) + ) +``` + +### Unconditional Likelihood Normalized + +Unconditional likelihood normalized selects the option with the highest average token logprob once normalized by the unconditional token logprobs, as described in [this EleutherAI blogpost](https://blog.eleuther.ai/multiple-choice-normalization/). This method incurs an additional LLM call to obtain the unconditional likelihoods. + +Usage example: +```python +@sgl.function +def example(s): + s += sgl.user("What is the capital of France?") + s += sgl.assistant( + sgl.gen( + "answer", + choices=["London", "Paris", "Berlin"], + choices_method=sgl.unconditional_likelihood_normalized, + ) + ) +``` \ No newline at end of file diff --git a/docs/en/conf.py b/docs/en/conf.py new file mode 100644 index 00000000000..5a7ed2dbfa8 --- /dev/null +++ b/docs/en/conf.py @@ -0,0 +1,125 @@ +import os +import sys + +sys.path.insert(0, os.path.abspath("../..")) + +version_file = "../../python/sglang/version.py" +with open(version_file, "r") as f: + exec(compile(f.read(), version_file, "exec")) +__version__ = locals()["__version__"] + +project = "SGLang" +copyright = "2023-2024, SGLang" +author = "SGLang Team" + +version = __version__ +release = __version__ + +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.napoleon", + "sphinx.ext.viewcode", + "sphinx.ext.autosectionlabel", + "sphinx.ext.intersphinx", + "sphinx_tabs.tabs", + "myst_parser", + "sphinx_copybutton", + "sphinxcontrib.mermaid", +] + +autosectionlabel_prefix_document = True + +templates_path = ["_templates"] + +source_suffix = { + ".rst": "restructuredtext", + ".md": "markdown", +} + +master_doc = "index" + +language = "en" + +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] + +pygments_style = "sphinx" + +html_theme = "sphinx_book_theme" +html_logo = "_static/image/logo.png" +html_title = project +html_copy_source = True +html_last_updated_fmt = "" + +html_theme_options = { + "path_to_docs": "docs/en", + "repository_url": "https://github.com/sgl-project/sglang", + "repository_branch": "main", + "show_navbar_depth": 3, + "max_navbar_depth": 4, + "collapse_navbar": True, + "use_edit_page_button": True, + "use_source_button": True, + "use_issues_button": True, + "use_repository_button": True, + "use_download_button": True, + "use_sidenotes": True, + "show_toc_level": 2, +} + +html_static_path = ["_static"] +html_css_files = ["css/readthedocs.css"] + +myst_enable_extensions = [ + "dollarmath", + "amsmath", + "deflist", + "colon_fence", +] +myst_heading_anchors = 5 + +htmlhelp_basename = "sglangdoc" + +latex_elements = {} + +latex_documents = [ + (master_doc, "sglang.tex", "sglang Documentation", "SGLang Team", "manual"), +] + +man_pages = [(master_doc, "sglang", "sglang Documentation", [author], 1)] + +texinfo_documents = [ + ( + master_doc, + "sglang", + "sglang Documentation", + author, + "sglang", + "One line description of project.", + "Miscellaneous", + ), +] + +epub_title = project + +epub_exclude_files = ["search.html"] + +copybutton_prompt_text = r">>> |\.\.\. " +copybutton_prompt_is_regexp = True + +autodoc_preserve_defaults = True +navigation_with_keys = False + +autodoc_mock_imports = [ + "torch", + "transformers", + "triton", +] + +intersphinx_mapping = { + "python": ("https://docs.python.org/3.12", None), + "typing_extensions": ("https://typing-extensions.readthedocs.io/en/latest", None), + "pillow": ("https://pillow.readthedocs.io/en/stable", None), + "numpy": ("https://numpy.org/doc/stable", None), + "torch": ("https://pytorch.org/docs/stable", None), +} diff --git a/docs/en/contributor_guide.md b/docs/en/contributor_guide.md new file mode 100644 index 00000000000..7a87187c1f0 --- /dev/null +++ b/docs/en/contributor_guide.md @@ -0,0 +1,11 @@ +# Contributor Guide + +## Format Your Code +Use these commands to format your code and pass CI linting tests. + +``` +pip3 install pre-commit +cd sglang +pre-commit install . +pre-commit run --all-files +``` diff --git a/docs/en/custom_chat_template.md b/docs/en/custom_chat_template.md new file mode 100644 index 00000000000..815c7e6760b --- /dev/null +++ b/docs/en/custom_chat_template.md @@ -0,0 +1,28 @@ +# Custom Chat Template in SGLang Runtime + +By default, the server uses the chat template specified in the model tokenizer from Hugging Face. It should just work for most official models such as Llama-2/Llama-3. + +If needed, you can also override the chat template when launching the server: + +``` +python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --chat-template llama-2 +``` + +If the chat template you are looking for is missing, you are welcome to contribute it. +Meanwhile, you can also temporarily register your chat template as follows: + +```json +{ + "name": "my_model", + "system": "<|im_start|>system", + "user": "<|im_start|>user", + "assistant": "<|im_start|>assistant", + "sep_style": "CHATML", + "sep": "<|im_end|>", + "stop_str": ["<|im_end|>", "<|im_start|>"] +} +``` + +``` +python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --chat-template ./my_model_template.json +``` \ No newline at end of file diff --git a/docs/hyperparameter_tuning.md b/docs/en/hyperparameter_tuning.md similarity index 82% rename from docs/hyperparameter_tuning.md rename to docs/en/hyperparameter_tuning.md index dec516bc92e..53b92435c79 100644 --- a/docs/hyperparameter_tuning.md +++ b/docs/en/hyperparameter_tuning.md @@ -6,11 +6,11 @@ Achieving a large batch size is the most important thing for attaining high thro When the server is running at full load, look for the following in the log: -```[gpu_id=0] Decode batch. #running-req: 233, #token: 370959, token usage: 0.82, gen throughput (token/s): 4594.01, #queue-req: 417``` +```[gpu=0] Decode batch. #running-req: 233, #token: 370959, token usage: 0.82, gen throughput (token/s): 4594.01, #queue-req: 417``` ### Tune Your Request Submission Speed `#queue-req` indicates the number of requests in the queue. If you frequently see `#queue-req == 0`, it suggests you are bottlenecked by the request submission speed. -A healthy range for `#queue-req` is `100 - 3000`. +A healthy range for `#queue-req` is `100 - 1000`. ### Tune `--schedule-conservativeness` `token usage` indicates the KV cache memory utilization of the server. `token usage > 0.9` means good utilization. @@ -29,7 +29,7 @@ If OOM happens during prefill, try to decrease `--max-prefill-tokens`. If OOM happens during decoding, try to decrease `--max-running-requests`. You can also try to decrease `--mem-fraction-static`, which reduces the memory usage of the KV cache memory pool and helps both prefill and decoding. -### (Minor) Tune `--schedule-heuristic` -If you have many shared prefixes, use the default `--schedule-heuristic lpm`. `lpm` stands for longest prefix match. +### (Minor) Tune `--schedule-policy` +If you have many shared prefixes, use the default `--schedule-policy lpm`. `lpm` stands for longest prefix match. When you have no shared prefixes at all or you always send the requests with the shared prefixes together, -you can try `--schedule-heuristic fcfs`. `fcfs` stands for first come first serve. +you can try `--schedule-policy fcfs`. `fcfs` stands for first come first serve. diff --git a/docs/en/index.rst b/docs/en/index.rst new file mode 100644 index 00000000000..5e4701c53b6 --- /dev/null +++ b/docs/en/index.rst @@ -0,0 +1,64 @@ +Welcome to SGLang's tutorials! +==================================== + +.. figure:: ./_static/image/logo.png + :width: 50% + :align: center + :alt: SGLang + :class: no-scaled-link + +.. raw:: html + +

+ SGLang is yet another fast serving framework for large language models and vision language models. + +

+ +

+ + Star + Watch + Fork +

+ +SGLang has the following core features: + +* **Fast Backend Runtime**: Efficient serving with RadixAttention for prefix caching, jump-forward constrained decoding, continuous batching, token attention (paged attention), tensor parallelism, flashinfer kernels, and quantization (AWQ/FP8/GPTQ/Marlin). + +* **Flexible Frontend Language**: Enables easy programming of LLM applications with chained generation calls, advanced prompting, control flow, multiple modalities, parallelism, and external interactions. + +Documentation +------------- + +.. _hyperparameter_tuning: +.. toctree:: + :maxdepth: 1 + :caption: Hyperparameter Tuning + + hyperparameter_tuning.md + +.. _custom_chat_template: +.. toctree:: + :maxdepth: 1 + :caption: Custom Chat Template + + custom_chat_template.md + +.. _model_support: +.. toctree:: + :maxdepth: 1 + :caption: Model Support + + model_support.md + +.. _sampling_params: +.. toctree:: + :maxdepth: 1 + :caption: Sampling Params + + sampling_params.md + +Search Bar +================== + +* :ref:`search` diff --git a/docs/model_support.md b/docs/en/model_support.md similarity index 84% rename from docs/model_support.md rename to docs/en/model_support.md index 08e942938e1..e46e99e85c8 100644 --- a/docs/model_support.md +++ b/docs/en/model_support.md @@ -1,4 +1,4 @@ -## How to Support a New Model +# How to Support a New Model To support a new model in SGLang, you only need to add a single file under [SGLang Models Directory](https://github.com/sgl-project/sglang/tree/main/python/sglang/srt/models). You can learn from existing model implementations and create new files for the new models. Most models are based on the transformer architecture, making them very similar. @@ -11,5 +11,6 @@ To port a model from vLLM to SGLang, you can compare these two files [SGLang LLa - Change `forward()` functions, and add `input_metadata`. - Add `EntryClass` at the end. - Test correctness by comparing the final logits and outputs of the two following commands: - - `python3 playground/reference_hf.py --model [new model]` - - `python3 -m sglang.bench_latency --model [new model] --correct --output-len 16` + - `python3 scripts/playground/reference_hf.py --model [new model]` + - `python3 -m sglang.bench_latency --model [new model] --correct --output-len 16 --trust-remote-code` + - Update [Supported Models](https://github.com/sgl-project/sglang/tree/main?tab=readme-ov-file#supported-models) at [README](../README.md). diff --git a/docs/release_process.md b/docs/en/release_process.md similarity index 100% rename from docs/release_process.md rename to docs/en/release_process.md diff --git a/docs/en/sampling_params.md b/docs/en/sampling_params.md new file mode 100644 index 00000000000..5f1cdece6a2 --- /dev/null +++ b/docs/en/sampling_params.md @@ -0,0 +1,417 @@ +# Sampling Parameters in SGLang Runtime +This doc describes the sampling parameters of the SGLang Runtime. + +The `/generate` endpoint accepts the following arguments in the JSON format. + +```python +@dataclass +class GenerateReqInput: + # The input prompt. It can be a single prompt or a batch of prompts. + text: Optional[Union[List[str], str]] = None + # The token ids for text; one can either specify text or input_ids. + input_ids: Optional[Union[List[List[int]], List[int]]] = None + # The image input. It can be a file name, a url, or base64 encoded string. + # See also python/sglang/srt/utils.py:load_image. + image_data: Optional[Union[List[str], str]] = None + # The sampling_params. See descriptions below. + sampling_params: Union[List[Dict], Dict] = None + # The request id. + rid: Optional[Union[List[str], str]] = None + # Whether to return logprobs. + return_logprob: Optional[Union[List[bool], bool]] = None + # The start location of the prompt for return_logprob. + logprob_start_len: Optional[Union[List[int], int]] = None + # The number of top logprobs to return. + top_logprobs_num: Optional[Union[List[int], int]] = None + # Whether to detokenize tokens in text in the returned logprobs. + return_text_in_logprobs: bool = False + # Whether to stream output. + stream: bool = False +``` + +The `sampling_params` follows this format + +```python +# The maximum number of output tokens +max_new_tokens: int = 128, +# Stop when hitting any of the strings in this list. +stop: Optional[Union[str, List[str]]] = None, +# Stop when hitting any of the token_ids in this list. Could be useful when mixed with +# `min_new_tokens`. +stop_token_ids: Optional[List[int]] = [], +# Sampling temperature +temperature: float = 1.0, +# Top-p sampling +top_p: float = 1.0, +# Top-k sampling +top_k: int = -1, +# Whether to ignore EOS token. +ignore_eos: bool = False, +# Whether to skip the special tokens during detokenization. +skip_special_tokens: bool = True, +# Whether to add spaces between special tokens during detokenization. +spaces_between_special_tokens: bool = True, +# Constrains the output to follow a given regular expression. +regex: Optional[str] = None, +# Do parallel sampling and return `n` outputs. +n: int = 1, + +## Penalties. See [Performance Implications on Penalties] section below for more informations. + +# Float that penalizes new tokens based on their frequency in the generated text so far. +# Values > 0 encourage the model to use new tokens, while values < 0 encourage the model to +# repeat tokens. Must be -2 <= value <= 2. Setting to 0 (default) will disable this penalty. +frequency_penalty: float = 0.0, +# Float that penalizes new tokens based on whether they appear in the generated text so far. +# Values > 0 encourage the model to use new tokens, while values < 0 encourage the model to repeat +# tokens. Must be -2 <= value <= 2. Setting to 0 (default) will disable this penalty. +presence_penalty: float = 0.0, +# Float that penalizes new tokens based on whether they appear in the prompt and the generated text +# so far. Values > 1 encourage the model to use new tokens, while values < 1 encourage the model to +# repeat tokens. Must be 0 <= value <= 2. Setting to 1 (default) will disable this penalty. +repetition_penalty: float = 1.0, +# Guides inference to generate at least this number of tokens by penalizing logits of tokenizer's +# EOS token and `stop_token_ids` to -inf, until the output token reaches given length. +# Note that any of the `stop` string can be generated before reaching `min_new_tokens`, as it is +# difficult to infer the correct token ID by given `stop` strings. +# Must be 0 <= value < max_new_tokens. Setting to 0 (default) will disable this penalty. +min_new_tokens: int = 0, +``` + +## Examples + +### Normal +Launch a server +``` +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 +``` + +Send a request +```python +import requests + +response = requests.post( + "http://localhost:30000/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0, + "max_new_tokens": 32, + }, + }, +) +print(response.json()) +``` + +### Streaming +Send a request and stream the output +```python +import requests, json + +response = requests.post( + "http://localhost:30000/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0, + "max_new_tokens": 32, + }, + "stream": True, + }, + stream=True, +) + +prev = 0 +for chunk in response.iter_lines(decode_unicode=False): + chunk = chunk.decode("utf-8") + if chunk and chunk.startswith("data:"): + if chunk == "data: [DONE]": + break + data = json.loads(chunk[5:].strip("\n")) + output = data["text"].strip() + print(output[prev:], end="", flush=True) + prev = len(output) +print("") +``` + +### Multi modal + +Launch a server +``` +python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.6-vicuna-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000 +``` + +Download an image +``` +curl -o example_image.png -L https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true +``` + +Send a request +```python +import requests + +response = requests.post( + "http://localhost:30000/generate", + json={ + "text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: \nDescribe this picture ASSISTANT:", + "image_data": "example_image.png", + "sampling_params": { + "temperature": 0, + "max_new_tokens": 32, + }, + }, +) +print(response.json()) +``` + +The `image_data` can be a file name, a URL, or a base64 encoded string. See also `python/sglang/srt/utils.py:load_image`. +Streaming is supported in a similar manner as [above](#streaming). + +## Performance Implications on Penalties + +While you can apply penalties by supplying relevant `sampling_params`, this comes with some drawbacks. + +These drawbacks will be applied to every single requests in the same batch, as penalizers also applies in batch. + +### Latency + +While we try to compute penalty algorithms through CUDA, it is still additional computation on top of the basic sampling logic. For detailed overhead, we recommend you to run your own benchmarks, but you can find samples below to get a glimpse. + +### Memory + +Since we compute penalty algorithms through CUDA, the logic stores relevant parameters on GPU. This is usually in a scale of `vocab_size` multiplied by `running_requests`. + +You can run your own benchmark with desired parameters on your own hardware to make sure it's not OOMing before using. + +Tuning `--mem-fraction-static` and/or `--max-running-requests` will help. See [here](hyperparameter_tuning.md#minor-tune---max-prefill-tokens---mem-fraction-static---max-running-requests) for more information. + +### Benchmarks + +All the benchmarks below were ran on NVIDIA H100 SXM5. + +
+ +#### Baseline + +Measured at [dc9d06d886151707f97d0b78095df9de262fd3c9](https://github.com/sgl-project/sglang/commit/dc9d06d886151707f97d0b78095df9de262fd3c9). + +``` +$ python3 -m sglang.bench_serving --backend sglang --port 8413 --dataset-name random --num-prompts 3000 --random-input 256 --random-output 512 + +============ Serving Benchmark Result ============ +Backend: sglang +Traffic request rate: inf +Successful requests: 3000 +Benchmark duration (s): 66.11 +Total input tokens: 378633 +Total generated tokens: 775651 +Total generated tokens (retokenized): 775118 +Request throughput (req/s): 45.38 +Input token throughput (tok/s): 5727.04 +Output token throughput (tok/s): 11732.16 +----------------End-to-End Latency---------------- +Mean E2E Latency (ms): 40881.94 +Median E2E Latency (ms): 43967.10 +---------------Time to First Token---------------- +Mean TTFT (ms): 19884.75 +Median TTFT (ms): 14226.56 +P99 TTFT (ms): 47738.97 +-----Time per Output Token (excl. 1st token)------ +Mean TPOT (ms): 91.96 +Median TPOT (ms): 90.11 +P99 TPOT (ms): 308.54 +---------------Inter-token Latency---------------- +Mean ITL (ms): 174.54 +Median ITL (ms): 58.56 +P99 ITL (ms): 440.18 +================================================== +``` + +#### All Together + +``` +$ python3 -m sglang.bench_serving --backend sglang --port 8413 --dataset-name random --num-prompts 3000 --random-input 256 --random-output 512 --extra-request-body '{ + "frequency_penalty": 1.1, + "presence_penalty": 1.1, + "repetition_penalty": 0.1, + "min_new_tokens": 5 +}' + +============ Serving Benchmark Result ============ +Backend: sglang +Traffic request rate: inf +Successful requests: 3000 +Benchmark duration (s): 78.35 +Total input tokens: 378633 +Total generated tokens: 775651 +Total generated tokens (retokenized): 774756 +Request throughput (req/s): 38.29 +Input token throughput (tok/s): 4832.86 +Output token throughput (tok/s): 9900.39 +----------------End-to-End Latency---------------- +Mean E2E Latency (ms): 49017.68 +Median E2E Latency (ms): 52825.70 +---------------Time to First Token---------------- +Mean TTFT (ms): 23892.60 +Median TTFT (ms): 18895.47 +P99 TTFT (ms): 57426.01 +-----Time per Output Token (excl. 1st token)------ +Mean TPOT (ms): 114.54 +Median TPOT (ms): 107.27 +P99 TPOT (ms): 293.31 +---------------Inter-token Latency---------------- +Mean ITL (ms): 205.68 +Median ITL (ms): 73.97 +P99 ITL (ms): 453.86 +================================================== +``` + +#### Frequency Penalty + +``` +$ python3 -m sglang.bench_serving --backend sglang --port 8413 --dataset-name random --num-prompts 3000 --random-input 256 --random-output 512 --extra-request-body '{ + "frequency_penalty": 1.1 +}' + +============ Serving Benchmark Result ============ +Backend: sglang +Traffic request rate: inf +Successful requests: 3000 +Benchmark duration (s): 72.72 +Total input tokens: 378633 +Total generated tokens: 775651 +Total generated tokens (retokenized): 774955 +Request throughput (req/s): 41.26 +Input token throughput (tok/s): 5206.84 +Output token throughput (tok/s): 10666.51 +----------------End-to-End Latency---------------- +Mean E2E Latency (ms): 45445.56 +Median E2E Latency (ms): 48960.39 +---------------Time to First Token---------------- +Mean TTFT (ms): 22363.16 +Median TTFT (ms): 17125.02 +P99 TTFT (ms): 52920.95 +-----Time per Output Token (excl. 1st token)------ +Mean TPOT (ms): 104.71 +Median TPOT (ms): 98.30 +P99 TPOT (ms): 268.06 +---------------Inter-token Latency---------------- +Mean ITL (ms): 191.60 +Median ITL (ms): 67.83 +P99 ITL (ms): 455.46 +================================================== +``` + +#### Presence Penalty + +``` +$ python3 -m sglang.bench_serving --backend sglang --port 8413 --dataset-name random --num-prompts 3000 --random-input 256 --random-output 512 --extra-request-body '{ + "presence_penalty": 1.1 +}' + +============ Serving Benchmark Result ============ +Backend: sglang +Traffic request rate: inf +Successful requests: 3000 +Benchmark duration (s): 72.04 +Total input tokens: 378633 +Total generated tokens: 775651 +Total generated tokens (retokenized): 775210 +Request throughput (req/s): 41.64 +Input token throughput (tok/s): 5255.98 +Output token throughput (tok/s): 10767.18 +----------------End-to-End Latency---------------- +Mean E2E Latency (ms): 44926.61 +Median E2E Latency (ms): 48302.88 +---------------Time to First Token---------------- +Mean TTFT (ms): 22095.39 +Median TTFT (ms): 16740.93 +P99 TTFT (ms): 52554.03 +-----Time per Output Token (excl. 1st token)------ +Mean TPOT (ms): 103.54 +Median TPOT (ms): 97.37 +P99 TPOT (ms): 271.86 +---------------Inter-token Latency---------------- +Mean ITL (ms): 189.86 +Median ITL (ms): 68.45 +P99 ITL (ms): 447.11 +================================================== +``` + +#### Repetition Penalty + +``` +$ python3 -m sglang.bench_serving --backend sglang --port 8413 --dataset-name random --num-prompts 3000 --random-input 256 --random-output 512 --extra-request-body '{ + "repetition_penalty": 0.1 +}' + +============ Serving Benchmark Result ============ +Backend: sglang +Traffic request rate: inf +Successful requests: 3000 +Benchmark duration (s): 74.54 +Total input tokens: 378633 +Total generated tokens: 775651 +Total generated tokens (retokenized): 766008 +Request throughput (req/s): 40.24 +Input token throughput (tok/s): 5079.36 +Output token throughput (tok/s): 10405.35 +----------------End-to-End Latency---------------- +Mean E2E Latency (ms): 46530.38 +Median E2E Latency (ms): 50302.65 +---------------Time to First Token---------------- +Mean TTFT (ms): 22603.47 +Median TTFT (ms): 17167.08 +P99 TTFT (ms): 54497.85 +-----Time per Output Token (excl. 1st token)------ +Mean TPOT (ms): 117.59 +Median TPOT (ms): 101.79 +P99 TPOT (ms): 320.04 +---------------Inter-token Latency---------------- +Mean ITL (ms): 195.26 +Median ITL (ms): 69.51 +P99 ITL (ms): 433.86 +================================================== +``` + +#### Min New Tokens + +The min new tokens penalizer computes until generation process reaches given `min_new_tokens`. + +Dislike other penalizers, setting this to higher value will have more latency implications. + +``` +$ python3 -m sglang.bench_serving --backend sglang --port 8413 --dataset-name random --num-prompts 3000 --random-input 256 --random-output 512 --extra-request-body '{ + "min_new_tokens": 5 +}' + +============ Serving Benchmark Result ============ +Backend: sglang +Traffic request rate: inf +Successful requests: 3000 +Benchmark duration (s): 66.94 +Total input tokens: 378633 +Total generated tokens: 775651 +Total generated tokens (retokenized): 775220 +Request throughput (req/s): 44.81 +Input token throughput (tok/s): 5656.13 +Output token throughput (tok/s): 11586.90 +----------------End-to-End Latency---------------- +Mean E2E Latency (ms): 41888.55 +Median E2E Latency (ms): 45354.16 +---------------Time to First Token---------------- +Mean TTFT (ms): 20866.91 +Median TTFT (ms): 16219.79 +P99 TTFT (ms): 49263.91 +-----Time per Output Token (excl. 1st token)------ +Mean TPOT (ms): 97.05 +Median TPOT (ms): 89.76 +P99 TPOT (ms): 233.50 +---------------Inter-token Latency---------------- +Mean ITL (ms): 179.17 +Median ITL (ms): 55.08 +P99 ITL (ms): 409.12 +================================================== +``` + +
diff --git a/docs/en/setup_runner.md b/docs/en/setup_runner.md new file mode 100644 index 00000000000..34f4576845b --- /dev/null +++ b/docs/en/setup_runner.md @@ -0,0 +1,34 @@ +# Set up self hosted runner for GitHub Action + +## Config Runner + +```bash +# https://github.com/sgl-project/sglang/settings/actions/runners/new?arch=x64&os=linux +# Involves some TOKEN and other private information, click the link to view specific steps. +``` + +## Start Runner + +add `/lib/systemd/system/runner.service` +``` +[Unit] +StartLimitIntervalSec=0 +[Service] +Environment="CUDA_VISIBLE_DEVICES=7" +Environment="XDG_CACHE_HOME=/data/.cache" +Environment="HF_TOKEN=hf_**" +Environment="OPENAI_API_KEY=sk-**" +Environment="HOME=/data/zhyncs" +Restart=always +RestartSec=1 +ExecStart=/data/zhyncs/actions-runner/run.sh +[Install] +WantedBy=multi-user.target +``` + +```bash +sudo systemctl daemon-reload +sudo systemctl start runner +sudo systemctl enable runner +sudo systemctl status runner +``` diff --git a/docs/en/troubleshooting.md b/docs/en/troubleshooting.md new file mode 100644 index 00000000000..c6c016fd1f7 --- /dev/null +++ b/docs/en/troubleshooting.md @@ -0,0 +1,13 @@ +# Troubleshooting + +This page lists some common errors and tips for fixing them. + +## CUDA error: an illegal memory access was encountered +This error may be due to kernel errors or out-of-memory issues. +- If it is a kernel error, it is not easy to fix. +- If it is out-of-memory, sometimes it will report this error instead of "Out-of-memory." In this case, try setting a smaller value for `--mem-fraction-static`. The default value of `--mem-fraction-static` is around 0.8 - 0.9. https://github.com/sgl-project/sglang/blob/1edd4e07d6ad52f4f63e7f6beaa5987c1e1cf621/python/sglang/srt/server_args.py#L92-L102 + +## The server hangs +If the server hangs, try disabling some optimizations when launching the server. +- Add `--disable-cuda-graph`. +- Add `--disable-flashinfer-sampling`. diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 00000000000..826a34bc157 --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,12 @@ +markdown>=3.4.0 +myst-parser +sphinx +sphinx-book-theme +sphinx-copybutton +sphinx-tabs +sphinxcontrib-mermaid +pillow +pydantic +torch +transformers +urllib3<2.0.0 diff --git a/docs/sampling_params.md b/docs/sampling_params.md deleted file mode 100644 index 065bbc2d534..00000000000 --- a/docs/sampling_params.md +++ /dev/null @@ -1,107 +0,0 @@ -## Sampling Parameters of SGLang Runtime -This doc describes the sampling parameters of the SGLang Runtime. - -The `/generate` endpoint accepts the following arguments in the JSON format. - -```python -@dataclass -class GenerateReqInput: - # The input prompt - text: Union[List[str], str] - # The token ids for text; one can either specify text or input_ids - input_ids: Optional[Union[List[List[int]], List[int]]] = None - # The image input - image_data: Optional[Union[List[str], str]] = None - # The sampling_params - sampling_params: Union[List[Dict], Dict] = None - # The request id - rid: Optional[Union[List[str], str]] = None - # Whether to return logprobs - return_logprob: Optional[Union[List[bool], bool]] = None - # The start location of the prompt for return_logprob - logprob_start_len: Optional[Union[List[int], int]] = None - # The number of top logprobs to return - top_logprobs_num: Optional[Union[List[int], int]] = None - # Whether to detokenize tokens in logprobs - return_text_in_logprobs: bool = False - # Whether to stream output - stream: bool = False -``` - -The `sampling_params` follows this format - -```python -class SamplingParams: - def __init__( - self, - max_new_tokens: int = 16, - stop: Optional[Union[str, List[str]]] = None, - temperature: float = 1.0, - top_p: float = 1.0, - top_k: int = -1, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, - ignore_eos: bool = False, - skip_special_tokens: bool = True, - dtype: Optional[str] = None, - regex: Optional[str] = None, - ) -> None: -``` - -## Examples - -### Normal -``` -python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 -``` - -```python -import requests - -response = requests.post( - "http://localhost:30000/generate", - json={ - "text": "The capital of France is", - "sampling_params": { - "temperature": 0, - "max_new_tokens": 32, - }, - }, -) -print(response.json()) -``` - -### Streaming - -```python -import requests, json - -response = requests.post( - "http://localhost:30000/generate", - json={ - "text": "The capital of France is", - "sampling_params": { - "temperature": 0, - "max_new_tokens": 256, - }, - "stream": True, - }, - stream=True, -) - -prev = 0 -for chunk in response.iter_lines(decode_unicode=False): - chunk = chunk.decode("utf-8") - if chunk and chunk.startswith("data:"): - if chunk == "data: [DONE]": - break - data = json.loads(chunk[5:].strip("\n")) - output = data["text"].strip() - print(output[prev:], end="", flush=True) - prev = len(output) -print("") -``` - -### Multi modal - -See [test_httpserver_llava.py](../test/srt/test_httpserver_llava.py). diff --git a/docs/test_process.md b/docs/test_process.md deleted file mode 100644 index 18f91c6d4c3..00000000000 --- a/docs/test_process.md +++ /dev/null @@ -1,94 +0,0 @@ -## SRT Unit Tests - -### Latency Alignment -Make sure your changes do not slow down the following benchmarks -``` -# single gpu -python -m sglang.bench_latency --model-path meta-llama/Llama-2-7b-chat-hf --mem-fraction-static 0.8 --batch 32 --input-len 512 --output-len 256 -python -m sglang.bench_latency --model-path meta-llama/Llama-2-7b-chat-hf --mem-fraction-static 0.8 --batch 1 --input-len 512 --output-len 256 - -# multiple gpu -python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-70B --tp 8 --mem-fraction-static 0.6 --batch 32 --input-len 8192 --output-len 1 -python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-70B --tp 8 --mem-fraction-static 0.6 --batch 1 --input-len 8100 --output-len 32 - -# moe model -python -m sglang.bench_latency --model-path databricks/dbrx-base --tp 8 --mem-fraction-static 0.6 --batch 4 --input-len 1024 --output-len 32 -``` - -### High-level API - -``` -python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 -``` - -``` -cd test/lang -python3 test_srt_backend.py -``` - -### Performance - -#### MMLU -``` -cd benchmark/mmlu -``` -Follow README.md to download the data. - -``` -python3 bench_sglang.py --nsub 3 - -# Expected performance on A10G -# Total latency: 8.200 -# Average accuracy: 0.413 -``` - -#### GSM-8K -``` -cd benchmark/gsm8k -``` -Follow README.md to download the data. - -``` -python3 bench_sglang.py --num-q 200 - -# Expected performance on A10G -# Latency: 32.103 -# Accuracy: 0.250 -``` - -#### More -Please also test `benchmark/hellaswag`, `benchmark/latency_throughput`. - -### More Models - -#### LLaVA - -``` -python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000 -``` - -``` -cd benchmark/llava_bench -python3 bench_sglang.py - -# Expected performance on A10G -# Latency: 50.031 -``` - -## SGLang Unit Tests -``` -export ANTHROPIC_API_KEY= -export OPENAI_API_KEY= -python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 -``` - -``` -cd test/lang -python3 run_all.py -``` - -## OpenAI API server -``` -cd test/srt -python test_openai_server.py -``` \ No newline at end of file diff --git a/examples/quick_start/anthropic_example_chat.py b/examples/quick_start/anthropic_example_chat.py index 03dbb0a454c..03d699be70b 100644 --- a/examples/quick_start/anthropic_example_chat.py +++ b/examples/quick_start/anthropic_example_chat.py @@ -3,6 +3,7 @@ export ANTHROPIC_API_KEY=sk-****** python3 anthropic_example_chat.py """ + import sglang as sgl @@ -30,7 +31,7 @@ def stream(): state = multi_turn_question.run( question_1="What is the capital of the United States?", question_2="List two local attractions.", - stream=True + stream=True, ) for out in state.text_iter(): @@ -39,13 +40,18 @@ def stream(): def batch(): - states = multi_turn_question.run_batch([ - {"question_1": "What is the capital of the United States?", - "question_2": "List two local attractions."}, - - {"question_1": "What is the capital of France?", - "question_2": "What is the population of this city?"}, - ]) + states = multi_turn_question.run_batch( + [ + { + "question_1": "What is the capital of the United States?", + "question_2": "List two local attractions.", + }, + { + "question_1": "What is the capital of France?", + "question_2": "What is the population of this city?", + }, + ] + ) for s in states: print(s.messages()) diff --git a/examples/quick_start/anthropic_example_complete.py b/examples/quick_start/anthropic_example_complete.py index 35d0e8f6266..bce2a61ea0f 100644 --- a/examples/quick_start/anthropic_example_complete.py +++ b/examples/quick_start/anthropic_example_complete.py @@ -9,15 +9,14 @@ @sgl.function def few_shot_qa(s, question): - s += ( -""" + s += """ \n\nHuman: What is the capital of France? \n\nAssistant: Paris \n\nHuman: What is the capital of Germany? \n\nAssistant: Berlin \n\nHuman: What is the capital of Italy? \n\nAssistant: Rome -""") +""" s += "\n\nHuman: " + question + "\n" s += "\n\nAssistant:" + sgl.gen("answer", temperature=0) @@ -33,8 +32,8 @@ def single(): def stream(): state = few_shot_qa.run( - question="What is the capital of the United States?", - stream=True) + question="What is the capital of the United States?", stream=True + ) for out in state.text_iter("answer"): print(out, end="", flush=True) @@ -42,10 +41,12 @@ def stream(): def batch(): - states = few_shot_qa.run_batch([ - {"question": "What is the capital of the United States?"}, - {"question": "What is the capital of China?"}, - ]) + states = few_shot_qa.run_batch( + [ + {"question": "What is the capital of the United States?"}, + {"question": "What is the capital of China?"}, + ] + ) for s in states: print(s["answer"]) diff --git a/examples/quick_start/azure_openai_example_chat.py b/examples/quick_start/azure_openai_example_chat.py index 3c40af8d2dd..d53f935f4b3 100644 --- a/examples/quick_start/azure_openai_example_chat.py +++ b/examples/quick_start/azure_openai_example_chat.py @@ -3,9 +3,11 @@ export AZURE_OPENAI_API_KEY=sk-****** python3 openai_example_chat.py """ -import sglang as sgl + import os +import sglang as sgl + @sgl.function def multi_turn_question(s, question_1, question_2): @@ -32,7 +34,7 @@ def stream(): state = multi_turn_question.run( question_1="What is the capital of the United States?", question_2="List two local attractions.", - stream=True + stream=True, ) for out in state.text_iter(): @@ -41,13 +43,18 @@ def stream(): def batch(): - states = multi_turn_question.run_batch([ - {"question_1": "What is the capital of the United States?", - "question_2": "List two local attractions."}, - - {"question_1": "What is the capital of France?", - "question_2": "What is the population of this city?"}, - ]) + states = multi_turn_question.run_batch( + [ + { + "question_1": "What is the capital of the United States?", + "question_2": "List two local attractions.", + }, + { + "question_1": "What is the capital of France?", + "question_2": "What is the population of this city?", + }, + ] + ) for s in states: print(s.messages()) diff --git a/examples/quick_start/gemini_example_chat.py b/examples/quick_start/gemini_example_chat.py index aafa1665cc6..0ae6231095f 100644 --- a/examples/quick_start/gemini_example_chat.py +++ b/examples/quick_start/gemini_example_chat.py @@ -3,6 +3,7 @@ export GCP_PROJECT_ID=****** python3 gemini_example_chat.py """ + import sglang as sgl @@ -30,7 +31,7 @@ def stream(): state = multi_turn_question.run( question_1="What is the capital of the United States?", question_2="List two local attractions.", - stream=True + stream=True, ) for out in state.text_iter(): @@ -39,13 +40,18 @@ def stream(): def batch(): - states = multi_turn_question.run_batch([ - {"question_1": "What is the capital of the United States?", - "question_2": "List two local attractions."}, - - {"question_1": "What is the capital of France?", - "question_2": "What is the population of this city?"}, - ]) + states = multi_turn_question.run_batch( + [ + { + "question_1": "What is the capital of the United States?", + "question_2": "List two local attractions.", + }, + { + "question_1": "What is the capital of France?", + "question_2": "What is the population of this city?", + }, + ] + ) for s in states: print(s.messages()) diff --git a/examples/quick_start/gemini_example_complete.py b/examples/quick_start/gemini_example_complete.py index 255a3ad4c0b..5188bf4185b 100644 --- a/examples/quick_start/gemini_example_complete.py +++ b/examples/quick_start/gemini_example_complete.py @@ -9,15 +9,14 @@ @sgl.function def few_shot_qa(s, question): - s += ( -"""The following are questions with answers. + s += """The following are questions with answers. Q: What is the capital of France? A: Paris Q: What is the capital of Germany? A: Berlin Q: What is the capital of Italy? A: Rome -""") +""" s += "Q: " + question + "\n" s += "A:" + sgl.gen("answer", stop="\n", temperature=0) @@ -33,8 +32,8 @@ def single(): def stream(): state = few_shot_qa.run( - question="What is the capital of the United States?", - stream=True) + question="What is the capital of the United States?", stream=True + ) for out in state.text_iter("answer"): print(out, end="", flush=True) @@ -42,10 +41,12 @@ def stream(): def batch(): - states = few_shot_qa.run_batch([ - {"question": "What is the capital of the United States?"}, - {"question": "What is the capital of China?"}, - ]) + states = few_shot_qa.run_batch( + [ + {"question": "What is the capital of the United States?"}, + {"question": "What is the capital of China?"}, + ] + ) for s in states: print(s["answer"]) diff --git a/examples/quick_start/gemini_example_multimodal_chat.py b/examples/quick_start/gemini_example_multimodal_chat.py index fa5e6e8b7f3..afe0c723ff1 100644 --- a/examples/quick_start/gemini_example_multimodal_chat.py +++ b/examples/quick_start/gemini_example_multimodal_chat.py @@ -3,6 +3,7 @@ export GCP_PROJECT_ID=****** python3 gemini_example_multimodal_chat.py """ + import sglang as sgl @@ -19,7 +20,7 @@ def image_qa(s, image_file1, image_file2, question): image_file1="./images/cat.jpeg", image_file2="./images/dog.jpeg", question="Describe difference of the two images in one sentence.", - stream=True + stream=True, ) for out in state.text_iter("answer"): diff --git a/examples/quick_start/openai_example_chat.py b/examples/quick_start/openai_example_chat.py index 66b8536c070..9511e21cf43 100644 --- a/examples/quick_start/openai_example_chat.py +++ b/examples/quick_start/openai_example_chat.py @@ -3,6 +3,7 @@ export OPENAI_API_KEY=sk-****** python3 openai_example_chat.py """ + import sglang as sgl @@ -31,7 +32,7 @@ def stream(): state = multi_turn_question.run( question_1="What is the capital of the United States?", question_2="List two local attractions.", - stream=True + stream=True, ) for out in state.text_iter(): @@ -40,13 +41,18 @@ def stream(): def batch(): - states = multi_turn_question.run_batch([ - {"question_1": "What is the capital of the United States?", - "question_2": "List two local attractions."}, - - {"question_1": "What is the capital of France?", - "question_2": "What is the population of this city?"}, - ]) + states = multi_turn_question.run_batch( + [ + { + "question_1": "What is the capital of the United States?", + "question_2": "List two local attractions.", + }, + { + "question_1": "What is the capital of France?", + "question_2": "What is the population of this city?", + }, + ] + ) for s in states: print(s.messages()) diff --git a/examples/quick_start/openai_example_complete.py b/examples/quick_start/openai_example_complete.py index 41b3c9904da..d64bcaf1c30 100644 --- a/examples/quick_start/openai_example_complete.py +++ b/examples/quick_start/openai_example_complete.py @@ -9,15 +9,14 @@ @sgl.function def few_shot_qa(s, question): - s += ( -"""The following are questions with answers. + s += """The following are questions with answers. Q: What is the capital of France? A: Paris Q: What is the capital of Germany? A: Berlin Q: What is the capital of Italy? A: Rome -""") +""" s += "Q: " + question + "\n" s += "A:" + sgl.gen("answer", stop="\n", temperature=0) @@ -33,8 +32,8 @@ def single(): def stream(): state = few_shot_qa.run( - question="What is the capital of the United States?", - stream=True) + question="What is the capital of the United States?", stream=True + ) for out in state.text_iter("answer"): print(out, end="", flush=True) @@ -42,10 +41,12 @@ def stream(): def batch(): - states = few_shot_qa.run_batch([ - {"question": "What is the capital of the United States?"}, - {"question": "What is the capital of China?"}, - ]) + states = few_shot_qa.run_batch( + [ + {"question": "What is the capital of the United States?"}, + {"question": "What is the capital of China?"}, + ] + ) for s in states: print(s["answer"]) diff --git a/examples/quick_start/openrouter_example_chat.py b/examples/quick_start/openrouter_example_chat.py index 43ac3d4e28b..a0b6f15bcbc 100644 --- a/examples/quick_start/openrouter_example_chat.py +++ b/examples/quick_start/openrouter_example_chat.py @@ -3,9 +3,11 @@ export OPENROUTER_API_KEY=sk-****** python3 together_example_chat.py """ -import sglang as sgl + import os +import sglang as sgl + @sgl.function def multi_turn_question(s, question_1, question_2): diff --git a/examples/quick_start/srt_example_chat.py b/examples/quick_start/srt_example_chat.py index 2f261b0952e..b1e1658a2a9 100644 --- a/examples/quick_start/srt_example_chat.py +++ b/examples/quick_start/srt_example_chat.py @@ -2,6 +2,7 @@ Usage: python3 srt_example_chat.py """ + import sglang as sgl @@ -29,7 +30,7 @@ def stream(): state = multi_turn_question.run( question_1="What is the capital of the United States?", question_2="List two local attractions.", - stream=True + stream=True, ) for out in state.text_iter(): @@ -38,13 +39,18 @@ def stream(): def batch(): - states = multi_turn_question.run_batch([ - {"question_1": "What is the capital of the United States?", - "question_2": "List two local attractions."}, - - {"question_1": "What is the capital of France?", - "question_2": "What is the population of this city?"}, - ]) + states = multi_turn_question.run_batch( + [ + { + "question_1": "What is the capital of the United States?", + "question_2": "List two local attractions.", + }, + { + "question_1": "What is the capital of France?", + "question_2": "What is the population of this city?", + }, + ] + ) for s in states: print(s.messages()) diff --git a/examples/quick_start/srt_example_complete.py b/examples/quick_start/srt_example_complete.py index 20089167099..056245979f4 100644 --- a/examples/quick_start/srt_example_complete.py +++ b/examples/quick_start/srt_example_complete.py @@ -2,20 +2,20 @@ Usage: python3 srt_example_complete.py """ + import sglang as sgl @sgl.function def few_shot_qa(s, question): - s += ( -"""The following are questions with answers. + s += """The following are questions with answers. Q: What is the capital of France? A: Paris Q: What is the capital of Germany? A: Berlin Q: What is the capital of Italy? A: Rome -""") +""" s += "Q: " + question + "\n" s += "A:" + sgl.gen("answer", stop="\n", temperature=0) @@ -31,8 +31,8 @@ def single(): def stream(): state = few_shot_qa.run( - question="What is the capital of the United States?", - stream=True) + question="What is the capital of the United States?", stream=True + ) for out in state.text_iter("answer"): print(out, end="", flush=True) @@ -40,10 +40,12 @@ def stream(): def batch(): - states = few_shot_qa.run_batch([ - {"question": "What is the capital of the United States?"}, - {"question": "What is the capital of China?"}, - ]) + states = few_shot_qa.run_batch( + [ + {"question": "What is the capital of the United States?"}, + {"question": "What is the capital of China?"}, + ] + ) for s in states: print(s["answer"]) diff --git a/examples/quick_start/srt_example_llava.py b/examples/quick_start/srt_example_llava.py index 27685b1d251..5d8f752394f 100644 --- a/examples/quick_start/srt_example_llava.py +++ b/examples/quick_start/srt_example_llava.py @@ -1,6 +1,7 @@ """ Usage: python3 srt_example_llava.py """ + import sglang as sgl @@ -12,9 +13,8 @@ def image_qa(s, image_path, question): def single(): state = image_qa.run( - image_path="images/cat.jpeg", - question="What is this?", - max_new_tokens=128) + image_path="images/cat.jpeg", question="What is this?", max_new_tokens=128 + ) print(state["answer"], "\n") @@ -23,7 +23,8 @@ def stream(): image_path="images/cat.jpeg", question="What is this?", max_new_tokens=64, - stream=True) + stream=True, + ) for out in state.text_iter("answer"): print(out, end="", flush=True) @@ -33,8 +34,8 @@ def stream(): def batch(): states = image_qa.run_batch( [ - {"image_path": "images/cat.jpeg", "question":"What is this?"}, - {"image_path": "images/dog.jpeg", "question":"What is this?"}, + {"image_path": "images/cat.jpeg", "question": "What is this?"}, + {"image_path": "images/dog.jpeg", "question": "What is this?"}, ], max_new_tokens=128, ) @@ -43,8 +44,10 @@ def batch(): if __name__ == "__main__": - runtime = sgl.Runtime(model_path="liuhaotian/llava-v1.6-vicuna-7b", - tokenizer_path="llava-hf/llava-1.5-7b-hf") + runtime = sgl.Runtime( + model_path="liuhaotian/llava-v1.6-vicuna-7b", + tokenizer_path="llava-hf/llava-1.5-7b-hf", + ) sgl.set_default_backend(runtime) print(f"chat template: {runtime.endpoint.chat_template.name}") diff --git a/examples/quick_start/srt_example_yi_vl.py b/examples/quick_start/srt_example_yi_vl.py index 359aacac310..66c7d57126c 100644 --- a/examples/quick_start/srt_example_yi_vl.py +++ b/examples/quick_start/srt_example_yi_vl.py @@ -3,6 +3,7 @@ Requirements: transformers==4.38 """ + import sglang as sgl @@ -17,7 +18,8 @@ def single(): image_path="images/cat.jpeg", question="What is this?", max_new_tokens=64, - stop="###") + stop="###", + ) print(state["answer"], "\n") @@ -27,7 +29,8 @@ def stream(): question="What is this?", max_new_tokens=64, stream=True, - stop="###") + stop="###", + ) for out in state.text_iter("answer"): print(out, end="", flush=True) @@ -37,11 +40,11 @@ def stream(): def batch(): states = image_qa.run_batch( [ - {"image_path": "images/cat.jpeg", "question":"What is this?"}, - {"image_path": "images/dog.jpeg", "question":"What is this?"}, + {"image_path": "images/cat.jpeg", "question": "What is this?"}, + {"image_path": "images/dog.jpeg", "question": "What is this?"}, ], max_new_tokens=64, - stop="###" + stop="###", ) for s in states: print(s["answer"], "\n") diff --git a/examples/quick_start/together_example_chat.py b/examples/quick_start/together_example_chat.py index d2834f44e02..2d2059062e6 100644 --- a/examples/quick_start/together_example_chat.py +++ b/examples/quick_start/together_example_chat.py @@ -3,9 +3,11 @@ export TOGETHER_API_KEY=sk-****** python3 together_example_chat.py """ -import sglang as sgl + import os +import sglang as sgl + @sgl.function def multi_turn_question(s, question_1, question_2): @@ -32,7 +34,7 @@ def stream(): state = multi_turn_question.run( question_1="What is the capital of the United States?", question_2="List two local attractions.", - stream=True + stream=True, ) for out in state.text_iter(): @@ -41,13 +43,18 @@ def stream(): def batch(): - states = multi_turn_question.run_batch([ - {"question_1": "What is the capital of the United States?", - "question_2": "List two local attractions."}, - - {"question_1": "What is the capital of France?", - "question_2": "What is the population of this city?"}, - ]) + states = multi_turn_question.run_batch( + [ + { + "question_1": "What is the capital of the United States?", + "question_2": "List two local attractions.", + }, + { + "question_1": "What is the capital of France?", + "question_2": "What is the population of this city?", + }, + ] + ) for s in states: print(s.messages()) diff --git a/examples/quick_start/together_example_complete.py b/examples/quick_start/together_example_complete.py index 011c652fd59..d9119ed6cba 100644 --- a/examples/quick_start/together_example_complete.py +++ b/examples/quick_start/together_example_complete.py @@ -4,21 +4,21 @@ python3 together_example_complete.py """ -import sglang as sgl import os +import sglang as sgl + @sgl.function def few_shot_qa(s, question): - s += ( -"""The following are questions with answers. + s += """The following are questions with answers. Q: What is the capital of France? A: Paris Q: What is the capital of Germany? A: Berlin Q: What is the capital of Italy? A: Rome -""") +""" s += "Q: " + question + "\n" s += "A:" + sgl.gen("answer", stop="\n", temperature=0) @@ -34,8 +34,8 @@ def single(): def stream(): state = few_shot_qa.run( - question="What is the capital of the United States?", - stream=True) + question="What is the capital of the United States?", stream=True + ) for out in state.text_iter("answer"): print(out, end="", flush=True) @@ -43,10 +43,12 @@ def stream(): def batch(): - states = few_shot_qa.run_batch([ - {"question": "What is the capital of the United States?"}, - {"question": "What is the capital of China?"}, - ]) + states = few_shot_qa.run_batch( + [ + {"question": "What is the capital of the United States?"}, + {"question": "What is the capital of China?"}, + ] + ) for s in states: print(s["answer"]) diff --git a/examples/usage/async_io.py b/examples/usage/async_io.py index 68714812fba..d12a3a4d9df 100644 --- a/examples/usage/async_io.py +++ b/examples/usage/async_io.py @@ -2,7 +2,9 @@ Usage: python3 async_io.py """ + import asyncio + from sglang import Runtime @@ -14,7 +16,10 @@ async def generate( tokenizer = engine.get_tokenizer() messages = [ - {"role": "system", "content": "You will be given question answer tasks.",}, + { + "role": "system", + "content": "You will be given question answer tasks.", + }, {"role": "user", "content": prompt}, ] @@ -36,5 +41,5 @@ async def generate( prompt = "Who is Alan Turing?" sampling_params = {"max_new_tokens": 128} asyncio.run(generate(runtime, prompt, sampling_params)) - + runtime.shutdown() diff --git a/examples/usage/choices_logprob.py b/examples/usage/choices_logprob.py index e261668f8a6..6cd733fe90a 100644 --- a/examples/usage/choices_logprob.py +++ b/examples/usage/choices_logprob.py @@ -20,8 +20,8 @@ def main(): print("questions:", question) print("choice:", state["tool"]) meta_info = state.get_meta_info("tool") - print("logprobs of choice 1", meta_info["prefill_token_logprobs"][0]) - print("logprobs of choice 2", meta_info["prefill_token_logprobs"][1]) + print("logprobs of choice 1", meta_info["input_token_logprobs"][0]) + print("logprobs of choice 2", meta_info["input_token_logprobs"][1]) print("-" * 50) # Run a batch @@ -34,8 +34,8 @@ def main(): print("questions:", question) print("choice:", state["tool"]) meta_info = state.get_meta_info("tool") - print("logprobs of choice 1", meta_info["prefill_token_logprobs"][0]) - print("logprobs of choice 2", meta_info["prefill_token_logprobs"][1]) + print("logprobs of choice 1", meta_info["input_token_logprobs"][0]) + print("logprobs of choice 2", meta_info["input_token_logprobs"][1]) print("-" * 50) diff --git a/examples/usage/cot_decoding.py b/examples/usage/cot_decoding.py index d81a813c81f..38fbde855bf 100644 --- a/examples/usage/cot_decoding.py +++ b/examples/usage/cot_decoding.py @@ -31,10 +31,9 @@ def cot_decoding(s, question, get_top_k, is_chat_model, verbose): top_logprobs_num=get_top_k, return_text_in_logprobs=True, ) - logprobs = step_0.get_meta_info("get_top_k")["decode_top_logprobs"][0] + logprobs = step_0.get_meta_info("get_top_k")["output_top_logprobs"][0] - print("Decoding step 0:", - ", ".join(pformat(token[2]) for token in logprobs)) + print("Decoding step 0:", ", ".join(pformat(token[2]) for token in logprobs)) for idx, (f, token) in enumerate(zip(forks, logprobs)): logprob, token_id, text = token f += text @@ -56,17 +55,9 @@ def cot_decoding(s, question, get_top_k, is_chat_model, verbose): ) # calculate probability disparity between the top and secondary tokens - x1s = [ - exp(xt[0][0]) - for xt in f.get_meta_info("answer")["decode_top_logprobs"] - ] - x2s = [ - exp(xt[1][0]) - for xt in f.get_meta_info("answer")["decode_top_logprobs"] - ] - tokens = [ - xt[0][2] for xt in f.get_meta_info("answer")["decode_top_logprobs"] - ] + x1s = [exp(xt[0][0]) for xt in f.get_meta_info("answer")["output_top_logprobs"]] + x2s = [exp(xt[1][0]) for xt in f.get_meta_info("answer")["output_top_logprobs"]] + tokens = [xt[0][2] for xt in f.get_meta_info("answer")["output_top_logprobs"]] delta = (sum(x1s) - sum(x2s)) / len(x1s) # extract the answer span (without the '<|end_of_text|>' token) @@ -79,42 +70,45 @@ def cot_decoding(s, question, get_top_k, is_chat_model, verbose): top_logprobs_num=2, return_text_in_logprobs=True, ) - answer = answer_forks[idx]['answer_span'].replace('\n', ' ').strip(':') + answer = answer_forks[idx]["answer_span"].replace("\n", " ").strip(":") print( f"{YELLOW}Path #{idx} {pformat(text)}[{exp(logprob):.3f}] (score={delta}, answer={answer}){CLEAR}" ) - generated_text = str(answer_forks[idx])[len("ProgramState("):-1] + generated_text = str(answer_forks[idx])[len("ProgramState(") : -1] print(f"{BLUE}{pformat(generated_text)}{CLEAR}") if verbose: answer_tokens = [ - xt[0][2] for xt in answer_forks[idx].get_meta_info( - "answer_span")["decode_top_logprobs"] + xt[0][2] + for xt in answer_forks[idx].get_meta_info("answer_span")[ + "output_top_logprobs" + ] ] answer_x1s = [ - exp(xt[0][0]) for xt in answer_forks[idx].get_meta_info( - "answer_span")["decode_top_logprobs"] + exp(xt[0][0]) + for xt in answer_forks[idx].get_meta_info("answer_span")[ + "output_top_logprobs" + ] ] answer_x2s = [ - exp(xt[1][0]) for xt in answer_forks[idx].get_meta_info( - "answer_span")["decode_top_logprobs"] + exp(xt[1][0]) + for xt in answer_forks[idx].get_meta_info("answer_span")[ + "output_top_logprobs" + ] ] for token, x1, x2 in zip(tokens, x1s, x2s): - print(f" {GREEN}{pformat(token)}{CLEAR}({x1:.3f}-{x2:.3f})", - end="") + print(f" {GREEN}{pformat(token)}{CLEAR}({x1:.3f}-{x2:.3f})", end="") print("\n===========") for token, x1, x2 in zip(answer_tokens, answer_x1s, answer_x2s): - print(f" {GREEN}{pformat(token)}{CLEAR}({x1:.3f}-{x2:.3f})", - end="") + print(f" {GREEN}{pformat(token)}{CLEAR}({x1:.3f}-{x2:.3f})", end="") print() sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000")) state = cot_decoding.run( - question= - r"Claire makes a 3 egg omelet every morning for breakfast. How many dozens of eggs will she eat in 4 weeks?", + question=r"Claire makes a 3 egg omelet every morning for breakfast. How many dozens of eggs will she eat in 4 weeks?", get_top_k=10, is_chat_model=True, verbose=False, diff --git a/examples/usage/json_decode.py b/examples/usage/json_decode.py index ec2323e68c3..dc34d3527ba 100644 --- a/examples/usage/json_decode.py +++ b/examples/usage/json_decode.py @@ -3,10 +3,12 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 python json_decode.py """ + from enum import Enum -import sglang as sgl from pydantic import BaseModel + +import sglang as sgl from sglang.srt.constrained import build_regex_from_object character_regex = ( diff --git a/examples/usage/json_logprobs.py b/examples/usage/json_logprobs.py index 6b5b9c8fcea..fa0e1b81f33 100644 --- a/examples/usage/json_logprobs.py +++ b/examples/usage/json_logprobs.py @@ -56,14 +56,14 @@ def srt_api_request(name): # fout.write(json.dumps(res, indent=4)) meta_info = res["meta_info"] - assert len(meta_info["prefill_token_logprobs"]) == len( - meta_info["prefill_top_logprobs"] + assert len(meta_info["input_token_logprobs"]) == len( + meta_info["input_top_logprobs"] ) - assert len(meta_info["decode_token_logprobs"]) == len( - meta_info["decode_top_logprobs"] + assert len(meta_info["output_token_logprobs"]) == len( + meta_info["output_top_logprobs"] ) - assert len(meta_info["prefill_token_logprobs"]) == meta_info["prompt_tokens"] - assert len(meta_info["decode_token_logprobs"]) == meta_info["completion_tokens"] - 1 + assert len(meta_info["input_token_logprobs"]) == meta_info["prompt_tokens"] + assert len(meta_info["output_token_logprobs"]) == meta_info["completion_tokens"] - 1 return res @@ -72,11 +72,11 @@ def pretty_print(res): meta_info = res["meta_info"] print("\n\n", "=" * 30, "Prefill", "=" * 30) - for i in range(len(meta_info["prefill_token_logprobs"])): - print(f"{str(meta_info['prefill_token_logprobs'][i][2].encode()): <20}", end="") + for i in range(len(meta_info["input_token_logprobs"])): + print(f"{str(meta_info['input_token_logprobs'][i][2].encode()): <20}", end="") top_ks = ( - [str(t[2].encode()) for t in meta_info["prefill_top_logprobs"][i]] - if meta_info["prefill_top_logprobs"][i] + [str(t[2].encode()) for t in meta_info["input_top_logprobs"][i]] + if meta_info["input_top_logprobs"][i] else [] ) for top_k in top_ks: @@ -84,9 +84,9 @@ def pretty_print(res): print() print("\n\n", "=" * 30, "Decode", "=" * 30) - for i in range(len(meta_info["decode_token_logprobs"])): - print(f"{str(meta_info['decode_token_logprobs'][i][2].encode()): <20}", end="") - top_ks = [str(t[2].encode()) for t in meta_info["decode_top_logprobs"][i]] + for i in range(len(meta_info["output_token_logprobs"])): + print(f"{str(meta_info['output_token_logprobs'][i][2].encode()): <20}", end="") + top_ks = [str(t[2].encode()) for t in meta_info["output_top_logprobs"][i]] for top_k in top_ks: print(f"{top_k: <15}", end="") print() diff --git a/examples/usage/llava/http_llama3_llava_test.py b/examples/usage/llava/http_llama3_llava_test.py index 113adbc8d78..813a26af531 100644 --- a/examples/usage/llava/http_llama3_llava_test.py +++ b/examples/usage/llava/http_llama3_llava_test.py @@ -14,16 +14,13 @@ import argparse import asyncio +import copy import json import time -import copy import aiohttp import requests - -from llava.conversation import ( - conv_llava_llama_3, -) +from llava.conversation import conv_llava_llama_3 async def send_request(url, data, delay=0): diff --git a/examples/usage/llava/http_qwen_llava_test.py b/examples/usage/llava/http_qwen_llava_test.py index 9ba206415a7..1c29658c609 100644 --- a/examples/usage/llava/http_qwen_llava_test.py +++ b/examples/usage/llava/http_qwen_llava_test.py @@ -14,16 +14,13 @@ import argparse import asyncio +import copy import json import time -import copy import aiohttp import requests - -from llava.conversation import ( - conv_qwen -) +from llava.conversation import conv_qwen async def send_request(url, data, delay=0): diff --git a/examples/usage/llava/srt_llava_next_test.py b/examples/usage/llava/srt_llava_next_test.py index 5333fb7d3cb..5a1211fbf7b 100644 --- a/examples/usage/llava/srt_llava_next_test.py +++ b/examples/usage/llava/srt_llava_next_test.py @@ -2,13 +2,15 @@ Usage: python3 srt_example_llava.py """ +from PIL import ImageFile + import sglang as sgl -from sglang.srt.utils import load_image from sglang.lang.chat_template import get_chat_template +from sglang.srt.utils import load_image -from PIL import ImageFile ImageFile.LOAD_TRUNCATED_IMAGES = True # Allow loading of truncated images + @sgl.function def image_qa(s, image, question): s += sgl.user(sgl.image(image) + question) diff --git a/examples/usage/llava_video/srt_example_llava_v.py b/examples/usage/llava_video/srt_example_llava_v.py index e18a81ebb89..27ba862d30d 100644 --- a/examples/usage/llava_video/srt_example_llava_v.py +++ b/examples/usage/llava_video/srt_example_llava_v.py @@ -1,16 +1,20 @@ """ -Usage: python3 srt_example_llava.py +Usage: +pip install opencv-python-headless +python3 srt_example_llava.py """ -import sglang as sgl -import os +import argparse import csv +import os import time -import argparse + +import sglang as sgl + @sgl.function def video_qa(s, num_frames, video_path, question): - s += sgl.user(sgl.video(video_path,num_frames) + question) + s += sgl.user(sgl.video(video_path, num_frames) + question) s += sgl.assistant(sgl.gen("answer")) @@ -25,7 +29,6 @@ def single(path, num_frames=16): print(state["answer"], "\n") - def split_into_chunks(lst, num_chunks): """Split a list into a specified number of chunks.""" # Calculate the chunk size using integer division. Note that this may drop some items if not evenly divisible. @@ -34,7 +37,7 @@ def split_into_chunks(lst, num_chunks): if chunk_size == 0: chunk_size = len(lst) # Use list comprehension to generate chunks. The last chunk will take any remainder if the list size isn't evenly divisible. - chunks = [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)] + chunks = [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)] # Ensure we have exactly num_chunks chunks, even if some are empty chunks.extend([[] for _ in range(num_chunks - len(chunks))]) return chunks @@ -42,67 +45,73 @@ def split_into_chunks(lst, num_chunks): def save_batch_results(batch_video_files, states, cur_chunk, batch_idx, save_dir): csv_filename = f"{save_dir}/chunk_{cur_chunk}_batch_{batch_idx}.csv" - with open(csv_filename, 'w', newline='') as csvfile: + with open(csv_filename, "w", newline="") as csvfile: writer = csv.writer(csvfile) - writer.writerow(['video_name', 'answer']) + writer.writerow(["video_name", "answer"]) for video_path, state in zip(batch_video_files, states): video_name = os.path.basename(video_path) writer.writerow([video_name, state["answer"]]) + def compile_and_cleanup_final_results(cur_chunk, num_batches, save_dir): final_csv_filename = f"{save_dir}/final_results_chunk_{cur_chunk}.csv" - with open(final_csv_filename, 'w', newline='') as final_csvfile: + with open(final_csv_filename, "w", newline="") as final_csvfile: writer = csv.writer(final_csvfile) - writer.writerow(['video_name', 'answer']) + writer.writerow(["video_name", "answer"]) for batch_idx in range(num_batches): batch_csv_filename = f"{save_dir}/chunk_{cur_chunk}_batch_{batch_idx}.csv" - with open(batch_csv_filename, 'r') as batch_csvfile: + with open(batch_csv_filename, "r") as batch_csvfile: reader = csv.reader(batch_csvfile) next(reader) # Skip header row for row in reader: writer.writerow(row) os.remove(batch_csv_filename) + def find_video_files(video_dir): # Check if the video_dir is actually a file if os.path.isfile(video_dir): # If it's a file, return it as a single-element list return [video_dir] - + # Original logic to find video files in a directory video_files = [] for root, dirs, files in os.walk(video_dir): for file in files: - if file.endswith(('.mp4', '.avi', '.mov')): + if file.endswith((".mp4", ".avi", ".mov")): video_files.append(os.path.join(root, file)) return video_files + def batch(video_dir, save_dir, cur_chunk, num_chunks, num_frames=16, batch_size=64): video_files = find_video_files(video_dir) chunked_video_files = split_into_chunks(video_files, num_chunks)[cur_chunk] num_batches = 0 for i in range(0, len(chunked_video_files), batch_size): - batch_video_files = chunked_video_files[i:i + batch_size] + batch_video_files = chunked_video_files[i : i + batch_size] print(f"Processing batch of {len(batch_video_files)} video(s)...") if not batch_video_files: print("No video files found in the specified directory.") return - + batch_input = [ - { + { "num_frames": num_frames, "video_path": video_path, "question": "Please provide a detailed description of the video, focusing on the main subjects, their actions, the background scenes.", - } for video_path in batch_video_files + } + for video_path in batch_video_files ] start_time = time.time() states = video_qa.run_batch(batch_input, max_new_tokens=512, temperature=0.2) total_time = time.time() - start_time average_time = total_time / len(batch_video_files) - print(f"Number of videos in batch: {len(batch_video_files)}. Average processing time per video: {average_time:.2f} seconds. Total time for this batch: {total_time:.2f} seconds") + print( + f"Number of videos in batch: {len(batch_video_files)}. Average processing time per video: {average_time:.2f} seconds. Total time for this batch: {total_time:.2f} seconds" + ) save_batch_results(batch_video_files, states, cur_chunk, num_batches, save_dir) num_batches += 1 @@ -113,16 +122,47 @@ def batch(video_dir, save_dir, cur_chunk, num_chunks, num_frames=16, batch_size= if __name__ == "__main__": # Create the parser - parser = argparse.ArgumentParser(description='Run video processing with specified port.') + parser = argparse.ArgumentParser( + description="Run video processing with specified port." + ) # Add an argument for the port - parser.add_argument('--port', type=int, default=30000, help='The master port for distributed serving.') - parser.add_argument('--chunk-idx', type=int, default=0, help='The index of the chunk to process.') - parser.add_argument('--num-chunks', type=int, default=8, help='The number of chunks to process.') - parser.add_argument('--save-dir', type=str, default="./work_dirs/llava_video", help='The directory to save the processed video files.') - parser.add_argument('--video-dir', type=str, default="./videos/Q98Z4OTh8RwmDonc.mp4", help='The directory or path for the processed video files.') - parser.add_argument('--model-path', type=str, default="lmms-lab/LLaVA-NeXT-Video-7B", help='The model path for the video processing.') - parser.add_argument('--num-frames', type=int, default=16, help='The number of frames to process in each video.' ) + parser.add_argument( + "--port", + type=int, + default=30000, + help="The master port for distributed serving.", + ) + parser.add_argument( + "--chunk-idx", type=int, default=0, help="The index of the chunk to process." + ) + parser.add_argument( + "--num-chunks", type=int, default=8, help="The number of chunks to process." + ) + parser.add_argument( + "--save-dir", + type=str, + default="./work_dirs/llava_video", + help="The directory to save the processed video files.", + ) + parser.add_argument( + "--video-dir", + type=str, + default="./videos/Q98Z4OTh8RwmDonc.mp4", + help="The directory or path for the processed video files.", + ) + parser.add_argument( + "--model-path", + type=str, + default="lmms-lab/LLaVA-NeXT-Video-7B", + help="The model path for the video processing.", + ) + parser.add_argument( + "--num-frames", + type=int, + default=16, + help="The number of frames to process in each video.", + ) parser.add_argument("--mm_spatial_pool_stride", type=int, default=2) # Parse the arguments @@ -154,7 +194,6 @@ def batch(video_dir, save_dir, cur_chunk, num_chunks, num_frames=16, batch_size= if "34b" in args.model_path.lower(): model_overide_args["image_token_index"] = 64002 - if args.num_frames == 32: model_overide_args["rope_scaling"] = {"factor": 2.0, "type": "linear"} model_overide_args["max_sequence_length"] = 4096 * 2 @@ -162,22 +201,22 @@ def batch(video_dir, save_dir, cur_chunk, num_chunks, num_frames=16, batch_size= elif args.num_frames < 32: pass else: - print("The maximum number of frames to process is 32. Please specify a valid number of frames.") + print( + "The maximum number of frames to process is 32. Please specify a valid number of frames." + ) exit() - runtime = sgl.Runtime( - model_path=args.model_path, #"liuhaotian/llava-v1.6-vicuna-7b", + model_path=args.model_path, # "liuhaotian/llava-v1.6-vicuna-7b", tokenizer_path=tokenizer_path, port=cur_port, - additional_ports=[cur_port+1,cur_port+2,cur_port+3,cur_port+4], + additional_ports=[cur_port + 1, cur_port + 2, cur_port + 3, cur_port + 4], model_overide_args=model_overide_args, - tp_size=1 + tp_size=1, ) sgl.set_default_backend(runtime) print(f"chat template: {runtime.endpoint.chat_template.name}") - # Run a single request # try: print("\n========== single ==========\n") @@ -185,24 +224,29 @@ def batch(video_dir, save_dir, cur_chunk, num_chunks, num_frames=16, batch_size= if os.path.isfile(root): video_files = [root] else: - video_files = [os.path.join(root, f) for f in os.listdir(root) if f.endswith(('.mp4', '.avi', '.mov'))] # Add more extensions if needed + video_files = [ + os.path.join(root, f) + for f in os.listdir(root) + if f.endswith((".mp4", ".avi", ".mov")) + ] # Add more extensions if needed start_time = time.time() # Start time for processing a single video for cur_video in video_files[:1]: print(cur_video) single(cur_video, num_frames) end_time = time.time() # End time for processing a single video total_time = end_time - start_time - average_time = total_time / len(video_files) # Calculate the average processing time + average_time = total_time / len( + video_files + ) # Calculate the average processing time print(f"Average processing time per video: {average_time:.2f} seconds") runtime.shutdown() # except Exception as e: # print(e) runtime.shutdown() - # # # Run a batch of requests # print("\n========== batch ==========\n") # if not os.path.exists(args.save_dir): # os.makedirs(args.save_dir) # batch(args.video_dir,args.save_dir,cur_chunk, num_chunks, num_frames, num_chunks) - # runtime.shutdown() \ No newline at end of file + # runtime.shutdown() diff --git a/examples/usage/openai_batch_chat.py b/examples/usage/openai_batch_chat.py new file mode 100644 index 00000000000..8640d092570 --- /dev/null +++ b/examples/usage/openai_batch_chat.py @@ -0,0 +1,96 @@ +""" +Usage: +python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 +python openai_batch_chat.py +Note: Before running this script, +you should create the input.jsonl file with the following content: +{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo-0125", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world! List 3 NBA players and tell a story"}],"max_tokens": 300}} +{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo-0125", "messages": [{"role": "system", "content": "You are an assistant. "},{"role": "user", "content": "Hello world! List three capital and tell a story"}],"max_tokens": 500}} +""" + +import json +import os +import time + +import openai +from openai import OpenAI + + +class OpenAIBatchProcessor: + def __init__(self, api_key): + # client = OpenAI(api_key=api_key) + client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="EMPTY") + + self.client = client + + def process_batch(self, input_file_path, endpoint, completion_window): + + # Upload the input file + with open(input_file_path, "rb") as file: + uploaded_file = self.client.files.create(file=file, purpose="batch") + + # Create the batch job + batch_job = self.client.batches.create( + input_file_id=uploaded_file.id, + endpoint=endpoint, + completion_window=completion_window, + ) + + # Monitor the batch job status + while batch_job.status not in ["completed", "failed", "cancelled"]: + time.sleep(3) # Wait for 3 seconds before checking the status again + print( + f"Batch job status: {batch_job.status}...trying again in 3 seconds..." + ) + batch_job = self.client.batches.retrieve(batch_job.id) + + # Check the batch job status and errors + if batch_job.status == "failed": + print(f"Batch job failed with status: {batch_job.status}") + print(f"Batch job errors: {batch_job.errors}") + return None + + # If the batch job is completed, process the results + if batch_job.status == "completed": + + # print result of batch job + print("batch", batch_job.request_counts) + + result_file_id = batch_job.output_file_id + # Retrieve the file content from the server + file_response = self.client.files.content(result_file_id) + result_content = file_response.read() # Read the content of the file + + # Save the content to a local file + result_file_name = "batch_job_chat_results.jsonl" + with open(result_file_name, "wb") as file: + file.write(result_content) # Write the binary content to the file + # Load data from the saved JSONL file + results = [] + with open(result_file_name, "r", encoding="utf-8") as file: + for line in file: + json_object = json.loads( + line.strip() + ) # Parse each line as a JSON object + results.append(json_object) + + return results + else: + print(f"Batch job failed with status: {batch_job.status}") + return None + + +# Initialize the OpenAIBatchProcessor +api_key = os.environ.get("OPENAI_API_KEY") +processor = OpenAIBatchProcessor(api_key) + +# Process the batch job +input_file_path = "input.jsonl" +endpoint = "/v1/chat/completions" +completion_window = "24h" + +# Process the batch job +results = processor.process_batch(input_file_path, endpoint, completion_window) + +# Print the results +print(results) diff --git a/examples/usage/openai_batch_complete.py b/examples/usage/openai_batch_complete.py new file mode 100644 index 00000000000..af694b54a92 --- /dev/null +++ b/examples/usage/openai_batch_complete.py @@ -0,0 +1,97 @@ +""" +Usage: +python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 +python openai_batch_complete.py +Note: Before running this script, +you should create the input.jsonl file with the following content: +{"custom_id": "request-1", "method": "POST", "url": "/v1/completions", "body": {"model": "gpt-3.5-turbo-instruct", "prompt": "List 3 names of famous soccer player: ", "max_tokens": 200}} +{"custom_id": "request-2", "method": "POST", "url": "/v1/completions", "body": {"model": "gpt-3.5-turbo-instruct", "prompt": "List 6 names of famous basketball player: ", "max_tokens": 400}} +{"custom_id": "request-3", "method": "POST", "url": "/v1/completions", "body": {"model": "gpt-3.5-turbo-instruct", "prompt": "List 6 names of famous basketball player: ", "max_tokens": 400}} +""" + +import json +import os +import time + +import openai +from openai import OpenAI + + +class OpenAIBatchProcessor: + def __init__(self, api_key): + # client = OpenAI(api_key=api_key) + client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="EMPTY") + + self.client = client + + def process_batch(self, input_file_path, endpoint, completion_window): + + # Upload the input file + with open(input_file_path, "rb") as file: + uploaded_file = self.client.files.create(file=file, purpose="batch") + + # Create the batch job + batch_job = self.client.batches.create( + input_file_id=uploaded_file.id, + endpoint=endpoint, + completion_window=completion_window, + ) + + # Monitor the batch job status + while batch_job.status not in ["completed", "failed", "cancelled"]: + time.sleep(3) # Wait for 3 seconds before checking the status again + print( + f"Batch job status: {batch_job.status}...trying again in 3 seconds..." + ) + batch_job = self.client.batches.retrieve(batch_job.id) + + # Check the batch job status and errors + if batch_job.status == "failed": + print(f"Batch job failed with status: {batch_job.status}") + print(f"Batch job errors: {batch_job.errors}") + return None + + # If the batch job is completed, process the results + if batch_job.status == "completed": + + # print result of batch job + print("batch", batch_job.request_counts) + + result_file_id = batch_job.output_file_id + # Retrieve the file content from the server + file_response = self.client.files.content(result_file_id) + result_content = file_response.read() # Read the content of the file + + # Save the content to a local file + result_file_name = "batch_job_complete_results.jsonl" + with open(result_file_name, "wb") as file: + file.write(result_content) # Write the binary content to the file + # Load data from the saved JSONL file + results = [] + with open(result_file_name, "r", encoding="utf-8") as file: + for line in file: + json_object = json.loads( + line.strip() + ) # Parse each line as a JSON object + results.append(json_object) + + return results + else: + print(f"Batch job failed with status: {batch_job.status}") + return None + + +# Initialize the OpenAIBatchProcessor +api_key = os.environ.get("OPENAI_API_KEY") +processor = OpenAIBatchProcessor(api_key) + +# Process the batch job +input_file_path = "input_complete.jsonl" +endpoint = "/v1/completions" +completion_window = "24h" + +# Process the batch job +results = processor.process_batch(input_file_path, endpoint, completion_window) + +# Print the results +print(results) diff --git a/examples/usage/openai_chat_speculative.py b/examples/usage/openai_chat_speculative.py index 94eb4327623..a9c5f5afb32 100644 --- a/examples/usage/openai_chat_speculative.py +++ b/examples/usage/openai_chat_speculative.py @@ -15,23 +15,40 @@ export OPENAI_API_KEY=sk-****** python3 openai_chat_speculative.py """ + import sglang as sgl -from sglang import function, set_default_backend, OpenAI +from sglang import OpenAI, function, set_default_backend @function(num_api_spec_tokens=256) def gen_character_spec(s): s += sgl.system("You are a helpful assistant.") s += sgl.user("Construct a character within the following format:") - s += sgl.assistant("Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n") + s += sgl.assistant( + "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n" + ) s += sgl.user("Please generate new Name, Birthday and Job.\n") - s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n")) + s += sgl.assistant( + "Name:" + + sgl.gen("name", stop="\n") + + "\nBirthday:" + + sgl.gen("birthday", stop="\n") + + "\nJob:" + + sgl.gen("job", stop="\n") + ) @function(num_api_spec_tokens=256) def gen_character_spec_no_few_shot(s): s += sgl.user("Construct a character. For each field stop with a newline\n") - s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nAge:" + sgl.gen("age", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n")) + s += sgl.assistant( + "Name:" + + sgl.gen("name", stop="\n") + + "\nAge:" + + sgl.gen("age", stop="\n") + + "\nJob:" + + sgl.gen("job", stop="\n") + ) @function @@ -45,10 +62,19 @@ def gen_character_normal(s): def multi_turn_question(s, question_1, question_2): s += sgl.system("You are a helpful assistant.") s += sgl.user("Answer questions in the following format:") - s += sgl.user("Question 1: What is the capital of France?\nQuestion 2: What is the population of this city?\n") - s += sgl.assistant("Answer 1: The capital of France is Paris.\nAnswer 2: The population of Paris in 2024 is estimated to be around 2.1 million for the city proper.\n") - s += sgl.user("Question 1: " + question_1+"\nQuestion 2: " + question_2) - s += sgl.assistant("Answer 1: " + sgl.gen("answer_1", stop="\n") + "\nAnswer 2: " + sgl.gen("answer_2", stop="\n")) + s += sgl.user( + "Question 1: What is the capital of France?\nQuestion 2: What is the population of this city?\n" + ) + s += sgl.assistant( + "Answer 1: The capital of France is Paris.\nAnswer 2: The population of Paris in 2024 is estimated to be around 2.1 million for the city proper.\n" + ) + s += sgl.user("Question 1: " + question_1 + "\nQuestion 2: " + question_2) + s += sgl.assistant( + "Answer 1: " + + sgl.gen("answer_1", stop="\n") + + "\nAnswer 2: " + + sgl.gen("answer_2", stop="\n") + ) def test_spec_single_turn(): @@ -97,7 +123,7 @@ def test_spec_multi_turn_stream(): state = multi_turn_question.run( question_1="What is the capital of the United States?", question_2="List two local attractions.", - stream=True + stream=True, ) for out in state.text_iter(): @@ -126,4 +152,4 @@ def test_spec_multi_turn_stream(): print("\n========== test spec multi turn stream ==========\n") # expect error in stream_executor: stream is not supported... - test_spec_multi_turn_stream() \ No newline at end of file + test_spec_multi_turn_stream() diff --git a/examples/usage/openai_parallel_sample.py b/examples/usage/openai_parallel_sample.py new file mode 100644 index 00000000000..753e66c744f --- /dev/null +++ b/examples/usage/openai_parallel_sample.py @@ -0,0 +1,153 @@ +import openai + +client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="EMPTY") + +# Text completion +response = client.completions.create( + model="default", + prompt="I am a robot and I want to study like humans. Now let's tell a story. Once upon a time, there was a little", + n=1, + temperature=0.8, + max_tokens=32, +) +print(response) + + +# Text completion +response = client.completions.create( + model="default", + prompt="I am a robot and I want to study like humans. Now let's tell a story. Once upon a time, there was a little", + n=5, + temperature=0.8, + max_tokens=320, +) +print(response) + + +# Text completion +response = client.completions.create( + model="default", + prompt="I am a robot and I want to study like humans. Now let's tell a story. Once upon a time, there was a little", + n=3, + temperature=0.8, + max_tokens=32, +) +print(response) + + +# Text completion +response = client.completions.create( + model="default", + prompt=["The name of the famous soccer player is"], + n=1, + temperature=0.8, + max_tokens=128, +) +print(response) + + +# Text completion +response = client.completions.create( + model="default", + prompt=["The name of the famous soccer player is ", "The capital of US is"], + n=1, + temperature=0.8, + max_tokens=32, +) +print(response) + + +# Text completion +response = client.completions.create( + model="default", + prompt=["The name of the famous soccer player is ", "The capital of US is"], + n=3, + temperature=0.8, + max_tokens=32, +) +print(response) + + +response = client.completions.create( + model="default", + prompt=[ + "prompt1: I am a robot and I want to learn like humans. Now let's begin a tale. Once upon a time, there was a small", + "prompt2: As a robot, my goal is to understand human learning. Let's start a story. In a faraway land, there lived a tiny", + "prompt3: Being a robot, I aspire to study like people. Let's share a story. Long ago, there was a little", + "prompt4: I am a robot aiming to learn like humans. Let's narrate a story. Once, in a distant kingdom, there was a young", + "prompt5: As a robot, I seek to learn in human ways. Let's tell a story. Once upon a time, in a small village, there was a young", + ], + n=1, + temperature=0.8, + max_tokens=320, +) +print(response) + + +# Text completion +response = client.completions.create( + model="default", + prompt=[ + "The capital of France is", + "The capital of Germany is", + "The capital of US is", + ], + n=3, + temperature=0.8, + max_tokens=32, +) +print(response) + +# Chat completion +response = client.chat.completions.create( + model="default", + messages=[ + {"role": "system", "content": "You are a helpful AI assistant"}, + {"role": "user", "content": "List 3 countries and their capitals."}, + ], + temperature=0.8, + max_tokens=1, + logprobs=True, + top_logprobs=3, +) +print(response) + +# Chat completion +response = client.chat.completions.create( + model="default", + messages=[ + {"role": "system", "content": "You are a helpful AI assistant"}, + {"role": "user", "content": "List 3 countries and their capitals."}, + ], + temperature=0.8, + max_tokens=1, + n=1, +) +print(response) + +# Chat completion +response = client.chat.completions.create( + model="default", + messages=[ + {"role": "system", "content": "You are a helpful AI assistant"}, + {"role": "user", "content": "List 3 countries and their capitals."}, + ], + temperature=0.8, + max_tokens=1, + logprobs=True, + top_logprobs=3, +) +print(response) + +# Chat completion +response = client.chat.completions.create( + model="default", + messages=[ + {"role": "system", "content": "You are a helpful AI assistant"}, + {"role": "user", "content": "List 3 countries and their capitals."}, + ], + temperature=0.8, + max_tokens=1, + n=4, +) +print(response) diff --git a/examples/usage/openai_speculative.py b/examples/usage/openai_speculative.py index c64694da667..4389cb05959 100644 --- a/examples/usage/openai_speculative.py +++ b/examples/usage/openai_speculative.py @@ -2,7 +2,8 @@ Usage: python3 openai_speculative.py """ -from sglang import function, gen, set_default_backend, OpenAI + +from sglang import OpenAI, function, gen, set_default_backend @function(num_api_spec_tokens=64) @@ -35,7 +36,11 @@ def gen_character_spec_no_few_shot(s): backend = OpenAI("gpt-3.5-turbo-instruct") set_default_backend(backend) - for function in [gen_character_spec, gen_character_no_spec, gen_character_spec_no_few_shot]: + for function in [ + gen_character_spec, + gen_character_no_spec, + gen_character_spec_no_few_shot, + ]: backend.token_usage.reset() print(f"function: {function.func.__name__}") @@ -46,4 +51,4 @@ def gen_character_spec_no_few_shot(s): print("...birthday:", state["birthday"]) print("...job:", state["job"]) print(backend.token_usage) - print() \ No newline at end of file + print() diff --git a/examples/usage/parallel_sample.py b/examples/usage/parallel_sample.py index 288b48ac0c0..0f3cf170000 100644 --- a/examples/usage/parallel_sample.py +++ b/examples/usage/parallel_sample.py @@ -2,6 +2,7 @@ Usage: python3 parallel_sample.py """ + import sglang as sgl @@ -12,7 +13,6 @@ def parallel_sample(s, question, n): "Reasoning: I need to use a calculator.\n" "Tool: calculator\n" "Answer: 6\n" - "Question: Compute 3 + 2 + 2\n" "Reasoning: I will try a calculator.\n" "Tool: calculator\n" @@ -27,13 +27,9 @@ def parallel_sample(s, question, n): sgl.set_default_backend(sgl.OpenAI("gpt-3.5-turbo-instruct")) -#sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000")) +# sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000")) -state = parallel_sample.run( - question="Compute 5 + 2 + 4.", - n=5, - temperature=1.0 -) +state = parallel_sample.run(question="Compute 5 + 2 + 4.", n=5, temperature=1.0) for i in range(5): obj = { diff --git a/examples/usage/rag_using_parea/trace_and_evaluate_rag_using_parea.ipynb b/examples/usage/rag_using_parea/trace_and_evaluate_rag_using_parea.ipynb index ce90e2186b3..25b91b7d1dc 100644 --- a/examples/usage/rag_using_parea/trace_and_evaluate_rag_using_parea.ipynb +++ b/examples/usage/rag_using_parea/trace_and_evaluate_rag_using_parea.ipynb @@ -71,6 +71,7 @@ "source": [ "import json\n", "import os\n", + "from typing import List\n", "\n", "import chromadb\n", "\n", @@ -148,7 +149,7 @@ "outputs": [], "source": [ "@trace\n", - "def retrieval(question: str) -> list[str]:\n", + "def retrieval(question: str) -> List[str]:\n", " return collection.query(\n", " query_texts=[question],\n", " n_results=1\n", @@ -278,7 +279,7 @@ "\n", "\n", "@trace(eval_funcs=[context_relevancy_eval, percent_target_supported_by_context])\n", - "def retrieval(question: str) -> list[str]:\n", + "def retrieval(question: str) -> List[str]:\n", " return collection.query(\n", " query_texts=[question],\n", " n_results=1\n", diff --git a/examples/usage/readme_examples.py b/examples/usage/readme_examples.py index 8789e1b132e..7269ef1485d 100644 --- a/examples/usage/readme_examples.py +++ b/examples/usage/readme_examples.py @@ -3,13 +3,18 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 python readme_examples.py """ + import sglang as sgl @sgl.function def tool_use(s, question): s += "To answer this question: " + question + ". " - s += "I need to use a " + sgl.gen("tool", choices=["calculator", "search engine"]) + ". " + s += ( + "I need to use a " + + sgl.gen("tool", choices=["calculator", "search engine"]) + + ". " + ) if s["tool"] == "calculator": s += "The math expression is" + sgl.gen("expression") @@ -75,7 +80,7 @@ def driver_batching(): {"question": "What is the capital of France?"}, {"question": "What is the capital of Japan?"}, ], - progress_bar=True + progress_bar=True, ) for s in states: @@ -85,9 +90,7 @@ def driver_batching(): def driver_stream(): state = text_qa.run( - question="What is the capital of France?", - temperature=0.1, - stream=True + question="What is the capital of France?", temperature=0.1, stream=True ) for out in state.text_iter(): @@ -96,7 +99,7 @@ def driver_stream(): if __name__ == "__main__": - #sgl.set_default_backend(sgl.OpenAI("gpt-3.5-turbo-instruct")) + # sgl.set_default_backend(sgl.OpenAI("gpt-3.5-turbo-instruct")) sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000")) driver_tool_use() diff --git a/examples/usage/streaming.py b/examples/usage/streaming.py index 20feaafbca6..506ee35c6f0 100644 --- a/examples/usage/streaming.py +++ b/examples/usage/streaming.py @@ -2,7 +2,9 @@ Usage: python3 streaming.py """ + import asyncio + import sglang as sgl @@ -22,7 +24,7 @@ def stream_a_variable(): state = multi_turn_question.run( question_1="What is the capital of the United States?", question_2="List two local attractions.", - stream=True + stream=True, ) for out in state.text_iter(var_name="answer_2"): @@ -34,7 +36,7 @@ async def async_stream(): state = multi_turn_question.run( question_1="What is the capital of the United States?", question_2="List two local attractions.", - stream=True + stream=True, ) async for out in state.text_async_iter(var_name="answer_2"): diff --git a/examples/usage/triton/models/character_generation/1/model.py b/examples/usage/triton/models/character_generation/1/model.py index e76992f9516..5550e93984b 100644 --- a/examples/usage/triton/models/character_generation/1/model.py +++ b/examples/usage/triton/models/character_generation/1/model.py @@ -1,45 +1,55 @@ -import triton_python_backend_utils as pb_utils import numpy +import triton_python_backend_utils as pb_utils +from pydantic import BaseModel + import sglang as sgl from sglang import function, set_default_backend from sglang.srt.constrained import build_regex_from_object -from pydantic import BaseModel - sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000")) + class Character(BaseModel): name: str eye_color: str house: str + @function def character_gen(s, name): s += ( name + " is a character in Harry Potter. Please fill in the following information about this character.\n" ) - s += sgl.gen("json_output", max_tokens=256, regex=build_regex_from_object(Character)) + s += sgl.gen( + "json_output", max_tokens=256, regex=build_regex_from_object(Character) + ) class TritonPythonModel: def initialize(self, args): print("Initialized.") + def execute(self, requests): responses = [] for request in requests: tensor_in = pb_utils.get_input_tensor_by_name(request, "INPUT_TEXT") if tensor_in is None: return pb_utils.InferenceResponse(output_tensors=[]) - - input_list_names = [i.decode('utf-8') if isinstance(i, bytes) else i for i in tensor_in.as_numpy().tolist()] - input_list_dicts = [{"name":i} for i in input_list_names] + input_list_names = [ + i.decode("utf-8") if isinstance(i, bytes) else i + for i in tensor_in.as_numpy().tolist() + ] + + input_list_dicts = [{"name": i} for i in input_list_names] states = character_gen.run_batch(input_list_dicts) character_strs = [state.text() for state in states] - tensor_out = pb_utils.Tensor("OUTPUT_TEXT", numpy.array(character_strs, dtype=object)) + tensor_out = pb_utils.Tensor( + "OUTPUT_TEXT", numpy.array(character_strs, dtype=object) + ) - responses.append(pb_utils.InferenceResponse(output_tensors = [tensor_out])) - return responses \ No newline at end of file + responses.append(pb_utils.InferenceResponse(output_tensors=[tensor_out])) + return responses diff --git a/python/pyproject.toml b/python/pyproject.toml index baa2165d259..32d4912a3ac 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -4,8 +4,8 @@ build-backend = "setuptools.build_meta" [project] name = "sglang" -version = "0.1.19" -description = "A structured generation langauge for LLMs." +version = "0.2.11" +description = "SGLang is yet another fast serving framework for large language models and vision language models." readme = "README.md" requires-python = ">=3.8" license = {file = "LICENSE"} @@ -20,12 +20,16 @@ dependencies = [ ] [project.optional-dependencies] -srt = ["aiohttp", "fastapi", "hf_transfer", "huggingface_hub", "interegular", "packaging", "pillow", - "psutil", "pydantic", "rpyc", "torch", "uvicorn", "uvloop", "zmq", "vllm==0.5.1", "outlines>=0.0.44"] +srt = ["aiohttp", "fastapi", "hf_transfer", "huggingface_hub", "interegular", + "packaging", "pillow", "psutil", "pydantic", "python-multipart", + "torch", "uvicorn", "uvloop", "zmq", + "vllm==0.5.4", "outlines>=0.0.44"] openai = ["openai>=1.0", "tiktoken"] anthropic = ["anthropic>=0.20.0"] litellm = ["litellm>=1.0.0"] +test = ["jsonlines", "matplotlib", "pandas"] all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"] +dev = ["sglang[all]", "sglang[test]"] [project.urls] "Homepage" = "https://github.com/sgl-project/sglang" diff --git a/python/sglang/README.md b/python/sglang/README.md new file mode 100644 index 00000000000..c873e1d63f2 --- /dev/null +++ b/python/sglang/README.md @@ -0,0 +1,12 @@ +# Code Structures + + +- `lang`: The frontend language. +- `srt`: The backend engine for running local models. (SRT = SGLang Runtime). +- `test`: Test utilities. +- `api.py`: Public API. +- `bench_latency.py`: Benchmark a single static batch. +- `bench_serving.py`: Benchmark online serving with dynamic requests. +- `global_config.py`: The global configs and constants. +- `launch_server.py`: The entry point of launching local server. +- `utils.py`: Common utilities. diff --git a/python/sglang/__init__.py b/python/sglang/__init__.py index eed6fba6e48..71d7bfeccfa 100644 --- a/python/sglang/__init__.py +++ b/python/sglang/__init__.py @@ -1,6 +1,5 @@ -__version__ = "0.1.19" - # SGL API Components + from sglang.api import ( Runtime, assistant, @@ -16,46 +15,62 @@ select, set_default_backend, system, + system_begin, + system_end, user, user_begin, user_end, video, ) +from sglang.lang.choices import ( + greedy_token_selection, + token_length_normalized, + unconditional_likelihood_normalized, +) -# SGL Backends -from sglang.backend.anthropic import Anthropic -from sglang.backend.litellm import LiteLLM -from sglang.backend.openai import OpenAI -from sglang.backend.runtime_endpoint import RuntimeEndpoint -from sglang.backend.vertexai import VertexAI - -# Global Configurations -from sglang.global_config import global_config - -# public APIs management +# SGLang DSL APIs __all__ = [ - "global_config", - "Anthropic", - "LiteLLM", - "OpenAI", - "RuntimeEndpoint", - "VertexAI", - "function", "Runtime", - "set_default_backend", + "assistant", + "assistant_begin", + "assistant_end", "flush_cache", - "get_server_args", + "function", "gen", "gen_int", "gen_string", + "get_server_args", "image", - "video", "select", + "set_default_backend", "system", + "system_begin", + "system_end", "user", - "assistant", "user_begin", "user_end", - "assistant_begin", - "assistant_end", + "video", + "greedy_token_selection", + "token_length_normalized", + "unconditional_likelihood_normalized", ] + +# Global Configurations +from sglang.global_config import global_config + +__all__ += ["global_config"] + +from sglang.version import __version__ + +__all__ += ["__version__"] + +# SGL Backends +from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint +from sglang.utils import LazyImport + +Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic") +LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM") +OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI") +VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI") + +__all__ += ["Anthropic", "LiteLLM", "OpenAI", "VertexAI", "RuntimeEndpoint"] diff --git a/python/sglang/api.py b/python/sglang/api.py index 04389356875..5a177c36b0b 100644 --- a/python/sglang/api.py +++ b/python/sglang/api.py @@ -4,8 +4,9 @@ import re from typing import Callable, List, Optional, Union -from sglang.backend.base_backend import BaseBackend from sglang.global_config import global_config +from sglang.lang.backend.base_backend import BaseBackend +from sglang.lang.choices import ChoicesSamplingMethod, token_length_normalized from sglang.lang.ir import ( SglExpr, SglExprList, @@ -73,12 +74,18 @@ def gen( return_text_in_logprobs: Optional[bool] = None, dtype: Optional[type] = None, choices: Optional[List[str]] = None, + choices_method: Optional[ChoicesSamplingMethod] = None, regex: Optional[str] = None, ): - """Call the model to generate. See the meaning of the arguments in docs/sampling_params.md""" + """Call the model to generate. See the meaning of the arguments in docs/en/sampling_params.md""" if choices: - return SglSelect(name, choices, 0.0 if temperature is None else temperature) + return SglSelect( + name, + choices, + 0.0 if temperature is None else temperature, + token_length_normalized if choices_method is None else choices_method, + ) # check regex is valid if regex is not None: @@ -186,9 +193,10 @@ def select( name: Optional[str] = None, choices: Optional[List[str]] = None, temperature: float = 0.0, + choices_method: ChoicesSamplingMethod = token_length_normalized, ): assert choices is not None - return SglSelect(name, choices, temperature) + return SglSelect(name, choices, temperature, choices_method) def _role_common(name: str, expr: Optional[SglExpr] = None): @@ -210,6 +218,14 @@ def assistant(expr: Optional[SglExpr] = None): return _role_common("assistant", expr) +def system_begin(): + return SglRoleBegin("system") + + +def system_end(): + return SglRoleEnd("system") + + def user_begin(): return SglRoleBegin("user") diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index 49727b121e2..50ec8a67f28 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -1,13 +1,21 @@ """ Benchmark the latency of a given model. It accepts arguments similar to those of launch_server.py. -# Usage (latency test): +# Usage (latency test) +## with dummy weights: python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy +## sweep through multiple data points and store (append) the results in a jsonl file: +python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --result-filename out.jsonl +## do some changes, and store the results under a different run_name: +python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --result-filename out.jsonl --run-name after +## plot the results in series of lines: +python -m sglang.bench_latency --result-filename out.jsonl --graph-sql="select run_name, batch_size, prefill_throughput from results" + # Usage (correctness test): python -m sglang.bench_latency --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct -### Reference output: +## Reference output (of the correctness test above, can be gpu dependent): prefill logits (first half) tensor([[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633], [-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633], [ -9.1875, -10.2500, 2.7109, ..., -4.3359, -4.0664, -4.1328]], @@ -28,20 +36,24 @@ import argparse import dataclasses +import itertools import logging import multiprocessing import os +import sqlite3 import time - +from typing import Tuple import numpy as np +import pandas as pd import torch import torch.distributed as dist from sglang.srt.hf_transformers_utils import get_tokenizer -from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, Req -from sglang.srt.managers.controller.model_runner import ModelRunner +from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.model_config import ModelConfig +from sglang.srt.model_executor.forward_batch_info import ForwardMode +from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.sampling_params import SamplingParams from sglang.srt.server_args import ServerArgs from sglang.srt.utils import suppress_other_loggers @@ -49,25 +61,50 @@ @dataclasses.dataclass class BenchArgs: - batch_size: int = 1 - input_len: int = 1024 - output_len: int = 4 + run_name: str = "before" + batch_size: Tuple[int] = (1,) + input_len: Tuple[int] = (1024,) + output_len: Tuple[int] = (4,) + result_filename: str = "" correctness_test: bool = False # This is only used for correctness test cut_len: int = 4 + # Plotting args + graph_sql: str = ( + "select run_name, batch_size, prefill_throughput from results where run_name='before'" + ) + graph_filename: str = "out.png" @staticmethod def add_cli_args(parser: argparse.ArgumentParser): - parser.add_argument("--batch-size", type=int, default=BenchArgs.batch_size) - parser.add_argument("--input-len", type=int, default=BenchArgs.input_len) - parser.add_argument("--output-len", type=int, default=BenchArgs.output_len) + parser.add_argument("--run-name", type=str, default=BenchArgs.run_name) + parser.add_argument( + "--batch-size", type=int, nargs="+", default=BenchArgs.batch_size + ) + parser.add_argument( + "--input-len", type=int, nargs="+", default=BenchArgs.input_len + ) + parser.add_argument( + "--output-len", type=int, nargs="+", default=BenchArgs.output_len + ) + parser.add_argument( + "--result-filename", type=str, default=BenchArgs.result_filename + ) parser.add_argument("--correctness-test", action="store_true") parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len) + # graphing + parser.add_argument("--graph-sql", type=str, default=BenchArgs.graph_sql) + parser.add_argument( + "--graph-filename", type=str, default=BenchArgs.graph_filename + ) @classmethod def from_cli_args(cls, args: argparse.Namespace): - attrs = [attr.name for attr in dataclasses.fields(cls)] - return cls(**{attr: getattr(args, attr) for attr in attrs}) + # use the default value's type to case the args into correct types. + attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)] + return cls( + **{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs} + ) def load_model(server_args, tp_rank): @@ -95,7 +132,7 @@ def load_model(server_args, tp_rank): return model_runner, tokenizer -def prepare_inputs(bench_args, tokenizer): +def prepare_inputs_for_correctness_test(bench_args, tokenizer): prompts = [ "The capital of France is", "The capital of the United Kindom is", @@ -121,7 +158,9 @@ def prepare_inputs(bench_args, tokenizer): return input_ids, reqs -def prepare_extend_inputs(bench_args, input_ids, reqs, model_runner): +def prepare_extend_inputs_for_correctness_test( + bench_args, input_ids, reqs, model_runner +): for i in range(len(reqs)): req = reqs[i] req.input_ids += input_ids[i][bench_args.cut_len :] @@ -131,8 +170,8 @@ def prepare_extend_inputs(bench_args, input_ids, reqs, model_runner): return reqs -def prepare_synthetic_inputs(bench_args, tokenizer): - input_ids = np.ones((bench_args.batch_size, bench_args.input_len), dtype=np.int32) +def prepare_synthetic_inputs_for_latency_test(batch_size, input_len): + input_ids = np.ones((batch_size, input_len), dtype=np.int32) sampling_params = SamplingParams( temperature=0, max_new_tokens=BenchArgs.output_len, @@ -150,7 +189,7 @@ def prepare_synthetic_inputs(bench_args, tokenizer): def extend(reqs, model_runner): - batch = Batch.init_new( + batch = ScheduleBatch.init_new( reqs=reqs, req_to_token_pool=model_runner.req_to_token_pool, token_to_kv_pool=model_runner.token_to_kv_pool, @@ -158,14 +197,14 @@ def extend(reqs, model_runner): ) batch.prepare_for_extend(model_runner.model_config.vocab_size, None) output = model_runner.forward(batch, ForwardMode.EXTEND) - next_token_ids, _ = batch.sample(output.next_token_logits) + next_token_ids = batch.sample(output.next_token_logits) return next_token_ids, output.next_token_logits, batch def decode(input_token_ids, batch, model_runner): batch.prepare_for_decode(input_token_ids.cpu().numpy()) output = model_runner.forward(batch, ForwardMode.DECODE) - next_token_ids, _ = batch.sample(output.next_token_logits) + next_token_ids = batch.sample(output.next_token_logits) return next_token_ids, output.next_token_logits @@ -181,7 +220,7 @@ def correctness_test( model_runner, tokenizer = load_model(server_args, tp_rank) # Prepare inputs - input_ids, reqs = prepare_inputs(bench_args, tokenizer) + input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer) if bench_args.cut_len > 0: # Prefill @@ -189,7 +228,9 @@ def correctness_test( rank_print("prefill logits (first half)", next_token_logits) # Prepare extend inputs - reqs = prepare_extend_inputs(bench_args, input_ids, reqs, model_runner) + reqs = prepare_extend_inputs_for_correctness_test( + bench_args, input_ids, reqs, model_runner + ) # Extend next_token_ids, next_token_logits, batch = extend(reqs, model_runner) @@ -197,7 +238,7 @@ def correctness_test( # Decode output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))] - for _ in range(bench_args.output_len): + for _ in range(bench_args.output_len[0]): next_token_ids, _ = decode(next_token_ids, batch, model_runner) for i in range(len(reqs)): output_ids[i].append(next_token_ids[i]) @@ -207,6 +248,74 @@ def correctness_test( rank_print(tokenizer.decode(output_ids[i])) +@torch.inference_mode() +def latency_test_run_once( + run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len +): + max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len) + if batch_size > max_batch_size: + rank_print( + f"skipping ({batch_size}, {input_len}, {output_len}) due to max batch size limit" + ) + return + + # Clear the pools. + model_runner.req_to_token_pool.clear() + model_runner.token_to_kv_pool.clear() + + measurement_results = { + "run_name": run_name, + "batch_size": batch_size, + "input_len": input_len, + "output_len": output_len, + } + + tot_latency = 0 + + # Prefill + torch.cuda.synchronize() + tic = time.time() + next_token_ids, _, batch = extend(reqs, model_runner) + torch.cuda.synchronize() + prefill_latency = time.time() - tic + tot_latency += prefill_latency + throughput = input_len * batch_size / prefill_latency + rank_print( + f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s" + ) + measurement_results["prefill_latency"] = prefill_latency + measurement_results["prefill_throughput"] = throughput + + # Decode + for i in range(output_len): + torch.cuda.synchronize() + tic = time.time() + next_token_ids, _ = decode(next_token_ids, batch, model_runner) + torch.cuda.synchronize() + latency = time.time() - tic + tot_latency += latency + throughput = batch_size / latency + if i < 5: + rank_print( + f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s" + ) + avg_decode_latency = (tot_latency - prefill_latency) / output_len + avg_decode_throughput = batch_size / avg_decode_latency + rank_print( + f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s" + ) + measurement_results["avg_decode_latency"] = avg_decode_latency + measurement_results["avg_decode_throughput"] = avg_decode_throughput + + throughput = (input_len + output_len) * batch_size / tot_latency + rank_print( + f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s" + ) + measurement_results["total_latency"] = tot_latency + measurement_results["total_throughput"] = throughput + return measurement_results + + def latency_test( server_args, bench_args, @@ -216,99 +325,151 @@ def latency_test( # Load the model model_runner, tokenizer = load_model(server_args, tp_rank) - rank_print( - f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}" - ) - # Prepare inputs - reqs = prepare_synthetic_inputs(bench_args, tokenizer) + # Prepare inputs for warm up + reqs = prepare_synthetic_inputs_for_latency_test( + bench_args.batch_size[0], bench_args.input_len[0] + ) - def clear(): - model_runner.req_to_token_pool.clear() - model_runner.token_to_kv_pool.clear() + # Warm up + latency_test_run_once( + bench_args.run_name, + model_runner, + rank_print, + reqs, + bench_args.batch_size[0], + bench_args.input_len[0], + 4, # shorter decoding to speed up the warmup + ) - @torch.inference_mode() - def run_once(output_len): - # Prefill - torch.cuda.synchronize() - tot_latency = 0 - tic = time.time() - next_token_ids, _, batch = extend(reqs, model_runner) - torch.cuda.synchronize() - prefill_latency = time.time() - tic - tot_latency += prefill_latency - throughput = bench_args.input_len * bench_args.batch_size / prefill_latency - rank_print( - f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s" + # Run the sweep + result_list = [] + for bs, il, ol in itertools.product( + bench_args.batch_size, bench_args.input_len, bench_args.output_len + ): + req = prepare_synthetic_inputs_for_latency_test(bs, il) + ret = latency_test_run_once( + bench_args.run_name, model_runner, rank_print, reqs, bs, il, ol ) + if ret is not None: + result_list.append(ret) - # Decode - for i in range(output_len): - torch.cuda.synchronize() - tic = time.time() - next_token_ids, _ = decode(next_token_ids, batch, model_runner) - torch.cuda.synchronize() - latency = time.time() - tic - tot_latency += latency - throughput = bench_args.batch_size / latency - if i < 5: - rank_print( - f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s" - ) - avg_decode_latency = (tot_latency - prefill_latency) / output_len - avg_decode_throughput = bench_args.batch_size / avg_decode_latency - rank_print( - f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s" - ) + # Write results in jsonlines format on rank 0. + if tp_rank == 0 and bench_args.result_filename: + import jsonlines - throughput = ( - (bench_args.input_len + bench_args.output_len) - * bench_args.batch_size - / tot_latency - ) - rank_print( - f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s" - ) + with jsonlines.open(bench_args.result_filename, "a") as f: + f.write_all(result_list) - # Warm up - run_once(4) - clear() - # Run again - run_once(bench_args.output_len) +def plot_latency_test( + server_args, + bench_args, + tp_rank, +): + assert tp_rank == 0 + + # read the jsonl file and put in sqlite + df = pd.read_json(bench_args.result_filename, lines=True) + conn = sqlite3.connect(":memory:") + cur = conn.cursor() + + # get the columns and their types + column_names = list(df.iloc[0].keys()) + type_dict = { + str: "TEXT", + np.int64: "INTEGER", + np.float64: "FLOAT", + } + column_types = [type_dict[type(i)] for i in list(df.iloc[0])] + + # create the table + cur.execute( + f""" + CREATE TABLE IF NOT EXISTS results ( + {", ".join([f"{name} {type}" for name, type in zip(column_names, column_types)])} + ) + """ + ) + conn.commit() + + # write the results to DB + df.to_sql("results", conn, if_exists="replace", index=False) + conn.commit() + + # read it back using sql + df = pd.read_sql_query(bench_args.graph_sql, conn) + conn.close() + + # plot it and save to a file + import matplotlib.pyplot as plt + + assert ( + len(df.columns) == 3 + ), f"The sql should have fetched columns, not {df.columns}" + for label in df[df.columns[0]].unique(): + q = f"{df.columns[0]}=='{label}'" + series = df.query(q) + plt.plot(series[df.columns[1]], series[df.columns[2]], label=q, marker="o") + plt.xlabel(df.columns[1]) + plt.ylabel(df.columns[2]) + plt.legend() + plt.savefig(bench_args.graph_filename, dpi=300) + + # if in kitty, just dump it to the terminal + if os.environ["TERM"] == "xterm-kitty": + os.system( + f"kitty icat --use-window-size 1,1,600,600 {bench_args.graph_filename}" + ) def main(server_args, bench_args): - print(bench_args) - if bench_args.correctness_test: - work_func = correctness_test + if server_args.model_path: + if bench_args.correctness_test: + work_func = correctness_test + else: + work_func = latency_test + elif os.path.isfile(bench_args.result_filename): + assert bench_args.graph_filename, "please provide a filename for the graph" + work_func = plot_latency_test else: - work_func = latency_test - - workers = [] - for tp_rank in range(server_args.tp_size): - proc = multiprocessing.Process( - target=work_func, - args=( - server_args, - bench_args, - tp_rank, - ), + raise ValueError( + "Provide --model-path for running the tests or " + "provide --result-filename for plotting the results" ) - proc.start() - workers.append(proc) - for proc in workers: - proc.join() + if server_args.tp_size == 1: + work_func(server_args, bench_args, 0) + else: + workers = [] + for tp_rank in range(server_args.tp_size): + proc = multiprocessing.Process( + target=work_func, + args=( + server_args, + bench_args, + tp_rank, + ), + ) + proc.start() + workers.append(proc) + + for proc in workers: + proc.join() - proc.terminate() + proc.terminate() if __name__ == "__main__": parser = argparse.ArgumentParser() ServerArgs.add_cli_args(parser) BenchArgs.add_cli_args(parser) + # For this script, model-path is not required + assert ( + parser._actions[1].option_strings[0] == "--model-path" + ), "options changed, this code need to be updated" + parser._actions[1].required = False args = parser.parse_args() server_args = ServerArgs.from_cli_args(args) diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py new file mode 100644 index 00000000000..cc240684684 --- /dev/null +++ b/python/sglang/bench_serving.py @@ -0,0 +1,1002 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/backend_request_func.py +# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/benchmark_serving.py + +""" +Benchmark online serving. + +Usage: +python3 -m sglang.bench_serving --backend sglang --num-prompt 10 + +python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 3000 --random-input 1024 --random-output 1024 --random-range-ratio 0.5 +python3 -m sglang.bench_serving --backend sglang --dataset-name random --request-rate-range 1,2,4,8,16,32 --random-input 4096 --random-output 1024 --random-range-ratio 0.125 --multi +""" + +import argparse +import asyncio +import json +import os +import random +import resource +import sys +import time +import traceback +import warnings +from argparse import ArgumentParser +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union + +import aiohttp +import numpy as np +import requests +from tqdm.asyncio import tqdm +from transformers import ( + AutoTokenizer, + PreTrainedTokenizer, + PreTrainedTokenizerBase, + PreTrainedTokenizerFast, +) + +AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) + + +@dataclass +class RequestFuncInput: + prompt: str + api_url: str + prompt_len: int + output_len: int + model: str + extra_request_body: Dict[str, Any] + + +@dataclass +class RequestFuncOutput: + generated_text: str = "" + success: bool = False + latency: float = 0.0 + ttft: float = 0.0 # Time to first token + itl: List[float] = field(default_factory=list) # List of inter-token latencies + prompt_len: int = 0 + error: str = "" + output_len: int = 0 + + +def remove_prefix(text: str, prefix: str) -> str: + return text[len(prefix) :] if text.startswith(prefix) else text + + +# trt llm not support ignore_eos +# https://github.com/triton-inference-server/tensorrtllm_backend/issues/505 +async def async_request_trt_llm( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + assert api_url.endswith("generate_stream") + + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + payload = { + "accumulate_tokens": True, + "text_input": request_func_input.prompt, + "temperature": 0.000001, + "top_p": 1.0, + "max_tokens": request_func_input.output_len, + "stream": True, + "min_length": request_func_input.output_len, + "end_id": 1048576, + **request_func_input.extra_request_body, + } + if args.disable_ignore_eos: + del payload["min_length"] + del payload["end_id"] + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post(url=api_url, json=payload) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data:") + + data = json.loads(chunk) + output.generated_text += data["text_output"] + timestamp = time.perf_counter() + # First token + if ttft == 0.0: + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - most_recent_timestamp) + + most_recent_timestamp = timestamp + + output.latency = most_recent_timestamp - st + output.success = True + output.output_len = request_func_input.output_len + + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +# set ignore_eos True by default +async def async_request_openai_completions( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + assert api_url.endswith( + "completions" + ), "OpenAI Completions API URL must end with 'completions'." + + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + payload = { + "model": request_func_input.model, + "prompt": request_func_input.prompt, + "temperature": 0.0, + "best_of": 1, + "max_tokens": request_func_input.output_len, + "stream": not args.disable_stream, + "ignore_eos": not args.disable_ignore_eos, + **request_func_input.extra_request_body, + } + headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + generated_text = "" + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post( + url=api_url, json=payload, headers=headers + ) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ") + latency = time.perf_counter() - st + if chunk == "[DONE]": + pass + else: + data = json.loads(chunk) + + # NOTE: Some completion API might have a last + # usage summary response without a token so we + # want to check a token was generated + if data["choices"][0]["text"]: + timestamp = time.perf_counter() + # First token + if ttft == 0.0: + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase + output.itl.append(timestamp - most_recent_timestamp) + + most_recent_timestamp = timestamp + generated_text += data["choices"][0]["text"] + + output.generated_text = generated_text + output.success = True + output.latency = latency + output.output_len = request_func_input.output_len + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +def get_model(pretrained_model_name_or_path: str) -> str: + if os.getenv("SGLANG_USE_MODELSCOPE", "False").lower() == "true": + import huggingface_hub.constants + from modelscope import snapshot_download + + model_path = snapshot_download( + model_id=pretrained_model_name_or_path, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"], + ) + + return model_path + return pretrained_model_name_or_path + + +def get_tokenizer( + pretrained_model_name_or_path: str, +) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + if pretrained_model_name_or_path is not None and not os.path.exists( + pretrained_model_name_or_path + ): + pretrained_model_name_or_path = get_model(pretrained_model_name_or_path) + return AutoTokenizer.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=True + ) + + +ASYNC_REQUEST_FUNCS = { + "sglang": async_request_openai_completions, + "vllm": async_request_openai_completions, + "lmdeploy": async_request_openai_completions, + "trt": async_request_trt_llm, +} + + +@dataclass +class BenchmarkMetrics: + completed: int + total_input: int + total_output: int + total_output_retokenized: int + request_throughput: float + input_throughput: float + output_throughput: float + output_throughput_retokenized: float + mean_ttft_ms: float + median_ttft_ms: float + std_ttft_ms: float + p99_ttft_ms: float + mean_tpot_ms: float + median_tpot_ms: float + std_tpot_ms: float + p99_tpot_ms: float + mean_itl_ms: float + median_itl_ms: float + std_itl_ms: float + p99_itl_ms: float + mean_e2e_latency_ms: float + median_e2e_latency_ms: float + + +default_sharegpt_path = "ShareGPT_V3_unfiltered_cleaned_split.json" + + +def download_sharegpt_dataset(path): + url = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json" + + print(f"Downloading dataset from {url}") + try: + response = requests.get(url, stream=True) + response.raise_for_status() + + total_size = int(response.headers.get("content-length", 0)) + block_size = 8192 + + with open(path, "wb") as f, tqdm( + desc="Downloading", + total=total_size, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as progress_bar: + for data in response.iter_content(block_size): + size = f.write(data) + progress_bar.update(size) + + print(f"Dataset downloaded and saved to {path}") + except requests.RequestException as e: + raise Exception(f"Failed to download dataset: {e}") + + +def sample_sharegpt_requests( + dataset_path: str, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + fixed_output_len: Optional[int] = None, +) -> List[Tuple[str, int, int]]: + if fixed_output_len is not None and fixed_output_len < 4: + raise ValueError("output_len too small") + + # Download sharegpt if necessary + if not os.path.isfile(dataset_path) and not os.path.isfile(default_sharegpt_path): + download_sharegpt_dataset(default_sharegpt_path) + dataset_path = default_sharegpt_path + else: + dataset_path = ( + dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path + ) + + # Load the dataset. + with open(dataset_path) as f: + dataset = json.load(f) + # Filter out the conversations with less than 2 turns. + dataset = [data for data in dataset if len(data["conversations"]) >= 2] + # Only keep the first two turns of each conversation. + dataset = [ + (data["conversations"][0]["value"], data["conversations"][1]["value"]) + for data in dataset + ] + + # Shuffle the dataset. + random.shuffle(dataset) + + # Filter out sequences that are too long or too short + filtered_dataset: List[Tuple[str, int, int]] = [] + for i in range(len(dataset)): + if len(filtered_dataset) == num_requests: + break + + # Tokenize the prompts and completions. + prompt = dataset[i][0] + prompt_token_ids = tokenizer(prompt).input_ids + completion = dataset[i][1] + completion_token_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_token_ids) + output_len = ( + len(completion_token_ids) if fixed_output_len is None else fixed_output_len + ) + if prompt_len < 4 or output_len < 4: + # Prune too short sequences. + continue + if prompt_len > 1024 or prompt_len + output_len > 2048: + # Prune too long sequences. + continue + filtered_dataset.append((prompt, prompt_len, output_len)) + + return filtered_dataset + + +def sample_random_requests( + input_len: int, + output_len: int, + num_prompts: int, + range_ratio: float, + tokenizer: PreTrainedTokenizerBase, + dataset_path: str, +) -> List[Tuple[str, int, int]]: + + input_lens = np.random.randint( + max(int(input_len * range_ratio), 1), + input_len + 1, + size=num_prompts, + ) + output_lens = np.random.randint( + int(output_len * range_ratio), + output_len + 1, + size=num_prompts, + ) + + if True: + # Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens + + # Download sharegpt if necessary + if not os.path.isfile(dataset_path) and not os.path.isfile( + default_sharegpt_path + ): + download_sharegpt_dataset(default_sharegpt_path) + dataset_path = default_sharegpt_path + else: + dataset_path = ( + dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path + ) + + # Load the dataset. + with open(dataset_path) as f: + dataset = json.load(f) + # Filter out the conversations with less than 2 turns. + dataset = [data for data in dataset if len(data["conversations"]) >= 2] + # Only keep the first two turns of each conversation. + dataset = [ + (data["conversations"][0]["value"], data["conversations"][1]["value"]) + for data in dataset + ] + + # Shuffle the dataset. + random.shuffle(dataset) + + # Filter out sequences that are too long or too short + input_requests: List[Tuple[str, int, int]] = [] + for i in range(num_prompts): + # Tokenize the prompts and completions. + prompt = dataset[i][0] + prompt_token_ids = tokenizer(prompt).input_ids + prompt_len = len(prompt_token_ids) + + if prompt_len > input_lens[i]: + input_ids = prompt_token_ids[: input_lens[i]] + else: + ratio = (input_lens[i] + prompt_len - 1) // prompt_len + input_ids = (prompt_token_ids * ratio)[: input_lens[i]] + prompt = tokenizer.decode(input_ids) + input_requests.append((prompt, int(input_lens[i]), int(output_lens[i]))) + else: + # Sample token ids from random integers. This can cause some NaN issues. + offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts) + input_requests = [] + for i in range(num_prompts): + prompt = tokenizer.decode( + [ + (offsets[i] + i + j) % tokenizer.vocab_size + for j in range(input_lens[i]) + ] + ) + input_requests.append((prompt, int(input_lens[i]), int(output_lens[i]))) + + print(f"#Input tokens: {np.sum(input_lens)}") + print(f"#Output tokens: {np.sum(output_lens)}") + return input_requests + + +async def get_request( + input_requests: List[Tuple[str, int, int]], + request_rate: float, +) -> AsyncGenerator[Tuple[str, int, int], None]: + input_requests = iter(input_requests) + for request in input_requests: + yield request + + if request_rate == float("inf"): + # If the request rate is infinity, then we don't need to wait. + continue + + # Sample the request interval from the exponential distribution. + interval = np.random.exponential(1.0 / request_rate) + # The next request will be sent after the interval. + await asyncio.sleep(interval) + + +def calculate_metrics( + input_requests: List[Tuple[str, int, int]], + outputs: List[RequestFuncOutput], + dur_s: float, + tokenizer: PreTrainedTokenizerBase, + backend: str, +) -> Tuple[BenchmarkMetrics, List[int]]: + output_lens: List[int] = [] + retokenized_output_lens: List[int] = [] + total_input = 0 + completed = 0 + itls: List[float] = [] + tpots: List[float] = [] + ttfts: List[float] = [] + e2e_latencies: List[float] = [] + for i in range(len(outputs)): + if outputs[i].success: + output_len = outputs[i].output_len + output_lens.append(output_len) + retokenized_output_len = len( + tokenizer(outputs[i].generated_text, add_special_tokens=False).input_ids + ) + retokenized_output_lens.append(retokenized_output_len) + total_input += input_requests[i][1] + if output_len > 1: + tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1)) + itls += outputs[i].itl + ttfts.append(outputs[i].ttft) + + e2e_latencies.append(outputs[i].latency) + + completed += 1 + else: + output_lens.append(0) + retokenized_output_lens.append(0) + + if completed == 0: + warnings.warn( + "All requests failed. This is likely due to a misconfiguration " + "on the benchmark arguments.", + stacklevel=2, + ) + metrics = BenchmarkMetrics( + completed=completed, + total_input=total_input, + total_output=sum(output_lens), + total_output_retokenized=sum(retokenized_output_lens), + request_throughput=completed / dur_s, + input_throughput=total_input / dur_s, + output_throughput=sum(output_lens) / dur_s, + output_throughput_retokenized=sum(retokenized_output_lens) / dur_s, + mean_ttft_ms=np.mean(ttfts or 0) + * 1000, # ttfts is empty if streaming is not supported by backend + median_ttft_ms=np.median(ttfts or 0) * 1000, + std_ttft_ms=np.std(ttfts or 0) * 1000, + p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000, + mean_tpot_ms=np.mean(tpots or 0) * 1000, + median_tpot_ms=np.median(tpots or 0) * 1000, + std_tpot_ms=np.std(tpots or 0) * 1000, + p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000, + mean_itl_ms=np.mean(itls or 0) * 1000, + median_itl_ms=np.median(itls or 0) * 1000, + std_itl_ms=np.std(itls or 0) * 1000, + p99_itl_ms=np.percentile(itls or 0, 99) * 1000, + mean_e2e_latency_ms=np.mean(e2e_latencies) * 1000, + median_e2e_latency_ms=np.median(e2e_latencies) * 1000, + ) + + return metrics, output_lens + + +async def benchmark( + backend: str, + api_url: str, + model_id: str, + tokenizer: PreTrainedTokenizerBase, + input_requests: List[Tuple[str, int, int]], + request_rate: float, + disable_tqdm: bool, + enable_multi: bool, + extra_request_body: Dict[str, Any], +): + if backend in ASYNC_REQUEST_FUNCS: + request_func = ASYNC_REQUEST_FUNCS[backend] + else: + raise ValueError(f"Unknown backend: {backend}") + + print("Starting initial single prompt test run...") + test_prompt, test_prompt_len, test_output_len = input_requests[0] + test_input = RequestFuncInput( + model=model_id, + prompt=test_prompt, + api_url=api_url, + prompt_len=test_prompt_len, + output_len=test_output_len, + extra_request_body=extra_request_body, + ) + test_output = await request_func(request_func_input=test_input) + if not test_output.success: + raise ValueError( + "Initial test run failed - Please make sure benchmark arguments " + f"are correctly specified. Error: {test_output.error}" + ) + else: + print("Initial test run completed. Starting main benchmark run...") + + pbar = None if disable_tqdm else tqdm(total=len(input_requests)) + + benchmark_start_time = time.perf_counter() + tasks: List[asyncio.Task] = [] + async for request in get_request(input_requests, request_rate): + prompt, prompt_len, output_len = request + request_func_input = RequestFuncInput( + model=model_id, + prompt=prompt, + api_url=api_url, + prompt_len=prompt_len, + output_len=output_len, + extra_request_body=extra_request_body, + ) + tasks.append( + asyncio.create_task( + request_func(request_func_input=request_func_input, pbar=pbar) + ) + ) + outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks) + + if pbar is not None: + pbar.close() + + benchmark_duration = time.perf_counter() - benchmark_start_time + + metrics, output_lens = calculate_metrics( + input_requests=input_requests, + outputs=outputs, + dur_s=benchmark_duration, + tokenizer=tokenizer, + backend=backend, + ) + + print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="=")) + print("{:<40} {:<10}".format("Backend:", backend)) + print("{:<40} {:<10}".format("Traffic request rate:", request_rate)) + print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) + print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration)) + print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) + print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output)) + print( + "{:<40} {:<10}".format( + "Total generated tokens (retokenized):", metrics.total_output_retokenized + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Request throughput (req/s):", metrics.request_throughput + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Input token throughput (tok/s):", metrics.input_throughput + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Output token throughput (tok/s):", metrics.output_throughput + ) + ) + print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-")) + print( + "{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms) + ) + print( + "{:<40} {:<10.2f}".format( + "Median E2E Latency (ms):", metrics.median_e2e_latency_ms + ) + ) + print("{s:{c}^{n}}".format(s="Time to First Token", n=50, c="-")) + print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms)) + print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms)) + print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms)) + print( + "{s:{c}^{n}}".format(s="Time per Output Token (excl. 1st token)", n=50, c="-") + ) + print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms)) + print("{:<40} {:<10.2f}".format("Median TPOT (ms):", metrics.median_tpot_ms)) + print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms)) + print("{s:{c}^{n}}".format(s="Inter-token Latency", n=50, c="-")) + print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms)) + print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms)) + print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms)) + print("=" * 50) + + if ( + metrics.median_ttft_ms is not None + and metrics.mean_itl_ms is not None + and metrics.output_throughput is not None + ): + result = { + "backend": args.backend, + "dataset_name": args.dataset_name, + "request_rate": request_rate, + "total_input": metrics.total_input, + "total_output": metrics.total_output, + "total_output_retokenized": metrics.total_output_retokenized, + "mean_e2e_latency": metrics.mean_e2e_latency_ms, + "median_e2e_latency": metrics.median_e2e_latency_ms, + "median_ttft": metrics.median_ttft_ms, + "median_itl": metrics.median_itl_ms, + "output_token_throughput": metrics.output_throughput, + "sharegpt_output_len": args.sharegpt_output_len, + "random_input_len": args.random_input_len, + "random_output_len": args.random_output_len, + "random_range_ratio": args.random_range_ratio, + "benchmark_duration": benchmark_duration, + } + else: + print(f"Error running benchmark for request rate: {request_rate}") + print("-" * 30) + + # Determine output file name + if args.output_file: + output_file_name = args.output_file + else: + now = datetime.now().strftime("%m%d") + if args.dataset_name == "random": + output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl" + else: + output_file_name = f"{args.backend}_{now}_{args.num_prompts}_sharegpt.jsonl" + + # Append results to a JSONL file + with open(output_file_name, "a") as file: + file.write(json.dumps(result) + "\n") + + result = { + "duration": benchmark_duration, + "completed": metrics.completed, + "total_input_tokens": metrics.total_input, + "total_output_tokens": metrics.total_output, + "total_output_tokens_retokenized": metrics.total_output_retokenized, + "request_throughput": metrics.request_throughput, + "input_throughput": metrics.input_throughput, + "output_throughput": metrics.output_throughput, + "mean_ttft_ms": metrics.mean_ttft_ms, + "median_ttft_ms": metrics.median_ttft_ms, + "std_ttft_ms": metrics.std_ttft_ms, + "p99_ttft_ms": metrics.p99_ttft_ms, + "mean_tpot_ms": metrics.mean_tpot_ms, + "median_tpot_ms": metrics.median_tpot_ms, + "std_tpot_ms": metrics.std_tpot_ms, + "p99_tpot_ms": metrics.p99_tpot_ms, + "mean_itl_ms": metrics.mean_itl_ms, + "median_itl_ms": metrics.median_itl_ms, + "std_itl_ms": metrics.std_itl_ms, + "p99_itl_ms": metrics.p99_itl_ms, + "input_lens": [output.prompt_len for output in outputs], + "output_lens": output_lens, + "ttfts": [output.ttft for output in outputs], + "itls": [output.itl for output in outputs], + "generated_texts": [output.generated_text for output in outputs], + "errors": [output.error for output in outputs], + "mean_e2e_latency_ms": metrics.mean_e2e_latency_ms, + "median_e2e_latency_ms": metrics.median_e2e_latency_ms, + } + return result + + +def parse_request_rate_range(request_rate_range): + if len(request_rate_range.split(",")) == 3: + start, stop, step = map(int, request_rate_range.split(",")) + return list(range(start, stop, step)) + else: + return list(map(int, request_rate_range.split(","))) + + +def check_chat_template(model_path): + try: + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + return "chat_template" in tokenizer.init_kwargs + except Exception as e: + print(f"Fail to load tokenizer config with error={e}") + return False + + +def fire(args: argparse.Namespace): + random.seed(args.seed) + np.random.seed(args.seed) + + extra_request_body = {} + if args.extra_request_body: + extra_request_body = json.loads(args.extra_request_body) + + if args.port is None: + args.port = { + "sglang": 30000, + "lmdeploy": 23333, + "vllm": 8000, + "trt": 8000, + }.get(args.backend, 30000) + + api_url = ( + f"{args.base_url}/v1/completions" + if args.base_url + else f"http://{args.host}:{args.port}/v1/completions" + ) + model_url = ( + f"{args.base_url}/v1/models" + if args.base_url + else f"http://{args.host}:{args.port}/v1/models" + ) + + if args.backend == "trt": + api_url = ( + f"{args.base_url}/v2/models/ensemble/generate_stream" + if args.base_url + else f"http://{args.host}:{args.port}/v2/models/ensemble/generate_stream" + ) + if args.model is None: + print("Please provide a model using `--model` when using `trt` backend.") + sys.exit(1) + + if args.model is None: + try: + response = requests.get(model_url) + model_list = response.json().get("data", []) + args.model = model_list[0]["id"] if model_list else None + except Exception as e: + print(f"Failed to fetch model from {model_url}. Error: {e}") + print( + "Please specify the correct host and port using `--host` and `--port`." + ) + sys.exit(1) + + if args.model is None: + print("No model specified or found. Please provide a model using `--model`.") + sys.exit(1) + + if not check_chat_template(args.model): + print( + "\nWARNING It is recommended to use the `Chat` or `Instruct` model for benchmarking.\n" + "Because when the tokenizer counts the output tokens, if there is gibberish, it might count incorrectly.\n" + ) + + print(f"{args}\n") + + backend = args.backend + model_id = args.model + tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model + + tokenizer = get_tokenizer(tokenizer_id) + + if args.dataset_name == "sharegpt": + input_requests = sample_sharegpt_requests( + dataset_path=args.dataset_path, + num_requests=args.num_prompts, + tokenizer=tokenizer, + fixed_output_len=args.sharegpt_output_len, + ) + elif args.dataset_name == "random": + input_requests = sample_random_requests( + input_len=args.random_input_len, + output_len=args.random_output_len, + num_prompts=args.num_prompts, + range_ratio=args.random_range_ratio, + tokenizer=tokenizer, + dataset_path=args.dataset_path, + ) + else: + raise ValueError(f"Unknown dataset: {args.dataset_name}") + + if args.multi: + request_rates = parse_request_rate_range(args.request_rate_range) + + for rate in request_rates: + asyncio.run( + benchmark( + backend=backend, + api_url=api_url, + model_id=model_id, + tokenizer=tokenizer, + input_requests=input_requests, + request_rate=rate, + disable_tqdm=args.disable_tqdm, + enable_multi=args.multi, + extra_request_body=extra_request_body, + ) + ) + else: + asyncio.run( + benchmark( + backend=backend, + api_url=api_url, + model_id=model_id, + tokenizer=tokenizer, + input_requests=input_requests, + request_rate=args.request_rate, + disable_tqdm=args.disable_tqdm, + enable_multi=args.multi, + extra_request_body=extra_request_body, + ) + ) + + +# to avoid relying on SGLang's components +def set_ulimit(target_soft_limit=65535): + resource_type = resource.RLIMIT_NOFILE + current_soft, current_hard = resource.getrlimit(resource_type) + + if current_soft < target_soft_limit: + try: + resource.setrlimit(resource_type, (target_soft_limit, current_hard)) + except ValueError as e: + print(f"Fail to set RLIMIT_NOFILE: {e}") + + +if __name__ == "__main__": + parser = ArgumentParser(description="Benchmark the online serving throughput.") + parser.add_argument( + "--backend", + type=str, + choices=list(ASYNC_REQUEST_FUNCS.keys()), + default="sglang", + help="Must specify a backend, depending on the LLM Inference Engine.", + ) + parser.add_argument( + "--base-url", + type=str, + default=None, + help="Server or API base url if not using http host and port.", + ) + parser.add_argument( + "--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0." + ) + parser.add_argument( + "--port", + type=int, + help="If not set, the default port is configured according to its default value for different LLM Inference Engines.", + ) + parser.add_argument( + "--dataset-name", + type=str, + default="sharegpt", + choices=["sharegpt", "random"], + help="Name of the dataset to benchmark on.", + ) + parser.add_argument( + "--dataset-path", type=str, default="", help="Path to the dataset." + ) + parser.add_argument( + "--model", + type=str, + help="Name or path of the model. If not set, the default model will request /v1/models for conf.", + ) + parser.add_argument( + "--tokenizer", + type=str, + help="Name or path of the tokenizer. If not set, using the model conf.", + ) + parser.add_argument( + "--num-prompts", + type=int, + default=1000, + help="Number of prompts to process. Default is 1000.", + ) + parser.add_argument( + "--sharegpt-output-len", + type=int, + default=None, + help="Output length for each request. Overrides the output length from the ShareGPT dataset.", + ) + parser.add_argument( + "--random-input-len", + type=int, + default=1024, + help="Number of input tokens per request, used only for random dataset.", + ) + parser.add_argument( + "--random-output-len", + type=int, + default=128, + help="Number of output tokens per request, used only for random dataset.", + ) + parser.add_argument( + "--random-range-ratio", + type=float, + default=0.0, + help="Range of sampled ratio of input/output length, " + "used only for random dataset.", + ) + parser.add_argument( + "--request-rate", + type=float, + default=float("inf"), + help="Number of requests per second. If this is inf, then all the requests are sent at time 0. " + "Otherwise, we use Poisson process to synthesize the request arrival times. Default is 128.0.", + ) + parser.add_argument("--seed", type=int, default=0, help="Default is 0.") + parser.add_argument( + "--disable-tqdm", + action="store_true", + help="Specify to disable tqdm progress bar.", + ) + parser.add_argument( + "--multi", + action="store_true", + help="Use request rate range rather than single value.", + ) + parser.add_argument( + "--request-rate-range", + type=str, + default="2,34,2", + help="Range of request rates in the format start,stop,step. Default is 2,34,2. It also supports a list of request rates, requiring the parameters to not equal three.", + ) + parser.add_argument("--output-file", type=str, help="Output JSONL file name.") + parser.add_argument( + "--disable-stream", + action="store_true", + help="Disable streaming mode.", + ) + parser.add_argument( + "--disable-ignore-eos", + action="store_true", + help="Disable ignoring EOS.", + ) + parser.add_argument( + "--extra-request-body", + metavar='{"key1": "value1", "key2": "value2"}', + type=str, + help="Append given JSON object to the request payload. You can use this to specify" + "additional generate params like sampling params.", + ) + + set_ulimit() + + args = parser.parse_args() + fire(args) diff --git a/python/sglang/check_env.py b/python/sglang/check_env.py new file mode 100644 index 00000000000..cc8ba10e00c --- /dev/null +++ b/python/sglang/check_env.py @@ -0,0 +1,195 @@ +"""Check environment configurations and dependency versions.""" + +import importlib +import os +import resource +import subprocess +import sys +from collections import OrderedDict, defaultdict + +import torch + +# List of packages to check versions for +PACKAGE_LIST = [ + "sglang", + "flashinfer", + "triton", + "transformers", + "requests", + "tqdm", + "numpy", + "aiohttp", + "fastapi", + "hf_transfer", + "huggingface_hub", + "interegular", + "packaging", + "PIL", + "psutil", + "pydantic", + "uvicorn", + "uvloop", + "zmq", + "vllm", + "outlines", + "multipart", + "openai", + "tiktoken", + "anthropic", + "litellm", +] + + +def get_package_versions(packages): + """ + Get versions of specified packages. + """ + versions = {} + for package in packages: + package_name = package.split("==")[0].split(">=")[0].split("<=")[0] + try: + module = importlib.import_module(package_name) + if hasattr(module, "__version__"): + versions[package_name] = module.__version__ + except ModuleNotFoundError: + versions[package_name] = "Module Not Found" + return versions + + +def get_cuda_info(): + """ + Get CUDA-related information if available. + """ + cuda_info = {"CUDA available": torch.cuda.is_available()} + + if cuda_info["CUDA available"]: + cuda_info.update(_get_gpu_info()) + cuda_info.update(_get_cuda_version_info()) + + return cuda_info + + +def _get_gpu_info(): + """ + Get information about available GPUs. + """ + devices = defaultdict(list) + capabilities = defaultdict(list) + for k in range(torch.cuda.device_count()): + devices[torch.cuda.get_device_name(k)].append(str(k)) + capability = torch.cuda.get_device_capability(k) + capabilities[f"{capability[0]}.{capability[1]}"].append(str(k)) + + gpu_info = {} + for name, device_ids in devices.items(): + gpu_info[f"GPU {','.join(device_ids)}"] = name + + if len(capabilities) == 1: + # All GPUs have the same compute capability + cap, gpu_ids = list(capabilities.items())[0] + gpu_info[f"GPU {','.join(gpu_ids)} Compute Capability"] = cap + else: + # GPUs have different compute capabilities + for cap, gpu_ids in capabilities.items(): + gpu_info[f"GPU {','.join(gpu_ids)} Compute Capability"] = cap + + return gpu_info + + +def _get_cuda_version_info(): + """ + Get CUDA version information. + """ + from torch.utils.cpp_extension import CUDA_HOME + + cuda_info = {"CUDA_HOME": CUDA_HOME} + + if CUDA_HOME and os.path.isdir(CUDA_HOME): + cuda_info.update(_get_nvcc_info()) + cuda_info.update(_get_cuda_driver_version()) + + return cuda_info + + +def _get_nvcc_info(): + """ + Get NVCC version information. + """ + from torch.utils.cpp_extension import CUDA_HOME + + try: + nvcc = os.path.join(CUDA_HOME, "bin/nvcc") + nvcc_output = ( + subprocess.check_output(f'"{nvcc}" -V', shell=True).decode("utf-8").strip() + ) + return { + "NVCC": nvcc_output[ + nvcc_output.rfind("Cuda compilation tools") : nvcc_output.rfind("Build") + ].strip() + } + except subprocess.SubprocessError: + return {"NVCC": "Not Available"} + + +def _get_cuda_driver_version(): + """ + Get CUDA driver version. + """ + versions = set() + try: + output = subprocess.check_output( + [ + "nvidia-smi", + "--query-gpu=driver_version", + "--format=csv,noheader,nounits", + ] + ) + versions = set(output.decode().strip().split("\n")) + if len(versions) == 1: + return {"CUDA Driver Version": versions.pop()} + else: + return {"CUDA Driver Versions": ", ".join(sorted(versions))} + except subprocess.SubprocessError: + return {"CUDA Driver Version": "Not Available"} + + +def get_gpu_topology(): + """ + Get GPU topology information. + """ + try: + result = subprocess.run( + ["nvidia-smi", "topo", "-m"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + check=True, + ) + return "\n" + result.stdout if result.returncode == 0 else None + except subprocess.SubprocessError: + return None + + +def check_env(): + """ + Check and print environment information. + """ + env_info = OrderedDict() + env_info["Python"] = sys.version.replace("\n", "") + env_info.update(get_cuda_info()) + env_info["PyTorch"] = torch.__version__ + env_info.update(get_package_versions(PACKAGE_LIST)) + + gpu_topo = get_gpu_topology() + if gpu_topo: + env_info["NVIDIA Topology"] = gpu_topo + + ulimit_soft, _ = resource.getrlimit(resource.RLIMIT_NOFILE) + env_info["ulimit soft"] = ulimit_soft + + for k, v in env_info.items(): + print(f"{k}: {v}") + + +if __name__ == "__main__": + check_env() diff --git a/python/sglang/global_config.py b/python/sglang/global_config.py index 00340b59a36..b02ce9f81ea 100644 --- a/python/sglang/global_config.py +++ b/python/sglang/global_config.py @@ -8,36 +8,42 @@ def __init__(self): # 2: output final text after every run self.verbosity = 0 + # Default backend of the language self.default_backend = None - # Output configs + # Runtime constants: Request dependency time due to network delay + self.request_dependency_delay = 0.02 + self.wait_for_new_request_delay = 0.0006 + + # Runtime constants: New generation token ratio estimation + self.init_new_token_ratio = 0.7 + self.base_min_new_token_ratio = 0.1 + self.new_token_ratio_decay = 0.001 + + # Runtime constants: The threshold (number of tokens) to trigger layer-wise cuda sync. + # This can improve the speed for large batch sizes during prefill. + self.layer_sync_threshold = 8192 + + # Runtime constants: others + self.num_continue_decode_steps = 10 + self.retract_decode_steps = 20 + self.flashinfer_workspace_size = 192 * 1024 * 1024 + + # Output tokenization configs self.skip_special_tokens_in_output = True self.spaces_between_special_tokens_in_out = True - # Optimization configs + # Interpreter optimization configs self.eager_fill_image = False self.enable_precache_with_tracing = True self.enable_parallel_encoding = True self.enable_parallel_decoding = True + # Deprecated # Choices: ["no_adjust", "adjust_cache"] # no_adjust: Do not adjust the position embedding of KV cache. # adjust_cache: Adjust the position embedding of KV cache. self.concate_and_append_mode = "no_adjust" - # Request dependency time due to network delay - self.request_dependency_delay = 0.02 - self.wait_for_new_request_delay = 0.0006 - - # New generation token ratio estimation - self.base_new_token_ratio = 0.4 - self.base_min_new_token_ratio = 0.2 - self.new_token_ratio_decay = 0.0001 - self.new_token_ratio_recovery = 0.05 - - # The threshold (number of tokens) to trigger layer-wise cuda sync. - # This can improve the speed for large batch sizes during prefill. - self.layer_sync_threshold = 8192 - global_config = GlobalConfig() diff --git a/python/sglang/backend/__init__.py b/python/sglang/lang/backend/__init__.py similarity index 100% rename from python/sglang/backend/__init__.py rename to python/sglang/lang/backend/__init__.py diff --git a/python/sglang/backend/anthropic.py b/python/sglang/lang/backend/anthropic.py similarity index 97% rename from python/sglang/backend/anthropic.py rename to python/sglang/lang/backend/anthropic.py index d96d0f04fbf..5a36bd9ac8a 100644 --- a/python/sglang/backend/anthropic.py +++ b/python/sglang/lang/backend/anthropic.py @@ -2,7 +2,7 @@ import numpy as np -from sglang.backend.base_backend import BaseBackend +from sglang.lang.backend.base_backend import BaseBackend from sglang.lang.chat_template import get_chat_template from sglang.lang.interpreter import StreamExecutor from sglang.lang.ir import SglSamplingParams diff --git a/python/sglang/backend/base_backend.py b/python/sglang/lang/backend/base_backend.py similarity index 91% rename from python/sglang/backend/base_backend.py rename to python/sglang/lang/backend/base_backend.py index cb504f51b74..185f2e297ae 100644 --- a/python/sglang/backend/base_backend.py +++ b/python/sglang/lang/backend/base_backend.py @@ -1,6 +1,7 @@ from typing import Callable, List, Optional, Union from sglang.lang.chat_template import get_chat_template +from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod from sglang.lang.interpreter import StreamExecutor from sglang.lang.ir import SglSamplingParams @@ -64,7 +65,8 @@ def select( s: StreamExecutor, choices: List[str], temperature: float, - ): + choices_method: Optional[ChoicesSamplingMethod] = None, + ) -> ChoicesDecision: raise NotImplementedError() def concatenate_and_append(self, src_rids: List[str], dst_rid: str): diff --git a/python/sglang/backend/litellm.py b/python/sglang/lang/backend/litellm.py similarity index 95% rename from python/sglang/backend/litellm.py rename to python/sglang/lang/backend/litellm.py index d9b4023caf0..5803b5431e2 100644 --- a/python/sglang/backend/litellm.py +++ b/python/sglang/lang/backend/litellm.py @@ -1,6 +1,6 @@ from typing import Mapping, Optional -from sglang.backend.base_backend import BaseBackend +from sglang.lang.backend.base_backend import BaseBackend from sglang.lang.chat_template import get_chat_template_by_model_path from sglang.lang.interpreter import StreamExecutor from sglang.lang.ir import SglSamplingParams @@ -61,7 +61,7 @@ def generate( model=self.model_name, messages=messages, **self.client_params, - **sampling_params.to_anthropic_kwargs(), + **sampling_params.to_litellm_kwargs(), ) comp = ret.choices[0].message.content diff --git a/python/sglang/backend/openai.py b/python/sglang/lang/backend/openai.py similarity index 97% rename from python/sglang/backend/openai.py rename to python/sglang/lang/backend/openai.py index 6f65f4eab82..6fa93d9b2eb 100644 --- a/python/sglang/backend/openai.py +++ b/python/sglang/lang/backend/openai.py @@ -6,8 +6,9 @@ import numpy as np -from sglang.backend.base_backend import BaseBackend +from sglang.lang.backend.base_backend import BaseBackend from sglang.lang.chat_template import ChatTemplate, get_chat_template_by_model_path +from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod from sglang.lang.interpreter import StreamExecutor from sglang.lang.ir import SglSamplingParams @@ -18,7 +19,7 @@ openai = tiktoken = e -logger = logging.getLogger("openai") +logger = logging.getLogger(__name__) def create_logit_bias_int(tokenizer): @@ -296,7 +297,9 @@ def select( s: StreamExecutor, choices: List[str], temperature: float, - ): + choices_method: ChoicesSamplingMethod, + ) -> ChoicesDecision: + """Note: `choices_method` is not used by the OpenAI backend.""" if self.is_chat_model: raise NotImplementedError( "select/choices is not supported for chat models. " @@ -354,8 +357,10 @@ def select( prompt_tokens.append(ret_token) - decision = choices[np.argmax(scores)] - return decision, scores, None, None + return ChoicesDecision( + decision=choices[np.argmax(scores)], + meta_info={"scores": scores}, + ) def openai_completion( diff --git a/python/sglang/backend/runtime_endpoint.py b/python/sglang/lang/backend/runtime_endpoint.py similarity index 79% rename from python/sglang/backend/runtime_endpoint.py rename to python/sglang/lang/backend/runtime_endpoint.py index da27a57e943..7f0db5b3599 100644 --- a/python/sglang/backend/runtime_endpoint.py +++ b/python/sglang/lang/backend/runtime_endpoint.py @@ -1,11 +1,14 @@ import json from typing import List, Optional -import numpy as np - -from sglang.backend.base_backend import BaseBackend from sglang.global_config import global_config +from sglang.lang.backend.base_backend import BaseBackend from sglang.lang.chat_template import get_chat_template_by_model_path +from sglang.lang.choices import ( + ChoicesDecision, + ChoicesSamplingMethod, + token_length_normalized, +) from sglang.lang.interpreter import StreamExecutor from sglang.lang.ir import SglSamplingParams from sglang.utils import http_request @@ -16,7 +19,6 @@ class RuntimeEndpoint(BaseBackend): def __init__( self, base_url: str, - auth_token: Optional[str] = None, api_key: Optional[str] = None, verify: Optional[str] = None, ): @@ -24,13 +26,11 @@ def __init__( self.support_concate_and_append = True self.base_url = base_url - self.auth_token = auth_token self.api_key = api_key self.verify = verify res = http_request( self.base_url + "/get_model_info", - auth_token=self.auth_token, api_key=self.api_key, verify=self.verify, ) @@ -38,7 +38,8 @@ def __init__( self.model_info = res.json() self.chat_template = get_chat_template_by_model_path( - self.model_info["model_path"]) + self.model_info["model_path"] + ) def get_model_name(self): return self.model_info["model_path"] @@ -46,7 +47,7 @@ def get_model_name(self): def flush_cache(self): res = http_request( self.base_url + "/flush_cache", - auth_token=self.auth_token, + api_key=self.api_key, verify=self.verify, ) self._assert_success(res) @@ -54,7 +55,7 @@ def flush_cache(self): def get_server_args(self): res = http_request( self.base_url + "/get_server_args", - auth_token=self.auth_token, + api_key=self.api_key, verify=self.verify, ) self._assert_success(res) @@ -67,7 +68,6 @@ def cache_prefix(self, prefix_str: str): res = http_request( self.base_url + "/generate", json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}}, - auth_token=self.auth_token, api_key=self.api_key, verify=self.verify, ) @@ -79,7 +79,6 @@ def commit_lazy_operations(self, s: StreamExecutor): res = http_request( self.base_url + "/generate", json=data, - auth_token=self.auth_token, api_key=self.api_key, verify=self.verify, ) @@ -91,7 +90,6 @@ def fill_image(self, s: StreamExecutor): res = http_request( self.base_url + "/generate", json=data, - auth_token=self.auth_token, api_key=self.api_key, verify=self.verify, ) @@ -124,7 +122,12 @@ def generate( else: raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}") - for item in ["return_logprob", "logprob_start_len", "top_logprobs_num", "return_text_in_logprobs"]: + for item in [ + "return_logprob", + "logprob_start_len", + "top_logprobs_num", + "return_text_in_logprobs", + ]: value = getattr(sampling_params, item, None) if value is not None: data[item] = value @@ -134,7 +137,6 @@ def generate( res = http_request( self.base_url + "/generate", json=data, - auth_token=self.auth_token, api_key=self.api_key, verify=self.verify, ) @@ -171,7 +173,12 @@ def generate_stream( else: raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}") - for item in ["return_logprob", "logprob_start_len", "top_logprobs_num", "return_text_in_logprobs"]: + for item in [ + "return_logprob", + "logprob_start_len", + "top_logprobs_num", + "return_text_in_logprobs", + ]: value = getattr(sampling_params, item, None) if value is not None: data[item] = value @@ -183,7 +190,6 @@ def generate_stream( self.base_url + "/generate", json=data, stream=True, - auth_token=self.auth_token, api_key=self.api_key, verify=self.verify, ) @@ -206,21 +212,14 @@ def select( s: StreamExecutor, choices: List[str], temperature: float, - ): + choices_method: ChoicesSamplingMethod, + ) -> ChoicesDecision: assert temperature <= 1e-5 # Cache common prefix data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}} - self._add_images(s, data) - res = http_request( - self.base_url + "/generate", - json=data, - auth_token=self.auth_token, - api_key=self.api_key, - verify=self.verify, - ) - self._assert_success(res) - prompt_len = res.json()["meta_info"]["prompt_tokens"] + obj = self._generate_http_request(s, data) + prompt_len = obj["meta_info"]["prompt_tokens"] # Compute logprob data = { @@ -229,40 +228,57 @@ def select( "return_logprob": True, "logprob_start_len": max(prompt_len - 2, 0), } - self._add_images(s, data) - res = http_request( - self.base_url + "/generate", - json=data, - auth_token=self.auth_token, - api_key=self.api_key, - verify=self.verify, - ) - self._assert_success(res) - obj = res.json() + obj = self._generate_http_request(s, data) + normalized_prompt_logprobs = [ r["meta_info"]["normalized_prompt_logprob"] for r in obj ] - decision = choices[np.argmax(normalized_prompt_logprobs)] - prefill_token_logprobs = [r["meta_info"]["prefill_token_logprobs"] for r in obj] - decode_token_logprobs = [r["meta_info"]["decode_token_logprobs"] for r in obj] - - return ( - decision, - normalized_prompt_logprobs, - prefill_token_logprobs, - decode_token_logprobs, + input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj] + output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj] + + # Compute unconditional logprobs if required + if choices_method.requires_unconditional_logprobs: + input_ids = [[el[1] for el in subl] for subl in input_token_logprobs] + data = { + "input_ids": input_ids, + "sampling_params": {"max_new_tokens": 0}, + "return_logprob": True, + } + obj = self._generate_http_request(s, data) + unconditional_token_logprobs = [ + r["meta_info"]["input_token_logprobs"] for r in obj + ] + else: + unconditional_token_logprobs = None + + return choices_method( + choices=choices, + normalized_prompt_logprobs=normalized_prompt_logprobs, + input_token_logprobs=input_token_logprobs, + output_token_logprobs=output_token_logprobs, + unconditional_token_logprobs=unconditional_token_logprobs, ) def concatenate_and_append(self, src_rids: List[str], dst_rid: str): res = http_request( self.base_url + "/concate_and_append_request", json={"src_rids": src_rids, "dst_rid": dst_rid}, - auth_token=self.auth_token, api_key=self.api_key, verify=self.verify, ) self._assert_success(res) + def _generate_http_request(self, s: StreamExecutor, data): + self._add_images(s, data) + res = http_request( + self.base_url + "/generate", + json=data, + api_key=self.api_key, + verify=self.verify, + ) + self._assert_success(res) + return res.json() + def _add_images(self, s: StreamExecutor, data): if s.images_: assert len(s.images_) == 1, "Only support one image." diff --git a/python/sglang/backend/vertexai.py b/python/sglang/lang/backend/vertexai.py similarity index 94% rename from python/sglang/backend/vertexai.py rename to python/sglang/lang/backend/vertexai.py index f32fca2f408..c27733b3ee8 100644 --- a/python/sglang/backend/vertexai.py +++ b/python/sglang/lang/backend/vertexai.py @@ -1,10 +1,8 @@ import os import warnings -from typing import List, Optional, Union +from typing import Optional -import numpy as np - -from sglang.backend.base_backend import BaseBackend +from sglang.lang.backend.base_backend import BaseBackend from sglang.lang.chat_template import get_chat_template from sglang.lang.interpreter import StreamExecutor from sglang.lang.ir import SglSamplingParams @@ -21,7 +19,7 @@ class VertexAI(BaseBackend): - def __init__(self, model_name): + def __init__(self, model_name, safety_settings=None): super().__init__() if isinstance(GenerativeModel, Exception): @@ -33,6 +31,7 @@ def __init__(self, model_name): self.model_name = model_name self.chat_template = get_chat_template("default") + self.safety_settings = safety_settings def get_chat_template(self): return self.chat_template @@ -54,6 +53,7 @@ def generate( ret = GenerativeModel(self.model_name).generate_content( prompt, generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()), + safety_settings=self.safety_settings, ) comp = ret.text @@ -78,6 +78,7 @@ def generate_stream( prompt, stream=True, generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()), + safety_settings=self.safety_settings, ) for ret in generator: yield ret.text, {} diff --git a/python/sglang/lang/chat_template.py b/python/sglang/lang/chat_template.py index 273eb8c3b9f..bfde4bbdb6a 100644 --- a/python/sglang/lang/chat_template.py +++ b/python/sglang/lang/chat_template.py @@ -84,7 +84,7 @@ def get_chat_template_by_model_path(model_path): "system": ("SYSTEM:", "\n"), "user": ("USER:", "\n"), "assistant": ("ASSISTANT:", "\n"), - } + }, ) ) @@ -177,7 +177,7 @@ def get_chat_template_by_model_path(model_path): "assistant": ("", "<|im_end|>\n"), }, style=ChatTemplateStyle.PLAIN, - stop_str=("<|im_end|>",) + stop_str=("<|im_end|>",), ) ) diff --git a/python/sglang/lang/choices.py b/python/sglang/lang/choices.py new file mode 100644 index 00000000000..e52c6b36217 --- /dev/null +++ b/python/sglang/lang/choices.py @@ -0,0 +1,164 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +import numpy as np + + +@dataclass +class ChoicesDecision: + decision: str + meta_info: Optional[Dict[str, Any]] = None + + +class ChoicesSamplingMethod(ABC): + + @property + def requires_unconditional_logprobs(self) -> bool: + return False + + @abstractmethod + def __call__( + self, + *, + choices: List[str], + normalized_prompt_logprobs: List[float], + input_token_logprobs: List[List[Any]], + output_token_logprobs: List[List[Any]], + unconditional_token_logprobs: Optional[List[List[Any]]] = None, + ) -> ChoicesDecision: ... + + +class TokenLengthNormalized(ChoicesSamplingMethod): + + def __call__( + self, + *, + choices: List[str], + normalized_prompt_logprobs: List[float], + input_token_logprobs: List[List[Any]], + output_token_logprobs: List[List[Any]], + unconditional_token_logprobs: Optional[List[List[Any]]] = None, + ) -> ChoicesDecision: + """Select the option with the highest token length normalized prompt logprob.""" + best_choice = choices[np.argmax(normalized_prompt_logprobs)] + meta_info = { + "normalized_prompt_logprobs": normalized_prompt_logprobs, + "input_token_logprobs": input_token_logprobs, + "output_token_logprobs": output_token_logprobs, + } + return ChoicesDecision(decision=best_choice, meta_info=meta_info) + + +token_length_normalized = TokenLengthNormalized() + + +class GreedyTokenSelection(ChoicesSamplingMethod): + + def __call__( + self, + *, + choices: List[str], + normalized_prompt_logprobs: List[float], + input_token_logprobs: List[List[Any]], + output_token_logprobs: List[List[Any]], + unconditional_token_logprobs: Optional[List[List[Any]]] = None, + ) -> ChoicesDecision: + """Select the option based on greedy logprob selection. For overlapping options + where one option is a subset of a longer option, extend the shorter option using + its average logprob for comparison against the longer option.""" + + num_options = len(choices) + max_tokens = max(len(option) for option in input_token_logprobs) + logprob_matrix = self._build_logprob_matrix( + input_token_logprobs, max_tokens, num_options + ) + remaining = self._greedy_selection(logprob_matrix, num_options, max_tokens) + + best_choice = choices[remaining[0]] + meta_info = { + "normalized_prompt_logprobs": normalized_prompt_logprobs, + "input_token_logprobs": input_token_logprobs, + "output_token_logprobs": output_token_logprobs, + "greedy_logprob_matrix": logprob_matrix.tolist(), + } + return ChoicesDecision(decision=best_choice, meta_info=meta_info) + + def _build_logprob_matrix(self, input_token_logprobs, max_tokens, num_options): + logprob_matrix = np.zeros((num_options, max_tokens)) + for i, option in enumerate(input_token_logprobs): + actual_logprobs = [token[0] for token in option] + avg_logprob = np.mean(actual_logprobs) + logprob_matrix[i, : len(option)] = actual_logprobs + if len(option) < max_tokens: + logprob_matrix[i, len(option) :] = avg_logprob + return logprob_matrix + + def _greedy_selection(self, logprob_matrix, num_options, max_tokens): + remaining = np.arange(num_options) + for j in range(max_tokens): + max_logprob = np.max(logprob_matrix[remaining, j]) + remaining = remaining[logprob_matrix[remaining, j] == max_logprob] + if len(remaining) == 1: + break + return remaining + + +greedy_token_selection = GreedyTokenSelection() + + +class UnconditionalLikelihoodNormalized(ChoicesSamplingMethod): + + @property + def requires_unconditional_logprobs(self) -> bool: + return True + + def __call__( + self, + *, + choices: List[str], + normalized_prompt_logprobs: List[float], + input_token_logprobs: List[List[Any]], + output_token_logprobs: List[List[Any]], + unconditional_token_logprobs: Optional[List[List[Any]]] = None, + ) -> ChoicesDecision: + """Select the option with the highest average token logprob once normalized by + the unconditional token logprobs. + + The first unconditional token logprob is assumed to be None. If so, it is + replaced with 0 for the purposes of normalization.""" + + if unconditional_token_logprobs is None: + raise ValueError( + "Unconditional token logprobs are required for this method." + ) + + normalized_unconditional_prompt_logprobs = self._normalize_logprobs( + input_token_logprobs, unconditional_token_logprobs + ) + + best_choice = choices[np.argmax(normalized_unconditional_prompt_logprobs)] + meta_info = { + "normalized_prompt_logprobs": normalized_prompt_logprobs, + "input_token_logprobs": input_token_logprobs, + "output_token_logprobs": output_token_logprobs, + "unconditional_token_logprobs": unconditional_token_logprobs, + "normalized_unconditional_prompt_logprobs": normalized_unconditional_prompt_logprobs, + } + return ChoicesDecision(decision=best_choice, meta_info=meta_info) + + def _normalize_logprobs(self, input_token_logprobs, unconditional_token_logprobs): + normalized_unconditional_prompt_logprobs = [] + for inputs, unconditionals in zip( + input_token_logprobs, unconditional_token_logprobs + ): + inputs_logprobs = np.array([token[0] for token in inputs]) + unconditionals_logprobs = np.array([token[0] for token in unconditionals]) + unconditionals_logprobs[0] = unconditionals_logprobs[0] or 0 + normalized_unconditional_prompt_logprobs.append( + float(np.mean(inputs_logprobs - unconditionals_logprobs)) + ) + return normalized_unconditional_prompt_logprobs + + +unconditional_likelihood_normalized = UnconditionalLikelihoodNormalized() diff --git a/python/sglang/lang/compiler.py b/python/sglang/lang/compiler.py index 36287cd397c..95af04adb0a 100644 --- a/python/sglang/lang/compiler.py +++ b/python/sglang/lang/compiler.py @@ -125,7 +125,7 @@ def run_internal( def run( self, *, - max_new_tokens: int = 16, + max_new_tokens: int = 128, stop: Union[str, List[str]] = (), temperature: float = 1.0, top_p: float = 1.0, @@ -155,7 +155,7 @@ def run_batch( self, batch_kwargs, *, - max_new_tokens: int = 16, + max_new_tokens: int = 128, stop: Union[str, List[str]] = (), temperature: float = 1.0, top_p: float = 1.0, diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index 31999c40027..cf53fac3035 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -288,6 +288,7 @@ def fork( exes[i].text_ = str(self.text_) exes[i].messages_ = list(self.messages_) exes[i].cur_role = self.cur_role + exes[i].cur_role_begin_pos = self.cur_role_begin_pos exes[i].fork_start_text_pos = len(self.text_) exes[i].images_ = list(self.images_) @@ -537,22 +538,17 @@ def _execute_gen(self, expr: SglGen): self.stream_var_event[name].set() def _execute_select(self, expr: SglSelect): - ( - decision, - normalized_prompt_logprobs, - prefill_token_logprobs, - decode_token_logprobs, - ) = self.backend.select(self, expr.choices, expr.temperature) + choices_decision = self.backend.select( + self, expr.choices, expr.temperature, expr.choices_method + ) if expr.name is not None: name = expr.name - self.variables[name] = decision - self.meta_info[name] = { - "normalized_prompt_logprobs": normalized_prompt_logprobs, - "prefill_token_logprobs": prefill_token_logprobs, - "decode_token_logprobs": decode_token_logprobs, - } + self.variables[name] = choices_decision.decision + self.meta_info[name] = choices_decision.meta_info self.variable_event[name].set() - self.text_ += decision + if self.stream_var_event: + self.stream_var_event[name].set() + self.text_ += choices_decision.decision def _execute_variable(self, expr: SglVariable): src_executor = expr.source_stream_executor @@ -704,9 +700,9 @@ def __init__(self, stream_executor: StreamExecutor): def _role_common(self, name: str, expr: Optional[SglExpr] = None): if expr is not None: - self.stream_executor.submit( - SglExprList([SglRoleBegin(name), expr, SglRoleEnd(name)]) - ) + role_expr = SglExprList([SglRoleBegin(name), expr, SglRoleEnd(name)]) + self.stream_executor.submit(role_expr) + return role_expr else: @contextmanager @@ -777,7 +773,14 @@ def text_iter(self, var_name: Optional[str] = None): if self.stream_executor.is_finished: break else: - event = self.stream_executor.stream_var_event[var_name] + event = None + while not event: + if var_name in self.stream_executor.stream_var_event: + event = self.stream_executor.stream_var_event[var_name] + if self.stream_executor.is_finished: + yield "" + return + while True: event.wait() event.clear() @@ -812,7 +815,14 @@ async def text_async_iter( if self.stream_executor.is_finished: break else: - event = self.stream_executor.stream_var_event[var_name] + event = None + while not event: + if var_name in self.stream_executor.stream_var_event: + event = self.stream_executor.stream_var_event[var_name] + if self.stream_executor.is_finished: + yield "" + return + while True: await loop.run_in_executor(None, event.wait) event.clear() diff --git a/python/sglang/lang/ir.py b/python/sglang/lang/ir.py index 83c6f79b0b6..135110c1e0d 100644 --- a/python/sglang/lang/ir.py +++ b/python/sglang/lang/ir.py @@ -6,6 +6,7 @@ from typing import List, Optional, Union from sglang.global_config import global_config +from sglang.lang.choices import ChoicesSamplingMethod REGEX_INT = r"[-+]?[0-9]+" REGEX_FLOAT = r"[-+]?[0-9]*\.?[0-9]+" @@ -15,7 +16,7 @@ @dataclasses.dataclass class SglSamplingParams: - max_new_tokens: int = 16 + max_new_tokens: int = 128 stop: Union[str, List[str]] = () temperature: float = 1.0 top_p: float = 1.0 @@ -24,9 +25,9 @@ class SglSamplingParams: presence_penalty: float = 0.0 ignore_eos: bool = False return_logprob: Optional[bool] = None - logprob_start_len: Optional[int] = None, - top_logprobs_num: Optional[int] = None, - return_text_in_logprobs: Optional[bool] = None, + logprob_start_len: Optional[int] = (None,) + top_logprobs_num: Optional[int] = (None,) + return_text_in_logprobs: Optional[bool] = (None,) # for constrained generation, not included in to_xxx_kwargs dtype: Optional[str] = None @@ -99,7 +100,6 @@ def to_litellm_kwargs(self): "stop": self.stop or None, "temperature": self.temperature, "top_p": self.top_p, - "top_k": self.top_k, "frequency_penalty": self.frequency_penalty, "presence_penalty": self.presence_penalty, } @@ -140,7 +140,7 @@ def bind(self, **kwargs): def run( self, *args, - max_new_tokens: int = 16, + max_new_tokens: int = 128, stop: Union[str, List[str]] = (), temperature: float = 1.0, top_p: float = 1.0, @@ -179,7 +179,7 @@ def run_batch( self, batch_kwargs, *, - max_new_tokens: int = 16, + max_new_tokens: int = 128, stop: Union[str, List[str]] = (), temperature: float = 1.0, top_p: float = 1.0, @@ -410,7 +410,7 @@ def __init__( dtype: Optional[type] = None, regex: Optional[str] = None, ): - """Call the model to generate. See the meaning of the arguments in docs/sampling_params.md""" + """Call the model to generate. See the meaning of the arguments in docs/en/sampling_params.md""" super().__init__() self.name = name self.sampling_params = SglSamplingParams( @@ -462,14 +462,22 @@ def __repr__(self): class SglSelect(SglExpr): - def __init__(self, name: str, choices: List[str], temperature: float): + + def __init__( + self, + name: str, + choices: List[str], + temperature: float, + choices_method: ChoicesSamplingMethod, + ): super().__init__() self.name = name self.choices = choices self.temperature = temperature + self.choices_method = choices_method def __repr__(self): - return f"Select({self.name}, choices={self.choices})" + return f"Select({self.name}, choices={self.choices}, choices_method={self.choices_method})" class SglFork(SglExpr): diff --git a/python/sglang/lang/tracer.py b/python/sglang/lang/tracer.py index 53f772163f8..cfe9198bcab 100644 --- a/python/sglang/lang/tracer.py +++ b/python/sglang/lang/tracer.py @@ -3,8 +3,8 @@ import uuid from typing import Any, Callable, Dict, List, Optional, Union -from sglang.backend.base_backend import BaseBackend from sglang.global_config import global_config +from sglang.lang.backend.base_backend import BaseBackend from sglang.lang.interpreter import ProgramState, ProgramStateGroup from sglang.lang.ir import ( SglArgument, diff --git a/python/sglang/launch_server.py b/python/sglang/launch_server.py index cb2b07251ae..91dc0dc4e95 100644 --- a/python/sglang/launch_server.py +++ b/python/sglang/launch_server.py @@ -11,4 +11,4 @@ args = parser.parse_args() server_args = ServerArgs.from_cli_args(args) - launch_server(server_args, None) + launch_server(server_args) diff --git a/python/sglang/launch_server_llavavid.py b/python/sglang/launch_server_llavavid.py index b71d8701d64..c34dd211672 100644 --- a/python/sglang/launch_server_llavavid.py +++ b/python/sglang/launch_server_llavavid.py @@ -1,7 +1,6 @@ """Launch the inference server for Llava-video model.""" import argparse -import multiprocessing as mp from sglang.srt.server import ServerArgs, launch_server @@ -27,6 +26,4 @@ server_args = ServerArgs.from_cli_args(args) - pipe_reader, pipe_writer = mp.Pipe(duplex=False) - - launch_server(server_args, pipe_writer, model_overide_args) + launch_server(server_args, model_overide_args, None) diff --git a/python/sglang/srt/constrained/__init__.py b/python/sglang/srt/constrained/__init__.py index b6d5a73584e..7e097c6fc26 100644 --- a/python/sglang/srt/constrained/__init__.py +++ b/python/sglang/srt/constrained/__init__.py @@ -1,3 +1,18 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + import json from typing import Dict, Optional, Union diff --git a/python/sglang/srt/constrained/base_cache.py b/python/sglang/srt/constrained/base_tool_cache.py similarity index 68% rename from python/sglang/srt/constrained/base_cache.py rename to python/sglang/srt/constrained/base_tool_cache.py index 19139d97a66..4cbb6bd2265 100644 --- a/python/sglang/srt/constrained/base_cache.py +++ b/python/sglang/srt/constrained/base_tool_cache.py @@ -1,9 +1,24 @@ -"""Base cache class.""" +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Base tool cache for constrained decoding tools.""" import time -class BaseCache: +class BaseToolCache: def __init__(self, enable=True): self.enable = enable self.reset() diff --git a/python/sglang/srt/constrained/fsm_cache.py b/python/sglang/srt/constrained/fsm_cache.py index cc0609a58b9..6df6bec51ce 100644 --- a/python/sglang/srt/constrained/fsm_cache.py +++ b/python/sglang/srt/constrained/fsm_cache.py @@ -1,10 +1,25 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + """Cache for the compressed finite state machine.""" from sglang.srt.constrained import RegexGuide, TransformerTokenizer -from sglang.srt.constrained.base_cache import BaseCache +from sglang.srt.constrained.base_tool_cache import BaseToolCache -class FSMCache(BaseCache): +class FSMCache(BaseToolCache): def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True): super().__init__(enable=enable) @@ -21,7 +36,27 @@ def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True): tokenizer = AutoTokenizer.from_pretrained( tokenizer_path, **tokenizer_args_dict ) - self.outlines_tokenizer = TransformerTokenizer(tokenizer) + try: + self.outlines_tokenizer = TransformerTokenizer(tokenizer) + except AttributeError: + # FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0) + origin_pad_token_id = tokenizer.pad_token_id + + def fset(self, value): + self._value = value + + type(tokenizer).pad_token_id = property( + fget=type(tokenizer).pad_token_id.fget, fset=fset + ) + self.outlines_tokenizer = TransformerTokenizer(tokenizer) + self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id + self.outlines_tokenizer.pad_token_id = origin_pad_token_id + self.outlines_tokenizer.pad_token = ( + self.outlines_tokenizer.tokenizer.pad_token + ) + self.outlines_tokenizer.vocabulary = ( + self.outlines_tokenizer.tokenizer.get_vocab() + ) else: self.outlines_tokenizer = TransformerTokenizer( tokenizer_path, **tokenizer_args_dict diff --git a/python/sglang/srt/constrained/jump_forward.py b/python/sglang/srt/constrained/jump_forward.py index a2375e17221..7b694318e49 100644 --- a/python/sglang/srt/constrained/jump_forward.py +++ b/python/sglang/srt/constrained/jump_forward.py @@ -1,3 +1,18 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + """ Faster constrained decoding. Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/ @@ -15,7 +30,7 @@ make_byte_level_fsm, make_deterministic_fsm, ) -from sglang.srt.constrained.base_cache import BaseCache +from sglang.srt.constrained.base_tool_cache import BaseToolCache IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" @@ -136,7 +151,7 @@ def is_jump_forward_symbol_state(self, state): ) -class JumpForwardCache(BaseCache): +class JumpForwardCache(BaseToolCache): def __init__(self): super().__init__() diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py index b3988cc8237..5ee12169740 100644 --- a/python/sglang/srt/conversation.py +++ b/python/sglang/srt/conversation.py @@ -1,3 +1,18 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + """Conversation templates.""" # Adapted from @@ -6,7 +21,7 @@ from enum import IntEnum, auto from typing import Dict, List, Optional, Tuple, Union -from sglang.srt.openai_protocol import ChatCompletionRequest +from sglang.srt.openai_api.protocol import ChatCompletionRequest class SeparatorStyle(IntEnum): @@ -421,3 +436,14 @@ def generate_chat_conv( sep2="", ) ) + +# Reference: https://github.com/InternLM/lmdeploy/blob/387bf54b4f124e72aab30ae9755f562e435d3d01/lmdeploy/model.py#L425-L442 +register_conv_template( + Conversation( + name="internlm2-chat", + system_template="<|im_start|>system\n{system_message}", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep="\n", + stop_str=["<|im_end|>", "<|action_end|>"], + ) +) diff --git a/python/sglang/srt/flush_cache.py b/python/sglang/srt/flush_cache.py deleted file mode 100644 index 575ba96006d..00000000000 --- a/python/sglang/srt/flush_cache.py +++ /dev/null @@ -1,18 +0,0 @@ -""" -Flush the KV cache. - -Usage: -python3 -m sglang.srt.flush_cache --url http://localhost:30000 -""" - -import argparse - -import requests - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--url", type=str, default="http://localhost:30000") - args = parser.parse_args() - - response = requests.get(args.url + "/flush_cache") - assert response.status_code == 200 diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index 218af433cdd..508843a395c 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -1,22 +1,44 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + """Utilities for Huggingface Transformers.""" import functools import json import os import warnings -from typing import AbstractSet, Collection, Literal, Optional, Union +from typing import AbstractSet, Collection, Dict, List, Literal, Optional, Type, Union from huggingface_hub import snapshot_download from transformers import ( AutoConfig, AutoProcessor, AutoTokenizer, + PretrainedConfig, PreTrainedTokenizer, PreTrainedTokenizerFast, ) +from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig from sglang.srt.utils import is_multimodal_model +_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { + ChatGLMConfig.model_type: ChatGLMConfig, + DbrxConfig.model_type: DbrxConfig, +} + def download_from_hf(model_path: str): if os.path.exists(model_path): @@ -40,6 +62,9 @@ def get_config( config = AutoConfig.from_pretrained( model, trust_remote_code=trust_remote_code, revision=revision ) + if config.model_type in _CONFIG_REGISTRY: + config_class = _CONFIG_REGISTRY[config.model_type] + config = config_class.from_pretrained(model, revision=revision) if model_overide_args: config.update(model_overide_args) return config @@ -63,6 +88,10 @@ def get_context_length(config): rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling: rope_scaling_factor = config.rope_scaling["factor"] + if "original_max_position_embeddings" in rope_scaling: + rope_scaling_factor = 1 + if config.rope_scaling.get("rope_type", None) == "llama3": + rope_scaling_factor = 1 else: rope_scaling_factor = 1 @@ -230,7 +259,7 @@ def encode_patched( Literal["all"], AbstractSet[str] ] = set(), # noqa: B006 disallowed_special: Union[Literal["all"], Collection[str]] = "all", - ) -> list[int]: + ) -> List[int]: if isinstance(allowed_special, set): allowed_special |= self._default_allowed_special return tiktoken.Encoding.encode( diff --git a/python/sglang/srt/layers/context_flashattention_nopad.py b/python/sglang/srt/layers/context_flashattention_nopad.py index 0c3102c3fc8..a2dc2ff318f 100644 --- a/python/sglang/srt/layers/context_flashattention_nopad.py +++ b/python/sglang/srt/layers/context_flashattention_nopad.py @@ -1,11 +1,24 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + # Adapted from # https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1 import torch import triton import triton.language as tl -from sglang.srt.utils import wrap_kernel_launcher - CUDA_CAPABILITY = torch.cuda.get_device_capability() @@ -119,9 +132,6 @@ def _fwd_kernel( tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) -cached_kernel = None - - def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): if CUDA_CAPABILITY[0] >= 8: BLOCK = 128 @@ -139,29 +149,6 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) num_warps = 4 if Lk <= 64 else 8 - global cached_kernel - if cached_kernel: - cached_kernel( - grid, - num_warps, - q, - k, - v, - sm_scale, - b_start_loc, - b_seq_len, - o, - q.stride(0), - q.stride(1), - k.stride(0), - k.stride(1), - v.stride(0), - v.stride(1), - o.stride(0), - o.stride(1), - ) - return - _fwd_kernel[grid]( q, k, @@ -185,4 +172,3 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): num_warps=num_warps, num_stages=1, ) - cached_kernel = wrap_kernel_launcher(_fwd_kernel) diff --git a/python/sglang/srt/layers/extend_attention.py b/python/sglang/srt/layers/extend_attention.py index 41c2ca7d134..7398895d62d 100644 --- a/python/sglang/srt/layers/extend_attention.py +++ b/python/sglang/srt/layers/extend_attention.py @@ -1,9 +1,23 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + import torch import triton import triton.language as tl from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd -from sglang.srt.utils import wrap_kernel_launcher CUDA_CAPABILITY = torch.cuda.get_device_capability() @@ -43,6 +57,8 @@ def _fwd_kernel( stride_buf_vh, stride_req_to_tokens_b, BLOCK_DMODEL: tl.constexpr, + BLOCK_DPE: tl.constexpr, + BLOCK_DV: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, logit_cap: tl.constexpr, @@ -61,8 +77,10 @@ def _fwd_kernel( cur_batch_req_idx = tl.load(B_req_idx + cur_seq) offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) offs_m = tl.arange(0, BLOCK_M) mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend + offs_q = ( (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None]) * stride_qbs @@ -71,10 +89,20 @@ def _fwd_kernel( ) q = tl.load(Q_Extend + offs_q, mask=mask_m[:, None], other=0.0) + if BLOCK_DPE > 0: + offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) + offs_qpe = ( + (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None]) + * stride_qbs + + cur_head * stride_qh + + offs_dpe[None, :] + ) + qpe = tl.load(Q_Extend + offs_qpe, mask=mask_m[:, None], other=0.0) + # stage1: compute scores with prefix offs_n = tl.arange(0, BLOCK_N) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32) deno = tl.zeros([BLOCK_M], dtype=tl.float32) e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") @@ -96,6 +124,18 @@ def _fwd_kernel( qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) + if BLOCK_DPE > 0: + offs_kpe = ( + offs_kv_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_dpe[:, None] + ) + kpe = tl.load( + K_Buffer + offs_kpe, + mask=mask_n[None, :], + other=0.0, + ) + qk += tl.dot(qpe, kpe) qk *= sm_scale if logit_cap > 0: @@ -111,7 +151,7 @@ def _fwd_kernel( offs_buf_v = ( offs_kv_loc[:, None] * stride_buf_vbs + cur_kv_head * stride_buf_vh - + offs_d[None, :] + + offs_dv[None, :] ) v = tl.load(V_Buffer + offs_buf_v, mask=mask_n[:, None], other=0.0) p = p.to(v.dtype) @@ -136,6 +176,21 @@ def _fwd_kernel( qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) + + if BLOCK_DPE > 0: + offs_kpe = ( + (cur_seq_extend_start_contiguous + start_n + offs_n[None, :]) + * stride_kbs + + cur_kv_head * stride_kh + + offs_dpe[:, None] + ) + kpe = tl.load( + K_Extend + offs_kpe, + mask=mask_n[None, :], + other=0.0, + ) + qk += tl.dot(qpe, kpe) + qk *= sm_scale if logit_cap > 0: @@ -155,7 +210,7 @@ def _fwd_kernel( offs_v = ( (cur_seq_extend_start_contiguous + start_n + offs_n[:, None]) * stride_vbs + cur_kv_head * stride_vh - + offs_d[None, :] + + offs_dv[None, :] ) v = tl.load(V_Extend + offs_v, mask=mask_n[:, None], other=0.0) p = p.to(v.dtype) @@ -167,14 +222,11 @@ def _fwd_kernel( (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None]) * stride_obs + cur_head * stride_oh - + offs_d[None, :] + + offs_dv[None, :] ) tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None]) -cached_kernel = None - - def extend_attention_fwd( q_extend, k_extend, @@ -206,8 +258,17 @@ def extend_attention_fwd( o_extend.shape[-1], ) - assert Lq == Lk and Lk == Lv and Lv == Lo - assert Lq in {16, 32, 64, 128, 256} + assert Lq == Lk and Lv == Lo + assert Lq in {16, 32, 64, 128, 256, 576} + assert Lv in {16, 32, 64, 128, 256, 512} + + if Lq == 576: + BLOCK_DMODEL = 512 + BLOCK_DPE = 64 + else: + BLOCK_DMODEL = Lq + BLOCK_DPE = 0 + BLOCK_DV = Lv if CUDA_CAPABILITY[0] >= 8: BLOCK_M, BLOCK_N = (128, 128) if Lq <= 128 else (64, 64) @@ -222,40 +283,6 @@ def extend_attention_fwd( num_warps = 4 if Lk <= 64 else 8 num_stages = 1 - global cached_kernel - if cached_kernel: - cached_kernel( - grid, - num_warps, - q_extend, - k_extend, - v_extend, - o_extend, - k_buffer, - v_buffer, - req_to_tokens, - b_req_idx, - b_seq_len, - b_start_loc_extend, - b_seq_len_extend, - sm_scale, - kv_group_num, - q_extend.stride(0), - q_extend.stride(1), - k_extend.stride(0), - k_extend.stride(1), - v_extend.stride(0), - v_extend.stride(1), - o_extend.stride(0), - o_extend.stride(1), - k_buffer.stride(0), - k_buffer.stride(1), - v_buffer.stride(0), - v_buffer.stride(1), - req_to_tokens.stride(0), - ) - return - _fwd_kernel[grid]( q_extend, k_extend, @@ -283,14 +310,15 @@ def extend_attention_fwd( v_buffer.stride(0), v_buffer.stride(1), req_to_tokens.stride(0), - BLOCK_DMODEL=Lq, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DPE=BLOCK_DPE, + BLOCK_DV=BLOCK_DV, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, num_warps=num_warps, num_stages=num_stages, logit_cap=logit_cap, ) - cached_kernel = wrap_kernel_launcher(_fwd_kernel) def redundant_attention( diff --git a/python/sglang/srt/layers/fused_moe.py b/python/sglang/srt/layers/fused_moe.py index 7dddabb05f8..c5630fa5db4 100644 --- a/python/sglang/srt/layers/fused_moe.py +++ b/python/sglang/srt/layers/fused_moe.py @@ -1,3 +1,18 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + # Adapted from # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/layers/fused_moe/fused_moe.py#L1 """Fused MoE kernel.""" diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py new file mode 100644 index 00000000000..fb8891cb297 --- /dev/null +++ b/python/sglang/srt/layers/linear.py @@ -0,0 +1,884 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# temporarily adapted from https://github.com/vllm-project/vllm/blob/e76466dde2bc9525d55165ceaa600d298c7bf773/vllm/model_executor/layers/linear.py +# FIXME: refactor the linear abstraction +from abc import abstractmethod +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch.nn.parameter import Parameter +from vllm.distributed import ( + divide, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, +) +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from vllm.model_executor.utils import set_weight_attrs + +logger = init_logger(__name__) + + +def adjust_marlin_shard(param, shard_size, shard_offset): + marlin_tile_size = getattr(param, "marlin_tile_size", None) + if marlin_tile_size is None: + return shard_size, shard_offset + + return shard_size * marlin_tile_size, shard_offset * marlin_tile_size + + +def adjust_bitsandbytes_shard( + param: Parameter, qkv_offsets: Dict[str, Tuple[int, int]], loaded_shard_id: str +) -> Tuple[int, int]: + """Adjust the quantization offsets and sizes for BitsAndBytes sharding.""" + + total, _ = qkv_offsets["total"] + orig_offset, orig_size = qkv_offsets[loaded_shard_id] + + quantized_total = param.data.shape[0] + quantized_offset = orig_offset * quantized_total // total + quantized_size = orig_size * quantized_total // total + + return quantized_size, quantized_offset + + +def adjust_scalar_to_fused_array(param, loaded_weight, shard_id): + """For fused modules (QKV and MLP) we have an array of length + N that holds 1 scale for each "logical" matrix. So the param + is an array of length N. The loaded_weight corresponds to + one of the shards on disk. Here, we slice the param based on + the shard_id for loading. + """ + qkv_idxs = {"q": 0, "k": 1, "v": 2} + + if isinstance(shard_id, str): + shard_id = qkv_idxs[shard_id] + elif not isinstance(shard_id, int): + raise ValueError(f"Unknown Shard Id {shard_id}") + + # AutoFP8 scales do not have a shape + # compressed-tensors scales do have a shape + if len(loaded_weight.shape) != 0: + assert loaded_weight.shape[0] == 1 + loaded_weight = loaded_weight[0] + + return param[shard_id], loaded_weight + + +class LinearMethodBase(QuantizeMethodBase): + """Base class for different (maybe quantized) linear methods.""" + + @abstractmethod + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + """Create weights for a linear layer. + The weights will be set as attributes of the layer. + + Args: + layer: The layer that is using the LinearMethodBase factory. + input_size_per_partition: Size of the weight input dim on rank X. + output_partition_sizes: Sizes of the output dim of each logical + weight on rank X. E.g., output_partition_sizes for QKVLinear + is a list contains the width of Wq, Wk, Wv on rank X. + input_size: Size of the input dim of the weight across all ranks. + output_size: Size of the output dim of the weight across all ranks. + params_dtype: Datatype of the parameters. + """ + raise NotImplementedError + + @abstractmethod + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Apply the weights in layer to the input tensor. + Expects create_weights to have been called before on the layer.""" + raise NotImplementedError + + +class UnquantizedLinearMethod(LinearMethodBase): + """Linear method without quantization. + + Args: + separate_bias_add: If true, add bias separately after matrix + multiplication. + """ + + def __init__(self, separate_bias_add: bool = False): + self.separate_bias_add = separate_bias_add + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + weight = Parameter( + torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + weight = layer.weight + if self.separate_bias_add: + if bias is not None: + return F.linear(x, weight) + bias + return F.linear(x, weight) + return F.linear(x, weight, bias) + + +class LinearBase(torch.nn.Module): + """Base linear layer. + + Args: + input_size: input dimension of the linear layer. + output_size: output dimension of the linear layer. + bias: If true, add bias. + skip_bias_add: If true, skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + """ + + def __init__( + self, + input_size: int, + output_size: int, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.skip_bias_add = skip_bias_add + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + if quant_config is None: + self.quant_method: Optional[QuantizeMethodBase] = UnquantizedLinearMethod() + else: + self.quant_method = quant_config.get_quant_method(self) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + +class ReplicatedLinear(LinearBase): + """Replicated linear layer. + + Args: + input_size: input dimension of the linear layer. + output_size: output dimension of the linear layer. + bias: If true, add bias. + skip_bias_add: If true, skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + """ + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__( + input_size, output_size, skip_bias_add, params_dtype, quant_config + ) + + # All the linear layer supports quant method. + assert self.quant_method is not None + self.quant_method.create_weights( + self, + self.input_size, + [self.output_size], + self.input_size, + self.output_size, + self.params_dtype, + ) + + if bias: + self.bias = Parameter( + torch.empty(self.output_size, dtype=self.params_dtype) + ) + set_weight_attrs(self.bias, {"output_dim": 0}) + else: + self.register_parameter("bias", None) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + bias = self.bias if not self.skip_bias_add else None + assert self.quant_method is not None + output = self.quant_method.apply(self, x, bias) + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + def extra_repr(self) -> str: + s = f"in_features={self.input_size}" + s += f", output_features={self.output_size}" + s += f", bias={self.bias is not None}" + return s + + +class ColumnParallelLinear(LinearBase): + """Linear layer with column parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its second dimension as A = [A_1, ..., A_p]. + + Args: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + bias: If true, add bias. + gather_output: If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is Y_i = XA_i + skip_bias_add: This was added to enable performance optimizations where + bias can be fused with other element-wise operations. we + skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + output_sizes: list of output sizes packed into one output, like for QKV + the list would be size 3. + """ + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + output_sizes: Optional[List[int]] = None, + ): + super().__init__( + input_size, output_size, skip_bias_add, params_dtype, quant_config + ) + + self.gather_output = gather_output + + # Divide the weight matrix along the last dimension. + tp_size = get_tensor_model_parallel_world_size() + assert self.quant_method is not None + self.output_size_per_partition = divide(self.output_size, tp_size) + self.output_partition_sizes = [self.output_size_per_partition] + # If QKV or MergedColumn, use output size of each partition. + if hasattr(self, "output_sizes"): + self.output_partition_sizes = [ + divide(output_size, tp_size) for output_size in self.output_sizes + ] + + if output_sizes is None: + output_sizes = [output_size] + self.quant_method.create_weights( + layer=self, + input_size_per_partition=self.input_size, + output_partition_sizes=self.output_partition_sizes, + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=self.weight_loader, + ) + if bias: + self.bias = Parameter( + torch.empty(self.output_size_per_partition, dtype=params_dtype) + ) + set_weight_attrs( + self.bias, + { + "output_dim": 0, + "weight_loader": self.weight_loader, + }, + ) + else: + self.register_parameter("bias", None) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + if param.data.dtype != loaded_weight.dtype: + param.data = torch.empty_like( + param.data, dtype=loaded_weight.dtype, device="cuda" + ) + + tp_rank = get_tensor_model_parallel_rank() + output_dim = getattr(param, "output_dim", None) + param_data = param.data + if output_dim is not None: + shard_size = param_data.shape[output_dim] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def forward(self, input_): + bias = self.bias if not self.skip_bias_add else None + + # Matrix multiply. + assert self.quant_method is not None + output_parallel = self.quant_method.apply(self, input_, bias) + if self.gather_output: + # All-gather across the partitions. + output = tensor_model_parallel_all_gather(output_parallel) + else: + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + def extra_repr(self) -> str: + s = f"in_features={self.input_size}" + s += f", output_features={self.output_size_per_partition}" + s += f", bias={self.bias is not None}" + s += f", tp_size={get_tensor_model_parallel_world_size()}" + s += f", gather_output={self.gather_output}" + return s + + +class MergedColumnParallelLinear(ColumnParallelLinear): + """Packed linear layers with column parallelism. + + Similar to ColumnParallelLinear, but the weight matrix is concatenated + along the output dimension. When the weight matrix is loaded, the + different partitions are sharded separately. + + Args: + input_size: input dimension of the linear layer. + output_sizes: list of output dimensions of the linear layer. + bias: If true, add bias. + gather_output: If true, call all-gather on output and make the output + available to all GPUs, otherwise, every GPU will have + its own output. + skip_bias_add: This was added to enable performance optimizations where + bias can be fused with other element-wise operations. we + skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + """ + + def __init__( + self, + input_size: int, + output_sizes: List[int], + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + self.output_sizes = output_sizes + tp_size = get_tensor_model_parallel_world_size() + assert all(output_size % tp_size == 0 for output_size in output_sizes) + super().__init__( + input_size=input_size, + output_size=sum(output_sizes), + bias=bias, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config, + ) + + def weight_loader( + self, + param: Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[int] = None, + ): + if param.data.dtype != loaded_weight.dtype: + param.data = torch.empty_like( + param.data, dtype=loaded_weight.dtype, device="cuda" + ) + + param_data = param.data + output_dim = getattr(param, "output_dim", None) + # Special case for AQLM codebooks. + is_metadata = getattr(param, "is_metadata", False) + # Special case for per-tensor scale to load scalar into fused array. + needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) + + if loaded_shard_id is None: + # Loaded weight is already fused on disk (qkv/mlp). + if output_dim is None: + if needs_scalar_to_array is not None: + param_data, loaded_weight = adjust_scalar_to_fused_array( + param_data, loaded_weight, 0 + ) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + return + current_shard_offset = 0 + shard_offsets: List[Tuple[int, int, int]] = [] + for i, output_size in enumerate(self.output_sizes): + shard_offsets.append((i, current_shard_offset, output_size)) + current_shard_offset += output_size + packed_dim = getattr(param, "packed_dim", None) + for shard_id, shard_offset, shard_size in shard_offsets: + # Special case for Quantization. + # If quantized, we need to adjust the offset and size to account + # for the packing. + if packed_dim == output_dim: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + # Special case for Marlin. + shard_size, shard_offset = adjust_marlin_shard( + param, shard_size, shard_offset + ) + + loaded_weight_shard = loaded_weight.narrow( + output_dim, shard_offset, shard_size + ) + self.weight_loader(param, loaded_weight_shard, shard_id) + return + + assert loaded_shard_id < len(self.output_sizes) + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + if output_dim is not None: + shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size + shard_size = self.output_sizes[loaded_shard_id] // tp_size + # Special case for quantization. + # If quantized, we need to adjust the offset and size to account + # for the packing. + packed_dim = getattr(param, "packed_dim", None) + if packed_dim == output_dim: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + # Special case for Marlin. + shard_size, shard_offset = adjust_marlin_shard( + param, shard_size, shard_offset + ) + + use_bitsandbytes = getattr(param, "use_bitsandbytes", False) + if use_bitsandbytes: + shard_size = loaded_weight.shape[output_dim] + shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id + + param_data = param_data.narrow(output_dim, shard_offset, shard_size) + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + # Special case for AQLM codebooks. + elif is_metadata: + # metadata indicates fixed size concatenated along dim 0 + shard_size = loaded_weight.shape[0] + shard_offset = loaded_shard_id * shard_size + param_data = param_data.narrow(0, shard_offset, shard_size) + + # Special case for per-tensor scales in fused case. + elif needs_scalar_to_array: + param_data, loaded_weight = adjust_scalar_to_fused_array( + param_data, loaded_weight, loaded_shard_id + ) + + else: + ignore_warning = getattr(param, "ignore_warning", False) + if not ignore_warning: + logger.warning( + "Loading a weight without `output_dim` attribute in " + "MergedColumnParallelLinear, assume the weight is " + "the same for all partitions." + ) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + +class QKVParallelLinear(ColumnParallelLinear): + """Linear layers for the attention's QKV transformation. + + Linear layers for the linear transformation of the query, key, and value + vectors in the attention layer. The weight matrix is concatenated along + the output dimension. The layer is parallelized along the head dimension. + When the number of key/value heads is smaller than the number of query + heads (e.g., multi-query/grouped-query attention), the key/value head may + be replicated while the query heads are partitioned. + + Args: + hidden_size: input hidden state size of the transformer. + head_size: size of each attention head. + total_num_heads: total number of attention query heads. + total_num_kv_heads: total number of attention key/value heads. If + None, assume total_num_kv_heads = total_num_heads. + bias: If true, add bias. + skip_bias_add: This was added to enable performance optimizations where + bias can be fused with other element-wise operations. we + skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + """ + + def __init__( + self, + hidden_size: int, + head_size: int, + total_num_heads: int, + total_num_kv_heads: Optional[int] = None, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + self.hidden_size = hidden_size + self.head_size = head_size + self.total_num_heads = total_num_heads + if total_num_kv_heads is None: + total_num_kv_heads = total_num_heads + self.total_num_kv_heads = total_num_kv_heads + # Divide the weight matrix along the last dimension. + tp_size = get_tensor_model_parallel_world_size() + self.num_heads = divide(self.total_num_heads, tp_size) + if tp_size >= self.total_num_kv_heads: + self.num_kv_heads = 1 + self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads) + else: + self.num_kv_heads = divide(self.total_num_kv_heads, tp_size) + self.num_kv_head_replicas = 1 + input_size = self.hidden_size + output_size = ( + (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size + ) + self.output_sizes = [ + self.num_heads * self.head_size * tp_size, # q_proj + self.num_kv_heads * self.head_size * tp_size, # k_proj + self.num_kv_heads * self.head_size * tp_size, # v_proj + ] + + super().__init__( + input_size=input_size, + output_size=output_size, + bias=bias, + gather_output=False, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config, + ) + + def weight_loader( + self, + param: Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[str] = None, + ): + if param.data.dtype != loaded_weight.dtype: + param.data = torch.empty_like( + param.data, dtype=loaded_weight.dtype, device="cuda" + ) + + param_data = param.data + output_dim = getattr(param, "output_dim", None) + # Special case for AQLM codebooks. + is_metadata = getattr(param, "is_metadata", False) + + # Special case for per-tensor scales in fused case. + needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) + + if loaded_shard_id is None: + # Loaded weight is already fused on disk (qkv/mlp). + if output_dim is None: + if needs_scalar_to_array is not None: + param_data, loaded_weight = adjust_scalar_to_fused_array( + param_data, loaded_weight, 0 + ) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + return + shard_offsets = [ + # (shard_id, shard_offset, shard_size) + ("q", 0, self.total_num_heads * self.head_size), + ( + "k", + self.total_num_heads * self.head_size, + self.total_num_kv_heads * self.head_size, + ), + ( + "v", + (self.total_num_heads + self.total_num_kv_heads) * self.head_size, + self.total_num_kv_heads * self.head_size, + ), + ] + packed_dim = getattr(param, "packed_dim", None) + for shard_id, shard_offset, shard_size in shard_offsets: + # Special case for Quantized Weights. + # If quantized, we need to adjust the offset and size to account + # for the packing. + if packed_dim == output_dim: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + + # Special case for Marlin. + shard_size, shard_offset = adjust_marlin_shard( + param, shard_size, shard_offset + ) + + loaded_weight_shard = loaded_weight.narrow( + output_dim, shard_offset, shard_size + ) + self.weight_loader(param, loaded_weight_shard, shard_id) + return + + tp_rank = get_tensor_model_parallel_rank() + assert loaded_shard_id in ["q", "k", "v"] + + # If output dim is defined, use the default loading process. + if output_dim is not None: + if loaded_shard_id == "q": + shard_offset = 0 + shard_size = self.num_heads * self.head_size + elif loaded_shard_id == "k": + shard_offset = self.num_heads * self.head_size + shard_size = self.num_kv_heads * self.head_size + elif loaded_shard_id == "v": + shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size + shard_size = self.num_kv_heads * self.head_size + # Special case for Quantized Weights. + # If quantized, we need to adjust the offset and size to account + # for the packing. + packed_dim = getattr(param, "packed_dim", None) + if packed_dim == output_dim: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + + # Special case for Marlin. + shard_size, shard_offset = adjust_marlin_shard( + param, shard_size, shard_offset + ) + + use_bitsandbytes = getattr(param, "use_bitsandbytes", False) + if use_bitsandbytes: + orig_qkv_offsets = { + "q": (0, self.num_heads * self.head_size), + "k": ( + self.num_heads * self.head_size, + self.num_kv_heads * self.head_size, + ), + "v": ( + (self.num_heads + self.num_kv_heads) * self.head_size, + self.num_kv_heads * self.head_size, + ), + "total": ( + (self.num_heads + 2 * self.num_kv_heads) * self.head_size, + 0, + ), + } + shard_size, shard_offset = adjust_bitsandbytes_shard( + param, orig_qkv_offsets, loaded_shard_id + ) + + param_data = param_data.narrow(output_dim, shard_offset, shard_size) + if loaded_shard_id == "q": + shard_id = tp_rank + else: + shard_id = tp_rank // self.num_kv_head_replicas + start_idx = shard_id * shard_size + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + # Special case for for AQLM codebooks. + elif is_metadata: + # metadata indicates fixed size concatenated along dim 0 + shard_size = loaded_weight.shape[0] + shard_index = ["q", "k", "v"].index(loaded_shard_id) + param_data = param_data.narrow(0, shard_index * shard_size, shard_size) + # Special case for per-tensor scales in fused case. + elif needs_scalar_to_array: + param_data, loaded_weight = adjust_scalar_to_fused_array( + param_data, loaded_weight, loaded_shard_id + ) + else: + ignore_warning = getattr(param, "ignore_warning", False) + if not ignore_warning: + logger.warning( + "Loading a weight without `output_dim` attribute in " + "QKVParallelLinear, assume the weight is the same " + "for all partitions." + ) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + +class RowParallelLinear(LinearBase): + """Linear layer with row parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its first dimension and X along its second dimension as: + - - + | A_1 | + | . | + A = | . | X = [X_1, ..., X_p] + | . | + | A_p | + - - + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + bias: If true, add bias. Note that bias is not parallelized. + input_is_parallel: If true, we assume that the input is already + split across the GPUs and we do not split + again. + skip_bias_add: This was added to enable performance optimization where + bias can be fused with other element-wise operations. + We skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + """ + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + input_is_parallel: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = True, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__( + input_size, output_size, skip_bias_add, params_dtype, quant_config + ) + + self.input_is_parallel = input_is_parallel + self.reduce_results = reduce_results + + # Divide the weight matrix along the last dimension. + self.tp_size = get_tensor_model_parallel_world_size() + self.input_size_per_partition = divide(input_size, self.tp_size) + assert self.quant_method is not None + self.quant_method.create_weights( + layer=self, + input_size_per_partition=self.input_size_per_partition, + output_partition_sizes=[self.output_size], + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=self.weight_loader, + ) + if not reduce_results and (bias and not skip_bias_add): + raise ValueError( + "When not reduce the results, adding bias to the " + "results can lead to incorrect results" + ) + + if bias: + self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype)) + set_weight_attrs( + self.bias, + { + "output_dim": 0, + "weight_loader": self.weight_loader, + }, + ) + else: + self.register_parameter("bias", None) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + if param.data.dtype != loaded_weight.dtype: + param.data = torch.empty_like( + param.data, dtype=loaded_weight.dtype, device="cuda" + ) + + param_data = param.data + tp_rank = get_tensor_model_parallel_rank() + input_dim = getattr(param, "input_dim", None) + if input_dim is not None: + shard_size = param.data.shape[input_dim] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size) + + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def forward(self, input_): + # Set up backprop all-reduce. + if self.input_is_parallel: + input_parallel = input_ + else: + tp_rank = get_tensor_model_parallel_rank() + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.tp_size + ) + input_parallel = splitted_input[tp_rank].contiguous() + + # Matrix multiply. + assert self.quant_method is not None + output_parallel = self.quant_method.apply(self, input_parallel) + if self.reduce_results and self.tp_size > 1: + output_ = tensor_model_parallel_all_reduce(output_parallel) + else: + output_ = output_parallel + + if not self.skip_bias_add: + output = output_ + self.bias if self.bias is not None else output_ + output_bias = None + else: + output = output_ + output_bias = self.bias + return output, output_bias + + def extra_repr(self) -> str: + s = f"input_features={self.input_size_per_partition}" + s += f", output_features={self.output_size}" + s += f", bias={self.bias is not None}" + s += f", tp_size={self.tp_size}" + s += f", reduce_results={self.reduce_results}" + return s diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 1ed7b8f7d82..5584d01ad83 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -1,7 +1,22 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + """Logits processing.""" import dataclasses -from typing import List, Union +from typing import List, Optional, Union import torch from torch import nn @@ -10,7 +25,7 @@ tensor_model_parallel_all_gather, ) -from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata +from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata @dataclasses.dataclass @@ -22,24 +37,23 @@ class LogitProcessorOutput: # The normlaized logprobs of prompts. shape: [#seq] normalized_prompt_logprobs: torch.Tensor - # The logprobs of prefill tokens. shape: [#token, vocab_size] - prefill_token_logprobs: torch.Tensor + # The logprobs of input tokens. shape: [#token, vocab_size] + input_token_logprobs: torch.Tensor - # The logprob and id of the top-k tokens in prefill positions. shape [#seq, #token, k] of Tuple(logprob, token_id) - prefill_top_logprobs: List - # The logprob and id of the top-k tokens in decode positions. shape [#seq, #token, k] of Tuple(logprob, token_id) - decode_top_logprobs: List + # The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id) + input_top_logprobs: List + # The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k] of Tuple(logprob, token_id) + output_top_logprobs: List @dataclasses.dataclass class LogitsMetadata: forward_mode: ForwardMode - extend_seq_lens: torch.Tensor - extend_start_loc: torch.Tensor + return_logprob: bool = False - # For logprobs - return_logprob: bool - top_logprobs_nums: List[int] + extend_seq_lens: Optional[torch.Tensor] = None + extend_start_loc: Optional[torch.Tensor] = None + top_logprobs_nums: Optional[List[int]] = None @classmethod def from_input_metadata(cls, input_metadata: InputMetadata): @@ -59,20 +73,16 @@ def __init__(self, config): self.tp_size = get_tensor_model_parallel_world_size() def _get_normalized_prompt_logprobs( - self, prefill_token_logprobs, logits_metadata: LogitsMetadata + self, input_token_logprobs, logits_metadata: LogitsMetadata ): - logprobs_cumsum = torch.cumsum( - prefill_token_logprobs, dim=0, dtype=torch.float32 - ) + logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32) start = logits_metadata.extend_start_loc.clone() end = start + logits_metadata.extend_seq_lens - 2 - start.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1) - end.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1) + start.clamp_(min=0, max=input_token_logprobs.shape[0] - 1) + end.clamp_(min=0, max=input_token_logprobs.shape[0] - 1) sum_logp = ( - logprobs_cumsum[end] - - logprobs_cumsum[start] - + prefill_token_logprobs[start] + logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start] ) normalized_prompt_logprobs = sum_logp / ( (logits_metadata.extend_seq_lens - 1).clamp(min=1) @@ -80,37 +90,51 @@ def _get_normalized_prompt_logprobs( return normalized_prompt_logprobs - def _get_top_logprobs(self, all_logprobs, logits_metadata: LogitsMetadata): - # TODO: vectorize the code below + @staticmethod + def get_top_logprobs(all_logprobs, logits_metadata: LogitsMetadata): if logits_metadata.forward_mode == ForwardMode.DECODE: - decode_top_logprobs = [] - for i in range(all_logprobs.shape[0]): - k = logits_metadata.top_logprobs_nums[i] - t = all_logprobs[i].topk(k) - v_cpu = t.values.tolist() - p_cpu = t.indices.tolist() - decode_top_logprobs.append(list(zip(v_cpu, p_cpu))) - return None, decode_top_logprobs + output_top_logprobs = [] + max_k = max(logits_metadata.top_logprobs_nums) + ret = all_logprobs.topk(max_k, dim=1) + values = ret.values.tolist() + indices = ret.indices.tolist() + for i, k in enumerate(logits_metadata.top_logprobs_nums): + output_top_logprobs.append(list(zip(values[i][:k], indices[i][:k]))) + return None, output_top_logprobs else: - prefill_top_logprobs, decode_top_logprobs = [], [] + # TODO: vectorize the code below + input_top_logprobs, output_top_logprobs = [], [] pt = 0 extend_seq_lens_cpu = logits_metadata.extend_seq_lens.tolist() + + max_k = max(logits_metadata.top_logprobs_nums) + ret = all_logprobs.topk(max_k, dim=1) + values = ret.values.tolist() + indices = ret.indices.tolist() + for i, extend_seq_len in enumerate(extend_seq_lens_cpu): if extend_seq_len == 0: - prefill_top_logprobs.append([]) - decode_top_logprobs.append([]) + input_top_logprobs.append([]) + output_top_logprobs.append([]) continue k = logits_metadata.top_logprobs_nums[i] - t = all_logprobs[pt : pt + extend_seq_len].topk(k) - vs_cpu = t.values.tolist() - ps_cpu = t.indices.tolist() - prefill_top_logprobs.append( - [list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)] + input_top_logprobs.append( + [ + list(zip(values[pt + j][:k], indices[pt + j][:k])) + for j in range(extend_seq_len - 1) + ] + ) + output_top_logprobs.append( + list( + zip( + values[pt + extend_seq_len - 1][:k], + indices[pt + extend_seq_len - 1][:k], + ) + ) ) - decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1]))) pt += extend_seq_len - return prefill_top_logprobs, decode_top_logprobs + return input_top_logprobs, output_top_logprobs def forward( self, @@ -137,7 +161,7 @@ def forward( last_logits = torch.matmul(last_hidden, weight.T) if self.tp_size > 1: last_logits = tensor_model_parallel_all_gather(last_logits) - last_logits = last_logits[:, : self.config.vocab_size] + last_logits = last_logits[:, : self.config.vocab_size].float() if hasattr(self.config, "final_logit_softcapping"): last_logits /= self.config.final_logit_softcapping @@ -150,63 +174,75 @@ def forward( next_token_logits=last_logits, next_token_logprobs=None, normalized_prompt_logprobs=None, - prefill_token_logprobs=None, - prefill_top_logprobs=None, - decode_top_logprobs=None, + input_token_logprobs=None, + input_top_logprobs=None, + output_top_logprobs=None, ) else: # When logprob is requested, compute the logits for all tokens. if logits_metadata.forward_mode == ForwardMode.DECODE: - all_logits = last_logits - else: - all_logits = torch.matmul(hidden_states, weight.T) - if self.tp_size > 1: - all_logits = tensor_model_parallel_all_gather(all_logits) - all_logits = all_logits[:, : self.config.vocab_size] + last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1) - all_logprobs = all_logits.float() - del all_logits - all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1) - - # Get the logprob of top-k tokens - return_top_logprob = any(x > 0 for x in logits_metadata.top_logprobs_nums) - if return_top_logprob: - prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs( - all_logprobs, logits_metadata + # Get the logprob of top-k tokens + return_top_logprob = any( + x > 0 for x in logits_metadata.top_logprobs_nums ) - else: - prefill_top_logprobs = decode_top_logprobs = None + if return_top_logprob: + output_top_logprobs = self.get_top_logprobs( + last_logprobs, logits_metadata + )[1] + else: + output_top_logprobs = None - if logits_metadata.forward_mode == ForwardMode.DECODE: return LogitProcessorOutput( next_token_logits=last_logits, - next_token_logprobs=all_logprobs, + next_token_logprobs=last_logprobs, normalized_prompt_logprobs=None, - prefill_token_logprobs=None, - prefill_top_logprobs=None, - decode_top_logprobs=decode_top_logprobs, + input_token_logprobs=None, + input_top_logprobs=None, + output_top_logprobs=output_top_logprobs, ) else: + all_logits = torch.matmul(hidden_states, weight.T) + if self.tp_size > 1: + all_logits = tensor_model_parallel_all_gather(all_logits) + all_logits = all_logits[:, : self.config.vocab_size].float() + + all_logprobs = all_logits + del all_logits, hidden_states + all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1) + + # Get the logprob of top-k tokens + return_top_logprob = any( + x > 0 for x in logits_metadata.top_logprobs_nums + ) + if return_top_logprob: + input_top_logprobs, output_top_logprobs = self.get_top_logprobs( + all_logprobs, logits_metadata + ) + else: + input_top_logprobs = output_top_logprobs = None + last_logprobs = all_logprobs[last_index] # Compute the logprobs and normalized logprobs for the prefill tokens. # Note that we pad a zero at the end of each sequence for easy computation. - prefill_token_logprobs = all_logprobs[ + input_token_logprobs = all_logprobs[ torch.arange(all_logprobs.shape[0], device="cuda"), torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]), ] normalized_prompt_logprobs = self._get_normalized_prompt_logprobs( - prefill_token_logprobs, logits_metadata + input_token_logprobs, logits_metadata ) return LogitProcessorOutput( next_token_logits=last_logits, next_token_logprobs=last_logprobs, normalized_prompt_logprobs=normalized_prompt_logprobs, - prefill_token_logprobs=prefill_token_logprobs, - prefill_top_logprobs=prefill_top_logprobs, - decode_top_logprobs=decode_top_logprobs, + input_token_logprobs=input_token_logprobs, + input_top_logprobs=input_top_logprobs, + output_top_logprobs=output_top_logprobs, ) diff --git a/python/sglang/srt/layers/pooler.py b/python/sglang/srt/layers/pooler.py new file mode 100644 index 00000000000..21752366a3c --- /dev/null +++ b/python/sglang/srt/layers/pooler.py @@ -0,0 +1,50 @@ +# adapted from +# https://github.com/vllm-project/vllm/blob/82a1b1a82b1fbb454c82a9ef95730b929c9b270c/vllm/model_executor/layers/pooler.py + +from dataclasses import dataclass +from enum import IntEnum + +import torch +import torch.nn as nn + +from sglang.srt.model_executor.model_runner import InputMetadata + + +class PoolingType(IntEnum): + LAST = 0 + + +@dataclass +class EmbeddingPoolerOutput: + embeddings: torch.Tensor + + +class Pooler(nn.Module): + """A layer that pools specific information from hidden states. + This layer does the following: + 1. Extracts specific tokens or aggregates data based on pooling method. + 2. Normalizes output if specified. + 3. Returns structured results as `PoolerOutput`. + Attributes: + pooling_type: The type of pooling to use (LAST, AVERAGE, MAX). + normalize: Whether to normalize the pooled data. + """ + + def __init__(self, pooling_type: PoolingType, normalize: bool): + super().__init__() + self.pooling_type = pooling_type + self.normalize = normalize + + def forward( + self, hidden_states: torch.Tensor, input_metadata: InputMetadata + ) -> EmbeddingPoolerOutput: + if self.pooling_type == PoolingType.LAST: + last_token_indices = torch.cumsum(input_metadata.extend_seq_lens, dim=0) - 1 + pooled_data = hidden_states[last_token_indices] + else: + raise ValueError(f"Invalid pooling type: {self.pooling_type}") + + if self.normalize: + pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1) + + return EmbeddingPoolerOutput(embeddings=pooled_data) diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py new file mode 100644 index 00000000000..564a696b0ce --- /dev/null +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -0,0 +1,64 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# temporarily adapted from vLLM +# FIXME: in progress of refactoring the model loader + +from typing import Dict, Type + +from vllm.model_executor.layers.quantization.aqlm import AQLMConfig +from vllm.model_executor.layers.quantization.awq import AWQConfig +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 + CompressedTensorsConfig, +) +from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig +from vllm.model_executor.layers.quantization.gptq import GPTQConfig +from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig +from vllm.model_executor.layers.quantization.gptq_marlin_24 import GPTQMarlin24Config +from vllm.model_executor.layers.quantization.marlin import MarlinConfig +from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig + +from sglang.srt.layers.quantization.fp8 import Fp8Config + +QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { + "aqlm": AQLMConfig, + "awq": AWQConfig, + "deepspeedfp": DeepSpeedFPConfig, + "fp8": Fp8Config, + # The order of gptq methods is important for config.py iteration over + # override_quantization_method(..) + "marlin": MarlinConfig, + "gptq_marlin_24": GPTQMarlin24Config, + "gptq_marlin": GPTQMarlinConfig, + "gptq": GPTQConfig, + "squeezellm": SqueezeLLMConfig, + "compressed-tensors": CompressedTensorsConfig, + "bitsandbytes": BitsAndBytesConfig, +} + + +def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: + if quantization not in QUANTIZATION_METHODS: + raise ValueError(f"Invalid quantization method: {quantization}") + return QUANTIZATION_METHODS[quantization] + + +__all__ = [ + "QuantizationConfig", + "get_quantization_config", + "QUANTIZATION_METHODS", +] diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py new file mode 100644 index 00000000000..12378d50622 --- /dev/null +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -0,0 +1,677 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# adapted from https://github.com/vllm-project/vllm/blob/e76466dde2bc9525d55165ceaa600d298c7bf773/vllm/model_executor/layers/quantization/fp8.py +# FIXME refactor in progress +from typing import Any, Dict, List, Optional, Union + +import torch +from torch.nn import Module +from torch.nn.parameter import Parameter +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase, fused_moe +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQ_MARLIN_MAX_PARALLEL, + GPTQ_MARLIN_MIN_THREAD_N, + GPTQMarlinState, + marlin_permute_scales, +) +from vllm.model_executor.layers.quantization.utils.marlin_utils import pack_fp8_to_int32 +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform +from vllm.utils import print_warning_once + +from sglang.srt.layers.linear import LinearBase, LinearMethodBase + +ACTIVATION_SCHEMES = ["static", "dynamic"] + +logger = init_logger(__name__) + + +def cutlass_fp8_supported() -> bool: + capability = current_platform.get_device_capability() + capability = capability[0] * 10 + capability[1] + + return ops.cutlass_scaled_mm_supports_fp8(capability) + + +class Fp8Config(QuantizationConfig): + """Config class for FP8.""" + + def __init__( + self, + is_checkpoint_fp8_serialized: bool = False, + activation_scheme: str = "dynamic", + ) -> None: + self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized + if is_checkpoint_fp8_serialized: + logger.warning( + "Detected fp8 checkpoint. Please note that the " + "format is experimental and subject to change." + ) + if activation_scheme not in ACTIVATION_SCHEMES: + raise ValueError(f"Unsupported activation scheme {activation_scheme}") + self.activation_scheme = activation_scheme + + @classmethod + def get_name(cls) -> str: + return "fp8" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "Fp8Config": + quant_method = cls.get_from_keys(config, ["quant_method"]) + is_checkpoint_fp8_serialized = "fp8" in quant_method + activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) + return cls( + is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, + activation_scheme=activation_scheme, + ) + + def get_quant_method( + self, layer: torch.nn.Module + ) -> Optional["QuantizeMethodBase"]: + + if isinstance(layer, LinearBase): + return Fp8LinearMethod(self) + elif isinstance(layer, FusedMoE): + return Fp8MoEMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class Fp8LinearMethod(LinearMethodBase): + """Linear method for FP8. + Supports loading FP8 checkpoints with static weight scale and + dynamic/static activation scale. + + Also supports loading quantized FP16/BF16 model checkpoints with dynamic + activation scaling. The weight scaling factor will be initialized after + the model weights are loaded. + + Limitations: + 1. Only support per-tensor quantization due to torch._scaled_mm support. + 2. Only support float8_e4m3fn data type due to the limitation of + torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856) + + Args: + quant_config: The quantization config. + """ + + def __init__(self, quant_config: Fp8Config): + self.quant_config = quant_config + self.cutlass_fp8_supported = cutlass_fp8_supported() + + # For GPUs that lack FP8 hardware support, we can leverage the Marlin + # kernel for fast weight-only FP8 quantization + capability = current_platform.get_device_capability() + capability = capability[0] * 10 + capability[1] + self.use_marlin = capability < 89 + + def _create_scale_param( + self, + scale_name: str, + layer: torch.nn.Module, + output_partition_sizes: List[int], + **extra_weight_attrs, + ) -> None: + scale = Parameter( + torch.empty(len(output_partition_sizes), dtype=torch.float32), + requires_grad=False, + ) + scale[:] = torch.finfo(torch.float8_e4m3fn).min + layer.register_parameter(scale_name, scale) + set_weight_attrs( + scale, + { + **extra_weight_attrs, + "needs_scalar_to_array": True, + }, + ) + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del input_size, output_size + output_size_per_partition = sum(output_partition_sizes) + + layer.process_after_load = True + layer.logical_widths = output_partition_sizes + + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + + # WEIGHT + # weight_dtype = (torch.float8_e4m3fn + # if self.quant_config.is_checkpoint_fp8_serialized else + # params_dtype) + weight_dtype = torch.float8_e4m3fn + weight = Parameter( + torch.empty( + output_size_per_partition, input_size_per_partition, dtype=weight_dtype + ), + requires_grad=False, + ) + layer.register_parameter("weight", weight) + set_weight_attrs( + weight, + { + **extra_weight_attrs, + "input_dim": 1, + "output_dim": 0, + }, + ) + + # If checkpoint is serialized fp8, load them. + # Otherwise, wait until process_weights_after_loading. + if self.quant_config.is_checkpoint_fp8_serialized: + # WEIGHT SCALE + self._create_scale_param( + scale_name="weight_scale", + layer=layer, + output_partition_sizes=output_partition_sizes, + **extra_weight_attrs, + ) + + # INPUT ACTIVATION SCALE + if self.quant_config.activation_scheme == "static": + self._create_scale_param( + scale_name="input_scale", + layer=layer, + output_partition_sizes=output_partition_sizes, + **extra_weight_attrs, + ) + + # For GPUs without FP8 hardware support, we use Marlin for fast + # fused dequantization + if self.use_marlin: + layer.marlin_state = GPTQMarlinState.REPACK + + def prepare_layer_for_marlin(self, layer: Module) -> None: + print_warning_once( + "Your GPU does not have native support for FP8 computation but " + "FP8 quantization is being used. Weight-only FP8 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads." + ) + + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + + assert layer.marlin_state == GPTQMarlinState.REPACK + layer.marlin_state = GPTQMarlinState.READY + + device = layer.weight.device + + # WEIGHTS + # Repack weights to gptq format (packed int32 elements) + packed_gptq_qweight = pack_fp8_to_int32(layer.weight) + + # Repack weights to marlin format + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=packed_gptq_qweight, + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=part_size_k, + size_n=part_size_n, + num_bits=8, + ) + layer.weight = Parameter(marlin_qweight, requires_grad=False) + + # WEIGHT SCALES + # Currently Marlin doesn't support per-tensor scales, so we + # expand it to channelwise + scales = ( + layer.weight_scale.repeat(1, part_size_n).to(layer.orig_dtype).to(device) + ) + # Permute scales + marlin_scales = marlin_permute_scales( + s=scales, + size_k=part_size_k, + size_n=part_size_n, + group_size=-1, + num_bits=8, + ) + layer.weight_scale = Parameter(marlin_scales, requires_grad=False) + + # Allocate marlin workspace + max_workspace_size = ( + part_size_n // GPTQ_MARLIN_MIN_THREAD_N + ) * GPTQ_MARLIN_MAX_PARALLEL + workspace = torch.zeros( + max_workspace_size, dtype=torch.int, device=device, requires_grad=False + ) + + layer.workspace = workspace + + def process_weights_after_loading(self, layer: Module) -> None: + if not hasattr(layer, "process_after_load") or not layer.process_after_load: + return + + # If checkpoint is fp/bf16 (not serialized fp8), quantize the weights. + if not self.quant_config.is_checkpoint_fp8_serialized: + qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None) + layer.weight = Parameter(qweight.t(), requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + layer.logical_widths = None + layer.input_scale = None + if self.use_marlin: + self.prepare_layer_for_marlin(layer) + return + + # If checkpoint is fp8, requantize the separately quantized logical + # weights into a single fp8 weight with a single weight scale. + else: + # WEIGHT_SCALE / WEIGHT + # Loop over logical weights, requantizing with single scale. + max_w_scale = layer.weight_scale.max() + + # QKV / MLP is fused in the on disk checkpoint if any of the + # weight scales are still set to the default since we initialize + # N weight scales for N shards but we only load 1 weight scale + # from disk in this case. As a result, we skip dequant -> requant + # since we already have quantized QKV together. + # Sample Model with fused checkpoint: + # * nm-testing/Phi-3-mini-128k-instruct-FP8 + unfused_module_in_checkpoint = ( + layer.weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min + ) + + if unfused_module_in_checkpoint: + start = 0 + for idx, logical_width in enumerate(layer.logical_widths): + end = start + logical_width + weight_dq = per_tensor_dequantize( + layer.weight[start:end, :], layer.weight_scale[idx] + ) + + layer.weight[start:end, :] = per_tensor_quantize( + weight_dq, layer.weight_scale.max() + ) + start = end + layer.weight_scale = Parameter(max_w_scale, requires_grad=False) + + # WEIGHT + # Transpose weight for passing to torch._scaled_mm + weight = layer.weight + layer.weight = Parameter(weight.t(), requires_grad=False) + + # INPUT ACTIVATION SCALE + # Dynamic: set to None (required input to ops.scaled_fp8_quant). + # Static: set to max of the input_scales (since they are equal). + if self.quant_config.activation_scheme == "dynamic": + layer.input_scale = None + elif self.quant_config.activation_scheme == "static": + layer.input_scale = Parameter( + layer.input_scale.max(), requires_grad=False + ) + else: + raise ValueError( + f"Unknown scheme {self.quant_config.activation_scheme}" + ) + + if self.use_marlin: + self.prepare_layer_for_marlin(layer) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + if self.use_marlin: + # For GPUs that lack FP8 hardware support, we can leverage the + # Marlin kernel for fast weight-only FP8 quantization + + reshaped_x = x.reshape(-1, x.shape[-1]) + out_shape = x.shape[:-1] + (layer.output_size_per_partition,) + + output = ops.fp8_marlin_gemm( + a=reshaped_x, + b_q_weight=layer.weight, + b_scales=layer.weight_scale, + workspace=layer.workspace, + num_bits=8, + size_m=reshaped_x.shape[0], + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + else: + + # ops.scaled_fp8_quant supports both dynamic and static quant. + # If dynamic, layer.input_scale is None and x_scale computed from x + # If static, layer.input_scale is scalar and x_scale is input_scale + + if bias is None and self.cutlass_fp8_supported: + qinput, x_scale = ops.scaled_fp8_quant(x, layer.input_scale) + + # Fused GEMM_DQ + output = ops.cutlass_scaled_mm( + qinput, + layer.weight, + out_dtype=x.dtype, + scale_a=x_scale, + scale_b=layer.weight_scale, + ) + + else: + qinput, x_scale = ops.scaled_fp8_quant( + x, layer.input_scale, batch_dim_padding=17 + ) + + # Fused GEMM_DQ -- note we padded the input above because + # torch._scaled_mm is more performant for matrices with + # batch dimension > 16. Note that this could change + # in the future. + output, _ = torch._scaled_mm( + qinput, + layer.weight, + out_dtype=x.dtype, + scale_a=x_scale, + scale_b=layer.weight_scale, + bias=bias, + ) + + return torch.narrow(output, 0, 0, x.shape[0]) + + +class Fp8MoEMethod(FusedMoEMethodBase): + """MoE method for FP8. + Supports loading FP8 checkpoints with static weight scale and + dynamic/static activation scale. + + Also supports loading quantized FP16/BF16 model checkpoints with dynamic + activation scaling. The weight scaling factor will be initialized after + the model weights are loaded. + + Args: + quant_config: The quantization config. + """ + + def __init__(self, quant_config: Fp8Config): + self.quant_config = quant_config + + def create_weights( + self, + layer: Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + + layer.process_after_load = True + + if self.quant_config.is_checkpoint_fp8_serialized: + params_dtype = torch.float8_e4m3fn + + # WEIGHTS + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, hidden_size, intermediate_size, dtype=params_dtype + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_scale = torch.nn.Parameter( + torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w13_scale", w13_scale) + + w2_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w2_scale", w2_scale) + + # If loading fp8 checkpoint, pass the weight loaders. + # If loading an fp16 checkpoint, do not (we will quantize in + # process_weights_after_loading() + if self.quant_config.is_checkpoint_fp8_serialized: + set_weight_attrs(w13_scale, extra_weight_attrs) + set_weight_attrs(w2_scale, extra_weight_attrs) + + # INPUT_SCALES + if self.quant_config.activation_scheme == "static": + if not self.quant_config.is_checkpoint_fp8_serialized: + raise ValueError( + "Found static activation scheme for checkpoint that " + "was not serialized fp8." + ) + + a13_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("a13_scale", a13_scale) + set_weight_attrs(a13_scale, extra_weight_attrs) + + a2_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("a2_scale", a2_scale) + set_weight_attrs(a2_scale, extra_weight_attrs) + else: + layer.a13_scale = None + layer.a2_scale = None + + def process_weights_after_loading(self, layer: Module) -> None: + if not hasattr(layer, "process_after_load") or not layer.process_after_load: + return + + # If checkpoint is fp16, quantize in place. + if not self.quant_config.is_checkpoint_fp8_serialized: + w13_weight = torch.empty_like( + layer.w13_weight.data, dtype=torch.float8_e4m3fn + ) + w2_weight = torch.empty_like( + layer.w2_weight.data, dtype=torch.float8_e4m3fn + ) + + # Re-initialize w13_scale because we directly quantize + # merged w13 weights and generate a single scaling factor. + layer.w13_scale = torch.nn.Parameter( + torch.ones( + layer.num_experts, dtype=torch.float32, device=w13_weight.device + ), + requires_grad=False, + ) + for expert in range(layer.num_experts): + w13_weight[expert, :, :], layer.w13_scale[expert] = ( + ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) + ) + w2_weight[expert, :, :], layer.w2_scale[expert] = ops.scaled_fp8_quant( + layer.w2_weight.data[expert, :, :] + ) + layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + return + + # If checkpoint is fp8, we need to handle that the + # MoE kernels require single activation scale and single weight + # scale for w13 per expert. + else: + # Fp8 moe kernels require a single activation scale. + # We take the max of all the scales in case they differ. + if self.quant_config.activation_scheme == "static": + if layer.a13_scale is None or layer.a2_scale is None: + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None." + ) + if not all_close_1d(layer.a13_scale) or not all_close_1d( + layer.a2_scale + ): + print_warning_once( + "Found input_scales that are not equal for " + "fp8 MoE layer. Using the maximum across experts " + "for each layer. " + ) + layer.a13_scale = torch.nn.Parameter( + layer.a13_scale.max(), requires_grad=False + ) + layer.a2_scale = torch.nn.Parameter( + layer.a2_scale.max(), requires_grad=False + ) + + # Fp8 moe kernel needs single weight scale for w13 per expert. + # We take the max then dequant and requant each expert. + assert layer.w13_scale is not None + shard_size = layer.intermediate_size_per_partition + max_w13_scales = layer.w13_scale.max(dim=1).values + for expert_id in range(layer.num_experts): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][start : start + shard_size, :], + layer.w13_scale[expert_id][shard_id], + ) + layer.w13_weight[expert_id][start : start + shard_size, :] = ( + per_tensor_quantize(dq_weight, max_w13_scales[expert_id]) + ) + start += shard_size + + layer.w13_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False) + return + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + ) -> torch.Tensor: + + return fused_moe( + x, + layer.w13_weight, + layer.w2_weight, + router_logits, + top_k, + renormalize=renormalize, + inplace=True, + use_fp8=True, + w1_scale=layer.w13_scale, + w2_scale=layer.w2_scale, + a1_scale=layer.a13_scale, + a2_scale=layer.a2_scale, + ) + + +# FIXME: not used +class Fp8KVCacheMethod(QuantizeMethodBase): + """Supports loading kv-cache scaling factors from FP8 checkpoints.""" + + def __init__(self, quant_config: Fp8Config): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module): + """Create "weight" (aka kv_scale) for an attention layer. + + Args: + layer: The layer that is using the QuantizeMethodBase factory. + """ + # Initialize the KV cache scale to 1.0 as the default value. + # If the kv_scale appears in the checkpoint, it will be + # overwritten when loading weights. + layer.kv_scale = Parameter(torch.tensor(1.0), requires_grad=False) + + def apply(self, layer: torch.nn.Module) -> torch.Tensor: + raise RuntimeError("Fp8KVCacheMethod.apply should not be called.") + + def process_weights_after_loading(self, layer: Module) -> None: + # If the kv-cache dtype is auto, we enforce the kv-scale to be 1.0 + # regardless whether the kv-scale is available in the checkpoint. + if layer.kv_cache_dtype != "auto": + kv_scale = layer.kv_scale.to("cpu").tolist() + if not isinstance(kv_scale, float): + raise ValueError( + "Only support per-tensor scaling factor " "for fp8 KV cache" + ) + layer._kv_scale = kv_scale + if layer._kv_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype: + print_warning_once( + "Using KV cache scaling factor 1.0 for fp8_e4m3. This may " + "cause accuracy issues. Please make sure kv-cache scaling " + "factor is available in the fp8 checkpoint." + ) + del layer.kv_scale + + +def per_tensor_quantize( + tensor: torch.Tensor, inv_scale: Union[float, torch.Tensor] +) -> torch.Tensor: + finfo = torch.finfo(torch.float8_e4m3fn) + qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max) + return qweight.to(torch.float8_e4m3fn) + + +def per_tensor_dequantize( + tensor: torch.Tensor, inv_scale: Union[float, torch.Tensor] +) -> torch.Tensor: + fake_qweight = tensor.to(torch.float16) + dq_weight = fake_qweight * inv_scale + return dq_weight + + +def all_close_1d(x: torch.Tensor) -> bool: + assert len(x.shape) == 1 + return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0])) diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 2c3e91af107..2afd329f96d 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -1,6 +1,20 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + """Radix attention.""" -import numpy as np import torch from flashinfer.cascade import merge_state from torch import nn @@ -8,8 +22,8 @@ from sglang.global_config import global_config from sglang.srt.layers.extend_attention import extend_attention_fwd from sglang.srt.layers.token_attention import token_attention_fwd -from sglang.srt.managers.controller.infer_batch import global_server_args_dict -from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata +from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata +from sglang.srt.model_executor.model_runner import global_server_args_dict class RadixAttention(nn.Module): @@ -21,16 +35,22 @@ def __init__( num_kv_heads: int, layer_id: int, logit_cap: int = -1, + v_head_dim: int = -1, ): super().__init__() self.tp_q_head_num = num_heads self.tp_k_head_num = num_kv_heads self.tp_v_head_num = num_kv_heads self.head_dim = head_dim + self.qk_head_dim = head_dim + self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim self.scaling = scaling self.layer_id = layer_id - if not global_server_args_dict.get("disable_flashinfer", False): + if ( + not global_server_args_dict.get("disable_flashinfer", False) + and self.qk_head_dim == self.v_head_dim + ): self.extend_forward = self.extend_forward_flashinfer self.decode_forward = self.decode_forward_flashinfer else: @@ -40,24 +60,28 @@ def __init__( self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0 def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata): - o = torch.empty_like(q) + if self.qk_head_dim != self.v_head_dim: + o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim)) + else: + o = torch.empty_like(q) + self.store_kv_cache(k, v, input_metadata) extend_attention_fwd( - q.view(-1, self.tp_q_head_num, self.head_dim), + q.view(-1, self.tp_q_head_num, self.qk_head_dim), k.contiguous(), v.contiguous(), - o.view(-1, self.tp_q_head_num, self.head_dim), + o.view(-1, self.tp_q_head_num, self.v_head_dim), input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id), input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id), input_metadata.req_to_token_pool.req_to_token, input_metadata.req_pool_indices, - input_metadata.start_loc, + input_metadata.triton_start_loc, input_metadata.seq_lens, - input_metadata.prefix_lens, + input_metadata.triton_prefix_lens, input_metadata.extend_start_loc, input_metadata.extend_seq_lens, - input_metadata.max_seq_len, - input_metadata.max_extend_len, + input_metadata.triton_max_seq_len, + input_metadata.triton_max_extend_len, sm_scale=self.scaling, logit_cap=self.logit_cap, ) @@ -65,19 +89,22 @@ def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata): return o def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata): - o = torch.empty_like(q) + if self.qk_head_dim != self.v_head_dim: + o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim)) + else: + o = torch.empty_like(q) self.store_kv_cache(k, v, input_metadata) token_attention_fwd( - q.view(-1, self.tp_q_head_num, self.head_dim), + q.view(-1, self.tp_q_head_num, self.qk_head_dim), input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id), input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id), - o.view(-1, self.tp_q_head_num, self.head_dim), + o.view(-1, self.tp_q_head_num, self.v_head_dim), input_metadata.req_to_token_pool.req_to_token, input_metadata.req_pool_indices, - input_metadata.start_loc, + input_metadata.triton_start_loc, input_metadata.seq_lens, - input_metadata.max_seq_len, + input_metadata.triton_max_seq_len, input_metadata.total_num_tokens, sm_scale=self.scaling, logit_cap=self.logit_cap, @@ -86,32 +113,47 @@ def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata): return o def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata): - o1, s1 = input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse( - q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), - k.contiguous().view(-1, self.tp_k_head_num, self.head_dim), - v.contiguous().view(-1, self.tp_v_head_num, self.head_dim), - causal=True, - sm_scale=self.scaling, - logits_soft_cap=self.logit_cap, - ) + if not input_metadata.flashinfer_use_ragged: + self.store_kv_cache(k, v, input_metadata) - if input_metadata.no_prefix: - o = o1 - else: - o2, s2 = input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse( + o = input_metadata.flashinfer_prefill_wrapper_paged.forward( q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), - input_metadata.token_to_kv_pool.kv_data[self.layer_id], - causal=False, + input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id), + causal=True, sm_scale=self.scaling, logits_soft_cap=self.logit_cap, ) + else: + o1, s1 = ( + input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse( + q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), + k.contiguous().view(-1, self.tp_k_head_num, self.head_dim), + v.contiguous().view(-1, self.tp_v_head_num, self.head_dim), + causal=True, + sm_scale=self.scaling, + logits_soft_cap=self.logit_cap, + ) + ) - o, _ = merge_state(o1, s1, o2, s2) + if input_metadata.extend_no_prefix: + o = o1 + else: + o2, s2 = ( + input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse( + q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), + input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id), + causal=False, + sm_scale=self.scaling, + logits_soft_cap=self.logit_cap, + ) + ) - self.store_kv_cache(k, v, input_metadata) + o, _ = merge_state(o1, s1, o2, s2) - if input_metadata.total_num_tokens >= global_config.layer_sync_threshold: - torch.cuda.synchronize() + self.store_kv_cache(k, v, input_metadata) + + if input_metadata.total_num_tokens >= global_config.layer_sync_threshold: + torch.cuda.synchronize() return o.view(-1, self.tp_q_head_num * self.head_dim) @@ -120,7 +162,7 @@ def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata): o = input_metadata.flashinfer_decode_wrapper.forward( q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), - input_metadata.token_to_kv_pool.kv_data[self.layer_id], + input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id), sm_scale=self.scaling, logits_soft_cap=self.logit_cap, ) @@ -128,8 +170,8 @@ def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata): return o.view(-1, self.tp_q_head_num * self.head_dim) def forward(self, q, k, v, input_metadata: InputMetadata): - k = k.view(-1, self.tp_k_head_num, self.head_dim) - v = v.view(-1, self.tp_v_head_num, self.head_dim) + k = k.view(-1, self.tp_k_head_num, self.qk_head_dim) + v = v.view(-1, self.tp_v_head_num, self.v_head_dim) if input_metadata.forward_mode == ForwardMode.EXTEND: return self.extend_forward(q, k, v, input_metadata) @@ -137,17 +179,7 @@ def forward(self, q, k, v, input_metadata: InputMetadata): return self.decode_forward(q, k, v, input_metadata) def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata): - key_buffer = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id) - value_buffer = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id) - if input_metadata.out_cache_loc is not None: - key_buffer[input_metadata.out_cache_loc] = cache_k - value_buffer[input_metadata.out_cache_loc] = cache_v - elif input_metadata.out_cache_cont_start is not None: - key_buffer[ - input_metadata.out_cache_cont_start : input_metadata.out_cache_cont_end - ] = cache_k - value_buffer[ - input_metadata.out_cache_cont_start : input_metadata.out_cache_cont_end - ] = cache_v - else: - raise RuntimeError() + k_cache = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id) + v_cache = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id) + k_cache[input_metadata.out_cache_loc] = cache_k + v_cache[input_metadata.out_cache_loc] = cache_v diff --git a/python/sglang/srt/layers/token_attention.py b/python/sglang/srt/layers/token_attention.py index 9d7bda145a6..ab6e7ba7727 100644 --- a/python/sglang/srt/layers/token_attention.py +++ b/python/sglang/srt/layers/token_attention.py @@ -1,3 +1,18 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + # Adapted from # https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py # https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py @@ -5,8 +20,7 @@ import triton import triton.language as tl -from sglang.srt.managers.controller.model_runner import global_server_args_dict -from sglang.srt.utils import wrap_kernel_launcher +from sglang.srt.managers.schedule_batch import global_server_args_dict if global_server_args_dict.get("attention_reduce_in_fp32", False): REDUCE_TRITON_TYPE = tl.float32 @@ -40,6 +54,7 @@ def _fwd_kernel_stage1( att_stride_h, kv_group_num: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_DPE: tl.constexpr, BLOCK_N: tl.constexpr, logit_cap: tl.constexpr, ): @@ -59,6 +74,10 @@ def _fwd_kernel_stage1( off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d + if BLOCK_DPE > 0: + offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) + off_qpe = cur_batch * stride_qbs + cur_head * stride_qh + offs_dpe + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) block_stard_index = start_n * BLOCK_N @@ -83,6 +102,19 @@ def _fwd_kernel_stage1( other=0.0, ).to(REDUCE_TRITON_TYPE) att_value = tl.sum(q[None, :] * k, 1) + if BLOCK_DPE > 0: + qpe = tl.load(Q + off_qpe + start_mark).to(REDUCE_TRITON_TYPE) + offs_buf_kpe = ( + k_loc[:, None] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_dpe[None, :] + ) + kpe = tl.load( + K_Buffer + offs_buf_kpe, + mask=offs_n_new[:, None] < cur_batch_end_index, + other=0.0, + ).to(REDUCE_TRITON_TYPE) + att_value += tl.sum(qpe[None, :] * kpe, 1) att_value *= sm_scale if logit_cap > 0: @@ -162,10 +194,6 @@ def _fwd_kernel_stage2( tl.store(out_ptrs, acc) -cached_kernel_stage1 = None -cached_kernel_stage2 = None - - def _token_att_m_fwd( q, k_buffer, @@ -182,7 +210,14 @@ def _token_att_m_fwd( # shape constraints Lq, Lk = q.shape[-1], k_buffer.shape[-1] assert Lq == Lk - assert Lk in {16, 32, 64, 128, 256} + assert Lk in {16, 32, 64, 128, 256, 576} + + if Lk == 576: + BLOCK_DMODEL = 512 + BLOCK_DPE = 64 + else: + BLOCK_DMODEL = Lk + BLOCK_DPE = 0 batch, head_num = B_req_idx.shape[0], q.shape[1] @@ -194,28 +229,6 @@ def _token_att_m_fwd( else: num_warps = 2 - global cached_kernel_stage1 - if cached_kernel_stage1: - cached_kernel_stage1( - grid, - num_warps, - q, - k_buffer, - sm_scale, - Req_to_tokens, - B_req_idx, - B_Start_Loc, - B_Seqlen, - att_out, - Req_to_tokens.stride(0), - q.stride(0), - q.stride(1), - k_buffer.stride(0), - k_buffer.stride(1), - att_out.stride(0), - ) - return - _fwd_kernel_stage1[grid]( q, k_buffer, @@ -232,13 +245,13 @@ def _token_att_m_fwd( k_buffer.stride(1), att_out.stride(0), kv_group_num=kv_group_num, - BLOCK_DMODEL=Lk, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DPE=BLOCK_DPE, BLOCK_N=BLOCK, logit_cap=logit_cap, num_warps=num_warps, num_stages=1, ) - cached_kernel_stage1 = wrap_kernel_launcher(_fwd_kernel_stage1) def _token_softmax_reducev_fwd( @@ -257,27 +270,6 @@ def _token_softmax_reducev_fwd( num_warps = 1 - global cached_kernel_stage2 - if cached_kernel_stage2: - cached_kernel_stage2( - grid, - num_warps, - logics, - v_buffer, - o, - req_to_tokens, - b_req_idx, - b_start_loc, - b_seq_len, - logics.stride(0), - v_buffer.stride(0), - v_buffer.stride(1), - o.stride(0), - o.stride(1), - req_to_tokens.stride(0), - ) - return - _fwd_kernel_stage2[grid]( logics, v_buffer, @@ -298,7 +290,6 @@ def _token_softmax_reducev_fwd( num_warps=num_warps, num_stages=3, ) - cached_kernel_stage2 = wrap_kernel_launcher(_fwd_kernel_stage2) def token_attention_fwd( @@ -312,7 +303,7 @@ def token_attention_fwd( b_seq_len, max_len_in_batch, total_num_tokens, - sm_scale=None, + sm_scale, logit_cap=-1, att_m=None, ): @@ -320,7 +311,6 @@ def token_attention_fwd( att_m = torch.empty( (q.shape[-2], total_num_tokens), dtype=REDUCE_TORCH_TYPE, device="cuda" ) - sm_scale = 1.0 / (Lq**0.5) if sm_scale is None else sm_scale _token_att_m_fwd( q, diff --git a/python/sglang/srt/managers/controller/dp_worker.py b/python/sglang/srt/managers/controller/dp_worker.py deleted file mode 100644 index 3b6becfd2dc..00000000000 --- a/python/sglang/srt/managers/controller/dp_worker.py +++ /dev/null @@ -1,113 +0,0 @@ -"""A data parallel worker thread.""" - -import asyncio -import logging -import queue -import threading -from typing import Callable, List - -import uvloop -import zmq - -from sglang.global_config import global_config -from sglang.srt.managers.controller.tp_worker import ModelTpClient -from sglang.srt.managers.io_struct import BatchTokenIDOut -from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.utils import kill_parent_process -from sglang.utils import get_exception_traceback - -logger = logging.getLogger("srt.controller") -CHECKING_INTERVAL = 5 - -asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) - - -class DataParallelWorkerThread(threading.Thread): - def __init__( - self, - worker_id: int, - request_queue: queue.Queue, - detokenizer_port: int, - step_func: Callable, - ): - super(DataParallelWorkerThread, self).__init__() - self.worker_id = worker_id - self.request_queue = request_queue - self.liveness = True - self.request_dependency_delay = global_config.request_dependency_delay - - context = zmq.asyncio.Context() - self.send_to_detokenizer = context.socket(zmq.PUSH) - self.send_to_detokenizer.connect(f"tcp://127.0.0.1:{detokenizer_port}") - - self.step = step_func - - async def loop_for_forward(self): - while self.liveness: - requests = [] - while not self.request_queue.empty(): - requests.append(self.request_queue.get()) - - out_pyobjs: List[BatchTokenIDOut] = [] - try: - out_pyobjs = await self.step(requests) - except Exception: - for r in requests: - self.request_queue.put(r) - logger.error( - f"Worker thread {self.worker_id}: " - f"failed to get back from Model Server\n" - f"{get_exception_traceback()}" - ) - self.liveness = False - # Crash the whole server when there are any errors. - # TODO(lianmin): make this an option. - kill_parent_process() - return - - for obj in out_pyobjs: - self.send_to_detokenizer.send_pyobj(obj) - - # async sleep for receiving the subsequent request and avoiding cache miss - if len(out_pyobjs) != 0: - has_finished = any( - [obj.finished_reason is not None for obj in out_pyobjs] - ) - if has_finished: - await asyncio.sleep(self.request_dependency_delay) - await asyncio.sleep(global_config.wait_for_new_request_delay) - - async def monitoring(self): - while True: - await asyncio.sleep(CHECKING_INTERVAL) - # can plug in monitoring logic here - - def run(self): - logger.info(f"DataParallelWorkerThread {self.worker_id} start") - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop.create_task(self.monitoring()) - loop.run_until_complete(self.loop_for_forward()) - - -def start_data_parallel_worker( - server_args: ServerArgs, - port_args: PortArgs, - model_overide_args, - gpu_ids: List[int], - worker_id: int, -): - model_tp_client = ModelTpClient( - gpu_ids, - server_args, - port_args.model_port_args[worker_id], - model_overide_args, - ) - worker_thread = DataParallelWorkerThread( - worker_id=worker_id, - request_queue=queue.Queue(), - detokenizer_port=port_args.detokenizer_port, - step_func=model_tp_client.step, - ) - worker_thread.start() - return worker_thread diff --git a/python/sglang/srt/managers/controller/manager_multi.py b/python/sglang/srt/managers/controller/manager_multi.py deleted file mode 100644 index 72e3bed808e..00000000000 --- a/python/sglang/srt/managers/controller/manager_multi.py +++ /dev/null @@ -1,191 +0,0 @@ -""" -A controller that manages multiple data parallel workers. -Each data parallel worker can manage multiple tensor parallel workers. -""" - -import asyncio -import logging -from concurrent.futures import ThreadPoolExecutor -from enum import Enum, auto -from typing import Dict - -import zmq -import zmq.asyncio - -from sglang.global_config import global_config -from sglang.srt.managers.controller.dp_worker import ( - DataParallelWorkerThread, - start_data_parallel_worker, -) -from sglang.srt.managers.io_struct import ( - AbortReq, - FlushCacheReq, - TokenizedGenerateReqInput, -) -from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.utils import get_exception_traceback - -logger = logging.getLogger("srt.controller") - - -class LoadBalanceMethod(Enum): - ROUND_ROBIN = auto() - SHORTEST_QUEUE = auto() - - @classmethod - def from_str(cls, method: str): - method = method.upper() - try: - return cls[method] - except KeyError as exc: - raise ValueError(f"Invalid load balance method: {method}") from exc - - -class Controller: - def __init__( - self, - load_balance_method: str, - server_args: ServerArgs, - port_args: PortArgs, - model_overide_args, - ): - self.load_balance_method = LoadBalanceMethod.from_str(load_balance_method) - self.server_args = server_args - self.port_args = port_args - - if self.load_balance_method == LoadBalanceMethod.ROUND_ROBIN: - self.round_robin_counter = 0 - - self.dispatch_lookup = { - LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler, - LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler, - } - self.dispatching = self.dispatch_lookup[self.load_balance_method] - - # Init communication - context = zmq.asyncio.Context() - self.recv_from_tokenizer = context.socket(zmq.PULL) - self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}") - - # Init status - self.recv_reqs = [] - - # Start data parallel workers - self.workers: Dict[int, DataParallelWorkerThread] = {} - tp_size = server_args.tp_size - - def start_dp_worker(i): - try: - gpu_ids = list(range(i * tp_size, (i + 1) * tp_size)) - worker_thread = start_data_parallel_worker( - server_args, port_args, model_overide_args, gpu_ids, i - ) - self.workers[i] = worker_thread - except Exception: - logger.error( - f"Failed to start local worker {i}\n{get_exception_traceback()}" - ) - - for i in range(server_args.dp_size): - start_dp_worker(i) - - # Parallel launch is slower, probably due to the disk bandwidth limitations. - # with ThreadPoolExecutor(server_args.dp_size) as executor: - # executor.map(start_dp_worker, range(server_args.dp_size)) - - def have_any_live_worker(self): - return any(worker_thread.liveness for worker_thread in self.workers.values()) - - def put_req_to_worker(self, worker_id, req): - self.workers[worker_id].request_queue.put(req) - - async def round_robin_scheduler(self, input_requests): - available_workers = list(self.workers.keys()) - for r in input_requests: - self.put_req_to_worker(available_workers[self.round_robin_counter], r) - self.round_robin_counter = (self.round_robin_counter + 1) % len( - available_workers - ) - return - - async def shortest_queue_scheduler(self, input_requests): - for r in input_requests: - worker = min( - self.workers, key=lambda w: self.workers[w].request_queue.qsize() - ) - self.put_req_to_worker(worker, r) - return - - async def remove_dead_workers(self): - for i in list(self.workers.keys()): - worker_thread = self.workers[i] - if not worker_thread.liveness: - worker_thread.join() - # move unsuccessful requests back to the queue - while not worker_thread.request_queue.empty(): - self.recv_reqs.append(worker_thread.request_queue.get()) - del self.workers[i] - logger.info(f"Stale worker {i} removed") - - async def loop_for_forward(self): - while True: - await self.remove_dead_workers() - - if self.have_any_live_worker(): - next_step_input = list(self.recv_reqs) - self.recv_reqs = [] - if next_step_input: - await self.dispatching(next_step_input) - # else: - # logger.error("There is no live worker.") - - await asyncio.sleep(global_config.wait_for_new_request_delay) - - async def loop_for_recv_requests(self): - while True: - recv_req = await self.recv_from_tokenizer.recv_pyobj() - if isinstance(recv_req, FlushCacheReq): - # TODO(lsyin): apply more specific flushCacheReq - for worker_thread in self.workers.values(): - worker_thread.request_queue.put(recv_req) - elif isinstance(recv_req, TokenizedGenerateReqInput): - self.recv_reqs.append(recv_req) - elif isinstance(recv_req, AbortReq): - in_queue = False - for i, req in enumerate(self.recv_reqs): - if req.rid == recv_req.rid: - self.recv_reqs[i] = recv_req - in_queue = True - break - if not in_queue: - # Send abort req to all TP groups - for worker in list(self.workers.keys()): - self.put_req_to_worker(worker, recv_req) - else: - logger.error(f"Invalid object: {recv_req}") - - -def start_controller_process( - server_args: ServerArgs, - port_args: PortArgs, - pipe_writer, - model_overide_args=None, -): - logging.basicConfig( - level=getattr(logging, server_args.log_level.upper()), - format="%(message)s", - ) - - try: - controller = Controller( - server_args.load_balance_method, server_args, port_args, model_overide_args - ) - except Exception: - pipe_writer.send(get_exception_traceback()) - raise - - pipe_writer.send("init ok") - loop = asyncio.get_event_loop() - asyncio.set_event_loop(loop) - loop.create_task(controller.loop_for_recv_requests()) - loop.run_until_complete(controller.loop_for_forward()) diff --git a/python/sglang/srt/managers/controller/manager_single.py b/python/sglang/srt/managers/controller/manager_single.py deleted file mode 100644 index 4c2720733d6..00000000000 --- a/python/sglang/srt/managers/controller/manager_single.py +++ /dev/null @@ -1,102 +0,0 @@ -"""A controller that manages a group of tensor parallel workers.""" - -import asyncio -import logging -from concurrent.futures import ThreadPoolExecutor - -import uvloop -import zmq -import zmq.asyncio - -from sglang.global_config import global_config -from sglang.srt.managers.controller.tp_worker import ModelTpClient -from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.utils import kill_parent_process -from sglang.utils import get_exception_traceback - -asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) - -logger = logging.getLogger("srt.controller") - - -class ControllerSingle: - def __init__(self, model_client: ModelTpClient, port_args: PortArgs): - # Init communication - context = zmq.asyncio.Context(2) - self.recv_from_tokenizer = context.socket(zmq.PULL) - self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}") - - self.send_to_detokenizer = context.socket(zmq.PUSH) - self.send_to_detokenizer.connect( - f"tcp://127.0.0.1:{port_args.detokenizer_port}" - ) - - # Init status - self.model_client = model_client - self.recv_reqs = [] - - # Init some configs - self.request_dependency_delay = global_config.request_dependency_delay - - async def loop_for_forward(self): - while True: - next_step_input = list(self.recv_reqs) - self.recv_reqs = [] - out_pyobjs = await self.model_client.step(next_step_input) - - for obj in out_pyobjs: - self.send_to_detokenizer.send_pyobj(obj) - - # async sleep for receiving the subsequent request and avoiding cache miss - slept = False - if len(out_pyobjs) != 0: - has_finished = any( - [obj.finished_reason is not None for obj in out_pyobjs] - ) - if has_finished: - if self.request_dependency_delay > 0: - slept = True - await asyncio.sleep(self.request_dependency_delay) - - if not slept: - await asyncio.sleep(global_config.wait_for_new_request_delay) - - async def loop_for_recv_requests(self): - while True: - recv_req = await self.recv_from_tokenizer.recv_pyobj() - self.recv_reqs.append(recv_req) - - -def start_controller_process( - server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args -): - logging.basicConfig( - level=getattr(logging, server_args.log_level.upper()), - format="%(message)s", - ) - - try: - tp_size_local = server_args.tp_size // server_args.nnodes - model_client = ModelTpClient( - [i for _ in range(server_args.nnodes) for i in range(tp_size_local)], - server_args, - port_args.model_port_args[0], - model_overide_args, - ) - controller = ControllerSingle(model_client, port_args) - except Exception: - pipe_writer.send(get_exception_traceback()) - raise - - pipe_writer.send("init ok") - - loop = asyncio.new_event_loop() - loop.set_default_executor(ThreadPoolExecutor(max_workers=256)) - asyncio.set_event_loop(loop) - loop.create_task(controller.loop_for_recv_requests()) - try: - loop.run_until_complete(controller.loop_for_forward()) - except Exception: - logger.error("Exception in ControllerSingle:\n" + get_exception_traceback()) - finally: - kill_parent_process() diff --git a/python/sglang/srt/managers/controller/model_runner.py b/python/sglang/srt/managers/controller/model_runner.py deleted file mode 100644 index 30a8001e716..00000000000 --- a/python/sglang/srt/managers/controller/model_runner.py +++ /dev/null @@ -1,334 +0,0 @@ -"""ModelRunner runs the forward passes of the models.""" - -import importlib -import importlib.resources -import logging -import pkgutil -from functools import lru_cache -from typing import Optional, Type - -import torch -import torch.nn as nn -from vllm.config import DeviceConfig, LoadConfig -from vllm.config import ModelConfig as VllmModelConfig -from vllm.distributed import init_distributed_environment, initialize_model_parallel -from vllm.model_executor.model_loader import get_model -from vllm.model_executor.models import ModelRegistry - -from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, InputMetadata, global_server_args_dict -from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool -from sglang.srt.server_args import ServerArgs -from sglang.srt.utils import ( - get_available_gpu_memory, - is_multimodal_model, - monkey_patch_vllm_dummy_weight_loader, - monkey_patch_vllm_p2p_access_check, -) - -logger = logging.getLogger("srt.model_runner") - - -class ModelRunner: - def __init__( - self, - model_config, - mem_fraction_static: float, - gpu_id: int, - tp_rank: int, - tp_size: int, - nccl_port: int, - server_args: ServerArgs, - ): - # Parse args - self.model_config = model_config - self.mem_fraction_static = mem_fraction_static - self.gpu_id = gpu_id - self.tp_rank = tp_rank - self.tp_size = tp_size - self.nccl_port = nccl_port - self.server_args = server_args - self.is_multimodal_model = is_multimodal_model(self.model_config) - monkey_patch_vllm_dummy_weight_loader() - - # Init torch distributed - torch.cuda.set_device(self.gpu_id) - logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.") - - if not server_args.enable_p2p_check: - monkey_patch_vllm_p2p_access_check(self.gpu_id) - - if server_args.nccl_init_addr: - nccl_init_method = f"tcp://{server_args.nccl_init_addr}" - else: - nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}" - init_distributed_environment( - backend="nccl", - world_size=self.tp_size, - rank=self.tp_rank, - local_rank=self.gpu_id, - distributed_init_method=nccl_init_method, - ) - initialize_model_parallel(tensor_model_parallel_size=self.tp_size) - total_gpu_memory = get_available_gpu_memory( - self.gpu_id, distributed=self.tp_size > 1 - ) - - if self.tp_size > 1: - total_local_gpu_memory = get_available_gpu_memory(self.gpu_id) - if total_local_gpu_memory < total_gpu_memory * 0.9: - raise ValueError( - "The memory capacity is unbalanced. Some GPUs may be occupied by other processes." - ) - - # Set some global args - global_server_args_dict["disable_flashinfer"] = server_args.disable_flashinfer - global_server_args_dict["attention_reduce_in_fp32"] = server_args.attention_reduce_in_fp32 - - # Load the model and create memory pool - self.load_model() - self.init_memory_pool(total_gpu_memory) - self.init_cublas() - self.init_flash_infer() - - def load_model(self): - logger.info( - f"[gpu_id={self.gpu_id}] Load weight begin. " - f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" - ) - - device_config = DeviceConfig() - load_config = LoadConfig(load_format=self.server_args.load_format) - vllm_model_config = VllmModelConfig( - model=self.server_args.model_path, - quantization=self.server_args.quantization, - tokenizer=None, - tokenizer_mode=None, - trust_remote_code=self.server_args.trust_remote_code, - dtype=self.server_args.dtype, - seed=42, - skip_tokenizer_init=True, - ) - self.dtype = vllm_model_config.dtype - if self.model_config.model_overide_args is not None: - vllm_model_config.hf_config.update(self.model_config.model_overide_args) - - self.model = get_model( - model_config=vllm_model_config, - device_config=device_config, - load_config=load_config, - lora_config=None, - multimodal_config=None, - parallel_config=None, - scheduler_config=None, - cache_config=None, - ) - logger.info( - f"[gpu_id={self.gpu_id}] Load weight end. " - f"type={type(self.model).__name__}, " - f"dtype={self.dtype}, " - f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" - ) - - def profile_max_num_token(self, total_gpu_memory): - available_gpu_memory = get_available_gpu_memory( - self.gpu_id, distributed=self.tp_size > 1 - ) - head_dim = self.model_config.head_dim - head_num = self.model_config.get_num_kv_heads(self.tp_size) - cell_size = ( - head_num - * head_dim - * self.model_config.num_hidden_layers - * 2 - * torch._utils._element_size(self.dtype) - ) - rest_memory = available_gpu_memory - total_gpu_memory * ( - 1 - self.mem_fraction_static - ) - max_num_token = int(rest_memory * (1 << 30) // cell_size) - return max_num_token - - def init_memory_pool(self, total_gpu_memory): - self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory) - - if self.max_total_num_tokens <= 0: - raise RuntimeError( - "Not enough memory. Please try to increase --mem-fraction-static." - ) - - self.req_to_token_pool = ReqToTokenPool( - int(self.max_total_num_tokens / self.model_config.context_len * 256), - self.model_config.context_len + 8, - ) - self.token_to_kv_pool = TokenToKVPool( - self.max_total_num_tokens, - dtype=self.dtype, - head_num=self.model_config.get_num_kv_heads(self.tp_size), - head_dim=self.model_config.head_dim, - layer_num=self.model_config.num_hidden_layers, - ) - logger.info( - f"[gpu_id={self.gpu_id}] Memory pool end. " - f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" - ) - - def init_cublas(self): - """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later.""" - dtype = torch.float16 - device = "cuda" - a = torch.ones((16, 16), dtype=dtype, device=device) - b = torch.ones((16, 16), dtype=dtype, device=device) - c = a @ b - return c - - def init_flash_infer(self): - if not global_server_args_dict.get("disable_flashinfer", False): - from flashinfer import ( - BatchDecodeWithPagedKVCacheWrapper, - BatchPrefillWithPagedKVCacheWrapper, - BatchPrefillWithRaggedKVCacheWrapper, - ) - from flashinfer.decode import _grouped_size_compiled_for_decode_kernels - - if not _grouped_size_compiled_for_decode_kernels( - self.model_config.num_attention_heads // self.tp_size, - self.model_config.get_num_kv_heads(self.tp_size), - ): - use_tensor_cores = True - else: - use_tensor_cores = False - - workspace_buffers = torch.empty( - 2, 96 * 1024 * 1024, dtype=torch.uint8, device="cuda" - ) - self.flashinfer_prefill_wrapper_ragged = ( - BatchPrefillWithRaggedKVCacheWrapper(workspace_buffers[0], "NHD") - ) - self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper( - workspace_buffers[1], "NHD" - ) - self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( - workspace_buffers[0], "NHD", use_tensor_cores=use_tensor_cores - ) - else: - self.flashinfer_prefill_wrapper_ragged = ( - self.flashinfer_prefill_wrapper_paged - ) = None - self.flashinfer_decode_wrapper = None - - @torch.inference_mode() - def forward_extend(self, batch: Batch): - input_metadata = InputMetadata.create( - self, - forward_mode=ForwardMode.EXTEND, - req_pool_indices=batch.req_pool_indices, - seq_lens=batch.seq_lens, - prefix_lens=batch.prefix_lens, - position_ids_offsets=batch.position_ids_offsets, - out_cache_loc=batch.out_cache_loc, - top_logprobs_nums=batch.top_logprobs_nums, - return_logprob=batch.return_logprob, - ) - return self.model.forward( - batch.input_ids, input_metadata.positions, input_metadata - ) - - @torch.inference_mode() - def forward_decode(self, batch: Batch): - input_metadata = InputMetadata.create( - self, - forward_mode=ForwardMode.DECODE, - req_pool_indices=batch.req_pool_indices, - seq_lens=batch.seq_lens, - prefix_lens=batch.prefix_lens, - position_ids_offsets=batch.position_ids_offsets, - out_cache_loc=batch.out_cache_loc, - out_cache_cont_start=batch.out_cache_cont_start, - out_cache_cont_end=batch.out_cache_cont_end, - top_logprobs_nums=batch.top_logprobs_nums, - return_logprob=batch.return_logprob, - ) - return self.model.forward( - batch.input_ids, input_metadata.positions, input_metadata - ) - - @torch.inference_mode() - def forward_extend_multi_modal(self, batch: Batch): - input_metadata = InputMetadata.create( - self, - forward_mode=ForwardMode.EXTEND, - tp_size=self.tp_size, - req_pool_indices=batch.req_pool_indices, - seq_lens=batch.seq_lens, - prefix_lens=batch.prefix_lens, - position_ids_offsets=batch.position_ids_offsets, - out_cache_loc=batch.out_cache_loc, - top_logprobs_nums=batch.top_logprobs_nums, - return_logprob=batch.return_logprob, - flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged, - flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged, - flashinfer_decode_wrapper=self.flashinfer_decode_wrapper, - ) - return self.model.forward( - batch.input_ids, - input_metadata.positions, - input_metadata, - batch.pixel_values, - batch.image_sizes, - batch.image_offsets, - ) - - def forward(self, batch: Batch, forward_mode: ForwardMode): - if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND: - return self.forward_extend_multi_modal(batch) - elif forward_mode == ForwardMode.DECODE: - return self.forward_decode(batch) - elif forward_mode == ForwardMode.EXTEND: - return self.forward_extend(batch) - else: - raise ValueError(f"Invaid forward mode: {forward_mode}") - - -@lru_cache() -def import_model_classes(): - model_arch_name_to_cls = {} - package_name = "sglang.srt.models" - package = importlib.import_module(package_name) - for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."): - if not ispkg: - module = importlib.import_module(name) - if hasattr(module, "EntryClass"): - entry = module.EntryClass - if isinstance( - entry, list - ): # To support multiple model classes in one module - for tmp in entry: - model_arch_name_to_cls[tmp.__name__] = tmp - else: - model_arch_name_to_cls[entry.__name__] = entry - - # compat: some models such as chatglm has incorrect class set in config.json - # usage: [ tuple("From_Entry_Class_Name": EntryClass), ] - if hasattr(module, "EntryClassRemapping") and isinstance( - module.EntryClassRemapping, list - ): - for remap in module.EntryClassRemapping: - if isinstance(remap, tuple) and len(remap) == 2: - model_arch_name_to_cls[remap[0]] = remap[1] - - return model_arch_name_to_cls - - -def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]: - model_arch_name_to_cls = import_model_classes() - - if model_arch not in model_arch_name_to_cls: - raise ValueError( - f"Unsupported architectures: {model_arch}. " - f"Supported list: {list(model_arch_name_to_cls.keys())}" - ) - return model_arch_name_to_cls[model_arch] - - -# Monkey patch model loader -setattr(ModelRegistry, "load_model_cls", load_model_cls_srt) diff --git a/python/sglang/srt/managers/controller/schedule_heuristic.py b/python/sglang/srt/managers/controller/schedule_heuristic.py deleted file mode 100644 index 4ae1a7069fd..00000000000 --- a/python/sglang/srt/managers/controller/schedule_heuristic.py +++ /dev/null @@ -1,61 +0,0 @@ -"""Request scheduler heuristic.""" - -import random -from collections import defaultdict - - -class ScheduleHeuristic: - def __init__( - self, - schedule_heuristic, - max_running_seqs, - max_prefill_num_tokens, - max_total_num_tokens, - tree_cache, - ): - self.schedule_heuristic = schedule_heuristic - self.max_running_seqs = max_running_seqs - self.max_prefill_num_tokens = max_prefill_num_tokens - self.max_total_num_tokens = max_total_num_tokens - self.tree_cache = tree_cache - - def get_priority_queue(self, forward_queue): - if self.schedule_heuristic == "lpm": - # longest prefix match - forward_queue.sort(key=lambda x: -len(x.prefix_indices)) - return forward_queue - elif self.schedule_heuristic == "random": - random.shuffle(forward_queue) - return forward_queue - elif self.schedule_heuristic == "fcfs": - return forward_queue - elif self.schedule_heuristic == "dfs-weight": - last_node_to_reqs = defaultdict(list) - for req in forward_queue: - last_node_to_reqs[req.last_node].append(req) - - node_to_weight = defaultdict(int) - for node in last_node_to_reqs: - node_to_weight[node] = len(last_node_to_reqs[node]) - self.calc_weight(self.tree_cache.root_node, node_to_weight) - - q = [] - self.get_dfs_priority( - self.tree_cache.root_node, node_to_weight, last_node_to_reqs, q - ) - assert len(q) == len(forward_queue) - return q - else: - raise ValueError(f"Unknown schedule_heuristic: {self.schedule_heuristic}") - - def calc_weight(self, cur_node, node_to_weight): - for child in cur_node.children.values(): - self.calc_weight(child, node_to_weight) - node_to_weight[cur_node] += node_to_weight[child] - - def get_dfs_priority(self, cur_node, node_to_priority, last_node_to_reqs, q): - childs = [child for child in cur_node.children.values()] - childs.sort(key=lambda x: -node_to_priority[x]) - for child in childs: - self.get_dfs_priority(child, node_to_priority, last_node_to_reqs, q) - q.extend(last_node_to_reqs[cur_node]) diff --git a/python/sglang/srt/managers/controller/tp_worker.py b/python/sglang/srt/managers/controller/tp_worker.py deleted file mode 100644 index 6d92c6bab39..00000000000 --- a/python/sglang/srt/managers/controller/tp_worker.py +++ /dev/null @@ -1,813 +0,0 @@ -"""A tensor parallel worker.""" - -import asyncio -import logging -import time -import warnings -from concurrent.futures import ThreadPoolExecutor -from typing import List, Optional - -import rpyc -import torch -from rpyc.utils.classic import obtain - -from sglang.global_config import global_config -from sglang.srt.constrained.fsm_cache import FSMCache -from sglang.srt.constrained.jump_forward import JumpForwardCache -from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer -from sglang.srt.managers.controller.infer_batch import ( - FINISH_ABORT, - BaseFinishReason, - Batch, - ForwardMode, - Req, -) -from sglang.srt.managers.controller.model_runner import ModelRunner -from sglang.srt.managers.controller.radix_cache import RadixCache -from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic -from sglang.srt.managers.io_struct import ( - AbortReq, - BatchTokenIDOut, - FlushCacheReq, - TokenizedGenerateReqInput, -) -from sglang.srt.model_config import ModelConfig -from sglang.srt.server_args import ModelPortArgs, ServerArgs -from sglang.srt.utils import ( - connect_rpyc_service, - get_int_token_logit_bias, - is_multimodal_model, - set_random_seed, - start_rpyc_service_process, - suppress_other_loggers, -) -from sglang.utils import get_exception_traceback - -logger = logging.getLogger("srt.tp_worker") - - -class ModelTpServer: - def __init__( - self, - gpu_id: int, - tp_rank: int, - server_args: ServerArgs, - model_port_args: ModelPortArgs, - model_overide_args, - ): - server_args, model_port_args = obtain(server_args), obtain(model_port_args) - suppress_other_loggers() - - # Copy arguments - self.gpu_id = gpu_id - self.tp_rank = tp_rank - self.tp_size = server_args.tp_size - self.dp_size = server_args.dp_size - self.schedule_heuristic = server_args.schedule_heuristic - self.disable_regex_jump_forward = server_args.disable_regex_jump_forward - - # Init model and tokenizer - self.model_config = ModelConfig( - server_args.model_path, - server_args.trust_remote_code, - context_length=server_args.context_length, - model_overide_args=model_overide_args, - ) - self.model_runner = ModelRunner( - model_config=self.model_config, - mem_fraction_static=server_args.mem_fraction_static, - gpu_id=gpu_id, - tp_rank=tp_rank, - tp_size=server_args.tp_size, - nccl_port=model_port_args.nccl_port, - server_args=server_args, - ) - - if is_multimodal_model(server_args.model_path): - self.processor = get_processor( - server_args.tokenizer_path, - tokenizer_mode=server_args.tokenizer_mode, - trust_remote_code=server_args.trust_remote_code, - ) - self.tokenizer = self.processor.tokenizer - else: - self.tokenizer = get_tokenizer( - server_args.tokenizer_path, - tokenizer_mode=server_args.tokenizer_mode, - trust_remote_code=server_args.trust_remote_code, - ) - self.max_total_num_tokens = self.model_runner.max_total_num_tokens - self.max_prefill_tokens = ( - 4096 - if server_args.max_prefill_tokens is None - else server_args.max_prefill_tokens - ) - self.max_running_requests = ( - self.max_total_num_tokens // 2 - if server_args.max_running_requests is None - else server_args.max_running_requests - ) - self.int_token_logit_bias = torch.tensor( - get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size) - ) - set_random_seed(server_args.random_seed) - - # Print info - logger.info( - f"[gpu_id={self.gpu_id}] " - f"max_total_num_tokens={self.max_total_num_tokens}, " - f"max_prefill_tokens={self.max_prefill_tokens}, " - f"context_len={self.model_config.context_len}" - ) - if self.tp_rank == 0: - logger.info( - f"[gpu_id={self.gpu_id}] " - f"server_args: {server_args.print_mode_args()}" - ) - - # Init cache - self.tree_cache = RadixCache( - req_to_token_pool=self.model_runner.req_to_token_pool, - token_to_kv_pool=self.model_runner.token_to_kv_pool, - disable=server_args.disable_radix_cache, - ) - self.tree_cache_metrics = {"total": 0, "hit": 0} - self.scheduler = ScheduleHeuristic( - self.schedule_heuristic, - self.max_running_requests, - self.max_prefill_tokens, - self.max_total_num_tokens, - self.tree_cache, - ) - self.req_to_token_pool = self.model_runner.req_to_token_pool - self.token_to_kv_pool = self.model_runner.token_to_kv_pool - - # Init running status - self.forward_queue: List[Req] = [] - self.running_batch: Batch = None - self.out_pyobjs = [] - self.decode_forward_ct = 0 - self.stream_interval = server_args.stream_interval - self.num_generated_tokens = 0 - self.last_stats_tic = time.time() - - # Init the FSM cache for constrained generation - self.regex_fsm_cache = FSMCache( - server_args.tokenizer_path, - { - "tokenizer_mode": server_args.tokenizer_mode, - "trust_remote_code": server_args.trust_remote_code, - }, - ) - self.jump_forward_cache = JumpForwardCache() - - # Init new token estimation - assert ( - server_args.schedule_conservativeness >= 0 - ), "Invalid schedule_conservativeness" - self.new_token_ratio = min( - global_config.base_new_token_ratio * server_args.schedule_conservativeness, - 1.0, - ) - self.min_new_token_ratio = min( - global_config.base_min_new_token_ratio - * server_args.schedule_conservativeness, - 1.0, - ) - self.new_token_ratio_decay = global_config.new_token_ratio_decay - self.new_token_ratio_recovery = global_config.new_token_ratio_recovery - - def exposed_step(self, recv_reqs): - if self.tp_size * self.dp_size != 1: - recv_reqs = obtain(recv_reqs) - - try: - # Recv requests - for recv_req in recv_reqs: - if isinstance(recv_req, TokenizedGenerateReqInput): - self.handle_generate_request(recv_req) - elif isinstance(recv_req, FlushCacheReq): - self.flush_cache() - elif isinstance(recv_req, AbortReq): - self.abort_request(recv_req) - else: - raise ValueError(f"Invalid request: {recv_req}") - - # Forward - self.forward_step() - except Exception: - logger.error("Exception in ModelTpServer:\n" + get_exception_traceback()) - raise - - # Return results - ret = self.out_pyobjs - self.out_pyobjs = [] - return ret - - @torch.inference_mode() - def forward_step(self): - new_batch = self.get_new_fill_batch() - - if new_batch is not None: - # Run a new fill batch - self.forward_fill_batch(new_batch) - self.cache_filled_batch(new_batch) - - if not new_batch.is_empty(): - if self.running_batch is None: - self.running_batch = new_batch - else: - self.running_batch.merge(new_batch) - else: - # Run decode batch - if self.running_batch is not None: - # Run a few decode batches continuously for reducing overhead - for _ in range(10): - self.num_generated_tokens += len(self.running_batch.reqs) - self.forward_decode_batch(self.running_batch) - - # Print stats - if self.tp_rank == 0: - if self.decode_forward_ct % 40 == 0: - num_used = self.max_total_num_tokens - ( - self.token_to_kv_pool.available_size() - + self.tree_cache.evictable_size() - ) - throughput = self.num_generated_tokens / ( - time.time() - self.last_stats_tic - ) - self.num_generated_tokens = 0 - self.last_stats_tic = time.time() - logger.info( - f"[gpu_id={self.gpu_id}] Decode batch. " - f"#running-req: {len(self.running_batch.reqs)}, " - f"#token: {num_used}, " - f"token usage: {num_used / self.max_total_num_tokens:.2f}, " - f"gen throughput (token/s): {throughput:.2f}, " - f"#queue-req: {len(self.forward_queue)}" - ) - - if self.running_batch.is_empty(): - self.running_batch = None - break - - if self.out_pyobjs and self.running_batch.has_stream(): - break - else: - # Check the available size - available_size = ( - self.token_to_kv_pool.available_size() - + self.tree_cache.evictable_size() - ) - if available_size != self.max_total_num_tokens: - warnings.warn( - "Warning: " - f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n" - "KV cache pool leak detected!" - ) - - def handle_generate_request( - self, - recv_req: TokenizedGenerateReqInput, - ): - req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids) - req.pixel_values = recv_req.pixel_values - if req.pixel_values is not None: - req.pad_value = [ - (recv_req.image_hash) % self.model_config.vocab_size, - (recv_req.image_hash >> 16) % self.model_config.vocab_size, - (recv_req.image_hash >> 32) % self.model_config.vocab_size, - (recv_req.image_hash >> 64) % self.model_config.vocab_size, - ] - req.image_size = recv_req.image_size - ( - req.origin_input_ids, - req.image_offset, - ) = self.model_runner.model.pad_input_ids( - req.origin_input_ids_unpadded, - req.pad_value, - req.pixel_values.shape, - req.image_size, - ) - req.sampling_params = recv_req.sampling_params - req.return_logprob = recv_req.return_logprob - req.logprob_start_len = recv_req.logprob_start_len - req.top_logprobs_num = recv_req.top_logprobs_num - req.stream = recv_req.stream - req.tokenizer = self.tokenizer - - # Init regex fsm - if req.sampling_params.regex is not None: - req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex) - if not self.disable_regex_jump_forward: - req.jump_forward_map = self.jump_forward_cache.query( - req.sampling_params.regex - ) - - # Truncate prompts that are too long - req.origin_input_ids = req.origin_input_ids[: self.model_config.context_len - 1] - req.sampling_params.max_new_tokens = min( - req.sampling_params.max_new_tokens, - self.model_config.context_len - 1 - len(req.origin_input_ids), - self.max_total_num_tokens - 128 - len(req.origin_input_ids), - ) - self.forward_queue.append(req) - - def get_new_fill_batch(self) -> Optional[Batch]: - if ( - self.running_batch is not None - and len(self.running_batch.reqs) > self.max_running_requests - ): - return None - - # Compute matched prefix length - for req in self.forward_queue: - req.input_ids = req.origin_input_ids + req.output_ids - prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids) - if req.return_logprob: - prefix_indices = prefix_indices[: req.logprob_start_len] - req.extend_input_len = len(req.input_ids) - len(prefix_indices) - req.prefix_indices = prefix_indices - req.last_node = last_node - - # Get priority queue - self.forward_queue = self.scheduler.get_priority_queue(self.forward_queue) - - # Add requests if there is available space - can_run_list = [] - new_batch_total_tokens = 0 - new_batch_input_tokens = 0 - - available_size = ( - self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() - ) - if self.running_batch: - available_size -= sum( - [ - (r.max_new_tokens() - len(r.output_ids)) * self.new_token_ratio - for r in self.running_batch.reqs - ] - ) - - for req in self.forward_queue: - if req.return_logprob and req.normalized_prompt_logprob is None: - # Need at least two tokens to compute normalized logprob - if req.extend_input_len < 2: - delta = 2 - req.extend_input_len - req.extend_input_len += delta - req.prefix_indices = req.prefix_indices[:-delta] - if req.image_offset is not None: - req.image_offset += delta - if req.extend_input_len == 0 and req.max_new_tokens() > 0: - # Need at least one token to compute logits - req.extend_input_len = 1 - req.prefix_indices = req.prefix_indices[:-1] - if req.image_offset is not None: - req.image_offset += 1 - - if ( - req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens - < available_size - and ( - req.extend_input_len + new_batch_input_tokens - <= self.max_prefill_tokens - or len(can_run_list) == 0 - ) - ): - delta = self.tree_cache.inc_lock_ref(req.last_node) - available_size += delta - - if not ( - req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens - < available_size - ): - # Undo locking - delta = self.tree_cache.dec_lock_ref(req.last_node) - available_size += delta - break - else: - # Add this request to the running batch - can_run_list.append(req) - new_batch_total_tokens += ( - req.extend_input_len + req.max_new_tokens() - ) - new_batch_input_tokens += req.extend_input_len - else: - break - if len(can_run_list) == 0: - return None - - # Print stats - if self.tp_rank == 0: - running_req = ( - 0 if self.running_batch is None else len(self.running_batch.reqs) - ) - hit_tokens = sum(len(x.prefix_indices) for x in can_run_list) - self.tree_cache_metrics["total"] += ( - hit_tokens + new_batch_input_tokens - ) / 10**9 - self.tree_cache_metrics["hit"] += hit_tokens / 10**9 - tree_cache_hit_rate = ( - self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"] - ) - logger.info( - f"[gpu_id={self.gpu_id}] Prefill batch. " - f"#new-seq: {len(can_run_list)}, " - f"#new-token: {new_batch_input_tokens}, " - f"#cached-token: {hit_tokens}, " - f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, " - f"#running-req: {running_req}, " - f"#queue-req: {len(self.forward_queue) - len(can_run_list)}" - ) - # logger.debug( - # f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. " - # f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. " - # f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. " - # f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. " - # ) - - # Return the new batch - new_batch = Batch.init_new( - can_run_list, - self.req_to_token_pool, - self.token_to_kv_pool, - self.tree_cache, - ) - self.forward_queue = [x for x in self.forward_queue if x not in can_run_list] - return new_batch - - def forward_fill_batch(self, batch: Batch): - # Build batch tensors - batch.prepare_for_extend( - self.model_config.vocab_size, self.int_token_logit_bias - ) - - # Forward and sample the next tokens - if batch.extend_num_tokens != 0: - output = self.model_runner.forward(batch, ForwardMode.EXTEND) - next_token_ids, _ = batch.sample(output.next_token_logits) - - # Move logprobs to cpu - if output.next_token_logprobs is not None: - output.next_token_logprobs = output.next_token_logprobs[ - torch.arange(len(next_token_ids), device=next_token_ids.device), - next_token_ids, - ].tolist() - output.prefill_token_logprobs = output.prefill_token_logprobs.tolist() - output.normalized_prompt_logprobs = ( - output.normalized_prompt_logprobs.tolist() - ) - - next_token_ids = next_token_ids.tolist() - else: - next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs) - - # Check finish conditions - pt = 0 - for i, req in enumerate(batch.reqs): - req.completion_tokens_wo_jump_forward += 1 - req.output_ids.append(next_token_ids[i]) - req.check_finished() - - if req.return_logprob: - self.add_logprob_return_values(i, req, pt, next_token_ids, output) - pt += req.extend_input_len - - self.handle_finished_requests(batch) - - def add_logprob_return_values(self, i, req, pt, next_token_ids, output): - if req.normalized_prompt_logprob is None: - req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i] - - if req.prefill_token_logprobs is None: - # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored. - req.prefill_token_logprobs = list( - zip( - output.prefill_token_logprobs[pt : pt + req.extend_input_len - 1], - req.input_ids[-req.extend_input_len + 1 :], - ) - ) - if req.logprob_start_len == 0: - req.prefill_token_logprobs = [ - (None, req.input_ids[0]) - ] + req.prefill_token_logprobs - - if req.last_update_decode_tokens != 0: - req.decode_token_logprobs.extend( - list( - zip( - output.prefill_token_logprobs[ - pt - + req.extend_input_len - - req.last_update_decode_tokens : pt - + req.extend_input_len - - 1 - ], - req.input_ids[-req.last_update_decode_tokens + 1 :], - ) - ) - ) - - req.decode_token_logprobs.append( - (output.next_token_logprobs[i], next_token_ids[i]) - ) - - if req.top_logprobs_num > 0: - if req.prefill_top_logprobs is None: - req.prefill_top_logprobs = output.prefill_top_logprobs[i] - if req.logprob_start_len == 0: - req.prefill_top_logprobs = [None] + req.prefill_top_logprobs - - if req.last_update_decode_tokens != 0: - req.decode_top_logprobs.extend( - output.prefill_top_logprobs[i][-req.last_update_decode_tokens + 1 :] - ) - req.decode_top_logprobs.append(output.decode_top_logprobs[i]) - - def cache_filled_batch(self, batch: Batch): - req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy() - for i, req in enumerate(batch.reqs): - new_prefix_indices, new_last_node = self.tree_cache.cache_req( - token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1], - last_uncached_pos=len(req.prefix_indices), - req_pool_idx=req_pool_indices_cpu[i], - del_in_memory_pool=False, - old_last_node=req.last_node, - ) - req.prefix_indices, req.last_node = new_prefix_indices, new_last_node - - def forward_decode_batch(self, batch: Batch): - # Check if decode out of memory - if not batch.check_decode_mem(): - old_ratio = self.new_token_ratio - self.new_token_ratio = min(old_ratio + self.new_token_ratio_recovery, 1.0) - - retracted_reqs = batch.retract_decode() - logger.info( - "decode out of memory happened, " - f"#retracted_reqs: {len(retracted_reqs)}, " - f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}" - ) - self.forward_queue.extend(retracted_reqs) - else: - self.new_token_ratio = max( - self.new_token_ratio - self.new_token_ratio_decay, - self.min_new_token_ratio, - ) - - if not self.disable_regex_jump_forward: - # Check for jump-forward - jump_forward_reqs = batch.check_for_jump_forward(self.model_runner) - self.forward_queue.extend(jump_forward_reqs) - if batch.is_empty(): - return - - # Update batch tensors - self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30) - batch.prepare_for_decode() - - # Forward and sample the next tokens - output = self.model_runner.forward(batch, ForwardMode.DECODE) - next_token_ids, _ = batch.sample(output.next_token_logits) - - # Move logprobs to cpu - if output.next_token_logprobs is not None: - next_token_logprobs = output.next_token_logprobs[ - torch.arange(len(next_token_ids), device=next_token_ids.device), - next_token_ids, - ].tolist() - - next_token_ids = next_token_ids.tolist() - - # Check finish condition - for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)): - req.completion_tokens_wo_jump_forward += 1 - req.output_ids.append(next_token_id) - req.check_finished() - - if req.return_logprob: - req.decode_token_logprobs.append( - (next_token_logprobs[i], next_token_id) - ) - if req.top_logprobs_num > 0: - req.decode_top_logprobs.append(output.decode_top_logprobs[i]) - - self.handle_finished_requests(batch) - - def handle_finished_requests(self, batch: Batch): - output_rids = [] - decoded_texts = [] - surr_output_ids = [] - read_output_ids = [] - output_skip_special_tokens = [] - output_spaces_between_special_tokens = [] - output_meta_info = [] - output_finished_reason: List[BaseFinishReason] = [] - finished_indices = [] - unfinished_indices = [] - for i, req in enumerate(batch.reqs): - if req.finished(): - finished_indices.append(i) - else: - unfinished_indices.append(i) - - if req.finished() or ( - ( - req.stream - and ( - self.decode_forward_ct % self.stream_interval == 0 - or len(req.output_ids) == 1 - ) - ) - ): - output_rids.append(req.rid) - decoded_texts.append(req.decoded_text) - surr_ids, read_ids, _ = req.init_detokenize_incrementally() - surr_output_ids.append(surr_ids) - read_output_ids.append(read_ids) - output_skip_special_tokens.append( - req.sampling_params.skip_special_tokens - ) - output_spaces_between_special_tokens.append( - req.sampling_params.spaces_between_special_tokens - ) - - meta_info = { - "prompt_tokens": len(req.origin_input_ids), - "completion_tokens": len(req.output_ids), - "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward, - "finish_reason": str(req.finished_reason), - } - if req.return_logprob: - ( - meta_info["prefill_token_logprobs"], - meta_info["decode_token_logprobs"], - meta_info["prefill_top_logprobs"], - meta_info["decode_top_logprobs"], - meta_info["normalized_prompt_logprob"], - ) = ( - req.prefill_token_logprobs, - req.decode_token_logprobs, - req.prefill_top_logprobs, - req.decode_top_logprobs, - req.normalized_prompt_logprob, - ) - output_meta_info.append(meta_info) - output_finished_reason.append(req.finished_reason) - - # Send to detokenizer - if output_rids: - self.out_pyobjs.append( - BatchTokenIDOut( - output_rids, - decoded_texts, - surr_output_ids, - read_output_ids, - output_skip_special_tokens, - output_spaces_between_special_tokens, - output_meta_info, - output_finished_reason, - ) - ) - - # Remove finished reqs - if finished_indices: - # Update radix cache - req_pool_indices_cpu = batch.req_pool_indices.tolist() - for i in finished_indices: - req = batch.reqs[i] - self.tree_cache.cache_req( - token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1], - last_uncached_pos=len(req.prefix_indices), - req_pool_idx=req_pool_indices_cpu[i], - ) - - self.tree_cache.dec_lock_ref(req.last_node) - - # Update batch tensors - if unfinished_indices: - batch.filter_batch(unfinished_indices) - else: - batch.reqs = [] - - def flush_cache(self): - if len(self.forward_queue) == 0 and ( - self.running_batch is None or len(self.running_batch.reqs) == 0 - ): - self.tree_cache.reset() - self.tree_cache_metrics = {"total": 0, "hit": 0} - self.regex_fsm_cache.reset() - self.req_to_token_pool.clear() - self.token_to_kv_pool.clear() - torch.cuda.empty_cache() - logger.info("Cache flushed successfully!") - else: - warnings.warn( - f"Cache not flushed because there are pending requests. " - f"#queue-req: {len(self.forward_queue)}, " - f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}" - ) - - def abort_request(self, recv_req): - # Delete requests in the waiting queue - to_del = None - for i, req in enumerate(self.forward_queue): - if req.rid == recv_req.rid: - to_del = i - break - - if to_del is not None: - del self.forward_queue[to_del] - - # Delete requests in the running batch - if self.running_batch: - for req in self.running_batch.reqs: - if req.rid == recv_req.rid: - req.finished_reason = FINISH_ABORT() - break - - -class ModelTpService(rpyc.Service): - exposed_ModelTpServer = ModelTpServer - - -class ModelTpClient: - def __init__( - self, - gpu_ids: List[int], - server_args: ServerArgs, - model_port_args: ModelPortArgs, - model_overide_args, - ): - server_args, model_port_args = obtain(server_args), obtain(model_port_args) - self.tp_size = server_args.tp_size - - if self.tp_size * server_args.dp_size == 1: - # Init model - assert len(gpu_ids) == 1 - self.model_server = ModelTpService().exposed_ModelTpServer( - 0, - gpu_ids[0], - server_args, - model_port_args, - model_overide_args, - ) - - # Wrap functions - def async_wrap(f): - async def _func(*args, **kwargs): - return f(*args, **kwargs) - - return _func - - self.step = async_wrap(self.model_server.exposed_step) - else: - with ThreadPoolExecutor(self.tp_size) as executor: - # Launch model processes - if server_args.nnodes == 1: - self.procs = list( - executor.map( - lambda args: start_rpyc_service_process(*args), - [ - (ModelTpService, p) - for p in model_port_args.model_tp_ports - ], - ) - ) - addrs = [("localhost", p) for p in model_port_args.model_tp_ports] - else: - addrs = [ - (ip, port) - for ip, port in zip( - model_port_args.model_tp_ips, model_port_args.model_tp_ports - ) - ] - - self.model_services = list( - executor.map(lambda args: connect_rpyc_service(*args), addrs) - ) - - # Init model - def init_model(i): - return self.model_services[i].ModelTpServer( - gpu_ids[i], - i, - server_args, - model_port_args, - model_overide_args, - ) - - self.model_servers = list(executor.map(init_model, range(self.tp_size))) - - # Wrap functions - def async_wrap(func_name): - fs = [rpyc.async_(getattr(m, func_name)) for m in self.model_servers] - - async def _func(*args, **kwargs): - tasks = [f(*args, **kwargs) for f in fs] - await asyncio.gather(*[asyncio.to_thread(t.wait) for t in tasks]) - return obtain(tasks[0].value) - - return _func - - self.step = async_wrap("step") diff --git a/python/sglang/srt/managers/controller_multi.py b/python/sglang/srt/managers/controller_multi.py new file mode 100644 index 00000000000..dcd984e0f2d --- /dev/null +++ b/python/sglang/srt/managers/controller_multi.py @@ -0,0 +1,217 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +""" +A controller that manages multiple data parallel workers. +Each data parallel worker can manage multiple tensor parallel workers. +""" + +import dataclasses +import logging +import multiprocessing +import os +from enum import Enum, auto + +import numpy as np +import zmq + +from sglang.srt.managers.controller_single import ( + start_controller_process as start_controller_process_single, +) +from sglang.srt.managers.io_struct import ( + AbortReq, + FlushCacheReq, + TokenizedGenerateReqInput, +) +from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.utils import kill_parent_process +from sglang.utils import get_exception_traceback + +logger = logging.getLogger(__name__) + + +class LoadBalanceMethod(Enum): + """Load balance method.""" + + ROUND_ROBIN = auto() + SHORTEST_QUEUE = auto() + + @classmethod + def from_str(cls, method: str): + method = method.upper() + try: + return cls[method] + except KeyError as exc: + raise ValueError(f"Invalid load balance method: {method}") from exc + + +@dataclasses.dataclass +class WorkerHandle: + """Store the handle of a data parallel worker.""" + + proc: multiprocessing.Process + queue: multiprocessing.Queue + + +class ControllerMulti: + """A controller that manages multiple data parallel workers.""" + + def __init__( + self, + server_args: ServerArgs, + port_args: PortArgs, + model_overide_args, + ): + # Parse args + self.server_args = server_args + self.port_args = port_args + self.model_overide_args = model_overide_args + self.load_balance_method = LoadBalanceMethod.from_str( + server_args.load_balance_method + ) + + # Init communication + context = zmq.Context() + self.recv_from_tokenizer = context.socket(zmq.PULL) + self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.controller_port}") + + # Dispatch method + self.round_robin_counter = 0 + dispatch_lookup = { + LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler, + LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler, + } + self.dispatching = dispatch_lookup[self.load_balance_method] + + # Start data parallel workers + self.workers = [] + for i in range(server_args.dp_size): + self.start_dp_worker(i) + + def start_dp_worker(self, dp_worker_id: int): + tp_size = self.server_args.tp_size + + pipe_controller_reader, pipe_controller_writer = multiprocessing.Pipe( + duplex=False + ) + + gpu_ids = list(range(dp_worker_id * tp_size, (dp_worker_id + 1) * tp_size)) + queue = multiprocessing.Queue() + proc = multiprocessing.Process( + target=start_controller_process_single, + args=( + self.server_args, + self.port_args, + pipe_controller_writer, + self.model_overide_args, + True, + gpu_ids, + dp_worker_id, + queue, + ), + ) + proc.start() + + controller_init_state = pipe_controller_reader.recv() + if controller_init_state != "init ok": + raise RuntimeError( + f"Initialization failed. controller_init_state: {controller_init_state}" + ) + self.workers.append( + WorkerHandle( + proc=proc, + queue=queue, + ) + ) + + def round_robin_scheduler(self, input_requests): + for r in input_requests: + self.workers[self.round_robin_counter].queue.put(r) + self.round_robin_counter = (self.round_robin_counter + 1) % len( + self.workers + ) + + def shortest_queue_scheduler(self, input_requests): + for r in input_requests: + queue_sizes = [worker.queue.qsize() for worker in self.workers] + wid = np.argmin(queue_sizes) + self.workers[wid].queue.put(r) + + def loop_for_forward(self): + while True: + recv_reqs = self.recv_requests() + self.dispatching(recv_reqs) + + def recv_requests(self): + recv_reqs = [] + + while True: + try: + recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK) + except zmq.ZMQError: + break + + if isinstance(recv_req, FlushCacheReq): + # TODO(lsyin): apply more specific flushCacheReq + for worker in self.workers: + worker.queue.put(recv_req) + elif isinstance(recv_req, AbortReq): + in_queue = False + for i, req in enumerate(recv_reqs): + if req.rid == recv_req.rid: + recv_reqs[i] = recv_req + in_queue = True + break + if not in_queue: + # Send abort req to all TP groups + for worker in self.workers: + worker.queue.put(recv_req) + elif isinstance(recv_req, TokenizedGenerateReqInput): + recv_reqs.append(recv_req) + else: + logger.error(f"Invalid object: {recv_req}") + + return recv_reqs + + +def start_controller_process( + server_args: ServerArgs, + port_args: PortArgs, + pipe_writer, + model_overide_args: dict, +): + """Start a controller process.""" + + logging.basicConfig( + level=getattr(logging, server_args.log_level.upper()), + format="%(message)s", + ) + + try: + controller = ControllerMulti(server_args, port_args, model_overide_args) + except Exception: + pipe_writer.send(get_exception_traceback()) + raise + + pipe_writer.send("init ok") + + try: + controller.loop_for_forward() + except Exception: + logger.error("Exception in ControllerMulti:\n" + get_exception_traceback()) + finally: + for w in controller.workers: + os.kill(w.proc.pid, 9) + kill_parent_process() diff --git a/python/sglang/srt/managers/controller_single.py b/python/sglang/srt/managers/controller_single.py new file mode 100644 index 00000000000..415325b131c --- /dev/null +++ b/python/sglang/srt/managers/controller_single.py @@ -0,0 +1,172 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""A controller that manages a group of tensor parallel workers.""" + +import logging +import multiprocessing +import os +from typing import List + +import zmq + +from sglang.srt.managers.tp_worker import ( + ModelTpServer, + broadcast_recv_input, + launch_tp_servers, +) +from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.utils import kill_parent_process +from sglang.utils import get_exception_traceback + +logger = logging.getLogger(__name__) + + +class ControllerSingle: + """A controller that manages a group of tensor parallel workers.""" + + def __init__( + self, + server_args: ServerArgs, + port_args: PortArgs, + model_overide_args: dict, + gpu_ids: List[int], + is_data_parallel_worker: bool, + dp_worker_id: int, + mp_queue: multiprocessing.Queue, + ): + # Parse args + self.tp_size = server_args.tp_size + self.is_dp_worker = is_data_parallel_worker + self.dp_worker_id = dp_worker_id + self.mp_queue = mp_queue + + # Init communication + context = zmq.Context(2) + + if not self.is_dp_worker: + self.recv_from_tokenizer = context.socket(zmq.PULL) + self.recv_from_tokenizer.bind( + f"tcp://127.0.0.1:{port_args.controller_port}" + ) + + self.send_to_detokenizer = context.socket(zmq.PUSH) + self.send_to_detokenizer.connect( + f"tcp://127.0.0.1:{port_args.detokenizer_port}" + ) + + # Launch other tp ranks + tp_size_local = server_args.tp_size // server_args.nnodes + self.tp_procs = [] + if tp_size_local > 1: + tp_rank_range = range(1, tp_size_local) + self.tp_procs = launch_tp_servers( + gpu_ids, + tp_rank_range, + server_args, + port_args.nccl_ports[dp_worker_id], + model_overide_args, + ) + + # Launch tp rank 0 + self.tp_server = ModelTpServer( + gpu_ids[0], + 0, + server_args, + port_args.nccl_ports[dp_worker_id], + model_overide_args, + ) + self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group + + def loop_for_forward(self): + while True: + if not self.is_dp_worker: + recv_reqs = self.recv_requests_from_zmq() + else: + recv_reqs = self.recv_requests_from_mp_queue() + + if self.tp_size > 1: + broadcast_recv_input(recv_reqs, 0, self.tp_cpu_group) + + out_pyobjs = self.tp_server.exposed_step(recv_reqs) + + for obj in out_pyobjs: + self.send_to_detokenizer.send_pyobj(obj) + + def recv_requests_from_zmq(self): + recv_reqs = [] + while True: + try: + recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK) + except zmq.ZMQError: + break + recv_reqs.append(recv_req) + + return recv_reqs + + def recv_requests_from_mp_queue(self): + recv_reqs = [] + while not self.mp_queue.empty(): + recv_reqs.append(self.mp_queue.get()) + return recv_reqs + + +def start_controller_process( + server_args: ServerArgs, + port_args: PortArgs, + pipe_writer: multiprocessing.connection.Connection, + model_overide_args: dict, + is_data_parallel_worker: bool = False, + gpu_ids: List[int] = None, + dp_worker_id: int = None, + queue: multiprocessing.connection.Connection = None, +): + """Start a controller process.""" + + logging.basicConfig( + level=getattr(logging, server_args.log_level.upper()), + format="%(message)s", + ) + + if not is_data_parallel_worker: + tp_size_local = server_args.tp_size // server_args.nnodes + gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)] + dp_worker_id = 0 + queue = None + + try: + controller = ControllerSingle( + server_args, + port_args, + model_overide_args, + gpu_ids, + is_data_parallel_worker, + dp_worker_id, + queue, + ) + except Exception: + pipe_writer.send(get_exception_traceback()) + raise + + pipe_writer.send("init ok") + + try: + controller.loop_for_forward() + except Exception: + logger.error("Exception in ControllerSingle:\n" + get_exception_traceback()) + finally: + for t in controller.tp_procs: + os.kill(t.pid, 9) + kill_parent_process() diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 3e0183b1b9b..623ffe916eb 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -1,21 +1,51 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + """DetokenizerManager is a process that detokenizes the token ids.""" import asyncio +import dataclasses import inspect +from typing import List import uvloop import zmq import zmq.asyncio from sglang.srt.hf_transformers_utils import get_tokenizer -from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR -from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut +from sglang.srt.managers.io_struct import ( + BatchEmbeddingOut, + BatchStrOut, + BatchTokenIDOut, +) +from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR from sglang.srt.server_args import PortArgs, ServerArgs from sglang.utils import find_printable_text, get_exception_traceback, graceful_registry asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) +@dataclasses.dataclass +class DecodeStatus: + vid: int + decoded_text: str + decode_ids: List[int] + surr_offset: int + read_offset: int + + class DetokenizerManager: def __init__( self, @@ -35,19 +65,55 @@ def __init__( trust_remote_code=server_args.trust_remote_code, ) + self.decode_status = {} + async def handle_loop(self): while True: recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj() + + if isinstance(recv_obj, BatchEmbeddingOut): + self.send_to_tokenizer.send_pyobj( + BatchEmbeddingOut( + rids=recv_obj.rids, + embeddings=recv_obj.embeddings, + meta_info=recv_obj.meta_info, + finished_reason=recv_obj.finished_reason, + ) + ) + continue + assert isinstance(recv_obj, BatchTokenIDOut) + bs = len(recv_obj.rids) + + # Initialize decode status + read_ids, surr_ids = [], [] + for i in range(bs): + rid = recv_obj.rids[i] + vid = recv_obj.vids[i] + if rid not in self.decode_status or self.decode_status[rid].vid != vid: + s = DecodeStatus( + vid=vid, + decoded_text=recv_obj.decoded_texts[i], + decode_ids=recv_obj.decode_ids[i], + surr_offset=0, + read_offset=recv_obj.read_offsets[i], + ) + self.decode_status[rid] = s + else: + s = self.decode_status[rid] + s.decode_ids = recv_obj.decode_ids[i] + + read_ids.append(s.decode_ids[s.surr_offset :]) + surr_ids.append(s.decode_ids[s.surr_offset : s.read_offset]) # TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request surr_texts = self.tokenizer.batch_decode( - recv_obj.surr_output_ids, + surr_ids, skip_special_tokens=recv_obj.skip_special_tokens[0], spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0], ) read_texts = self.tokenizer.batch_decode( - recv_obj.read_output_ids, + read_ids, skip_special_tokens=recv_obj.skip_special_tokens[0], spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0], ) @@ -55,11 +121,20 @@ async def handle_loop(self): # Trim stop str # TODO(lmzheng): handle the case where multiple stop strs are hit output_strs = [] - for i in range(len(recv_obj.rids)): + for i in range(bs): + s = self.decode_status[recv_obj.rids[i]] new_text = read_texts[i][len(surr_texts[i]) :] if recv_obj.finished_reason[i] is None: - new_text = find_printable_text(new_text) - output_strs.append(recv_obj.decoded_texts[i] + new_text) + # Streaming chunk: update the decode status + if len(new_text) > 0 and not new_text.endswith("�"): + s.decoded_text = s.decoded_text + new_text + s.surr_offset = s.read_offset + s.read_offset = len(s.decode_ids) + new_text = "" + else: + new_text = find_printable_text(new_text) + + output_strs.append(s.decoded_text + new_text) if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR): pos = output_strs[i].find(recv_obj.finished_reason[i].matched) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 7b26a4f2da2..e4c3040c9a9 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -1,3 +1,18 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + """ The definition of objects transfered between different processes (TokenizerManager, DetokenizerManager, Controller). @@ -7,31 +22,34 @@ from dataclasses import dataclass from typing import Dict, List, Optional, Union -from sglang.srt.managers.controller.infer_batch import BaseFinishReason +import torch + +from sglang.srt.managers.schedule_batch import BaseFinishReason from sglang.srt.sampling_params import SamplingParams @dataclass class GenerateReqInput: - # The input prompt + # The input prompt. It can be a single prompt or a batch of prompts. text: Optional[Union[List[str], str]] = None - # The token ids for text; one can either specify text or input_ids + # The token ids for text; one can either specify text or input_ids. input_ids: Optional[Union[List[List[int]], List[int]]] = None - # The image input + # The image input. It can be a file name, a url, or base64 encoded string. + # See also python/sglang/srt/utils.py:load_image. image_data: Optional[Union[List[str], str]] = None - # The sampling_params + # The sampling_params. See descriptions below. sampling_params: Union[List[Dict], Dict] = None - # The request id + # The request id. rid: Optional[Union[List[str], str]] = None - # Whether to return logprobs + # Whether to return logprobs. return_logprob: Optional[Union[List[bool], bool]] = None - # The start location of the prompt for return_logprob + # The start location of the prompt for return_logprob. logprob_start_len: Optional[Union[List[int], int]] = None - # The number of top logprobs to return + # The number of top logprobs to return. top_logprobs_num: Optional[Union[List[int], int]] = None - # Whether to detokenize tokens in logprobs + # Whether to detokenize tokens in text in the returned logprobs. return_text_in_logprobs: bool = False - # Whether to stream output + # Whether to stream output. stream: bool = False def post_init(self): @@ -39,11 +57,16 @@ def post_init(self): self.text is not None and self.input_ids is not None ): raise ValueError("Either text or input_ids should be provided.") - - if self.text is not None: - is_single = isinstance(self.text, str) + if ( + isinstance(self.sampling_params, dict) + and self.sampling_params.get("n", 1) != 1 + ): + is_single = False else: - is_single = isinstance(self.input_ids[0], int) + if self.text is not None: + is_single = isinstance(self.text, str) + else: + is_single = isinstance(self.input_ids[0], int) self.is_single = is_single if is_single: @@ -58,7 +81,45 @@ def post_init(self): if self.top_logprobs_num is None: self.top_logprobs_num = 0 else: - num = len(self.text) if self.text is not None else len(self.input_ids) + parallel_sample_num_list = [] + if isinstance(self.sampling_params, dict): + parallel_sample_num = self.sampling_params.get("n", 1) + elif isinstance(self.sampling_params, list): + for sp in self.sampling_params: + parallel_sample_num = sp.get("n", 1) + parallel_sample_num_list.append(parallel_sample_num) + parallel_sample_num = max(parallel_sample_num_list) + all_equal = all( + element == parallel_sample_num + for element in parallel_sample_num_list + ) + if parallel_sample_num > 1 and (not all_equal): + # TODO cope with the case that the parallel_sample_num is different for different samples + raise ValueError( + "The parallel_sample_num should be the same for all samples in sample params." + ) + else: + parallel_sample_num = 1 + self.parallel_sample_num = parallel_sample_num + + if parallel_sample_num != 1: + # parallel sampling +1 represents the original prefill stage + num = parallel_sample_num + 1 + if isinstance(self.text, list): + # suppot batch operation + self.batch_size = len(self.text) + num = num * len(self.text) + elif isinstance(self.input_ids, list) and isinstance( + self.input_ids[0], list + ): + self.batch_size = len(self.input_ids) + num = num * len(self.input_ids) + else: + self.batch_size = 1 + else: + # support select operation + num = len(self.text) if self.text is not None else len(self.input_ids) + self.batch_size = num if self.image_data is None: self.image_data = [None] * num @@ -107,12 +168,63 @@ class TokenizedGenerateReqInput: stream: bool +@dataclass +class EmbeddingReqInput: + # The input prompt. It can be a single prompt or a batch of prompts. + text: Optional[Union[List[str], str]] = None + # The token ids for text; one can either specify text or input_ids. + input_ids: Optional[Union[List[List[int]], List[int]]] = None + # The request id. + rid: Optional[Union[List[str], str]] = None + # Dummy sampling params for compatibility + sampling_params: Union[List[Dict], Dict] = None + + def post_init(self): + if (self.text is None and self.input_ids is None) or ( + self.text is not None and self.input_ids is not None + ): + raise ValueError("Either text or input_ids should be provided.") + + if self.text is not None: + is_single = isinstance(self.text, str) + else: + is_single = isinstance(self.input_ids[0], int) + self.is_single = is_single + + if is_single: + if self.rid is None: + self.rid = uuid.uuid4().hex + self.sampling_params = {"max_new_tokens": 0} + else: + # support select operation + self.batch_size = ( + len(self.text) if self.text is not None else len(self.input_ids) + ) + if self.rid is None: + self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)] + else: + if not isinstance(self.rid, list): + raise ValueError("The rid should be a list.") + self.sampling_params = [ + {"max_new_tokens": 0} for _ in range(self.batch_size) + ] + + +@dataclass +class TokenizedEmbeddingReqInput: + rid: str + input_text: str + input_ids: List[int] + sampling_params: SamplingParams + + @dataclass class BatchTokenIDOut: rids: List[str] + vids: List[int] decoded_texts: List[str] - surr_output_ids: List[List[int]] - read_output_ids: List[List[int]] + decode_ids: List[int] + read_offsets: List[int] skip_special_tokens: List[bool] spaces_between_special_tokens: List[bool] meta_info: List[Dict] @@ -127,6 +239,14 @@ class BatchStrOut: finished_reason: List[BaseFinishReason] +@dataclass +class BatchEmbeddingOut: + rids: List[str] + embeddings: List[List[float]] + meta_info: List[Dict] + finished_reason: List[BaseFinishReason] + + @dataclass class FlushCacheReq: pass diff --git a/python/sglang/srt/managers/policy_scheduler.py b/python/sglang/srt/managers/policy_scheduler.py new file mode 100644 index 00000000000..30a009c2e6a --- /dev/null +++ b/python/sglang/srt/managers/policy_scheduler.py @@ -0,0 +1,207 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Request policy scheduler""" + +import random +from collections import defaultdict +from contextlib import contextmanager + +from sglang.srt.managers.schedule_batch import Req, ScheduleBatch + + +class PolicyScheduler: + def __init__( + self, + policy, + max_running_seqs, + max_prefill_num_tokens, + max_total_num_tokens, + tree_cache, + ): + if tree_cache.disable and policy == "lpm": + # LMP is meaningless when the tree cache is disabled. + policy = "fcfs" + + self.policy = policy + self.max_running_seqs = max_running_seqs + self.max_prefill_num_tokens = max_prefill_num_tokens + self.max_total_num_tokens = max_total_num_tokens + self.tree_cache = tree_cache + + def get_priority_queue(self, waiting_queue): + if self.policy == "lpm": + # longest prefix match + waiting_queue.sort(key=lambda x: -len(x.prefix_indices)) + return waiting_queue + elif self.policy == "fcfs": + # first come first serve + return waiting_queue + elif self.policy == "lof": + # longest output first + waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens) + return waiting_queue + elif self.policy == "random": + random.shuffle(waiting_queue) + return waiting_queue + elif self.policy == "dfs-weight": + last_node_to_reqs = defaultdict(list) + for req in waiting_queue: + last_node_to_reqs[req.last_node].append(req) + + node_to_weight = defaultdict(int) + for node in last_node_to_reqs: + node_to_weight[node] = len(last_node_to_reqs[node]) + self.calc_weight(self.tree_cache.root_node, node_to_weight) + + q = [] + self.get_dfs_priority( + self.tree_cache.root_node, node_to_weight, last_node_to_reqs, q + ) + assert len(q) == len(waiting_queue) + return q + else: + raise ValueError(f"Unknown schedule_policy: {self.policy}") + + def calc_weight(self, cur_node, node_to_weight): + for child in cur_node.children.values(): + self.calc_weight(child, node_to_weight) + node_to_weight[cur_node] += node_to_weight[child] + + def get_dfs_priority(self, cur_node, node_to_priority, last_node_to_reqs, q): + childs = [child for child in cur_node.children.values()] + childs.sort(key=lambda x: -node_to_priority[x]) + for child in childs: + self.get_dfs_priority(child, node_to_priority, last_node_to_reqs, q) + q.extend(last_node_to_reqs[cur_node]) + + +class PrefillAdder: + def __init__( + self, + tree_cache, + rem_total_tokens, + rem_input_tokens, + rem_chunk_tokens, + ): + self.tree_cache = tree_cache + self.rem_total_tokens = rem_total_tokens + self.rem_input_tokens = rem_input_tokens + self.rem_chunk_tokens = rem_chunk_tokens + + self.can_run_list = [] + self.new_inflight_req = None + self.log_hit_tokens = 0 + self.log_input_tokens = 0 + + def no_remaining_tokens(self): + return ( + self.rem_total_tokens <= 0 + or self.rem_input_tokens <= 0 + or ( + self.rem_chunk_tokens <= 0 + if self.rem_chunk_tokens is not None + else False + ) + ) + + def remove_running_tokens( + self, running_batch: ScheduleBatch, new_token_ratio: float + ): + self.rem_total_tokens -= sum( + [ + (r.sampling_params.max_new_tokens - len(r.output_ids)) * new_token_ratio + for r in running_batch.reqs + ] + ) + + def _prefill_one_req( + self, prefix_len: int, extend_input_len: int, max_new_tokens: int + ): + self.rem_total_tokens -= extend_input_len + max_new_tokens + self.rem_input_tokens -= extend_input_len + if self.rem_chunk_tokens is not None: + self.rem_chunk_tokens -= extend_input_len + + self.log_hit_tokens += prefix_len + self.log_input_tokens += extend_input_len + + def add_inflight_req(self, req: Req): + req.input_ids = req.origin_input_ids + req.output_ids + req.extend_input_len = len(req.input_ids) - len(req.prefix_indices) + truncated = req.extend_input_len > self.rem_chunk_tokens + req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens) + req.input_ids = req.input_ids[: len(req.prefix_indices) + req.extend_input_len] + self.can_run_list.append(req) + + self._prefill_one_req( + len(req.prefix_indices), + req.extend_input_len, + req.sampling_params.max_new_tokens if not truncated else 0, + ) + + # Return if chunked prefill not finished + return req if truncated else None + + @contextmanager + def _lock_node(self, last_node): + try: + delta = self.tree_cache.inc_lock_ref(last_node) + self.rem_total_tokens += delta + yield None + finally: + delta = self.tree_cache.dec_lock_ref(last_node) + self.rem_total_tokens += delta + + def add_one_req(self, req: Req): + total_tokens = req.extend_input_len + req.sampling_params.max_new_tokens + input_tokens = req.extend_input_len + prefix_len = len(req.prefix_indices) + + if total_tokens >= self.rem_total_tokens: + return False + + if input_tokens > self.rem_input_tokens and len(self.can_run_list) != 0: + return False + + with self._lock_node(req.last_node): + if total_tokens > self.rem_total_tokens: + return False + + if ( + self.rem_chunk_tokens is None + or input_tokens <= self.rem_chunk_tokens + or (req.return_logprob and req.normalized_prompt_logprob is None) + ): + # Non-chunked prefill + self.can_run_list.append(req) + self.tree_cache.inc_lock_ref(req.last_node) + self._prefill_one_req( + prefix_len, input_tokens, req.sampling_params.max_new_tokens + ) + else: + # Chunked prefill + trunc_len = self.rem_chunk_tokens + if trunc_len == 0: + return False + + req.extend_input_len = trunc_len + req.input_ids = req.input_ids[: len(req.prefix_indices) + trunc_len] + self.can_run_list.append(req) + self.new_inflight_req = req + self.tree_cache.inc_lock_ref(req.last_node) + self._prefill_one_req(prefix_len, trunc_len, 0) + + return True diff --git a/python/sglang/srt/managers/controller/infer_batch.py b/python/sglang/srt/managers/schedule_batch.py similarity index 54% rename from python/sglang/srt/managers/controller/infer_batch.py rename to python/sglang/srt/managers/schedule_batch.py index 27d041d1dd7..d2101d2c0c0 100644 --- a/python/sglang/srt/managers/controller/infer_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1,31 +1,49 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + """Meta data for requests and batches""" +import logging import warnings from dataclasses import dataclass -from enum import IntEnum, auto from typing import List, Union import numpy as np import torch +from flashinfer.sampling import top_k_top_p_sampling_from_probs +import sglang.srt.sampling.penaltylib as penaltylib +from sglang.global_config import global_config from sglang.srt.constrained import RegexGuide from sglang.srt.constrained.jump_forward import JumpForwardMap -from sglang.srt.managers.controller.radix_cache import RadixCache -from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool +from sglang.srt.mem_cache.chunk_cache import ChunkCache +from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool +from sglang.srt.mem_cache.radix_cache import RadixCache INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 -# Store some global server args -global_server_args_dict = {} +# Put some global args for easy access +global_server_args_dict = { + "disable_flashinfer": False, + "disable_flashinfer_sampling": False, + "attention_reduce_in_fp32": False, + "enable_mla": False, +} -class ForwardMode(IntEnum): - # Prefill a new sequence. This is deprecated now. "EXTEND" covers this case. - PREFILL = auto() - # Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt). - EXTEND = auto() - # Decode one token. - DECODE = auto() +logger = logging.getLogger(__name__) class BaseFinishReason: @@ -75,6 +93,7 @@ class Req: """Store all inforamtion of a request.""" def __init__(self, rid, origin_input_text, origin_input_ids): + # Input and output info self.rid = rid self.origin_input_text = origin_input_text self.origin_input_ids_unpadded = origin_input_ids # Before image padding @@ -82,7 +101,19 @@ def __init__(self, rid, origin_input_text, origin_input_ids): self.output_ids = [] # Each decode stage's output ids self.input_ids = None # input_ids = origin_input_ids + output_ids + # Memory info + self.req_pool_idx = None + # For incremental decoding + # ----- | --------- read_ids -------| + # ----- | surr_ids | + # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx | + # ----- ^ ----------- ^ ----------- ^ + # ----- 1 ----------- 2 ----------- 3 + # 1: surr_offset + # 2: read_offset + # 3: last token + self.vid = 0 # version id to sync decode status with in detokenizer_manager self.decoded_text = "" self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm self.read_offset = None @@ -94,9 +125,14 @@ def __init__(self, rid, origin_input_text, origin_input_ids): # For vision input self.pixel_values = None self.image_size = None - self.image_offset = 0 + self.image_offset = None self.pad_value = None + # Prefix info + self.extend_input_len = 0 + self.prefix_indices = [] + self.last_node = None + # Sampling parameters self.sampling_params = None self.stream = False @@ -105,20 +141,16 @@ def __init__(self, rid, origin_input_text, origin_input_ids): self.tokenizer = None self.finished_reason = None - # Prefix info - self.extend_input_len = 0 - self.prefix_indices = [] - self.last_node = None - # Logprobs self.return_logprob = False + self.embedding = None self.logprob_start_len = 0 self.top_logprobs_num = 0 self.normalized_prompt_logprob = None - self.prefill_token_logprobs = None - self.prefill_top_logprobs = None - self.decode_token_logprobs = [] - self.decode_top_logprobs = [] + self.input_token_logprobs = None + self.input_top_logprobs = None + self.output_token_logprobs = [] + self.output_top_logprobs = [] # The tokens is prefilled but need to be considered as decode tokens # and should be updated for the decode logprobs self.last_update_decode_tokens = 0 @@ -132,8 +164,25 @@ def __init__(self, rid, origin_input_text, origin_input_ids): def finished(self) -> bool: return self.finished_reason is not None + def adjust_max_prefix_ids(self): + input_len = len(self.input_ids) + max_prefix_len = input_len + + if self.sampling_params.max_new_tokens > 0: + # Need at least one token to compute logits + max_prefix_len = min(max_prefix_len, input_len - 1) + + if self.return_logprob: + max_prefix_len = min(max_prefix_len, self.logprob_start_len) + + if self.normalized_prompt_logprob is None: + # Need at least two tokens to compute normalized logprob + max_prefix_len = min(max_prefix_len, input_len - 2) + + return self.input_ids[:max_prefix_len] + # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313 - def init_detokenize_incrementally(self): + def init_incremental_detokenize(self): first_iter = self.surr_offset is None or self.read_offset is None if first_iter: @@ -143,13 +192,11 @@ def init_detokenize_incrementally(self): ) all_ids = self.origin_input_ids_unpadded + self.output_ids - surr_ids = all_ids[self.surr_offset : self.read_offset] - read_ids = all_ids[self.surr_offset :] + return all_ids[self.surr_offset :], self.read_offset - self.surr_offset - return surr_ids, read_ids, len(all_ids) - - def detokenize_incrementally(self, inplace: bool = True): - surr_ids, read_ids, num_all_tokens = self.init_detokenize_incrementally() + def get_next_inc_detokenization(self): + read_ids, read_offset = self.init_incremental_detokenize() + surr_ids = read_ids[:read_offset] surr_text = self.tokenizer.decode( surr_ids, @@ -163,29 +210,23 @@ def detokenize_incrementally(self, inplace: bool = True): ) if len(new_text) > len(surr_text) and not new_text.endswith("�"): - new_text = new_text[len(surr_text) :] - if inplace: - self.decoded_text += new_text - self.surr_offset = self.read_offset - self.read_offset = num_all_tokens - - return True, new_text + return True, new_text[len(surr_text) :] return False, "" - def max_new_tokens(self): - return self.sampling_params.max_new_tokens - def check_finished(self): if self.finished(): return if len(self.output_ids) >= self.sampling_params.max_new_tokens: - self.finished_reason = FINISH_LENGTH(len(self.output_ids)) + self.finished_reason = FINISH_LENGTH( + length=self.sampling_params.max_new_tokens + ) return + last_token_id = self.output_ids[-1] if ( - self.output_ids[-1] == self.tokenizer.eos_token_id + last_token_id == self.tokenizer.eos_token_id and not self.sampling_params.ignore_eos ): self.finished_reason = FINISH_MATCHED_TOKEN( @@ -193,6 +234,10 @@ def check_finished(self): ) return + if last_token_id in self.sampling_params.stop_token_ids: + self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id) + return + if len(self.sampling_params.stop_strs) > 0: tail_str = self.tokenizer.decode( self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :] @@ -246,8 +291,8 @@ def jump_forward_and_retokenize(self, jump_forward_str, next_state): k = k + 1 else: break - self.decode_token_logprobs = self.decode_token_logprobs[:k] - self.decode_top_logprobs = self.decode_top_logprobs[:k] + self.output_token_logprobs = self.output_token_logprobs[:k] + self.output_top_logprobs = self.output_top_logprobs[:k] self.logprob_start_len = prompt_tokens + k self.last_update_decode_tokens = len(self.output_ids) - k @@ -258,43 +303,32 @@ def __repr__(self): @dataclass -class Batch: +class ScheduleBatch: """Store all inforamtion of a batch.""" + # Request, memory pool, and cache reqs: List[Req] req_to_token_pool: ReqToTokenPool - token_to_kv_pool: TokenToKVPool + token_to_kv_pool: BaseTokenToKVPool tree_cache: RadixCache - # batched arguments to model runner + # Batched arguments to model runner input_ids: torch.Tensor = None req_pool_indices: torch.Tensor = None seq_lens: torch.Tensor = None - prefix_lens: torch.Tensor = None position_ids_offsets: torch.Tensor = None out_cache_loc: torch.Tensor = None - out_cache_cont_start: torch.Tensor = None - out_cache_cont_end: torch.Tensor = None + extend_num_tokens: int = None - # for processing logprobs + # For processing logprobs return_logprob: bool = False top_logprobs_nums: List[int] = None - # for multimodal - pixel_values: List[torch.Tensor] = None - image_sizes: List[List[int]] = None - image_offsets: List[int] = None - - # other arguments for control - output_ids: torch.Tensor = None - extend_num_tokens: int = None - - # batched sampling params + # Batched sampling params temperatures: torch.Tensor = None top_ps: torch.Tensor = None top_ks: torch.Tensor = None - frequency_penalties: torch.Tensor = None - presence_penalties: torch.Tensor = None + penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None logit_bias: torch.Tensor = None @classmethod @@ -309,91 +343,44 @@ def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache): return_logprob=return_logprob, ) + def batch_size(self): + return len(self.reqs) if self.reqs is not None else 0 + def is_empty(self): return len(self.reqs) == 0 - # whether batch has at least 1 streaming request def has_stream(self) -> bool: + # Return whether batch has at least 1 streaming request return any(r.stream for r in self.reqs) - def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor): - device = "cuda" - bs = len(self.reqs) - reqs = self.reqs - input_ids = [r.input_ids[len(r.prefix_indices) :] for r in reqs] - prefix_indices = [r.prefix_indices for r in reqs] - - # Handle prefix - flatten_input_ids = [] - extend_lens = [] - prefix_lens = [] - seq_lens = [] - - req_pool_indices = self.req_to_token_pool.alloc(bs) - req_pool_indices_cpu = req_pool_indices.cpu().numpy() - for i in range(bs): - flatten_input_ids.extend(input_ids[i]) - extend_lens.append(len(input_ids[i])) - - if len(prefix_indices[i]) == 0: - prefix_lens.append(0) - else: - prefix_lens.append(len(prefix_indices[i])) - self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][ - : len(prefix_indices[i]) - ] = prefix_indices[i] - - seq_lens.append(prefix_lens[-1] + extend_lens[-1]) + def alloc_req_slots(self, num_reqs): + req_pool_indices = self.req_to_token_pool.alloc(num_reqs) + if req_pool_indices is None: + raise RuntimeError( + "Out of memory. " + "Please set a smaller number for `--max-running-requests`." + ) + return req_pool_indices - position_ids_offsets = torch.zeros((bs,), dtype=torch.int32, device=device) + def alloc_token_slots(self, num_tokens: int): + out_cache_loc = self.token_to_kv_pool.alloc(num_tokens) - # Alloc mem - seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens) - extend_num_tokens = seq_lens.sum() - prefix_lens.sum() - out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens) if out_cache_loc is None: - self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.dec_refs) - out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens) + if self.tree_cache is not None: + self.tree_cache.evict(num_tokens, self.token_to_kv_pool.free) + out_cache_loc = self.token_to_kv_pool.alloc(num_tokens) if out_cache_loc is None: - print("Prefill out of memory. This should never happen.") - self.tree_cache.pretty_print() - exit() - - pt = 0 - for i in range(bs): - self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][ - prefix_lens[i] : prefix_lens[i] + extend_lens[i] - ] = out_cache_loc[pt : pt + extend_lens[i]] - pt += extend_lens[i] - - # Handle logit bias but only allocate when needed - logit_bias = None - for i in range(bs): - if reqs[i].sampling_params.dtype == "int": - if logit_bias is None: - logit_bias = torch.zeros( - (bs, vocab_size), dtype=torch.float32, device=device - ) - logit_bias[i] = int_token_logit_bias + logger.error("Prefill out of memory. Try to lower your batch size.") + if self.tree_cache is not None: + self.tree_cache.pretty_print() + exit(1) - # Set fields - self.input_ids = torch.tensor( - flatten_input_ids, dtype=torch.int32, device=device - ) - self.pixel_values = [r.pixel_values for r in reqs] - self.image_sizes = [r.image_size for r in reqs] - self.image_offsets = [ - r.image_offset - p_len for r, p_len in zip(reqs, prefix_lens) - ] - self.req_pool_indices = req_pool_indices - self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32, device=device) - self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device) - self.position_ids_offsets = position_ids_offsets - self.extend_num_tokens = extend_num_tokens - self.out_cache_loc = out_cache_loc - self.top_logprobs_nums = [r.top_logprobs_num for r in reqs] + return out_cache_loc + def batch_sampling_params(self, vocab_size, int_token_logit_bias): + device = "cuda" + bs, reqs = self.batch_size(), self.reqs self.temperatures = torch.tensor( [r.sampling_params.temperature for r in reqs], dtype=torch.float, @@ -401,28 +388,87 @@ def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor ).view(-1, 1) self.top_ps = torch.tensor( [r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device - ).view(-1, 1) + ) self.top_ks = torch.tensor( [r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device - ).view(-1, 1) - self.frequency_penalties = torch.tensor( - [r.sampling_params.frequency_penalty for r in reqs], - dtype=torch.float, - device=device, ) - self.presence_penalties = torch.tensor( - [r.sampling_params.presence_penalty for r in reqs], - dtype=torch.float, + + # Each penalizers will do nothing if they evaluate themselves as not required by looking at + # the sampling_params of the requests (See {_is_required()} of each penalizers). So this + # should not add hefty computation overhead other than simple checks. + # + # While we choose not to even create the class instances if they are not required, this + # could add additional complexity to the {ScheduleBatch} class, especially we need to + # handle {filter_batch()} and {merge()} cases as well. + self.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator( + vocab_size=vocab_size, + batch=self, device=device, + Penalizers={ + penaltylib.BatchedFrequencyPenalizer, + penaltylib.BatchedMinNewTokensPenalizer, + penaltylib.BatchedPresencePenalizer, + penaltylib.BatchedRepetitionPenalizer, + }, ) - self.logit_bias = logit_bias + + # Handle logit bias but only allocate when needed + self.logit_bias = None + for i in range(bs): + if reqs[i].sampling_params.dtype == "int": + if self.logit_bias is None: + self.logit_bias = torch.zeros( + (bs, vocab_size), dtype=torch.float32, device=device + ) + self.logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias + + def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor): + bs = self.batch_size() + reqs = self.reqs + input_ids = [r.input_ids[len(r.prefix_indices) :] for r in reqs] + extend_num_tokens = sum(len(ids) for ids in input_ids) + seq_lens = [] + + # Allocate memory + req_pool_indices_cpu = self.alloc_req_slots(bs) + out_cache_loc = self.alloc_token_slots(extend_num_tokens) + + pt = 0 + for i, req in enumerate(reqs): + req.req_pool_idx = req_pool_indices_cpu[i] + pre_len, seq_len = len(req.prefix_indices), len(req.input_ids) + ext_len = seq_len - pre_len + seq_lens.append(seq_len) + + if pre_len > 0: + self.req_to_token_pool.req_to_token[req.req_pool_idx][ + :pre_len + ] = req.prefix_indices + + self.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = ( + out_cache_loc[pt : pt + ext_len] + ) + pt += ext_len + + # Set fields + with torch.device("cuda"): + self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32) + self.req_pool_indices = torch.tensor(req_pool_indices_cpu) + self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32) + self.position_ids_offsets = torch.zeros((bs,), dtype=torch.int64) + + self.extend_num_tokens = extend_num_tokens + self.out_cache_loc = out_cache_loc + self.top_logprobs_nums = [r.top_logprobs_num for r in reqs] + + self.batch_sampling_params(vocab_size, int_token_logit_bias) def check_decode_mem(self): - bs = len(self.reqs) + bs = self.batch_size() if self.token_to_kv_pool.available_size() >= bs: return True - self.tree_cache.evict(bs, self.token_to_kv_pool.dec_refs) + self.tree_cache.evict(bs, self.token_to_kv_pool.free) if self.token_to_kv_pool.available_size() >= bs: return True @@ -431,7 +477,8 @@ def check_decode_mem(self): def retract_decode(self): sorted_indices = [i for i in range(len(self.reqs))] - # TODO(lsyin): improve the priority of retraction + + # TODO(lsyin): improve retraction policy for radix cache sorted_indices.sort( key=lambda i: ( len(self.reqs[i].output_ids), @@ -442,21 +489,48 @@ def retract_decode(self): retracted_reqs = [] seq_lens_cpu = self.seq_lens.cpu().numpy() - req_pool_indices_cpu = self.req_pool_indices.cpu().numpy() - while self.token_to_kv_pool.available_size() < len(self.reqs): + while ( + self.token_to_kv_pool.available_size() + < len(sorted_indices) * global_config.retract_decode_steps + ): + if len(sorted_indices) == 1: + # Corner case: only one request left + assert ( + self.token_to_kv_pool.available_size() > 0 + ), "No space left for only one request" + break + idx = sorted_indices.pop() req = self.reqs[idx] retracted_reqs.append(req) - # TODO: apply more fine-grained retraction - last_uncached_pos = len(req.prefix_indices) - token_indices = self.req_to_token_pool.req_to_token[ - req_pool_indices_cpu[idx] - ][last_uncached_pos : seq_lens_cpu[idx]] - self.token_to_kv_pool.dec_refs(token_indices) + if isinstance(self.tree_cache, ChunkCache): + # ChunkCache does not have eviction + token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx][ + : seq_lens_cpu[idx] + ] + self.token_to_kv_pool.free(token_indices) + self.req_to_token_pool.free(req.req_pool_idx) + del self.tree_cache.entries[req.rid] + else: + # TODO: apply more fine-grained retraction + last_uncached_pos = len(req.prefix_indices) + token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx][ + last_uncached_pos : seq_lens_cpu[idx] + ] + self.token_to_kv_pool.free(token_indices) + self.req_to_token_pool.free(req.req_pool_idx) + + # release the last node + self.tree_cache.dec_lock_ref(req.last_node) - # release the last node - self.tree_cache.dec_lock_ref(req.last_node) + # NOTE(lsyin): we should use the newly evictable memory instantly. + residual_size = ( + len(sorted_indices) * global_config.retract_decode_steps + - self.token_to_kv_pool.available_size() + ) + residual_size = max(0, residual_size) + self.tree_cache.evict(residual_size, self.token_to_kv_pool.free) req.prefix_indices = None req.last_node = None @@ -468,14 +542,21 @@ def retract_decode(self): self.filter_batch(sorted_indices) - return retracted_reqs + # Reqs in batch are filtered + total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs) + total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs) + + new_estimate_ratio = ( + total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs) + ) / total_max_new_tokens + new_estimate_ratio = min(1.0, new_estimate_ratio) + + return retracted_reqs, new_estimate_ratio def check_for_jump_forward(self, model_runner): jump_forward_reqs = [] filter_indices = [i for i in range(len(self.reqs))] - req_pool_indices_cpu = None - for i, req in enumerate(self.reqs): if req.jump_forward_map is not None: jump_forward_bytes = req.jump_forward_map.jump_forward_byte( @@ -502,7 +583,7 @@ def check_for_jump_forward(self, model_runner): cur_output_ids = req.output_ids req.output_ids.extend(suffix_ids) - decode_res, new_text = req.detokenize_incrementally(inplace=False) + decode_res, new_text = req.get_next_inc_detokenization() if not decode_res: req.output_ids = cur_output_ids continue @@ -521,17 +602,11 @@ def check_for_jump_forward(self, model_runner): req.output_ids = cur_output_ids continue - # insert the old request into tree_cache - if req_pool_indices_cpu is None: - req_pool_indices_cpu = self.req_pool_indices.tolist() - self.tree_cache.cache_req( - token_ids=cur_all_ids, - last_uncached_pos=len(req.prefix_indices), - req_pool_idx=req_pool_indices_cpu[i], - ) + # The decode status has diverged from detokenizer_manager + req.vid += 1 - # unlock the last node - self.tree_cache.dec_lock_ref(req.last_node) + # insert the old request into tree_cache + self.tree_cache.cache_finished_req(req, cur_all_ids) # re-applying image padding if req.pixel_values is not None: @@ -548,8 +623,7 @@ def check_for_jump_forward(self, model_runner): jump_forward_reqs.append(req) filter_indices.remove(i) - if len(filter_indices) < len(self.reqs): - self.filter_batch(filter_indices) + self.filter_batch(filter_indices) return jump_forward_reqs @@ -558,69 +632,68 @@ def prepare_for_decode(self, input_ids=None): input_ids = [ r.output_ids[-1] if r.output_ids else r.input_ids[-1] for r in self.reqs ] + else: + self.penalizer_orchestrator.cumulate_input_tokens(input_ids) + self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda") self.seq_lens.add_(1) - self.prefix_lens = None # Alloc mem - bs = len(self.reqs) - alloc_res = self.token_to_kv_pool.alloc_contiguous(bs) - if alloc_res is None: - self.out_cache_loc = self.token_to_kv_pool.alloc(bs) - - if self.out_cache_loc is None: - print("Decode out of memory. This should never happen.") - self.tree_cache.pretty_print() - exit() - - self.out_cache_cont_start = None - self.out_cache_cont_end = None - else: - self.out_cache_loc = alloc_res[0] - self.out_cache_cont_start = alloc_res[1] - self.out_cache_cont_end = alloc_res[2] + bs = self.batch_size() + self.out_cache_loc = self.alloc_token_slots(bs) self.req_to_token_pool.req_to_token[ self.req_pool_indices, self.seq_lens - 1 ] = self.out_cache_loc def filter_batch(self, unfinished_indices: List[int]): + if unfinished_indices is None or len(unfinished_indices) == 0: + # Filter out all requests + self.reqs = [] + return + + if len(unfinished_indices) == len(self.reqs): + # No need to filter + return + self.reqs = [self.reqs[i] for i in unfinished_indices] new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda") self.seq_lens = self.seq_lens[new_indices] self.input_ids = None self.req_pool_indices = self.req_pool_indices[new_indices] - self.prefix_lens = None self.position_ids_offsets = self.position_ids_offsets[new_indices] - self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None + self.out_cache_loc = None self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices] self.return_logprob = any(req.return_logprob for req in self.reqs) + self.penalizer_orchestrator.filter(unfinished_indices, new_indices) + for item in [ "temperatures", "top_ps", "top_ks", - "frequency_penalties", - "presence_penalties", "logit_bias", ]: self_val = getattr(self, item, None) - # logit_bias can be None - if self_val is not None: + if self_val is not None: # logit_bias can be None setattr(self, item, self_val[new_indices]) - def merge(self, other: "Batch"): + def merge(self, other: "ScheduleBatch"): + # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because + # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it + # needs to be called with pre-merged Batch.reqs. + self.penalizer_orchestrator.merge(other.penalizer_orchestrator) + self.reqs.extend(other.reqs) self.req_pool_indices = torch.concat( [self.req_pool_indices, other.req_pool_indices] ) self.seq_lens = torch.concat([self.seq_lens, other.seq_lens]) - self.prefix_lens = None self.position_ids_offsets = torch.concat( [self.position_ids_offsets, other.position_ids_offsets] ) - self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None + self.out_cache_loc = None self.top_logprobs_nums.extend(other.top_logprobs_nums) self.return_logprob = any(req.return_logprob for req in self.reqs) @@ -628,8 +701,6 @@ def merge(self, other: "Batch"): "temperatures", "top_ps", "top_ks", - "frequency_penalties", - "presence_penalties", ]: self_val = getattr(self, item, None) other_val = getattr(other, item, None) @@ -653,6 +724,7 @@ def merge(self, other: "Batch"): self.logit_bias = torch.concat([self.logit_bias, other.logit_bias]) def sample(self, logits: torch.Tensor): + # TODO(lsyin): move this into a part of layer and run with CUDA Graph # Post process logits logits = logits.contiguous() logits.div_(self.temperatures) @@ -670,16 +742,31 @@ def sample(self, logits: torch.Tensor): ] = 1 logits[i].masked_fill_(~allowed_mask, float("-inf")) - # TODO(lmzheng): apply penalty + logits = self.penalizer_orchestrator.apply(logits) + probs = torch.softmax(logits, dim=-1) - probs_sort, probs_idx = _top_p_top_k(probs, self.top_ps, self.top_ks) - sampled_index = torch.multinomial(probs_sort, num_samples=1) - batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view( - -1 - ) - batch_next_token_probs = torch.gather( - probs_sort, dim=1, index=sampled_index - ).view(-1) + + if not global_server_args_dict["disable_flashinfer_sampling"]: + max_top_k_round, batch_size = 32, probs.shape[0] + uniform_samples = torch.rand( + (max_top_k_round, batch_size), device=probs.device + ) + batch_next_token_ids, success = top_k_top_p_sampling_from_probs( + probs, uniform_samples, self.top_ks, self.top_ps + ) + else: + # Here we provide a slower fallback implementation. + batch_next_token_ids, success = top_k_top_p_sampling_from_probs_torch( + probs, self.top_ks, self.top_ps + ) + + if not torch.all(success): + warnings.warn("Sampling failed, fallback to top_k=1 strategy") + probs = probs.masked_fill(torch.isnan(probs), 0.0) + argmax_ids = torch.argmax(probs, dim=-1) + batch_next_token_ids = torch.where( + success, batch_next_token_ids, argmax_ids + ) if has_regex: batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy() @@ -689,203 +776,32 @@ def sample(self, logits: torch.Tensor): req.regex_fsm_state, batch_next_token_ids_cpu[i] ) - return batch_next_token_ids, batch_next_token_probs + self.penalizer_orchestrator.cumulate_output_tokens(batch_next_token_ids) + + return batch_next_token_ids -def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor): +def top_k_top_p_sampling_from_probs_torch( + probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor +): + """A top-k and top-k sampling implementation with native pytorch operations.""" probs_sort, probs_idx = probs.sort(dim=-1, descending=True) probs_sum = torch.cumsum(probs_sort, dim=-1) - probs_sort[(probs_sum - probs_sort) > top_ps] = 0.0 + probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0 probs_sort[ - torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1) >= top_ks + torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1) + >= top_ks.view(-1, 1) ] = 0.0 probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0]) - return probs_sort, probs_idx - - - -@dataclass -class InputMetadata: - """Store all inforamtion of a forward pass.""" - - forward_mode: ForwardMode - batch_size: int - total_num_tokens: int - max_seq_len: int - req_pool_indices: torch.Tensor - start_loc: torch.Tensor - seq_lens: torch.Tensor - prefix_lens: torch.Tensor - positions: torch.Tensor - req_to_token_pool: ReqToTokenPool - token_to_kv_pool: TokenToKVPool - - # for extend - extend_seq_lens: torch.Tensor = None - extend_start_loc: torch.Tensor = None - max_extend_len: int = 0 - - out_cache_loc: torch.Tensor = None - out_cache_cont_start: torch.Tensor = None - out_cache_cont_end: torch.Tensor = None - - return_logprob: bool = False - top_logprobs_nums: List[int] = None - - # for flashinfer - qo_indptr: torch.Tensor = None - kv_indptr: torch.Tensor = None - kv_indices: torch.Tensor = None - kv_last_page_len: torch.Tensor = None - flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None - flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None - flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None - - def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim): - if self.forward_mode == ForwardMode.DECODE: - paged_kernel_lens = self.seq_lens - else: - paged_kernel_lens = self.prefix_lens - self.no_prefix = torch.all(self.prefix_lens == 0) - - kv_indptr = torch.zeros( - (self.batch_size + 1,), dtype=torch.int32, device="cuda" - ) - kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) - req_pool_indices_cpu = self.req_pool_indices.cpu().numpy() - paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy() - kv_indices = torch.cat( - [ - self.req_to_token_pool.req_to_token[ - req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i] - ] - for i in range(self.batch_size) - ], - dim=0, - ).contiguous() - kv_last_page_len = torch.ones( - (self.batch_size,), dtype=torch.int32, device="cuda" - ) - - if self.forward_mode == ForwardMode.DECODE: - self.flashinfer_decode_wrapper.end_forward() - self.flashinfer_decode_wrapper.begin_forward( - kv_indptr, - kv_indices, - kv_last_page_len, - num_qo_heads, - num_kv_heads, - head_dim, - 1, - pos_encoding_mode="NONE", - data_type=self.token_to_kv_pool.kv_data[0].dtype, - ) - else: - # extend part - qo_indptr = torch.zeros( - (self.batch_size + 1,), dtype=torch.int32, device="cuda" - ) - qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0) - - self.flashinfer_prefill_wrapper_ragged.end_forward() - self.flashinfer_prefill_wrapper_ragged.begin_forward( - qo_indptr, - qo_indptr, - num_qo_heads, - num_kv_heads, - head_dim, - ) - - # cached part - self.flashinfer_prefill_wrapper_paged.end_forward() - self.flashinfer_prefill_wrapper_paged.begin_forward( - qo_indptr, - kv_indptr, - kv_indices, - kv_last_page_len, - num_qo_heads, - num_kv_heads, - head_dim, - 1, - ) - - def init_extend_args(self): - self.extend_seq_lens = self.seq_lens - self.prefix_lens - self.extend_start_loc = torch.zeros_like(self.seq_lens) - self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0) - self.max_extend_len = int(torch.max(self.extend_seq_lens)) - - @classmethod - def create( - cls, - model_runner, - forward_mode, - req_pool_indices, - seq_lens, - prefix_lens, - position_ids_offsets, - out_cache_loc, - out_cache_cont_start=None, - out_cache_cont_end=None, - top_logprobs_nums=None, - return_logprob=False, - ): - batch_size = len(req_pool_indices) - start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") - start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0) - total_num_tokens = int(torch.sum(seq_lens)) - max_seq_len = int(torch.max(seq_lens)) - - if forward_mode == ForwardMode.DECODE: - positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64) - else: - seq_lens_cpu = seq_lens.cpu().numpy() - prefix_lens_cpu = prefix_lens.cpu().numpy() - position_ids_offsets_cpu = position_ids_offsets.cpu().numpy() - positions = torch.tensor( - np.concatenate( - [ - np.arange( - prefix_lens_cpu[i] + position_ids_offsets_cpu[i], - seq_lens_cpu[i] + position_ids_offsets_cpu[i], - ) - for i in range(batch_size) - ], - axis=0, - ), - device="cuda", - ) - - ret = cls( - forward_mode=forward_mode, - batch_size=batch_size, - total_num_tokens=total_num_tokens, - max_seq_len=max_seq_len, - req_pool_indices=req_pool_indices, - start_loc=start_loc, - seq_lens=seq_lens, - prefix_lens=prefix_lens, - positions=positions, - req_to_token_pool=model_runner.req_to_token_pool, - token_to_kv_pool=model_runner.token_to_kv_pool, - out_cache_loc=out_cache_loc, - out_cache_cont_start=out_cache_cont_start, - out_cache_cont_end=out_cache_cont_end, - return_logprob=return_logprob, - top_logprobs_nums=top_logprobs_nums, - flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged, - flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged, - flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper, + try: + sampled_index = torch.multinomial(probs_sort, num_samples=1) + except RuntimeError: + batch_next_token_ids = torch.zeros( + (probs_sort.shape[0],), dtype=torch.int32, device=probs.device ) + success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device) + return batch_next_token_ids, success - if forward_mode == ForwardMode.EXTEND: - ret.init_extend_args() - - if not global_server_args_dict.get("disable_flashinfer", False): - ret.init_flashinfer_args( - model_runner.model_config.num_attention_heads // model_runner.tp_size, - model_runner.model_config.get_num_kv_heads(model_runner.tp_size), - model_runner.model_config.head_dim, - ) - - return ret + batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1) + success = torch.ones(probs.shape[0], dtype=torch.bool, device=probs.device) + return batch_next_token_ids, success diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index bd5012904bd..8711c127d65 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -1,3 +1,18 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + """TokenizerManager is a process that tokenizes the text.""" import asyncio @@ -6,7 +21,7 @@ import logging import multiprocessing as mp import os -from typing import Dict, List +from typing import Dict, List, Tuple, Union import numpy as np import transformers @@ -23,16 +38,19 @@ ) from sglang.srt.managers.io_struct import ( AbortReq, + BatchEmbeddingOut, BatchStrOut, BatchTokenIDOut, + EmbeddingReqInput, FlushCacheReq, GenerateReqInput, + TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, ) from sglang.srt.mm_utils import expand2square, process_anyres_image from sglang.srt.sampling_params import SamplingParams from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.utils import is_multimodal_model, load_image +from sglang.srt.utils import is_generation_model, is_multimodal_model, load_image from sglang.utils import get_exception_traceback asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -61,15 +79,21 @@ def __init__( self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}") self.send_to_router = context.socket(zmq.PUSH) - self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.router_port}") + self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.controller_port}") self.model_path = server_args.model_path + self.served_model_name = server_args.served_model_name self.hf_config = get_config( self.model_path, trust_remote_code=server_args.trust_remote_code, model_overide_args=model_overide_args, ) - self.context_len = get_context_length(self.hf_config) + self.is_generation = is_generation_model(self.hf_config.architectures) + + if server_args.context_length is not None: + self.context_len = server_args.context_length + else: + self.context_len = get_context_length(self.hf_config) if is_multimodal_model(self.model_path): self.processor = get_processor( @@ -113,141 +137,212 @@ async def get_pixel_values(self, image_data): image_data, aspect_ratio, grid_pinpoints, self.processor ) - async def generate_request(self, obj: GenerateReqInput, request=None): + async def generate_request( + self, obj: Union[GenerateReqInput, EmbeddingReqInput], request=None + ): if self.to_create_loop: self.create_handle_loop() obj.post_init() is_single = obj.is_single + if is_single: - rid = obj.rid + async for response in self._handle_single_request(obj, request): + yield response + else: + if isinstance(obj, EmbeddingReqInput): + raise NotImplementedError("Please send only one prompt in each request") + if obj.stream: + raise ValueError("Do not support stream for batch mode.") + + async for response in self._handle_batch_request(obj, request): + yield response + async def _handle_single_request( + self, + obj: Union[GenerateReqInput, EmbeddingReqInput], + request, + index=None, + is_cache_for_prefill=False, + ): + if not is_cache_for_prefill: # The normal case with a single prompt + not_use_index = index is None + + rid = obj.rid if not_use_index else obj.rid[index] + input_text = obj.text if not_use_index else obj.text[index] if obj.input_ids is None: - input_ids = self.tokenizer.encode(obj.text) + input_ids = self.tokenizer.encode(input_text) else: - input_ids = obj.input_ids + input_ids = obj.input_ids if not_use_index else obj.input_ids[index] - if len(input_ids) >= self.context_len: - raise ValueError( - f"The input ({len(input_ids)} tokens) is longer than the " - f"model's context length ({self.context_len} tokens)." - ) + self._validate_input_length(input_ids) - sampling_params = SamplingParams(**obj.sampling_params) - if sampling_params.max_new_tokens != 0: - sampling_params.normalize(self.tokenizer) - sampling_params.verify() + sampling_params = self._get_sampling_params( + obj.sampling_params if not_use_index else obj.sampling_params[index] + ) - if isinstance(obj.image_data, list) and len(obj.image_data) > 0: - pixel_values, image_hash, image_size = await self.get_pixel_values( - obj.image_data[0] + if self.is_generation: + pixel_values, image_hash, image_size = await self._get_pixel_values( + obj.image_data if not_use_index else obj.image_data[index] + ) + return_logprob = ( + obj.return_logprob if not_use_index else obj.return_logprob[index] ) - elif isinstance(obj.image_data, str): - pixel_values, image_hash, image_size = await self.get_pixel_values( - obj.image_data + logprob_start_len = ( + obj.logprob_start_len + if not_use_index + else obj.logprob_start_len[index] ) + top_logprobs_num = ( + obj.top_logprobs_num + if not_use_index + else obj.top_logprobs_num[index] + ) + else: # A prefill request to cache the common prompt for parallel sampling + assert self.is_generation + if obj.text is not None: + if isinstance(obj.text, list): + input_text = obj.text[index] + rid = obj.rid[index] + else: + input_text = obj.text + rid = obj.rid[0] + input_ids = self.tokenizer.encode(input_text) else: - pixel_values, image_hash, image_size = None, None, None + input_text = None + if isinstance(obj.input_ids, list) and isinstance( + obj.input_ids[0], list + ): + # when obj["input_ids"] is List[List[int]] + input_ids = obj.input_ids[index] + rid = obj.rid[index] + else: + input_ids = obj.input_ids + rid = obj.rid[0] + + sampling_params = SamplingParams(**obj.sampling_params[0]) + sampling_params.max_new_tokens = 0 + pixel_values, image_hash, image_size = await self._get_pixel_values( + obj.image_data[0] + ) + return_logprob = obj.return_logprob[0] + logprob_start_len = obj.logprob_start_len[0] + top_logprobs_num = obj.top_logprobs_num[0] + + if self.is_generation: tokenized_obj = TokenizedGenerateReqInput( - rid=rid, - input_text=obj.text, - input_ids=input_ids, - pixel_values=pixel_values, - image_hash=image_hash, - image_size=image_size, - sampling_params=sampling_params, - return_logprob=obj.return_logprob, - logprob_start_len=obj.logprob_start_len, - top_logprobs_num=obj.top_logprobs_num, - stream=obj.stream, + rid, + input_text, + input_ids, + pixel_values, + image_hash, + image_size, + sampling_params, + return_logprob, + logprob_start_len, + top_logprobs_num, + obj.stream, + ) + else: # is embedding + tokenized_obj = TokenizedEmbeddingReqInput( + rid, + input_text, + input_ids, + sampling_params, ) - self.send_to_router.send_pyobj(tokenized_obj) - event = asyncio.Event() - state = ReqState([], False, event) - self.rid_to_state[rid] = state + self.send_to_router.send_pyobj(tokenized_obj) - while True: - try: - await asyncio.wait_for(event.wait(), timeout=4) - except asyncio.TimeoutError: - if request is not None and await request.is_disconnected(): - self.abort_request(rid) - raise ValueError(f"Abort request {rid}") + event = asyncio.Event() + state = ReqState([], False, event) + self.rid_to_state[rid] = state + if not is_cache_for_prefill: + async for response in self._wait_for_response( + event, state, obj, rid, request + ): + yield response + else: + await self._wait_for_cache_prefill_response(event, state, obj, rid, request) + yield input_ids + + async def _handle_batch_request(self, obj: GenerateReqInput, request): + batch_size = obj.batch_size + parallel_sample_num = obj.parallel_sample_num + + if parallel_sample_num != 1: + # Send prefill requests to cache the common input + parallel_sample_num += 1 + input_id_result = [] if obj.input_ids is None else None + for i in range(batch_size): + async for input_id in self._handle_single_request( + obj, request, index=i, is_cache_for_prefill=True + ): + if input_id_result is not None: + input_id_result.append(input_id) + if input_id_result is not None and len(input_id_result) > 1: + obj.input_ids = input_id_result + elif input_id_result is not None: + obj.input_ids = input_id_result[0] + + # First send out all requests + for i in range(batch_size): + for j in range(parallel_sample_num): + if j == 0 and parallel_sample_num != 1: continue - - out = self.convert_logprob_style( - state.out_list[-1], - obj.return_logprob, - obj.top_logprobs_num, - obj.return_text_in_logprobs, + index = i * parallel_sample_num + j + if parallel_sample_num != 1: + # Here when using parallel sampling we should consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1 + index += batch_size - 1 - i + rid = obj.rid[index] + if parallel_sample_num == 1: + ## select operation + if obj.input_ids is None: + input_text = obj.text[i] + input_ids = self.tokenizer.encode(obj.text[i]) + else: + input_text = None + input_ids = obj.input_ids[i] + else: + assert obj.input_ids is not None + if batch_size == 1: + input_text = None + input_ids = obj.input_ids + else: + input_text = None + input_ids = obj.input_ids[i] + sampling_params = self._get_sampling_params(obj.sampling_params[index]) + pixel_values, image_hash, image_size = await self._get_pixel_values( + obj.image_data[index] ) - if self.server_args.log_requests and state.finished: - logger.info(f"in={obj.text}, out={out}") - - state.out_list = [] - if state.finished: - del self.rid_to_state[rid] - - yield out - - break - - event.clear() - - yield out - else: - if obj.stream: - raise ValueError("Do not support stream for batch mode.") - - if obj.input_ids is None: - bs = len(obj.text) - else: - bs = len(obj.input_ids) - - for i in range(bs): - rid = obj.rid[i] - - if obj.input_ids is None: - input_text = obj.text[i] - input_ids = self.tokenizer.encode(obj.text[i]) - else: - input_text = None - input_ids = obj.input_ids[i] - - sampling_params = SamplingParams(**obj.sampling_params[i]) - if sampling_params.max_new_tokens != 0: - sampling_params.normalize(self.tokenizer) - sampling_params.verify() - if obj.image_data[i] is None: - pixel_values, image_hash, image_size = None, None, None - else: - pixel_values, image_hash, image_size = await self.get_pixel_values( - obj.image_data[i] - ) tokenized_obj = TokenizedGenerateReqInput( - rid=rid, - input_text=input_text, - input_ids=input_ids, - pixel_values=pixel_values, - image_hash=image_hash, - image_size=image_size, - sampling_params=sampling_params, - return_logprob=obj.return_logprob[i], - logprob_start_len=obj.logprob_start_len[i], - top_logprobs_num=obj.top_logprobs_num[i], - stream=obj.stream, + rid, + input_text, + input_ids, + pixel_values, + image_hash, + image_size, + sampling_params, + obj.return_logprob[index], + obj.logprob_start_len[index], + obj.top_logprobs_num[index], + obj.stream, ) self.send_to_router.send_pyobj(tokenized_obj) event = asyncio.Event() state = ReqState([], False, event) self.rid_to_state[rid] = state - - output_list = [] - for i in range(bs): - rid = obj.rid[i] + # Then wait for all responses + output_list = [] + for i in range(batch_size): + for j in range(parallel_sample_num): + if j == 0 and parallel_sample_num != 1: + continue + index = i * parallel_sample_num + j + if parallel_sample_num != 1: + index += batch_size - 1 - i + rid = obj.rid[index] state = self.rid_to_state[rid] while True: @@ -260,25 +355,111 @@ async def generate_request(self, obj: GenerateReqInput, request=None): self.abort_request(rid) raise ValueError(f"Abort request {rid}") continue - output_list.append( self.convert_logprob_style( state.out_list[-1], - obj.return_logprob[i], - obj.top_logprobs_num[i], + obj.return_logprob[index], + obj.top_logprobs_num[index], obj.return_text_in_logprobs, ) ) assert state.finished del self.rid_to_state[rid] + yield output_list + + def _validate_input_length(self, input_ids: List[int]): + if len(input_ids) >= self.context_len: + raise ValueError( + f"The input ({len(input_ids)} tokens) is longer than the " + f"model's context length ({self.context_len} tokens)." + ) + + def _get_sampling_params(self, sampling_params_data: dict): + sampling_params = SamplingParams(**sampling_params_data) + if sampling_params.max_new_tokens != 0: + sampling_params.normalize(self.tokenizer) + sampling_params.verify() + return sampling_params + + async def _get_pixel_values(self, image_data): + if isinstance(image_data, list) and len(image_data) > 0: + return await self.get_pixel_values(image_data[0]) + elif isinstance(image_data, str): + return await self.get_pixel_values(image_data) + else: + return None, None, None + + async def _wait_for_response( + self, + event: asyncio.Event, + state: ReqState, + obj: Union[GenerateReqInput, EmbeddingReqInput], + rid: str, + request, + ): + while True: + try: + await asyncio.wait_for(event.wait(), timeout=4) + except asyncio.TimeoutError: + if request is not None and await request.is_disconnected(): + self.abort_request(rid) + raise ValueError(f"Abort request {rid}") + continue + + if self.is_generation: + out = self.convert_logprob_style( + state.out_list[-1], + obj.return_logprob, + obj.top_logprobs_num, + obj.return_text_in_logprobs, + ) + else: # isinstance(obj, EmbeddingReqInput) + out = state.out_list[-1] + + # Log requests + if self.server_args.log_requests and state.finished: + if obj.text is None: + in_obj = {"text": self.tokenizer.decode(obj.input_ids)} + else: + in_obj = {"text": obj.text} + logger.info(f"in={in_obj}, out={out}") + + state.out_list = [] + if state.finished: + del self.rid_to_state[rid] + yield out + break + + event.clear() + yield out + + async def _wait_for_cache_prefill_response( + self, + event: asyncio.Event, + state: ReqState, + obj: GenerateReqInput, + rid: str, + request, + ): + while True: + try: + await asyncio.wait_for(state.event.wait(), timeout=4) + break + except asyncio.TimeoutError: + if request is not None and await request.is_disconnected(): + for rid in obj.rid: + self.abort_request(rid) + raise ValueError(f"Abort request {rid}") + continue - yield output_list + assert state.finished + del self.rid_to_state[rid] def flush_cache(self): req = FlushCacheReq() self.send_to_router.send_pyobj(req) - def abort_request(self, rid): + def abort_request(self, rid: str): if rid not in self.rid_to_state: return del self.rid_to_state[rid] @@ -306,8 +487,10 @@ def create_handle_loop(self): async def handle_loop(self): while True: - recv_obj: BatchTokenIDOut = await self.recv_from_detokenizer.recv_pyobj() - assert isinstance(recv_obj, BatchStrOut) + recv_obj: Union[BatchStrOut, BatchEmbeddingOut] = ( + await self.recv_from_detokenizer.recv_pyobj() + ) + assert isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut)) for i, rid in enumerate(recv_obj.rids): state = self.rid_to_state.get(rid, None) @@ -315,39 +498,53 @@ async def handle_loop(self): continue recv_obj.meta_info[i]["id"] = rid - out_dict = { - "text": recv_obj.output_strs[i], - "meta_info": recv_obj.meta_info[i], - } + if isinstance(recv_obj, BatchStrOut): + out_dict = { + "text": recv_obj.output_strs[i], + "meta_info": recv_obj.meta_info[i], + } + else: + assert isinstance(recv_obj, BatchEmbeddingOut) + out_dict = { + "embedding": recv_obj.embeddings[i], + "meta_info": recv_obj.meta_info[i], + } state.out_list.append(out_dict) state.finished = recv_obj.finished_reason[i] is not None state.event.set() def convert_logprob_style( - self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs + self, + ret: dict, + return_logprob: bool, + top_logprobs_num: int, + return_text_in_logprobs: bool, ): if return_logprob: - ret["meta_info"]["prefill_token_logprobs"] = self.detokenize_logprob_tokens( - ret["meta_info"]["prefill_token_logprobs"], return_text_in_logprobs + ret["meta_info"]["input_token_logprobs"] = self.detokenize_logprob_tokens( + ret["meta_info"]["input_token_logprobs"], return_text_in_logprobs ) - ret["meta_info"]["decode_token_logprobs"] = self.detokenize_logprob_tokens( - ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs + ret["meta_info"]["output_token_logprobs"] = self.detokenize_logprob_tokens( + ret["meta_info"]["output_token_logprobs"], return_text_in_logprobs ) if top_logprobs_num > 0: - ret["meta_info"][ - "prefill_top_logprobs" - ] = self.detokenize_top_logprobs_tokens( - ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs + ret["meta_info"]["input_top_logprobs"] = ( + self.detokenize_top_logprobs_tokens( + ret["meta_info"]["input_top_logprobs"], + return_text_in_logprobs, + ) ) - ret["meta_info"][ - "decode_top_logprobs" - ] = self.detokenize_top_logprobs_tokens( - ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs + ret["meta_info"]["output_top_logprobs"] = ( + self.detokenize_top_logprobs_tokens( + ret["meta_info"]["output_top_logprobs"], return_text_in_logprobs + ) ) return ret - def detokenize_logprob_tokens(self, token_logprobs, decode_to_text): + def detokenize_logprob_tokens( + self, token_logprobs: List[Tuple[float, int]], decode_to_text: bool + ): if not decode_to_text: return [(logprob, token_id, None) for logprob, token_id in token_logprobs] @@ -358,10 +555,14 @@ def detokenize_logprob_tokens(self, token_logprobs, decode_to_text): for (logprob, token_id), token_text, in zip(token_logprobs, token_texts) ] - def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text): - for i, t in enumerate(top_logprobs): - if t: - top_logprobs[i] = self.detokenize_logprob_tokens(t, decode_to_text) + def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool): + # TODO: The current implementation only batches the detokenization for top-k tokens per single position. + # We should batch all top-k tokens in all positions. + for i, token_top_logprobs in enumerate(top_logprobs): + if token_top_logprobs: + top_logprobs[i] = self.detokenize_logprob_tokens( + token_top_logprobs, decode_to_text + ) return top_logprobs diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py new file mode 100644 index 00000000000..77941c8af75 --- /dev/null +++ b/python/sglang/srt/managers/tp_worker.py @@ -0,0 +1,833 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""A tensor parallel worker.""" + +import logging +import multiprocessing +import pickle +import time +import warnings +from typing import List, Optional, Union + +import torch +import torch.distributed as dist + +from sglang.global_config import global_config +from sglang.srt.constrained.fsm_cache import FSMCache +from sglang.srt.constrained.jump_forward import JumpForwardCache +from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer +from sglang.srt.managers.io_struct import ( + AbortReq, + BatchEmbeddingOut, + BatchTokenIDOut, + FlushCacheReq, + TokenizedEmbeddingReqInput, + TokenizedGenerateReqInput, +) +from sglang.srt.managers.policy_scheduler import PolicyScheduler, PrefillAdder +from sglang.srt.managers.schedule_batch import ( + FINISH_ABORT, + BaseFinishReason, + Req, + ScheduleBatch, +) +from sglang.srt.mem_cache.chunk_cache import ChunkCache +from sglang.srt.mem_cache.radix_cache import RadixCache +from sglang.srt.model_config import ModelConfig +from sglang.srt.model_executor.forward_batch_info import ForwardMode +from sglang.srt.model_executor.model_runner import ModelRunner +from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import ( + get_int_token_logit_bias, + is_multimodal_model, + set_random_seed, + suppress_other_loggers, +) +from sglang.utils import get_exception_traceback + +logger = logging.getLogger(__name__) + + +class ModelTpServer: + def __init__( + self, + gpu_id: int, + tp_rank: int, + server_args: ServerArgs, + nccl_port: int, + model_overide_args: dict, + ): + suppress_other_loggers() + + # Copy arguments + self.gpu_id = gpu_id + self.tp_rank = tp_rank + self.tp_size = server_args.tp_size + self.dp_size = server_args.dp_size + self.schedule_policy = server_args.schedule_policy + self.disable_regex_jump_forward = server_args.disable_regex_jump_forward + + # Chunked prefill + self.chunked_prefill_size = server_args.chunked_prefill_size + self.current_inflight_req = None + + # Init model and tokenizer + self.model_config = ModelConfig( + server_args.model_path, + server_args.trust_remote_code, + context_length=server_args.context_length, + model_overide_args=model_overide_args, + ) + self.model_runner = ModelRunner( + model_config=self.model_config, + mem_fraction_static=server_args.mem_fraction_static, + gpu_id=gpu_id, + tp_rank=tp_rank, + tp_size=server_args.tp_size, + nccl_port=nccl_port, + server_args=server_args, + ) + + if is_multimodal_model(server_args.model_path): + self.processor = get_processor( + server_args.tokenizer_path, + tokenizer_mode=server_args.tokenizer_mode, + trust_remote_code=server_args.trust_remote_code, + ) + self.tokenizer = self.processor.tokenizer + else: + self.tokenizer = get_tokenizer( + server_args.tokenizer_path, + tokenizer_mode=server_args.tokenizer_mode, + trust_remote_code=server_args.trust_remote_code, + ) + self.max_total_num_tokens = self.model_runner.max_total_num_tokens + self.max_prefill_tokens = ( + 16384 + if server_args.max_prefill_tokens is None + else server_args.max_prefill_tokens + ) + self.max_running_requests = min( + ( + self.max_total_num_tokens // 2 + if server_args.max_running_requests is None + else server_args.max_running_requests + ), + self.model_runner.req_to_token_pool.size - 1, + ) + self.int_token_logit_bias = torch.tensor( + get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size) + ) + self.max_req_input_len = min( + self.model_config.context_len - 1, + self.max_total_num_tokens - 1, + ) + set_random_seed(server_args.random_seed) + + # Print info + logger.info( + f"[gpu={self.gpu_id}] " + f"max_total_num_tokens={self.max_total_num_tokens}, " + f"max_prefill_tokens={self.max_prefill_tokens}, " + f"max_running_requests={self.max_running_requests}, " + f"context_len={self.model_config.context_len}" + ) + + # Init cache + if ( + server_args.chunked_prefill_size is not None + and server_args.disable_radix_cache + ): + self.tree_cache = ChunkCache( + req_to_token_pool=self.model_runner.req_to_token_pool, + token_to_kv_pool=self.model_runner.token_to_kv_pool, + ) + else: + self.tree_cache = RadixCache( + req_to_token_pool=self.model_runner.req_to_token_pool, + token_to_kv_pool=self.model_runner.token_to_kv_pool, + disable=server_args.disable_radix_cache, + ) + self.tree_cache_metrics = {"total": 0, "hit": 0} + self.scheduler = PolicyScheduler( + self.schedule_policy, + self.max_running_requests, + self.max_prefill_tokens, + self.max_total_num_tokens, + self.tree_cache, + ) + self.req_to_token_pool = self.model_runner.req_to_token_pool + self.token_to_kv_pool = self.model_runner.token_to_kv_pool + + # Init running status + self.waiting_queue: List[Req] = [] + self.running_batch: ScheduleBatch = None + self.out_pyobjs = [] + self.decode_forward_ct = 0 + self.stream_interval = server_args.stream_interval + self.num_generated_tokens = 0 + self.last_stats_tic = time.time() + + # Init the FSM cache for constrained generation + self.regex_fsm_cache = FSMCache( + server_args.tokenizer_path, + { + "tokenizer_mode": server_args.tokenizer_mode, + "trust_remote_code": server_args.trust_remote_code, + }, + ) + self.jump_forward_cache = JumpForwardCache() + + # Init new token estimation + assert ( + server_args.schedule_conservativeness >= 0 + ), "Invalid schedule_conservativeness" + self.min_new_token_ratio = min( + global_config.base_min_new_token_ratio + * server_args.schedule_conservativeness, + 1.0, + ) + self.new_token_ratio = self.min_new_token_ratio + self.new_token_ratio_decay = global_config.new_token_ratio_decay + + def exposed_step(self, recv_reqs): + try: + # Recv requests + for recv_req in recv_reqs: + if isinstance( + recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput) + ): + self.handle_generate_request(recv_req) + elif isinstance(recv_req, FlushCacheReq): + self.flush_cache() + elif isinstance(recv_req, AbortReq): + self.abort_request(recv_req) + else: + raise ValueError(f"Invalid request: {recv_req}") + + # Forward + self.forward_step() + except Exception: + logger.error("Exception in ModelTpServer:\n" + get_exception_traceback()) + raise + + # Return results + ret = self.out_pyobjs + self.out_pyobjs = [] + return ret + + @torch.inference_mode() + def forward_step(self): + new_batch = self.get_new_prefill_batch() + + if new_batch is not None: + # Run a new prefill batch + self.forward_prefill_batch(new_batch) + + if not new_batch.is_empty(): + if self.running_batch is None: + self.running_batch = new_batch + else: + self.running_batch.merge(new_batch) + else: + # Run a decode batch + if self.running_batch is not None: + # Run a few decode batches continuously for reducing overhead + for _ in range(global_config.num_continue_decode_steps): + self.num_generated_tokens += len(self.running_batch.reqs) + self.forward_decode_batch(self.running_batch) + + # Print stats + if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0: + self.print_stats() + + if self.running_batch.is_empty(): + self.running_batch = None + break + + if self.out_pyobjs and self.running_batch.has_stream(): + break + else: + self.check_memory() + self.new_token_ratio = global_config.init_new_token_ratio + + def print_stats(self): + num_used = self.max_total_num_tokens - ( + self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() + ) + throughput = self.num_generated_tokens / (time.time() - self.last_stats_tic) + self.num_generated_tokens = 0 + self.last_stats_tic = time.time() + logger.info( + f"[gpu={self.gpu_id}] Decode batch. " + f"#running-req: {len(self.running_batch.reqs)}, " + f"#token: {num_used}, " + f"token usage: {num_used / self.max_total_num_tokens:.2f}, " + f"gen throughput (token/s): {throughput:.2f}, " + f"#queue-req: {len(self.waiting_queue)}" + ) + + def check_memory(self): + available_size = ( + self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() + ) + if available_size != self.max_total_num_tokens: + warnings.warn( + "Warning: " + f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n" + "KV cache pool leak detected!" + ) + + if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size: + warnings.warn( + "Warning: " + f"available req slots={len(self.req_to_token_pool.free_slots)}, " + f"total slots={self.req_to_token_pool.size}\n" + "Memory pool leak detected!" + ) + + def handle_generate_request( + self, + recv_req: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput], + ): + req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids) + req.tokenizer = self.tokenizer + req.sampling_params = recv_req.sampling_params + if self.model_runner.is_generation: + req.pixel_values = recv_req.pixel_values + if req.pixel_values is not None: + req.pad_value = [ + (recv_req.image_hash) % self.model_config.vocab_size, + (recv_req.image_hash >> 16) % self.model_config.vocab_size, + (recv_req.image_hash >> 32) % self.model_config.vocab_size, + (recv_req.image_hash >> 64) % self.model_config.vocab_size, + ] + req.image_size = recv_req.image_size + ( + req.origin_input_ids, + req.image_offset, + ) = self.model_runner.model.pad_input_ids( + req.origin_input_ids_unpadded, + req.pad_value, + req.pixel_values.shape, + req.image_size, + ) + req.return_logprob = recv_req.return_logprob + req.logprob_start_len = recv_req.logprob_start_len + req.top_logprobs_num = recv_req.top_logprobs_num + req.stream = recv_req.stream + + # Init regex fsm + if req.sampling_params.regex is not None: + req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex) + if not self.disable_regex_jump_forward: + req.jump_forward_map = self.jump_forward_cache.query( + req.sampling_params.regex + ) + + # Truncate prompts that are too long + if len(req.origin_input_ids) >= self.max_req_input_len: + logger.warn( + "Request length is longer than the KV cache pool size or " + "the max context length. Truncated!!!" + ) + req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len] + + if self.model_runner.is_generation: + req.sampling_params.max_new_tokens = min( + ( + req.sampling_params.max_new_tokens + if req.sampling_params.max_new_tokens is not None + else 1 << 30 + ), + self.max_req_input_len - 1 - len(req.origin_input_ids), + ) + + self.waiting_queue.append(req) + + def get_new_prefill_batch(self) -> Optional[ScheduleBatch]: + running_bs = ( + len(self.running_batch.reqs) if self.running_batch is not None else 0 + ) + if running_bs >= self.max_running_requests: + return None + + # Compute matched prefix length + for req in self.waiting_queue: + req.input_ids = req.origin_input_ids + req.output_ids + # NOTE: the prefix_indices must always be aligned with last_node + req.prefix_indices, req.last_node = self.tree_cache.match_prefix( + rid=req.rid, key=req.adjust_max_prefix_ids() + ) + req.extend_input_len = len(req.input_ids) - len(req.prefix_indices) + + # Get priority queue + self.waiting_queue = self.scheduler.get_priority_queue(self.waiting_queue) + + adder = PrefillAdder( + self.tree_cache, + self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(), + self.max_prefill_tokens, + self.chunked_prefill_size, + ) + + if self.running_batch is not None: + adder.remove_running_tokens(self.running_batch, self.new_token_ratio) + + has_inflight = self.current_inflight_req is not None + if self.current_inflight_req is not None: + self.current_inflight_req = adder.add_inflight_req( + self.current_inflight_req + ) + + for req in self.waiting_queue: + + res = adder.add_one_req(req) + if ( + not res + or adder.no_remaining_tokens() + or running_bs + len(adder.can_run_list) >= self.max_running_requests + ): + break + + can_run_list = adder.can_run_list + + if adder.new_inflight_req is not None: + assert self.current_inflight_req is None + self.current_inflight_req = adder.new_inflight_req + + if len(can_run_list) == 0: + return None + + # Print stats + if self.tp_rank == 0: + self.tree_cache_metrics["total"] += ( + adder.log_input_tokens + adder.log_hit_tokens + ) / 10**9 + self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9 + tree_cache_hit_rate = ( + self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"] + ) + logger.info( + f"[gpu={self.gpu_id}] Prefill batch. " + f"#new-seq: {len(can_run_list)}, " + f"#new-token: {adder.log_input_tokens}, " + f"#cached-token: {adder.log_hit_tokens}, " + f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, " + f"#running-req: {running_bs}, " + f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}" + ) + + # Return the new batch + new_batch = ScheduleBatch.init_new( + can_run_list, + self.req_to_token_pool, + self.token_to_kv_pool, + self.tree_cache, + ) + self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list] + return new_batch + + def forward_prefill_batch(self, batch: ScheduleBatch): + # Build batch tensors + batch.prepare_for_extend( + self.model_config.vocab_size, self.int_token_logit_bias + ) + + if self.model_runner.is_generation: + # Forward and sample the next tokens + if batch.extend_num_tokens != 0: + output = self.model_runner.forward(batch, ForwardMode.EXTEND) + next_token_ids = batch.sample(output.next_token_logits) + + # Move logprobs to cpu + if output.next_token_logprobs is not None: + output.next_token_logprobs = output.next_token_logprobs[ + torch.arange(len(next_token_ids), device=next_token_ids.device), + next_token_ids, + ].tolist() + output.input_token_logprobs = output.input_token_logprobs.tolist() + output.normalized_prompt_logprobs = ( + output.normalized_prompt_logprobs.tolist() + ) + + next_token_ids = next_token_ids.tolist() + else: + next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs) + + # Check finish conditions + pt = 0 + for i, req in enumerate(batch.reqs): + if req is not self.current_inflight_req: + # Inflight reqs' prefill is not finished + req.completion_tokens_wo_jump_forward += 1 + req.output_ids.append(next_token_ids[i]) + req.check_finished() + + if req.finished(): + self.tree_cache.cache_finished_req(req) + else: + self.tree_cache.cache_unfinished_req(req) + + if req is self.current_inflight_req: + # Inflight request would get a new req idx + self.req_to_token_pool.free(req.req_pool_idx) + + if req.return_logprob: + self.add_logprob_return_values(i, req, pt, next_token_ids, output) + pt += req.extend_input_len + else: + assert batch.extend_num_tokens != 0 + output = self.model_runner.forward(batch, ForwardMode.EXTEND) + embeddings = output.embeddings.tolist() + + # Check finish conditions + for i, req in enumerate(batch.reqs): + req.embedding = embeddings[i] + if req is not self.current_inflight_req: + # Inflight reqs' prefill is not finished + req.check_finished() + + if req.finished(): + self.tree_cache.cache_finished_req(req) + else: + self.tree_cache.cache_unfinished_req(req) + + if req is self.current_inflight_req: + # Inflight request would get a new req idx + self.req_to_token_pool.free(req.req_pool_idx) + + self.handle_finished_requests(batch) + + def add_logprob_return_values(self, i, req, pt, next_token_ids, output): + if req.normalized_prompt_logprob is None: + req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i] + + if req.input_token_logprobs is None: + # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored. + req.input_token_logprobs = list( + zip( + output.input_token_logprobs[pt : pt + req.extend_input_len - 1], + req.input_ids[-req.extend_input_len + 1 :], + ) + ) + if req.logprob_start_len == 0: + req.input_token_logprobs = [ + (None, req.input_ids[0]) + ] + req.input_token_logprobs + + if req.last_update_decode_tokens != 0: + req.output_token_logprobs.extend( + list( + zip( + output.input_token_logprobs[ + pt + + req.extend_input_len + - req.last_update_decode_tokens : pt + + req.extend_input_len + - 1 + ], + req.input_ids[-req.last_update_decode_tokens + 1 :], + ) + ) + ) + + req.output_token_logprobs.append( + (output.next_token_logprobs[i], next_token_ids[i]) + ) + + if req.top_logprobs_num > 0: + if req.input_top_logprobs is None: + req.input_top_logprobs = output.input_top_logprobs[i] + if req.logprob_start_len == 0: + req.input_top_logprobs = [None] + req.input_top_logprobs + + if req.last_update_decode_tokens != 0: + req.output_top_logprobs.extend( + output.input_top_logprobs[i][-req.last_update_decode_tokens + 1 :] + ) + req.output_top_logprobs.append(output.output_top_logprobs[i]) + + def forward_decode_batch(self, batch: ScheduleBatch): + # Check if decode out of memory + if not batch.check_decode_mem(): + old_ratio = self.new_token_ratio + + retracted_reqs, new_token_ratio = batch.retract_decode() + self.new_token_ratio = new_token_ratio + + logger.info( + "decode out of memory happened, " + f"#retracted_reqs: {len(retracted_reqs)}, " + f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}" + ) + self.waiting_queue.extend(retracted_reqs) + else: + self.new_token_ratio = max( + self.new_token_ratio - self.new_token_ratio_decay, + self.min_new_token_ratio, + ) + + if not self.disable_regex_jump_forward: + # Check for jump-forward + jump_forward_reqs = batch.check_for_jump_forward(self.model_runner) + self.waiting_queue.extend(jump_forward_reqs) + if batch.is_empty(): + return + + # Update batch tensors + self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30) + batch.prepare_for_decode() + + # Forward and sample the next tokens + output = self.model_runner.forward(batch, ForwardMode.DECODE) + next_token_ids = batch.sample(output.next_token_logits) + + # Move logprobs to cpu + if output.next_token_logprobs is not None: + next_token_logprobs = output.next_token_logprobs[ + torch.arange(len(next_token_ids), device=next_token_ids.device), + next_token_ids, + ].tolist() + + next_token_ids = next_token_ids.tolist() + + # Check finish condition + for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)): + req.completion_tokens_wo_jump_forward += 1 + req.output_ids.append(next_token_id) + req.check_finished() + + if req.finished(): + self.tree_cache.cache_finished_req(req) + + if req.return_logprob: + req.output_token_logprobs.append( + (next_token_logprobs[i], next_token_id) + ) + if req.top_logprobs_num > 0: + req.output_top_logprobs.append(output.output_top_logprobs[i]) + + self.handle_finished_requests(batch) + + def handle_finished_requests(self, batch: ScheduleBatch): + output_rids = [] + output_meta_info = [] + output_finished_reason: List[BaseFinishReason] = [] + if self.model_runner.is_generation: + output_vids = [] + decoded_texts = [] + output_read_ids = [] + output_read_offsets = [] + output_skip_special_tokens = [] + output_spaces_between_special_tokens = [] + else: # for embedding model + output_embeddings = [] + unfinished_indices = [] + + for i, req in enumerate(batch.reqs): + if not req.finished() and req is not self.current_inflight_req: + unfinished_indices.append(i) + + if req.finished() or ( + ( + req.stream + and ( + self.decode_forward_ct % self.stream_interval == 0 + or len(req.output_ids) == 1 + ) + ) + ): + output_rids.append(req.rid) + output_finished_reason.append(req.finished_reason) + if self.model_runner.is_generation: + output_vids.append(req.vid) + decoded_texts.append(req.decoded_text) + read_ids, read_offset = req.init_incremental_detokenize() + output_read_ids.append(read_ids) + output_read_offsets.append(read_offset) + output_skip_special_tokens.append( + req.sampling_params.skip_special_tokens + ) + output_spaces_between_special_tokens.append( + req.sampling_params.spaces_between_special_tokens + ) + + meta_info = { + "prompt_tokens": len(req.origin_input_ids), + "completion_tokens": len(req.output_ids), + "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward, + "finish_reason": str(req.finished_reason), + } + if req.return_logprob: + ( + meta_info["input_token_logprobs"], + meta_info["output_token_logprobs"], + meta_info["input_top_logprobs"], + meta_info["output_top_logprobs"], + meta_info["normalized_prompt_logprob"], + ) = ( + req.input_token_logprobs, + req.output_token_logprobs, + req.input_top_logprobs, + req.output_top_logprobs, + req.normalized_prompt_logprob, + ) + output_meta_info.append(meta_info) + else: # for embedding model + output_embeddings.append(req.embedding) + meta_info = { + "prompt_tokens": len(req.origin_input_ids), + } + output_meta_info.append(meta_info) + + # Send to detokenizer + if output_rids: + if self.model_runner.is_generation: + self.out_pyobjs.append( + BatchTokenIDOut( + output_rids, + output_vids, + decoded_texts, + output_read_ids, + output_read_offsets, + output_skip_special_tokens, + output_spaces_between_special_tokens, + output_meta_info, + output_finished_reason, + ) + ) + else: # for embedding model + self.out_pyobjs.append( + BatchEmbeddingOut( + output_rids, + output_embeddings, + output_meta_info, + output_finished_reason, + ) + ) + + # Remove finished reqs: update batch tensors + batch.filter_batch(unfinished_indices) + + def flush_cache(self): + if len(self.waiting_queue) == 0 and ( + self.running_batch is None or len(self.running_batch.reqs) == 0 + ): + self.tree_cache.reset() + self.tree_cache_metrics = {"total": 0, "hit": 0} + self.regex_fsm_cache.reset() + self.req_to_token_pool.clear() + self.token_to_kv_pool.clear() + torch.cuda.empty_cache() + logger.info("Cache flushed successfully!") + else: + warnings.warn( + f"Cache not flushed because there are pending requests. " + f"#queue-req: {len(self.waiting_queue)}, " + f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}" + ) + + def abort_request(self, recv_req): + # Delete requests in the waiting queue + to_del = None + for i, req in enumerate(self.waiting_queue): + if req.rid == recv_req.rid: + to_del = i + break + + if to_del is not None: + del self.waiting_queue[to_del] + + # Delete requests in the running batch + if self.running_batch: + for req in self.running_batch.reqs: + if req.rid == recv_req.rid: + req.finished_reason = FINISH_ABORT() + break + + +def run_tp_server( + gpu_id: int, + tp_rank: int, + server_args: ServerArgs, + nccl_port: int, + model_overide_args: dict, +): + """Run a tensor parallel server.""" + try: + model_server = ModelTpServer( + gpu_id, + tp_rank, + server_args, + nccl_port, + model_overide_args, + ) + tp_cpu_group = model_server.model_runner.tp_group.cpu_group + + while True: + recv_reqs = broadcast_recv_input(None, tp_rank, tp_cpu_group) + model_server.exposed_step(recv_reqs) + except Exception: + logger.error("Exception in run_tp_server:\n" + get_exception_traceback()) + raise + + +def launch_tp_servers( + gpu_ids, tp_rank_range, server_args, nccl_port, model_overide_args +): + """Launch multiple tensor parallel servers.""" + procs = [] + for i in tp_rank_range: + proc = multiprocessing.Process( + target=run_tp_server, + args=(gpu_ids[i], i, server_args, nccl_port, model_overide_args), + ) + proc.start() + procs.append(proc) + + return procs + + +def broadcast_recv_input(data, rank, dist_group): + """Broadcast inputs from rank=0 to all other ranks with torch.dist backend.""" + + if rank == 0: + if len(data) == 0: + tensor_size = torch.tensor([0], dtype=torch.long) + dist.broadcast(tensor_size, src=0, group=dist_group) + else: + serialized_data = pickle.dumps(data) + size = len(serialized_data) + tensor_data = torch.ByteTensor(list(serialized_data)) + tensor_size = torch.tensor([size], dtype=torch.long) + + dist.broadcast(tensor_size, src=0, group=dist_group) + dist.broadcast(tensor_data, src=0, group=dist_group) + else: + tensor_size = torch.tensor([0], dtype=torch.long) + dist.broadcast(tensor_size, src=0, group=dist_group) + size = tensor_size.item() + + if size == 0: + return [] + + tensor_data = torch.empty(size, dtype=torch.uint8) + dist.broadcast(tensor_data, src=0, group=dist_group) + + serialized_data = bytes(tensor_data.tolist()) + data = pickle.loads(serialized_data) + return data diff --git a/python/sglang/srt/mem_cache/base_prefix_cache.py b/python/sglang/srt/mem_cache/base_prefix_cache.py new file mode 100644 index 00000000000..fb2b7a627b1 --- /dev/null +++ b/python/sglang/srt/mem_cache/base_prefix_cache.py @@ -0,0 +1,47 @@ +from abc import ABC, abstractmethod + + +class BasePrefixCache(ABC): + """Cache can be indexed by either rid or key.""" + + @abstractmethod + def reset(self): + pass + + @abstractmethod + def match_prefix(self, **kwargs): + pass + + @abstractmethod + def insert(self, **kwargs): + pass + + @abstractmethod + def cache_finished_req(self, **kwargs): + pass + + @abstractmethod + def cache_unfinished_req(self, **kwargs): + pass + + @abstractmethod + def evict(self, num_tokens, evict_callback): + pass + + @abstractmethod + def inc_lock_ref(self, node): + pass + + @abstractmethod + def dec_lock_ref(self, node): + pass + + @abstractmethod + def evictable_size(self): + pass + + def total_size(self): + raise NotImplementedError + + def pretty_print(self): + raise NotImplementedError diff --git a/python/sglang/srt/mem_cache/chunk_cache.py b/python/sglang/srt/mem_cache/chunk_cache.py new file mode 100644 index 00000000000..100cbbaec22 --- /dev/null +++ b/python/sglang/srt/mem_cache/chunk_cache.py @@ -0,0 +1,75 @@ +"""Cache for chunked prefill, used when RadixCache is disabled.""" + +from typing import TYPE_CHECKING + +from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache + +if TYPE_CHECKING: + from sglang.srt.managers.schedule_batch import Req + + +class ChunkCacheEntry: + def __init__(self, rid, value): + self.rid = rid + self.value = value + + +class ChunkCache(BasePrefixCache): + def __init__(self, req_to_token_pool, token_to_kv_pool): + self.disable = True + self.req_to_token_pool = req_to_token_pool + self.token_to_kv_pool = token_to_kv_pool + + self.reset() + + def reset(self): + self.entries = {} + + def match_prefix(self, rid, **kwargs): + if rid not in self.entries: + return [], None + + entry = self.entries[rid] + return entry.value, entry + + def cache_finished_req(self, req: "Req", token_ids=None): + if token_ids is None: + token_ids = (req.input_ids + req.output_ids)[:-1] + + kv_indices = self.req_to_token_pool.req_to_token[ + req.req_pool_idx, : len(token_ids) + ] + assert req.rid in self.entries + self.req_to_token_pool.free(req.req_pool_idx) + self.token_to_kv_pool.free(kv_indices) + + def cache_unfinished_req(self, req: "Req", token_ids=None): + if token_ids is None: + token_ids = req.input_ids + + kv_indices = self.req_to_token_pool.req_to_token[ + req.req_pool_idx, : len(token_ids) + ] + + if req.rid not in self.entries: + self.entries[req.rid] = ChunkCacheEntry(req.rid, kv_indices) + + entry = self.entries[req.rid] + entry.value = kv_indices + req.prefix_indices = kv_indices + req.last_node = entry + + def insert(self): + raise NotImplementedError + + def evict(self, num_tokens, evict_callback): + pass + + def inc_lock_ref(self, node): + return 0 + + def dec_lock_ref(self, node): + return 0 + + def evictable_size(self): + return 0 diff --git a/python/sglang/srt/mem_cache/flush_cache.py b/python/sglang/srt/mem_cache/flush_cache.py new file mode 100644 index 00000000000..3ac425ac8c0 --- /dev/null +++ b/python/sglang/srt/mem_cache/flush_cache.py @@ -0,0 +1,33 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +""" +Flush the KV cache. + +Usage: +python3 -m sglang.srt.mem_cache.flush_cache --url http://localhost:30000 +""" + +import argparse + +import requests + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--url", type=str, default="http://localhost:30000") + args = parser.parse_args() + + response = requests.get(args.url + "/flush_cache") + assert response.status_code == 200 diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py new file mode 100644 index 00000000000..37ce4296dee --- /dev/null +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -0,0 +1,177 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Memory pool.""" + +import logging +from typing import List + +import torch + +logger = logging.getLogger(__name__) + + +class ReqToTokenPool: + """A memory pool that maps a request to its token locations.""" + + def __init__(self, size: int, max_context_len: int): + self.size = size + self.free_slots = list(range(size)) + self.req_to_token = torch.empty( + (size, max_context_len), dtype=torch.int32, device="cuda" + ) + + def alloc(self, need_size: int) -> List[int]: + if need_size > len(self.free_slots): + return None + + select_index = self.free_slots[:need_size] + self.free_slots = self.free_slots[need_size:] + + return select_index + + def free(self, free_index): + if isinstance(free_index, (int,)): + self.free_slots.append(free_index) + else: + self.free_slots.extend(free_index) + + def clear(self): + self.free_slots = list(range(self.size)) + + +class BaseTokenToKVPool: + """A memory pool that maps a token to its kv cache locations""" + + def __init__( + self, + size: int, + ): + self.size = size + + # We also add one slot. This slot is used for writing dummy output from padded tokens. + self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda") + + # Prefetch buffer + self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32) + self.prefetch_chunk_size = 512 + + self.can_use_mem_size = self.size + self.clear() + + def available_size(self): + return self.can_use_mem_size + len(self.prefetch_buffer) + + def alloc(self, need_size: int): + buffer_len = len(self.prefetch_buffer) + if need_size <= buffer_len: + select_index = self.prefetch_buffer[:need_size] + self.prefetch_buffer = self.prefetch_buffer[need_size:] + return select_index + + addition_size = need_size - buffer_len + alloc_size = max(addition_size, self.prefetch_chunk_size) + select_index = ( + torch.nonzero(self.mem_state).squeeze(1)[:alloc_size].to(torch.int32) + ) + + if select_index.shape[0] < addition_size: + return None + + self.mem_state[select_index] = False + self.can_use_mem_size -= len(select_index) + + self.prefetch_buffer = torch.cat((self.prefetch_buffer, select_index)) + ret_index = self.prefetch_buffer[:need_size] + self.prefetch_buffer = self.prefetch_buffer[need_size:] + + return ret_index + + def free(self, free_index: torch.Tensor): + self.mem_state[free_index] = True + self.can_use_mem_size += len(free_index) + + def clear(self): + self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32) + + self.mem_state.fill_(True) + self.can_use_mem_size = self.size + + # We also add one slot. This slot is used for writing dummy output from padded tokens. + self.mem_state[0] = False + + +class MHATokenToKVPool(BaseTokenToKVPool): + + def __init__( + self, + size: int, + dtype: torch.dtype, + head_num: int, + head_dim: int, + layer_num: int, + ): + super().__init__(size) + + # [size, head_num, head_dim] for each layer + self.k_buffer = [ + torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda") + for _ in range(layer_num) + ] + self.v_buffer = [ + torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda") + for _ in range(layer_num) + ] + + def get_key_buffer(self, layer_id: int): + return self.k_buffer[layer_id] + + def get_value_buffer(self, layer_id: int): + return self.v_buffer[layer_id] + + def get_kv_buffer(self, layer_id: int): + return self.k_buffer[layer_id], self.v_buffer[layer_id] + + +class MLATokenToKVPool(BaseTokenToKVPool): + + def __init__( + self, + size: int, + dtype: torch.dtype, + kv_lora_rank: int, + qk_rope_head_dim: int, + layer_num: int, + ): + super().__init__(size) + + self.kv_lora_rank = kv_lora_rank + self.kv_buffer = [ + torch.empty( + (size + 1, 1, kv_lora_rank + qk_rope_head_dim), + dtype=dtype, + device="cuda", + ) + for _ in range(layer_num) + ] + + def get_key_buffer(self, layer_id: int): + return self.kv_buffer[layer_id] + + def get_value_buffer(self, layer_id: int): + return self.kv_buffer[layer_id][..., : self.kv_lora_rank] + + def get_kv_buffer(self, layer_id: int): + return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id) diff --git a/python/sglang/srt/managers/controller/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py similarity index 71% rename from python/sglang/srt/managers/controller/radix_cache.py rename to python/sglang/srt/mem_cache/radix_cache.py index ab8d6b4468b..c2381204925 100644 --- a/python/sglang/srt/managers/controller/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -1,3 +1,18 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + """ The radix tree data structure for managing the KV cache. """ @@ -5,9 +20,15 @@ import heapq import time from collections import defaultdict +from typing import TYPE_CHECKING import torch +from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache + +if TYPE_CHECKING: + from sglang.srt.managers.schedule_batch import Req + class TreeNode: def __init__(self): @@ -31,7 +52,7 @@ def _key_match(key0, key1): return i -class RadixCache: +class RadixCache(BasePrefixCache): def __init__(self, req_to_token_pool, token_to_kv_pool, disable: bool = False): self.req_to_token_pool = req_to_token_pool self.token_to_kv_pool = token_to_kv_pool @@ -47,7 +68,7 @@ def reset(self): self.root_node.lock_ref = 1 self.evictable_size_ = 0 - def match_prefix(self, key): + def match_prefix(self, key, **kwargs): if self.disable: return [], self.root_node @@ -57,7 +78,7 @@ def match_prefix(self, key): if value: value = torch.concat(value) else: - value = torch.tensor([], dtype=torch.int64) + value = torch.tensor([], dtype=torch.int32) return value, last_node[0] def insert(self, key, value=None): @@ -68,39 +89,54 @@ def insert(self, key, value=None): value = [x for x in key] return self._insert_helper(self.root_node, key, value) - def cache_req( - self, - token_ids, - last_uncached_pos, - req_pool_idx, - del_in_memory_pool=True, - old_last_node=None, - ): - # Insert the request into radix cache - indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)] - new_prefix_len = self.insert(token_ids, indices.clone()) + def cache_finished_req(self, req: "Req", token_ids=None): + """Cache request when it finishes.""" + if token_ids is None: + token_ids = (req.input_ids + req.output_ids)[:-1] + kv_indices = self.req_to_token_pool.req_to_token[ + req.req_pool_idx, : len(token_ids) + ] if self.disable: - if del_in_memory_pool: - self.token_to_kv_pool.dec_refs(indices) - else: - return torch.tensor([], dtype=torch.int64), self.root_node + self.token_to_kv_pool.free(kv_indices) + self.req_to_token_pool.free(req.req_pool_idx) + return # Radix Cache takes one ref in memory pool - self.token_to_kv_pool.dec_refs(indices[last_uncached_pos:new_prefix_len]) + new_prefix_len = self.insert(token_ids, kv_indices.clone()) + self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len]) - if del_in_memory_pool: - self.req_to_token_pool.free(req_pool_idx) - else: - cached_indices, new_last_node = self.match_prefix(token_ids) - assert len(cached_indices) == len(token_ids) + # Remove req slot release the cache lock + self.req_to_token_pool.free(req.req_pool_idx) + self.dec_lock_ref(req.last_node) + + def cache_unfinished_req(self, req: "Req", token_ids=None): + """Cache request when it is unfinished.""" + if self.disable: + return + + if token_ids is None: + token_ids = req.input_ids + + kv_indices = self.req_to_token_pool.req_to_token[ + req.req_pool_idx, : len(token_ids) + ] + + # Radix Cache takes one ref in memory pool + new_prefix_len = self.insert(token_ids, kv_indices.clone()) + self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len]) + + # The prefix indices could be updated, reuse it + new_indices, new_last_node = self.match_prefix(token_ids) + assert len(new_indices) == len(token_ids) + self.req_to_token_pool.req_to_token[ + req.req_pool_idx, len(req.prefix_indices) : len(new_indices) + ] = new_indices[len(req.prefix_indices) :] - self.req_to_token_pool.req_to_token[ - req_pool_idx, last_uncached_pos : len(cached_indices) - ] = cached_indices[last_uncached_pos:] - self.dec_lock_ref(old_last_node) - self.inc_lock_ref(new_last_node) - return cached_indices, new_last_node + self.dec_lock_ref(req.last_node) + self.inc_lock_ref(new_last_node) + req.prefix_indices = new_indices + req.last_node = new_last_node def pretty_print(self): self._print_helper(self.root_node, 0) @@ -125,7 +161,8 @@ def evict(self, num_tokens, evict_callback): if x.lock_ref > 0: continue - num_evicted += evict_callback(x.value) + evict_callback(x.value) + num_evicted += len(x.value) self._delete_leaf(x) if len(x.parent.children) == 0: diff --git a/python/sglang/srt/memory_pool.py b/python/sglang/srt/memory_pool.py deleted file mode 100644 index 33f4b8784b1..00000000000 --- a/python/sglang/srt/memory_pool.py +++ /dev/null @@ -1,103 +0,0 @@ -"""Memory pool.""" - -import logging - -import torch - -logger = logging.getLogger(__name__) - - -class ReqToTokenPool: - def __init__(self, size, max_context_len): - self.mem_state = torch.ones((size,), dtype=torch.bool, device="cuda") - self.can_use_mem_size = size - self.req_to_token = torch.empty( - (size, max_context_len), dtype=torch.int32, device="cuda" - ) - - def alloc(self, need_size): - if need_size > self.can_use_mem_size: - return None - - select_index = torch.nonzero(self.mem_state).squeeze(1)[:need_size] - self.mem_state[select_index] = 0 - self.can_use_mem_size -= need_size - return select_index.to(torch.int32) - - def free(self, free_index): - if isinstance(free_index, (int,)): - self.can_use_mem_size += 1 - else: - self.can_use_mem_size += free_index.shape[0] - self.mem_state[free_index] = 1 - - def clear(self): - self.mem_state.fill_(1) - self.can_use_mem_size = len(self.mem_state) - - -class TokenToKVPool: - def __init__(self, size, dtype, head_num, head_dim, layer_num): - self.mem_state = torch.zeros((size,), dtype=torch.int16, device="cuda") - self.total_ref_ct = 0 - - # [size, key/value, head_num, head_dim] for each layer - self.kv_data = [ - torch.empty((size, 2, head_num, head_dim), dtype=dtype, device="cuda") - for _ in range(layer_num) - ] - - def get_key_buffer(self, layer_id): - return self.kv_data[layer_id][:, 0] - - def get_value_buffer(self, layer_id): - return self.kv_data[layer_id][:, 1] - - def alloc(self, need_size): - select_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:need_size] - if select_index.shape[0] < need_size: - return None - - self.add_refs(select_index) - return select_index.to(torch.int32) - - def alloc_contiguous(self, need_size): - empty_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:need_size] - if empty_index.shape[0] < need_size: - return None - empty_size = len(empty_index) - loc_sum = ( - empty_index[need_size - 1 :] - empty_index[: empty_size - (need_size - 1)] - ) - can_used_loc = empty_index[: empty_size - (need_size - 1)][ - loc_sum == need_size - 1 - ] - if can_used_loc.shape[0] == 0: - return None - - start_loc = can_used_loc[0].item() - select_index = torch.arange(start_loc, start_loc + need_size, device="cuda") - self.add_refs(select_index) - return select_index.to(torch.int32), start_loc, start_loc + need_size - - def used_size(self): - return len(torch.nonzero(self.mem_state).squeeze(1)) - - def available_size(self): - return torch.sum(self.mem_state == 0).item() - - def add_refs(self, token_index: torch.Tensor): - self.total_ref_ct += len(token_index) - self.mem_state[token_index] += 1 - - def dec_refs(self, token_index: torch.Tensor): - self.total_ref_ct -= len(token_index) - self.mem_state[token_index] -= 1 - - num_freed = torch.sum(self.mem_state[token_index] == 0) - - return num_freed - - def clear(self): - self.mem_state.fill_(0) - self.total_ref_ct = 0 diff --git a/python/sglang/srt/mm_utils.py b/python/sglang/srt/mm_utils.py index 17d9c8c18d1..f253bd39e13 100644 --- a/python/sglang/srt/mm_utils.py +++ b/python/sglang/srt/mm_utils.py @@ -1,3 +1,18 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + # Source: https://github.com/haotian-liu/LLaVA/blob/main/llava/mm_utils.py import ast import base64 diff --git a/python/sglang/srt/model_config.py b/python/sglang/srt/model_config.py index c2cf7d47e55..ed496515cd3 100644 --- a/python/sglang/srt/model_config.py +++ b/python/sglang/srt/model_config.py @@ -1,3 +1,19 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from enum import IntEnum, auto from typing import Optional from transformers import PretrainedConfig @@ -5,6 +21,11 @@ from sglang.srt.hf_transformers_utils import get_config, get_context_length +class AttentionArch(IntEnum): + MLA = auto() + MHA = auto() + + class ModelConfig: def __init__( self, @@ -36,6 +57,16 @@ def __init__( "head_dim", self.hf_config.hidden_size // self.hf_config.num_attention_heads, ) + + # FIXME: temporary special judge for deepseek v2 MLA architecture + if "DeepseekV2ForCausalLM" in self.hf_config.architectures: + self.head_dim = 256 + self.attention_arch = AttentionArch.MLA + self.kv_lora_rank = self.hf_config.kv_lora_rank + self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim + else: + self.attention_arch = AttentionArch.MHA + self.num_attention_heads = self.hf_config.num_attention_heads self.num_key_value_heads = getattr(self.hf_config, "num_key_value_heads", None) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py new file mode 100644 index 00000000000..9bfd4a646c2 --- /dev/null +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -0,0 +1,279 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Run the model with cuda graph.""" + +import bisect +from contextlib import contextmanager + +import torch +from flashinfer import BatchDecodeWithPagedKVCacheWrapper +from flashinfer.decode import _grouped_size_compiled_for_decode_kernels +from vllm.distributed.parallel_state import graph_capture +from vllm.model_executor.custom_op import CustomOp + +from sglang.srt.layers.logits_processor import ( + LogitProcessorOutput, + LogitsMetadata, + LogitsProcessor, +) +from sglang.srt.managers.schedule_batch import ScheduleBatch +from sglang.srt.model_executor.forward_batch_info import ( + ForwardMode, + InputMetadata, + update_flashinfer_indices, +) +from sglang.srt.utils import monkey_patch_vllm_all_gather + + +def _to_torch(model: torch.nn.Module, reverse: bool = False): + for sub in model._modules.values(): + if isinstance(sub, CustomOp): + if reverse: + sub._forward_method = sub.forward_cuda + else: + sub._forward_method = sub.forward_native + if isinstance(sub, torch.nn.Module): + _to_torch(sub, reverse) + + +@contextmanager +def patch_model( + model: torch.nn.Module, use_compile: bool, tp_group: "GroupCoordinator" +): + backup_ca_comm = None + + try: + if use_compile: + _to_torch(model) + monkey_patch_vllm_all_gather() + backup_ca_comm = tp_group.ca_comm + tp_group.ca_comm = None + yield torch.compile(model.forward, mode="max-autotune-no-cudagraphs") + else: + yield model.forward + finally: + if use_compile: + _to_torch(model, reverse=True) + monkey_patch_vllm_all_gather(reverse=True) + tp_group.ca_comm = backup_ca_comm + + +def set_torch_compile_config(): + import torch._dynamo.config + import torch._inductor.config + + torch._inductor.config.coordinate_descent_tuning = True + torch._inductor.config.triton.unique_kernel_names = True + torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future + + # FIXME: tmp workaround + torch._dynamo.config.accumulated_cache_size_limit = 1024 + + +class CudaGraphRunner: + def __init__(self, model_runner, max_batch_size_to_capture, use_torch_compile): + self.model_runner = model_runner + self.graphs = {} + self.input_buffers = {} + self.output_buffers = {} + self.flashinfer_handlers = {} + self.graph_memory_pool = None + + # Common inputs + self.max_bs = max_batch_size_to_capture + self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda") + self.req_pool_indices = torch.zeros( + (self.max_bs,), dtype=torch.int32, device="cuda" + ) + self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32, device="cuda") + self.position_ids_offsets = torch.zeros( + (self.max_bs,), dtype=torch.int32, device="cuda" + ) + self.out_cache_loc = torch.zeros( + (self.max_bs,), dtype=torch.int32, device="cuda" + ) + + # FlashInfer inputs + self.flashinfer_workspace_buffer = ( + self.model_runner.flashinfer_workspace_buffers[0] + ) + self.flashinfer_kv_indptr = torch.zeros( + (self.max_bs + 1,), dtype=torch.int32, device="cuda" + ) + self.flashinfer_kv_indices = torch.zeros( + (self.max_bs * model_runner.model_config.context_len,), + dtype=torch.int32, + device="cuda", + ) + self.flashinfer_kv_last_page_len = torch.ones( + (self.max_bs,), dtype=torch.int32, device="cuda" + ) + + self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else [] + + if use_torch_compile: + set_torch_compile_config() + + def can_run(self, batch_size): + return batch_size < self.max_bs + + def capture(self, batch_size_list): + self.batch_size_list = batch_size_list + with graph_capture() as graph_capture_context: + self.stream = graph_capture_context.stream + for bs in batch_size_list: + with patch_model( + self.model_runner.model, + bs in self.compile_bs, + self.model_runner.tp_group, + ) as forward: + ( + graph, + input_buffers, + output_buffers, + flashinfer_handler, + ) = self.capture_one_batch_size(bs, forward) + self.graphs[bs] = graph + self.input_buffers[bs] = input_buffers + self.output_buffers[bs] = output_buffers + self.flashinfer_handlers[bs] = flashinfer_handler + + def capture_one_batch_size(self, bs, forward): + graph = torch.cuda.CUDAGraph() + stream = self.stream + + # Common inputs + input_ids = self.input_ids[:bs] + req_pool_indices = self.req_pool_indices[:bs] + seq_lens = self.seq_lens[:bs] + position_ids_offsets = self.position_ids_offsets[:bs] + out_cache_loc = self.out_cache_loc[:bs] + + # FlashInfer inputs + if not _grouped_size_compiled_for_decode_kernels( + self.model_runner.model_config.num_attention_heads + // self.model_runner.tp_size, + self.model_runner.model_config.get_num_kv_heads(self.model_runner.tp_size), + ): + use_tensor_cores = True + else: + use_tensor_cores = False + flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( + self.flashinfer_workspace_buffer, + "NHD", + use_cuda_graph=True, + use_tensor_cores=use_tensor_cores, + paged_kv_indptr_buffer=self.flashinfer_kv_indptr[: bs + 1], + paged_kv_indices_buffer=self.flashinfer_kv_indices, + paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs], + ) + update_flashinfer_indices( + ForwardMode.DECODE, + self.model_runner, + req_pool_indices, + seq_lens, + None, + flashinfer_decode_wrapper, + ) + + # Run and capture + def run_once(): + input_metadata = InputMetadata( + forward_mode=ForwardMode.DECODE, + batch_size=bs, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + req_to_token_pool=self.model_runner.req_to_token_pool, + token_to_kv_pool=self.model_runner.token_to_kv_pool, + out_cache_loc=out_cache_loc, + return_logprob=False, + top_logprobs_nums=0, + positions=(seq_lens - 1).to(torch.int64), + flashinfer_decode_wrapper=flashinfer_decode_wrapper, + ) + + return forward(input_ids, input_metadata.positions, input_metadata) + + for _ in range(2): + run_once() + + torch.cuda.synchronize() + with torch.cuda.graph(graph, pool=self.graph_memory_pool, stream=stream): + out = run_once() + torch.cuda.synchronize() + self.graph_memory_pool = graph.pool() + return graph, None, out, flashinfer_decode_wrapper + + def replay(self, batch: ScheduleBatch): + assert batch.out_cache_loc is not None + raw_bs = len(batch.reqs) + + # Pad + index = bisect.bisect_left(self.batch_size_list, raw_bs) + bs = self.batch_size_list[index] + if bs != raw_bs: + self.seq_lens.fill_(1) + self.position_ids_offsets.zero_() + self.out_cache_loc.zero_() + + # Common inputs + self.input_ids[:raw_bs] = batch.input_ids + self.req_pool_indices[:raw_bs] = batch.req_pool_indices + self.seq_lens[:raw_bs] = batch.seq_lens + self.position_ids_offsets[:raw_bs] = batch.position_ids_offsets + self.out_cache_loc[:raw_bs] = batch.out_cache_loc + + # FlashInfer inputs + update_flashinfer_indices( + ForwardMode.DECODE, + self.model_runner, + self.req_pool_indices[:bs], + self.seq_lens[:bs], + None, + self.flashinfer_handlers[bs], + ) + + # Replay + self.graphs[bs].replay() + output = self.output_buffers[bs] + + # Unpad + if bs != raw_bs: + output = LogitProcessorOutput( + next_token_logits=output.next_token_logits[:raw_bs], + next_token_logprobs=None, + normalized_prompt_logprobs=None, + input_token_logprobs=None, + input_top_logprobs=None, + output_top_logprobs=None, + ) + + # Extract logprobs + if batch.return_logprob: + output.next_token_logprobs = torch.nn.functional.log_softmax( + output.next_token_logits, dim=-1 + ) + return_top_logprob = any(x > 0 for x in batch.top_logprobs_nums) + if return_top_logprob: + logits_metadata = LogitsMetadata( + forward_mode=ForwardMode.DECODE, + top_logprobs_nums=batch.top_logprobs_nums, + ) + output.output_top_logprobs = LogitsProcessor.get_top_logprobs( + output.next_token_logprobs, logits_metadata + )[1] + + return output diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py new file mode 100644 index 00000000000..686e7ed86fd --- /dev/null +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -0,0 +1,318 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""ModelRunner runs the forward passes of the models.""" +from dataclasses import dataclass +from enum import IntEnum, auto +from typing import TYPE_CHECKING, List + +import numpy as np +import torch + +from sglang.srt.managers.schedule_batch import ScheduleBatch +from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool + +if TYPE_CHECKING: + from sglang.srt.model_executor.model_runner import ModelRunner + + +class ForwardMode(IntEnum): + # Prefill a new sequence. This is deprecated now. "EXTEND" covers this case. + PREFILL = auto() + # Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt). + EXTEND = auto() + # Decode one token. + DECODE = auto() + + +@dataclass +class InputMetadata: + """Store all inforamtion of a forward pass.""" + + forward_mode: ForwardMode + batch_size: int + req_pool_indices: torch.Tensor + seq_lens: torch.Tensor + req_to_token_pool: ReqToTokenPool + token_to_kv_pool: BaseTokenToKVPool + + # Output location of the KV cache + out_cache_loc: torch.Tensor + + total_num_tokens: int = None + + # Position information + positions: torch.Tensor = None + + # For extend + extend_seq_lens: torch.Tensor = None + extend_start_loc: torch.Tensor = None + extend_no_prefix: bool = None + + # Output options + return_logprob: bool = False + top_logprobs_nums: List[int] = None + + # For multimodal + pixel_values: List[torch.Tensor] = None + image_sizes: List[List[int]] = None + image_offsets: List[int] = None + + # Trition attention backend + triton_max_seq_len: int = 0 + triton_max_extend_len: int = 0 + triton_start_loc: torch.Tensor = None + triton_prefix_lens: torch.Tensor = None + + # FlashInfer attention backend + flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None + flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None + flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None + flashinfer_use_ragged: bool = False + + def init_multimuldal_info(self, batch: ScheduleBatch): + reqs = batch.reqs + self.pixel_values = [r.pixel_values for r in reqs] + self.image_sizes = [r.image_size for r in reqs] + self.image_offsets = [ + ( + (r.image_offset - len(r.prefix_indices)) + if r.image_offset is not None + else 0 + ) + for r in reqs + ] + + def compute_positions(self, batch: ScheduleBatch): + position_ids_offsets = batch.position_ids_offsets + + if self.forward_mode == ForwardMode.DECODE: + if True: + self.positions = self.seq_lens - 1 + else: + # Deprecated + self.positions = (self.seq_lens - 1) + position_ids_offsets + else: + if True: + self.positions = torch.tensor( + np.concatenate( + [ + np.arange(len(req.prefix_indices), len(req.input_ids)) + for req in batch.reqs + ], + axis=0, + ), + device="cuda", + ) + else: + # Deprecated + position_ids_offsets_cpu = position_ids_offsets.cpu().numpy() + self.positions = torch.tensor( + np.concatenate( + [ + np.arange( + len(req.prefix_indices) + position_ids_offsets_cpu[i], + len(req.input_ids) + position_ids_offsets_cpu[i], + ) + for i, req in enumerate(batch.reqs) + ], + axis=0, + ), + device="cuda", + ) + + # Positions should be in long type + self.positions = self.positions.to(torch.int64) + + def compute_extend_infos(self, batch: ScheduleBatch): + if self.forward_mode == ForwardMode.DECODE: + self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None + else: + prefix_lens_cpu = [ + len(r.input_ids) - len(r.prefix_indices) for r in batch.reqs + ] + self.extend_seq_lens = torch.tensor(prefix_lens_cpu, device="cuda") + self.extend_start_loc = torch.zeros_like(self.seq_lens) + self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0) + self.extend_no_prefix = all(x == 0 for x in prefix_lens_cpu) + + def init_total_num_tokens(self, batch: ScheduleBatch): + self.total_num_tokens = sum(len(req.input_ids) for req in batch.reqs) + + @classmethod + def from_schedule_batch( + cls, + model_runner: "ModelRunner", + batch: ScheduleBatch, + forward_mode: ForwardMode, + ): + ret = cls( + forward_mode=forward_mode, + batch_size=batch.batch_size(), + req_pool_indices=batch.req_pool_indices, + seq_lens=batch.seq_lens, + req_to_token_pool=model_runner.req_to_token_pool, + token_to_kv_pool=model_runner.token_to_kv_pool, + out_cache_loc=batch.out_cache_loc, + return_logprob=batch.return_logprob, + top_logprobs_nums=batch.top_logprobs_nums, + ) + + ret.compute_positions(batch) + + ret.compute_extend_infos(batch) + + ret.init_total_num_tokens(batch) + + if forward_mode != ForwardMode.DECODE: + ret.init_multimuldal_info(batch) + + prefix_lens = None + if forward_mode != ForwardMode.DECODE: + prefix_lens = torch.tensor( + [len(r.prefix_indices) for r in batch.reqs], device="cuda" + ) + + if model_runner.server_args.disable_flashinfer: + ret.init_triton_args(batch, prefix_lens) + + flashinfer_use_ragged = False + if not model_runner.server_args.disable_flashinfer: + if ( + forward_mode != ForwardMode.DECODE + and int(torch.sum(ret.seq_lens)) > 4096 + ): + flashinfer_use_ragged = True + ret.init_flashinfer_handlers( + model_runner, prefix_lens, flashinfer_use_ragged + ) + + return ret + + def init_triton_args(self, batch: ScheduleBatch, prefix_lens): + """Init auxiliary variables for triton attention backend.""" + self.triton_max_seq_len = max(len(r.input_ids) for r in batch.reqs) + self.triton_prefix_lens = prefix_lens + self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32) + self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0) + + if self.forward_mode == ForwardMode.DECODE: + self.triton_max_extend_len = None + else: + extend_seq_lens = self.seq_lens - prefix_lens + self.triton_max_extend_len = int(torch.max(extend_seq_lens)) + + def init_flashinfer_handlers( + self, model_runner, prefix_lens, flashinfer_use_ragged + ): + update_flashinfer_indices( + self.forward_mode, + model_runner, + self.req_pool_indices, + self.seq_lens, + prefix_lens, + flashinfer_use_ragged=flashinfer_use_ragged, + ) + + ( + self.flashinfer_prefill_wrapper_ragged, + self.flashinfer_prefill_wrapper_paged, + self.flashinfer_decode_wrapper, + self.flashinfer_use_ragged, + ) = ( + model_runner.flashinfer_prefill_wrapper_ragged, + model_runner.flashinfer_prefill_wrapper_paged, + model_runner.flashinfer_decode_wrapper, + flashinfer_use_ragged, + ) + + +def update_flashinfer_indices( + forward_mode, + model_runner, + req_pool_indices, + seq_lens, + prefix_lens, + flashinfer_decode_wrapper=None, + flashinfer_use_ragged=False, +): + """Init auxiliary variables for FlashInfer attention backend.""" + num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size + num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size) + head_dim = model_runner.model_config.head_dim + batch_size = len(req_pool_indices) + + if flashinfer_use_ragged: + paged_kernel_lens = prefix_lens + else: + paged_kernel_lens = seq_lens + + kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda") + kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) + req_pool_indices_cpu = req_pool_indices.cpu().numpy() + paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy() + kv_indices = torch.cat( + [ + model_runner.req_to_token_pool.req_to_token[ + req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i] + ] + for i in range(batch_size) + ], + dim=0, + ).contiguous() + kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda") + + if forward_mode == ForwardMode.DECODE: + # CUDA graph uses different flashinfer_decode_wrapper + if flashinfer_decode_wrapper is None: + flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper + + flashinfer_decode_wrapper.end_forward() + flashinfer_decode_wrapper.begin_forward( + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + 1, + ) + else: + # extend part + qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda") + qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0) + + if flashinfer_use_ragged: + model_runner.flashinfer_prefill_wrapper_ragged.end_forward() + model_runner.flashinfer_prefill_wrapper_ragged.begin_forward( + qo_indptr, + qo_indptr, + num_qo_heads, + num_kv_heads, + head_dim, + ) + + # cached part + model_runner.flashinfer_prefill_wrapper_paged.end_forward() + model_runner.flashinfer_prefill_wrapper_paged.begin_forward( + qo_indptr, + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + 1, + ) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py new file mode 100644 index 00000000000..574ad365800 --- /dev/null +++ b/python/sglang/srt/model_executor/model_runner.py @@ -0,0 +1,447 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""ModelRunner runs the forward passes of the models.""" + +import importlib +import importlib.resources +import logging +import pkgutil +import warnings +from functools import lru_cache +from typing import Optional, Type + +import torch +import torch.nn as nn +from flashinfer import ( + BatchDecodeWithPagedKVCacheWrapper, + BatchPrefillWithPagedKVCacheWrapper, + BatchPrefillWithRaggedKVCacheWrapper, +) +from flashinfer.decode import _grouped_size_compiled_for_decode_kernels +from vllm.config import DeviceConfig, LoadConfig +from vllm.config import ModelConfig as VllmModelConfig +from vllm.distributed import ( + get_tp_group, + init_distributed_environment, + initialize_model_parallel, +) +from vllm.model_executor.models import ModelRegistry + +from sglang.global_config import global_config +from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict +from sglang.srt.mem_cache.memory_pool import ( + MHATokenToKVPool, + MLATokenToKVPool, + ReqToTokenPool, +) +from sglang.srt.model_config import AttentionArch +from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata +from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import ( + get_available_gpu_memory, + is_generation_model, + is_llama3_405b_fp8, + is_multimodal_model, + monkey_patch_vllm_dummy_weight_loader, + monkey_patch_vllm_p2p_access_check, + monkey_patch_vllm_qvk_linear_loader, +) + +logger = logging.getLogger(__name__) + + +class ModelRunner: + def __init__( + self, + model_config, + mem_fraction_static: float, + gpu_id: int, + tp_rank: int, + tp_size: int, + nccl_port: int, + server_args: ServerArgs, + ): + # Parse args + self.model_config = model_config + self.mem_fraction_static = mem_fraction_static + self.gpu_id = gpu_id + self.tp_rank = tp_rank + self.tp_size = tp_size + self.nccl_port = nccl_port + self.server_args = server_args + self.is_multimodal_model = is_multimodal_model(self.model_config) + global_server_args_dict.update( + { + "disable_flashinfer": server_args.disable_flashinfer, + "disable_flashinfer_sampling": server_args.disable_flashinfer_sampling, + "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32, + "enable_mla": server_args.enable_mla, + } + ) + + # Init torch distributed + torch.cuda.set_device(self.gpu_id) + logger.info(f"[gpu={self.gpu_id}] Init nccl begin.") + + if not server_args.enable_p2p_check: + monkey_patch_vllm_p2p_access_check(self.gpu_id) + + if server_args.nccl_init_addr: + nccl_init_method = f"tcp://{server_args.nccl_init_addr}" + else: + nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}" + init_distributed_environment( + backend="nccl", + world_size=self.tp_size, + rank=self.tp_rank, + local_rank=self.gpu_id, + distributed_init_method=nccl_init_method, + ) + initialize_model_parallel(tensor_model_parallel_size=self.tp_size) + self.tp_group = get_tp_group() + total_gpu_memory = get_available_gpu_memory( + self.gpu_id, distributed=self.tp_size > 1 + ) + + if self.tp_size > 1: + total_local_gpu_memory = get_available_gpu_memory(self.gpu_id) + if total_local_gpu_memory < total_gpu_memory * 0.9: + raise ValueError( + "The memory capacity is unbalanced. Some GPUs may be occupied by other processes." + ) + + # Load the model and create memory pool + self.load_model() + self.init_memory_pool( + total_gpu_memory, + server_args.max_num_reqs, + server_args.max_total_tokens, + ) + self.init_cublas() + self.init_flashinfer() + + if self.is_generation: + # FIXME Currently, cuda graph only capture decode steps, which only exists in causal models + # Capture cuda graphs + self.init_cuda_graphs() + + def load_model(self): + logger.info( + f"[gpu={self.gpu_id}] Load weight begin. " + f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" + ) + + monkey_patch_vllm_dummy_weight_loader() + device_config = DeviceConfig() + load_config = LoadConfig(load_format=self.server_args.load_format) + vllm_model_config = VllmModelConfig( + model=self.server_args.model_path, + quantization=self.server_args.quantization, + tokenizer=None, + tokenizer_mode=None, + trust_remote_code=self.server_args.trust_remote_code, + dtype=self.server_args.dtype, + seed=42, + skip_tokenizer_init=True, + ) + + if is_llama3_405b_fp8(self.model_config) and self.tp_size <= 8: + # A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints + self.model_config.hf_config.num_key_value_heads = 8 + vllm_model_config.hf_config.num_key_value_heads = 8 + monkey_patch_vllm_qvk_linear_loader() + + self.dtype = vllm_model_config.dtype + if self.model_config.model_overide_args is not None: + vllm_model_config.hf_config.update(self.model_config.model_overide_args) + + if ( + self.server_args.efficient_weight_load + and "llama" in self.server_args.model_path.lower() + and self.server_args.quantization == "fp8" + ): + from sglang.srt.model_loader.model_loader import get_model + else: + from vllm.model_executor.model_loader import get_model + + self.model = get_model( + model_config=vllm_model_config, + device_config=device_config, + load_config=load_config, + lora_config=None, + multimodal_config=None, + parallel_config=None, + scheduler_config=None, + cache_config=None, + ) + self.is_generation = is_generation_model( + self.model_config.hf_config.architectures + ) + + logger.info( + f"[gpu={self.gpu_id}] Load weight end. " + f"type={type(self.model).__name__}, " + f"dtype={self.dtype}, " + f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" + ) + + def profile_max_num_token(self, total_gpu_memory): + available_gpu_memory = get_available_gpu_memory( + self.gpu_id, distributed=self.tp_size > 1 + ) + if ( + self.model_config.attention_arch == AttentionArch.MLA + and self.server_args.enable_mla + ): + cell_size = ( + (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim) + * self.model_config.num_hidden_layers + * torch._utils._element_size(self.dtype) + ) + else: + cell_size = ( + self.model_config.get_num_kv_heads(self.tp_size) + * self.model_config.head_dim + * self.model_config.num_hidden_layers + * 2 + * torch._utils._element_size(self.dtype) + ) + rest_memory = available_gpu_memory - total_gpu_memory * ( + 1 - self.mem_fraction_static + ) + max_num_token = int(rest_memory * (1 << 30) // cell_size) + return max_num_token + + def init_memory_pool( + self, total_gpu_memory, max_num_reqs=None, max_total_tokens=None + ): + self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory) + if max_total_tokens is not None: + if max_total_tokens > self.max_total_num_tokens: + warnings.warn( + f"max_total_tokens={max_total_tokens} is larger than the profiled value " + f"{self.max_total_num_tokens}. " + f"Use the profiled value instead." + ) + self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens) + + if self.max_total_num_tokens <= 0: + raise RuntimeError( + "Not enough memory. Please try to increase --mem-fraction-static." + ) + + if max_num_reqs is None: + max_num_reqs = min( + max( + int( + self.max_total_num_tokens / self.model_config.context_len * 512 + ), + 2048, + ), + 5120, + ) + + self.req_to_token_pool = ReqToTokenPool( + max_num_reqs, + self.model_config.context_len + 8, + ) + if ( + self.model_config.attention_arch == AttentionArch.MLA + and self.server_args.enable_mla + ): + self.token_to_kv_pool = MLATokenToKVPool( + self.max_total_num_tokens, + dtype=self.dtype, + kv_lora_rank=self.model_config.kv_lora_rank, + qk_rope_head_dim=self.model_config.qk_rope_head_dim, + layer_num=self.model_config.num_hidden_layers, + ) + logger.info("using MLA Triton implementaion, flashinfer is disabled") + # FIXME: temporarily only Triton MLA is supported + self.server_args.disable_flashinfer = True + else: + self.token_to_kv_pool = MHATokenToKVPool( + self.max_total_num_tokens, + dtype=self.dtype, + head_num=self.model_config.get_num_kv_heads(self.tp_size), + head_dim=self.model_config.head_dim, + layer_num=self.model_config.num_hidden_layers, + ) + logger.info( + f"[gpu={self.gpu_id}] Memory pool end. " + f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" + ) + + def init_cublas(self): + """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later.""" + dtype = torch.float16 + device = "cuda" + a = torch.ones((16, 16), dtype=dtype, device=device) + b = torch.ones((16, 16), dtype=dtype, device=device) + c = a @ b + return c + + def init_flashinfer(self): + if self.server_args.disable_flashinfer: + self.flashinfer_prefill_wrapper_ragged = None + self.flashinfer_prefill_wrapper_paged = None + self.flashinfer_decode_wrapper = None + return + + if not _grouped_size_compiled_for_decode_kernels( + self.model_config.num_attention_heads // self.tp_size, + self.model_config.get_num_kv_heads(self.tp_size), + ): + use_tensor_cores = True + else: + use_tensor_cores = False + + self.flashinfer_workspace_buffers = torch.empty( + 2, global_config.flashinfer_workspace_size, dtype=torch.uint8, device="cuda" + ) + self.flashinfer_prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper( + self.flashinfer_workspace_buffers[0], "NHD" + ) + self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper( + self.flashinfer_workspace_buffers[1], "NHD" + ) + self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( + self.flashinfer_workspace_buffers[0], + "NHD", + use_tensor_cores=use_tensor_cores, + ) + + def init_cuda_graphs(self): + from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner + + if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer: + self.cuda_graph_runner = None + return + + logger.info( + f"[gpu={self.gpu_id}] Capture cuda graph begin. This can take up to several minutes." + ) + batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 17)] + self.cuda_graph_runner = CudaGraphRunner( + self, + max_batch_size_to_capture=max(batch_size_list), + use_torch_compile=self.server_args.enable_torch_compile, + ) + try: + self.cuda_graph_runner.capture(batch_size_list) + except RuntimeError as e: + raise Exception( + f"Capture cuda graph failed: {e}\n" + "Possible solutions:\n" + "1. disable torch compile by not using --enable-torch-compile\n" + "2. disable cuda graph by --disable-cuda-graph\n" + "3. set --mem-fraction-static to a smaller value\n" + "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n" + ) + + @torch.inference_mode() + def forward_decode(self, batch: ScheduleBatch): + if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)): + return self.cuda_graph_runner.replay(batch) + + input_metadata = InputMetadata.from_schedule_batch( + self, batch, ForwardMode.DECODE + ) + + return self.model.forward( + batch.input_ids, input_metadata.positions, input_metadata + ) + + @torch.inference_mode() + def forward_extend(self, batch: ScheduleBatch): + input_metadata = InputMetadata.from_schedule_batch( + self, batch, forward_mode=ForwardMode.EXTEND + ) + return self.model.forward( + batch.input_ids, input_metadata.positions, input_metadata + ) + + @torch.inference_mode() + def forward_extend_multi_modal(self, batch: ScheduleBatch): + input_metadata = InputMetadata.from_schedule_batch( + self, batch, forward_mode=ForwardMode.EXTEND + ) + return self.model.forward( + batch.input_ids, + input_metadata.positions, + input_metadata, + input_metadata.pixel_values, + input_metadata.image_sizes, + input_metadata.image_offsets, + ) + + def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode): + if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND: + return self.forward_extend_multi_modal(batch) + elif forward_mode == ForwardMode.DECODE: + return self.forward_decode(batch) + elif forward_mode == ForwardMode.EXTEND: + return self.forward_extend(batch) + else: + raise ValueError(f"Invaid forward mode: {forward_mode}") + + +@lru_cache() +def import_model_classes(): + model_arch_name_to_cls = {} + package_name = "sglang.srt.models" + package = importlib.import_module(package_name) + for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."): + if not ispkg: + module = importlib.import_module(name) + if hasattr(module, "EntryClass"): + entry = module.EntryClass + if isinstance( + entry, list + ): # To support multiple model classes in one module + for tmp in entry: + assert tmp.__name__ not in model_arch_name_to_cls + model_arch_name_to_cls[tmp.__name__] = tmp + else: + assert entry.__name__ not in model_arch_name_to_cls + model_arch_name_to_cls[entry.__name__] = entry + + # compat: some models such as chatglm has incorrect class set in config.json + # usage: [ tuple("From_Entry_Class_Name": EntryClass), ] + if hasattr(module, "EntryClassRemapping") and isinstance( + module.EntryClassRemapping, list + ): + for remap in module.EntryClassRemapping: + if isinstance(remap, tuple) and len(remap) == 2: + assert remap[0] not in model_arch_name_to_cls + model_arch_name_to_cls[remap[0]] = remap[1] + + return model_arch_name_to_cls + + +def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]: + model_arch_name_to_cls = import_model_classes() + + if model_arch not in model_arch_name_to_cls: + raise ValueError( + f"Unsupported architectures: {model_arch}. " + f"Supported list: {list(model_arch_name_to_cls.keys())}" + ) + return model_arch_name_to_cls[model_arch] + + +# Monkey patch model loader +setattr(ModelRegistry, "load_model_cls", load_model_cls_srt) diff --git a/python/sglang/srt/model_loader/model_loader.py b/python/sglang/srt/model_loader/model_loader.py new file mode 100644 index 00000000000..4b7e32b6e55 --- /dev/null +++ b/python/sglang/srt/model_loader/model_loader.py @@ -0,0 +1,292 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# temporarily adapted from https://github.com/vllm-project/vllm/blob/10383887e03412196a2689b9398290719c4797bf/vllm/model_executor/model_loader/loader.py +# FIXME: in progress of refactoring the model loader + +import glob +import os +import re +from typing import Any, Dict, Generator, List, Optional, Tuple, Type + +import torch +from torch import nn +from tqdm import tqdm +from vllm.config import ( + CacheConfig, + DeviceConfig, + LoadConfig, + LoadFormat, + LoRAConfig, + ModelConfig, + MultiModalConfig, + ParallelConfig, + SchedulerConfig, +) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.model_loader.utils import ( + get_model_architecture, + set_default_torch_dtype, +) +from vllm.platforms import current_platform + +from sglang.srt.model_loader.utils import ( + download_safetensors_index_file_from_hf, + download_weights_from_hf, + filter_duplicate_safetensors_files, + get_quant_config, + safetensors_weights_iterator, +) + + +def _get_quantization_config( + model_config: ModelConfig, load_config: LoadConfig +) -> Optional[QuantizationConfig]: + """Get the quantization config.""" + if model_config.quantization is not None: + quant_config = get_quant_config(model_config, load_config) + capability = current_platform.get_device_capability() + capability = capability[0] * 10 + capability[1] + if capability < quant_config.get_min_capability(): + raise ValueError( + f"The quantization method {model_config.quantization} is not " + "supported for the current GPU. " + f"Minimum capability: {quant_config.get_min_capability()}. " + f"Current capability: {capability}." + ) + supported_dtypes = quant_config.get_supported_act_dtypes() + if model_config.dtype not in supported_dtypes: + raise ValueError( + f"{model_config.dtype} is not supported for quantization " + f"method {model_config.quantization}. Supported dtypes: " + f"{supported_dtypes}" + ) + return quant_config + return None + + +def _get_model_initialization_kwargs( + model_class: Type[nn.Module], + lora_config: Optional[LoRAConfig], + multimodal_config: Optional[MultiModalConfig], +) -> Dict[str, Any]: + """Get extra kwargs for model initialization.""" + extra_kwargs: Dict[str, Any] = {} + + assert lora_config is None + assert multimodal_config is None + + return extra_kwargs + + +def _initialize_model( + model_config: ModelConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + multimodal_config: Optional[MultiModalConfig], + cache_config: CacheConfig, +) -> nn.Module: + """Initialize a model with the given configurations.""" + model_class = get_model_architecture(model_config)[0] + quant_config = _get_quantization_config(model_config, load_config) + + return model_class( + config=model_config.hf_config, + cache_config=cache_config, + quant_config=quant_config, + efficient_weight_load=True, + **_get_model_initialization_kwargs(model_class, lora_config, multimodal_config), + ) + + +class ModelLoader: + """Model loader that can load different file types from disk.""" + + def __init__(self, load_config: LoadConfig): + self.load_config = load_config + + def _prepare_weights( + self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool + ) -> Tuple[str, List[str], bool]: + """Prepare weights for the model. + + If the model is not local, it will be downloaded.""" + + is_local = os.path.isdir(model_name_or_path) + load_format = self.load_config.load_format + use_safetensors = False + # Some quantized models use .pt files for storing the weights. + if load_format == LoadFormat.AUTO: + allow_patterns = ["*.safetensors", "*.bin"] + elif load_format == LoadFormat.SAFETENSORS: + use_safetensors = True + allow_patterns = ["*.safetensors"] + elif load_format == LoadFormat.PT: + allow_patterns = ["*.pt"] + elif load_format == LoadFormat.NPCACHE: + allow_patterns = ["*.bin"] + else: + raise ValueError(f"Unknown load_format: {load_format}") + + if fall_back_to_pt: + allow_patterns += ["*.pt"] + + if not is_local: + hf_folder = download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + allow_patterns, + revision, + ) + else: + hf_folder = model_name_or_path + + hf_weights_files: List[str] = [] + for pattern in allow_patterns: + hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) + if len(hf_weights_files) > 0: + if pattern == "*.safetensors": + use_safetensors = True + break + + if use_safetensors: + # For models like Mistral-7B-Instruct-v0.3 + # there are both sharded safetensors files and a consolidated + # safetensors file. Using both breaks. + # Here, we download the `model.safetensors.index.json` and filter + # any files not found in the index. + if not is_local: + download_safetensors_index_file_from_hf( + model_name_or_path, self.load_config.download_dir, revision + ) + hf_weights_files = filter_duplicate_safetensors_files( + hf_weights_files, hf_folder + ) + else: + hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files) + + if len(hf_weights_files) == 0: + raise RuntimeError( + f"Cannot find any model weights with `{model_name_or_path}`" + ) + + return hf_folder, hf_weights_files, use_safetensors + + def _get_weights_iterator( + self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Get an iterator for the model weights based on the load format.""" + hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( + model_name_or_path, revision, fall_back_to_pt + ) + if self.load_config.load_format == LoadFormat.NPCACHE: + # Currently np_cache only support *.bin checkpoints + assert use_safetensors is False + weights_iterator = np_cache_weights_iterator( + model_name_or_path, + self.load_config.download_dir, + hf_folder, + hf_weights_files, + ) + elif use_safetensors: + weights_iterator = safetensors_weights_iterator(hf_weights_files) + else: + weights_iterator = pt_weights_iterator(hf_weights_files) + + return weights_iterator + + def load_model( + self, + *, + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + multimodal_config: Optional[MultiModalConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + ) -> nn.Module: + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model( + model_config, + self.load_config, + lora_config, + multimodal_config, + cache_config, + ) + weights = self._get_weights_iterator( + model_config.model, + model_config.revision, + fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True), + ) + + modules = {} + for name, module in model.named_modules(): + modules[name] = module + + def apply_quant_method(module): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + # print("before apply quant", module.weight, module.weight.dtype) + quant_method.process_weights_after_loading(module) + # print("after apply quant", module.weight, module.weight.dtype) + # FIXME: Remove this after Mixtral is updated + # to use quant_method. + if hasattr(module, "process_weights_after_loading"): + module.process_weights_after_loading() + + if torch.cuda.current_device() == 0: + weights = tqdm( + weights, total=model.get_num_params() * 1.5, desc="load model" + ) + + num_shard = {} + num_loaded = {} + for name, loaded_weight in weights: + model.load_weights(None, name, loaded_weight) + module_name, shard_num = model.get_module_name(name) + num_shard[module_name] = shard_num + if module_name not in num_loaded: + num_loaded[module_name] = 1 + else: + num_loaded[module_name] += 1 + if num_loaded[module_name] == num_shard[module_name]: + apply_quant_method(modules[module_name]) + + return model.eval() + + +def get_model( + *, + model_config: ModelConfig, + load_config: LoadConfig, + device_config: DeviceConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + lora_config: Optional[LoRAConfig], + multimodal_config: Optional[MultiModalConfig], + cache_config: CacheConfig, +) -> nn.Module: + loader = ModelLoader(load_config) + return loader.load_model( + model_config=model_config, + device_config=device_config, + lora_config=lora_config, + multimodal_config=multimodal_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + cache_config=cache_config, + ) diff --git a/python/sglang/srt/model_loader/utils.py b/python/sglang/srt/model_loader/utils.py new file mode 100644 index 00000000000..9d6520e2ae5 --- /dev/null +++ b/python/sglang/srt/model_loader/utils.py @@ -0,0 +1,275 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# temporarily adapted from vLLM +# FIXME: in progress of refactoring the model loader +"""Utilities for selecting and loading models.""" +import contextlib +import fnmatch +import hashlib +import json +import logging +import os +import tempfile +from typing import Any, Generator, Iterable, List, Optional, Tuple, Type + +import filelock +import huggingface_hub.constants +import torch +from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download +from safetensors.torch import load_file, safe_open, save_file +from torch import nn +from tqdm.auto import tqdm +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME +from vllm.config import LoadConfig, ModelConfig +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig + +from sglang.srt.layers.quantization import get_quantization_config + +logger = logging.getLogger(__name__) +temp_dir = tempfile.gettempdir() + + +@contextlib.contextmanager +def set_default_torch_dtype(dtype: torch.dtype): + """Sets the default torch dtype to the given dtype.""" + old_dtype = torch.get_default_dtype() + torch.set_default_dtype(dtype) + yield + torch.set_default_dtype(old_dtype) + + +def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: + architectures = getattr(model_config.hf_config, "architectures", []) + # Special handling for quantized Mixtral. + # FIXME(woosuk): This is a temporary hack. + if ( + model_config.quantization is not None + and model_config.quantization != "fp8" + and "MixtralForCausalLM" in architectures + ): + architectures = ["QuantMixtralForCausalLM"] + + for arch in architectures: + model_cls = ModelRegistry.load_model_cls(arch) + if model_cls is not None: + return (model_cls, arch) + raise ValueError( + f"Model architectures {architectures} are not supported for now. " + f"Supported architectures: {ModelRegistry.get_supported_archs()}" + ) + + +class DisabledTqdm(tqdm): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs, disable=True) + + +def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None): + lock_dir = cache_dir or temp_dir + os.makedirs(os.path.dirname(lock_dir), exist_ok=True) + model_name = model_name_or_path.replace("/", "-") + hash_name = hashlib.sha256(model_name.encode()).hexdigest() + # add hash to avoid conflict with old users' lock files + lock_file_name = hash_name + model_name + ".lock" + # mode 0o666 is required for the filelock to be shared across users + lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), mode=0o666) + return lock + + +def download_weights_from_hf( + model_name_or_path: str, + cache_dir: Optional[str], + allow_patterns: List[str], + revision: Optional[str] = None, +) -> str: + """Download model weights from Hugging Face Hub. + + Args: + model_name_or_path (str): The model name or path. + cache_dir (Optional[str]): The cache directory to store the model + weights. If None, will use HF defaults. + allow_patterns (List[str]): The allowed patterns for the + weight files. Files matched by any of the patterns will be + downloaded. + revision (Optional[str]): The revision of the model. + + Returns: + str: The path to the downloaded model weights. + """ + if not huggingface_hub.constants.HF_HUB_OFFLINE: + # Before we download we look at that is available: + fs = HfFileSystem() + file_list = fs.ls(model_name_or_path, detail=False, revision=revision) + + # depending on what is available we download different things + for pattern in allow_patterns: + matching = fnmatch.filter(file_list, pattern) + if len(matching) > 0: + allow_patterns = [pattern] + break + + logger.info("Using model weights format %s", allow_patterns) + # Use file lock to prevent multiple processes from + # downloading the same model weights at the same time. + with get_lock(model_name_or_path, cache_dir): + hf_folder = snapshot_download( + model_name_or_path, + allow_patterns=allow_patterns, + cache_dir=cache_dir, + tqdm_class=DisabledTqdm, + revision=revision, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + ) + return hf_folder + + +def download_safetensors_index_file_from_hf( + model_name_or_path: str, + cache_dir: Optional[str], + revision: Optional[str] = None, +) -> None: + """Download hf safetensors index file from Hugging Face Hub. + + Args: + model_name_or_path (str): The model name or path. + cache_dir (Optional[str]): The cache directory to store the model + weights. If None, will use HF defaults. + revision (Optional[str]): The revision of the model. + """ + # Use file lock to prevent multiple processes from + # downloading the same model weights at the same time. + with get_lock(model_name_or_path, cache_dir): + try: + # Download the safetensors index file. + hf_hub_download( + repo_id=model_name_or_path, + filename=SAFE_WEIGHTS_INDEX_NAME, + cache_dir=cache_dir, + revision=revision, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + ) + # If file not found on remote or locally, we should not fail since + # only some models will have SAFE_WEIGHTS_INDEX_NAME. + except huggingface_hub.utils.EntryNotFoundError: + logger.info("No %s found in remote.", SAFE_WEIGHTS_INDEX_NAME) + except huggingface_hub.utils.LocalEntryNotFoundError: + logger.info("No %s found in local cache.", SAFE_WEIGHTS_INDEX_NAME) + + +# For models like Mistral-7B-v0.3, there are both sharded +# safetensors files and a consolidated safetensors file. +# Passing both of these to the weight loader functionality breaks. +# So, we use the SAFE_WEIGHTS_INDEX_NAME to +# look up which safetensors files should be used. +def filter_duplicate_safetensors_files( + hf_weights_files: List[str], hf_folder: str +) -> List[str]: + # model.safetensors.index.json is a mapping from keys in the + # torch state_dict to safetensors file holding that weight. + index_file_name = os.path.join(hf_folder, SAFE_WEIGHTS_INDEX_NAME) + if not os.path.isfile(index_file_name): + return hf_weights_files + + # Iterate through the weight_map (weight_name: safetensors files) + # to identify weights that we should use. + with open(index_file_name) as index_file: + weight_map = json.load(index_file)["weight_map"] + weight_files_in_index = set() + for weight_name in weight_map: + weight_files_in_index.add(os.path.join(hf_folder, weight_map[weight_name])) + # Filter out any fields that are not found in the index file. + hf_weights_files = [f for f in hf_weights_files if f in weight_files_in_index] + return hf_weights_files + + +def safetensors_weights_iterator( + hf_weights_files: List[str], +) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Iterate over the weights in the model safetensor files.""" + for st_file in hf_weights_files: + with safe_open(st_file, framework="pt") as f: + for name in f.keys(): # noqa: SIM118 + param = f.get_tensor(name) + yield name, param + + +def get_quant_config( + model_config: ModelConfig, load_config: LoadConfig +) -> QuantizationConfig: + quant_cls = get_quantization_config(model_config.quantization) + # Read the quantization config from the HF model config, if available. + hf_quant_config = getattr(model_config.hf_config, "quantization_config", None) + if hf_quant_config is None: + # compressed-tensors uses a compressions_config + hf_quant_config = getattr(model_config.hf_config, "compression_config", None) + if hf_quant_config is not None: + return quant_cls.from_config(hf_quant_config) + # In case of bitsandbytes/QLoRA, get quant config from the adapter model. + if model_config.quantization == "bitsandbytes": + if ( + not load_config.model_loader_extra_config + or "qlora_adapter_name_or_path" not in load_config.model_loader_extra_config + ): + return quant_cls.from_config({"adapter_name_or_path": ""}) + model_name_or_path = load_config.model_loader_extra_config[ + "qlora_adapter_name_or_path" + ] + + else: + model_name_or_path = model_config.model + is_local = os.path.isdir(model_name_or_path) + if not is_local: + # Download the config files. + with get_lock(model_name_or_path, load_config.download_dir): + hf_folder = snapshot_download( + model_name_or_path, + revision=model_config.revision, + allow_patterns="*.json", + cache_dir=load_config.download_dir, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + tqdm_class=DisabledTqdm, + ) + else: + hf_folder = model_name_or_path + + possible_config_filenames = quant_cls.get_config_filenames() + + # If the quantization config is not found, use the default config. + if not possible_config_filenames: + return quant_cls() + + config_files = glob.glob(os.path.join(hf_folder, "*.json")) + + quant_config_files = [ + f for f in config_files if any(f.endswith(x) for x in possible_config_filenames) + ] + if len(quant_config_files) == 0: + raise ValueError(f"Cannot find the config file for {model_config.quantization}") + if len(quant_config_files) > 1: + raise ValueError( + f"Found multiple config files for {model_config.quantization}: " + f"{quant_config_files}" + ) + + quant_config_file = quant_config_files[0] + with open(quant_config_file, "r") as f: + config = json.load(f) + + if model_config.quantization == "bitsandbytes": + config["adapter_name_or_path"] = model_name_or_path + + return quant_cls.from_config(config) diff --git a/python/sglang/srt/models/chatglm.py b/python/sglang/srt/models/chatglm.py index e9ec3e2d2ba..d2ad02fbf4c 100644 --- a/python/sglang/srt/models/chatglm.py +++ b/python/sglang/srt/models/chatglm.py @@ -1,3 +1,18 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + # coding=utf-8 # Adapted from # https://github.com/THUDM/ChatGLM2-6B @@ -30,7 +45,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.controller.model_runner import InputMetadata +from sglang.srt.model_executor.forward_batch_info import InputMetadata LoraConfig = None @@ -360,6 +375,7 @@ def __init__( self.logits_processor = LogitsProcessor(config) self.sampler = Sampler() + @torch.no_grad() def forward( self, input_ids: torch.Tensor, diff --git a/python/sglang/srt/models/commandr.py b/python/sglang/srt/models/commandr.py index 2757645e16d..1259285c464 100644 --- a/python/sglang/srt/models/commandr.py +++ b/python/sglang/srt/models/commandr.py @@ -1,3 +1,18 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + # coding=utf-8 # Copyright 2024 Cohere and the HuggingFace Inc. team. All rights reserved. # @@ -49,7 +64,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.controller.model_runner import InputMetadata +from sglang.srt.model_executor.forward_batch_info import InputMetadata @torch.compile diff --git a/python/sglang/srt/models/dbrx.py b/python/sglang/srt/models/dbrx.py index b21142d2e66..39ac4aefa72 100644 --- a/python/sglang/srt/models/dbrx.py +++ b/python/sglang/srt/models/dbrx.py @@ -1,3 +1,18 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + # Adapted from: # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/dbrx.py#L1 # coding=utf-8 @@ -30,7 +45,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.controller.model_runner import InputMetadata +from sglang.srt.model_executor.forward_batch_info import InputMetadata class DbrxRouter(nn.Module): @@ -368,6 +383,7 @@ def __init__( ) self.logits_processor = LogitsProcessor(config) + @torch.no_grad() def forward( self, input_ids: torch.Tensor, diff --git a/python/sglang/srt/models/deepseek.py b/python/sglang/srt/models/deepseek.py new file mode 100644 index 00000000000..98dcfd28df4 --- /dev/null +++ b/python/sglang/srt/models/deepseek.py @@ -0,0 +1,445 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# Adapted from: +# https://github.com/vllm-project/vllm/blob/14f91fe67c2342f2fe859dc6a5c40810df0e1c61/vllm/model_executor/models/deepseek.py +"""Inference-only Deepseek model.""" +from typing import Any, Dict, Iterable, Optional, Tuple + +import torch +from torch import nn +from transformers import PretrainedConfig +from vllm.config import CacheConfig +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.model_executor.forward_batch_info import InputMetadata + + +class DeepseekMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + ) + if hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now." + ) + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class DeepseekMoE(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + self.n_routed_experts = config.n_routed_experts + self.top_k = config.num_experts_per_tok + if self.tp_size > self.n_routed_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {self.n_routed_experts}." + ) + + self.experts = nn.ModuleList( + [ + DeepseekMLP( + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + ) + for idx in range(self.n_routed_experts) + ] + ) + self.pack_params() + + self.gate = ReplicatedLinear( + config.hidden_size, self.n_routed_experts, bias=False, quant_config=None + ) + + if config.n_shared_experts is not None: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = DeepseekMLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + ) + + def pack_params(self): + w1 = [] + w2 = [] + for expert in self.experts: + w1.append(expert.gate_up_proj.weight) + w2.append(expert.down_proj.weight) + self.w1 = torch._utils._flatten_dense_tensors(w1) + w1s = torch._utils._unflatten_dense_tensors(self.w1, w1) + for data, param in zip(w1s, w1): + param.data = data + self.w1 = self.w1.view(len(w1), *w1s[0].shape) + + self.w2 = torch._utils._flatten_dense_tensors(w2) + w2s = torch._utils._unflatten_dense_tensors(self.w2, w2) + for data, param in zip(w2s, w2): + param.data = data + + self.w2 = self.w2.view(len(w2), *w2s[0].shape) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + if self.config.n_shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = fused_moe( + hidden_states, + self.w1, + self.w2, + router_logits, + self.top_k, + renormalize=self.config.norm_topk_prob, + inplace=True, + ) + + if self.config.n_shared_experts is not None: + final_hidden_states = final_hidden_states + shared_output + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) + + return final_hidden_states.view(num_tokens, hidden_dim) + + +class DeepseekAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + layer_id: int = 0, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + ) + + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, input_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class DeepseekDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + layer_id: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + self.self_attn = DeepseekAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + layer_id=layer_id, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + ) + if ( + config.n_routed_experts is not None + and layer_id >= config.first_k_dense_replace + and layer_id % config.moe_layer_freq == 0 + ): + self.mlp = DeepseekMoE(config=config, quant_config=quant_config) + else: + self.mlp = DeepseekMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + input_metadata=input_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class DeepseekModel(nn.Module): + + fall_back_to_pt_during_load = False + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.layers = nn.ModuleList( + [ + DeepseekDecoderLayer( + config, layer_id, cache_config, quant_config=quant_config + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer( + positions, hidden_states, input_metadata, residual + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class DeepseekForCausalLM(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.quant_config = quant_config + self.model = DeepseekModel(config, cache_config, quant_config) + self.lm_head = ParallelLMHead( + config.vocab_size, config.hidden_size, quant_config=quant_config + ) + self.logits_processor = LogitsProcessor(config) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, input_metadata) + return self.logits_processor( + input_ids, hidden_states, self.lm_head.weight, input_metadata + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip experts that are not assigned to this worker. + if ( + "mlp.experts." in name or "mlp.shared_experts." in name + ) and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip experts that are not assigned to this worker. + if ( + "mlp.experts." in name or "mlp.shared_experts." in name + ) and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +EntryClass = DeepseekForCausalLM diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py new file mode 100644 index 00000000000..739562730b3 --- /dev/null +++ b/python/sglang/srt/models/deepseek_v2.py @@ -0,0 +1,714 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# Adapted from: +# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py +"""Inference-only DeepseekV2 model.""" +from typing import Any, Dict, Iterable, Optional, Tuple + +import torch +from torch import nn +from transformers import PretrainedConfig +from vllm.config import CacheConfig +from vllm.distributed import ( + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.model_executor.forward_batch_info import InputMetadata + + +class DeepseekV2MLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + ) + if hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now." + ) + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class DeepseekV2MoE(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.routed_scaling_factor = config.routed_scaling_factor + self.n_shared_experts = config.n_shared_experts + self.routed_scaling_factor = config.routed_scaling_factor + if self.tp_size > config.n_routed_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.n_routed_experts}." + ) + + if config.hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now." + ) + + self.experts = FusedMoE( + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + ) + + self.gate = ReplicatedLinear( + config.hidden_size, config.n_routed_experts, bias=False, quant_config=None + ) + if config.n_shared_experts is not None: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = DeepseekV2MLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + if self.n_shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = ( + self.experts(hidden_states=hidden_states, router_logits=router_logits) + * self.routed_scaling_factor + ) + if shared_output is not None: + final_hidden_states = final_hidden_states + shared_output + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) + + return final_hidden_states.view(num_tokens, hidden_dim) + + +def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: + import math + + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +class DeepseekV2Attention(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: int, + kv_lora_rank: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + layer_id=None, + ) -> None: + super().__init__() + self.layer_id = layer_id + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.num_heads = num_heads + tp_size = get_tensor_model_parallel_world_size() + assert num_heads % tp_size == 0 + self.num_local_heads = num_heads // tp_size + self.scaling = self.qk_head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + if self.q_lora_rank is not None: + self.q_a_proj = ReplicatedLinear( + self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + ) + self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear( + q_lora_rank, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + ) + else: + self.q_proj = ColumnParallelLinear( + self.hidden_size, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + ) + + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + ) + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config, + ) + # O projection. + self.o_proj = RowParallelLinear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + ) + rope_scaling["type"] = "deepseek_yarn" + self.rotary_emb = get_rope( + qk_rope_head_dim, + rotary_dim=qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=False, + ) + + if rope_scaling: + mscale_all_dim = rope_scaling.get("mscale_all_dim", False) + scaling_factor = rope_scaling["factor"] + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.scaling = self.scaling * mscale * mscale + + # self.attn = Attention(self.num_heads, + # self.qk_head_dim, + # self.scaling, + # num_kv_heads=self.num_heads) + + # TODO, support head_size 192 + self.attn = RadixAttention( + self.num_local_heads, + 256, + self.scaling, + num_kv_heads=self.num_local_heads, + layer_id=layer_id, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + if self.q_lora_rank is not None: + q = self.q_a_proj(hidden_states)[0] + q = self.q_a_layernorm(q) + q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) + else: + q = self.q_proj(hidden_states)[0].view( + -1, self.num_local_heads, self.qk_head_dim + ) + q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] + kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + latent_cache = latent_cache.unsqueeze(1) + kv_a = self.kv_a_layernorm(kv_a.contiguous()) + kv = self.kv_b_proj(kv_a)[0] + kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k_pe = latent_cache[:, :, self.kv_lora_rank :] + q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) + q[..., self.qk_nope_head_dim :] = q_pe + k = torch.empty_like(q) + k[..., : self.qk_nope_head_dim] = k_nope + k[..., self.qk_nope_head_dim :] = k_pe + q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim], value=0).view( + -1, self.num_local_heads * 256 + ) + k = torch.nn.functional.pad(k, [0, 256 - self.qk_head_dim], value=0).view( + -1, self.num_local_heads * 256 + ) + v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim], value=0).view( + -1, self.num_local_heads * 256 + ) + attn_output = self.attn(q, k, v, input_metadata) + attn_output = attn_output.view(-1, self.num_local_heads, 256)[ + ..., : self.v_head_dim + ].reshape(-1, self.num_local_heads * self.v_head_dim) + output, _ = self.o_proj(attn_output) + return output + + +class DeepseekV2AttentionMLA(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: int, + kv_lora_rank: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + layer_id=None, + ) -> None: + super().__init__() + self.layer_id = layer_id + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.num_heads = num_heads + tp_size = get_tensor_model_parallel_world_size() + assert num_heads % tp_size == 0 + self.num_local_heads = num_heads // tp_size + self.scaling = self.qk_head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + if self.q_lora_rank is not None: + self.q_a_proj = ReplicatedLinear( + self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + ) + self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear( + q_lora_rank, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + ) + else: + self.q_proj = ColumnParallelLinear( + self.hidden_size, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + ) + + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + ) + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config, + ) + # O projection. + self.o_proj = RowParallelLinear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + ) + rope_scaling["type"] = "deepseek_yarn" + self.rotary_emb = get_rope( + qk_rope_head_dim, + rotary_dim=qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=False, + ) + + if rope_scaling: + mscale_all_dim = rope_scaling.get("mscale_all_dim", False) + scaling_factor = rope_scaling["factor"] + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.scaling = self.scaling * mscale * mscale + + self.attn = RadixAttention( + self.num_local_heads, + self.kv_lora_rank + self.qk_rope_head_dim, + self.scaling, + num_kv_heads=1, + layer_id=layer_id, + v_head_dim=self.kv_lora_rank, + ) + + kv_b_proj = self.kv_b_proj + w_kc, w_vc = kv_b_proj.weight.unflatten( + 0, (-1, qk_nope_head_dim + v_head_dim) + ).split([qk_nope_head_dim, v_head_dim], dim=1) + self.w_kc = w_kc + self.w_vc = w_vc + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + q_len = hidden_states.shape[0] + q_input = hidden_states.new_empty( + q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim + ) + if self.q_lora_rank is not None: + q = self.q_a_proj(hidden_states)[0] + q = self.q_a_layernorm(q) + q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) + else: + q = self.q_proj(hidden_states)[0].view( + -1, self.num_local_heads, self.qk_head_dim + ) + q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + q_nope_out = q_input[..., : self.kv_lora_rank] + torch.bmm(q_nope.transpose(0, 1), self.w_kc, out=q_nope_out.transpose(0, 1)) + + k_input = self.kv_a_proj_with_mqa(hidden_states)[0].unsqueeze(1) + k_pe = k_input[..., self.kv_lora_rank :] + v_input = k_input[..., : self.kv_lora_rank] + v_input = self.kv_a_layernorm(v_input.contiguous()) + k_input[..., : self.kv_lora_rank] = v_input + + q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) + q_input[..., self.kv_lora_rank :] = q_pe + k_input[..., self.kv_lora_rank :] = k_pe + + attn_output = self.attn(q_input, k_input, v_input, input_metadata) + attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) + attn_bmm_output = attn_output.new_empty( + q_len, self.num_local_heads, self.v_head_dim + ) + torch.bmm( + attn_output.transpose(0, 1), + self.w_vc.transpose(1, 2).contiguous(), + out=attn_bmm_output.transpose(0, 1), + ) + + attn_output = attn_bmm_output.flatten(1, 2) + output, _ = self.o_proj(attn_output) + + return output + + +class DeepseekV2DecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + layer_id: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + if global_server_args_dict["enable_mla"]: + self.self_attn = DeepseekV2AttentionMLA( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + qk_nope_head_dim=config.qk_nope_head_dim, + qk_rope_head_dim=config.qk_rope_head_dim, + v_head_dim=config.v_head_dim, + q_lora_rank=( + config.q_lora_rank if hasattr(config, "q_lora_rank") else None + ), + kv_lora_rank=config.kv_lora_rank, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + layer_id=layer_id, + ) + else: + self.self_attn = DeepseekV2Attention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + qk_nope_head_dim=config.qk_nope_head_dim, + qk_rope_head_dim=config.qk_rope_head_dim, + v_head_dim=config.v_head_dim, + q_lora_rank=( + config.q_lora_rank if hasattr(config, "q_lora_rank") else None + ), + kv_lora_rank=config.kv_lora_rank, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + layer_id=layer_id, + ) + if ( + config.n_routed_experts is not None + and layer_id >= config.first_k_dense_replace + and layer_id % config.moe_layer_freq == 0 + ): + self.mlp = DeepseekV2MoE(config=config, quant_config=quant_config) + else: + self.mlp = DeepseekV2MLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + input_metadata=input_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class DeepseekV2Model(nn.Module): + + fall_back_to_pt_during_load = False + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.padding_id = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.layers = nn.ModuleList( + [ + DeepseekV2DecoderLayer( + config, + layer_id, + cache_config=cache_config, + quant_config=quant_config, + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer( + positions, hidden_states, input_metadata, residual + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class DeepseekV2ForCausalLM(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.quant_config = quant_config + self.model = DeepseekV2Model(config, cache_config, quant_config) + self.lm_head = ParallelLMHead( + config.vocab_size, config.hidden_size, quant_config=quant_config + ) + self.logits_processor = LogitsProcessor(config) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, input_metadata) + return self.logits_processor( + input_ids, hidden_states, self.lm_head.weight, input_metadata + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts, + ) + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if ("mlp.experts." in name) and name not in params_dict: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight, + weight_name, + shard_id=shard_id, + expert_id=expert_id, + ) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + + +EntryClass = DeepseekV2ForCausalLM diff --git a/python/sglang/srt/models/gemma.py b/python/sglang/srt/models/gemma.py index 2281b4c8082..ce397311569 100644 --- a/python/sglang/srt/models/gemma.py +++ b/python/sglang/srt/models/gemma.py @@ -1,3 +1,18 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + # Adapted from: # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/gemma.py#L1 """Inference-only Gemma model compatible with HuggingFace weights.""" @@ -22,7 +37,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.controller.model_runner import InputMetadata +from sglang.srt.model_executor.forward_batch_info import InputMetadata class GemmaMLP(nn.Module): diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index c6c409dee65..db87624d2df 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -1,3 +1,18 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + # Adapted from: # https://github.com/vllm-project/vllm/blob/56b325e977435af744f8b3dca7af0ca209663558/vllm/model_executor/models/gemma2.py from typing import Iterable, Optional, Set, Tuple, Union @@ -23,11 +38,10 @@ # from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.controller.model_runner import InputMetadata +from sglang.srt.model_executor.forward_batch_info import InputMetadata class GemmaRMSNorm(CustomOp): diff --git a/python/sglang/srt/models/gpt_bigcode.py b/python/sglang/srt/models/gpt_bigcode.py new file mode 100644 index 00000000000..9a9e2aec3a7 --- /dev/null +++ b/python/sglang/srt/models/gpt_bigcode.py @@ -0,0 +1,297 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# Adapted from: +# https://github.com/vllm-project/vllm/blob/07eb6f19f3b0ee9f7adf6eb689607028aa40bfd5/vllm/model_executor/models/gpt_bigcode.py +"""Inference-only GPTBigCode model compatible with HuggingFace weights.""" +from typing import Iterable, Optional, Tuple + +import torch +from torch import nn +from transformers import GPTBigCodeConfig +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.model_executor.forward_batch_info import InputMetadata + + +class GPTBigCodeAttention(nn.Module): + + def __init__( + self, + layer_id: int, + config: GPTBigCodeConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.hidden_size = config.hidden_size + total_num_heads = config.num_attention_heads + self.tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() + assert total_num_heads % self.tensor_model_parallel_world_size == 0 + self.num_heads = total_num_heads // self.tensor_model_parallel_world_size + self.head_dim = self.hidden_size // total_num_heads + self.scale = self.head_dim**-0.5 + + self.multi_query = config.multi_query + if self.multi_query: + total_num_kv_heads = 1 + self.num_kv_heads = 1 + else: + total_num_kv_heads = total_num_heads + self.num_kv_heads = self.num_heads + self.kv_dim = self.head_dim * self.num_kv_heads + self.c_attn = QKVParallelLinear( + self.hidden_size, + self.head_dim, + total_num_heads, + total_num_kv_heads, + bias=True, + quant_config=quant_config, + ) + + self.c_proj = RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + ) + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + scaling=self.scale, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + ) + + def forward( + self, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.c_attn(hidden_states) + q, k, v = qkv.split( + [ + self.hidden_size // self.tensor_model_parallel_world_size, + self.kv_dim, + self.kv_dim, + ], + dim=-1, + ) + attn_output = self.attn(q, k, v, input_metadata) + attn_output, _ = self.c_proj(attn_output) + return attn_output + + +class GPTBigMLP(nn.Module): + + def __init__( + self, + intermediate_size: int, + config: GPTBigCodeConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + hidden_size = config.hidden_size + self.c_fc = ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=True, + quant_config=quant_config, + ) + self.c_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=True, + quant_config=quant_config, + ) + self.act = get_act_fn( + config.activation_function, quant_config, intermediate_size + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.c_proj(hidden_states) + return hidden_states + + +class GPTBigCodeBlock(nn.Module): + + def __init__( + self, + layer_id: int, + config: GPTBigCodeConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + hidden_size = config.hidden_size + inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size + + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = GPTBigCodeAttention(layer_id, config, cache_config, quant_config) + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.mlp = GPTBigMLP(inner_dim, config, quant_config) + + def forward( + self, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_output = self.attn( + hidden_states=hidden_states, input_metadata=input_metadata + ) + # residual connection + hidden_states = attn_output + residual + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + feed_forward_hidden_states + return hidden_states + + +class GPTBigCodeModel(nn.Module): + + def __init__( + self, + config: GPTBigCodeConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ): + super().__init__() + self.config = config + assert not config.add_cross_attention + + self.embed_dim = config.hidden_size + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) + self.vocab_size = config.vocab_size + lora_vocab + self.wte = VocabParallelEmbedding( + self.vocab_size, self.embed_dim, org_num_embeddings=config.vocab_size + ) + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + self.h = nn.ModuleList( + [ + GPTBigCodeBlock(i, config, cache_config, quant_config) + for i in range(config.num_hidden_layers) + ] + ) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + for i in range(len(self.h)): + layer = self.h[i] + hidden_states = layer(hidden_states, input_metadata) + + hidden_states = self.ln_f(hidden_states) + return hidden_states + + +class GPTBigCodeForCausalLM(nn.Module): + packed_modules_mapping = {"c_attn": ["c_attn"]} + + supported_lora_modules = ["c_fc", "c_proj", "wte", "c_attn"] + + embedding_modules = { + "wte": "input_embeddings", + "lm_head": "output_embeddings", + } + + embedding_padding_modules = [] + + def __init__( + self, + config: GPTBigCodeConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ): + super().__init__() + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.transformer = GPTBigCodeModel( + config, cache_config, quant_config, lora_config + ) + self.lm_head = self.transformer.wte + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.logits_processor = LogitsProcessor(config) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.transformer(input_ids, positions, input_metadata) + return self.logits_processor( + input_ids, hidden_states, self.lm_head.weight, input_metadata + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in weights: + if "lm_head.weight" in name: + continue + if ".attn.bias" in name: + # Skip attention mask. + # NOTE: "c_attn.bias" should not be skipped. + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + # TODO (@robertgshaw2-neuralmagic): move to fp8 linear method + if "c_attn.input_scale" in name or "c_attn.weight_scale" in name: + weight_loader(param, loaded_weight, "q") + weight_loader(param, loaded_weight, "k") + weight_loader(param, loaded_weight, "v") + else: + weight_loader(param, loaded_weight) + + +EntryClass = GPTBigCodeForCausalLM diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index cbf29055c04..38297b7d6ec 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -1,3 +1,18 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + # Adapted from # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1 """Inference-only Grok1 model.""" @@ -37,7 +52,7 @@ from sglang.srt.layers.fused_moe import fused_moe from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.controller.model_runner import InputMetadata +from sglang.srt.model_executor.forward_batch_info import InputMetadata use_fused = True @@ -601,6 +616,7 @@ def __init__( # Monkey patch _prepare_weights to load pre-sharded weights setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights) + @torch.no_grad() def forward( self, input_ids: torch.Tensor, diff --git a/python/sglang/srt/models/internlm2.py b/python/sglang/srt/models/internlm2.py new file mode 100644 index 00000000000..394d005042d --- /dev/null +++ b/python/sglang/srt/models/internlm2.py @@ -0,0 +1,332 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# -*- coding: utf-8 -*- +# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/7f62077af5159c625fe3ad1c812e6c1a2b93ba3b/vllm/model_executor/models/internlm2.py + +from typing import Any, Dict, Iterable, Optional, Tuple + +import torch +from torch import nn +from transformers import PretrainedConfig +from vllm.config import CacheConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.model_executor.forward_batch_info import InputMetadata + + +class InternLM2MLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config + ) + self.w2 = RowParallelLinear( + intermediate_size, hidden_size, bias=False, quant_config=quant_config + ) + if hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now." + ) + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.w2(x) + return x + + +class InternLM2Attention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + layer_id: int = 0, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.wqkv = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + ) + self.wo = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = RadixAttention( + self.num_heads, self.head_dim, self.scaling, self.num_kv_heads, layer_id + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.wqkv(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, input_metadata) + output, _ = self.wo(attn_output) + return output + + +class InternLMDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + layer_id: int = 0, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + self.attention = InternLM2Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + layer_id=layer_id, + quant_config=quant_config, + ) + self.feed_forward = InternLM2MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + ) + self.attention_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.attention_norm(hidden_states) + else: + hidden_states, residual = self.attention_norm(hidden_states, residual) + hidden_states = self.attention( + positions=positions, + hidden_states=hidden_states, + input_metadata=input_metadata, + ) + + # Fully Connected + hidden_states, residual = self.ffn_norm(hidden_states, residual) + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + + +class InternLM2Model(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.tok_embeddings = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.layers = nn.ModuleList( + [ + InternLMDecoderLayer(config, i, quant_config) + for i in range(config.num_hidden_layers) + ] + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_metadata: InputMetadata, + input_embeds: torch.Tensor = None, + ) -> torch.Tensor: + if input_embeds is None: + hidden_states = self.tok_embeddings(input_ids) + else: + hidden_states = input_embeds + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + input_metadata, + residual, + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class InternLM2ForCausalLM(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.quant_config = quant_config + self.model = InternLM2Model(config, quant_config) + self.output = ParallelLMHead(config.vocab_size, config.hidden_size) + self.logits_processor = LogitsProcessor(config) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_metadata: InputMetadata, + input_embeds: torch.Tensor = None, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) + return self.logits_processor( + input_ids, hidden_states, self.output.weight, input_metadata + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "w1", 0), + ("gate_up_proj", "w3", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + if "wqkv" in name: + config = self.config + kv_groups = config.num_attention_heads // config.num_key_value_heads + head_dim = config.hidden_size // config.num_attention_heads + loaded_weight = loaded_weight.view( + -1, 2 + kv_groups, head_dim, loaded_weight.shape[-1] + ) + wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1], dim=1) + wq = wq.reshape(-1, wq.shape[-1]) + wk = wk.reshape(-1, wk.shape[-1]) + wv = wv.reshape(-1, wv.shape[-1]) + weight_loader = param.weight_loader + weight_loader(param, wq, "q") + weight_loader(param, wk, "k") + weight_loader(param, wv, "v") + else: + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + + +EntryClass = InternLM2ForCausalLM diff --git a/python/sglang/srt/models/llama2.py b/python/sglang/srt/models/llama2.py index 95ba71ee94d..20f8970f7d0 100644 --- a/python/sglang/srt/models/llama2.py +++ b/python/sglang/srt/models/llama2.py @@ -1,3 +1,18 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + # Adapted from # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1 """Inference-only LLaMA model compatible with HuggingFace weights.""" @@ -5,14 +20,10 @@ from typing import Any, Dict, Iterable, Optional, Tuple import torch -import tqdm from torch import nn from transformers import LlamaConfig from vllm.config import CacheConfig -from vllm.distributed import ( - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, -) +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( @@ -28,9 +39,9 @@ ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.logits_processor import LogitProcessorOutput, LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.controller.model_runner import InputMetadata +from sglang.srt.model_executor.forward_batch_info import InputMetadata class LlamaMLP(nn.Module): @@ -40,6 +51,7 @@ def __init__( intermediate_size: int, hidden_act: str, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -47,12 +59,14 @@ def __init__( [intermediate_size] * 2, bias=False, quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", ) self.down_proj = RowParallelLinear( intermediate_size, hidden_size, bias=False, quant_config=quant_config, + prefix=f"{prefix}.down_proj", ) if hidden_act != "silu": raise ValueError( @@ -71,6 +85,7 @@ def forward(self, x): class LlamaAttention(nn.Module): def __init__( self, + config: LlamaConfig, hidden_size: int, num_heads: int, num_kv_heads: int, @@ -80,6 +95,7 @@ def __init__( rope_is_neox_style: bool = True, max_position_embeddings: int = 8192, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = hidden_size @@ -97,7 +113,10 @@ def __init__( # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - self.head_dim = hidden_size // self.total_num_heads + # MistralConfig has an optional head_dim introduced by Mistral-Nemo + self.head_dim = getattr( + config, "head_dim", self.hidden_size // self.total_num_heads + ) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -111,12 +130,14 @@ def __init__( self.total_num_kv_heads, bias=False, quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, quant_config=quant_config, + prefix=f"{prefix}.o_proj", ) self.rotary_emb = get_rope( @@ -155,6 +176,7 @@ def __init__( config: LlamaConfig, layer_id: int = 0, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -163,12 +185,13 @@ def __init__( if rope_scaling is not None and getattr( config, "original_max_position_embeddings", None ): - rope_scaling[ - "original_max_position_embeddings" - ] = config.original_max_position_embeddings + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings + ) rope_is_neox_style = getattr(config, "rope_is_neox_style", True) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.self_attn = LlamaAttention( + config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=config.num_key_value_heads, @@ -178,12 +201,14 @@ def __init__( rope_is_neox_style=rope_is_neox_style, max_position_embeddings=max_position_embeddings, quant_config=quant_config, + prefix=f"{prefix}.self_attn", ) self.mlp = LlamaMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, + prefix=f"{prefix}.mlp", ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( @@ -231,7 +256,9 @@ def __init__( ) self.layers = nn.ModuleList( [ - LlamaDecoderLayer(config, i, quant_config=quant_config) + LlamaDecoderLayer( + config, i, quant_config=quant_config, prefix=f"model.layers.{i}" + ) for i in range(config.num_hidden_layers) ] ) @@ -267,6 +294,7 @@ def __init__( config: LlamaConfig, quant_config: Optional[QuantizationConfig] = None, cache_config: Optional[CacheConfig] = None, + efficient_weight_load=False, ) -> None: super().__init__() self.config = config @@ -275,19 +303,43 @@ def __init__( self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) + @torch.no_grad() def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, input_metadata: InputMetadata, input_embeds: torch.Tensor = None, - ) -> torch.Tensor: + ) -> LogitProcessorOutput: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) return self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def get_module_name(self, name): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id, num_shard) + ("qkv_proj", "q_proj", "q", 3), + ("qkv_proj", "k_proj", "k", 3), + ("qkv_proj", "v_proj", "v", 3), + ("gate_up_proj", "gate_proj", 0, 2), + ("gate_up_proj", "up_proj", 1, 2), + ] + for param_name, weight_name, shard_id, num_shard in stacked_params_mapping: + if weight_name in name: + return ( + name.replace(weight_name, param_name)[: -len(".weight")], + num_shard, + ) + return name[: -len(".weight")], 1 + + def get_num_params(self): + params_dict = dict(self.named_parameters()) + return len(params_dict) + + def load_weights( + self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None + ): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -297,15 +349,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) - if get_tensor_model_parallel_rank() == 0: - weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5)) - for name, loaded_weight in weights: + + def load_weights_per_param(name, loaded_weight): if "rotary_emb.inv_freq" in name or "projector" in name: - continue + return if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. - continue + return for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue @@ -322,12 +373,18 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): else: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: - continue + return if name.startswith("model.vision_tower") and name not in params_dict: - continue + return param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + if name is None or loaded_weight is None: + for name, loaded_weight in weights: + load_weights_per_param(name, loaded_weight) + else: + load_weights_per_param(name, loaded_weight) + EntryClass = LlamaForCausalLM diff --git a/python/sglang/srt/models/llama_classification.py b/python/sglang/srt/models/llama_classification.py index eb9dde45c1a..02224971d6a 100644 --- a/python/sglang/srt/models/llama_classification.py +++ b/python/sglang/srt/models/llama_classification.py @@ -1,3 +1,18 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + from typing import Iterable, Optional, Tuple import torch @@ -10,7 +25,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitProcessorOutput -from sglang.srt.managers.controller.model_runner import InputMetadata +from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.models.llama2 import LlamaModel @@ -31,6 +46,7 @@ def __init__( ) self.eos_token_id = config.eos_token_id + @torch.no_grad() def forward( self, input_ids: torch.Tensor, @@ -53,9 +69,9 @@ def forward( next_token_logits=scores, next_token_logprobs=scores, normalized_prompt_logprobs=scores, - prefill_token_logprobs=torch.ones_like(input_ids), - prefill_top_logprobs=None, - decode_top_logprobs=None, + input_token_logprobs=torch.ones_like(input_ids), + input_top_logprobs=None, + output_top_logprobs=None, ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/llama_embedding.py b/python/sglang/srt/models/llama_embedding.py new file mode 100644 index 00000000000..e8e6780472d --- /dev/null +++ b/python/sglang/srt/models/llama_embedding.py @@ -0,0 +1,88 @@ +from typing import Iterable, Optional, Tuple + +import torch +from torch import nn +from transformers import LlamaConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType +from sglang.srt.model_executor.model_runner import InputMetadata +from sglang.srt.models.llama2 import LlamaForCausalLM, LlamaModel + + +class LlamaEmbeddingModel(nn.Module): + def __init__( + self, + config: LlamaConfig, + quant_config=None, + cache_config=None, + efficient_weight_load=False, + ) -> None: + super().__init__() + self.model = LlamaModel(config, quant_config=quant_config) + self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_metadata: InputMetadata, + input_embeds: torch.Tensor = None, + ) -> EmbeddingPoolerOutput: + hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) + return self.pooler(hidden_states, input_metadata) + + def load_weights( + self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None + ): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.model.named_parameters()) + + def load_weights_per_param(name, loaded_weight): + if "rotary_emb.inv_freq" in name or "projector" in name: + return + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + return + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if name.startswith("model.vision_tower") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + return + if name.startswith("model.vision_tower") and name not in params_dict: + return + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + if name is None or loaded_weight is None: + for name, loaded_weight in weights: + load_weights_per_param(name, loaded_weight) + else: + load_weights_per_param(name, loaded_weight) + + +EntryClass = LlamaEmbeddingModel +# compat: e5-mistral model.config class == MistralModel +EntryClassRemapping = [("MistralModel", LlamaEmbeddingModel)] diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index 93d492a201d..86cc385a3cf 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -1,3 +1,18 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + """Inference-only LLaVa model compatible with HuggingFace weights.""" from typing import Iterable, List, Optional, Tuple @@ -19,13 +34,12 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from sglang.srt.managers.controller.infer_batch import ForwardMode -from sglang.srt.managers.controller.model_runner import InputMetadata from sglang.srt.mm_utils import ( get_anyres_image_grid_shape, unpad_image, unpad_image_shape, ) +from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata from sglang.srt.models.llama2 import LlamaForCausalLM from sglang.srt.models.mistral import MistralForCausalLM from sglang.srt.models.qwen2 import Qwen2ForCausalLM @@ -97,6 +111,7 @@ def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor: return image_features + @torch.no_grad() def forward( self, input_ids: torch.LongTensor, diff --git a/python/sglang/srt/models/llavavid.py b/python/sglang/srt/models/llavavid.py index 47e20583c39..8b81251d692 100644 --- a/python/sglang/srt/models/llavavid.py +++ b/python/sglang/srt/models/llavavid.py @@ -1,3 +1,18 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + """Inference-only LLaVa video model compatible with HuggingFace weights.""" from typing import Iterable, List, Optional, Tuple @@ -11,13 +26,12 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from sglang.srt.managers.controller.infer_batch import ForwardMode -from sglang.srt.managers.controller.model_runner import InputMetadata from sglang.srt.mm_utils import ( get_anyres_image_grid_shape, unpad_image, unpad_image_shape, ) +from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata from sglang.srt.models.llama2 import LlamaForCausalLM @@ -106,6 +120,7 @@ def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor: return image_features + @torch.no_grad() def forward( self, input_ids: torch.LongTensor, diff --git a/python/sglang/srt/models/minicpm.py b/python/sglang/srt/models/minicpm.py index 072bf99ab19..bf572855e66 100644 --- a/python/sglang/srt/models/minicpm.py +++ b/python/sglang/srt/models/minicpm.py @@ -1,3 +1,18 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + """Inference-only MiniCPM model compatible with HuggingFace weights.""" import math @@ -5,12 +20,9 @@ import torch from torch import nn - from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size - from vllm.model_executor.layers.activation import SiluAndMul - from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, @@ -27,11 +39,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.controller.model_runner import InputMetadata +from sglang.srt.model_executor.forward_batch_info import InputMetadata class MiniCPMMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -67,7 +78,6 @@ def forward(self, x): class MiniCPMAttention(nn.Module): - def __init__( self, hidden_size: int, @@ -152,7 +162,6 @@ def forward( class MiniCPMDecoderLayer(nn.Module): - def __init__( self, config, @@ -217,7 +226,6 @@ def forward( class MiniCPMModel(nn.Module): - def __init__( self, config, @@ -274,7 +282,7 @@ def __init__( ) -> None: super().__init__() self.config = config - + self.num_experts = getattr(self.config, "num_experts", 0) self.quant_config = quant_config self.model = MiniCPMModel(config, quant_config=quant_config) @@ -290,6 +298,7 @@ def __init__( self.logits_processor = LogitsProcessor(config) + @torch.no_grad() def forward( self, input_ids: torch.Tensor, diff --git a/python/sglang/srt/models/mistral.py b/python/sglang/srt/models/mistral.py index 54794e6fc1c..614c1c1d747 100644 --- a/python/sglang/srt/models/mistral.py +++ b/python/sglang/srt/models/mistral.py @@ -1,3 +1,18 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + """Inference-only Mistral model.""" from sglang.srt.models.llama2 import LlamaForCausalLM diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index abcde6de500..63053ac50bc 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -1,3 +1,18 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + # Adapted from # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1 """Inference-only Mixtral model.""" @@ -35,7 +50,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.controller.model_runner import InputMetadata +from sglang.srt.model_executor.forward_batch_info import InputMetadata class MixtralMoE(nn.Module): @@ -460,6 +475,7 @@ def __init__( self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) + @torch.no_grad() def forward( self, input_ids: torch.Tensor, diff --git a/python/sglang/srt/models/mixtral_quant.py b/python/sglang/srt/models/mixtral_quant.py index aa8f8a75945..07caf383343 100644 --- a/python/sglang/srt/models/mixtral_quant.py +++ b/python/sglang/srt/models/mixtral_quant.py @@ -1,3 +1,18 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + # Adapted from # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral_quant.py#L1 """Inference-only Mixtral model.""" @@ -30,7 +45,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.controller.model_runner import InputMetadata +from sglang.srt.model_executor.forward_batch_info import InputMetadata class MixtralMLP(nn.Module): @@ -322,6 +337,7 @@ def __init__( self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) + @torch.no_grad() def forward( self, input_ids: torch.Tensor, diff --git a/python/sglang/srt/models/qwen.py b/python/sglang/srt/models/qwen.py index 9c59d14fe27..ffc512b1ca2 100644 --- a/python/sglang/srt/models/qwen.py +++ b/python/sglang/srt/models/qwen.py @@ -1,3 +1,18 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + # Adapted from # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/qwen.py#L1 from typing import Any, Dict, Iterable, Optional, Tuple @@ -24,7 +39,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.controller.model_runner import InputMetadata +from sglang.srt.model_executor.forward_batch_info import InputMetadata class QWenMLP(nn.Module): @@ -237,6 +252,7 @@ def __init__( self.lm_head = ParallelLMHead(vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) + @torch.no_grad() def forward( self, input_ids: torch.Tensor, diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index dc50075caf6..dec962bf0af 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -1,3 +1,18 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + # Adapted from llama2.py # Modify details for the adaptation of Qwen2 model. """Inference-only Qwen2 model compatible with HuggingFace weights.""" @@ -24,7 +39,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.controller.model_runner import InputMetadata +from sglang.srt.model_executor.forward_batch_info import InputMetadata Qwen2Config = None @@ -261,6 +276,7 @@ def __init__( self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) + @torch.no_grad() def forward( self, input_ids: torch.Tensor, @@ -312,6 +328,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + if ( + self.config.tie_word_embeddings + and name == "model.embed_tokens.weight" + ): + weight_loader(params_dict["lm_head.weight"], loaded_weight) EntryClass = Qwen2ForCausalLM diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 79187cd4351..f96f7e0e484 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -1,3 +1,18 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + # coding=utf-8 # Adapted from # https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen2_moe.py @@ -8,34 +23,36 @@ import torch.nn.functional as F from torch import nn from transformers import PretrainedConfig - from vllm.config import CacheConfig -from vllm.distributed import (get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) +from vllm.distributed import ( + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.controller.model_runner import InputMetadata +from sglang.srt.model_executor.forward_batch_info import InputMetadata -class Qwen2MoeMLP(nn.Module): +class Qwen2MoeMLP(nn.Module): def __init__( self, hidden_size: int, @@ -46,17 +63,20 @@ def __init__( ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, bias=False, - quant_config=quant_config) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config, - reduce_results=reduce_results) + quant_config=quant_config, + reduce_results=reduce_results, + ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -67,7 +87,6 @@ def forward(self, x): class Qwen2MoeSparseMoeBlock(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -79,20 +98,22 @@ def __init__( if self.tp_size > config.num_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {config.num_experts}.") - - self.experts = FusedMoE(num_experts=config.num_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_topk_prob, - quant_config=quant_config) - - self.gate = ReplicatedLinear(config.hidden_size, - config.num_experts, - bias=False, - quant_config=None) + f"the number of experts {config.num_experts}." + ) + + self.experts = FusedMoE( + num_experts=config.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + ) + + self.gate = ReplicatedLinear( + config.hidden_size, config.num_experts, bias=False, quant_config=None + ) if config.shared_expert_intermediate_size > 0: self.shared_expert = Qwen2MoeMLP( hidden_size=config.hidden_size, @@ -103,9 +124,7 @@ def __init__( ) else: self.shared_expert = None - self.shared_expert_gate = torch.nn.Linear(config.hidden_size, - 1, - bias=False) + self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape @@ -114,24 +133,24 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.shared_expert is not None: shared_output = self.shared_expert(hidden_states) if self.shared_expert_gate is not None: - shared_output = F.sigmoid( - self.shared_expert_gate(hidden_states)) * shared_output + shared_output = ( + F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_output + ) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts(hidden_states=hidden_states, - router_logits=router_logits) + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) if shared_output is not None: final_hidden_states = final_hidden_states + shared_output if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states.view(num_tokens, hidden_dim) class Qwen2MoeAttention(nn.Module): - def __init__( self, hidden_size: int, @@ -190,17 +209,19 @@ def __init__( base=rope_theta, rope_scaling=rope_scaling, ) - self.attn = RadixAttention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - layer_id=layer_id) + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + ) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata + input_metadata: InputMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -211,7 +232,6 @@ def forward( class Qwen2MoeDecoderLayer(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -223,8 +243,7 @@ def __init__( self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.self_attn = Qwen2MoeAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -239,13 +258,13 @@ def __init__( # Note: Qwen/Qwen2-57B-A14B-Instruct does not have # `mlp_only_layers` in the config. - mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else - config.mlp_only_layers) + mlp_only_layers = ( + [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers + ) if (layer_id not in mlp_only_layers) and ( - config.num_experts > 0 and - (layer_id + 1) % config.decoder_sparse_step == 0): - self.mlp = Qwen2MoeSparseMoeBlock(config=config, - quant_config=quant_config) + config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0 + ): + self.mlp = Qwen2MoeSparseMoeBlock(config=config, quant_config=quant_config) else: self.mlp = Qwen2MoeMLP( hidden_size=config.hidden_size, @@ -253,10 +272,10 @@ def __init__( hidden_act=config.hidden_act, quant_config=quant_config, ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -270,23 +289,20 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - input_metadata=input_metadata + input_metadata=input_metadata, ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual class Qwen2MoeModel(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -301,13 +317,14 @@ def __init__( config.vocab_size, config.hidden_size, ) - self.layers = nn.ModuleList([ - Qwen2MoeDecoderLayer(config, - layer_id, - cache_config, - quant_config=quant_config) - for layer_id in range(config.num_hidden_layers) - ]) + self.layers = nn.ModuleList( + [ + Qwen2MoeDecoderLayer( + config, layer_id, cache_config, quant_config=quant_config + ) + for layer_id in range(config.num_hidden_layers) + ] + ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( @@ -315,7 +332,7 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, input_metadata: InputMetadata, - input_embeds: torch.Tensor = None + input_embeds: torch.Tensor = None, ) -> torch.Tensor: if input_embeds is None: hidden_states = self.embed_tokens(input_ids) @@ -324,10 +341,9 @@ def forward( residual = None for i in range(len(self.layers)): layer = self.layers[i] - hidden_states, residual = layer(positions, - hidden_states, - input_metadata, - residual) + hidden_states, residual = layer( + positions, hidden_states, input_metadata, residual + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -346,37 +362,34 @@ def __init__( self.config = config self.quant_config = quant_config self.model = Qwen2MoeModel(config, cache_config, quant_config) - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config) + self.lm_head = ParallelLMHead( + config.vocab_size, config.hidden_size, quant_config=quant_config + ) self.logits_processor = LogitsProcessor(config) - self.sampler = Sampler() + @torch.no_grad() def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, input_metadata: InputMetadata, - input_embeds: torch.Tensor = None + input_embeds: torch.Tensor = None, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, input_metadata, - input_embeds) - return self.logits_processor(input_ids, hidden_states, self.lm_head.weight, - input_metadata) - - def compute_logits(self, input_ids: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata) -> torch.Tensor: - logits = self.logits_processor(input_ids, hidden_states, self.lm_head.weight, - input_metadata) - return logits + hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) + return self.logits_processor( + input_ids, hidden_states, self.lm_head.weight, input_metadata + ) - def sample( + def compute_logits( self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens + input_ids: torch.Tensor, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + logits = self.logits_processor( + input_ids, hidden_states, self.lm_head.weight, input_metadata + ) + return logits def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ @@ -391,18 +404,27 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): expert_params_mapping = [ # These are the weights for the experts # (param_name, weight_name, expert_id, shard_id) - ("experts.w13_weight" if weight_name in ["gate_proj", "up_proj"] - else "experts.w2_weight", - f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id) - for expert_id in range(self.config.num_experts) for shard_id, - weight_name in enumerate(["gate_proj", "down_proj", "up_proj"]) + ( + ( + "experts.w13_weight" + if weight_name in ["gate_proj", "up_proj"] + else "experts.w2_weight" + ), + f"experts.{expert_id}.{weight_name}.weight", + expert_id, + shard_id, + ) + for expert_id in range(self.config.num_experts) + for shard_id, weight_name in enumerate( + ["gate_proj", "down_proj", "up_proj"] + ) ] params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue @@ -433,11 +455,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): name = name.replace(weight_name, param_name) param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - weight_name, - shard_id=shard_id, - expert_id=expert_id) + weight_loader( + param, + loaded_weight, + weight_name, + shard_id=shard_id, + expert_id=expert_id, + ) break else: # Skip loading extra bias for GPTQ models. @@ -447,8 +471,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) + EntryClass = Qwen2MoeForCausalLM diff --git a/python/sglang/srt/models/stablelm.py b/python/sglang/srt/models/stablelm.py index 875ddd70b8d..aeaa46ab122 100644 --- a/python/sglang/srt/models/stablelm.py +++ b/python/sglang/srt/models/stablelm.py @@ -1,3 +1,18 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + # Adapted from: # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/stablelm.py#L1 """Inference-only StableLM-2 (https://huggingface.co/stabilityai/stablelm-2-1_6b) @@ -25,7 +40,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.controller.model_runner import InputMetadata +from sglang.srt.model_executor.forward_batch_info import InputMetadata class StablelmMLP(nn.Module): @@ -235,6 +250,7 @@ def __init__( self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) + @torch.no_grad() def forward( self, input_ids: torch.Tensor, diff --git a/python/sglang/srt/models/yivl.py b/python/sglang/srt/models/yivl.py index 3016bfe13f9..11d4cda1c00 100644 --- a/python/sglang/srt/models/yivl.py +++ b/python/sglang/srt/models/yivl.py @@ -1,3 +1,18 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + """Inference-only Yi-VL model.""" from typing import Iterable, Optional, Tuple diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py new file mode 100644 index 00000000000..c1213839194 --- /dev/null +++ b/python/sglang/srt/openai_api/adapter.py @@ -0,0 +1,1116 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Conversion between OpenAI APIs and native SRT APIs""" + +import asyncio +import json +import os +import time +import uuid +from http import HTTPStatus +from typing import Dict, List, Optional + +from fastapi import HTTPException, Request, UploadFile +from fastapi.responses import JSONResponse, StreamingResponse +from pydantic import ValidationError + +from sglang.srt.conversation import ( + Conversation, + SeparatorStyle, + chat_template_exists, + generate_chat_conv, + register_conv_template, +) +from sglang.srt.managers.io_struct import GenerateReqInput +from sglang.srt.openai_api.protocol import ( + BatchRequest, + BatchResponse, + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, + ChatCompletionTokenLogprob, + ChatMessage, + ChoiceLogprobs, + CompletionRequest, + CompletionResponse, + CompletionResponseChoice, + CompletionResponseStreamChoice, + CompletionStreamResponse, + DeltaMessage, + EmbeddingRequest, + EmbeddingResponse, + ErrorResponse, + FileDeleteResponse, + FileRequest, + FileResponse, + LogProbs, + TopLogprob, + UsageInfo, +) + +chat_template_name = None + + +class FileMetadata: + def __init__(self, filename: str, purpose: str): + self.filename = filename + self.purpose = purpose + + +# In-memory storage for batch jobs and files +batch_storage: Dict[str, BatchResponse] = {} +file_id_request: Dict[str, FileMetadata] = {} +file_id_response: Dict[str, FileResponse] = {} +# map file id to file path in SGlang backend +file_id_storage: Dict[str, str] = {} + + +# backend storage directory +storage_dir = None + + +def format_finish_reason(finish_reason) -> Optional[str]: + if finish_reason.startswith("None"): + return None + elif finish_reason.startswith("FINISH_MATCHED"): + return "stop" + elif finish_reason.startswith("FINISH_LENGTH"): + return "length" + elif finish_reason.startswith("FINISH_ABORT"): + return "abort" + else: + return "unknown" + + +def create_error_response( + message: str, + err_type: str = "BadRequestError", + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, +): + error = ErrorResponse(message=message, type=err_type, code=status_code.value) + return JSONResponse(content=error.model_dump(), status_code=error.code) + + +def create_streaming_error_response( + message: str, + err_type: str = "BadRequestError", + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, +) -> str: + error = ErrorResponse(message=message, type=err_type, code=status_code.value) + json_str = json.dumps({"error": error.model_dump()}) + return json_str + + +def load_chat_template_for_openai_api(chat_template_arg): + global chat_template_name + + print(f"Use chat template: {chat_template_arg}") + if not chat_template_exists(chat_template_arg): + if not os.path.exists(chat_template_arg): + raise RuntimeError( + f"Chat template {chat_template_arg} is not a built-in template name " + "or a valid chat template file path." + ) + with open(chat_template_arg, "r") as filep: + template = json.load(filep) + try: + sep_style = SeparatorStyle[template["sep_style"]] + except KeyError: + raise ValueError( + f"Unknown separator style: {template['sep_style']}" + ) from None + register_conv_template( + Conversation( + name=template["name"], + system_template=template["system"] + "\n{system_message}", + system_message=template.get("system_message", ""), + roles=(template["user"], template["assistant"]), + sep_style=sep_style, + sep=template.get("sep", "\n"), + stop_str=template["stop_str"], + ), + override=True, + ) + chat_template_name = template["name"] + else: + chat_template_name = chat_template_arg + + +async def v1_files_create(file: UploadFile, purpose: str, file_storage_pth: str = None): + try: + global storage_dir + if file_storage_pth: + storage_dir = file_storage_pth + # Read the file content + file_content = await file.read() + + # Create an instance of RequestBody + request_body = FileRequest(file=file_content, purpose=purpose) + + # Save the file to the sglang_oai_storage directory + os.makedirs(storage_dir, exist_ok=True) + file_id = f"backend_input_file-{uuid.uuid4()}" + filename = f"{file_id}.jsonl" + file_path = os.path.join(storage_dir, filename) + + with open(file_path, "wb") as f: + f.write(request_body.file) + + # add info to global file map + file_id_request[file_id] = FileMetadata(filename=file.filename, purpose=purpose) + file_id_storage[file_id] = file_path + + # Return the response in the required format + response = FileResponse( + id=file_id, + bytes=len(request_body.file), + created_at=int(time.time()), + filename=file.filename, + purpose=request_body.purpose, + ) + file_id_response[file_id] = response + + return response + except ValidationError as e: + return {"error": "Invalid input", "details": e.errors()} + + +async def v1_delete_file(file_id: str): + # Retrieve the file job from the in-memory storage + file_response = file_id_response.get(file_id) + if file_response is None: + raise HTTPException(status_code=404, detail="File not found") + file_path = file_id_storage.get(file_id) + if file_path is None: + raise HTTPException(status_code=404, detail="File not found") + os.remove(file_path) + del file_id_response[file_id] + del file_id_storage[file_id] + return FileDeleteResponse(id=file_id, deleted=True) + + +async def v1_batches(tokenizer_manager, raw_request: Request): + try: + body = await raw_request.json() + + batch_request = BatchRequest(**body) + + batch_id = f"batch_{uuid.uuid4()}" + + # Create an instance of BatchResponse + batch_response = BatchResponse( + id=batch_id, + endpoint=batch_request.endpoint, + input_file_id=batch_request.input_file_id, + completion_window=batch_request.completion_window, + created_at=int(time.time()), + metadata=batch_request.metadata, + ) + + batch_storage[batch_id] = batch_response + + # Start processing the batch asynchronously + asyncio.create_task(process_batch(tokenizer_manager, batch_id, batch_request)) + + # Return the initial batch_response + return batch_response + + except ValidationError as e: + return {"error": "Invalid input", "details": e.errors()} + except Exception as e: + return {"error": str(e)} + + +async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRequest): + try: + # Update the batch status to "in_progress" + batch_storage[batch_id].status = "in_progress" + batch_storage[batch_id].in_progress_at = int(time.time()) + + # Retrieve the input file content + input_file_request = file_id_request.get(batch_request.input_file_id) + if not input_file_request: + raise ValueError("Input file not found") + + # Parse the JSONL file and process each request + input_file_path = file_id_storage.get(batch_request.input_file_id) + with open(input_file_path, "r", encoding="utf-8") as f: + lines = f.readlines() + + total_requests = len(lines) + completed_requests = 0 + failed_requests = 0 + + all_ret = [] + end_point = batch_storage[batch_id].endpoint + file_request_list = [] + all_requests = [] + for line in lines: + request_data = json.loads(line) + file_request_list.append(request_data) + body = request_data["body"] + if end_point == "/v1/chat/completions": + all_requests.append(ChatCompletionRequest(**body)) + elif end_point == "/v1/completions": + all_requests.append(CompletionRequest(**body)) + if end_point == "/v1/chat/completions": + adapted_request, request = v1_chat_generate_request( + all_requests, tokenizer_manager + ) + elif end_point == "/v1/completions": + adapted_request, request = v1_generate_request(all_requests) + try: + ret = await tokenizer_manager.generate_request(adapted_request).__anext__() + if not isinstance(ret, list): + ret = [ret] + if end_point == "/v1/chat/completions": + responses = v1_chat_generate_response(request, ret, to_file=True) + else: + responses = v1_generate_response( + request, ret, tokenizer_manager, to_file=True + ) + + except Exception as e: + error_json = { + "id": f"batch_req_{uuid.uuid4()}", + "custom_id": request_data.get("custom_id"), + "response": None, + "error": {"message": str(e)}, + } + all_ret.append(error_json) + failed_requests += len(file_request_list) + + for idx, response in enumerate(responses): + # the batch_req here can be changed to be named within a batch granularity + response_json = { + "id": f"batch_req_{uuid.uuid4()}", + "custom_id": file_request_list[idx].get("custom_id"), + "response": response, + "error": None, + } + all_ret.append(response_json) + completed_requests += 1 + # Write results to a new file + output_file_id = f"backend_result_file-{uuid.uuid4()}" + global storage_dir + output_file_path = os.path.join(storage_dir, f"{output_file_id}.jsonl") + with open(output_file_path, "w", encoding="utf-8") as f: + for ret in all_ret: + f.write(json.dumps(ret) + "\n") + + # Update batch response with output file information + retrieve_batch = batch_storage[batch_id] + retrieve_batch.output_file_id = output_file_id + file_id_storage[output_file_id] = output_file_path + file_id_response[output_file_id] = FileResponse( + id=output_file_id, + bytes=os.path.getsize(output_file_path), + created_at=int(time.time()), + filename=f"{output_file_id}.jsonl", + purpose="batch_result", + ) + # Update batch status to "completed" + retrieve_batch.status = "completed" + retrieve_batch.completed_at = int(time.time()) + retrieve_batch.request_counts = { + "total": total_requests, + "completed": completed_requests, + "failed": failed_requests, + } + + except Exception as e: + print("error in SGlang:", e) + # Update batch status to "failed" + retrieve_batch = batch_storage[batch_id] + retrieve_batch.status = "failed" + retrieve_batch.failed_at = int(time.time()) + retrieve_batch.errors = {"message": str(e)} + + +async def v1_retrieve_batch(batch_id: str): + # Retrieve the batch job from the in-memory storage + batch_response = batch_storage.get(batch_id) + if batch_response is None: + raise HTTPException(status_code=404, detail="Batch not found") + + return batch_response + + +async def v1_retrieve_file(file_id: str): + # Retrieve the batch job from the in-memory storage + file_response = file_id_response.get(file_id) + if file_response is None: + raise HTTPException(status_code=404, detail="File not found") + return file_response + + +async def v1_retrieve_file_content(file_id: str): + file_pth = file_id_storage.get(file_id) + if not file_pth or not os.path.exists(file_pth): + raise HTTPException(status_code=404, detail="File not found") + + def iter_file(): + with open(file_pth, mode="rb") as file_like: + yield from file_like + + return StreamingResponse(iter_file(), media_type="application/octet-stream") + + +def v1_generate_request(all_requests): + prompts = [] + sampling_params_list = [] + return_logprobs = [] + top_logprobs_nums = [] + first_prompt_type = type(all_requests[0].prompt) + + for request in all_requests: + prompt = request.prompt + assert ( + type(prompt) == first_prompt_type + ), "All prompts must be of the same type in file input settings" + prompts.append(prompt) + return_logprobs.append(request.logprobs is not None and request.logprobs > 0) + top_logprobs_nums.append( + request.logprobs if request.logprobs is not None else 0 + ) + sampling_params_list.append( + { + "temperature": request.temperature, + "max_new_tokens": request.max_tokens, + "min_new_tokens": request.min_tokens, + "stop": request.stop, + "stop_token_ids": request.stop_token_ids, + "top_p": request.top_p, + "presence_penalty": request.presence_penalty, + "frequency_penalty": request.frequency_penalty, + "repetition_penalty": request.repetition_penalty, + "regex": request.regex, + "n": request.n, + "ignore_eos": request.ignore_eos, + } + ) + if len(all_requests) > 1 and request.n > 1: + raise ValueError( + "Parallel sampling is not supported for completions from files" + ) + + if len(all_requests) == 1: + prompt = prompts[0] + sampling_params_list = sampling_params_list[0] + return_logprobs = return_logprobs[0] + top_logprobs_nums = top_logprobs_nums[0] + if isinstance(prompt, str) or isinstance(prompt[0], str): + prompt_kwargs = {"text": prompt} + else: + prompt_kwargs = {"input_ids": prompt} + else: + if isinstance(prompts[0], str): + prompt_kwargs = {"text": prompts} + else: + prompt_kwargs = {"input_ids": prompts} + + adapted_request = GenerateReqInput( + **prompt_kwargs, + sampling_params=sampling_params_list, + return_logprob=return_logprobs, + top_logprobs_num=top_logprobs_nums, + return_text_in_logprobs=True, + stream=all_requests[0].stream, + ) + + if len(all_requests) == 1: + return adapted_request, all_requests[0] + return adapted_request, all_requests + + +def v1_generate_response(request, ret, tokenizer_manager, to_file=False): + choices = [] + echo = False + + if (not isinstance(request, list)) and request.echo: + # TODO: handle the case propmt is token ids + if isinstance(request.prompt, list) and isinstance(request.prompt[0], str): + # for the case of multiple str prompts + prompts = request.prompt + elif isinstance(request.prompt, list) and isinstance(request.prompt[0], list): + # for the case of multiple token ids prompts + prompts = [ + tokenizer_manager.tokenizer.decode(prompt, skip_special_tokens=True) + for prompt in request.prompt + ] + elif isinstance(request.prompt, list) and isinstance(request.prompt[0], int): + # for the case of single token ids prompt + prompts = [ + tokenizer_manager.tokenizer.decode( + request.prompt, skip_special_tokens=True + ) + ] + else: + # for the case of single str prompt + prompts = [request.prompt] + echo = True + + for idx, ret_item in enumerate(ret): + text = ret_item["text"] + if isinstance(request, list) and request[idx].echo: + echo = True + text = request[idx].prompt + text + if (not isinstance(request, list)) and echo: + prompt_index = idx // request.n + text = prompts[prompt_index] + text + + logprobs = False + if isinstance(request, list) and request[idx].logprobs: + logprobs = True + elif (not isinstance(request, list)) and request.logprobs: + logprobs = True + if logprobs: + if echo: + input_token_logprobs = ret_item["meta_info"]["input_token_logprobs"] + input_top_logprobs = ret_item["meta_info"]["input_top_logprobs"] + else: + input_token_logprobs = None + input_top_logprobs = None + + logprobs = to_openai_style_logprobs( + input_token_logprobs=input_token_logprobs, + input_top_logprobs=input_top_logprobs, + output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"], + output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"], + ) + else: + logprobs = None + + if to_file: + # to make the choise data json serializable + choice_data = { + "index": 0, + "text": text, + "logprobs": logprobs, + "finish_reason": format_finish_reason( + ret_item["meta_info"]["finish_reason"] + ), + } + else: + choice_data = CompletionResponseChoice( + index=idx, + text=text, + logprobs=logprobs, + finish_reason=format_finish_reason( + ret_item["meta_info"]["finish_reason"] + ), + ) + + choices.append(choice_data) + + if to_file: + responses = [] + for i, choice in enumerate(choices): + response = { + "status_code": 200, + "request_id": ret[i]["meta_info"]["id"], + "body": { + # remain the same but if needed we can change that + "id": ret[i]["meta_info"]["id"], + "object": "text_completion", + "created": int(time.time()), + "model": request[i].model, + "choices": choice, + "usage": { + "prompt_tokens": ret[i]["meta_info"]["prompt_tokens"], + "completion_tokens": ret[i]["meta_info"]["completion_tokens"], + "total_tokens": ret[i]["meta_info"]["prompt_tokens"] + + ret[i]["meta_info"]["completion_tokens"], + }, + "system_fingerprint": None, + }, + } + responses.append(response) + return responses + else: + prompt_tokens = sum( + ret[i]["meta_info"]["prompt_tokens"] for i in range(0, len(ret), request.n) + ) + completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret) + response = CompletionResponse( + id=ret[0]["meta_info"]["id"], + model=request.model, + choices=choices, + usage=UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), + ) + return response + + +async def v1_completions(tokenizer_manager, raw_request: Request): + request_json = await raw_request.json() + all_requests = [CompletionRequest(**request_json)] + adapted_request, request = v1_generate_request(all_requests) + + if adapted_request.stream: + + async def generate_stream_resp(): + stream_buffer = "" + n_prev_token = 0 + try: + async for content in tokenizer_manager.generate_request( + adapted_request, raw_request + ): + text = content["text"] + prompt_tokens = content["meta_info"]["prompt_tokens"] + completion_tokens = content["meta_info"]["completion_tokens"] + + if not stream_buffer: # The first chunk + if request.echo: + if isinstance(request.prompt, str): + # for the case of single str prompts + prompts = request.prompt + elif isinstance(request.prompt, list) and isinstance( + request.prompt[0], int + ): + prompts = tokenizer_manager.tokenizer.decode( + request.prompt, skip_special_tokens=True + ) + + # Prepend prompt in response text. + text = prompts + text + + if request.logprobs: + # The first chunk and echo is enabled. + if not stream_buffer and request.echo: + input_token_logprobs = content["meta_info"][ + "input_token_logprobs" + ] + input_top_logprobs = content["meta_info"][ + "input_top_logprobs" + ] + else: + input_token_logprobs = None + input_top_logprobs = None + + logprobs = to_openai_style_logprobs( + input_token_logprobs=input_token_logprobs, + input_top_logprobs=input_top_logprobs, + output_token_logprobs=content["meta_info"][ + "output_token_logprobs" + ][n_prev_token:], + output_top_logprobs=content["meta_info"][ + "output_top_logprobs" + ][n_prev_token:], + ) + n_prev_token = len( + content["meta_info"]["output_token_logprobs"] + ) + else: + logprobs = None + + delta = text[len(stream_buffer) :] + stream_buffer = stream_buffer + delta + choice_data = CompletionResponseStreamChoice( + index=0, + text=delta, + logprobs=logprobs, + finish_reason=format_finish_reason( + content["meta_info"]["finish_reason"] + ), + ) + chunk = CompletionStreamResponse( + id=content["meta_info"]["id"], + object="text_completion", + choices=[choice_data], + model=request.model, + ) + yield f"data: {chunk.model_dump_json()}\n\n" + if request.stream_options and request.stream_options.include_usage: + usage = UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + + final_usage_chunk = CompletionStreamResponse( + id=str(uuid.uuid4().hex), + choices=[], + model=request.model, + usage=usage, + ) + final_usage_data = final_usage_chunk.model_dump_json( + exclude_unset=True, exclude_none=True + ) + yield f"data: {final_usage_data}\n\n" + except ValueError as e: + error = create_streaming_error_response(str(e)) + yield f"data: {error}\n\n" + yield "data: [DONE]\n\n" + + return StreamingResponse( + generate_stream_resp(), + media_type="text/event-stream", + background=tokenizer_manager.create_abort_task(adapted_request), + ) + + # Non-streaming response. + try: + ret = await tokenizer_manager.generate_request( + adapted_request, raw_request + ).__anext__() + except ValueError as e: + return create_error_response(str(e)) + + if not isinstance(ret, list): + ret = [ret] + + response = v1_generate_response(request, ret, tokenizer_manager) + return response + + +def v1_chat_generate_request(all_requests, tokenizer_manager): + input_ids = [] + sampling_params_list = [] + image_data_list = [] + return_logprobs = [] + top_logprobs_nums = [] + for request in all_requests: + # Prep the data needed for the underlying GenerateReqInput: + # - prompt: The full prompt string. + # - stop: Custom stop tokens. + # - image_data: None or a list of image strings (URLs or base64 strings). + # None skips any image processing in GenerateReqInput. + if not isinstance(request.messages, str): + # Apply chat template and its stop strings. + if chat_template_name is None: + prompt_ids = tokenizer_manager.tokenizer.apply_chat_template( + request.messages, tokenize=True, add_generation_prompt=True + ) + stop = request.stop + image_data = None + else: + conv = generate_chat_conv(request, chat_template_name) + prompt = conv.get_prompt() + image_data = conv.image_data + stop = conv.stop_str or [] + if request.stop: + if isinstance(request.stop, str): + stop.append(request.stop) + else: + stop.extend(request.stop) + prompt_ids = tokenizer_manager.tokenizer.encode(prompt) + else: + # Use the raw prompt and stop strings if the messages is already a string. + prompt_ids = request.messages + stop = request.stop + image_data = None + input_ids.append(prompt_ids) + return_logprobs.append(request.logprobs) + top_logprobs_nums.append(request.top_logprobs) + sampling_params_list.append( + { + "temperature": request.temperature, + "max_new_tokens": request.max_tokens, + "min_new_tokens": request.min_tokens, + "stop": stop, + "stop_token_ids": request.stop_token_ids, + "top_p": request.top_p, + "presence_penalty": request.presence_penalty, + "frequency_penalty": request.frequency_penalty, + "repetition_penalty": request.repetition_penalty, + "regex": request.regex, + "n": request.n, + } + ) + image_data_list.append(image_data) + if len(all_requests) == 1: + input_ids = input_ids[0] + if isinstance(input_ids, str): + prompt_kwargs = {"text": input_ids} + else: + prompt_kwargs = {"input_ids": input_ids} + sampling_params_list = sampling_params_list[0] + image_data = image_data_list[0] + return_logprobs = return_logprobs[0] + top_logprobs_nums = top_logprobs_nums[0] + else: + if isinstance(input_ids[0], str): + prompt_kwargs = {"text": input_ids} + else: + prompt_kwargs = {"input_ids": input_ids} + adapted_request = GenerateReqInput( + **prompt_kwargs, + image_data=image_data, + sampling_params=sampling_params_list, + return_logprob=return_logprobs, + top_logprobs_num=top_logprobs_nums, + stream=all_requests[0].stream, + return_text_in_logprobs=True, + ) + if len(all_requests) == 1: + return adapted_request, all_requests[0] + return adapted_request, all_requests + + +def v1_chat_generate_response(request, ret, to_file=False): + choices = [] + + for idx, ret_item in enumerate(ret): + logprobs = False + if isinstance(request, list) and request[idx].logprobs: + logprobs = True + elif (not isinstance(request, list)) and request.logprobs: + logprobs = True + if logprobs: + logprobs = to_openai_style_logprobs( + output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"], + output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"], + ) + token_logprobs = [] + for token, logprob in zip(logprobs.tokens, logprobs.token_logprobs): + token_bytes = list(token.encode("utf-8")) + top_logprobs = [] + if logprobs.top_logprobs: + for top_token, top_logprob in logprobs.top_logprobs[0].items(): + top_token_bytes = list(top_token.encode("utf-8")) + top_logprobs.append( + TopLogprob( + token=top_token, + bytes=top_token_bytes, + logprob=top_logprob, + ) + ) + token_logprobs.append( + ChatCompletionTokenLogprob( + token=token, + bytes=token_bytes, + logprob=logprob, + top_logprobs=top_logprobs, + ) + ) + + choice_logprobs = ChoiceLogprobs(content=token_logprobs) + else: + choice_logprobs = None + + if to_file: + # to make the choice data json serializable + choice_data = { + "index": 0, + "message": {"role": "assistant", "content": ret_item["text"]}, + "logprobs": choice_logprobs, + "finish_reason": format_finish_reason( + ret_item["meta_info"]["finish_reason"] + ), + } + else: + choice_data = ChatCompletionResponseChoice( + index=idx, + message=ChatMessage(role="assistant", content=ret_item["text"]), + logprobs=choice_logprobs, + finish_reason=format_finish_reason( + ret_item["meta_info"]["finish_reason"] + ), + ) + + choices.append(choice_data) + + if to_file: + responses = [] + + for i, choice in enumerate(choices): + response = { + "status_code": 200, + "request_id": ret[i]["meta_info"]["id"], + "body": { + # remain the same but if needed we can change that + "id": ret[i]["meta_info"]["id"], + "object": "chat.completion", + "created": int(time.time()), + "model": request[i].model, + "choices": choice, + "usage": { + "prompt_tokens": ret[i]["meta_info"]["prompt_tokens"], + "completion_tokens": ret[i]["meta_info"]["completion_tokens"], + "total_tokens": ret[i]["meta_info"]["prompt_tokens"] + + ret[i]["meta_info"]["completion_tokens"], + }, + "system_fingerprint": None, + }, + } + responses.append(response) + return responses + else: + prompt_tokens = sum( + ret[i]["meta_info"]["prompt_tokens"] for i in range(0, len(ret), request.n) + ) + completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret) + response = ChatCompletionResponse( + id=ret[0]["meta_info"]["id"], + model=request.model, + choices=choices, + usage=UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), + ) + return response + + +async def v1_chat_completions(tokenizer_manager, raw_request: Request): + request_json = await raw_request.json() + all_requests = [ChatCompletionRequest(**request_json)] + adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager) + + if adapted_request.stream: + + async def generate_stream_resp(): + is_first = True + + stream_buffer = "" + n_prev_token = 0 + try: + async for content in tokenizer_manager.generate_request( + adapted_request, raw_request + ): + prompt_tokens = content["meta_info"]["prompt_tokens"] + completion_tokens = content["meta_info"]["completion_tokens"] + if request.logprobs: + logprobs = to_openai_style_logprobs( + output_token_logprobs=content["meta_info"][ + "output_token_logprobs" + ][n_prev_token:], + output_top_logprobs=content["meta_info"][ + "output_top_logprobs" + ][n_prev_token:], + ) + + n_prev_token = len( + content["meta_info"]["output_token_logprobs"] + ) + token_logprobs = [] + for token, logprob in zip( + logprobs.tokens, logprobs.token_logprobs + ): + token_bytes = list(token.encode("utf-8")) + top_logprobs = [] + if logprobs.top_logprobs: + for top_token, top_logprob in logprobs.top_logprobs[ + 0 + ].items(): + top_token_bytes = list(top_token.encode("utf-8")) + top_logprobs.append( + TopLogprob( + token=top_token, + bytes=top_token_bytes, + logprob=top_logprob, + ) + ) + token_logprobs.append( + ChatCompletionTokenLogprob( + token=token, + bytes=token_bytes, + logprob=logprob, + top_logprobs=top_logprobs, + ) + ) + + choice_logprobs = ChoiceLogprobs(content=token_logprobs) + + else: + choice_logprobs = None + + if is_first: + # First chunk with role + is_first = False + choice_data = ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(role="assistant"), + finish_reason=format_finish_reason( + content["meta_info"]["finish_reason"] + ), + logprobs=choice_logprobs, + ) + chunk = ChatCompletionStreamResponse( + id=content["meta_info"]["id"], + choices=[choice_data], + model=request.model, + ) + yield f"data: {chunk.model_dump_json()}\n\n" + + text = content["text"] + delta = text[len(stream_buffer) :] + stream_buffer = stream_buffer + delta + choice_data = ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(content=delta), + finish_reason=format_finish_reason( + content["meta_info"]["finish_reason"] + ), + logprobs=choice_logprobs, + ) + chunk = ChatCompletionStreamResponse( + id=content["meta_info"]["id"], + choices=[choice_data], + model=request.model, + ) + yield f"data: {chunk.model_dump_json()}\n\n" + if request.stream_options and request.stream_options.include_usage: + usage = UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + + final_usage_chunk = ChatCompletionStreamResponse( + id=str(uuid.uuid4().hex), + choices=[], + model=request.model, + usage=usage, + ) + final_usage_data = final_usage_chunk.model_dump_json( + exclude_unset=True, exclude_none=True + ) + yield f"data: {final_usage_data}\n\n" + except ValueError as e: + error = create_streaming_error_response(str(e)) + yield f"data: {error}\n\n" + yield "data: [DONE]\n\n" + + return StreamingResponse( + generate_stream_resp(), + media_type="text/event-stream", + background=tokenizer_manager.create_abort_task(adapted_request), + ) + + # Non-streaming response. + try: + ret = await tokenizer_manager.generate_request( + adapted_request, raw_request + ).__anext__() + except ValueError as e: + return create_error_response(str(e)) + if not isinstance(ret, list): + ret = [ret] + + response = v1_chat_generate_response(request, ret) + + return response + + +def v1_embedding_request(all_requests, tokenizer_manager): + prompts = [] + sampling_params_list = [] + first_prompt_type = type(all_requests[0].prompt) + + for request in all_requests: + prompt = request.prompt + assert ( + type(prompt) == first_prompt_type + ), "All prompts must be of the same type in file input settings" + prompts.append(prompt) + + if len(all_requests) == 1: + prompt = prompts[0] + if isinstance(prompt, str) or isinstance(prompt[0], str): + prompt_kwargs = {"text": prompt} + else: + prompt_kwargs = {"input_ids": prompt} + else: + if isinstance(prompts[0], str) or isinstance(propmt[0][0], str): + prompt_kwargs = {"text": prompts} + else: + prompt_kwargs = {"input_ids": prompts} + + adapted_request = EmbeddingReqInput( + **prompt_kwargs, + ) + + if len(all_requests) == 1: + return adapted_request, all_requests[0] + return adapted_request, all_requests + + +def v1_embedding_response(request, ret, to_file=False): + response = [] + for idx, ret_item in enumerate(ret): + response.append( + EmbeddingResponse( + index=idx, + embedding=ret[idx], + object="embedding", + ) + ) + return response + + +async def v1_embeddings(tokenizer_manager, raw_request: Request): + request_json = await raw_request.json() + all_requests = [EmbeddingRequest(**request_json)] + adapted_request, request = v1_embedding_request(all_requests, tokenizer_manager) + + try: + ret = await tokenizer_manager.generate_request( + adapted_request, raw_request + ).__anext__() + except ValueError as e: + return create_error_response(str(e)) + + if not isinstance(ret, list): + ret = [ret] + + response = v1_embedding_response(request, ret) + + return response + + +def to_openai_style_logprobs( + input_token_logprobs=None, + output_token_logprobs=None, + input_top_logprobs=None, + output_top_logprobs=None, +): + ret_logprobs = LogProbs() + + def append_token_logprobs(token_logprobs): + for logprob, _, token_text in token_logprobs: + ret_logprobs.tokens.append(token_text) + ret_logprobs.token_logprobs.append(logprob) + + # Not supported yet + ret_logprobs.text_offset.append(-1) + + def append_top_logprobs(top_logprobs): + for tokens in top_logprobs: + if tokens is not None: + ret_logprobs.top_logprobs.append( + {token[2]: token[0] for token in tokens} + ) + else: + ret_logprobs.top_logprobs.append(None) + + if input_token_logprobs is not None: + append_token_logprobs(input_token_logprobs) + if output_token_logprobs is not None: + append_token_logprobs(output_token_logprobs) + if input_top_logprobs is not None: + append_top_logprobs(input_top_logprobs) + if output_top_logprobs is not None: + append_top_logprobs(output_top_logprobs) + + return ret_logprobs diff --git a/python/sglang/srt/openai_protocol.py b/python/sglang/srt/openai_api/protocol.py similarity index 55% rename from python/sglang/srt/openai_protocol.py rename to python/sglang/srt/openai_api/protocol.py index dfe58e8570c..75f0a1aabaf 100644 --- a/python/sglang/srt/openai_protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -1,3 +1,18 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + """Pydantic models for OpenAI API protocol""" import time @@ -7,6 +22,23 @@ from typing_extensions import Literal +class ModelCard(BaseModel): + """Model cards.""" + + id: str + object: str = "model" + created: int = Field(default_factory=lambda: int(time.time())) + owned_by: str = "sglang" + root: Optional[str] = None + + +class ModelList(BaseModel): + """Model list consists of model cards.""" + + object: str = "list" + data: List[ModelCard] = [] + + class ErrorResponse(BaseModel): object: str = "error" message: str @@ -22,12 +54,89 @@ class LogProbs(BaseModel): top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list) +class TopLogprob(BaseModel): + token: str + bytes: List[int] + logprob: float + + +class ChatCompletionTokenLogprob(BaseModel): + token: str + bytes: List[int] + logprob: float + top_logprobs: List[TopLogprob] + + +class ChoiceLogprobs(BaseModel): + # build for v1/chat/completions response + content: List[ChatCompletionTokenLogprob] + + class UsageInfo(BaseModel): prompt_tokens: int = 0 total_tokens: int = 0 completion_tokens: Optional[int] = 0 +class StreamOptions(BaseModel): + include_usage: Optional[bool] = False + + +class FileRequest(BaseModel): + # https://platform.openai.com/docs/api-reference/files/create + file: bytes # The File object (not file name) to be uploaded + purpose: str = ( + "batch" # The intended purpose of the uploaded file, default is "batch" + ) + + +class FileResponse(BaseModel): + id: str + object: str = "file" + bytes: int + created_at: int + filename: str + purpose: str + + +class FileDeleteResponse(BaseModel): + id: str + object: str = "file" + deleted: bool + + +class BatchRequest(BaseModel): + input_file_id: ( + str # The ID of an uploaded file that contains requests for the new batch + ) + endpoint: str # The endpoint to be used for all requests in the batch + completion_window: str # The time frame within which the batch should be processed + metadata: Optional[dict] = None # Optional custom metadata for the batch + + +class BatchResponse(BaseModel): + id: str + object: str = "batch" + endpoint: str + errors: Optional[dict] = None + input_file_id: str + completion_window: str + status: str = "validating" + output_file_id: Optional[str] = None + error_file_id: Optional[str] = None + created_at: int + in_progress_at: Optional[int] = None + expires_at: Optional[int] = None + finalizing_at: Optional[int] = None + completed_at: Optional[int] = None + failed_at: Optional[int] = None + expired_at: Optional[int] = None + cancelling_at: Optional[int] = None + cancelled_at: Optional[int] = None + request_counts: dict = {"total": 0, "completed": 0, "failed": 0} + metadata: Optional[dict] = None + + class CompletionRequest(BaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/completions/create @@ -44,6 +153,7 @@ class CompletionRequest(BaseModel): seed: Optional[int] = None stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stream: Optional[bool] = False + stream_options: Optional[StreamOptions] = None suffix: Optional[str] = None temperature: Optional[float] = 1.0 top_p: Optional[float] = 1.0 @@ -51,6 +161,10 @@ class CompletionRequest(BaseModel): # Extra parameters for SRT backend only and will be ignored by OpenAI models. regex: Optional[str] = None + ignore_eos: Optional[bool] = False + min_tokens: Optional[int] = 0 + repetition_penalty: Optional[float] = 1.0 + stop_token_ids: Optional[List[int]] = Field(default_factory=list) class CompletionResponseChoice(BaseModel): @@ -82,7 +196,7 @@ class CompletionStreamResponse(BaseModel): created: int = Field(default_factory=lambda: int(time.time())) model: str choices: List[CompletionResponseStreamChoice] - usage: UsageInfo + usage: Optional[UsageInfo] = None class ChatCompletionMessageGenericParam(BaseModel): @@ -134,19 +248,23 @@ class ChatCompletionRequest(BaseModel): logit_bias: Optional[Dict[str, float]] = None logprobs: Optional[bool] = False top_logprobs: Optional[int] = None - max_tokens: Optional[int] = 16 + max_tokens: Optional[int] = None n: Optional[int] = 1 presence_penalty: Optional[float] = 0.0 response_format: Optional[ResponseFormat] = None seed: Optional[int] = None stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stream: Optional[bool] = False + stream_options: Optional[StreamOptions] = None temperature: Optional[float] = 0.7 top_p: Optional[float] = 1.0 user: Optional[str] = None # Extra parameters for SRT backend only and will be ignored by OpenAI models. regex: Optional[str] = None + min_tokens: Optional[int] = 0 + repetition_penalty: Optional[float] = 1.0 + stop_token_ids: Optional[List[int]] = Field(default_factory=list) class ChatMessage(BaseModel): @@ -157,8 +275,8 @@ class ChatMessage(BaseModel): class ChatCompletionResponseChoice(BaseModel): index: int message: ChatMessage - logprobs: Optional[LogProbs] = None - finish_reason: Optional[str] = None + logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None + finish_reason: str class ChatCompletionResponse(BaseModel): @@ -178,7 +296,7 @@ class DeltaMessage(BaseModel): class ChatCompletionResponseStreamChoice(BaseModel): index: int delta: DeltaMessage - logprobs: Optional[LogProbs] = None + logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None finish_reason: Optional[str] = None @@ -188,3 +306,21 @@ class ChatCompletionStreamResponse(BaseModel): created: int = Field(default_factory=lambda: int(time.time())) model: str choices: List[ChatCompletionResponseStreamChoice] + usage: Optional[UsageInfo] = None + + +class EmbeddingRequest(BaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/embeddings/create + input: Union[List[int], List[List[int]], str, List[str]] + model: str + encoding_format: str = "float" + dimensions: int = None + user: Optional[str] = None + + +class EmbeddingResponse(BaseModel): + index: str + embedding: List[float] = None + object: str = "embedding" + usage: Optional[UsageInfo] = None diff --git a/python/sglang/srt/openai_api_adapter.py b/python/sglang/srt/openai_api_adapter.py deleted file mode 100644 index 4306950f013..00000000000 --- a/python/sglang/srt/openai_api_adapter.py +++ /dev/null @@ -1,411 +0,0 @@ -"""Conversion between OpenAI APIs and native SRT APIs""" - -import asyncio -import json -import os -from http import HTTPStatus - -from fastapi import Request -from fastapi.responses import JSONResponse, StreamingResponse - -from sglang.srt.conversation import ( - Conversation, - SeparatorStyle, - chat_template_exists, - generate_chat_conv, - register_conv_template, -) -from sglang.srt.managers.io_struct import GenerateReqInput -from sglang.srt.openai_protocol import ( - ChatCompletionRequest, - ChatCompletionResponse, - ChatCompletionResponseChoice, - ChatCompletionResponseStreamChoice, - ChatCompletionStreamResponse, - ChatMessage, - CompletionRequest, - CompletionResponse, - CompletionResponseChoice, - CompletionResponseStreamChoice, - CompletionStreamResponse, - DeltaMessage, - ErrorResponse, - LogProbs, - UsageInfo, -) - -chat_template_name = None - - -def create_error_response( - message: str, - err_type: str = "BadRequestError", - status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, -): - error = ErrorResponse(message=message, type=err_type, code=status_code.value) - return JSONResponse(content=error.model_dump(), status_code=error.code) - - -def create_streaming_error_response( - message: str, - err_type: str = "BadRequestError", - status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, -) -> str: - error = ErrorResponse(message=message, type=err_type, code=status_code.value) - json_str = json.dumps({"error": error.model_dump()}) - return json_str - - -def load_chat_template_for_openai_api(chat_template_arg): - global chat_template_name - - print(f"Use chat template: {chat_template_arg}") - if not chat_template_exists(chat_template_arg): - if not os.path.exists(chat_template_arg): - raise RuntimeError( - f"Chat template {chat_template_arg} is not a built-in template name " - "or a valid chat template file path." - ) - with open(chat_template_arg, "r") as filep: - template = json.load(filep) - try: - sep_style = SeparatorStyle[template["sep_style"]] - except KeyError: - raise ValueError( - f"Unknown separator style: {template['sep_style']}" - ) from None - register_conv_template( - Conversation( - name=template["name"], - system_template=template["system"] + "\n{system_message}", - system_message=template.get("system_message", ""), - roles=(template["user"], template["assistant"]), - sep_style=sep_style, - sep=template.get("sep", "\n"), - stop_str=template["stop_str"], - ), - override=True, - ) - chat_template_name = template["name"] - else: - chat_template_name = chat_template_arg - - -async def v1_completions(tokenizer_manager, raw_request: Request): - request_json = await raw_request.json() - request = CompletionRequest(**request_json) - - if request.n != 1: - return create_error_response("n != 1 is not supported") - - adapted_request = GenerateReqInput( - text=request.prompt, - sampling_params={ - "temperature": request.temperature, - "max_new_tokens": request.max_tokens, - "stop": request.stop, - "top_p": request.top_p, - "presence_penalty": request.presence_penalty, - "frequency_penalty": request.frequency_penalty, - "regex": request.regex, - }, - return_logprob=request.logprobs is not None and request.logprobs > 0, - top_logprobs_num=request.logprobs if request.logprobs is not None else 0, - return_text_in_logprobs=True, - stream=request.stream, - ) - - if adapted_request.stream: - - async def generate_stream_resp(): - stream_buffer = "" - n_prev_token = 0 - try: - async for content in tokenizer_manager.generate_request( - adapted_request, raw_request - ): - text = content["text"] - prompt_tokens = content["meta_info"]["prompt_tokens"] - completion_tokens = content["meta_info"]["completion_tokens"] - - if not stream_buffer: # The first chunk - if request.echo: - # Prepend prompt in response text. - text = request.prompt + text - - if request.logprobs: - # The first chunk and echo is enabled. - if not stream_buffer and request.echo: - prefill_token_logprobs = content["meta_info"][ - "prefill_token_logprobs" - ] - prefill_top_logprobs = content["meta_info"][ - "prefill_top_logprobs" - ] - else: - prefill_token_logprobs = None - prefill_top_logprobs = None - - logprobs = to_openai_style_logprobs( - prefill_token_logprobs=prefill_token_logprobs, - prefill_top_logprobs=prefill_top_logprobs, - decode_token_logprobs=content["meta_info"][ - "decode_token_logprobs" - ][n_prev_token:], - decode_top_logprobs=content["meta_info"][ - "decode_top_logprobs" - ][n_prev_token:], - ) - - n_prev_token = len( - content["meta_info"]["decode_token_logprobs"] - ) - else: - logprobs = None - - delta = text[len(stream_buffer) :] - stream_buffer = stream_buffer + delta - choice_data = CompletionResponseStreamChoice( - index=0, - text=delta, - logprobs=logprobs, - finish_reason=content["meta_info"]["finish_reason"], - ) - chunk = CompletionStreamResponse( - id=content["meta_info"]["id"], - object="text_completion", - choices=[choice_data], - model=request.model, - usage=UsageInfo( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ), - ) - yield f"data: {chunk.model_dump_json()}\n\n" - except ValueError as e: - error = create_streaming_error_response(str(e)) - yield f"data: {error}\n\n" - yield "data: [DONE]\n\n" - - return StreamingResponse( - generate_stream_resp(), - media_type="text/event-stream", - background=tokenizer_manager.create_abort_task(adapted_request), - ) - - # Non-streaming response. - try: - ret = await tokenizer_manager.generate_request( - adapted_request, raw_request - ).__anext__() - except ValueError as e: - return create_error_response(str(e)) - - ret = ret[0] if isinstance(ret, list) else ret - prompt_tokens = ret["meta_info"]["prompt_tokens"] - completion_tokens = ret["meta_info"]["completion_tokens"] - text = ret["text"] - if request.echo: - text = request.prompt + text - - if request.logprobs: - if request.echo: - prefill_token_logprobs = ret["meta_info"]["prefill_token_logprobs"] - prefill_top_logprobs = ret["meta_info"]["prefill_top_logprobs"] - else: - prefill_token_logprobs = None - prefill_top_logprobs = None - - logprobs = to_openai_style_logprobs( - prefill_token_logprobs=prefill_token_logprobs, - prefill_top_logprobs=prefill_top_logprobs, - decode_token_logprobs=ret["meta_info"]["decode_token_logprobs"], - decode_top_logprobs=ret["meta_info"]["decode_top_logprobs"], - ) - else: - logprobs = None - - choice_data = CompletionResponseChoice( - index=0, - text=text, - logprobs=logprobs, - finish_reason=ret["meta_info"]["finish_reason"], - ) - response = CompletionResponse( - id=ret["meta_info"]["id"], - model=request.model, - choices=[choice_data], - usage=UsageInfo( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ), - ) - return response - - -async def v1_chat_completions(tokenizer_manager, raw_request: Request): - request_json = await raw_request.json() - request = ChatCompletionRequest(**request_json) - - if request.n != 1: - return create_error_response("n != 1 is not supported") - - # Prep the data needed for the underlying GenerateReqInput: - # - prompt: The full prompt string. - # - stop: Custom stop tokens. - # - image_data: None or a list of image strings (URLs or base64 strings). - # None skips any image processing in GenerateReqInput. - if not isinstance(request.messages, str): - # Apply chat template and its stop strings. - if chat_template_name is None: - prompt = tokenizer_manager.tokenizer.apply_chat_template( - request.messages, tokenize=False, add_generation_prompt=True - ) - stop = request.stop - image_data = None - else: - conv = generate_chat_conv(request, chat_template_name) - prompt = conv.get_prompt() - image_data = conv.image_data - stop = conv.stop_str or [] - if request.stop: - if isinstance(request.stop, str): - stop.append(request.stop) - else: - stop.extend(request.stop) - else: - # Use the raw prompt and stop strings if the messages is already a string. - prompt = request.messages - stop = request.stop - image_data = None - - adapted_request = GenerateReqInput( - text=prompt, - image_data=image_data, - sampling_params={ - "temperature": request.temperature, - "max_new_tokens": request.max_tokens, - "stop": stop, - "top_p": request.top_p, - "presence_penalty": request.presence_penalty, - "frequency_penalty": request.frequency_penalty, - "regex": request.regex, - }, - stream=request.stream, - ) - - if adapted_request.stream: - - async def generate_stream_resp(): - is_first = True - - stream_buffer = "" - try: - async for content in tokenizer_manager.generate_request( - adapted_request, raw_request - ): - if is_first: - # First chunk with role - is_first = False - choice_data = ChatCompletionResponseStreamChoice( - index=0, - delta=DeltaMessage(role="assistant"), - finish_reason=content["meta_info"]["finish_reason"], - ) - chunk = ChatCompletionStreamResponse( - id=content["meta_info"]["id"], - choices=[choice_data], - model=request.model, - ) - yield f"data: {chunk.model_dump_json()}\n\n" - - text = content["text"] - delta = text[len(stream_buffer) :] - stream_buffer = stream_buffer + delta - choice_data = ChatCompletionResponseStreamChoice( - index=0, - delta=DeltaMessage(content=delta), - finish_reason=content["meta_info"]["finish_reason"], - ) - chunk = ChatCompletionStreamResponse( - id=content["meta_info"]["id"], - choices=[choice_data], - model=request.model, - ) - yield f"data: {chunk.model_dump_json()}\n\n" - except ValueError as e: - error = create_streaming_error_response(str(e)) - yield f"data: {error}\n\n" - yield "data: [DONE]\n\n" - - return StreamingResponse( - generate_stream_resp(), - media_type="text/event-stream", - background=tokenizer_manager.create_abort_task(adapted_request), - ) - - # Non-streaming response. - try: - ret = await tokenizer_manager.generate_request( - adapted_request, raw_request - ).__anext__() - except ValueError as e: - return create_error_response(str(e)) - - prompt_tokens = ret["meta_info"]["prompt_tokens"] - completion_tokens = ret["meta_info"]["completion_tokens"] - choice_data = ChatCompletionResponseChoice( - index=0, - message=ChatMessage(role="assistant", content=ret["text"]), - finish_reason=ret["meta_info"]["finish_reason"], - ) - response = ChatCompletionResponse( - id=ret["meta_info"]["id"], - model=request.model, - choices=[choice_data], - usage=UsageInfo( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ), - ) - return response - - -def to_openai_style_logprobs( - prefill_token_logprobs=None, - decode_token_logprobs=None, - prefill_top_logprobs=None, - decode_top_logprobs=None, -): - ret_logprobs = LogProbs() - - def append_token_logprobs(token_logprobs): - for logprob, _, token_text in token_logprobs: - ret_logprobs.tokens.append(token_text) - ret_logprobs.token_logprobs.append(logprob) - - # Not supported yet - ret_logprobs.text_offset.append(-1) - - def append_top_logprobs(top_logprobs): - for tokens in top_logprobs: - if tokens is not None: - ret_logprobs.top_logprobs.append( - {token[2]: token[0] for token in tokens} - ) - else: - ret_logprobs.top_logprobs.append(None) - - if prefill_token_logprobs is not None: - append_token_logprobs(prefill_token_logprobs) - if decode_token_logprobs is not None: - append_token_logprobs(decode_token_logprobs) - if prefill_top_logprobs is not None: - append_top_logprobs(prefill_top_logprobs) - if decode_top_logprobs is not None: - append_top_logprobs(decode_top_logprobs) - - return ret_logprobs diff --git a/python/sglang/srt/sampling/penaltylib/__init__.py b/python/sglang/srt/sampling/penaltylib/__init__.py new file mode 100644 index 00000000000..43fff0fca44 --- /dev/null +++ b/python/sglang/srt/sampling/penaltylib/__init__.py @@ -0,0 +1,13 @@ +from .orchestrator import BatchedPenalizerOrchestrator +from .penalizers.frequency_penalty import BatchedFrequencyPenalizer +from .penalizers.min_new_tokens import BatchedMinNewTokensPenalizer +from .penalizers.presence_penalty import BatchedPresencePenalizer +from .penalizers.repetition_penalty import BatchedRepetitionPenalizer + +__all__ = [ + "BatchedFrequencyPenalizer", + "BatchedMinNewTokensPenalizer", + "BatchedPresencePenalizer", + "BatchedRepetitionPenalizer", + "BatchedPenalizerOrchestrator", +] diff --git a/python/sglang/srt/sampling/penaltylib/orchestrator.py b/python/sglang/srt/sampling/penaltylib/orchestrator.py new file mode 100644 index 00000000000..4214a746bda --- /dev/null +++ b/python/sglang/srt/sampling/penaltylib/orchestrator.py @@ -0,0 +1,357 @@ +import abc +import dataclasses +import typing + +import torch + + +@dataclasses.dataclass +class _ReqLike: + origin_input_ids: typing.Union[torch.Tensor, typing.List[int]] + + +@dataclasses.dataclass +class _BatchLike: + reqs: typing.List[_ReqLike] + + def batch_size(self): + return len(self.reqs) + + +class BatchedPenalizerOrchestrator: + batch: _BatchLike + device: str + vocab_size: int + penalizers: typing.Dict[typing.Type["_BatchedPenalizer"], "_BatchedPenalizer"] + + def __init__( + self, + vocab_size: int, + batch: _BatchLike, + device: str, + Penalizers: typing.Set[typing.Type["_BatchedPenalizer"]], + ): + self.vocab_size = vocab_size + self.batch = batch + self.device = device + + self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers} + + for penalizer in self.penalizers.values(): + penalizer.prepare_if_required() + + self.cumulate_input_tokens( + input_ids=[req.origin_input_ids for req in self.reqs()] + ) + + def reqs(self): + return self.batch.reqs + + def batch_size(self): + return self.batch.batch_size() + + def cumulate_input_tokens( + self, + input_ids: typing.Union[ + typing.List[torch.Tensor], typing.List[typing.List[int]] + ], + ): + """ + Feed the input tokens to the penalizers. + + Args: + input_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The input tokens. + """ + token_ids = _TokenIDs(orchestrator=self, token_ids=input_ids) + + for penalizer in self.penalizers.values(): + penalizer.cumulate_input_tokens(input_ids=token_ids) + + def cumulate_output_tokens( + self, + output_ids: typing.Union[ + typing.List[torch.Tensor], typing.List[typing.List[int]] + ], + ): + """ + Feed the output tokens to the penalizers. + + Args: + output_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The output tokens. + """ + token_ids = _TokenIDs(orchestrator=self, token_ids=output_ids) + + for penalizer in self.penalizers.values(): + penalizer.cumulate_output_tokens(output_ids=token_ids) + + def apply(self, logits: torch.Tensor) -> torch.Tensor: + """ + Apply the penalizers to the logits. + Note that it may apply the penalizers in-place. + + Args: + logits (torch.Tensor): The logits to apply the penalizers to. + + Returns: + torch.Tensor: The logits after applying the penalizers. + """ + for penalizer in self.penalizers.values(): + logits = penalizer.apply(logits) + + return logits + + def filter( + self, + indices_to_keep: typing.List[int], + indices_tensor_to_keep: torch.Tensor = None, + ): + """ + Filter the penalizers based on the indices to keep in the batch. + + Args: + indices_to_keep (typing.List[int]): List of indices to keep in the batch. + indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor. + """ + empty_indices = len(indices_to_keep) == 0 + + for penalizer in self.penalizers.values(): + if not penalizer.is_required() or empty_indices: + penalizer.teardown() + else: + # create tensor index only when it's needed + if indices_tensor_to_keep is None: + indices_tensor_to_keep = torch.tensor( + indices_to_keep, dtype=torch.int32, device=self.device + ) + + penalizer.filter( + indices_to_keep=indices_to_keep, + indices_tensor_to_keep=indices_tensor_to_keep, + ) + + def merge(self, their: "BatchedPenalizerOrchestrator"): + """ + Merge the penalizers of another orchestrator into this one. + + Note that this function **must** be called _before_ self.batch.reqs is updated (filtered). + Each unprepared penalizers would have to be prepared (creating tensors, etc.) first before merging. + This step requires the original batch.reqs, before it gets merged with other batch.reqs. + + Args: + their (BatchedPenalizerOrchestrator): The orchestrator to merge into this one. + """ + if self.vocab_size != their.vocab_size: + raise ValueError( + f"vocab_size mismatch: {self.vocab_size} != {their.vocab_size}" + ) + + for Penalizer, their_penalizer in their.penalizers.items(): + if Penalizer not in self.penalizers: + raise ValueError(f"Penalizer {Penalizer} not found in self.penalizers") + + self.penalizers[Penalizer].merge(their_penalizer) + + +class _TokenIDs: + """ + A class that wraps token IDs to provide additional utility functions to penalizers. + + Attributes: + orchestrator (BatchedPenalizerOrchestrator): The orchestrator that this token IDs belong to. + token_ids (typing.Union[torch.Tensor, typing.List[torch.Tensor]]): The token IDs. + cached_counts (torch.Tensor): The cached occurrence count tensor. + """ + + orchestrator: BatchedPenalizerOrchestrator + token_ids: typing.Union[torch.Tensor, typing.List[torch.Tensor]] + cached_counts: torch.Tensor = None + + def __init__( + self, + orchestrator: BatchedPenalizerOrchestrator, + token_ids: typing.Union[ + typing.List[torch.Tensor], typing.List[typing.List[int]] + ], + ): + self.orchestrator = orchestrator + + if not isinstance(token_ids[0], torch.Tensor): + token_ids = [ + torch.tensor( + data=ids, dtype=torch.int64, device=self.orchestrator.device + ) + for ids in token_ids + ] + + self.token_ids = token_ids + + def occurrence_count(self) -> torch.Tensor: + """ + Returns a tensor of shape (batch_size, vocab_size) where each element is the number of times the corresponding token appears in the batch. + + Returns: + torch.Tensor: The occurrence count tensor. + """ + if self.cached_counts is not None: + return self.cached_counts + + token_ids = self.token_ids + + if isinstance(token_ids, torch.Tensor): + token_ids = token_ids.unsqueeze(1) + + # needs to be long to be used as index in scatter_add + if token_ids.dtype != torch.int64: + token_ids = token_ids.to(torch.int64) + + padded_token_ids = torch.nn.utils.rnn.pad_sequence( + sequences=token_ids, + batch_first=True, + padding_value=self.orchestrator.vocab_size, + ) + + self.cached_counts = torch.zeros( + size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1), + dtype=torch.int64, + device=self.orchestrator.device, + ).scatter_add_( + dim=1, + index=padded_token_ids, + src=torch.ones_like(padded_token_ids), + )[ + :, : self.orchestrator.vocab_size + ] + + return self.cached_counts + + +class _BatchedPenalizer(abc.ABC): + """ + An abstract class for a batched penalizer. + """ + + orchestrator: BatchedPenalizerOrchestrator + _is_prepared: bool = False + + def __init__(self, orchestrator: BatchedPenalizerOrchestrator): + self.orchestrator = orchestrator + + def is_prepared(self) -> bool: + return self._is_prepared + + def is_required(self) -> bool: + return self._is_required() + + def prepare(self): + if not self.is_prepared(): + self._prepare() + self._is_prepared = True + + def prepare_if_required(self): + if self.is_required(): + self.prepare() + + def teardown(self): + if self.is_prepared(): + self._teardown() + self._is_prepared = False + + def cumulate_input_tokens(self, input_ids: _TokenIDs): + if not self.is_prepared(): + return + + self._cumulate_input_tokens(input_ids=input_ids) + + def cumulate_output_tokens(self, output_ids: _TokenIDs): + if not self.is_prepared(): + return + + self._cumulate_output_tokens(output_ids=output_ids) + + def apply(self, logits: torch.Tensor) -> torch.Tensor: + if not self.is_prepared(): + return logits + + return self._apply(logits=logits) + + def filter( + self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor + ): + if not self.is_prepared(): + return + + self._filter( + indices_to_keep=indices_to_keep, + indices_tensor_to_keep=indices_tensor_to_keep, + ) + + def merge(self, their: "_BatchedPenalizer"): + if not self.is_prepared() and not their.is_prepared(): + return + + self.prepare() + their.prepare() + self._merge(their) + + @abc.abstractmethod + def _is_required(self) -> bool: + """ + Check if the penalizer is required to be prepared. + """ + pass + + @abc.abstractmethod + def _prepare(self): + """ + Prepare the penalizer. + Usually, this is where the penalizer initializes its tensors. + """ + pass + + @abc.abstractmethod + def _teardown(self): + """ + Tear down the penalizer. + Usually, this is where the penalizer frees its tensors. + """ + pass + + @abc.abstractmethod + def _cumulate_input_tokens(self, input_ids: _TokenIDs): + """ + Cumulate the input tokens. + Orchestrator will call this function to feed the input tokens to the penalizer. + """ + pass + + @abc.abstractmethod + def _cumulate_output_tokens(self, output_ids: _TokenIDs): + """ + Cumulate the output tokens. + Orchestrator will call this function to feed the output tokens to the penalizer. + """ + pass + + @abc.abstractmethod + def _apply(self, logits: torch.Tensor) -> torch.Tensor: + """ + Apply the penalizer to the logits. + Penalizers can modify the logits in-place if needed. + """ + pass + + @abc.abstractmethod + def _filter( + self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor + ): + """ + Filter the penalizer (tensors or underlying data) based on the indices to keep in the batch. + """ + pass + + @abc.abstractmethod + def _merge(self, their: "_BatchedPenalizer"): + """ + Merge the penalizer with another penalizer. + """ + pass diff --git a/python/sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py b/python/sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py new file mode 100644 index 00000000000..178cb54b24c --- /dev/null +++ b/python/sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py @@ -0,0 +1,80 @@ +import typing + +import torch + +from ..orchestrator import _BatchedPenalizer, _TokenIDs + + +class BatchedFrequencyPenalizer(_BatchedPenalizer): + """ + Frequency penalizer penalizes tokens based on their frequency in the output. + """ + + frequency_penalties: torch.Tensor = None + cumulated_frequency_penalties: torch.Tensor = None + + def _is_required(self) -> bool: + return any( + req.sampling_params.frequency_penalty != 0.0 + for req in self.orchestrator.reqs() + ) + + def _prepare(self): + self.cumulated_frequency_penalties = ( + torch.tensor( + data=[0.0 for _ in self.orchestrator.reqs()], + dtype=torch.float32, + device=self.orchestrator.device, + ) + .unsqueeze_(1) + .repeat(1, self.orchestrator.vocab_size) + ) + + self.frequency_penalties = ( + torch.tensor( + data=[ + req.sampling_params.frequency_penalty + for req in self.orchestrator.reqs() + ], + dtype=torch.float32, + device=self.orchestrator.device, + ) + .unsqueeze_(1) + .expand_as(self.cumulated_frequency_penalties) + ) + + def _teardown(self): + del self.frequency_penalties + del self.cumulated_frequency_penalties + + self.frequency_penalties = None + self.cumulated_frequency_penalties = None + + def _cumulate_input_tokens(self, input_ids: _TokenIDs): + pass + + def _cumulate_output_tokens(self, output_ids: _TokenIDs): + self.cumulated_frequency_penalties += ( + self.frequency_penalties * output_ids.occurrence_count() + ) + + def _apply(self, logits: torch.Tensor) -> torch.Tensor: + logits -= self.cumulated_frequency_penalties + return logits + + def _filter( + self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor + ): + self.frequency_penalties = self.frequency_penalties[indices_tensor_to_keep] + self.cumulated_frequency_penalties = self.cumulated_frequency_penalties[ + indices_tensor_to_keep + ] + + def _merge(self, their: "BatchedFrequencyPenalizer"): + self.frequency_penalties = torch.cat( + [self.frequency_penalties, their.frequency_penalties], dim=0 + ) + self.cumulated_frequency_penalties = torch.cat( + [self.cumulated_frequency_penalties, their.cumulated_frequency_penalties], + dim=0, + ) diff --git a/python/sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py b/python/sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py new file mode 100644 index 00000000000..c9e0f078ed0 --- /dev/null +++ b/python/sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py @@ -0,0 +1,105 @@ +import typing + +import torch + +from ..orchestrator import _BatchedPenalizer, _TokenIDs + + +class BatchedMinNewTokensPenalizer(_BatchedPenalizer): + """ + Min new tokens penalizer penalizes tokens based on the length of the output. + """ + + min_new_tokens: torch.Tensor = None + stop_token_penalties: torch.Tensor = None + len_output_tokens: torch.Tensor = None + + def _is_required(self) -> bool: + return any( + req.sampling_params.min_new_tokens > 0 for req in self.orchestrator.reqs() + ) + + def _prepare(self): + self.min_new_tokens = torch.tensor( + data=[ + req.sampling_params.min_new_tokens for req in self.orchestrator.reqs() + ], + dtype=torch.int32, + device=self.orchestrator.device, + ).unsqueeze_(1) + + padded_stop_token_ids = torch.nn.utils.rnn.pad_sequence( + sequences=[ + torch.tensor( + data=list( + req.sampling_params.stop_token_ids + | {req.tokenizer.eos_token_id} + ), + dtype=torch.int64, + device=self.orchestrator.device, + ) + for req in self.orchestrator.reqs() + ], + batch_first=True, + padding_value=self.orchestrator.vocab_size, + ) + self.stop_token_penalties = torch.zeros( + size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1), + dtype=torch.float32, + device=self.orchestrator.device, + ).scatter_add_( + dim=1, + index=padded_stop_token_ids, + src=torch.full_like( + input=padded_stop_token_ids, + dtype=torch.float32, + fill_value=float("-inf"), + device=self.orchestrator.device, + ), + )[ + :, : self.orchestrator.vocab_size + ] + + self.len_output_tokens = torch.zeros( + size=(self.orchestrator.batch_size(), 1), + dtype=torch.int32, + device=self.orchestrator.device, + ) + + def _teardown(self): + del self.min_new_tokens + del self.stop_token_penalties + del self.len_output_tokens + + self.min_new_tokens = None + self.stop_token_penalties = None + self.len_output_tokens = None + + def _cumulate_input_tokens(self, input_ids: _TokenIDs): + pass + + def _cumulate_output_tokens(self, output_ids: _TokenIDs): + self.len_output_tokens += 1 + + def _apply(self, logits: torch.Tensor) -> torch.Tensor: + mask = (self.len_output_tokens < self.min_new_tokens).expand_as(logits) + logits[mask] += self.stop_token_penalties[mask] + return logits + + def _filter( + self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor + ): + self.min_new_tokens = self.min_new_tokens[indices_tensor_to_keep] + self.stop_token_penalties = self.stop_token_penalties[indices_tensor_to_keep] + self.len_output_tokens = self.len_output_tokens[indices_tensor_to_keep] + + def _merge(self, their: "BatchedMinNewTokensPenalizer"): + self.min_new_tokens = torch.cat( + [self.min_new_tokens, their.min_new_tokens], dim=0 + ) + self.stop_token_penalties = torch.cat( + [self.stop_token_penalties, their.stop_token_penalties], dim=0 + ) + self.len_output_tokens = torch.cat( + [self.len_output_tokens, their.len_output_tokens], dim=0 + ) diff --git a/python/sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py b/python/sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py new file mode 100644 index 00000000000..0593fddc9c3 --- /dev/null +++ b/python/sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py @@ -0,0 +1,79 @@ +import typing + +import torch + +from ..orchestrator import _BatchedPenalizer, _TokenIDs + + +class BatchedPresencePenalizer(_BatchedPenalizer): + """ + Presence penalizer penalizes tokens based on their presence in the output. + """ + + presence_penalties: torch.Tensor = None + cumulated_presence_penalties: torch.Tensor = None + + def _is_required(self) -> bool: + return any( + req.sampling_params.presence_penalty != 0.0 + for req in self.orchestrator.reqs() + ) + + def _prepare(self): + self.cumulated_presence_penalties = ( + torch.tensor( + data=[0.0 for _ in self.orchestrator.reqs()], + dtype=torch.float32, + device=self.orchestrator.device, + ) + .unsqueeze_(1) + .repeat(1, self.orchestrator.vocab_size) + ) + + self.presence_penalties = ( + torch.tensor( + data=[ + req.sampling_params.presence_penalty + for req in self.orchestrator.reqs() + ], + dtype=torch.float32, + device=self.orchestrator.device, + ) + .unsqueeze_(1) + .expand_as(self.cumulated_presence_penalties) + ) + + def _teardown(self): + del self.presence_penalties + del self.cumulated_presence_penalties + + self.presence_penalties = None + self.cumulated_presence_penalties = None + + def _cumulate_input_tokens(self, input_ids: _TokenIDs): + pass + + def _cumulate_output_tokens(self, output_ids: _TokenIDs): + mask = output_ids.occurrence_count() > 0 + self.cumulated_presence_penalties[mask] = self.presence_penalties[mask] + + def _apply(self, logits: torch.Tensor) -> torch.Tensor: + logits -= self.cumulated_presence_penalties + return logits + + def _filter( + self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor + ): + self.presence_penalties = self.presence_penalties[indices_tensor_to_keep] + self.cumulated_presence_penalties = self.cumulated_presence_penalties[ + indices_tensor_to_keep + ] + + def _merge(self, their: "BatchedPresencePenalizer"): + self.presence_penalties = torch.cat( + [self.presence_penalties, their.presence_penalties], dim=0 + ) + self.cumulated_presence_penalties = torch.cat( + [self.cumulated_presence_penalties, their.cumulated_presence_penalties], + dim=0, + ) diff --git a/python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py b/python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py new file mode 100644 index 00000000000..ea32addc2ea --- /dev/null +++ b/python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py @@ -0,0 +1,83 @@ +import typing + +import torch + +from ..orchestrator import _BatchedPenalizer, _TokenIDs + + +class BatchedRepetitionPenalizer(_BatchedPenalizer): + """ + Repetition penalizer penalizes tokens based on their repetition in the input and output. + """ + + repetition_penalties: torch.Tensor = None + cumulated_repetition_penalties: torch.Tensor = None + + def _is_required(self) -> bool: + return any( + req.sampling_params.repetition_penalty != 1.0 + for req in self.orchestrator.reqs() + ) + + def _prepare(self): + self.cumulated_repetition_penalties = ( + torch.tensor( + data=[1.0 for _ in self.orchestrator.reqs()], + dtype=torch.float32, + device=self.orchestrator.device, + ) + .unsqueeze_(1) + .repeat(1, self.orchestrator.vocab_size) + ) + + self.repetition_penalties = ( + torch.tensor( + data=[ + req.sampling_params.repetition_penalty + for req in self.orchestrator.reqs() + ], + dtype=torch.float32, + device=self.orchestrator.device, + ) + .unsqueeze_(1) + .expand_as(self.cumulated_repetition_penalties) + ) + + def _teardown(self): + del self.repetition_penalties + del self.cumulated_repetition_penalties + + self.repetition_penalties = None + self.cumulated_repetition_penalties = None + + def _cumulate_input_tokens(self, input_ids: _TokenIDs): + mask = input_ids.occurrence_count() > 0 + self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask] + + def _cumulate_output_tokens(self, output_ids: _TokenIDs): + mask = output_ids.occurrence_count() > 0 + self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask] + + def _apply(self, logits: torch.Tensor) -> torch.Tensor: + return torch.where( + logits > 0, + logits / self.cumulated_repetition_penalties, + logits * self.cumulated_repetition_penalties, + ) + + def _filter( + self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor + ): + self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep] + self.cumulated_repetition_penalties = self.cumulated_repetition_penalties[ + indices_tensor_to_keep + ] + + def _merge(self, their: "BatchedRepetitionPenalizer"): + self.repetition_penalties = torch.cat( + [self.repetition_penalties, their.repetition_penalties], dim=0 + ) + self.cumulated_repetition_penalties = torch.cat( + [self.cumulated_repetition_penalties, their.cumulated_repetition_penalties], + dim=0, + ) diff --git a/python/sglang/srt/sampling_params.py b/python/sglang/srt/sampling_params.py index f6b4f570663..39774d9acfc 100644 --- a/python/sglang/srt/sampling_params.py +++ b/python/sglang/srt/sampling_params.py @@ -1,3 +1,18 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + """Sampling parameters for text generation.""" from typing import List, Optional, Union @@ -8,31 +23,39 @@ class SamplingParams: def __init__( self, - max_new_tokens: int = 16, + max_new_tokens: int = 128, + min_new_tokens: int = 0, stop: Optional[Union[str, List[str]]] = None, + stop_token_ids: Optional[List[int]] = [], temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, + repetition_penalty: float = 1.0, ignore_eos: bool = False, skip_special_tokens: bool = True, spaces_between_special_tokens: bool = True, dtype: Optional[str] = None, regex: Optional[str] = None, + n: int = 1, ) -> None: self.temperature = temperature self.top_p = top_p self.top_k = top_k self.frequency_penalty = frequency_penalty self.presence_penalty = presence_penalty + self.repetition_penalty = repetition_penalty self.stop_strs = stop + self.stop_token_ids = {*stop_token_ids} self.max_new_tokens = max_new_tokens + self.min_new_tokens = min_new_tokens self.ignore_eos = ignore_eos self.skip_special_tokens = skip_special_tokens self.spaces_between_special_tokens = spaces_between_special_tokens self.dtype = dtype self.regex = regex + self.n = n # Process some special cases if self.temperature < _SAMPLING_EPS: @@ -63,10 +86,26 @@ def verify(self): raise ValueError( "presence_penalty must be in [-2, 2], got " f"{self.presence_penalty}." ) - if self.max_new_tokens < 0: + if not 0.0 <= self.repetition_penalty <= 2.0: + raise ValueError( + "repetition_penalty must be in (0, 2], got " + f"{self.repetition_penalty}." + ) + if not 0 <= self.min_new_tokens: raise ValueError( - f"max_new_tokens must be at least 0, got {self.max_new_tokens}." + f"min_new_tokens must be in (0, max_new_tokens], got " + f"{self.min_new_tokens}." ) + if self.max_new_tokens is not None: + if self.max_new_tokens < 0: + raise ValueError( + f"max_new_tokens must be at least 0, got {self.max_new_tokens}." + ) + if not self.min_new_tokens <= self.max_new_tokens: + raise ValueError( + f"min_new_tokens must be in (0, max_new_tokens({self.max_new_tokens})], got " + f"{self.min_new_tokens}." + ) def normalize(self, tokenizer): # Process stop strings diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index eb37c7bb538..d6e3f31ecbc 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -1,3 +1,18 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + """ The entry point of inference server. SRT = SGLang Runtime. @@ -13,7 +28,7 @@ import threading import time from http import HTTPStatus -from typing import Dict, Optional +from typing import Dict, List, Optional, Union # Fix a bug of Python threading setattr(threading, "_register_atexit", lambda *args, **kwargs: None) @@ -23,37 +38,45 @@ import requests import uvicorn import uvloop -from fastapi import FastAPI, Request +from fastapi import FastAPI, File, Form, Request, UploadFile from fastapi.responses import JSONResponse, Response, StreamingResponse -from sglang.backend.runtime_endpoint import RuntimeEndpoint +from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint from sglang.srt.constrained import disable_cache from sglang.srt.hf_transformers_utils import get_tokenizer -from sglang.srt.managers.controller.manager_multi import ( +from sglang.srt.managers.controller_multi import ( start_controller_process as start_controller_process_multi, ) -from sglang.srt.managers.controller.manager_single import ( +from sglang.srt.managers.controller_single import launch_tp_servers +from sglang.srt.managers.controller_single import ( start_controller_process as start_controller_process_single, ) -from sglang.srt.managers.controller.tp_worker import ModelTpService from sglang.srt.managers.detokenizer_manager import start_detokenizer_process -from sglang.srt.managers.io_struct import GenerateReqInput +from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput from sglang.srt.managers.tokenizer_manager import TokenizerManager -from sglang.srt.openai_api_adapter import ( +from sglang.srt.openai_api.adapter import ( load_chat_template_for_openai_api, + v1_batches, v1_chat_completions, v1_completions, + v1_delete_file, + v1_files_create, + v1_retrieve_batch, + v1_retrieve_file, + v1_retrieve_file_content, ) -from sglang.srt.server_args import ModelPortArgs, PortArgs, ServerArgs +from sglang.srt.openai_api.protocol import ModelCard, ModelList +from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import ( - API_KEY_HEADER_NAME, - APIKeyValidatorMiddleware, + add_api_key_middleware, allocate_init_ports, assert_pkg_version, enable_show_time_cost, - receive_addrs, - send_addrs_to_rank_0, - start_rpyc_service_process, + kill_child_process, + maybe_set_triton_cache_manager, + prepare_model, + prepare_tokenizer, + set_ulimit, ) from sglang.utils import get_exception_traceback @@ -76,6 +99,7 @@ async def health() -> Response: async def get_model_info(): result = { "model_path": tokenizer_manager.model_path, + "is_generation": tokenizer_manager.is_generation, } return result @@ -96,6 +120,7 @@ async def flush_cache(): async def generate_request(obj: GenerateReqInput, request: Request): + """Handle a generate request.""" if obj.stream: async def stream_results(): @@ -126,6 +151,21 @@ async def stream_results(): app.put("/generate")(generate_request) +async def encode_request(obj: EmbeddingReqInput, request: Request): + """Handle an embedding request.""" + try: + ret = await tokenizer_manager.generate_request(obj, request).__anext__() + return ret + except ValueError as e: + return JSONResponse( + {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST + ) + + +app.post("/encode")(encode_request) +app.put("/encode")(encode_request) + + @app.post("/v1/completions") async def openai_v1_completions(raw_request: Request): return await v1_completions(tokenizer_manager, raw_request) @@ -136,7 +176,57 @@ async def openai_v1_chat_completions(raw_request: Request): return await v1_chat_completions(tokenizer_manager, raw_request) -def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_args=None): +@app.get("/v1/models") +def available_models(): + """Show available models.""" + served_model_names = [tokenizer_manager.served_model_name] + model_cards = [] + for served_model_name in served_model_names: + model_cards.append(ModelCard(id=served_model_name, root=served_model_name)) + return ModelList(data=model_cards) + + +@app.post("/v1/files") +async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")): + return await v1_files_create( + file, purpose, tokenizer_manager.server_args.file_storage_pth + ) + + +@app.delete("/v1/files/{file_id}") +async def delete_file(file_id: str): + # https://platform.openai.com/docs/api-reference/files/delete + return await v1_delete_file(file_id) + + +@app.post("/v1/batches") +async def openai_v1_batches(raw_request: Request): + return await v1_batches(tokenizer_manager, raw_request) + + +@app.get("/v1/batches/{batch_id}") +async def retrieve_batch(batch_id: str): + return await v1_retrieve_batch(batch_id) + + +@app.get("/v1/files/{file_id}") +async def retrieve_file(file_id: str): + # https://platform.openai.com/docs/api-reference/files/retrieve + return await v1_retrieve_file(file_id) + + +@app.get("/v1/files/{file_id}/content") +async def retrieve_file_content(file_id: str): + # https://platform.openai.com/docs/api-reference/files/retrieve-contents + return await v1_retrieve_file_content(file_id) + + +def launch_server( + server_args: ServerArgs, + model_overide_args: Optional[dict] = None, + pipe_finish_writer: Optional[mp.connection.Connection] = None, +): + """Launch an HTTP server.""" global tokenizer_manager logging.basicConfig( @@ -144,85 +234,65 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg format="%(message)s", ) - # Set global environments - os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" - if server_args.show_time_cost: - enable_show_time_cost() - if server_args.disable_disk_cache: - disable_cache() - if not server_args.disable_flashinfer: - assert_pkg_version( - "flashinfer", - "0.0.8", - "Please uninstall the old version and " - "reinstall the latest version by following the instructions " - "at https://docs.flashinfer.ai/installation.html.", - ) - if server_args.chat_template: - # TODO: replace this with huggingface transformers template - load_chat_template_for_openai_api(server_args.chat_template) + server_args.check_server_args() + _set_envs_and_config(server_args) # Allocate ports - assert server_args.tp_size % server_args.nnodes == 0 - tp_size_local = server_args.tp_size // server_args.nnodes server_args.port, server_args.additional_ports = allocate_init_ports( server_args.port, server_args.additional_ports, - tp_size_local, server_args.dp_size, ) - ports = server_args.additional_ports - model_port_args = [] - for i in range(server_args.dp_size): - model_port_args.append( - ModelPortArgs( - nccl_port=ports[3 + i * (tp_size_local + 1)], - model_tp_ips=[None] * tp_size_local, - model_tp_ports=ports[ - 3 + i * (tp_size_local + 1) + 1 : 3 + (i + 1) * (tp_size_local + 1) - ], - ) - ) port_args = PortArgs( tokenizer_port=ports[0], - router_port=ports[1], + controller_port=ports[1], detokenizer_port=ports[2], - model_port_args=model_port_args, + nccl_ports=ports[3:], ) + logger.info(f"{server_args=}") - # TODO multi-node dp is not supported - assert not (server_args.dp_size > 1 and server_args.node_rank is not None) + # Use model from www.modelscope.cn, first download the model. + server_args.model_path = prepare_model(server_args.model_path) + server_args.tokenizer_path = prepare_tokenizer(server_args.tokenizer_path) + + # Launch processes for multi-node tensor parallelism if server_args.nnodes > 1: if server_args.node_rank != 0: - send_addrs_to_rank_0(model_port_args[0], server_args) - else: - receive_addrs(model_port_args[0], server_args) - for i in range(tp_size_local): - start_rpyc_service_process( - ModelTpService, model_port_args[0].model_tp_ports[i] + tp_size_local = server_args.tp_size // server_args.nnodes + gpu_ids = [ + i for _ in range(server_args.nnodes) for i in range(tp_size_local) + ] + tp_rank_range = list( + range( + server_args.node_rank * tp_size_local, + (server_args.node_rank + 1) * tp_size_local, + ) ) - if server_args.node_rank != 0: - logger.info( - f"[node_rank={server_args.node_rank}]: Listen for connections..." + procs = launch_tp_servers( + gpu_ids, + tp_rank_range, + server_args, + ports[3], + model_overide_args, ) while True: pass # Launch processes tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args) - pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False) + pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False) pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False) if server_args.dp_size == 1: start_process = start_controller_process_single else: start_process = start_controller_process_multi - proc_router = mp.Process( + proc_controller = mp.Process( target=start_process, - args=(server_args, port_args, pipe_router_writer, model_overide_args), + args=(server_args, port_args, pipe_controller_writer, model_overide_args), ) - proc_router.start() + proc_controller.start() proc_detoken = mp.Process( target=start_detokenizer_process, args=( @@ -234,68 +304,31 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg proc_detoken.start() # Wait for the model to finish loading - router_init_state = pipe_router_reader.recv() + controller_init_state = pipe_controller_reader.recv() detoken_init_state = pipe_detoken_reader.recv() - if router_init_state != "init ok" or detoken_init_state != "init ok": - proc_router.kill() + if controller_init_state != "init ok" or detoken_init_state != "init ok": + proc_controller.kill() proc_detoken.kill() print( - f"Initialization failed. router_init_state: {router_init_state}", flush=True + f"Initialization failed. controller_init_state: {controller_init_state}", + flush=True, ) print( f"Initialization failed. detoken_init_state: {detoken_init_state}", flush=True, ) sys.exit(1) - assert proc_router.is_alive() and proc_detoken.is_alive() + assert proc_controller.is_alive() and proc_detoken.is_alive() - if server_args.api_key and server_args.api_key != "": - app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key) + # Add api key authorization + if server_args.api_key: + add_api_key_middleware(app, server_args.api_key) # Send a warmup request - def _wait_and_warmup(): - headers = {} - url = server_args.url() - if server_args.api_key: - headers[API_KEY_HEADER_NAME] = server_args.api_key - - # Wait until the server is launched - for _ in range(120): - time.sleep(0.5) - try: - requests.get(url + "/get_model_info", timeout=5, headers=headers) - break - except requests.exceptions.RequestException: - pass - - # Send a warmup request - try: - for _ in range(server_args.dp_size): - res = requests.post( - url + "/generate", - json={ - "text": "The capital city of France is", - "sampling_params": { - "temperature": 0, - "max_new_tokens": 8, - }, - }, - headers=headers, - timeout=600, - ) - assert res.status_code == 200 - except Exception as e: - if pipe_finish_writer is not None: - pipe_finish_writer.send(get_exception_traceback()) - print(f"Initialization failed. warmup error: {e}", flush=True) - raise e - - logger.info("The server is fired up and ready to roll!") - if pipe_finish_writer is not None: - pipe_finish_writer.send("init ok") - - t = threading.Thread(target=_wait_and_warmup) + t = threading.Thread( + target=_wait_and_warmup, args=(server_args, pipe_finish_writer) + ) t.start() # Listen for requests @@ -312,6 +345,101 @@ def _wait_and_warmup(): t.join() +def _set_envs_and_config(server_args: ServerArgs): + # Set global environments + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + os.environ["NCCL_CUMEM_ENABLE"] = "0" + os.environ["NCCL_NVLS_ENABLE"] = "0" + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + + # Set ulimit + set_ulimit() + + # Enable show time cost for debugging + if server_args.show_time_cost: + enable_show_time_cost() + + # Disable disk cache + if server_args.disable_disk_cache: + disable_cache() + + # Fix triton bugs + if server_args.tp_size * server_args.dp_size > 1: + # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency. + maybe_set_triton_cache_manager() + + # Set global chat template + if server_args.chat_template: + # TODO: replace this with huggingface transformers template + load_chat_template_for_openai_api(server_args.chat_template) + + # Check flashinfer version + if not server_args.disable_flashinfer: + assert_pkg_version( + "flashinfer", + "0.1.3", + "Please uninstall the old version and " + "reinstall the latest version by following the instructions " + "at https://docs.flashinfer.ai/installation.html.", + ) + + +def _wait_and_warmup(server_args, pipe_finish_writer): + headers = {} + url = server_args.url() + if server_args.api_key: + headers["Authorization"] = f"Bearer {server_args.api_key}" + + # Wait until the server is launched + success = False + for _ in range(120): + time.sleep(1) + try: + res = requests.get(url + "/get_model_info", timeout=5, headers=headers) + assert res.status_code == 200, f"{res}" + success = True + break + except (AssertionError, requests.exceptions.RequestException) as e: + last_traceback = get_exception_traceback() + pass + model_info = res.json() + + if not success: + if pipe_finish_writer is not None: + pipe_finish_writer.send(last_traceback) + print(f"Initialization failed. warmup error: {last_traceback}", flush=True) + sys.exit(1) + + # Send a warmup request + request_name = "/generate" if model_info["is_generation"] else "/encode" + max_new_tokens = 8 if model_info["is_generation"] else 0 + try: + for _ in range(server_args.dp_size): + res = requests.post( + url + request_name, + json={ + "text": "The capital city of France is", + "sampling_params": { + "temperature": 0, + "max_new_tokens": max_new_tokens, + }, + }, + headers=headers, + timeout=600, + ) + assert res.status_code == 200, f"{res}" + except Exception as e: + last_traceback = get_exception_traceback() + if pipe_finish_writer is not None: + pipe_finish_writer.send(last_traceback) + print(f"Initialization failed. warmup error: {last_traceback}", flush=True) + sys.exit(1) + + logger.info("The server is fired up and ready to roll!") + if pipe_finish_writer is not None: + pipe_finish_writer.send("init ok") + + class Runtime: """ A wrapper for the server. @@ -333,7 +461,6 @@ def __init__( self.server_args.port, self.server_args.additional_ports = allocate_init_ports( self.server_args.port, self.server_args.additional_ports, - self.server_args.tp_size, self.server_args.dp_size, ) @@ -346,7 +473,7 @@ def __init__( pipe_reader, pipe_writer = mp.Pipe(duplex=False) proc = mp.Process( target=launch_server, - args=(self.server_args, pipe_writer, model_overide_args), + args=(self.server_args, model_overide_args, pipe_writer), ) proc.start() pipe_writer.close() @@ -367,18 +494,12 @@ def __init__( def shutdown(self): if self.pid is not None: - try: - parent = psutil.Process(self.pid) - except psutil.NoSuchProcess: - return - children = parent.children(recursive=True) - for child in children: - child.kill() - psutil.wait_procs(children, timeout=5) - parent.kill() - parent.wait(timeout=5) + kill_child_process(self.pid) self.pid = None + def cache_prefix(self, prefix: str): + self.endpoint.cache_prefix(prefix) + def get_tokenizer(self): return get_tokenizer( self.server_args.tokenizer_path, @@ -386,10 +507,10 @@ def get_tokenizer(self): trust_remote_code=self.server_args.trust_remote_code, ) - async def add_request( + async def async_generate( self, prompt: str, - sampling_params: Dict, + sampling_params: Optional[Dict] = None, ): json_data = { "text": prompt, @@ -412,5 +533,39 @@ async def add_request( yield cur pos += len(cur) + add_request = async_generate + + def generate( + self, + prompt: str, + sampling_params: Optional[Dict] = None, + return_logprob: Optional[Union[List[bool], bool]] = False, + top_logprobs_num: Optional[Union[List[int], int]] = None, + ): + json_data = { + "text": prompt, + "sampling_params": sampling_params, + "return_logprob": return_logprob, + "top_logprobs_num": top_logprobs_num, + } + response = requests.post( + self.url + "/generate", + json=json_data, + ) + return json.dumps(response.json()) + + def encode( + self, + prompt: str, + ): + json_data = { + "text": prompt, + } + response = requests.post( + self.url + "/encode", + json=json_data, + ) + return json.dumps(response.json()) + def __del__(self): self.shutdown() diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index ef8b6d252ae..f42afdf8d56 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1,3 +1,18 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + """The arguments of the server.""" import argparse @@ -17,6 +32,7 @@ class ServerArgs: trust_remote_code: bool = True context_length: Optional[int] = None quantization: Optional[str] = None + served_model_name: Optional[str] = None chat_template: Optional[str] = None # Port @@ -28,12 +44,14 @@ class ServerArgs: mem_fraction_static: Optional[float] = None max_prefill_tokens: Optional[int] = None max_running_requests: Optional[int] = None - schedule_heuristic: str = "lpm" + max_num_reqs: Optional[int] = None + max_total_tokens: Optional[int] = None + schedule_policy: str = "lpm" schedule_conservativeness: float = 1.0 # Other runtime options tp_size: int = 1 - stream_interval: int = 8 + stream_interval: int = 1 random_seed: Optional[int] = None # Logging @@ -43,20 +61,28 @@ class ServerArgs: show_time_cost: bool = False # Other - api_key: str = "" + api_key: Optional[str] = None + file_storage_pth: str = "SGlang_storage" # Data parallelism dp_size: int = 1 load_balance_method: str = "round_robin" + # Chunked Prefill + chunked_prefill_size: Optional[int] = None + # Optimization/debug options disable_flashinfer: bool = False + disable_flashinfer_sampling: bool = False disable_radix_cache: bool = False disable_regex_jump_forward: bool = False disable_cuda_graph: bool = False disable_disk_cache: bool = False - attention_reduce_in_fp32: bool = False + enable_torch_compile: bool = False enable_p2p_check: bool = False + enable_mla: bool = False + attention_reduce_in_fp32: bool = False + efficient_weight_load: bool = False # Distributed args nccl_init_addr: Optional[str] = None @@ -66,15 +92,21 @@ class ServerArgs: def __post_init__(self): if self.tokenizer_path is None: self.tokenizer_path = self.model_path + + if self.served_model_name is None: + self.served_model_name = self.model_path + if self.mem_fraction_static is None: - if self.tp_size >= 8: - self.mem_fraction_static = 0.80 + if self.tp_size >= 16: + self.mem_fraction_static = 0.79 + elif self.tp_size >= 8: + self.mem_fraction_static = 0.83 elif self.tp_size >= 4: - self.mem_fraction_static = 0.82 - elif self.tp_size >= 2: self.mem_fraction_static = 0.85 + elif self.tp_size >= 2: + self.mem_fraction_static = 0.87 else: - self.mem_fraction_static = 0.90 + self.mem_fraction_static = 0.88 if isinstance(self.additional_ports, int): self.additional_ports = [self.additional_ports] elif self.additional_ports is None: @@ -164,8 +196,24 @@ def add_cli_args(parser: argparse.ArgumentParser): "--quantization", type=str, default=ServerArgs.quantization, + choices=[ + "awq", + "fp8", + "gptq", + "marlin", + "gptq_marlin", + "awq_marlin", + "squeezellm", + "bitsandbytes", + ], help="The quantization method.", ) + parser.add_argument( + "--served-model-name", + type=str, + default=ServerArgs.served_model_name, + help="Override the model name returned by the v1/models endpoint in OpenAI API server.", + ) parser.add_argument( "--chat-template", type=str, @@ -191,11 +239,23 @@ def add_cli_args(parser: argparse.ArgumentParser): help="The maximum number of running requests.", ) parser.add_argument( - "--schedule-heuristic", + "--max-num-reqs", + type=int, + default=ServerArgs.max_num_reqs, + help="The maximum number of requests to serve in the memory pool. If the model have a large context length, you may need to decrease this value to avoid out-of-memory errors.", + ) + parser.add_argument( + "--max-total-tokens", + type=int, + default=ServerArgs.max_total_tokens, + help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. This option is typically used for development and debugging purposes.", + ) + parser.add_argument( + "--schedule-policy", type=str, - default=ServerArgs.schedule_heuristic, + default=ServerArgs.schedule_policy, choices=["lpm", "random", "fcfs", "dfs-weight"], - help="The scheduling heuristic.", + help="The scheduling policy of the requests.", ) parser.add_argument( "--schedule-conservativeness", @@ -204,6 +264,7 @@ def add_cli_args(parser: argparse.ArgumentParser): help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.", ) parser.add_argument( + "--tensor-parallel-size", "--tp-size", type=int, default=ServerArgs.tp_size, @@ -241,17 +302,24 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--show-time-cost", action="store_true", - help="Show time cost of custom marks", + help="Show time cost of custom marks.", ) parser.add_argument( "--api-key", type=str, default=ServerArgs.api_key, - help="Set API key of the server", + help="Set API key of the server. It is also used in the OpenAI API compatible server.", + ) + parser.add_argument( + "--file-storage-pth", + type=str, + default=ServerArgs.file_storage_pth, + help="The path of the file storage in backend.", ) # Data parallelism parser.add_argument( + "--data-parallel-size", "--dp-size", type=int, default=ServerArgs.dp_size, @@ -275,25 +343,38 @@ def add_cli_args(parser: argparse.ArgumentParser): help="The nccl init address of multi-node server.", ) parser.add_argument( - "--nnodes", type=int, default=1, help="The number of nodes." + "--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes." ) parser.add_argument("--node-rank", type=int, help="The node rank.") + # Chunked prefill + parser.add_argument( + "--chunked-prefill-size", + type=int, + default=ServerArgs.chunked_prefill_size, + help="The size of the chunked prefill.", + ) + # Optimization/debug options parser.add_argument( "--disable-flashinfer", action="store_true", - help="Disable flashinfer inference kernels", + help="Disable flashinfer attention kernels.", + ) + parser.add_argument( + "--disable-flashinfer-sampling", + action="store_true", + help="Disable flashinfer sampling kernels.", ) parser.add_argument( "--disable-radix-cache", action="store_true", - help="Disable RadixAttention", + help="Disable RadixAttention for prefix caching.", ) parser.add_argument( "--disable-regex-jump-forward", action="store_true", - help="Disable regex jump-forward", + help="Disable regex jump-forward.", ) parser.add_argument( "--disable-cuda-graph", @@ -305,6 +386,21 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Disable disk cache to avoid possible crashes related to file system or high concurrency.", ) + parser.add_argument( + "--enable-torch-compile", + action="store_true", + help="Optimize the model with torch.compile, experimental feature.", + ) + parser.add_argument( + "--enable-p2p-check", + action="store_true", + help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.", + ) + parser.add_argument( + "--enable-mla", + action="store_true", + help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2", + ) parser.add_argument( "--attention-reduce-in-fp32", action="store_true", @@ -312,13 +408,15 @@ def add_cli_args(parser: argparse.ArgumentParser): "This only affects Triton attention kernels", ) parser.add_argument( - "--enable-p2p-check", + "--efficient-weight-load", action="store_true", - help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.", + help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).", ) @classmethod def from_cli_args(cls, args: argparse.Namespace): + args.tp_size = args.tensor_parallel_size + args.dp_size = args.data_parallel_size attrs = [attr.name for attr in dataclasses.fields(cls)] return cls(**{attr: getattr(args, attr) for attr in attrs}) @@ -334,17 +432,18 @@ def print_mode_args(self): f"disable_disk_cache={self.disable_disk_cache}, " ) - -@dataclasses.dataclass -class ModelPortArgs: - nccl_port: int - model_tp_ips: List[str] - model_tp_ports: List[int] + def check_server_args(self): + assert ( + self.tp_size % self.nnodes == 0 + ), "tp_size must be divisible by number of nodes" + assert not ( + self.dp_size > 1 and self.node_rank is not None + ), "multi-node data parallel is not supported" @dataclasses.dataclass class PortArgs: tokenizer_port: int - router_port: int + controller_port: int detokenizer_port: int - model_port_args: List[ModelPortArgs] + nccl_ports: List[int] diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 03a2d60abf4..dd41156f340 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1,11 +1,26 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + """Common utilities.""" import base64 import fcntl import logging -import multiprocessing import os import random +import resource import socket import struct import time @@ -16,13 +31,18 @@ import numpy as np import psutil import requests -import rpyc import torch -import triton +import torch.distributed as dist from fastapi.responses import JSONResponse from packaging import version as pkg_version -from rpyc.utils.server import ThreadedServer from starlette.middleware.base import BaseHTTPMiddleware +from torch.nn.parameter import Parameter +from triton.runtime.cache import ( + FileCacheManager, + default_cache_dir, + default_dump_dir, + default_override_dir, +) logger = logging.getLogger(__name__) @@ -148,7 +168,6 @@ def is_port_available(port): def allocate_init_ports( port: Optional[int] = None, additional_ports: Optional[List[int]] = None, - tp_size: int = 1, dp_size: int = 1, ): """Allocate ports for all connections.""" @@ -160,8 +179,8 @@ def allocate_init_ports( ret_ports = list(set(x for x in ret_ports if is_port_available(x))) cur_port = ret_ports[-1] + 1 if len(ret_ports) > 0 else 10000 - # HTTP + Tokenizer + Controller + Detokenizer + dp_size * (nccl + tp_size) - num_ports_needed = 4 + dp_size * (1 + tp_size) + # HTTP + Tokenizer + Controller + Detokenizer + dp_size * 1 (nccl) + num_ports_needed = 4 + dp_size while len(ret_ports) < num_ports_needed: if cur_port not in ret_ports and is_port_available(cur_port): ret_ports.append(cur_port) @@ -188,71 +207,6 @@ def get_int_token_logit_bias(tokenizer, vocab_size): return logit_bias -def wrap_kernel_launcher(kernel): - """A faster launcher for triton kernels.""" - if int(triton.__version__.split(".")[0]) >= 3: - return None - - gpu_id = torch.cuda.current_device() - kernels = kernel.cache[gpu_id].values() - kernel = next(iter(kernels)) - - # Different trition versions use different low-level names - if hasattr(kernel, "cu_function"): - kfunction = kernel.cu_function - else: - kfunction = kernel.function - - if hasattr(kernel, "c_wrapper"): - run = kernel.c_wrapper - else: - run = kernel.run - - add_cluster_dim = True - - def ret_func(grid, num_warps, *args): - nonlocal add_cluster_dim - - try: - if add_cluster_dim: - run( - grid[0], - grid[1], - grid[2], - num_warps, - 1, - 1, - 1, - 1, - kernel.shared, - 0, - kfunction, - None, - None, - kernel, - *args, - ) - else: - run( - grid[0], - grid[1], - grid[2], - num_warps, - kernel.shared, - 0, - kfunction, - None, - None, - kernel, - *args, - ) - except TypeError: - add_cluster_dim = not add_cluster_dim - ret_func(grid, num_warps, *args) - - return ret_func - - def is_multimodal_model(model): from sglang.srt.model_config import ModelConfig @@ -269,6 +223,15 @@ def is_multimodal_model(model): raise ValueError("unrecognized type") +def is_generation_model(model_architectures): + if ( + "LlamaEmbeddingModel" in model_architectures + or "MistralModel" in model_architectures + ): + return False + return True + + def decode_video_base64(video_base64): from PIL import Image @@ -371,49 +334,6 @@ def load_image(image_file): return image, image_size -def connect_rpyc_service(host, port): - repeat_count = 0 - while repeat_count < 20: - try: - con = rpyc.connect( - host, - port, - config={ - "allow_public_attrs": True, - "allow_pickle": True, - "sync_request_timeout": 3600, - }, - ) - break - except ConnectionRefusedError as e: - time.sleep(1) - repeat_count += 1 - if repeat_count == 20: - raise RuntimeError(f"Connect rpyc error: {e}") - - return con.root - - -def start_rpyc_service(service: rpyc.Service, port: int): - t = ThreadedServer( - service=service, - port=port, - protocol_config={ - "allow_public_attrs": True, - "allow_pickle": True, - "sync_request_timeout": 3600, - }, - ) - t.logger.setLevel(logging.WARN) - t.start() - - -def start_rpyc_service_process(service: rpyc.Service, port: int): - proc = multiprocessing.Process(target=start_rpyc_service, args=(service, port)) - proc.start() - return proc - - def suppress_other_loggers(): from vllm.logger import logger as vllm_default_logger @@ -422,6 +342,9 @@ def suppress_other_loggers(): logging.getLogger("vllm.distributed.device_communicators.pynccl").setLevel( logging.WARN ) + logging.getLogger("vllm.distributed.device_communicators.shm_broadcast").setLevel( + logging.WARN + ) logging.getLogger("vllm.selector").setLevel(logging.WARN) logging.getLogger("vllm.utils").setLevel(logging.WARN) @@ -445,13 +368,33 @@ def kill_parent_process(): """Kill the parent process and all children of the parent process.""" current_process = psutil.Process() parent_process = current_process.parent() - children = current_process.children(recursive=True) + children = parent_process.children(recursive=True) for child in children: if child.pid != current_process.pid: os.kill(child.pid, 9) os.kill(parent_process.pid, 9) +def kill_child_process(pid, including_parent=True): + try: + parent = psutil.Process(pid) + except psutil.NoSuchProcess: + return + + children = parent.children(recursive=True) + for child in children: + try: + child.kill() + except psutil.NoSuchProcess: + pass + + if including_parent: + try: + parent.kill() + except psutil.NoSuchProcess: + pass + + def monkey_patch_vllm_p2p_access_check(gpu_id: int): """ Monkey patch the slow p2p access check in vllm. @@ -474,9 +417,9 @@ def monkey_patch_vllm_dummy_weight_loader(): DummyModelLoader, LoRAConfig, ModelConfig, + MultiModalConfig, ParallelConfig, SchedulerConfig, - MultiModalConfig, _initialize_model, initialize_dummy_weights, nn, @@ -521,24 +464,88 @@ def load_model( setattr(DummyModelLoader, "load_model", load_model) -API_KEY_HEADER_NAME = "X-API-Key" +vllm_all_gather_backup = None + + +def monkey_patch_vllm_all_gather(reverse: bool = False): + """Monkey patch all-gather to remove in-place operations.""" + from torch.distributed import _functional_collectives as funcol + from vllm.distributed.parallel_state import GroupCoordinator + + global vllm_all_gather_backup + if vllm_all_gather_backup is None: + vllm_all_gather_backup = GroupCoordinator.all_gather + + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert ( + -input_.dim() <= dim < input_.dim() + ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + input_size = input_.size() + # Allocate output tensor. + output_tensor = torch.empty( + (world_size,) + input_size, dtype=input_.dtype, device=input_.device + ) + output_tensor = funcol.all_gather_tensor( + input_, gather_dim=0, group=self.device_group + ).view((world_size,) + input_size) -class APIKeyValidatorMiddleware(BaseHTTPMiddleware): - def __init__(self, app, api_key: str): - super().__init__(app) - self.api_key = api_key + # Reshape + output_tensor = output_tensor.movedim(0, dim) + output_tensor = output_tensor.reshape( + input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :] + ) + return output_tensor - async def dispatch(self, request, call_next): - # extract API key from the request headers - api_key_header = request.headers.get(API_KEY_HEADER_NAME) - if not api_key_header or api_key_header != self.api_key: - return JSONResponse( - status_code=403, - content={"detail": "Invalid API Key"}, + if reverse: + setattr(GroupCoordinator, "all_gather", vllm_all_gather_backup) + else: + setattr(GroupCoordinator, "all_gather", all_gather) + + +def maybe_set_triton_cache_manager() -> None: + """Set environment variable to tell Triton to use a + custom cache manager""" + cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None) + if cache_manger is None: + manager = "sglang.srt.utils:CustomCacheManager" + logger.debug("Setting Triton cache manager to: %s", manager) + os.environ["TRITON_CACHE_MANAGER"] = manager + + +class CustomCacheManager(FileCacheManager): + # Adapted from: https://github.com/tdoublep/vllm/blob/3307522289fdfefe323b6c00d0db696651989a2f/vllm/triton_utils/custom_cache_manager.py + def __init__(self, key, override=False, dump=False): + + self.key = key + self.lock_path = None + if dump: + self.cache_dir = default_dump_dir() + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") + os.makedirs(self.cache_dir, exist_ok=True) + elif override: + self.cache_dir = default_override_dir() + self.cache_dir = os.path.join(self.cache_dir, self.key) + else: + # create cache directory if it doesn't exist + self.cache_dir = ( + os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir() ) - response = await call_next(request) - return response + if self.cache_dir: + self.cache_dir = f"{self.cache_dir}_{os.getpid()}" + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") + os.makedirs(self.cache_dir, exist_ok=True) + else: + raise RuntimeError("Could not create or locate cache dir") def get_ip_address(ifname): @@ -559,7 +566,6 @@ def get_ip_address(ifname): def send_addrs_to_rank_0(model_port_args, server_args): assert server_args.node_rank != 0 and server_args.dp_size == 1 - import torch.distributed as dist ifname = os.environ.get( "SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0") @@ -591,7 +597,6 @@ def send_addrs_to_rank_0(model_port_args, server_args): def receive_addrs(model_port_args, server_args): assert server_args.node_rank == 0 and server_args.dp_size == 1 - import torch.distributed as dist ifname = os.environ.get( "SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0") @@ -624,3 +629,95 @@ def receive_addrs(model_port_args, server_args): dist.barrier() dist.destroy_process_group() + + +def set_ulimit(target_soft_limit=65535): + resource_type = resource.RLIMIT_NOFILE + current_soft, current_hard = resource.getrlimit(resource_type) + + if current_soft < target_soft_limit: + try: + resource.setrlimit(resource_type, (target_soft_limit, current_hard)) + except ValueError as e: + logger.warn(f"Fail to set RLIMIT_NOFILE: {e}") + + +def is_llama3_405b_fp8(model_config): + """Return whether the model is meta-llama/Meta-Llama-3.1-405B-FP8 with 16 kv heads.""" + if ( + model_config.hf_config.architectures[0] == "LlamaForCausalLM" + and model_config.hf_config.hidden_size == 16384 + and model_config.hf_config.intermediate_size == 53248 + and model_config.hf_config.num_hidden_layers == 126 + and model_config.hf_config.num_key_value_heads == 16 + and hasattr(model_config.hf_config, "quantization_config") + and model_config.hf_config.quantization_config["quant_method"] == "fbgemm_fp8" + ): + return True + return False + + +def monkey_patch_vllm_qvk_linear_loader(): + """A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints.""" + from vllm.model_executor.layers.linear import QKVParallelLinear + + origin_weight_loader = QKVParallelLinear.weight_loader + + def get_original_weight(loaded_weight, head_dim): + n_kv_head = loaded_weight.shape[0] // (2 * head_dim) + dim = loaded_weight.shape[1] + for i in range(n_kv_head): + loaded_weight[i * head_dim : (i + 1) * head_dim, :] = loaded_weight[ + 2 * i * head_dim : (2 * i + 1) * head_dim, : + ] + original_kv_weight = loaded_weight[: n_kv_head * head_dim, :] + assert original_kv_weight.shape == (n_kv_head * head_dim, dim) + return original_kv_weight + + def weight_loader_srt( + self, + param: Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[str] = None, + ): + if ( + loaded_shard_id in ["k", "v"] + and loaded_weight.shape[0] == self.head_size * self.total_num_kv_heads * 2 + ): + loaded_weight = get_original_weight(loaded_weight, self.head_size) + + origin_weight_loader(self, param, loaded_weight, loaded_shard_id) + + setattr(QKVParallelLinear, "weight_loader", weight_loader_srt) + + +def add_api_key_middleware(app, api_key): + @app.middleware("http") + async def authentication(request, call_next): + if request.method == "OPTIONS": + return await call_next(request) + if request.url.path.startswith("/health"): + return await call_next(request) + if request.headers.get("Authorization") != "Bearer " + api_key: + return JSONResponse(content={"error": "Unauthorized"}, status_code=401) + return await call_next(request) + + +def prepare_model(model_path): + if "SGLANG_USE_MODELSCOPE" in os.environ: + if not os.path.exists(model_path): + from modelscope import snapshot_download + + return snapshot_download(model_path) + return model_path + + +def prepare_tokenizer(tokenizer_path): + if "SGLANG_USE_MODELSCOPE" in os.environ: + if not os.path.exists(tokenizer_path): + from modelscope import snapshot_download + + return snapshot_download( + tokenizer_path, ignore_patterns=["*.bin", "*.safetensors"] + ) + return tokenizer_path diff --git a/python/sglang/test/run_eval.py b/python/sglang/test/run_eval.py new file mode 100644 index 00000000000..6c1f284b163 --- /dev/null +++ b/python/sglang/test/run_eval.py @@ -0,0 +1,115 @@ +""" +Usage: +python3 -m sglang.test.run_eval --port 30000 --eval-name mmlu --num-examples 10 +""" + +import argparse +import json +import os +import time + +from sglang.test.simple_eval_common import ( + ChatCompletionSampler, + make_report, + set_ulimit, +) + + +def run_eval(args): + if "OPENAI_API_KEY" not in os.environ: + os.environ["OPENAI_API_KEY"] = "EMPTY" + + base_url = ( + f"{args.base_url}/v1" if args.base_url else f"http://{args.host}:{args.port}/v1" + ) + + if args.eval_name == "mmlu": + from sglang.test.simple_eval_mmlu import MMLUEval + + filename = "https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv" + eval_obj = MMLUEval(filename, args.num_examples, args.num_threads) + elif args.eval_name == "math": + from sglang.test.simple_eval_math import MathEval + + equality_checker = ChatCompletionSampler(model="gpt-4-turbo") + + filename = ( + "https://openaipublic.blob.core.windows.net/simple-evals/math_test.csv" + ) + eval_obj = MathEval( + filename, equality_checker, args.num_examples, args.num_threads + ) + elif args.eval_name == "gpqa": + from sglang.test.simple_eval_gpqa import GPQAEval + + filename = ( + "https://openaipublic.blob.core.windows.net/simple-evals/gpqa_diamond.csv" + ) + eval_obj = GPQAEval(filename, args.num_examples, args.num_threads) + elif args.eval_name == "humaneval": + from sglang.test.simple_eval_humaneval import HumanEval + + eval_obj = HumanEval(args.num_examples, args.num_threads) + else: + raise ValueError(f"Invalid eval name: {args.eval_name}") + + sampler = ChatCompletionSampler( + model=args.model, + max_tokens=2048, + base_url=base_url, + ) + + # Run eval + tic = time.time() + result = eval_obj(sampler) + latency = time.time() - tic + + # Dump reports + metrics = result.metrics | {"score": result.score} + file_stem = f"{args.eval_name}_{sampler.model.replace('/', '_')}" + report_filename = f"/tmp/{file_stem}.html" + print(f"Writing report to {report_filename}") + with open(report_filename, "w") as fh: + fh.write(make_report(result)) + metrics = result.metrics | {"score": result.score} + print(metrics) + result_filename = f"/tmp/{file_stem}.json" + with open(result_filename, "w") as f: + f.write(json.dumps(metrics, indent=2)) + print(f"Writing results to {result_filename}") + + # Print results + print(f"Total latency: {latency:.3f} s") + print(f"Score: {metrics['score']:.3f}") + + return metrics + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--base-url", + type=str, + default=None, + help="Server or API base url if not using http host and port.", + ) + parser.add_argument( + "--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0." + ) + parser.add_argument( + "--port", + type=int, + help="If not set, the default port is configured according to its default value for different LLM Inference Engines.", + ) + parser.add_argument( + "--model", + type=str, + help="Name or path of the model. If not set, the default model will request /v1/models for conf.", + ) + parser.add_argument("--eval-name", type=str, default="mmlu") + parser.add_argument("--num-examples", type=int) + parser.add_argument("--num-threads", type=int, default=512) + set_ulimit() + args = parser.parse_args() + + run_eval(args) diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py new file mode 100644 index 00000000000..87277ca69bb --- /dev/null +++ b/python/sglang/test/runners.py @@ -0,0 +1,234 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import json +import multiprocessing +from dataclasses import dataclass +from typing import List, Union + +import torch +import torch.nn.functional as F +from transformers import AutoModelForCausalLM, AutoTokenizer + +from sglang.srt.server import Runtime +from sglang.srt.utils import is_generation_model + +DEFAULT_PROMPTS = [ + "The capital of France is", + "The capital of the United Kindom is", + "Today is a sunny day and I like", +] + +NUM_TOP_LOGPROBS = 5 + + +def get_dtype_str(torch_dtype): + if torch_dtype is torch.float16: + return "float16" + else: + raise NotImplementedError() + + +@dataclass +class ModelOutput: + output_strs: str = None + top_input_logprobs: torch.Tensor = None + top_output_logprobs: torch.Tensor = None + embed_logits: torch.Tensor = None + + +class HFRunner: + def __init__( + self, + model_path, + torch_dtype=torch.float16, + is_generation_model=None, + ): + self.in_queue = multiprocessing.Queue() + self.out_queue = multiprocessing.Queue() + + self.model_proc = multiprocessing.Process( + target=self.start_model_process, + args=( + self.in_queue, + self.out_queue, + model_path, + torch_dtype, + is_generation_model, + ), + ) + self.model_proc.start() + + def start_model_process( + self, in_queue, out_queue, model_path, torch_dtype, is_generation_model + ): + self.tokenizer = AutoTokenizer.from_pretrained( + model_path, + torch_dtype=torch_dtype, + trust_remote_code=True, + ) + + self.is_generation_model = ( + is_generation_model(model_path) + if is_generation_model is None + else is_generation_model + ) + if self.is_generation_model: + self.model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=torch_dtype, + low_cpu_mem_usage=True, + trust_remote_code=True, + ).cuda() + else: + from sentence_transformers import SentenceTransformer + + self.model = SentenceTransformer( + model_path, + model_kwargs={"torch_dtype": torch_dtype}, + ) + + while True: + prompts, max_new_tokens = in_queue.get() + if prompts is not None: + if self.is_generation_model: + output_strs = [] + prefill_logprobs = [] + for p in prompts: + if isinstance(p, str): + input_ids = self.tokenizer.encode( + p, return_tensors="pt" + ).cuda() + else: + input_ids = torch.tensor([p], device="cuda") + + output_ids = self.model.generate( + input_ids, do_sample=False, max_new_tokens=max_new_tokens + ) + output_strs.append(self.tokenizer.decode(output_ids[0])) + + logits = self.model.forward(input_ids).logits[0] + logprobs = F.log_softmax( + logits, dim=-1, dtype=torch.float32 + ).tolist() + # index_of_max = (lambda nums: nums.index(max(nums)))(logprobs[-1]) + # print("index", index_of_max) + logprobs = [ + sorted(token_logprobs, reverse=True)[:NUM_TOP_LOGPROBS] + for token_logprobs in logprobs + ] + prefill_logprobs.append(logprobs) + + out_queue.put( + ModelOutput( + output_strs=output_strs, top_input_logprobs=prefill_logprobs + ) + ) + + else: + logits = self.model.encode(prompts).tolist() + + out_queue.put(ModelOutput(embed_logits=logits)) + + def forward( + self, + prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS, + max_new_tokens=64, + ): + self.in_queue.put((prompts, max_new_tokens)) + return self.out_queue.get() + + def terminate(self): + self.model_proc.terminate() + self.in_queue = self.out_queue = None + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.model_proc.terminate() + self.in_queue = self.out_queue = None + + +class SRTRunner: + def __init__( + self, + model_path, + tp_size=1, + torch_dtype=torch.float16, + is_generation_model=None, + ): + self.is_generation_model = ( + is_generation_model(model_path) + if is_generation_model is None + else is_generation_model + ) + self.runtime = Runtime( + model_path=model_path, + tp_size=tp_size, + dtype=get_dtype_str(torch_dtype), + ) + + def forward( + self, + prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS, + max_new_tokens=64, + ): + if self.is_generation_model: + # the return value contains logprobs from prefill + output_strs = [] + top_input_logprobs = [] + sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0} + for prompt in prompts: + response = self.runtime.generate( + prompt, + sampling_params=sampling_params, + return_logprob=True, + top_logprobs_num=NUM_TOP_LOGPROBS, + ) + response = json.loads(response) + output_strs.append(response["text"]) + top_input_logprobs.append( + [ + [tup[0] for tup in x[:NUM_TOP_LOGPROBS]] + for x in response["meta_info"]["input_top_logprobs"][1:] + ] + + [ + [ + tup[0] + for tup in response["meta_info"]["output_top_logprobs"][0][ + :NUM_TOP_LOGPROBS + ] + ] + ] + ) + + return ModelOutput( + output_strs=output_strs, top_input_logprobs=top_input_logprobs + ) + else: + logits = [] + for prompt in prompts: + response = self.runtime.encode(prompt) + response = json.loads(response) + logits.append(response["embedding"]) + return ModelOutput(embed_logits=logits) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.runtime.shutdown() + del self.runtime diff --git a/python/sglang/test/simple_eval_common.py b/python/sglang/test/simple_eval_common.py new file mode 100644 index 00000000000..4cfd3515fe2 --- /dev/null +++ b/python/sglang/test/simple_eval_common.py @@ -0,0 +1,467 @@ +# Adapted from https://github.com/openai/simple-evals/ + +import base64 +import os +import resource +import time +from collections import defaultdict +from dataclasses import dataclass, field +from multiprocessing.pool import ThreadPool +from typing import Any, Dict, List, Tuple + +import httpx +import jinja2 +import numpy as np +import openai +import requests +from openai import OpenAI +from tqdm import tqdm + +OPENAI_SYSTEM_MESSAGE_API = "You are a helpful assistant." +OPENAI_SYSTEM_MESSAGE_CHATGPT = ( + "You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture." + + "\nKnowledge cutoff: 2023-12\nCurrent date: 2024-04-01" +) + + +Message = Dict[str, Any] # keys role, content +MessageList = List[Message] + + +class SamplerBase: + """ + Base class for defining a sampling model, which can be evaluated, + or used as part of the grading process. + """ + + def __call__(self, message_list: MessageList) -> str: + raise NotImplementedError() + + +@dataclass +class EvalResult: + """ + Result of running an evaluation (usually consisting of many samples) + """ + + score: float | None # top-line metric + metrics: Dict[str, float] | None # other metrics + htmls: List[str] # strings of valid HTML + convos: List[MessageList] # sampled conversations + + +@dataclass +class SingleEvalResult: + """ + Result of evaluating a single sample + """ + + score: float | None + metrics: Dict[str, float] = field(default_factory=dict) + html: str | None = None + convo: MessageList | None = None # sampled conversation + + +class Eval: + """ + Base class for defining an evaluation. + """ + + def __call__(self, sampler: SamplerBase) -> EvalResult: + raise NotImplementedError() + + +class LargerHttpxClient(httpx.Client): + def __init__(self): + timeout_config = httpx.Timeout(3600) + limits = httpx.Limits( + max_keepalive_connections=3600, + max_connections=3600, + ) + super().__init__(timeout=timeout_config, limits=limits) + + +class ChatCompletionSampler(SamplerBase): + """ + Sample from OpenAI's chat completion API + """ + + def __init__( + self, + base_url: str = None, + model: str | None = None, + system_message: str | None = None, + temperature: float = 0.0, + max_tokens: int = 2048, + ): + self.client = OpenAI(base_url=base_url, http_client=LargerHttpxClient()) + + if model is None: + model = self.client.models.list().data[0].id + + self.model = model + self.system_message = system_message + self.temperature = temperature + self.max_tokens = max_tokens + self.image_format = "url" + + def _handle_image( + self, + image: str, + encoding: str = "base64", + format: str = "png", + fovea: int = 768, + ): + new_image = { + "type": "image_url", + "image_url": { + "url": f"data:image/{format};{encoding},{image}", + }, + } + return new_image + + def _handle_text(self, text: str): + return {"type": "text", "text": text} + + def _pack_message(self, role: str, content: Any): + return {"role": str(role), "content": content} + + def __call__(self, message_list: MessageList) -> str: + if self.system_message: + message_list = [ + self._pack_message("system", self.system_message) + ] + message_list + trial = 0 + while True: + try: + response = self.client.chat.completions.create( + model=self.model, + messages=message_list, + temperature=self.temperature, + max_tokens=self.max_tokens, + ) + return response.choices[0].message.content + # NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are reruning MMMU + except openai.BadRequestError as e: + print("Bad Request Error", e) + return "" + except Exception as e: + exception_backoff = 2**trial # expontial back off + print( + f"Rate limit exception so wait and retry {trial} after {exception_backoff} sec", + e, + ) + time.sleep(exception_backoff) + trial += 1 + # unknown error shall throw exception + + +QUERY_TEMPLATE_MULTICHOICE = """ +Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering. + +{Question} + +A) {A} +B) {B} +C) {C} +D) {D} +""".strip() + +ANSWER_PATTERN_MULTICHOICE = r"(?i)Answer\s*:\s*([A-D])" +ANSWER_PATTERN = r"(?i)Answer\s*:\s*([^\n]+)" + + +EQUALITY_TEMPLATE = r""" +Look at the following two expressions (answers to a math problem) and judge whether they are equivalent. Only perform trivial simplifications + +Examples: + + Expression 1: $2x+3$ + Expression 2: $3+2x$ + +Yes + + Expression 1: 3/2 + Expression 2: 1.5 + +Yes + + Expression 1: $x^2+2x+1$ + Expression 2: $y^2+2y+1$ + +No + + Expression 1: $x^2+2x+1$ + Expression 2: $(x+1)^2$ + +Yes + + Expression 1: 3245/5 + Expression 2: 649 + +No +(these are actually equal, don't mark them equivalent if you need to do nontrivial simplifications) + + Expression 1: 2/(-3) + Expression 2: -2/3 + +Yes +(trivial simplifications are allowed) + + Expression 1: 72 degrees + Expression 2: 72 + +Yes +(give benefit of the doubt to units) + + Expression 1: 64 + Expression 2: 64 square feet + +Yes +(give benefit of the doubt to units) + +--- + +YOUR TASK + + +Respond with only "Yes" or "No" (without quotes). Do not include a rationale. + + Expression 1: %(expression1)s + Expression 2: %(expression2)s +""".strip() + + +HTML_JINJA = """ +

Prompt conversation

+{% for message in prompt_messages %} +{{ message_to_html(message) | safe }} +{% endfor %} +

Sampled message

+{{ message_to_html(next_message) | safe }} +

Results

+

Correct Answer: {{ correct_answer }}

+

Extracted Answer: {{ extracted_answer }}

+

Score: {{ score }}

+""" + + +def format_multichoice_question(row): + return QUERY_TEMPLATE_MULTICHOICE.format(**row) + + +def check_equality(sampler: SamplerBase, expr1: str, expr2: str): + prompt = EQUALITY_TEMPLATE % {"expression1": expr1, "expression2": expr2} + response = sampler([dict(content=prompt, role="user")]) + return response.lower().strip() == "yes" + + +def _compute_stat(values: list, stat: str): + if stat == "mean": + return np.mean(values) + elif stat == "std": + return np.std(values) + elif stat == "min": + return np.min(values) + elif stat == "max": + return np.max(values) + else: + raise ValueError(f"Unknown {stat =}") + + +def aggregate_results( + single_eval_results: List[SingleEvalResult], + default_stats: Tuple[str] = ("mean", "std"), + name2stats: Dict[str, Tuple[str]] | None = None, +) -> EvalResult: + """ + Aggregate results from multiple evaluations into a single EvalResult. + """ + name2stats = name2stats or {} + name2values = defaultdict(list) + htmls = [] + convos = [] + for single_eval_result in single_eval_results: + for name, value in single_eval_result.metrics.items(): + name2values[name].append(value) + if single_eval_result.score is not None: + name2values["score"].append(single_eval_result.score) + htmls.append(single_eval_result.html) + convos.append(single_eval_result.convo) + final_metrics = {} + for name, values in name2values.items(): + stats = name2stats.get(name, default_stats) + for stat in stats: + key = name if stat == "mean" else f"{name}:{stat}" + final_metrics[key] = _compute_stat(values, stat) + return EvalResult( + score=final_metrics.pop("score", None), + metrics=final_metrics, + htmls=htmls, + convos=convos, + ) + + +def map_with_progress(f: callable, xs: List[Any], num_threads: int): + """ + Apply f to each element of xs, using a ThreadPool, and show progress. + """ + if os.getenv("debug"): + return list(map(f, tqdm(xs, total=len(xs)))) + else: + with ThreadPool(min(num_threads, len(xs))) as pool: + return list(tqdm(pool.imap(f, xs), total=len(xs))) + + +jinja_env = jinja2.Environment( + loader=jinja2.BaseLoader(), + undefined=jinja2.StrictUndefined, + autoescape=jinja2.select_autoescape(["html", "xml"]), +) +_message_template = """ +
+
+ {{ role }} + {% if variant %}({{ variant }}){% endif %} +
+
+
{{ content }}
+
+
+""" + + +def message_to_html(message: Message) -> str: + """ + Generate HTML snippet (inside a
) for a message. + """ + return jinja_env.from_string(_message_template).render( + role=message["role"], + content=message["content"], + variant=message.get("variant", None), + ) + + +jinja_env.globals["message_to_html"] = message_to_html + + +_report_template = """ + + + + + + {% if metrics %} +

Metrics

+ + + + + + + + + + {% for name, value in metrics.items() %} + + + + + {% endfor %} +
MetricValue
Score{{ score | float | round(3) }}
{{ name }}{{ value }}
+ {% endif %} +

Examples

+ {% for html in htmls %} + {{ html | safe }} +
+ {% endfor %} + + +""" + + +def make_report(eval_result: EvalResult) -> str: + """ + Create a standalone HTML report from an EvalResult. + """ + return jinja_env.from_string(_report_template).render( + score=eval_result.score, + metrics=eval_result.metrics, + htmls=eval_result.htmls, + ) + + +def make_report_from_example_htmls(htmls: List[str]): + """ + Create a standalone HTML report from a list of example htmls + """ + return jinja_env.from_string(_report_template).render( + score=None, metrics={}, htmls=htmls + ) + + +def download_dataset(path, url): + print(f"Downloading dataset {path} from {url}") + try: + response = requests.get(url, stream=True) + response.raise_for_status() + + total_size = int(response.headers.get("content-length", 0)) + block_size = 8192 + + with open(path, "wb") as f, tqdm( + desc="Downloading", + total=total_size, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as progress_bar: + for data in response.iter_content(block_size): + size = f.write(data) + progress_bar.update(size) + + print(f"Dataset downloaded and saved to {path}") + except requests.RequestException as e: + raise Exception(f"Failed to download dataset: {e}") + + +def set_ulimit(target_soft_limit=65535): + resource_type = resource.RLIMIT_NOFILE + current_soft, current_hard = resource.getrlimit(resource_type) + + if current_soft < target_soft_limit: + try: + resource.setrlimit(resource_type, (target_soft_limit, current_hard)) + except ValueError as e: + print(f"Fail to set RLIMIT_NOFILE: {e}") diff --git a/python/sglang/test/simple_eval_gpqa.py b/python/sglang/test/simple_eval_gpqa.py new file mode 100644 index 00000000000..46055caa5f1 --- /dev/null +++ b/python/sglang/test/simple_eval_gpqa.py @@ -0,0 +1,92 @@ +# Adapted from https://github.com/openai/simple-evals/ + +""" +GPQA: A Graduate-Level Google-Proof Q&A Benchmark +David Rein, Betty Li Hou, Asa Cooper Stickland, Jackson Petty, Richard Yuanzhe Pang, Julien Dirani, Julian Michael, Samuel R. Bowman +https://arxiv.org/abs/2311.12022 +""" + +import random +import re + +import pandas + +from sglang.test import simple_eval_common as common +from sglang.test.simple_eval_common import ( + ANSWER_PATTERN_MULTICHOICE, + HTML_JINJA, + Eval, + EvalResult, + MessageList, + SamplerBase, + SingleEvalResult, + format_multichoice_question, +) + + +class GPQAEval(Eval): + def __init__( + self, + filename: str, + num_examples: int | None, + num_threads: int, + n_repeats: int = 1, + ): + df = pandas.read_csv(filename) + examples = [row.to_dict() for _, row in df.iterrows()] + rng = random.Random(0) + if num_examples: + assert n_repeats == 1, "n_repeats only supported for num_examples" + examples = rng.sample(examples, num_examples) + examples = examples * n_repeats + examples = [ + example | {"permutation": rng.sample(range(4), 4)} for example in examples + ] + self.examples = examples + self.n_repeats = n_repeats + self.num_threads = num_threads + + def __call__(self, sampler: SamplerBase) -> EvalResult: + def fn(row: dict): + choices = [ + row["Correct Answer"], + row["Incorrect Answer 1"], + row["Incorrect Answer 2"], + row["Incorrect Answer 3"], + ] + choices = [choices[i] for i in row["permutation"]] + correct_index = choices.index(row["Correct Answer"]) + correct_answer = "ABCD"[correct_index] + choices_dict = dict( + A=choices[0], + B=choices[1], + C=choices[2], + D=choices[3], + Question=row["Question"], + ) + prompt_messages = [ + sampler._pack_message( + content=format_multichoice_question(choices_dict), role="user" + ) + ] + response_text = sampler(prompt_messages) + match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text) + extracted_answer = match.group(1) if match else None + score = 1.0 if extracted_answer == correct_answer else 0.0 + html = common.jinja_env.from_string(HTML_JINJA).render( + prompt_messages=prompt_messages, + next_message=dict(content=response_text, role="assistant"), + score=score, + correct_answer=correct_answer, + extracted_answer=extracted_answer, + ) + convo = prompt_messages + [dict(content=response_text, role="assistant")] + return SingleEvalResult( + html=html, + score=score, + convo=convo, + metrics={"chars": len(response_text)}, + ) + + results = common.map_with_progress(fn, self.examples, self.num_threads) + return common.aggregate_results(results) diff --git a/python/sglang/test/simple_eval_humaneval.py b/python/sglang/test/simple_eval_humaneval.py new file mode 100644 index 00000000000..7a0f90c4673 --- /dev/null +++ b/python/sglang/test/simple_eval_humaneval.py @@ -0,0 +1,139 @@ +# Adapted from https://github.com/openai/simple-evals/ + +""" +HumanEval: Evaluating Large Language Models Trained on Code +Mark Chen and Jerry Tworek and Heewoo Jun and Qiming Yuan and Henrique Ponde de Oliveira Pinto and Jared Kaplan and Harri Edwards and Yuri Burda and Nicholas Joseph and Greg Brockman and Alex Ray and Raul Puri and Gretchen Krueger and Michael Petrov and Heidy Khlaaf and Girish Sastry and Pamela Mishkin and Brooke Chan and Scott Gray and Nick Ryder and Mikhail Pavlov and Alethea Power and Lukasz Kaiser and Mohammad Bavarian and Clemens Winter and Philippe Tillet and Felipe Petroski Such and Dave Cummings and Matthias Plappert and Fotios Chantzis and Elizabeth Barnes and Ariel Herbert-Voss and William Hebgen Guss and Alex Nichol and Alex Paino and Nikolas Tezak and Jie Tang and Igor Babuschkin and Suchir Balaji and Shantanu Jain and William Saunders and Christopher Hesse and Andrew N. Carr and Jan Leike and Josh Achiam and Vedant Misra and Evan Morikawa and Alec Radford and Matthew Knight and Miles Brundage and Mira Murati and Katie Mayer and Peter Welinder and Bob McGrew and Dario Amodei and Sam McCandlish and Ilya Sutskever and Wojciech Zaremba +https://arxiv.org/abs/2107.03374 https://github.com/openai/human-eval/ +""" + +import json +import logging +import multiprocessing +import random +import re +from collections import Counter, defaultdict +from concurrent.futures import ThreadPoolExecutor, as_completed +from io import BytesIO +from typing import Any, Dict, List, Tuple + +import blobfile as bf +import tqdm + +try: + from human_eval.data import HUMAN_EVAL, read_problems + from human_eval.evaluation import estimate_pass_at_k + from human_eval.execution import check_correctness # , unsafe_execute +except (ImportError, ModuleNotFoundError): + print("\nPlease install human-eval at https://github.com/openai/human-eval.\n") + raise + +from sglang.test import simple_eval_common as common +from sglang.test.simple_eval_common import ( + HTML_JINJA, + Eval, + EvalResult, + SamplerBase, + SingleEvalResult, +) + + +def evaluate_functional_correctness( + sample: Dict[str, str], + completions: List[str], + n_workers: int = 4, + timeout: float = 3.0, +): + """ + Evaluates the functional correctness of generated samples, and writes + results to f"{sample_file}_results.jsonl.gz" + """ + import copy + + # Check the generated samples against test suites. + with ThreadPoolExecutor(max_workers=n_workers) as executor: + futures = [] + for i, completion in enumerate(completions): + args = (sample, completion, timeout, i) + future = executor.submit(check_correctness, *args) + futures.append(future) + results = [] + for future in as_completed(futures): + result = future.result() + results.append(result) + passed = [int(r["passed"]) for r in results] + return passed + + +class HumanEval(Eval): + def __init__( + self, + num_examples: int | None, + num_threads: int, + num_samples_per_task: int = 5, + ks_passes: List[int] = [1, 2, 5], + timeout: int = 120, + ): + self.seed = 0 + self.examples = read_problems() + self.examples = list(self.examples.values()) + + self._num_examples = num_examples + if self._num_examples: + self.examples = random.Random(self.seed).sample(self.examples, num_examples) + self._num_samples_per_task = num_samples_per_task + self._ks_passes = ks_passes + self._timeout = timeout + self._num_threads = num_threads + + def __call__(self, sampler: SamplerBase) -> EvalResult: + instruction = "Read the following function signature and docstring, and fully implement the function described. Your response should only contain the code for this function.\n" + + def find_code(completion): + pattern = re.compile(r"```python\n(.*?)```", re.DOTALL) + matches = pattern.findall(completion) + extracted_answer = matches[0] if len(matches) >= 1 else completion + extracted_answer = extracted_answer[ + extracted_answer.find(":\n ") + 2 : + ] # remove signature + return extracted_answer + + def fn(sample: Dict[str, str]): + prompt_messages = [ + sampler._pack_message( + role="user", content=instruction + sample["prompt"] + ) + ] + completions = [ + find_code(sampler(prompt_messages)) + for _ in range(self._num_samples_per_task) + ] + results = evaluate_functional_correctness(sample, completions) + total = len(results) + correct = sum(results) + score = sum(results) / len(results) + html = common.jinja_env.from_string(HTML_JINJA).render( + prompt_messages=prompt_messages, + next_message=dict(content=completions[0], role="assistant"), + score=score, + correct_answer=[1] * len(results), + extracted_answer=results, + ) + convo = prompt_messages + [ + dict(content=completion, role="assistant") for completion in completions + ] + return SingleEvalResult( + html=html, + score=score, + convo=convo, + metrics={ + f"pass@{k}": estimate_pass_at_k([total], [correct], k) + # this will be aggrated so no need of .mean() + for k in self._ks_passes + if total >= k + }, + ) + + results = common.map_with_progress( + fn, self.examples, num_threads=self._num_threads + ) + return common.aggregate_results(results) diff --git a/python/sglang/test/simple_eval_math.py b/python/sglang/test/simple_eval_math.py new file mode 100644 index 00000000000..4ddb650d965 --- /dev/null +++ b/python/sglang/test/simple_eval_math.py @@ -0,0 +1,72 @@ +# Adapted from https://github.com/openai/simple-evals/ + +""" +Measuring Mathematical Problem Solving With the MATH Dataset +Dan Hendrycks, Collin Burns, Saurav Kadavath, Akul Arora, Steven Basart, Eric Tang, Dawn Song, Jacob Steinhardt +https://arxiv.org/abs/2103.03874 +""" + +import random +import re + +import pandas + +from sglang.test import simple_eval_common as common +from sglang.test.simple_eval_common import ( + ANSWER_PATTERN, + HTML_JINJA, + Eval, + EvalResult, + SamplerBase, + SingleEvalResult, + check_equality, +) + +QUERY_TEMPLATE = """ +Solve the following math problem step by step. The last line of your response should be of the form Answer: $ANSWER (without quotes) where $ANSWER is the answer to the problem. + +{Question} + +Remember to put your answer on its own line after "Answer:", and you do not need to use a \\boxed command. +""".strip() + + +class MathEval(Eval): + def __init__( + self, + filename: str, + equality_checker: SamplerBase, + num_examples: int | None, + num_threads: int, + ): + df = pandas.read_csv(filename) + examples = [row.to_dict() for _, row in df.iterrows()] + if num_examples: + examples = random.Random(0).sample(examples, num_examples) + self.examples = examples + self.equality_checker = equality_checker + self.num_threads = num_threads + + def __call__(self, sampler: SamplerBase) -> EvalResult: + def fn(row: dict): + prompt_messages = [ + sampler._pack_message(content=QUERY_TEMPLATE.format(**row), role="user") + ] + response_text = sampler(prompt_messages) + match = re.search(ANSWER_PATTERN, response_text) + extracted_answer = match.group(1) if match else None + score = float( + check_equality(self.equality_checker, row["Answer"], extracted_answer) + ) + html = common.jinja_env.from_string(HTML_JINJA).render( + prompt_messages=prompt_messages, + next_message=dict(content=response_text, role="assistant"), + score=score, + correct_answer=row["Answer"], + extracted_answer=extracted_answer, + ) + convo = prompt_messages + [dict(content=response_text, role="assistant")] + return SingleEvalResult(html=html, score=score, convo=convo) + + results = common.map_with_progress(fn, self.examples, self.num_threads) + return common.aggregate_results(results) diff --git a/python/sglang/test/simple_eval_mmlu.py b/python/sglang/test/simple_eval_mmlu.py new file mode 100644 index 00000000000..3c0287510cb --- /dev/null +++ b/python/sglang/test/simple_eval_mmlu.py @@ -0,0 +1,120 @@ +# Adapted from https://github.com/openai/simple-evals/ + +""" +Measuring Massive Multitask Language Understanding +Dan Hendrycks, Collin Burns, Steven Basart, Andy Zou, Mantas Mazeika, Dawn Song, Jacob Steinhardt +https://arxiv.org/abs/2009.03300 +""" + +import random +import re + +import pandas + +from sglang.test import simple_eval_common as common +from sglang.test.simple_eval_common import ( + ANSWER_PATTERN_MULTICHOICE, + HTML_JINJA, + Eval, + EvalResult, + SamplerBase, + SingleEvalResult, + format_multichoice_question, +) + +subject2category = { + "abstract_algebra": "stem", + "anatomy": "other", + "astronomy": "stem", + "business_ethics": "other", + "clinical_knowledge": "other", + "college_biology": "stem", + "college_chemistry": "stem", + "college_computer_science": "stem", + "college_mathematics": "stem", + "college_medicine": "other", + "college_physics": "stem", + "computer_security": "stem", + "conceptual_physics": "stem", + "econometrics": "social_sciences", + "electrical_engineering": "stem", + "elementary_mathematics": "stem", + "formal_logic": "humanities", + "global_facts": "other", + "high_school_biology": "stem", + "high_school_chemistry": "stem", + "high_school_computer_science": "stem", + "high_school_european_history": "humanities", + "high_school_geography": "social_sciences", + "high_school_government_and_politics": "social_sciences", + "high_school_macroeconomics": "social_sciences", + "high_school_mathematics": "stem", + "high_school_microeconomics": "social_sciences", + "high_school_physics": "stem", + "high_school_psychology": "social_sciences", + "high_school_statistics": "stem", + "high_school_us_history": "humanities", + "high_school_world_history": "humanities", + "human_aging": "other", + "human_sexuality": "social_sciences", + "international_law": "humanities", + "jurisprudence": "humanities", + "logical_fallacies": "humanities", + "machine_learning": "stem", + "management": "other", + "marketing": "other", + "medical_genetics": "other", + "miscellaneous": "other", + "moral_disputes": "humanities", + "moral_scenarios": "humanities", + "nutrition": "other", + "philosophy": "humanities", + "prehistory": "humanities", + "professional_accounting": "other", + "professional_law": "humanities", + "professional_medicine": "other", + "professional_psychology": "social_sciences", + "public_relations": "social_sciences", + "security_studies": "social_sciences", + "sociology": "social_sciences", + "us_foreign_policy": "social_sciences", + "virology": "other", + "world_religions": "humanities", +} + + +class MMLUEval(Eval): + def __init__(self, filename: str, num_examples: int | None, num_threads: int): + df = pandas.read_csv(filename) + examples = [row.to_dict() for _, row in df.iterrows()] + if num_examples: + examples = random.Random(0).sample(examples, num_examples) + self.examples = examples + self.num_threads = num_threads + + def __call__(self, sampler: SamplerBase) -> EvalResult: + def fn(row: dict): + prompt_messages = [ + sampler._pack_message( + content=format_multichoice_question(row), role="user" + ) + ] + response_text = sampler(prompt_messages) + match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text) + extracted_answer = match.group(1) if match else None + score = 1.0 if extracted_answer == row["Answer"] else 0.0 + html = common.jinja_env.from_string(HTML_JINJA).render( + prompt_messages=prompt_messages, + next_message=dict(content=response_text, role="assistant"), + score=score, + correct_answer=row["Answer"], + extracted_answer=extracted_answer, + ) + convo = prompt_messages + [dict(content=response_text, role="assistant")] + category = subject2category.get(row["Subject"], "other") + return SingleEvalResult( + html=html, score=score, metrics={category: score}, convo=convo + ) + + results = common.map_with_progress(fn, self.examples, self.num_threads) + return common.aggregate_results(results) diff --git a/python/sglang/test/srt/sampling/penaltylib/utils.py b/python/sglang/test/srt/sampling/penaltylib/utils.py new file mode 100644 index 00000000000..b41eac32ba9 --- /dev/null +++ b/python/sglang/test/srt/sampling/penaltylib/utils.py @@ -0,0 +1,337 @@ +import dataclasses +import enum +import typing +import unittest + +import torch + +from sglang.srt.sampling.penaltylib.orchestrator import ( + BatchedPenalizerOrchestrator, + _BatchedPenalizer, + _BatchLike, +) + + +@dataclasses.dataclass +class MockSamplingParams: + frequency_penalty: float = 0.0 + min_new_tokens: int = 0 + stop_token_ids: typing.List[int] = None + presence_penalty: float = 0.0 + repetition_penalty: float = 1.0 + + +@dataclasses.dataclass +class MockTokenizer: + eos_token_id: int + + +@dataclasses.dataclass +class MockReq: + origin_input_ids: typing.List[int] + sampling_params: MockSamplingParams + tokenizer: MockTokenizer + + +class StepType(enum.Enum): + INPUT = "input" + OUTPUT = "output" + + +@dataclasses.dataclass +class Step: + type: StepType + token_ids: typing.List[int] + expected_tensors: typing.Dict[str, torch.Tensor] + # assume initial logits are all 1 + expected_logits: torch.Tensor + + +@dataclasses.dataclass +class Subject: + sampling_params: MockSamplingParams + # first step must be input, which will be converted to Req + steps: typing.List[Step] + eos_token_id: int = -1 + + def __post_init__(self): + if self.steps[0].type != StepType.INPUT: + raise ValueError("First step must be input") + + # each steps should have the same expected_tensors.keys() + for i in range(1, len(self.steps)): + if self.tensor_keys(i) != self.tensor_keys(): + raise ValueError( + f"Expected tensors keys must be the same for all steps. Got {self.steps[i].expected_tensors.keys()} for key={i} and {self.steps[0].expected_tensors.keys()}" + ) + + def tensor_keys(self, i: int = 0) -> typing.Set[str]: + return set(self.steps[i].expected_tensors.keys()) + + def to_req(self) -> MockReq: + return MockReq( + origin_input_ids=self.steps[0].token_ids, + sampling_params=self.sampling_params, + tokenizer=MockTokenizer(eos_token_id=self.eos_token_id), + ) + + +@dataclasses.dataclass +class Case: + enabled: bool + test_subjects: typing.List[Subject] + + def __post_init__(self): + # each test_subjects.steps should have the same expected_tensors.keys() + for i in range(1, len(self.test_subjects)): + if self.tensor_keys(i) != self.tensor_keys(): + raise ValueError( + f"Expected tensors keys must be the same for all test_subjects. Got {self.test_subjects[i].tensor_keys()} for key={i} and {self.test_subjects[0].tensor_keys()}" + ) + + def tensor_keys(self, i: int = 0) -> typing.List[str]: + return set(self.test_subjects[i].tensor_keys()) + + +class BaseBatchedPenalizerTest(unittest.TestCase): + Penalizer: typing.Type[_BatchedPenalizer] + device = "cuda" + vocab_size = 5 + + enabled: Subject = None + disabled: Subject = None + + def setUp(self): + if self.__class__ == BaseBatchedPenalizerTest: + self.skipTest("Base class for penalizer tests") + + self.create_test_subjects() + self.create_test_cases() + + def tensor(self, data, **kwargs) -> torch.Tensor: + """ + Shortcut to create a tensor with device=self.device. + """ + return torch.tensor(data, **kwargs, device=self.device) + + def create_test_subjects(self) -> typing.List[Subject]: + raise NotImplementedError() + + def create_test_cases(self): + self.test_cases = [ + Case(enabled=True, test_subjects=[self.enabled]), + Case(enabled=False, test_subjects=[self.disabled]), + Case(enabled=True, test_subjects=[self.enabled, self.disabled]), + ] + + def _create_penalizer( + self, case: Case + ) -> typing.Tuple[BatchedPenalizerOrchestrator, _BatchedPenalizer]: + orchestrator = BatchedPenalizerOrchestrator( + vocab_size=self.vocab_size, + batch=_BatchLike(reqs=[subject.to_req() for subject in case.test_subjects]), + device=self.device, + Penalizers={self.Penalizer}, + ) + + return orchestrator, orchestrator.penalizers[self.Penalizer] + + def test_is_required(self): + for case in self.test_cases: + with self.subTest(case=case): + _, penalizer = self._create_penalizer(case) + self.assertEqual(case.enabled, penalizer.is_required()) + + def test_prepare(self): + for case in self.test_cases: + with self.subTest(case=case): + orchestrator, penalizer = self._create_penalizer(case) + self.assertEqual(case.enabled, penalizer.is_prepared()) + + if case.enabled: + for key, tensor in { + key: torch.cat( + tensors=[ + subject.steps[0].expected_tensors[key] + for subject in case.test_subjects + ], + ) + for key in case.tensor_keys() + }.items(): + torch.testing.assert_close( + actual=getattr(penalizer, key), + expected=tensor, + msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}", + ) + + actual = orchestrator.apply( + torch.ones( + size=(len(case.test_subjects), self.vocab_size), + dtype=torch.float32, + device=self.device, + ) + ) + expected = torch.cat( + tensors=[ + subject.steps[0].expected_logits + for subject in case.test_subjects + ], + ) + torch.testing.assert_close( + actual=actual, + expected=expected, + msg=f"logits\nactual={actual}\nexpected={expected}", + ) + + def test_teardown(self): + for case in self.test_cases: + with self.subTest(case=case): + _, penalizer = self._create_penalizer(case) + penalizer.teardown() + + for key in case.test_subjects[0].steps[0].expected_tensors.keys(): + self.assertIsNone(getattr(penalizer, key, None)) + + def test_filter(self): + for case in self.test_cases: + with self.subTest(case=case): + orchestrator, penalizer = self._create_penalizer(case) + + indices_to_keep = [0] + orchestrator.filter(indices_to_keep=indices_to_keep) + + filtered_subjects = [case.test_subjects[i] for i in indices_to_keep] + + if penalizer.is_required(): + self.assertTrue(penalizer.is_prepared()) + for key, tensor in { + key: torch.cat( + tensors=[ + subject.steps[0].expected_tensors[key] + for subject in filtered_subjects + ], + ) + for key in case.tensor_keys() + }.items(): + torch.testing.assert_close( + actual=getattr(penalizer, key), + expected=tensor, + msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}", + ) + + actual_logits = orchestrator.apply( + torch.ones( + size=(len(filtered_subjects), self.vocab_size), + dtype=torch.float32, + device=self.device, + ) + ) + filtered_expected_logits = torch.cat( + tensors=[ + subject.steps[0].expected_logits + for subject in filtered_subjects + ], + ) + torch.testing.assert_close( + actual=actual_logits, + expected=filtered_expected_logits, + msg=f"logits\nactual={actual_logits}\nexpected={filtered_expected_logits}", + ) + + def test_merge_enabled_with_disabled(self): + enabled_test_case = self.test_cases[0] + disabled_test_case = self.test_cases[1] + + orchestrator, penalizer = self._create_penalizer(enabled_test_case) + theirs, _ = self._create_penalizer(disabled_test_case) + + orchestrator.merge(theirs) + + for key, tensor in { + key: torch.cat( + tensors=[ + enabled_test_case.test_subjects[0].steps[0].expected_tensors[key], + disabled_test_case.test_subjects[0].steps[0].expected_tensors[key], + ], + ) + for key in enabled_test_case.tensor_keys() + }.items(): + torch.testing.assert_close( + actual=getattr(penalizer, key), + expected=tensor, + msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}", + ) + + def test_cumulate_apply_repeat(self): + for case in self.test_cases: + with self.subTest(case=case): + orchestrator, penalizer = self._create_penalizer(case) + + max_step = max(len(subject.steps) for subject in case.test_subjects) + for i in range(1, max_step): + orchestrator.filter( + indices_to_keep=[ + j + for j, subject in enumerate(case.test_subjects) + if i < len(subject.steps) + ] + ) + + filtered_subjects = [ + subject + for subject in case.test_subjects + if i < len(subject.steps) + ] + + inputs: typing.List[typing.List[int]] = [] + outputs: typing.List[typing.List[int]] = [] + for subject in filtered_subjects: + step = subject.steps[i] + if step.type == StepType.INPUT: + inputs.append(step.token_ids) + outputs.append([]) + else: + inputs.append([]) + outputs.append(step.token_ids) + + if any(inputs): + orchestrator.cumulate_input_tokens(inputs) + + if any(outputs): + orchestrator.cumulate_output_tokens(outputs) + + if penalizer.is_required(): + self.assertTrue(penalizer.is_prepared()) + for key, tensor in { + key: torch.cat( + tensors=[ + subject.steps[i].expected_tensors[key] + for subject in filtered_subjects + ], + ) + for key in case.tensor_keys() + }.items(): + torch.testing.assert_close( + actual=getattr(penalizer, key), + expected=tensor, + msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}", + ) + + actual_logits = orchestrator.apply( + torch.ones( + size=(len(filtered_subjects), self.vocab_size), + dtype=torch.float32, + device=self.device, + ) + ) + filtered_expected_logits = torch.cat( + tensors=[ + subject.steps[i].expected_logits + for subject in filtered_subjects + ], + ) + torch.testing.assert_close( + actual=actual_logits, + expected=filtered_expected_logits, + msg=f"logits\nactual={actual_logits}\nexpected={filtered_expected_logits}", + ) diff --git a/python/sglang/test/test_conversation.py b/python/sglang/test/test_conversation.py deleted file mode 100644 index 11e837ddbde..00000000000 --- a/python/sglang/test/test_conversation.py +++ /dev/null @@ -1,46 +0,0 @@ -from sglang.srt.conversation import generate_chat_conv -from sglang.srt.managers.openai_protocol import ( - ChatCompletionMessageContentImagePart, - ChatCompletionMessageContentImageURL, - ChatCompletionMessageContentTextPart, - ChatCompletionMessageGenericParam, - ChatCompletionMessageUserParam, - ChatCompletionRequest, -) - - -def test_chat_completion_to_conv_image(): - """Test that we can convert a chat image request to a convo""" - request = ChatCompletionRequest( - model="default", - messages=[ - ChatCompletionMessageGenericParam( - role="system", content="You are a helpful AI assistant" - ), - ChatCompletionMessageUserParam( - role="user", - content=[ - ChatCompletionMessageContentTextPart( - type="text", text="Describe this image" - ), - ChatCompletionMessageContentImagePart( - type="image_url", - image_url=ChatCompletionMessageContentImageURL( - url="https://someurl.com" - ), - ), - ], - ), - ], - ) - conv = generate_chat_conv(request, "vicuna_v1.1") - assert conv.messages == [ - ["USER", "Describe this image"], - ["ASSISTANT", None], - ] - assert conv.system_message == "You are a helpful AI assistant" - assert conv.image_data == ["https://someurl.com"] - - -if __name__ == "__main__": - test_chat_completion_to_conv_image() diff --git a/python/sglang/test/test_openai_protocol.py b/python/sglang/test/test_openai_protocol.py deleted file mode 100644 index 99e7a8089cf..00000000000 --- a/python/sglang/test/test_openai_protocol.py +++ /dev/null @@ -1,51 +0,0 @@ -from sglang.srt.managers.openai_protocol import ( - ChatCompletionMessageContentImagePart, - ChatCompletionMessageContentImageURL, - ChatCompletionMessageContentTextPart, - ChatCompletionMessageGenericParam, - ChatCompletionMessageUserParam, - ChatCompletionRequest, -) - - -def test_chat_completion_request_image(): - """Test that Chat Completion Requests with images can be converted.""" - - image_request = { - "model": "default", - "messages": [ - {"role": "system", "content": "You are a helpful AI assistant"}, - { - "role": "user", - "content": [ - {"type": "text", "text": "Describe this image"}, - {"type": "image_url", "image_url": {"url": "https://someurl.com"}}, - ], - }, - ], - "temperature": 0, - "max_tokens": 64, - } - request = ChatCompletionRequest(**image_request) - assert len(request.messages) == 2 - assert request.messages[0] == ChatCompletionMessageGenericParam( - role="system", content="You are a helpful AI assistant" - ) - assert request.messages[1] == ChatCompletionMessageUserParam( - role="user", - content=[ - ChatCompletionMessageContentTextPart( - type="text", text="Describe this image" - ), - ChatCompletionMessageContentImagePart( - type="image_url", - image_url=ChatCompletionMessageContentImageURL( - url="https://someurl.com" - ), - ), - ], - ) - - -if __name__ == "__main__": - test_chat_completion_request_image() diff --git a/python/sglang/test/test_programs.py b/python/sglang/test/test_programs.py index 6fa8f821433..710871ba5db 100644 --- a/python/sglang/test/test_programs.py +++ b/python/sglang/test/test_programs.py @@ -105,20 +105,22 @@ def test_decode_json_regex(): def decode_json(s): from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STRING - s += "Generate a JSON object to describe the basic information of a city.\n" + s += "Generate a JSON object to describe the basic city information of Paris.\n" with s.var_scope("json_output"): s += "{\n" s += ' "name": ' + sgl.gen(regex=REGEX_STRING + ",") + "\n" s += ' "population": ' + sgl.gen(regex=REGEX_INT + ",") + "\n" s += ' "area": ' + sgl.gen(regex=REGEX_INT + ",") + "\n" - s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT + ",") + "\n" - s += ' "country": ' + sgl.gen(regex=REGEX_STRING + ",") + "\n" - s += ' "timezone": ' + sgl.gen(regex=REGEX_STRING) + "\n" + s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT) + "\n" s += "}" - ret = decode_json.run() - js_obj = json.loads(ret["json_output"]) + ret = decode_json.run(temperature=0.0) + try: + js_obj = json.loads(ret["json_output"]) + except json.decoder.JSONDecodeError: + print("JSONDecodeError", ret["json_output"]) + raise assert isinstance(js_obj["name"], str) assert isinstance(js_obj["population"], int) @@ -126,7 +128,7 @@ def decode_json(s): def test_decode_json(): @sgl.function def decode_json(s): - s += "Generate a JSON object to describe the basic information of a city.\n" + s += "Generate a JSON object to describe the basic city information of Paris.\n" with s.var_scope("json_output"): s += "{\n" @@ -137,8 +139,12 @@ def decode_json(s): s += ' "timezone": ' + sgl.gen(dtype=str) + "\n" s += "}" - ret = decode_json.run() - js_obj = json.loads(ret["json_output"]) + ret = decode_json.run(max_new_tokens=64) + try: + js_obj = json.loads(ret["json_output"]) + except json.decoder.JSONDecodeError: + print("JSONDecodeError", ret["json_output"]) + raise assert isinstance(js_obj["name"], str) assert isinstance(js_obj["population"], int) @@ -257,6 +263,7 @@ def parallel_decoding(s, topic): s += "\nIn summary," + sgl.gen("summary", max_tokens=512) ret = parallel_decoding.run(topic="writing a good blog post", temperature=0.3) + assert isinstance(ret["summary"], str) def test_parallel_encoding(check_answer=True): @@ -306,7 +313,7 @@ def image_qa(s, question): assert ( "taxi" in state.messages()[-1]["content"] or "car" in state.messages()[-1]["content"] - ) + ), f"{state.messages()[-1]['content']}" def test_stream(): diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 693bade6f2d..613645b572e 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -1,16 +1,27 @@ """Common utilities for testing and benchmarking""" +import argparse import asyncio +import multiprocessing +import subprocess +import threading +import time +import unittest from functools import partial +from typing import Callable, List, Optional import numpy as np import requests +import torch +import torch.nn.functional as F -from sglang.backend.openai import OpenAI -from sglang.backend.runtime_endpoint import RuntimeEndpoint from sglang.global_config import global_config +from sglang.lang.backend.openai import OpenAI +from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint from sglang.utils import get_exception_traceback +DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Meta-Llama-3.1-8B-Instruct" + def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None): assert url is not None @@ -243,7 +254,7 @@ async def program(ctx, choices): return choices.index(answer) -def add_common_other_args_and_parse(parser): +def add_common_other_args_and_parse(parser: argparse.ArgumentParser): parser.add_argument("--parallel", type=int, default=64) parser.add_argument("--host", type=str, default="http://127.0.0.1") parser.add_argument("--port", type=int, default=None) @@ -282,7 +293,7 @@ def add_common_other_args_and_parse(parser): return args -def add_common_sglang_args_and_parse(parser): +def add_common_sglang_args_and_parse(parser: argparse.ArgumentParser): parser.add_argument("--parallel", type=int, default=64) parser.add_argument("--host", type=str, default="http://127.0.0.1") parser.add_argument("--port", type=int, default=30000) @@ -292,7 +303,7 @@ def add_common_sglang_args_and_parse(parser): return args -def select_sglang_backend(args): +def select_sglang_backend(args: argparse.Namespace): if args.backend.startswith("srt"): if args.backend == "srt-no-parallel": global_config.enable_parallel_decoding = False @@ -305,7 +316,7 @@ def select_sglang_backend(args): return backend -def _get_call_generate(args): +def _get_call_generate(args: argparse.Namespace): if args.backend == "lightllm": return partial(call_generate_lightllm, url=f"{args.host}:{args.port}/generate") elif args.backend == "vllm": @@ -332,7 +343,7 @@ def _get_call_generate(args): raise ValueError(f"Invalid backend: {args.backend}") -def _get_call_select(args): +def _get_call_select(args: argparse.Namespace): if args.backend == "lightllm": return partial(call_select_lightllm, url=f"{args.host}:{args.port}/generate") elif args.backend == "vllm": @@ -355,7 +366,7 @@ def _get_call_select(args): raise ValueError(f"Invalid backend: {args.backend}") -def get_call_generate(args): +def get_call_generate(args: argparse.Namespace): call_generate = _get_call_generate(args) def func(*args, **kwargs): @@ -368,7 +379,7 @@ def func(*args, **kwargs): return func -def get_call_select(args): +def get_call_select(args: argparse.Namespace): call_select = _get_call_select(args) def func(*args, **kwargs): @@ -379,3 +390,111 @@ def func(*args, **kwargs): raise return func + + +def popen_launch_server( + model: str, + base_url: str, + timeout: float, + api_key: Optional[str] = None, + other_args: tuple = (), +): + _, host, port = base_url.split(":") + host = host[2:] + + command = [ + "python3", + "-m", + "sglang.launch_server", + "--model-path", + model, + "--host", + host, + "--port", + port, + *other_args, + ] + if api_key: + command += ["--api-key", api_key] + + process = subprocess.Popen(command, stdout=None, stderr=None) + + start_time = time.time() + while time.time() - start_time < timeout: + try: + headers = { + "Content-Type": "application/json; charset=utf-8", + "Authorization": f"Bearer {api_key}", + } + response = requests.get(f"{base_url}/v1/models", headers=headers) + if response.status_code == 200: + return process + except requests.RequestException: + pass + time.sleep(10) + raise TimeoutError("Server failed to start within the timeout period.") + + +def run_with_timeout( + func: Callable, + args: tuple = (), + kwargs: Optional[dict] = None, + timeout: float = None, +): + """Run a function with timeout.""" + ret_value = [] + + def _target_func(): + ret_value.append(func(*args, **(kwargs or {}))) + + t = threading.Thread(target=_target_func) + t.start() + t.join(timeout=timeout) + if t.is_alive(): + raise TimeoutError() + + if not ret_value: + raise RuntimeError() + + return ret_value[0] + + +def run_unittest_files(files: List[str], timeout_per_file: float): + tic = time.time() + success = True + + for filename in files: + + def func(): + print(f"\n\nRun {filename}\n\n") + ret = unittest.main(module=None, argv=["", "-vb"] + [filename]) + + p = multiprocessing.Process(target=func) + + def run_one_file(): + p.start() + p.join() + + try: + run_with_timeout(run_one_file, timeout=timeout_per_file) + if p.exitcode != 0: + success = False + break + except TimeoutError: + p.terminate() + time.sleep(5) + print( + f"\nTimeout after {timeout_per_file} seconds when running {filename}\n" + ) + return False + + if success: + print(f"Success. Time elapsed: {time.time() - tic:.2f}s") + else: + print(f"Fail. Time elapsed: {time.time() - tic:.2f}s") + + return 0 if success else -1 + + +def get_similarities(vec1, vec2): + return F.cosine_similarity(torch.tensor(vec1), torch.tensor(vec2), dim=0) diff --git a/python/sglang/utils.py b/python/sglang/utils.py index 0f5fd439082..c880d259d53 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -1,16 +1,17 @@ """Common utilities.""" import base64 +import importlib import json import logging import signal import sys -import threading import traceback import urllib.request from concurrent.futures import ThreadPoolExecutor from io import BytesIO from json import dumps +from typing import Union import numpy as np import requests @@ -24,7 +25,7 @@ def get_exception_traceback(): return err_str -def is_same_type(values): +def is_same_type(values: list): """Return whether the elements in values are of the same type.""" if len(values) <= 1: return True @@ -44,7 +45,7 @@ def read_jsonl(filename: str): return rets -def dump_state_text(filename, states, mode="w"): +def dump_state_text(filename: str, states: list, mode: str = "w"): """Dump program state in a text file.""" from sglang.lang.interpreter import ProgramState @@ -74,19 +75,13 @@ def status_code(self): return self.resp.status -def http_request( - url, json=None, stream=False, auth_token=None, api_key=None, verify=None -): +def http_request(url, json=None, stream=False, api_key=None, verify=None): """A faster version of requests.post with low-level urllib API.""" headers = {"Content-Type": "application/json; charset=utf-8"} - # add the Authorization header if an auth token is provided - if auth_token is not None: - headers["Authorization"] = f"Bearer {auth_token}" - - # add the API Key header if an API key is provided + # add the Authorization header if an api key is provided if api_key is not None: - headers["X-API-Key"] = api_key + headers["Authorization"] = f"Bearer {api_key}" if stream: return requests.post(url, json=json, stream=True, headers=headers) @@ -104,7 +99,7 @@ def http_request( return HttpResponse(e) -def encode_image_base64(image_path): +def encode_image_base64(image_path: Union[str, bytes]): """Encode an image in base64.""" if isinstance(image_path, str): with open(image_path, "rb") as image_file: @@ -143,7 +138,7 @@ def encode_frame(frame): return frame_bytes -def encode_video_base64(video_path, num_frames=16): +def encode_video_base64(video_path: str, num_frames: int = 16): import cv2 # pip install opencv-python-headless cap = cv2.VideoCapture(video_path) @@ -189,7 +184,7 @@ def encode_video_base64(video_path, num_frames=16): return video_base64 -def _is_chinese_char(cp): +def _is_chinese_char(cp: int): """Checks whether CP is the codepoint of a CJK character.""" # This defines a "chinese character" as anything in the CJK Unicode block: # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) @@ -214,7 +209,7 @@ def _is_chinese_char(cp): return False -def find_printable_text(text): +def find_printable_text(text: str): """Returns the longest printable substring of text that contains only entire words.""" # Borrowed from https://github.com/huggingface/transformers/blob/061580c82c2db1de9139528243e105953793f7a2/src/transformers/generation/streamers.py#L99 @@ -233,26 +228,7 @@ def find_printable_text(text): return text[: text.rfind(" ") + 1] -def run_with_timeout(func, args=(), kwargs=None, timeout=None): - """Run a function with timeout.""" - ret_value = [] - - def _target_func(): - ret_value.append(func(*args, **(kwargs or {}))) - - t = threading.Thread(target=_target_func) - t.start() - t.join(timeout=timeout) - if t.is_alive(): - raise TimeoutError() - - if not ret_value: - raise RuntimeError() - - return ret_value[0] - - -def graceful_registry(sub_module_name): +def graceful_registry(sub_module_name: str): def graceful_shutdown(signum, frame): logger.info( f"{sub_module_name} Received signal to shutdown. Performing graceful shutdown..." @@ -261,3 +237,26 @@ def graceful_shutdown(signum, frame): logger.info(f"{sub_module_name} recive sigterm") signal.signal(signal.SIGTERM, graceful_shutdown) + + +class LazyImport: + """Lazy import to make `import sglang` run faster.""" + + def __init__(self, module_name: str, class_name: str): + self.module_name = module_name + self.class_name = class_name + self._module = None + + def _load(self): + if self._module is None: + module = importlib.import_module(self.module_name) + self._module = getattr(module, self.class_name) + return self._module + + def __getattr__(self, name: str): + module = self._load() + return getattr(module, name) + + def __call__(self, *args, **kwargs): + module = self._load() + return module(*args, **kwargs) diff --git a/python/sglang/version.py b/python/sglang/version.py new file mode 100644 index 00000000000..5635676f6b4 --- /dev/null +++ b/python/sglang/version.py @@ -0,0 +1 @@ +__version__ = "0.2.11" diff --git a/scripts/convert_yi_vl.py b/scripts/convert_yi_vl.py index a45f83a3002..bdf37ff92bb 100644 --- a/scripts/convert_yi_vl.py +++ b/scripts/convert_yi_vl.py @@ -10,16 +10,15 @@ from transformers import AutoConfig, AutoTokenizer + def add_image_token(model_path: str): tokenizer = AutoTokenizer.from_pretrained(model_path) - tokenizer.add_tokens( - [""], - special_tokens=True - ) + tokenizer.add_tokens([""], special_tokens=True) print(tokenizer) tokenizer.save_pretrained(model_path) + def edit_model_config(model_path): config = AutoConfig.from_pretrained(model_path) @@ -29,10 +28,11 @@ def edit_model_config(model_path): print(config) config.save_pretrained(model_path) + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model-path", type=str) args = parser.parse_args() add_image_token(args.model_path) - edit_model_config(args.model_path) \ No newline at end of file + edit_model_config(args.model_path) diff --git a/test/srt/test_curl.sh b/scripts/deprecated/test_curl.sh similarity index 86% rename from test/srt/test_curl.sh rename to scripts/deprecated/test_curl.sh index 4362eaa9355..1c83208a759 100644 --- a/test/srt/test_curl.sh +++ b/scripts/deprecated/test_curl.sh @@ -3,7 +3,7 @@ curl http://localhost:30000/generate \ -d '{ "text": "Once upon a time,", "sampling_params": { - "max_new_tokens": 16, + "max_new_tokens": 64, "temperature": 0 } }' diff --git a/test/srt/test_flashinfer.py b/scripts/deprecated/test_flashinfer.py similarity index 100% rename from test/srt/test_flashinfer.py rename to scripts/deprecated/test_flashinfer.py diff --git a/test/srt/test_httpserver_classify.py b/scripts/deprecated/test_httpserver_classify.py similarity index 100% rename from test/srt/test_httpserver_classify.py rename to scripts/deprecated/test_httpserver_classify.py diff --git a/test/srt/test_httpserver_concurrent.py b/scripts/deprecated/test_httpserver_concurrent.py similarity index 100% rename from test/srt/test_httpserver_concurrent.py rename to scripts/deprecated/test_httpserver_concurrent.py diff --git a/test/srt/test_httpserver_decode.py b/scripts/deprecated/test_httpserver_decode.py similarity index 69% rename from test/srt/test_httpserver_decode.py rename to scripts/deprecated/test_httpserver_decode.py index 7e169f3e423..57517a15b00 100644 --- a/test/srt/test_httpserver_decode.py +++ b/scripts/deprecated/test_httpserver_decode.py @@ -13,14 +13,15 @@ import requests -def test_decode(url, return_logprob, top_logprobs_num, return_text): +def test_decode(url, return_logprob=False, top_logprobs_num=0, return_text=False, n=1): response = requests.post( url + "/generate", json={ "text": "The capital of France is", "sampling_params": { - "temperature": 0, + "temperature": 0 if n == 1 else 0.5, "max_new_tokens": 32, + "n": n, }, "stream": False, "return_logprob": return_logprob, @@ -41,8 +42,14 @@ def test_decode(url, return_logprob, top_logprobs_num, return_text): url = f"{args.host}:{args.port}" - test_decode(url, False, 0, False) - test_decode(url, True, 0, False) - test_decode(url, True, 0, True) - test_decode(url, True, 3, False) - test_decode(url, True, 3, True) + test_decode(url) + test_decode(url, n=3) + + for top_logprobs_num in [0, 3]: + for return_text in [True, False]: + test_decode( + url, + return_logprob=True, + top_logprobs_num=top_logprobs_num, + return_text=return_text, + ) diff --git a/test/srt/test_httpserver_decode_stream.py b/scripts/deprecated/test_httpserver_decode_stream.py similarity index 89% rename from test/srt/test_httpserver_decode_stream.py rename to scripts/deprecated/test_httpserver_decode_stream.py index 38f090b7d1b..955c368d154 100644 --- a/test/srt/test_httpserver_decode_stream.py +++ b/scripts/deprecated/test_httpserver_decode_stream.py @@ -40,14 +40,14 @@ def test_decode_stream(url, return_logprob, top_logprobs_num): data = json.loads(chunk[5:].strip("\n")) if return_logprob: - assert data["meta_info"]["prefill_token_logprobs"] is not None - assert data["meta_info"]["decode_token_logprobs"] is not None + assert data["meta_info"]["input_token_logprobs"] is not None + assert data["meta_info"]["output_token_logprobs"] is not None assert data["meta_info"]["normalized_prompt_logprob"] is not None for logprob, token_id, token_text in data["meta_info"][ - "decode_token_logprobs" + "output_token_logprobs" ][prev:]: print(f"{token_text:12s}\t{logprob}\t{token_id}", flush=True) - prev = len(data["meta_info"]["decode_token_logprobs"]) + prev = len(data["meta_info"]["output_token_logprobs"]) else: output = data["text"].strip() print(output[prev:], end="", flush=True) diff --git a/test/srt/test_httpserver_llava.py b/scripts/deprecated/test_httpserver_llava.py similarity index 97% rename from test/srt/test_httpserver_llava.py rename to scripts/deprecated/test_httpserver_llava.py index e3cf1b79931..791fc6deb1f 100644 --- a/test/srt/test_httpserver_llava.py +++ b/scripts/deprecated/test_httpserver_llava.py @@ -10,7 +10,6 @@ import argparse import asyncio import json -import time import aiohttp import requests @@ -37,7 +36,7 @@ async def test_concurrent(args): "image_data": "example_image.png", "sampling_params": { "temperature": 0, - "max_new_tokens": 16, + "max_new_tokens": 64, }, }, ) diff --git a/test/srt/test_httpserver_reuse.py b/scripts/deprecated/test_httpserver_reuse.py similarity index 100% rename from test/srt/test_httpserver_reuse.py rename to scripts/deprecated/test_httpserver_reuse.py diff --git a/test/srt/test_jump_forward.py b/scripts/deprecated/test_jump_forward.py similarity index 100% rename from test/srt/test_jump_forward.py rename to scripts/deprecated/test_jump_forward.py diff --git a/test/srt/test_robust.py b/scripts/deprecated/test_robust.py similarity index 100% rename from test/srt/test_robust.py rename to scripts/deprecated/test_robust.py diff --git a/scripts/format.sh b/scripts/format.sh deleted file mode 100644 index a49aed74549..00000000000 --- a/scripts/format.sh +++ /dev/null @@ -1,8 +0,0 @@ -isort python -black python - -isort test -black test - -isort benchmark -black benchmark diff --git a/scripts/launch_tgi.sh b/scripts/launch_tgi.sh deleted file mode 100644 index eeb4054754f..00000000000 --- a/scripts/launch_tgi.sh +++ /dev/null @@ -1,6 +0,0 @@ -docker run --name tgi --rm -ti --gpus all --network host \ - -v /home/ubuntu/model_weights/Llama-2-7b-chat-hf:/Llama-2-7b-chat-hf \ - ghcr.io/huggingface/text-generation-inference:1.3.0 \ - --model-id /Llama-2-7b-chat-hf --num-shard 1 --trust-remote-code \ - --max-input-length 2048 --max-total-tokens 4096 \ - --port 24000 diff --git a/playground/launch_tgi.sh b/scripts/playground/launch_tgi.sh similarity index 100% rename from playground/launch_tgi.sh rename to scripts/playground/launch_tgi.sh diff --git a/playground/load_tokenizer.py b/scripts/playground/load_tokenizer.py similarity index 61% rename from playground/load_tokenizer.py rename to scripts/playground/load_tokenizer.py index 39fa1842481..94cf34bc71f 100644 --- a/playground/load_tokenizer.py +++ b/scripts/playground/load_tokenizer.py @@ -3,11 +3,12 @@ from sglang.srt.hf_transformers_utils import get_tokenizer - if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--name", type=str, default="meta-llama/Meta-Llama-3-8B-Instruct") + parser.add_argument( + "--name", type=str, default="meta-llama/Meta-Llama-3-8B-Instruct" + ) args = parser.parse_args() t = get_tokenizer(args.name) - code.interact(local=locals()) \ No newline at end of file + code.interact(local=locals()) diff --git a/playground/reference_hf.py b/scripts/playground/reference_hf.py similarity index 93% rename from playground/reference_hf.py rename to scripts/playground/reference_hf.py index ca82871c9de..ac91b3bed40 100644 --- a/playground/reference_hf.py +++ b/scripts/playground/reference_hf.py @@ -30,9 +30,12 @@ @torch.inference_mode() def normal_text(args): - t = AutoTokenizer.from_pretrained(args.model_path) + t = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) m = AutoModelForCausalLM.from_pretrained( - args.model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True + args.model_path, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + trust_remote_code=True, ) m.cuda() diff --git a/test/README.md b/test/README.md new file mode 100644 index 00000000000..cdfbbaee81a --- /dev/null +++ b/test/README.md @@ -0,0 +1,26 @@ +# Run Unit Tests + +## Test Frontend Language +``` +cd sglang/test/lang +export OPENAI_API_KEY=sk-***** + +# Run a single file +python3 test_openai_backend.py + +# Run a suite +python3 run_suite.py --suite minimal +``` + +## Test Backend Runtime +``` +cd sglang/test/srt + +# Run a single file +python3 test_eval_accuracy.py + +# Run a suite +python3 run_suite.py --suite minimal +``` + + diff --git a/test/__init__.py b/test/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/lang/run_all.py b/test/lang/run_all.py deleted file mode 100644 index cb5da15850b..00000000000 --- a/test/lang/run_all.py +++ /dev/null @@ -1,60 +0,0 @@ -import argparse -import glob -import multiprocessing -import os -import time -import unittest - -from sglang.utils import run_with_timeout - - -def run_unittest_files(files, args): - for filename in files: - - def func(): - print(filename) - ret = unittest.main(module=None, argv=["", "-vb"] + [filename]) - - p = multiprocessing.Process(target=func) - - def run_one_file(): - p.start() - p.join() - - try: - run_with_timeout(run_one_file, timeout=args.time_limit_per_file) - if p.exitcode != 0: - return False - except TimeoutError: - p.terminate() - time.sleep(5) - print( - f"\nTimeout after {args.time_limit_per_file} seconds " - f"when running {filename}" - ) - return False - - return True - - -if __name__ == "__main__": - arg_parser = argparse.ArgumentParser() - arg_parser.add_argument( - "--time-limit-per-file", - type=int, - default=1000, - help="The time limit for running one file in seconds.", - ) - args = arg_parser.parse_args() - - files = glob.glob("**/test_*.py", recursive=True) - - tic = time.time() - success = run_unittest_files(files, args) - - if success: - print(f"Success. Time elapsed: {time.time() - tic:.2f}s") - else: - print(f"Fail. Time elapsed: {time.time() - tic:.2f}s") - - exit(0 if success else -1) diff --git a/test/lang/run_suite.py b/test/lang/run_suite.py new file mode 100644 index 00000000000..379427afac9 --- /dev/null +++ b/test/lang/run_suite.py @@ -0,0 +1,34 @@ +import argparse +import glob + +from sglang.test.test_utils import run_unittest_files + +suites = { + "minimal": ["test_srt_backend.py", "test_openai_backend.py"], +} + + +if __name__ == "__main__": + arg_parser = argparse.ArgumentParser() + arg_parser.add_argument( + "--timeout-per-file", + type=int, + default=1000, + help="The time limit for running one file in seconds.", + ) + arg_parser.add_argument( + "--suite", + type=str, + default=list(suites.keys())[0], + choices=list(suites.keys()) + ["all"], + help="The suite to run", + ) + args = arg_parser.parse_args() + + if args.suite == "all": + files = glob.glob("**/test_*.py", recursive=True) + else: + files = suites[args.suite] + + exit_code = run_unittest_files(files, args.timeout_per_file) + exit(exit_code) diff --git a/test/lang/test_anthropic_backend.py b/test/lang/test_anthropic_backend.py index 3eb4051d739..87b27a765a3 100644 --- a/test/lang/test_anthropic_backend.py +++ b/test/lang/test_anthropic_backend.py @@ -7,14 +7,11 @@ class TestAnthropicBackend(unittest.TestCase): backend = None - chat_backend = None - def setUp(self): - cls = type(self) - - if cls.backend is None: - cls.backend = Anthropic("claude-3-haiku-20240307") - set_default_backend(cls.backend) + @classmethod + def setUpClass(cls): + cls.backend = Anthropic("claude-3-haiku-20240307") + set_default_backend(cls.backend) def test_mt_bench(self): test_mt_bench() @@ -30,5 +27,5 @@ def test_stream(self): # global_config.verbosity = 2 # t = TestAnthropicBackend() - # t.setUp() + # t.setUpClass() # t.test_mt_bench() diff --git a/test/lang/test_bind_cache.py b/test/lang/test_bind_cache.py index 9cba14ce437..14a7e509863 100644 --- a/test/lang/test_bind_cache.py +++ b/test/lang/test_bind_cache.py @@ -1,17 +1,20 @@ import unittest import sglang as sgl -from sglang.backend.runtime_endpoint import RuntimeEndpoint +from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST class TestBind(unittest.TestCase): backend = None - def setUp(self): - cls = type(self) + @classmethod + def setUpClass(cls): + cls.backend = sgl.Runtime(model_path=DEFAULT_MODEL_NAME_FOR_TEST) + sgl.set_default_backend(cls.backend) - if cls.backend is None: - cls.backend = RuntimeEndpoint(base_url="http://localhost:30000") + @classmethod + def tearDownClass(cls): + cls.backend.shutdown() def test_bind(self): @sgl.function @@ -48,5 +51,5 @@ def few_shot_qa(s, prompt, question): unittest.main(warnings="ignore") # t = TestBind() - # t.setUp() + # t.setUpClass() # t.test_cache() diff --git a/test/lang/test_choices.py b/test/lang/test_choices.py new file mode 100644 index 00000000000..da25e9e496f --- /dev/null +++ b/test/lang/test_choices.py @@ -0,0 +1,95 @@ +import unittest + +import numpy as np + +from sglang.lang.choices import ( + greedy_token_selection, + token_length_normalized, + unconditional_likelihood_normalized, +) + +MOCK_CHOICES_INPUT_DATA = { + "choices": [ + "organ", # ["organ"] + "organism", # ["organ", "ism"] + "antidisestablishmentarianism", # ["ant", "id", "is", "est", "ablish", "ment", "arian", "ism"] + ], + "normalized_prompt_logprobs": [-0.1, -0.2, -0.05], + "input_token_logprobs": [ + [[-0.1, 1, None]], + [[-0.1, 1, None], [-0.3, 2, None]], + [ + [-0.4, 3, None], + [-0.25, 4, None], + [-0.1, 5, None], + [-0.01, 6, None], + [-0.01, 7, None], + [-0.01, 8, None], + [-0.01, 9, None], + [-0.01, 2, None], + ], + ], + "output_token_logprobs": [ + [[-0.1, 10, None]], + [[-0.1, 10, None]], + [[-0.1, 10, None]], + ], + "unconditional_token_logprobs": [ + [[None, 1, None]], + [[None, 1, None], [-1.4, 2, None]], + [ + [None, 3, None], + [-0.25, 4, None], + [-0.1, 5, None], + [-0.01, 6, None], + [-0.01, 7, None], + [-0.01, 8, None], + [-0.01, 9, None], + [-0.01, 2, None], + ], + ], +} + + +class TestChoices(unittest.TestCase): + + def test_token_length_normalized(self): + """Confirm 'antidisestablishmentarianism' is selected due to high confidences for + its later tokens resulting in highest token length normalized prompt logprob.""" + decision = token_length_normalized(**MOCK_CHOICES_INPUT_DATA) + assert decision.decision == "antidisestablishmentarianism" + + def test_greedy_token_selection(self): + """Confirm 'organ' is selected due it having the joint highest initial token + logprob, and a higher average logprob than organism's second token.""" + decision = greedy_token_selection(**MOCK_CHOICES_INPUT_DATA) + assert decision.decision == "organ" + assert np.allclose( + decision.meta_info["greedy_logprob_matrix"], + [ + [-0.1, -0.1, -0.1, -0.1, -0.1, -0.1, -0.1, -0.1], + [-0.1, -0.3, -0.2, -0.2, -0.2, -0.2, -0.2, -0.2], + [-0.4, -0.25, -0.1, -0.01, -0.01, -0.01, -0.01, -0.01], + ], + atol=0.01, + ) + + def test_unconditional_likelihood_normalized(self): + """Confirm 'organism' is selected due to it having the highest average token logprob + once normalized by the unconditional token logprobs.""" + decision = unconditional_likelihood_normalized(**MOCK_CHOICES_INPUT_DATA) + assert decision.decision == "organism" + assert np.allclose( + decision.meta_info["normalized_unconditional_prompt_logprobs"], + [-0.1, 0.5, -0.05], + atol=0.01, + ) + + +if __name__ == "__main__": + unittest.main(warnings="ignore") + + # t = TestChoices() + # t.test_token_length_normalized() + # t.test_greedy_token_selection() + # t.test_unconditional_likelihood_normalized() diff --git a/test/lang/test_litellm_backend.py b/test/lang/test_litellm_backend.py index 15d83bd517a..3c7f5db2182 100644 --- a/test/lang/test_litellm_backend.py +++ b/test/lang/test_litellm_backend.py @@ -6,15 +6,12 @@ class TestAnthropicBackend(unittest.TestCase): - backend = None chat_backend = None - def setUp(self): - cls = type(self) - - if cls.backend is None: - cls.backend = LiteLLM("gpt-3.5-turbo") - set_default_backend(cls.backend) + @classmethod + def setUpClass(cls): + cls.chat_backend = LiteLLM("gpt-3.5-turbo") + set_default_backend(cls.chat_backend) def test_mt_bench(self): test_mt_bench() diff --git a/test/lang/test_openai_backend.py b/test/lang/test_openai_backend.py index d35495e4d75..b1bb47b82f6 100644 --- a/test/lang/test_openai_backend.py +++ b/test/lang/test_openai_backend.py @@ -20,20 +20,18 @@ class TestOpenAIBackend(unittest.TestCase): - backend = None + instruct_backend = None chat_backend = None chat_vision_backend = None - def setUp(self): - cls = type(self) - - if cls.backend is None: - cls.backend = OpenAI("gpt-3.5-turbo-instruct") - cls.chat_backend = OpenAI("gpt-3.5-turbo") - cls.chat_vision_backend = OpenAI("gpt-4-turbo") + @classmethod + def setUpClass(cls): + cls.instruct_backend = OpenAI("gpt-3.5-turbo-instruct") + cls.chat_backend = OpenAI("gpt-3.5-turbo") + cls.chat_vision_backend = OpenAI("gpt-4-turbo") def test_few_shot_qa(self): - set_default_backend(self.backend) + set_default_backend(self.instruct_backend) test_few_shot_qa() def test_mt_bench(self): @@ -41,35 +39,35 @@ def test_mt_bench(self): test_mt_bench() def test_select(self): - set_default_backend(self.backend) + set_default_backend(self.instruct_backend) test_select(check_answer=True) def test_decode_int(self): - set_default_backend(self.backend) + set_default_backend(self.instruct_backend) test_decode_int() def test_decode_json(self): - set_default_backend(self.backend) + set_default_backend(self.instruct_backend) test_decode_json() def test_expert_answer(self): - set_default_backend(self.backend) + set_default_backend(self.instruct_backend) test_expert_answer() def test_tool_use(self): - set_default_backend(self.backend) + set_default_backend(self.instruct_backend) test_tool_use() def test_react(self): - set_default_backend(self.backend) + set_default_backend(self.instruct_backend) test_react() def test_parallel_decoding(self): - set_default_backend(self.backend) + set_default_backend(self.instruct_backend) test_parallel_decoding() def test_parallel_encoding(self): - set_default_backend(self.backend) + set_default_backend(self.instruct_backend) test_parallel_encoding() def test_image_qa(self): @@ -77,11 +75,11 @@ def test_image_qa(self): test_image_qa() def test_stream(self): - set_default_backend(self.backend) + set_default_backend(self.instruct_backend) test_stream() def test_completion_speculative(self): - set_default_backend(self.backend) + set_default_backend(self.instruct_backend) test_completion_speculative() def test_chat_completion_speculative(self): @@ -96,5 +94,5 @@ def test_chat_completion_speculative(self): # global_config.verbosity = 2 # t = TestOpenAIBackend() - # t.setUp() - # t.test_chat_completion_speculative() + # t.setUpClass() + # t.test_stream() diff --git a/test/lang/test_srt_backend.py b/test/lang/test_srt_backend.py index c92568c0bf1..778cde8be4e 100644 --- a/test/lang/test_srt_backend.py +++ b/test/lang/test_srt_backend.py @@ -1,7 +1,3 @@ -""" -python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 -""" - import json import unittest @@ -13,24 +9,25 @@ test_few_shot_qa, test_mt_bench, test_parallel_decoding, - test_parallel_encoding, - test_react, test_regex, test_select, test_stream, test_tool_use, ) +from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST class TestSRTBackend(unittest.TestCase): backend = None - def setUp(self): - cls = type(self) + @classmethod + def setUpClass(cls): + cls.backend = sgl.Runtime(model_path=DEFAULT_MODEL_NAME_FOR_TEST) + sgl.set_default_backend(cls.backend) - if cls.backend is None: - cls.backend = sgl.RuntimeEndpoint(base_url="http://localhost:30000") - sgl.set_default_backend(cls.backend) + @classmethod + def tearDownClass(cls): + cls.backend.shutdown() def test_few_shot_qa(self): test_few_shot_qa() @@ -62,9 +59,6 @@ def test_stream(self): def test_regex(self): test_regex() - # def test_parallel_encoding(self): - # test_parallel_encoding(check_answer=False) - if __name__ == "__main__": unittest.main(warnings="ignore") @@ -73,5 +67,6 @@ def test_regex(self): # global_config.verbosity = 2 # t = TestSRTBackend() - # t.setUp() - # t.test_regex() + # t.setUpClass() + # t.test_few_shot_qa() + # t.tearDownClass() diff --git a/test/lang/test_tracing.py b/test/lang/test_tracing.py index 266ce65fe38..5f2bc1d04fe 100644 --- a/test/lang/test_tracing.py +++ b/test/lang/test_tracing.py @@ -1,7 +1,7 @@ import unittest import sglang as sgl -from sglang.backend.base_backend import BaseBackend +from sglang.lang.backend.base_backend import BaseBackend from sglang.lang.chat_template import get_chat_template @@ -16,7 +16,7 @@ def few_shot_qa(s, question): s += "A:" + sgl.gen("answer", stop="\n") tracer = few_shot_qa.trace() - print(tracer.last_node.print_graph_dfs() + "\n") + # print(tracer.last_node.print_graph_dfs() + "\n") def test_select(self): @sgl.function @@ -26,7 +26,7 @@ def capital(s): s += "It is a city" + sgl.gen("description", stop=".") tracer = capital.trace() - print(tracer.last_node.print_graph_dfs() + "\n") + # print(tracer.last_node.print_graph_dfs() + "\n") def test_raise_warning(self): @sgl.function @@ -66,11 +66,11 @@ def tip_suggestion(s, topic): s += "In summary" + sgl.gen("summary") compiled = tip_suggestion.compile() - compiled.print_graph() + # compiled.print_graph() sgl.set_default_backend(sgl.OpenAI("gpt-3.5-turbo-instruct")) state = compiled.run(topic="staying healthy") - print(state.text() + "\n") + # print(state.text() + "\n") states = compiled.run_batch( [ @@ -80,8 +80,8 @@ def tip_suggestion(s, topic): ], temperature=0, ) - for s in states: - print(s.text() + "\n") + # for s in states: + # print(s.text() + "\n") def test_role(self): @sgl.function @@ -95,7 +95,7 @@ def multi_turn_chat(s): backend.chat_template = get_chat_template("llama-2-chat") compiled = multi_turn_chat.compile(backend=backend) - compiled.print_graph() + # compiled.print_graph() def test_fork(self): @sgl.function @@ -118,10 +118,10 @@ def tip_suggestion(s): s += "In summary" + sgl.gen("summary") tracer = tip_suggestion.trace() - print(tracer.last_node.print_graph_dfs()) + # print(tracer.last_node.print_graph_dfs()) a = tip_suggestion.run(backend=sgl.OpenAI("gpt-3.5-turbo-instruct")) - print(a.text()) + # print(a.text()) if __name__ == "__main__": diff --git a/test/lang/test_vertexai_backend.py b/test/lang/test_vertexai_backend.py index aae840101ac..b29efaa75ad 100644 --- a/test/lang/test_vertexai_backend.py +++ b/test/lang/test_vertexai_backend.py @@ -17,13 +17,11 @@ class TestVertexAIBackend(unittest.TestCase): chat_backend = None chat_vision_backend = None - def setUp(self): - cls = type(self) - - if cls.backend is None: - cls.backend = VertexAI("gemini-pro") - cls.chat_backend = VertexAI("gemini-pro") - cls.chat_vision_backend = VertexAI("gemini-pro-vision") + @classmethod + def setUpClass(cls): + cls.backend = VertexAI("gemini-pro") + cls.chat_backend = VertexAI("gemini-pro") + cls.chat_vision_backend = VertexAI("gemini-pro-vision") def test_few_shot_qa(self): set_default_backend(self.backend) @@ -61,5 +59,5 @@ def test_stream(self): # global_config.verbosity = 2 # t = TestVertexAIBackend() - # t.setUp() + # t.setUpClass() # t.test_stream() diff --git a/test/srt/example_image.png b/test/srt/example_image.png deleted file mode 120000 index c8a970edd0c..00000000000 --- a/test/srt/example_image.png +++ /dev/null @@ -1 +0,0 @@ -../lang/example_image.png \ No newline at end of file diff --git a/test/srt/models/test_embedding_models.py b/test/srt/models/test_embedding_models.py new file mode 100644 index 00000000000..c29c33188c0 --- /dev/null +++ b/test/srt/models/test_embedding_models.py @@ -0,0 +1,69 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import unittest + +import torch + +from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner +from sglang.test.test_utils import get_similarities + +MODELS = [("intfloat/e5-mistral-7b-instruct", 1)] +TORCH_DTYPES = [torch.float16] + + +class TestEmbeddingModels(unittest.TestCase): + + def assert_close_prefill_logits( + self, + prompts, + model_path, + tp_size, + torch_dtype, + ) -> None: + with HFRunner( + model_path, torch_dtype=torch_dtype, is_generation_model=False + ) as hf_runner: + hf_outputs = hf_runner.forward(prompts) + + with SRTRunner( + model_path, + tp_size=tp_size, + torch_dtype=torch_dtype, + is_generation_model=False, + ) as srt_runner: + srt_outputs = srt_runner.forward(prompts) + + for i in range(len(prompts)): + hf_logits = torch.Tensor(hf_outputs.embed_logits[i]) + srt_logits = torch.Tensor(srt_outputs.embed_logits[i]) + + similarities = torch.tensor(get_similarities(hf_logits, srt_logits)) + + tolerance = 1e-2 + assert torch.all( + abs(similarities - 1) < tolerance + ), f"embeddings not all close" + + def test_prefill_logits(self): + for model, tp_size in MODELS: + for torch_dtype in TORCH_DTYPES: + self.assert_close_prefill_logits( + DEFAULT_PROMPTS, model, tp_size, torch_dtype + ) + + +if __name__ == "__main__": + unittest.main(warnings="ignore") diff --git a/test/srt/models/test_generation_models.py b/test/srt/models/test_generation_models.py new file mode 100644 index 00000000000..f057648020f --- /dev/null +++ b/test/srt/models/test_generation_models.py @@ -0,0 +1,68 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import unittest + +import torch + +from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner + +MODELS = [ + ("meta-llama/Meta-Llama-3.1-8B-Instruct", 1), +] +TORCH_DTYPES = [torch.float16] + + +class TestCausalModels(unittest.TestCase): + + def assert_close_prefill_logits( + self, + prompts, + model_path, + tp_size, + torch_dtype, + ) -> None: + with HFRunner( + model_path, torch_dtype=torch_dtype, is_generation_model=True + ) as hf_runner: + hf_outputs = hf_runner.forward(prompts) + + with SRTRunner( + model_path, + tp_size=tp_size, + torch_dtype=torch_dtype, + is_generation_model=True, + ) as srt_runner: + srt_outputs = srt_runner.forward(prompts) + + for i in range(len(prompts)): + hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i]) + srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i]) + + tolerance = 3e-2 + assert torch.all( + abs(hf_logprobs - srt_logprobs) < tolerance + ), f"prefill logprobs not all close" + + def test_prefill_logits(self): + for model, tp_size in MODELS: + for torch_dtype in TORCH_DTYPES: + self.assert_close_prefill_logits( + DEFAULT_PROMPTS, model, tp_size, torch_dtype + ) + + +if __name__ == "__main__": + unittest.main(warnings="ignore") diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py new file mode 100644 index 00000000000..d5051ffc1e2 --- /dev/null +++ b/test/srt/run_suite.py @@ -0,0 +1,54 @@ +import argparse +import glob + +from sglang.test.test_utils import run_unittest_files + +suites = { + "minimal": [ + "test_eval_accuracy.py", + "test_openai_server.py", + "test_vision_openai_server.py", + "test_chunked_prefill.py", + "test_torch_compile.py", + "test_models_from_modelscope.py", + "models/test_generation_models.py", + "models/test_embedding_models.py", + "sampling/penaltylib", + ], + "sampling/penaltylib": glob.glob( + "sampling/penaltylib/**/test_*.py", recursive=True + ), +} + +for target_suite_name, target_tests in suites.items(): + for suite_name, tests in suites.items(): + if suite_name == target_suite_name: + continue + if target_suite_name in tests: + tests.remove(target_suite_name) + tests.extend(target_tests) + +if __name__ == "__main__": + arg_parser = argparse.ArgumentParser() + arg_parser.add_argument( + "--timeout-per-file", + type=int, + default=2000, + help="The time limit for running one file in seconds.", + ) + arg_parser.add_argument( + "--suite", + type=str, + default=list(suites.keys())[0], + choices=list(suites.keys()) + ["all"], + help="The suite to run", + ) + args = arg_parser.parse_args() + + if args.suite == "all": + files = glob.glob("**/test_*.py", recursive=True) + else: + files = suites[args.suite] + + exit_code = run_unittest_files(files, args.timeout_per_file) + exit(exit_code) diff --git a/test/srt/sampling/penaltylib/penalizers/test_frequency_penalty.py b/test/srt/sampling/penaltylib/penalizers/test_frequency_penalty.py new file mode 100644 index 00000000000..59db353abfa --- /dev/null +++ b/test/srt/sampling/penaltylib/penalizers/test_frequency_penalty.py @@ -0,0 +1,93 @@ +import typing +import unittest + +import torch + +from sglang.srt.sampling.penaltylib.penalizers.frequency_penalty import ( + BatchedFrequencyPenalizer, +) +from sglang.test.srt.sampling.penaltylib.utils import ( + BaseBatchedPenalizerTest, + MockSamplingParams, + Step, + StepType, + Subject, +) + + +class BaseBatchedFrequencyPenalizerTest(BaseBatchedPenalizerTest): + Penalizer = BatchedFrequencyPenalizer + frequency_penalty: float + + def setUp(self): + if self.__class__ == BaseBatchedFrequencyPenalizerTest: + self.skipTest("Base class for frequency_penalty tests") + + super().setUp() + + def _create_subject(self, frequency_penalty: float) -> Subject: + return Subject( + sampling_params=MockSamplingParams( + frequency_penalty=frequency_penalty, + ), + steps=[ + Step( + type=StepType.INPUT, + token_ids=[0, 1, 2], + expected_tensors={ + "frequency_penalties": self.tensor( + [[frequency_penalty] * self.vocab_size], dtype=torch.float32 + ), + "cumulated_frequency_penalties": self.tensor( + [[0.0] * self.vocab_size], dtype=torch.float32 + ), + }, + expected_logits=self.tensor( + [[1] * self.vocab_size], dtype=torch.float32 + ), + ), + Step( + type=StepType.OUTPUT, + token_ids=[1, 2, 2], + expected_tensors={ + "frequency_penalties": self.tensor( + [[frequency_penalty] * self.vocab_size], dtype=torch.float32 + ), + "cumulated_frequency_penalties": self.tensor( + [ + [ + frequency_penalty * i if i in {1, 2} else 0.0 + for i in range(self.vocab_size) + ], + ], + dtype=torch.float32, + ), + }, + expected_logits=self.tensor( + [ + [ + 1.0 - frequency_penalty * i if i in {1, 2} else 1.0 + for i in range(self.vocab_size) + ], + ], + dtype=torch.float32, + ), + ), + ], + ) + + def create_test_subjects(self) -> typing.List[Subject]: + self.enabled = self._create_subject(frequency_penalty=self.frequency_penalty) + self.disabled = self._create_subject(frequency_penalty=0.0) + + +class TestBatchedFrequencyPenalizerPositiveValue(BaseBatchedFrequencyPenalizerTest): + frequency_penalty = 0.12 + + +class TestBatchedFrequencyPenalizerNegativeValue(BaseBatchedFrequencyPenalizerTest): + frequency_penalty = -0.12 + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/sampling/penaltylib/penalizers/test_min_new_tokens.py b/test/srt/sampling/penaltylib/penalizers/test_min_new_tokens.py new file mode 100644 index 00000000000..1984aafe5ea --- /dev/null +++ b/test/srt/sampling/penaltylib/penalizers/test_min_new_tokens.py @@ -0,0 +1,152 @@ +import typing +import unittest + +import torch + +from sglang.srt.sampling.penaltylib.penalizers.min_new_tokens import ( + BatchedMinNewTokensPenalizer, +) +from sglang.test.srt.sampling.penaltylib.utils import ( + BaseBatchedPenalizerTest, + MockSamplingParams, + Step, + StepType, + Subject, +) + +MIN_NEW_TOKENS = 2 +EOS_TOKEN_ID = 4 +STOP_TOKEN_ID = 3 + +ALL_STOP_TOKEN_IDS = {STOP_TOKEN_ID, EOS_TOKEN_ID} + + +class TestBatchedMinNewTokensPenalizer(BaseBatchedPenalizerTest): + Penalizer = BatchedMinNewTokensPenalizer + + def _create_subject(self, min_new_tokens: int) -> Subject: + return Subject( + eos_token_id=EOS_TOKEN_ID, + sampling_params=MockSamplingParams( + min_new_tokens=min_new_tokens, + stop_token_ids={STOP_TOKEN_ID}, + ), + steps=[ + Step( + type=StepType.INPUT, + token_ids=[0, 1, 2], + expected_tensors={ + "min_new_tokens": self.tensor( + [[min_new_tokens]], dtype=torch.int32 + ), + "stop_token_penalties": self.tensor( + [ + [ + float("-inf") if i in ALL_STOP_TOKEN_IDS else 0 + for i in range(self.vocab_size) + ] + ], + dtype=torch.float32, + ), + "len_output_tokens": self.tensor([[0]], dtype=torch.int32), + }, + expected_logits=( + self.tensor( + [ + [ + float("-inf") if i in ALL_STOP_TOKEN_IDS else 1 + for i in range(self.vocab_size) + ] + ], + dtype=torch.float32, + ) + if min_new_tokens > 0 + else torch.ones( + (1, self.vocab_size), + dtype=torch.float32, + device=self.device, + ) + ), + ), + Step( + type=StepType.OUTPUT, + token_ids=[0], + expected_tensors={ + "min_new_tokens": self.tensor( + [[min_new_tokens]], dtype=torch.int32 + ), + "stop_token_penalties": self.tensor( + [ + [ + float("-inf") if i in ALL_STOP_TOKEN_IDS else 0 + for i in range(self.vocab_size) + ] + ], + dtype=torch.float32, + ), + "len_output_tokens": self.tensor([[1]], dtype=torch.int32), + }, + expected_logits=( + self.tensor( + [ + [ + float("-inf") if i in ALL_STOP_TOKEN_IDS else 1 + for i in range(self.vocab_size) + ] + ], + dtype=torch.float32, + ) + if min_new_tokens > 1 + else torch.ones( + (1, self.vocab_size), + dtype=torch.float32, + device=self.device, + ) + ), + ), + Step( + type=StepType.OUTPUT, + token_ids=[0], + expected_tensors={ + "min_new_tokens": self.tensor( + [[min_new_tokens]], dtype=torch.int32 + ), + "stop_token_penalties": self.tensor( + [ + [ + float("-inf") if i in ALL_STOP_TOKEN_IDS else 0 + for i in range(self.vocab_size) + ] + ], + dtype=torch.float32, + ), + "len_output_tokens": self.tensor([[2]], dtype=torch.int32), + }, + expected_logits=( + self.tensor( + [ + [ + float("-inf") if i in ALL_STOP_TOKEN_IDS else 1 + for i in range(self.vocab_size) + ] + ], + dtype=torch.float32, + ) + if min_new_tokens > 2 + else torch.ones( + (1, self.vocab_size), + dtype=torch.float32, + device=self.device, + ) + ), + ), + ], + ) + + def create_test_subjects(self) -> typing.List[Subject]: + self.enabled = self._create_subject(min_new_tokens=MIN_NEW_TOKENS) + self.disabled = self._create_subject(min_new_tokens=0.0) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/sampling/penaltylib/penalizers/test_presence_penalty.py b/test/srt/sampling/penaltylib/penalizers/test_presence_penalty.py new file mode 100644 index 00000000000..96cbf1082e5 --- /dev/null +++ b/test/srt/sampling/penaltylib/penalizers/test_presence_penalty.py @@ -0,0 +1,93 @@ +import typing +import unittest + +import torch + +from sglang.srt.sampling.penaltylib.penalizers.presence_penalty import ( + BatchedPresencePenalizer, +) +from sglang.test.srt.sampling.penaltylib.utils import ( + BaseBatchedPenalizerTest, + MockSamplingParams, + Step, + StepType, + Subject, +) + + +class BaseBatchedPresencePenalizerTest(BaseBatchedPenalizerTest): + Penalizer = BatchedPresencePenalizer + presence_penalty: float + + def setUp(self): + if self.__class__ == BaseBatchedPresencePenalizerTest: + self.skipTest("Base class for presence_penalty tests") + + super().setUp() + + def _create_subject(self, presence_penalty: float) -> Subject: + return Subject( + sampling_params=MockSamplingParams( + presence_penalty=presence_penalty, + ), + steps=[ + Step( + type=StepType.INPUT, + token_ids=[0, 1, 2], + expected_tensors={ + "presence_penalties": self.tensor( + [[presence_penalty] * self.vocab_size], dtype=torch.float32 + ), + "cumulated_presence_penalties": self.tensor( + [[0.0] * self.vocab_size], dtype=torch.float32 + ), + }, + expected_logits=self.tensor( + [[1] * self.vocab_size], dtype=torch.float32 + ), + ), + Step( + type=StepType.OUTPUT, + token_ids=[1, 2, 2], + expected_tensors={ + "presence_penalties": self.tensor( + [[presence_penalty] * self.vocab_size], dtype=torch.float32 + ), + "cumulated_presence_penalties": self.tensor( + [ + [ + presence_penalty if i in {1, 2} else 0.0 + for i in range(self.vocab_size) + ], + ], + dtype=torch.float32, + ), + }, + expected_logits=self.tensor( + [ + [ + 1.0 - presence_penalty if i in {1, 2} else 1.0 + for i in range(self.vocab_size) + ], + ], + dtype=torch.float32, + ), + ), + ], + ) + + def create_test_subjects(self) -> typing.List[Subject]: + self.enabled = self._create_subject(presence_penalty=self.presence_penalty) + self.disabled = self._create_subject(presence_penalty=0.0) + + +class TestBatchedPresencePenalizerPositiveValue(BaseBatchedPresencePenalizerTest): + presence_penalty = 0.12 + + +class TestBatchedPresencePenalizerNegativeValue(BaseBatchedPresencePenalizerTest): + presence_penalty = -0.12 + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/sampling/penaltylib/penalizers/test_repetition_penalty.py b/test/srt/sampling/penaltylib/penalizers/test_repetition_penalty.py new file mode 100644 index 00000000000..e3751c14a30 --- /dev/null +++ b/test/srt/sampling/penaltylib/penalizers/test_repetition_penalty.py @@ -0,0 +1,87 @@ +import typing +import unittest + +import torch + +from sglang.srt.sampling.penaltylib.penalizers.repetition_penalty import ( + BatchedRepetitionPenalizer, +) +from sglang.test.srt.sampling.penaltylib.utils import ( + BaseBatchedPenalizerTest, + MockSamplingParams, + Step, + StepType, + Subject, +) + +REPETITION_PENALTY = 2.0 + + +class TestBatchedRepetitionPenalizer(BaseBatchedPenalizerTest): + Penalizer = BatchedRepetitionPenalizer + + def _create_subject(self, repetition_penalty: float) -> Subject: + l = 1.0 / repetition_penalty + return Subject( + sampling_params=MockSamplingParams( + repetition_penalty=repetition_penalty, + ), + steps=[ + Step( + type=StepType.INPUT, + token_ids=[0, 1, 2], + expected_tensors={ + "repetition_penalties": self.tensor( + [[repetition_penalty] * self.vocab_size], + dtype=torch.float32, + ), + "cumulated_repetition_penalties": ( + self.tensor( + [[2.0, 2.0, 2.0, 1.0, 1.0]], dtype=torch.float32 + ) + if repetition_penalty != 1.0 + else self.tensor( + [[1.0] * self.vocab_size], dtype=torch.float32 + ) + ), + }, + expected_logits=( + self.tensor([[l, l, l, 1.0, 1.0]], dtype=torch.float32) + if repetition_penalty != 1.0 + else self.tensor([[1.0] * self.vocab_size], dtype=torch.float32) + ), + ), + Step( + type=StepType.OUTPUT, + token_ids=[0, 1, 3], + expected_tensors={ + "repetition_penalties": self.tensor( + [[repetition_penalty] * self.vocab_size], + dtype=torch.float32, + ), + "cumulated_repetition_penalties": ( + self.tensor( + [[2.0, 2.0, 2.0, 2.0, 1.0]], dtype=torch.float32 + ) + if repetition_penalty != 1.0 + else self.tensor( + [[1.0] * self.vocab_size], dtype=torch.float32 + ) + ), + }, + expected_logits=( + self.tensor([[l, l, l, l, 1.0]], dtype=torch.float32) + if repetition_penalty != 1.0 + else self.tensor([[1.0] * self.vocab_size], dtype=torch.float32) + ), + ), + ], + ) + + def create_test_subjects(self) -> typing.List[Subject]: + self.enabled = self._create_subject(repetition_penalty=REPETITION_PENALTY) + self.disabled = self._create_subject(repetition_penalty=1.0) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py b/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py new file mode 100644 index 00000000000..e72dc30f956 --- /dev/null +++ b/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py @@ -0,0 +1,110 @@ +import json +import unittest +from multiprocessing import Process + +import requests + +from sglang.srt.utils import kill_child_process +from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST, popen_launch_server + + +class TestBatchPenalizerE2E(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = f"http://127.0.0.1:{8157}" + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=300, + other_args=( + "--random-seed", + "0", + ), + ) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid) + + def run_decode( + self, + return_logprob=True, + top_logprobs_num=5, + return_text=True, + n=1, + **sampling_params, + ): + response = requests.post( + self.base_url + "/generate", + json={ + # prompt that is supposed to generate < 32 tokens + "text": "<|start_header_id|>user<|end_header_id|>\n\nWhat is the answer for 1 + 1 = ?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", + "sampling_params": { + "max_new_tokens": 32, + "n": n, + **sampling_params, + }, + "stream": False, + "return_logprob": return_logprob, + "top_logprobs_num": top_logprobs_num, + "return_text_in_logprobs": return_text, + "logprob_start_len": 0, + }, + ) + print(json.dumps(response.json())) + print("=" * 100) + + def test_default_values(self): + self.run_decode() + + def test_mixed(self): + """ + Sends two requests with one with penalizers disabled, and the other with penalizers enabled. + This will cause two different {ScheduleBatch} to be initialized and eventually gets merged. + + Merging batch with penalizers enabled with enabled, or disabled is trivial. However disabled + enabled is not. + This is because the penalizer will not be prepared if it is not required, then it will be prepared during the merge. + + This test triggers the merge of disabled + enabled. + """ + + processes = [] + + p = Process( + target=self.run_decode, + ) + processes.append(p) + p.start() + + p = Process( + target=self.run_decode, + kwargs={ + "frequency_penalty": 2, + "min_new_tokens": 16, + "presence_penalty": 2, + "repetition_penalty": 2, + }, + ) + processes.append(p) + p.start() + + for p in processes: + p.join() + + def test_frequency_penalty(self): + self.run_decode(frequency_penalty=2) + + def test_min_new_tokens(self): + self.run_decode(min_new_tokens=16) + + def test_presence_penalty(self): + self.run_decode(presence_penalty=2) + + def test_repetition_penalty(self): + self.run_decode(repetition_penalty=2) + + +if __name__ == "__main__": + unittest.main(warnings="ignore") diff --git a/test/srt/test_chunked_prefill.py b/test/srt/test_chunked_prefill.py new file mode 100644 index 00000000000..7f274926a62 --- /dev/null +++ b/test/srt/test_chunked_prefill.py @@ -0,0 +1,45 @@ +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_child_process +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST, popen_launch_server + + +class TestAccuracy(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = "http://127.0.0.1:8157" + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=300, + other_args=["--chunked-prefill-size", "32"], + ) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=20, + num_threads=20, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.5 + + +if __name__ == "__main__": + unittest.main(warnings="ignore") + + # t = TestAccuracy() + # t.setUpClass() + # t.test_mmlu() + # t.tearDownClass() diff --git a/test/srt/test_eval_accuracy.py b/test/srt/test_eval_accuracy.py new file mode 100644 index 00000000000..b6359362670 --- /dev/null +++ b/test/srt/test_eval_accuracy.py @@ -0,0 +1,40 @@ +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_child_process +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST, popen_launch_server + + +class TestAccuracy(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = "http://127.0.0.1:8157" + cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=20, + num_threads=20, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.5 + + +if __name__ == "__main__": + unittest.main(warnings="ignore") + + # t = TestAccuracy() + # t.setUpClass() + # t.test_mmlu() + # t.tearDownClass() diff --git a/test/srt/test_models_from_modelscope.py b/test/srt/test_models_from_modelscope.py new file mode 100644 index 00000000000..2313053b909 --- /dev/null +++ b/test/srt/test_models_from_modelscope.py @@ -0,0 +1,47 @@ +import os +import shutil +import subprocess +import unittest +from unittest import mock + +from sglang.srt.utils import prepare_model, prepare_tokenizer + + +class TestDownloadFromModelScope(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.model = "iic/nlp_lstmcrf_word-segmentation_chinese-news" + stat, output = subprocess.getstatusoutput("pip install modelscope") + + cls.with_modelscope_environ = {k: v for k, v in os.environ.items()} + cls.with_modelscope_environ["SGLANG_USE_MODELSCOPE"] = "True" + + @classmethod + def tearDownClass(cls): + pass + + def test_prepare_model(self): + from modelscope.utils.file_utils import get_model_cache_root + + model_cache_root = get_model_cache_root() + if os.path.exists(model_cache_root): + shutil.rmtree(model_cache_root) + with mock.patch.dict(os.environ, self.with_modelscope_environ, clear=True): + model_path = prepare_model(self.model) + assert os.path.exists(os.path.join(model_path, "pytorch_model.bin")) + + def test_prepare_tokenizer(self): + from modelscope.utils.file_utils import get_model_cache_root + + model_cache_root = get_model_cache_root() + if os.path.exists(model_cache_root): + shutil.rmtree(model_cache_root) + with mock.patch.dict(os.environ, self.with_modelscope_environ, clear=True): + tokenizer_path = prepare_tokenizer(self.model) + assert not os.path.exists(os.path.join(tokenizer_path, "pytorch_model.bin")) + assert os.path.exists(os.path.join(tokenizer_path, "config.json")) + + +if __name__ == "__main__": + unittest.main(warnings="ignore") diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index a77319b1baa..f8f6ca63210 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -1,209 +1,407 @@ -""" -First run the following command to launch the server. -Note that TinyLlama adopts different chat templates in different versions. -For v0.4, the chat template is chatml. - -python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 \ ---port 30000 --chat-template chatml - -Output example: -The capital of France is Paris. -The capital of the United States is Washington, D.C. -The capital of Canada is Ottawa. -The capital of Japan is Tokyo -""" - -import argparse import json +import time +import unittest import openai +from sglang.srt.hf_transformers_utils import get_tokenizer +from sglang.srt.utils import kill_child_process +from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST, popen_launch_server -def test_completion(args, echo, logprobs): - client = openai.Client(api_key="EMPTY", base_url=args.base_url) - response = client.completions.create( - model="default", - prompt="The capital of France is", - temperature=0, - max_tokens=32, - echo=echo, - logprobs=logprobs, - ) - text = response.choices[0].text - print(response.choices[0].text) - if echo: - assert text.startswith("The capital of France is") - if logprobs: - print(response.choices[0].logprobs.top_logprobs) - assert response.choices[0].logprobs - if echo: - assert response.choices[0].logprobs.token_logprobs[0] == None + +class TestOpenAIServer(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = "http://127.0.0.1:8157" + cls.api_key = "sk-123456" + cls.process = popen_launch_server( + cls.model, cls.base_url, timeout=300, api_key=cls.api_key + ) + cls.base_url += "/v1" + cls.tokenizer = get_tokenizer(DEFAULT_MODEL_NAME_FOR_TEST) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid) + + def run_completion( + self, echo, logprobs, use_list_input, parallel_sample_num, token_input + ): + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + prompt = "The capital of France is" + if token_input: + prompt_input = self.tokenizer.encode(prompt) + num_prompt_tokens = len(prompt_input) + else: + prompt_input = prompt + num_prompt_tokens = len(self.tokenizer.encode(prompt)) + + if use_list_input: + prompt_arg = [prompt_input, prompt_input] + num_choices = len(prompt_arg) + num_prompt_tokens *= 2 else: - assert response.choices[0].logprobs.token_logprobs[0] != None - assert response.id - assert response.created - assert response.usage.prompt_tokens > 0 - assert response.usage.completion_tokens > 0 - assert response.usage.total_tokens > 0 - print("=" * 100) - - -def test_completion_stream(args, echo, logprobs): - client = openai.Client(api_key="EMPTY", base_url=args.base_url) - response = client.completions.create( - model="default", - prompt="The capital of France is", - temperature=0, - max_tokens=32, - stream=True, - echo=echo, - logprobs=logprobs, - ) - first = True - for r in response: - if first: + prompt_arg = prompt_input + num_choices = 1 + + response = client.completions.create( + model=self.model, + prompt=prompt_arg, + temperature=0, + max_tokens=32, + echo=echo, + logprobs=logprobs, + n=parallel_sample_num, + ) + + assert len(response.choices) == num_choices * parallel_sample_num + + if echo: + text = response.choices[0].text + assert text.startswith(prompt) + + if logprobs: + assert response.choices[0].logprobs + assert isinstance(response.choices[0].logprobs.tokens[0], str) + assert isinstance(response.choices[0].logprobs.top_logprobs[1], dict) + ret_num_top_logprobs = len(response.choices[0].logprobs.top_logprobs[1]) + # FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some out_put id maps to the same output token and duplicate in the map + # assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}" + assert ret_num_top_logprobs > 0 if echo: - assert r.choices[0].text.startswith("The capital of France is") - first = False + assert response.choices[0].logprobs.token_logprobs[0] == None + else: + assert response.choices[0].logprobs.token_logprobs[0] != None + + assert response.id + assert response.created + assert ( + response.usage.prompt_tokens == num_prompt_tokens + ), f"{response.usage.prompt_tokens} vs {num_prompt_tokens}" + assert response.usage.completion_tokens > 0 + assert response.usage.total_tokens > 0 + + def run_completion_stream(self, echo, logprobs, token_input): + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + prompt = "The capital of France is" + if token_input: + prompt_arg = self.tokenizer.encode(prompt) + else: + prompt_arg = prompt + generator = client.completions.create( + model=self.model, + prompt=prompt_arg, + temperature=0, + max_tokens=32, + echo=echo, + logprobs=logprobs, + stream=True, + stream_options={"include_usage": True}, + ) + + first = True + for response in generator: + usage = response.usage + if usage is not None: + assert usage.prompt_tokens > 0 + assert usage.completion_tokens > 0 + assert usage.total_tokens > 0 + continue + if logprobs: + assert response.choices[0].logprobs + assert isinstance(response.choices[0].logprobs.tokens[0], str) + if not (first and echo): + assert isinstance( + response.choices[0].logprobs.top_logprobs[0], dict + ) + ret_num_top_logprobs = len( + response.choices[0].logprobs.top_logprobs[0] + ) + # FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some out_put id maps to the same output token and duplicate in the map + # assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}" + assert ret_num_top_logprobs > 0 + + if first: + if echo: + assert response.choices[0].text.startswith( + prompt + ), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {first}" + first = False + assert response.id + assert response.created + + def run_chat_completion(self, logprobs, parallel_sample_num): + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + response = client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": "You are a helpful AI assistant"}, + { + "role": "user", + "content": "What is the capital of France? Answer in a few words.", + }, + ], + temperature=0, + logprobs=logprobs is not None and logprobs > 0, + top_logprobs=logprobs, + n=parallel_sample_num, + ) + if logprobs: - print( - f"{r.choices[0].text:12s}\t" f"{r.choices[0].logprobs.token_logprobs}", - flush=True, + assert isinstance( + response.choices[0].logprobs.content[0].top_logprobs[0].token, str + ) + + ret_num_top_logprobs = len( + response.choices[0].logprobs.content[0].top_logprobs ) - print(r.choices[0].logprobs.top_logprobs) + assert ( + ret_num_top_logprobs == logprobs + ), f"{ret_num_top_logprobs} vs {logprobs}" + + assert len(response.choices) == parallel_sample_num + assert response.choices[0].message.role == "assistant" + assert isinstance(response.choices[0].message.content, str) + assert response.id + assert response.created + assert response.usage.prompt_tokens > 0 + assert response.usage.completion_tokens > 0 + assert response.usage.total_tokens > 0 + + def run_chat_completion_stream(self, logprobs): + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + generator = client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": "You are a helpful AI assistant"}, + {"role": "user", "content": "What is the capital of France?"}, + ], + temperature=0, + logprobs=logprobs is not None and logprobs > 0, + top_logprobs=logprobs, + stream=True, + stream_options={"include_usage": True}, + ) + + is_first = True + for response in generator: + usage = response.usage + if usage is not None: + assert usage.prompt_tokens > 0 + assert usage.completion_tokens > 0 + assert usage.total_tokens > 0 + continue + + data = response.choices[0].delta + + if is_first: + data.role == "assistant" + is_first = False + continue + + if logprobs: + assert response.choices[0].logprobs + assert isinstance( + response.choices[0].logprobs.content[0].top_logprobs[0].token, str + ) + assert isinstance( + response.choices[0].logprobs.content[0].top_logprobs, list + ) + ret_num_top_logprobs = len( + response.choices[0].logprobs.content[0].top_logprobs + ) + assert ( + ret_num_top_logprobs == logprobs + ), f"{ret_num_top_logprobs} vs {logprobs}" + + assert isinstance(data.content, str) + assert response.id + assert response.created + + def run_batch(self, mode): + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + if mode == "completion": + input_file_path = "complete_input.jsonl" + # write content to input file + content = [ + { + "custom_id": "request-1", + "method": "POST", + "url": "/v1/completions", + "body": { + "model": "gpt-3.5-turbo-instruct", + "prompt": "List 3 names of famous soccer player: ", + "max_tokens": 20, + }, + }, + { + "custom_id": "request-2", + "method": "POST", + "url": "/v1/completions", + "body": { + "model": "gpt-3.5-turbo-instruct", + "prompt": "List 6 names of famous basketball player: ", + "max_tokens": 40, + }, + }, + { + "custom_id": "request-3", + "method": "POST", + "url": "/v1/completions", + "body": { + "model": "gpt-3.5-turbo-instruct", + "prompt": "List 6 names of famous tenniss player: ", + "max_tokens": 40, + }, + }, + ] + else: - print(r.choices[0].text, end="", flush=True) - assert r.id - assert r.usage.prompt_tokens > 0 - assert r.usage.completion_tokens > 0 - assert r.usage.total_tokens > 0 - print("=" * 100) - - -def test_chat_completion(args): - client = openai.Client(api_key="EMPTY", base_url=args.base_url) - response = client.chat.completions.create( - model="default", - messages=[ - {"role": "system", "content": "You are a helpful AI assistant"}, - {"role": "user", "content": "What is the capital of France?"}, - ], - temperature=0, - max_tokens=32, - ) - print(response.choices[0].message.content) - assert response.id - assert response.created - assert response.usage.prompt_tokens > 0 - assert response.usage.completion_tokens > 0 - assert response.usage.total_tokens > 0 - print("=" * 100) - - -def test_chat_completion_image(args): - client = openai.Client(api_key="EMPTY", base_url=args.base_url) - response = client.chat.completions.create( - model="default", - messages=[ - {"role": "system", "content": "You are a helpful AI assistant"}, - { - "role": "user", - "content": [ - {"type": "text", "text": "Describe this image"}, - { - "type": "image_url", - "image_url": { - "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/mixtral_8x7b.jpg" - }, + input_file_path = "chat_input.jsonl" + content = [ + { + "custom_id": "request-1", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": "gpt-3.5-turbo-0125", + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant.", + }, + { + "role": "user", + "content": "Hello! List 3 NBA players and tell a story", + }, + ], + "max_tokens": 30, + }, + }, + { + "custom_id": "request-2", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": "gpt-3.5-turbo-0125", + "messages": [ + {"role": "system", "content": "You are an assistant. "}, + { + "role": "user", + "content": "Hello! List three capital and tell a story", + }, + ], + "max_tokens": 50, }, - ], - }, - ], - temperature=0, - max_tokens=32, - ) - print(response.choices[0].message.content) - assert response.id - assert response.created - assert response.usage.prompt_tokens > 0 - assert response.usage.completion_tokens > 0 - assert response.usage.total_tokens > 0 - print("=" * 100) - - -def test_chat_completion_stream(args): - client = openai.Client(api_key="EMPTY", base_url=args.base_url) - response = client.chat.completions.create( - model="default", - messages=[ - {"role": "system", "content": "You are a helpful AI assistant"}, - {"role": "user", "content": "List 3 countries and their capitals."}, - ], - temperature=0, - max_tokens=64, - stream=True, - ) - is_first = True - for chunk in response: - if is_first: - is_first = False - assert chunk.choices[0].delta.role == "assistant" - continue - - data = chunk.choices[0].delta - if not data.content: - continue - print(data.content, end="", flush=True) - print("=" * 100) - - -def test_regex(args): - client = openai.Client(api_key="EMPTY", base_url=args.base_url) - - regex = ( - r"""\{\n""" - + r""" "name": "[\w]+",\n""" - + r""" "population": [\d]+\n""" - + r"""\}""" - ) - - response = client.chat.completions.create( - model="default", - messages=[ - {"role": "system", "content": "You are a helpful AI assistant"}, - {"role": "user", "content": "Introduce the capital of France."}, - ], - temperature=0, - max_tokens=128, - extra_body={"regex": regex}, - ) - text = response.choices[0].message.content - print(json.loads(text)) - print("=" * 100) + }, + ] + with open(input_file_path, "w") as file: + for line in content: + file.write(json.dumps(line) + "\n") + with open(input_file_path, "rb") as file: + uploaded_file = client.files.create(file=file, purpose="batch") + if mode == "completion": + endpoint = "/v1/completions" + elif mode == "chat": + endpoint = "/v1/chat/completions" + completion_window = "24h" + batch_job = client.batches.create( + input_file_id=uploaded_file.id, + endpoint=endpoint, + completion_window=completion_window, + ) + while batch_job.status not in ["completed", "failed", "cancelled"]: + time.sleep(3) + print( + f"Batch job status: {batch_job.status}...trying again in 3 seconds..." + ) + batch_job = client.batches.retrieve(batch_job.id) + assert batch_job.status == "completed" + assert batch_job.request_counts.completed == len(content) + assert batch_job.request_counts.failed == 0 + assert batch_job.request_counts.total == len(content) + + result_file_id = batch_job.output_file_id + file_response = client.files.content(result_file_id) + result_content = file_response.read().decode("utf-8") # Decode bytes to string + results = [ + json.loads(line) + for line in result_content.split("\n") + if line.strip() != "" + ] + assert len(results) == len(content) + + def test_completion(self): + for echo in [False, True]: + for logprobs in [None, 5]: + for use_list_input in [True, False]: + for parallel_sample_num in [1, 2]: + for token_input in [False, True]: + self.run_completion( + echo, + logprobs, + use_list_input, + parallel_sample_num, + token_input, + ) + + def test_completion_stream(self): + # parallel sampling adn list input are not supported in streaming mode + for echo in [False, True]: + for logprobs in [None, 5]: + for token_input in [False, True]: + self.run_completion_stream(echo, logprobs, token_input) + + def test_chat_completion(self): + for logprobs in [None, 5]: + for parallel_sample_num in [1, 2]: + self.run_chat_completion(logprobs, parallel_sample_num) + + def test_chat_completion_stream(self): + for logprobs in [None, 5]: + self.run_chat_completion_stream(logprobs) + + def test_batch(self): + for mode in ["completion", "chat"]: + self.run_batch(mode) + + def test_regex(self): + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + + regex = ( + r"""\{\n""" + + r""" "name": "[\w]+",\n""" + + r""" "population": [\d]+\n""" + + r"""\}""" + ) + + response = client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": "You are a helpful AI assistant"}, + {"role": "user", "content": "Introduce the capital of France."}, + ], + temperature=0, + max_tokens=128, + extra_body={"regex": regex}, + ) + text = response.choices[0].message.content + + try: + js_obj = json.loads(text) + except (TypeError, json.decoder.JSONDecodeError): + print("JSONDecodeError", text) + raise + assert isinstance(js_obj["name"], str) + assert isinstance(js_obj["population"], int) if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--base-url", type=str, default="http://127.0.0.1:30000/v1") - parser.add_argument( - "--test-image", action="store_true", help="Enables testing image inputs" - ) - args = parser.parse_args() - - test_completion(args, echo=False, logprobs=False) - test_completion(args, echo=True, logprobs=False) - test_completion(args, echo=False, logprobs=True) - test_completion(args, echo=True, logprobs=True) - test_completion(args, echo=False, logprobs=3) - test_completion(args, echo=True, logprobs=3) - test_completion_stream(args, echo=False, logprobs=False) - test_completion_stream(args, echo=True, logprobs=False) - test_completion_stream(args, echo=False, logprobs=True) - test_completion_stream(args, echo=True, logprobs=True) - test_completion_stream(args, echo=False, logprobs=3) - test_completion_stream(args, echo=True, logprobs=3) - test_chat_completion(args) - test_chat_completion_stream(args) - test_regex(args) - if args.test_image: - test_chat_completion_image(args) + unittest.main(warnings="ignore") + + # t = TestOpenAIServer() + # t.setUpClass() + # t.test_completion() + # t.tearDownClass() diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py new file mode 100644 index 00000000000..b208dfa1329 --- /dev/null +++ b/test/srt/test_srt_endpoint.py @@ -0,0 +1,62 @@ +import json +import unittest + +import requests + +from sglang.srt.utils import kill_child_process +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST, popen_launch_server + + +class TestSRTEndpoint(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = "http://127.0.0.1:8157" + cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid) + + def run_decode( + self, return_logprob=False, top_logprobs_num=0, return_text=False, n=1 + ): + response = requests.post( + self.base_url + "/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0 if n == 1 else 0.5, + "max_new_tokens": 32, + "n": n, + }, + "stream": False, + "return_logprob": return_logprob, + "top_logprobs_num": top_logprobs_num, + "return_text_in_logprobs": return_text, + "logprob_start_len": 0, + }, + ) + print(json.dumps(response.json())) + print("=" * 100) + + def test_simple_decode(self): + self.run_decode() + + def test_parallel_sample(self): + self.run_decode(n=3) + + def test_logprob(self): + for top_logprobs_num in [0, 3]: + for return_text in [True, False]: + self.run_decode( + return_logprob=True, + top_logprobs_num=top_logprobs_num, + return_text=return_text, + ) + + +if __name__ == "__main__": + unittest.main(warnings="ignore") diff --git a/test/srt/test_torch_compile.py b/test/srt/test_torch_compile.py new file mode 100644 index 00000000000..fd2c6ebb778 --- /dev/null +++ b/test/srt/test_torch_compile.py @@ -0,0 +1,42 @@ +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_child_process +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST, popen_launch_server + + +class TestAccuracy(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = "http://127.0.0.1:8157" + cls.process = popen_launch_server( + cls.model, cls.base_url, timeout=300, other_args=["--enable-torch-compile"] + ) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=20, + num_threads=20, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.5 + + +if __name__ == "__main__": + unittest.main(warnings="ignore") + + # t = TestAccuracy() + # t.setUpClass() + # t.test_mmlu() + # t.tearDownClass() diff --git a/test/srt/test_vision_openai_server.py b/test/srt/test_vision_openai_server.py new file mode 100644 index 00000000000..982c026dbbf --- /dev/null +++ b/test/srt/test_vision_openai_server.py @@ -0,0 +1,121 @@ +import json +import unittest + +import openai + +from sglang.srt.hf_transformers_utils import get_tokenizer +from sglang.srt.utils import kill_child_process +from sglang.test.test_utils import popen_launch_server + + +class TestOpenAIVisionServer(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.model = "liuhaotian/llava-v1.6-vicuna-7b" + cls.base_url = "http://127.0.0.1:8157" + cls.api_key = "sk-123456" + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=300, + api_key=cls.api_key, + other_args=[ + "--chat-template", + "vicuna_v1.1", + "--tokenizer-path", + "llava-hf/llava-1.5-7b-hf", + "--log-requests", + ], + ) + cls.base_url += "/v1" + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid) + + def test_chat_completion(self): + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + + response = client.chat.completions.create( + model="default", + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true" + }, + }, + { + "type": "text", + "text": "Describe this image in a very short sentence.", + }, + ], + }, + ], + temperature=0, + ) + + assert response.choices[0].message.role == "assistant" + text = response.choices[0].message.content + assert isinstance(text, str) + assert "car" in text or "taxi" in text, text + assert response.id + assert response.created + assert response.usage.prompt_tokens > 0 + assert response.usage.completion_tokens > 0 + assert response.usage.total_tokens > 0 + + def test_regex(self): + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + + regex = ( + r"""\{\n""" + + r""" "color": "[\w]+",\n""" + + r""" "number_of_cars": [\d]+\n""" + + r"""\}""" + ) + + response = client.chat.completions.create( + model="default", + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true" + }, + }, + { + "type": "text", + "text": "Describe this image in the JSON format.", + }, + ], + }, + ], + temperature=0, + extra_body={"regex": regex}, + ) + text = response.choices[0].message.content + + try: + js_obj = json.loads(text) + except (TypeError, json.decoder.JSONDecodeError): + print("JSONDecodeError", text) + raise + assert isinstance(js_obj["color"], str) + assert isinstance(js_obj["number_of_cars"], int) + + +if __name__ == "__main__": + unittest.main(warnings="ignore") + + # t = TestOpenAIVisionServer() + # t.setUpClass() + # t.test_chat_completion() + # t.tearDownClass()