Skip to content

Commit

Permalink
Sampler cudagraph (sgl-project#1253)
Browse files Browse the repository at this point in the history
  • Loading branch information
hnyls2002 authored Aug 29, 2024
1 parent 8153168 commit 381dd57
Show file tree
Hide file tree
Showing 29 changed files with 342 additions and 116 deletions.
14 changes: 7 additions & 7 deletions python/sglang/bench_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,16 +200,16 @@ def extend(reqs, model_runner):
tree_cache=None,
)
batch.prepare_for_extend(model_runner.model_config.vocab_size)
output = model_runner.forward(batch, ForwardMode.EXTEND)
next_token_ids = batch.sample(output.next_token_logits)
return next_token_ids, output.next_token_logits, batch
sample_output, logits_output = model_runner.forward(batch, ForwardMode.EXTEND)
next_token_ids = sample_output.batch_next_token_ids.tolist()
return next_token_ids, logits_output.next_token_logits, batch


def decode(input_token_ids, batch, model_runner):
batch.prepare_for_decode(input_token_ids.cpu().numpy())
output = model_runner.forward(batch, ForwardMode.DECODE)
next_token_ids = batch.sample(output.next_token_logits)
return next_token_ids, output.next_token_logits
batch.prepare_for_decode(input_token_ids)
sample_output, logits_output = model_runner.forward(batch, ForwardMode.DECODE)
next_token_ids = sample_output.batch_next_token_ids.tolist()
return next_token_ids, logits_output.next_token_logits


@torch.inference_mode()
Expand Down
8 changes: 4 additions & 4 deletions python/sglang/srt/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@


@dataclasses.dataclass
class LogitProcessorOutput:
class LogitsProcessorOutput:
# The logits of the next tokens. shape: [#seq, vocab_size]
next_token_logits: torch.Tensor
# The logprobs of the next tokens. shape: [#seq, vocab_size]
Expand Down Expand Up @@ -185,7 +185,7 @@ def forward(

# Return only last_logits if logprob is not requested
if not logits_metadata.return_logprob:
return LogitProcessorOutput(
return LogitsProcessorOutput(
next_token_logits=last_logits,
next_token_logprobs=None,
normalized_prompt_logprobs=None,
Expand All @@ -209,7 +209,7 @@ def forward(
else:
output_top_logprobs = None

return LogitProcessorOutput(
return LogitsProcessorOutput(
next_token_logits=last_logits,
next_token_logprobs=last_logprobs,
normalized_prompt_logprobs=None,
Expand Down Expand Up @@ -278,7 +278,7 @@ def forward(
# Remove the last token logprob for the prefill tokens.
input_token_logprobs = input_token_logprobs[:-1]

return LogitProcessorOutput(
return LogitsProcessorOutput(
next_token_logits=last_logits,
next_token_logprobs=last_logprobs,
normalized_prompt_logprobs=normalized_prompt_logprobs,
Expand Down
83 changes: 68 additions & 15 deletions python/sglang/srt/layers/sampler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import dataclasses
import logging
from typing import Union

import torch
from flashinfer.sampling import (
Expand All @@ -9,37 +11,80 @@
)
from vllm.model_executor.custom_op import CustomOp

from sglang.srt.layers.logits_processor import LogitsProcessorOutput

# TODO: move this dict to another place
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo

logger = logging.getLogger(__name__)


@dataclasses.dataclass
class SampleOutput:
success: torch.Tensor
probs: torch.Tensor
batch_next_token_ids: torch.Tensor


class Sampler(CustomOp):
def __init__(self):
super().__init__()

def forward_cuda(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
def _apply_penalties(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
# min-token, presence, frequency
if sampling_info.linear_penalties is not None:
logits += sampling_info.linear_penalties

# repetition
if sampling_info.scaling_penalties is not None:
logits = torch.where(
logits > 0,
logits / sampling_info.scaling_penalties,
logits * sampling_info.scaling_penalties,
)

return logits

def _get_probs(
self,
logits: torch.Tensor,
sampling_info: SamplingBatchInfo,
is_torch_compile: bool = False,
):
# Post process logits
logits = logits.contiguous()
logits.div_(sampling_info.temperatures)
if is_torch_compile:
# FIXME: Temporary workaround for unknown bugs in torch.compile
logits.add_(0)

if sampling_info.logit_bias is not None:
logits.add_(sampling_info.logit_bias)

if sampling_info.vocab_mask is not None:
logits = logits.masked_fill(~sampling_info.vocab_mask, float("-inf"))

logits = sampling_info.penalizer_orchestrator.apply(logits)
logits = self._apply_penalties(logits, sampling_info)

probs = torch.softmax(logits, dim=-1)
return torch.softmax(logits, dim=-1)

def forward_cuda(
self,
logits: Union[torch.Tensor, LogitsProcessorOutput],
sampling_info: SamplingBatchInfo,
):
if isinstance(logits, LogitsProcessorOutput):
logits = logits.next_token_logits

probs = self._get_probs(logits, sampling_info)

if not global_server_args_dict["disable_flashinfer_sampling"]:
max_top_k_round, batch_size = 32, probs.shape[0]
uniform_samples = torch.rand(
(max_top_k_round, batch_size), device=probs.device
)
if sampling_info.min_ps.any():
if sampling_info.need_min_p_sampling:
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
batch_next_token_ids, success = min_p_sampling_from_probs(
Expand All @@ -55,18 +100,23 @@ def forward_cuda(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
)

if not torch.all(success):
logging.warning("Sampling failed, fallback to top_k=1 strategy")
probs = probs.masked_fill(torch.isnan(probs), 0.0)
argmax_ids = torch.argmax(probs, dim=-1)
batch_next_token_ids = torch.where(
success, batch_next_token_ids, argmax_ids
)
return SampleOutput(success, probs, batch_next_token_ids)

return batch_next_token_ids
def forward_native(
self,
logits: Union[torch.Tensor, LogitsProcessorOutput],
sampling_info: SamplingBatchInfo,
):
if isinstance(logits, LogitsProcessorOutput):
logits = logits.next_token_logits

probs = self._get_probs(logits, sampling_info, is_torch_compile=True)

batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
)

def forward_native():
raise NotImplementedError("Native forward is not implemented yet.")
return SampleOutput(success, probs, batch_next_token_ids)


def top_k_top_p_min_p_sampling_from_probs_torch(
Expand All @@ -87,7 +137,10 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
try:
sampled_index = torch.multinomial(probs_sort, num_samples=1)
# FIXME: torch.multiomial does not support num_samples = 1
sampled_index = torch.multinomial(probs_sort, num_samples=2, replacement=True)[
:, :1
]
except RuntimeError as e:
logger.warning(f"Sampling error: {e}")
batch_next_token_ids = torch.zeros(
Expand Down
28 changes: 20 additions & 8 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -17,7 +19,7 @@

import logging
from dataclasses import dataclass
from typing import List, Optional, Union
from typing import TYPE_CHECKING, List, Optional, Union

import torch

Expand All @@ -29,6 +31,10 @@
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo

if TYPE_CHECKING:
from sglang.srt.layers.sampler import SampleOutput


INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5

# Put some global args for easy access
Expand Down Expand Up @@ -678,11 +684,17 @@ def merge(self, other: "ScheduleBatch"):
self.top_logprobs_nums.extend(other.top_logprobs_nums)
self.return_logprob = any(req.return_logprob for req in self.reqs)

def sample(self, logits: torch.Tensor):
from sglang.srt.layers.sampler import Sampler

sampler = Sampler()

batch_next_token_ids = sampler(logits, self.sampling_info)
def check_sample_results(self, sample_output: SampleOutput):
if not torch.all(sample_output.success):
probs = sample_output.probs
batch_next_token_ids = sample_output.batch_next_token_ids
logging.warning("Sampling failed, fallback to top_k=1 strategy")
probs = probs.masked_fill(torch.isnan(probs), 0.0)
argmax_ids = torch.argmax(probs, dim=-1)
batch_next_token_ids = torch.where(
sample_output.success, batch_next_token_ids, argmax_ids
)
sample_output.probs = probs
sample_output.batch_next_token_ids = batch_next_token_ids

return batch_next_token_ids
return sample_output.batch_next_token_ids
52 changes: 32 additions & 20 deletions python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from sglang.srt.constrained.fsm_cache import FSMCache
from sglang.srt.constrained.jump_forward import JumpForwardCache
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.logits_processor import LogitProcessorOutput
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
AbortReq,
BatchEmbeddingOut,
Expand Down Expand Up @@ -504,21 +504,29 @@ def forward_prefill_batch(self, batch: ScheduleBatch):
if self.model_runner.is_generation:
# Forward and sample the next tokens
if batch.extend_num_tokens != 0:
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
next_token_ids = batch.sample(output.next_token_logits)
sample_output, logits_output = self.model_runner.forward(
batch, ForwardMode.EXTEND
)
next_token_ids = batch.check_sample_results(sample_output)
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids
)

# Move logprobs to cpu
if output.next_token_logprobs is not None:
output.next_token_logprobs = output.next_token_logprobs[
torch.arange(len(next_token_ids), device=next_token_ids.device),
next_token_ids,
].tolist()
output.input_token_logprobs = output.input_token_logprobs.tolist()
output.normalized_prompt_logprobs = (
output.normalized_prompt_logprobs.tolist()
if logits_output.next_token_logprobs is not None:
logits_output.next_token_logprobs = (
logits_output.next_token_logprobs[
torch.arange(
len(next_token_ids), device=next_token_ids.device
),
next_token_ids,
].tolist()
)
logits_output.input_token_logprobs = (
logits_output.input_token_logprobs.tolist()
)
logits_output.normalized_prompt_logprobs = (
logits_output.normalized_prompt_logprobs.tolist()
)

next_token_ids = next_token_ids.tolist()
Expand Down Expand Up @@ -557,12 +565,14 @@ def forward_prefill_batch(self, batch: ScheduleBatch):
self.req_to_token_pool.free(req.req_pool_idx)

if req.return_logprob:
self.add_logprob_return_values(i, req, pt, next_token_ids, output)
self.add_logprob_return_values(
i, req, pt, next_token_ids, logits_output
)
pt += req.extend_input_len
else:
assert batch.extend_num_tokens != 0
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
embeddings = output.embeddings.tolist()
logits_output = self.model_runner.forward(batch, ForwardMode.EXTEND)
embeddings = logits_output.embeddings.tolist()

# Check finish conditions
for i, req in enumerate(batch.reqs):
Expand Down Expand Up @@ -590,7 +600,7 @@ def add_logprob_return_values(
req: Req,
pt: int,
next_token_ids: List[int],
output: LogitProcessorOutput,
output: LogitsProcessorOutput,
):
if req.normalized_prompt_logprob is None:
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
Expand Down Expand Up @@ -672,15 +682,17 @@ def forward_decode_batch(self, batch: ScheduleBatch):
batch.prepare_for_decode()

# Forward and sample the next tokens
output = self.model_runner.forward(batch, ForwardMode.DECODE)
next_token_ids = batch.sample(output.next_token_logits)
sample_output, logits_output = self.model_runner.forward(
batch, ForwardMode.DECODE
)
next_token_ids = batch.check_sample_results(sample_output)
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids
)

# Move logprobs to cpu
if output.next_token_logprobs is not None:
next_token_logprobs = output.next_token_logprobs[
if logits_output.next_token_logprobs is not None:
next_token_logprobs = logits_output.next_token_logprobs[
torch.arange(len(next_token_ids), device=next_token_ids.device),
next_token_ids,
].tolist()
Expand All @@ -706,7 +718,7 @@ def forward_decode_batch(self, batch: ScheduleBatch):
(next_token_logprobs[i], next_token_id)
)
if req.top_logprobs_num > 0:
req.output_top_logprobs.append(output.output_top_logprobs[i])
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])

self.handle_finished_requests(batch)

Expand Down
Loading

0 comments on commit 381dd57

Please sign in to comment.