From 6c91c482cf1ec7a7e77e0343b9b4d9e3f20f7366 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 25 Sep 2024 09:54:42 -0700 Subject: [PATCH 01/31] TMP --- vllm/core/interfaces.py | 33 +- vllm/core/scheduler.py | 1493 ------------------- vllm/core/scheduler_v2.py | 358 +++++ vllm/engine/llm_engine.py | 1723 ---------------------- vllm/engine/llm_engine_v2.py | 665 +++++++++ vllm/entrypoints/llm.py | 2 +- vllm/executor/executor_base.py | 9 +- vllm/model_executor/__init__.py | 4 +- vllm/model_executor/layers/sampler.py | 1369 ++--------------- vllm/model_executor/sampling_metadata.py | 586 +------- vllm/request.py | 112 ++ vllm/sampler_output.py | 19 + vllm/sequence.py | 1329 ----------------- 13 files changed, 1268 insertions(+), 6434 deletions(-) delete mode 100644 vllm/core/scheduler.py create mode 100644 vllm/core/scheduler_v2.py delete mode 100644 vllm/engine/llm_engine.py create mode 100644 vllm/engine/llm_engine_v2.py create mode 100644 vllm/request.py create mode 100644 vllm/sampler_output.py delete mode 100644 vllm/sequence.py diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index 96f8dd851b2f4..fa46a78480cd4 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -4,7 +4,7 @@ from typing import Sequence as GenericSequence from typing import Tuple -from vllm.sequence import Sequence, SequenceGroup +from vllm.request import Request from vllm.utils import Device @@ -44,53 +44,49 @@ def get_block_space_manager_class(version: str): raise ValueError(f"Unknown version {version=}") @abstractmethod - def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: + def can_allocate(self, request: Request) -> AllocStatus: pass @abstractmethod - def allocate(self, seq_group: SequenceGroup) -> None: + def allocate(self, request: Request) -> None: pass @abstractmethod - def can_append_slots(self, seq_group: SequenceGroup, + def can_append_slots(self, request: Request, num_lookahead_slots: int) -> bool: pass @abstractmethod def append_slots( self, - seq: Sequence, + request: Request, num_lookahead_slots: int, ) -> List[Tuple[int, int]]: pass @abstractmethod - def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: - pass - - @abstractmethod - def can_swap_in(self, seq_group: SequenceGroup, + def can_swap_in(self, request: Request, num_lookahead_slots: int) -> AllocStatus: pass @abstractmethod - def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + def swap_in(self, request: Request) -> List[Tuple[int, int]]: pass @abstractmethod - def can_swap_out(self, seq_group: SequenceGroup) -> bool: + def can_swap_out(self, request: Request) -> bool: pass @abstractmethod - def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + def swap_out(self, request: Request) -> List[Tuple[int, int]]: pass @abstractmethod - def free(self, seq: Sequence) -> None: + def free(self, request: Request) -> None: pass @abstractmethod - def get_block_table(self, seq: Sequence) -> List[int]: + def get_block_table(self, request: Request) -> List[int]: pass @abstractmethod @@ -104,19 +100,18 @@ def get_num_free_cpu_blocks(self) -> int: @abstractmethod def access_all_blocks_in_seq( self, - seq: Sequence, + request: Request, access_time: float, ) -> None: pass @abstractmethod def get_common_computed_block_ids( - self, seqs: List[Sequence]) -> GenericSequence[int]: + self, reqs: List[Request]) -> GenericSequence[int]: pass @abstractmethod - def mark_blocks_as_computed(self, seq_group: SequenceGroup, - token_chunk_size: int): + def mark_blocks_as_computed(self, request: Request, token_chunk_size: int): pass @abstractmethod diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py deleted file mode 100644 index c3fa95f57b737..0000000000000 --- a/vllm/core/scheduler.py +++ /dev/null @@ -1,1493 +0,0 @@ -import enum -import os -import random -import time -from collections import deque -from dataclasses import dataclass, field -from typing import (Callable, Deque, Dict, Iterable, List, Optional, Set, - Tuple, Union) - -from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig -from vllm.core.interfaces import AllocStatus, BlockSpaceManager -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sequence import (Sequence, SequenceData, SequenceGroup, - SequenceGroupMetadata, SequenceGroupMetadataDelta, - SequenceStatus) -from vllm.utils import Device, PyObjectCache - -logger = init_logger(__name__) - -# Test-only. If configured, decode is preempted with -# ARTIFICIAL_PREEMPTION_PROB% probability. -ENABLE_ARTIFICIAL_PREEMPT = bool( - os.getenv("VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT", False)) # noqa -ARTIFICIAL_PREEMPTION_PROB = 0.5 -ARTIFICIAL_PREEMPTION_MAX_CNT = 500 - - -class PreemptionMode(enum.Enum): - """Preemption modes. - - 1. Swapping: Swap out the blocks of the preempted sequences to CPU memory - and swap them back in when the sequences are resumed. - 2. Recomputation: Discard the blocks of the preempted sequences and - recompute them when the sequences are resumed, treating the sequences as - new prompts. - """ - SWAP = enum.auto() - RECOMPUTE = enum.auto() - - -@dataclass -class SchedulingBudget: - """The available slots for scheduling. - - TODO(sang): Right now, the budget is request_id-aware meaning it can ignore - budget update from the same request_id. It is because in normal scheduling - path, we update RUNNING num_seqs ahead of time, meaning it could be - updated more than once when scheduling RUNNING requests. Since this won't - happen if we only have chunked prefill scheduling, we can remove this - feature from the API when chunked prefill is enabled by default. - """ - token_budget: int - max_num_seqs: int - _request_ids_num_batched_tokens: Set[str] = field(default_factory=set) - _request_ids_num_curr_seqs: Set[str] = field(default_factory=set) - _num_batched_tokens: int = 0 - _num_curr_seqs: int = 0 - - def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int): - assert num_new_tokens != 0 - assert num_new_seqs != 0 - return (self.num_batched_tokens + num_new_tokens <= self.token_budget - and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs) - - def remaining_token_budget(self): - return self.token_budget - self.num_batched_tokens - - def add_num_batched_tokens(self, req_id: str, num_batched_tokens: int): - if req_id in self._request_ids_num_batched_tokens: - return - - self._request_ids_num_batched_tokens.add(req_id) - self._num_batched_tokens += num_batched_tokens - - def subtract_num_batched_tokens(self, req_id: str, - num_batched_tokens: int): - if req_id in self._request_ids_num_batched_tokens: - self._request_ids_num_batched_tokens.remove(req_id) - self._num_batched_tokens -= num_batched_tokens - - def add_num_seqs(self, req_id: str, num_curr_seqs: int): - if req_id in self._request_ids_num_curr_seqs: - return - - self._request_ids_num_curr_seqs.add(req_id) - self._num_curr_seqs += num_curr_seqs - - def subtract_num_seqs(self, req_id: str, num_curr_seqs: int): - if req_id in self._request_ids_num_curr_seqs: - self._request_ids_num_curr_seqs.remove(req_id) - self._num_curr_seqs -= num_curr_seqs - - @property - def num_batched_tokens(self): - return self._num_batched_tokens - - @property - def num_curr_seqs(self): - return self._num_curr_seqs - - -@dataclass -class ScheduledSequenceGroup: - # A sequence group that's scheduled. - seq_group: SequenceGroup - # The total chunk size (number of tokens) to process for next iteration. - # 1 for decoding. Same as prompt tokens for prefill, but if prefill is - # chunked, it can be smaller than that. - token_chunk_size: int - - -@dataclass -class SchedulerOutputs: - """The scheduling decision made from a scheduler.""" - # Scheduled sequence groups. - scheduled_seq_groups: Iterable[ScheduledSequenceGroup] - # Number of prefill groups scheduled. - num_prefill_groups: int - # Total number of batched tokens. - num_batched_tokens: int - # Blocks to swap in. List of CPU -> GPU block number. - blocks_to_swap_in: List[Tuple[int, int]] - # Blocks to swap out. List of GPU -> CPU block number. - blocks_to_swap_out: List[Tuple[int, int]] - # Blocks to copy. Source to dest block. - blocks_to_copy: List[Tuple[int, int]] - # Sequence groups that are going to be ignored. - ignored_seq_groups: List[SequenceGroup] - # The number of slots for lookahead decoding. - num_lookahead_slots: int - # The number of requests in the running queue - running_queue_size: int - preempted: int - - def __post_init__(self): - # Swap in and swap out should never happen at the same time. - assert not (self.blocks_to_swap_in and self.blocks_to_swap_out) - - self.num_loras: int = len(self.lora_requests) - if self.num_loras > 0: - self._sort_by_lora_ids() - - self.num_prompt_adapters: int = len(self.prompt_adapter_requests) - - def is_empty(self) -> bool: - # NOTE: We do not consider the ignored sequence groups. - return (not self.scheduled_seq_groups and not self.blocks_to_swap_in - and not self.blocks_to_swap_out and not self.blocks_to_copy) - - def _sort_by_lora_ids(self): - self.scheduled_seq_groups = sorted( - self.scheduled_seq_groups, - key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id)) - - @property - def lora_requests(self) -> Set[LoRARequest]: - return { - g.seq_group.lora_request - for g in self.scheduled_seq_groups - if g.seq_group.lora_request is not None - } - - @property - def prompt_adapter_requests(self) -> Set[PromptAdapterRequest]: - return { - g.seq_group.prompt_adapter_request - for g in self.scheduled_seq_groups - if g.seq_group.prompt_adapter_request is not None - } - - -@dataclass -class SchedulerRunningOutputs: - """The requests that are scheduled from a running queue. - - Could contain prefill (prefill that's chunked) or decodes. If there's not - enough memory, it can be preempted (for recompute) or swapped out. - """ - # Selected sequences that are running and in a decoding phase. - decode_seq_groups: List[ScheduledSequenceGroup] - # Selected sequences that are running and in a prefill phase. - # I.e., it means the prefill has been chunked. - prefill_seq_groups: List[ScheduledSequenceGroup] - # The preempted sequences. - preempted: List[SequenceGroup] - # Sequences that are swapped out. - swapped_out: List[SequenceGroup] - # The blocks to swap out. - blocks_to_swap_out: List[Tuple[int, int]] - # The blocks to copy. - blocks_to_copy: List[Tuple[int, int]] - # The number of slots for lookahead decoding. - num_lookahead_slots: int - - # Optimization for fast-access to seq_group lists - decode_seq_groups_list: List[SequenceGroup] - prefill_seq_groups_list: List[SequenceGroup] - - @classmethod - def create_empty(cls) -> "SchedulerRunningOutputs": - return SchedulerRunningOutputs( - decode_seq_groups=[], - prefill_seq_groups=[], - preempted=[], - swapped_out=[], - blocks_to_swap_out=[], - blocks_to_copy=[], - num_lookahead_slots=0, - decode_seq_groups_list=[], - prefill_seq_groups_list=[], - ) - - -@dataclass -class SchedulerSwappedInOutputs: - """The requests that are scheduled from a swap queue. - - Could contain prefill (prefill that's chunked) or decodes. - """ - # Selected sequences that are going to be swapped in and is in a - # decoding phase. - decode_seq_groups: List[ScheduledSequenceGroup] - # Selected sequences that are going to be swapped in and in a prefill - # phase. I.e., it means the prefill has been chunked. - prefill_seq_groups: List[ScheduledSequenceGroup] - # The blocks to swap in. - blocks_to_swap_in: List[Tuple[int, int]] - # The blocks to copy. - blocks_to_copy: List[Tuple[int, int]] - # The number of slots for lookahead decoding. - num_lookahead_slots: int - # Infeasible sequence groups. - infeasible_seq_groups: List[SequenceGroup] - - @classmethod - def create_empty(cls) -> "SchedulerSwappedInOutputs": - return SchedulerSwappedInOutputs( - decode_seq_groups=[], - prefill_seq_groups=[], - blocks_to_swap_in=[], - blocks_to_copy=[], - num_lookahead_slots=0, - infeasible_seq_groups=[], - ) - - -@dataclass -class SchedulerPrefillOutputs: - """The requests that are scheduled from a waiting queue. - - Could contain a fresh prefill requests or preempted requests that need - to be recomputed from scratch. - """ - # Selected sequences for prefill. - seq_groups: List[ScheduledSequenceGroup] - # Ignored sequence groups. - ignored_seq_groups: List[SequenceGroup] - num_lookahead_slots: int - - @classmethod - def create_empty(cls) -> "SchedulerPrefillOutputs": - return SchedulerPrefillOutputs( - seq_groups=[], - ignored_seq_groups=[], - num_lookahead_slots=0, - ) - - -def seq_group_metadata_builder(): - return SequenceGroupMetadata(request_id="", - is_prompt=False, - seq_data={}, - sampling_params=None, - block_tables={}) - - -def scheduler_running_outputs_builder(): - return SchedulerRunningOutputs(decode_seq_groups=[], - prefill_seq_groups=[], - preempted=[], - swapped_out=[], - blocks_to_swap_out=[], - blocks_to_copy=[], - num_lookahead_slots=0, - prefill_seq_groups_list=[], - decode_seq_groups_list=[]) - - -def scheduled_seq_group_builder(): - return ScheduledSequenceGroup(SequenceGroup("", [], -1), - token_chunk_size=0) - # return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0) - - -class Scheduler: - - def __init__( - self, - scheduler_config: SchedulerConfig, - cache_config: CacheConfig, - lora_config: Optional[LoRAConfig], - pipeline_parallel_size: int = 1, - output_proc_callback: Optional[Callable] = None, - ) -> None: - self.scheduler_config = scheduler_config - self.cache_config = cache_config - # Note for LoRA scheduling: the current policy is extremely - # simple and NOT fair. It can lead to starvation of some - # LoRAs. This should be improved in the future. - self.lora_config = lora_config - - version = "v1" - if self.scheduler_config.use_v2_block_manager: - version = "v2" - if self.scheduler_config.embedding_mode: - version = "embedding" - - BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class( - version) - - num_gpu_blocks = cache_config.num_gpu_blocks - if num_gpu_blocks: - num_gpu_blocks //= pipeline_parallel_size - - num_cpu_blocks = cache_config.num_cpu_blocks - if num_cpu_blocks: - num_cpu_blocks //= pipeline_parallel_size - - # Create the block space manager. - self.block_manager = BlockSpaceManagerImpl( - block_size=self.cache_config.block_size, - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=num_cpu_blocks, - sliding_window=self.cache_config.sliding_window, - enable_caching=self.cache_config.enable_prefix_caching) - - # Sequence groups in the WAITING state. - # Contain new prefill or preempted requests. - self.waiting: Deque[SequenceGroup] = deque() - # Sequence groups in the RUNNING state. - # Contain decode requests. - self.running: Deque[SequenceGroup] = deque() - # Sequence groups in the SWAPPED state. - # Contain decode requests that are swapped out. - self.swapped: Deque[SequenceGroup] = deque() - # Sequence groups finished requests ids since last step iteration. - # It lets the model know that any state associated with these requests - # can and must be released after the current step. - # This is used to evict the finished requests from the Mamba cache. - self._finished_requests_ids: List[str] = list() - # Time at previous scheduling step - self.prev_time = 0.0 - # Did we schedule a prompt at previous step? - self.prev_prompt = False - # Latency of the last prompt step - self.last_prompt_latency = 0.0 - # preemption mode, RECOMPUTE or SWAP - self.user_specified_preemption_mode = scheduler_config.preemption_mode - - # The following field is test-only. It is used to inject artificial - # preemption. - self.enable_artificial_preemption = ENABLE_ARTIFICIAL_PREEMPT - self.artificial_preempt_cnt = (ARTIFICIAL_PREEMPTION_MAX_CNT - if self.enable_artificial_preemption - else 0) - self.num_cumulative_preemption: int = 0 - - # Used to cache python objects - self._seq_group_metadata_cache: List[PyObjectCache] = [] - self._scheduler_running_outputs_cache: List[PyObjectCache] = [] - self._scheduled_seq_group_cache: List[PyObjectCache] = [] - - # For async output processing, we need to swap cache buffers between - # iterations. I.e. since the output processing is lagged one step, - # we cannot reuse the cached objects immediately when the schedule() - # is called again, but only when schedule() is called the second time. - self.output_proc_callback = output_proc_callback - self.use_async_output_proc = self.output_proc_callback is not None - self.num_cache_iters = 2 if self.use_async_output_proc else 1 - - self.cache_id = 0 - for i in range(self.num_cache_iters): - self._seq_group_metadata_cache.append( - PyObjectCache(seq_group_metadata_builder)) - self._scheduler_running_outputs_cache.append( - PyObjectCache(scheduler_running_outputs_builder)) - self._scheduled_seq_group_cache.append( - PyObjectCache(scheduled_seq_group_builder)) - - # For async postprocessor, the extra decode run cannot be done - # when the request reaches max_model_len. In this case, the request - # will be stopped during schedule() call and added to this stop list - # for processing and deallocation by the free_finished_seq_groups() - self._async_stopped: List[SequenceGroup] = [] - - @property - def next_cache_id(self): - return (self.cache_id + 1) % self.num_cache_iters - - @property - def lora_enabled(self) -> bool: - return bool(self.lora_config) - - @property - def num_decoding_tokens_per_seq(self) -> int: - """The number of new tokens.""" - return 1 - - def add_seq_group(self, seq_group: SequenceGroup) -> None: - # Add sequence groups to the waiting queue. - self.waiting.append(seq_group) - - def _add_seq_group_to_running(self, seq_group: SequenceGroup) -> None: - # Add sequence groups to the running queue. - # Only for testing purposes. - self.running.append(seq_group) - - def _add_seq_group_to_swapped(self, seq_group: SequenceGroup) -> None: - # Add sequence groups to the swapped queue. - # Only for testing purposes. - self.swapped.append(seq_group) - - def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: - """Aborts a sequence group with the given ID. - - Check if the sequence group with the given ID - is present in any of the state queue. - If present, remove the sequence group from the state queue. - Also, if any of the sequences in the sequence group is not finished, - free the sequence with status `FINISHED_ABORTED`. - Otherwise, do nothing. - - Args: - request_id: The ID(s) of the sequence group to abort. - """ - if isinstance(request_id, str): - request_id = (request_id, ) - request_ids = set(request_id) - for state_queue in [self.waiting, self.running, self.swapped]: - aborted_groups: List[SequenceGroup] = [] - for seq_group in state_queue: - if not request_ids: - # Using 'break' here may add two extra iterations, - # but is acceptable to reduce complexity. - break - if seq_group.request_id in request_ids: - # Appending aborted group into pending list. - aborted_groups.append(seq_group) - request_ids.remove(seq_group.request_id) - for aborted_group in aborted_groups: - # Remove the sequence group from the state queue. - state_queue.remove(aborted_group) - # Remove the aborted request from the Mamba cache. - self._finished_requests_ids.append(aborted_group.request_id) - for seq in aborted_group.get_seqs(): - if seq.is_finished(): - continue - seq.status = SequenceStatus.FINISHED_ABORTED - self.free_seq(seq) - - self._free_seq_group_cross_attn_blocks(aborted_group) - - def _free_seq_group_cross_attn_blocks( - self, - seq_group: SequenceGroup, - ) -> None: - """ - Free a sequence group from a cross-attention block table. - Has no effect on decoder-only models. - """ - if seq_group.is_encoder_decoder(): - self.block_manager.free_cross(seq_group) - - def has_unfinished_seqs(self) -> bool: - return len(self.waiting) != 0 or len(self.running) != 0 or len( - self.swapped) != 0 - - def get_prefix_cache_hit_rate(self, device: Device) -> float: - return self.block_manager.get_prefix_cache_hit_rate(device) - - def get_num_unfinished_seq_groups(self) -> int: - return len(self.waiting) + len(self.running) + len(self.swapped) - - def get_and_reset_finished_requests_ids(self) -> List[str]: - """Flushes the list of request ids of previously finished seq_groups.""" - finished_requests_ids = self._finished_requests_ids - self._finished_requests_ids = list() - return finished_requests_ids - - def _schedule_running( - self, - budget: SchedulingBudget, - curr_loras: Optional[Set[int]], - enable_chunking: bool = False, - ) -> SchedulerRunningOutputs: - """Schedule sequence groups that are running. - - Running queue should include decode and chunked prefill requests. - - Args: - budget: The scheduling budget. The argument is in-place updated - when any decodes are preempted. - curr_loras: Currently batched lora request ids. The argument is - in-place updated when any decodes are preempted. - enable_chunking: If True, seq group can be chunked and only a - chunked number of tokens are scheduled if - `budget.num_batched_tokens` has not enough capacity to schedule - all tokens. - - Returns: - SchedulerRunningOutputs. - """ - ret: SchedulerRunningOutputs = \ - self._scheduler_running_outputs_cache[self.cache_id].get_object() - ret.blocks_to_swap_out.clear() - ret.blocks_to_copy.clear() - ret.decode_seq_groups.clear() - ret.prefill_seq_groups.clear() - ret.preempted.clear() - ret.swapped_out.clear() - - ret.num_lookahead_slots = self._get_num_lookahead_slots( - is_prefill=False) - - ret.decode_seq_groups_list.clear() - ret.prefill_seq_groups_list.clear() - - # Blocks that need to be swapped or copied before model execution. - blocks_to_swap_out: List[Tuple[int, int]] = ret.blocks_to_swap_out - blocks_to_copy: List[Tuple[int, int]] = ret.blocks_to_copy - - decode_seq_groups: List[ScheduledSequenceGroup] = ret.decode_seq_groups - prefill_seq_groups: List[ - ScheduledSequenceGroup] = ret.prefill_seq_groups - preempted: List[SequenceGroup] = ret.preempted - swapped_out: List[SequenceGroup] = ret.swapped_out - - running_queue = self.running - assert len(self._async_stopped) == 0 - while running_queue: - seq_group = running_queue[0] - num_running_tokens = self._get_num_new_tokens( - seq_group, SequenceStatus.RUNNING, enable_chunking, budget) - - if num_running_tokens == 0: - # No budget => Stop - break - - running_queue.popleft() - - # With async postprocessor, an extra decode run is done - # to process the final tokens. The check below avoids this extra - # decode run when the model max len is reached, in order to avoid - # a memory overflow. - if self.use_async_output_proc and seq_group.seqs[0].get_len( - ) > self.scheduler_config.max_model_len: - self._async_stopped.append(seq_group) - continue - - # NOTE(woosuk): Preemption happens only when there is no available - # slot to keep all the sequence groups in the RUNNING state. - while not self._can_append_slots(seq_group): - budget.subtract_num_batched_tokens(seq_group.request_id, - num_running_tokens) - num_running_seqs = seq_group.get_max_num_running_seqs() - budget.subtract_num_seqs(seq_group.request_id, - num_running_seqs) - - if (curr_loras is not None and seq_group.lora_int_id > 0 - and seq_group.lora_int_id in curr_loras): - curr_loras.remove(seq_group.lora_int_id) - - # Determine victim sequence - cont_loop = True - if running_queue: - # Preempt the lowest-priority sequence group. - victim_seq_group = running_queue.pop() - else: - # No other sequence group can be preempted. - # Preempt the current sequence group. - # Note: This is also where we stop this loop - # (since there is nothing else to preempt) - victim_seq_group = seq_group - cont_loop = False - - # With async postprocessor, before preempting a sequence - # we need to ensure it has no pending async postprocessor - do_preempt = True - if self.use_async_output_proc: - assert self.output_proc_callback is not None - self.output_proc_callback( - request_id=victim_seq_group.request_id) - - # It may be that the async pending "victim_seq_group" - # becomes finished, in which case we simply free it. - if victim_seq_group.is_finished(): - self._free_finished_seq_group(victim_seq_group) - do_preempt = False - - # Do preemption - if do_preempt: - preempted_mode = self._preempt(victim_seq_group, - blocks_to_swap_out) - if preempted_mode == PreemptionMode.RECOMPUTE: - preempted.append(victim_seq_group) - else: - swapped_out.append(victim_seq_group) - - if not cont_loop: - break - else: - self._append_slots(seq_group, blocks_to_copy) - is_prefill = seq_group.is_prefill() - - scheduled_seq_group: ScheduledSequenceGroup = \ - self._scheduled_seq_group_cache[self.cache_id].get_object() - scheduled_seq_group.seq_group = seq_group - if is_prefill: - scheduled_seq_group.token_chunk_size = num_running_tokens - prefill_seq_groups.append(scheduled_seq_group) - ret.prefill_seq_groups_list.append(seq_group) - else: - scheduled_seq_group.token_chunk_size = 1 - decode_seq_groups.append(scheduled_seq_group) - ret.decode_seq_groups_list.append(seq_group) - - budget.add_num_batched_tokens(seq_group.request_id, - num_running_tokens) - # OPTIMIZATION: Note that get_max_num_running_seqs is - # expensive. For the default scheduling chase where - # enable_chunking is False, num_seqs are updated before running - # this method, so we don't have to update it again here. - if enable_chunking: - num_running_seqs = seq_group.get_max_num_running_seqs() - budget.add_num_seqs(seq_group.request_id, num_running_seqs) - if curr_loras is not None and seq_group.lora_int_id > 0: - curr_loras.add(seq_group.lora_int_id) - - self._scheduler_running_outputs_cache[self.next_cache_id].reset() - self._scheduled_seq_group_cache[self.next_cache_id].reset() - - return ret - - def _schedule_swapped( - self, - budget: SchedulingBudget, - curr_loras: Optional[Set[int]], - enable_chunking: bool = False, - ) -> SchedulerSwappedInOutputs: - """Schedule sequence groups that are swapped out. - - It schedules swapped requests as long as it fits `budget` and - curr_loras <= max_lora from the scheduling config. The input arguments - `budget` and `curr_loras` are updated based on scheduled seq_groups. - - Args: - budget: The scheduling budget. The argument is in-place updated - when any requests are swapped in. - curr_loras: Currently batched lora request ids. The argument is - in-place updated when any requests are swapped in. - enable_chunking: If True, seq group can be chunked and only a - chunked number of tokens are scheduled if - `budget.num_batched_tokens` has not enough capacity to schedule - all tokens. - - Returns: - SchedulerSwappedInOutputs. - """ - # Blocks that need to be swapped or copied before model execution. - blocks_to_swap_in: List[Tuple[int, int]] = [] - blocks_to_copy: List[Tuple[int, int]] = [] - decode_seq_groups: List[ScheduledSequenceGroup] = [] - prefill_seq_groups: List[ScheduledSequenceGroup] = [] - infeasible_seq_groups: List[SequenceGroup] = [] - - swapped_queue = self.swapped - - leftover_swapped: Deque[SequenceGroup] = deque() - while swapped_queue: - seq_group = swapped_queue[0] - - # If the sequence group cannot be swapped in, stop. - is_prefill = seq_group.is_prefill() - alloc_status = self.block_manager.can_swap_in( - seq_group, self._get_num_lookahead_slots(is_prefill)) - if alloc_status == AllocStatus.LATER: - break - elif alloc_status == AllocStatus.NEVER: - logger.warning( - "Failing the request %s because there's not enough kv " - "cache blocks to run the entire sequence.", - seq_group.request_id) - for seq in seq_group.get_seqs(): - seq.status = SequenceStatus.FINISHED_IGNORED - infeasible_seq_groups.append(seq_group) - swapped_queue.popleft() - continue - - lora_int_id = 0 - if self.lora_enabled: - lora_int_id = seq_group.lora_int_id - assert curr_loras is not None - assert self.lora_config is not None - if (lora_int_id > 0 and (lora_int_id not in curr_loras) - and len(curr_loras) >= self.lora_config.max_loras): - # We don't have a space for another LoRA, so - # we ignore this request for now. - leftover_swapped.appendleft(seq_group) - swapped_queue.popleft() - continue - - # The total number of sequences in the RUNNING state should not - # exceed the maximum number of sequences. - num_new_seqs = seq_group.get_max_num_running_seqs() - num_new_tokens = self._get_num_new_tokens(seq_group, - SequenceStatus.SWAPPED, - enable_chunking, budget) - - if (num_new_tokens == 0 - or not budget.can_schedule(num_new_tokens=num_new_tokens, - num_new_seqs=num_new_seqs)): - break - - if lora_int_id > 0 and curr_loras is not None: - curr_loras.add(lora_int_id) - swapped_queue.popleft() - self._swap_in(seq_group, blocks_to_swap_in) - self._append_slots(seq_group, blocks_to_copy) - is_prefill = seq_group.is_prefill() - if is_prefill: - prefill_seq_groups.append( - ScheduledSequenceGroup(seq_group, - token_chunk_size=num_new_tokens)) - else: - decode_seq_groups.append( - ScheduledSequenceGroup(seq_group, token_chunk_size=1)) - budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens) - budget.add_num_seqs(seq_group.request_id, num_new_seqs) - - swapped_queue.extendleft(leftover_swapped) - - return SchedulerSwappedInOutputs( - decode_seq_groups=decode_seq_groups, - prefill_seq_groups=prefill_seq_groups, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_copy=blocks_to_copy, - num_lookahead_slots=self._get_num_lookahead_slots( - is_prefill=False), - infeasible_seq_groups=infeasible_seq_groups, - ) - - def _get_prompt_limit(self, seq_group: SequenceGroup) -> int: - if self.scheduler_config.chunked_prefill_enabled: - prompt_limit = self.scheduler_config.max_model_len - else: - prompt_limit = min(self.scheduler_config.max_model_len, - self.scheduler_config.max_num_batched_tokens) - - # Model is fine tuned with long context. Return the fine tuned max_len. - if (seq_group.lora_request - and seq_group.lora_request.long_lora_max_len): - assert prompt_limit <= seq_group.lora_request.long_lora_max_len - return seq_group.lora_request.long_lora_max_len - else: - return prompt_limit - - def _schedule_prefills( - self, - budget: SchedulingBudget, - curr_loras: Optional[Set[int]], - enable_chunking: bool = False, - ) -> SchedulerPrefillOutputs: - """Schedule sequence groups that are in prefill stage. - - Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE - as a new prefill (that starts from beginning -> most recently generated - tokens). - - It schedules waiting requests as long as it fits `budget` and - curr_loras <= max_lora from the scheduling config. The input arguments - `budget` and `curr_loras` are updated based on scheduled seq_groups. - - Args: - budget: The scheduling budget. The argument is in-place updated - when any requests are scheduled. - curr_loras: Currently batched lora request ids. The argument is - in-place updated when any requests are scheduled. - enable_chunking: If True, seq group can be chunked and only a - chunked number of tokens are scheduled if - `budget.num_batched_tokens` has not enough capacity to schedule - all tokens. - - Returns: - SchedulerPrefillOutputs. - """ - ignored_seq_groups: List[SequenceGroup] = [] - seq_groups: List[ScheduledSequenceGroup] = [] - - waiting_queue = self.waiting - - leftover_waiting_sequences: Deque[SequenceGroup] = deque() - while self._passed_delay(time.time()) and waiting_queue: - seq_group = waiting_queue[0] - - waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) - assert len(waiting_seqs) == 1, ( - "Waiting sequence group should have only one prompt " - "sequence.") - num_new_tokens = self._get_num_new_tokens(seq_group, - SequenceStatus.WAITING, - enable_chunking, budget) - if not enable_chunking: - num_prompt_tokens = waiting_seqs[0].get_len() - assert num_new_tokens == num_prompt_tokens - - prompt_limit = self._get_prompt_limit(seq_group) - if num_new_tokens > prompt_limit: - logger.warning( - "Input prompt (%d tokens) is too long" - " and exceeds limit of %d", num_new_tokens, prompt_limit) - for seq in waiting_seqs: - seq.status = SequenceStatus.FINISHED_IGNORED - ignored_seq_groups.append(seq_group) - waiting_queue.popleft() - continue - - # If the sequence group cannot be allocated, stop. - can_allocate = self.block_manager.can_allocate(seq_group) - if can_allocate == AllocStatus.LATER: - break - elif can_allocate == AllocStatus.NEVER: - logger.warning( - "Input prompt (%d tokens) is too long" - " and exceeds the capacity of block_manager", - num_new_tokens) - for seq in waiting_seqs: - seq.status = SequenceStatus.FINISHED_IGNORED - ignored_seq_groups.append(seq_group) - waiting_queue.popleft() - continue - - lora_int_id = 0 - if self.lora_enabled: - lora_int_id = seq_group.lora_int_id - assert curr_loras is not None - assert self.lora_config is not None - if (self.lora_enabled and lora_int_id > 0 - and lora_int_id not in curr_loras - and len(curr_loras) >= self.lora_config.max_loras): - # We don't have a space for another LoRA, so - # we ignore this request for now. - leftover_waiting_sequences.appendleft(seq_group) - waiting_queue.popleft() - continue - - num_new_seqs = seq_group.get_max_num_running_seqs() - if (num_new_tokens == 0 - or not budget.can_schedule(num_new_tokens=num_new_tokens, - num_new_seqs=num_new_seqs)): - break - - # Can schedule this request. - if curr_loras is not None and lora_int_id > 0: - curr_loras.add(lora_int_id) - waiting_queue.popleft() - self._allocate_and_set_running(seq_group) - seq_group.init_multi_step( - num_scheduler_steps=self._get_num_lookahead_slots( - is_prefill=True) + 1) - seq_groups.append( - ScheduledSequenceGroup(seq_group=seq_group, - token_chunk_size=num_new_tokens)) - budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens) - budget.add_num_seqs(seq_group.request_id, num_new_seqs) - - # Queue requests that couldn't be scheduled. - waiting_queue.extendleft(leftover_waiting_sequences) - if len(seq_groups) > 0: - self.prev_prompt = True - - return SchedulerPrefillOutputs( - seq_groups=seq_groups, - ignored_seq_groups=ignored_seq_groups, - num_lookahead_slots=self._get_num_lookahead_slots(is_prefill=True)) - - def _schedule_default(self) -> SchedulerOutputs: - """Schedule queued requests. - - The current policy is designed to optimize the throughput. First, - it batches as many prefill requests as possible. And it schedules - decodes. If there's a pressure on GPU memory, decode requests can - be swapped or preempted. - """ - # Include running requests to the budget. - budget = SchedulingBudget( - token_budget=self.scheduler_config.max_num_batched_tokens, - max_num_seqs=self.scheduler_config.max_num_seqs, - ) - # Make sure we include num running seqs before scheduling prefill, - # so that we don't schedule beyond max_num_seqs for prefill. - for seq_group in self.running: - budget.add_num_seqs(seq_group.request_id, - seq_group.get_max_num_running_seqs()) - curr_loras = set( - seq_group.lora_int_id for seq_group in self.running - if seq_group.lora_int_id > 0) if self.lora_enabled else None - - prefills = SchedulerPrefillOutputs.create_empty() - running_scheduled = SchedulerRunningOutputs.create_empty() - swapped_in = SchedulerSwappedInOutputs.create_empty() - - # If any requests are swapped, prioritized swapped requests. - if not self.swapped: - prefills = self._schedule_prefills(budget, - curr_loras, - enable_chunking=False) - - # Don't schedule decodes if prefills are scheduled. - # NOTE: If `_schedule_prefills` doesn't enable chunking, self.running - # only contains decode requests, not chunked prefills. - if len(prefills.seq_groups) == 0: - running_scheduled = self._schedule_running(budget, - curr_loras, - enable_chunking=False) - - # If any sequence group is preempted, do not swap in any sequence - # group. because it means there's no slot for new running requests. - if len(running_scheduled.preempted) + len( - running_scheduled.swapped_out) == 0: - swapped_in = self._schedule_swapped(budget, curr_loras) - - assert (budget.num_batched_tokens <= - self.scheduler_config.max_num_batched_tokens) - assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs - - # Update waiting requests. - self.waiting.extendleft(running_scheduled.preempted) - # Update new running requests. - if len(prefills.seq_groups) > 0: - self.running.extend([s.seq_group for s in prefills.seq_groups]) - - self.running.extend(running_scheduled.decode_seq_groups_list) - - if len(swapped_in.decode_seq_groups) > 0: - self.running.extend( - [s.seq_group for s in swapped_in.decode_seq_groups]) - - # Update swapped requests. - self.swapped.extend(running_scheduled.swapped_out) - preempted = (len(running_scheduled.preempted) + - len(running_scheduled.swapped_out)) - - # There should be no prefill from running queue because this policy - # doesn't allow chunked prefills. - assert len(running_scheduled.prefill_seq_groups) == 0 - assert len(swapped_in.prefill_seq_groups) == 0 - - # Merge lists - num_prefill_groups = len(prefills.seq_groups) - if num_prefill_groups > 0: - scheduled_seq_groups = prefills.seq_groups - scheduled_seq_groups.extend(running_scheduled.decode_seq_groups) - else: - scheduled_seq_groups = running_scheduled.decode_seq_groups - scheduled_seq_groups.extend(swapped_in.decode_seq_groups) - - blocks_to_copy = running_scheduled.blocks_to_copy - blocks_to_copy.extend(swapped_in.blocks_to_copy) - - ignored_seq_groups = prefills.ignored_seq_groups - ignored_seq_groups.extend(swapped_in.infeasible_seq_groups) - - return SchedulerOutputs( - scheduled_seq_groups=scheduled_seq_groups, - num_prefill_groups=num_prefill_groups, - num_batched_tokens=budget.num_batched_tokens, - blocks_to_swap_in=swapped_in.blocks_to_swap_in, - blocks_to_swap_out=running_scheduled.blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - ignored_seq_groups=ignored_seq_groups, - num_lookahead_slots=running_scheduled.num_lookahead_slots, - running_queue_size=len(self.running), - preempted=preempted, - ) - - def _schedule_chunked_prefill(self) -> SchedulerOutputs: - """Schedule queued requests. - - Chunked prefill allows to chunk prefill requests, batch them together - with decode requests. This policy 1. schedule as many decoding requests - as possible. 2. schedule chunked prefill requests that are not - finished. 3. schedule swapped request. 4. schedule new prefill - requests. - - The policy can sustain the high GPU utilization because it can put - prefill and decodes requests to the same batch, while it improves - inter token latency because decodes requests don't need to be blocked - by prefill requests. - """ - budget = SchedulingBudget( - token_budget=self.scheduler_config.max_num_batched_tokens, - max_num_seqs=self.scheduler_config.max_num_seqs, - ) - curr_loras: Set[int] = set() - - prefills = SchedulerPrefillOutputs.create_empty() - swapped_in = SchedulerSwappedInOutputs.create_empty() - - # Decoding should be always scheduled first by fcfs. - running_scheduled = self._schedule_running(budget, - curr_loras, - enable_chunking=True) - - # Schedule swapped out requests. - # If preemption happens, it means we don't have space for swap-in. - if len(running_scheduled.preempted) + len( - running_scheduled.swapped_out) == 0: - swapped_in = self._schedule_swapped(budget, curr_loras) - - # Schedule new prefills. - prefills = self._schedule_prefills(budget, - curr_loras, - enable_chunking=True) - - assert (budget.num_batched_tokens <= - self.scheduler_config.max_num_batched_tokens) - assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs - - # Update waiting requests. - self.waiting.extendleft(running_scheduled.preempted) - - # Update new running requests. - # By default, vLLM scheduler prioritizes prefills. - # Once chunked prefill is enabled, - # the policy is changed to prioritize decode requests. - self.running.extend( - [s.seq_group for s in swapped_in.decode_seq_groups]) - self.running.extend( - [s.seq_group for s in swapped_in.prefill_seq_groups]) - self.running.extend( - [s.seq_group for s in running_scheduled.decode_seq_groups]) - self.running.extend( - [s.seq_group for s in running_scheduled.prefill_seq_groups]) - self.running.extend([s.seq_group for s in prefills.seq_groups]) - - # Update swapped requests. - self.swapped.extend(running_scheduled.swapped_out) - return SchedulerOutputs( - scheduled_seq_groups=(prefills.seq_groups + - running_scheduled.prefill_seq_groups + - swapped_in.prefill_seq_groups + - running_scheduled.decode_seq_groups + - swapped_in.decode_seq_groups), - num_prefill_groups=(len(prefills.seq_groups) + - len(swapped_in.prefill_seq_groups) + - len(running_scheduled.prefill_seq_groups)), - num_batched_tokens=budget.num_batched_tokens, - blocks_to_swap_in=swapped_in.blocks_to_swap_in, - blocks_to_swap_out=running_scheduled.blocks_to_swap_out, - blocks_to_copy=running_scheduled.blocks_to_copy + - swapped_in.blocks_to_copy, - ignored_seq_groups=prefills.ignored_seq_groups + - swapped_in.infeasible_seq_groups, - num_lookahead_slots=running_scheduled.num_lookahead_slots, - running_queue_size=len(self.running), - preempted=(len(running_scheduled.preempted) + - len(running_scheduled.swapped_out)), - ) - - def _schedule(self) -> SchedulerOutputs: - """Schedule queued requests.""" - if self.scheduler_config.chunked_prefill_enabled: - return self._schedule_chunked_prefill() - else: - return self._schedule_default() - - def _can_append_slots(self, seq_group: SequenceGroup) -> bool: - """Determine whether or not we have enough space in the KV cache to - continue generation of the sequence group. - """ - # It is True only for testing case to trigger artificial preemption. - if (self.enable_artificial_preemption - and random.uniform(0, 1) < ARTIFICIAL_PREEMPTION_PROB - and self.artificial_preempt_cnt > 0): - self.artificial_preempt_cnt -= 1 - return False - - # Appending slots only occurs in decoding. - is_prefill = False - - return self.block_manager.can_append_slots( - seq_group=seq_group, - num_lookahead_slots=self._get_num_lookahead_slots(is_prefill), - ) - - def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool: - no_beam_search = seq_group.sampling_params is None or ( - seq_group.sampling_params.best_of == 1 - and not seq_group.sampling_params.use_beam_search) - return no_beam_search - - def schedule( - self - ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]: - # Schedule sequence groups. - # This function call changes the internal states of the scheduler - # such as self.running, self.swapped, and self.waiting. - scheduler_start_time = time.perf_counter() - - scheduler_outputs = self._schedule() - now = time.time() - - if not self.cache_config.enable_prefix_caching: - common_computed_block_nums = [] - - allow_async_output_proc: bool = self.use_async_output_proc - - # Create input data structures. - seq_group_metadata_list: List[SequenceGroupMetadata] = [] - for i, scheduled_seq_group in enumerate( - scheduler_outputs.scheduled_seq_groups): - seq_group = scheduled_seq_group.seq_group - token_chunk_size = scheduled_seq_group.token_chunk_size - seq_group.maybe_set_first_scheduled_time(now) - - seq_group_metadata = self._seq_group_metadata_cache[ - self.cache_id].get_object() - seq_group_metadata.seq_data.clear() - seq_group_metadata.block_tables.clear() - - # seq_id -> SequenceData - seq_data: Dict[int, SequenceData] = {} - # seq_id -> physical block numbers - block_tables: Dict[int, List[int]] = {} - - if seq_group.is_encoder_decoder(): - # Encoder associated with SequenceGroup - encoder_seq = seq_group.get_encoder_seq() - assert encoder_seq is not None - encoder_seq_data = encoder_seq.data - # Block table for cross-attention - # Also managed at SequenceGroup level - cross_block_table = self.block_manager.get_cross_block_table( - seq_group) - else: - encoder_seq_data = None - cross_block_table = None - - for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - seq_id = seq.seq_id - seq_data[seq_id] = seq.data - block_tables[seq_id] = self.block_manager.get_block_table(seq) - self.block_manager.access_all_blocks_in_seq(seq, now) - - if self.cache_config.enable_prefix_caching: - common_computed_block_nums = ( - self.block_manager.get_common_computed_block_ids( - seq_group.get_seqs(status=SequenceStatus.RUNNING))) - - do_sample = True - is_prompt = seq_group.is_prefill() - # We should send the metadata to workers when the first prefill - # is sent. Subsequent requests could be chunked prefill or decode. - is_first_prefill = False - if is_prompt: - seqs = seq_group.get_seqs() - # Prefill has only 1 sequence. - assert len(seqs) == 1 - num_computed_tokens = seqs[0].data.get_num_computed_tokens() - is_first_prefill = num_computed_tokens == 0 - # In the next iteration, all prompt tokens are not computed. - # It means the prefill is chunked, and we don't need sampling. - # NOTE: We use get_len instead of get_prompt_len because when - # a sequence is preempted, prefill includes previous generated - # output tokens. - if (token_chunk_size + num_computed_tokens < - seqs[0].data.get_len()): - do_sample = False - - # It assumes the scheduled_seq_groups is ordered by - # prefill < decoding. - if is_first_prefill or not self.scheduler_config.send_delta_data: - seq_group_metadata = SequenceGroupMetadata( - request_id=seq_group.request_id, - is_prompt=is_prompt, - seq_data=seq_data, - sampling_params=seq_group.sampling_params, - block_tables=block_tables, - do_sample=do_sample, - pooling_params=seq_group.pooling_params, - token_chunk_size=token_chunk_size, - lora_request=seq_group.lora_request, - computed_block_nums=common_computed_block_nums, - encoder_seq_data=encoder_seq_data, - cross_block_table=cross_block_table, - state=seq_group.state, - # `multi_modal_data` will only be present for the 1st comm - # between engine and worker. - # the subsequent comms can still use delta, but - # `multi_modal_data` will be None. - multi_modal_data=seq_group.multi_modal_data - if scheduler_outputs.num_prefill_groups > 0 else None, - prompt_adapter_request=seq_group.prompt_adapter_request, - ) - else: - # When SPMD mode is enabled, we only send delta data except for - # the first request to reduce serialization cost. - seq_data_delta = {} - for id, data in seq_data.items(): - seq_data_delta[id] = data.get_delta_and_reset() - seq_group_metadata = SequenceGroupMetadataDelta( - seq_data_delta, - seq_group.request_id, - block_tables, - is_prompt, - do_sample=do_sample, - token_chunk_size=token_chunk_size, - computed_block_nums=common_computed_block_nums, - ) - seq_group_metadata_list.append(seq_group_metadata) - - if allow_async_output_proc: - allow_async_output_proc = self._allow_async_output_proc( - seq_group) - - # Now that the batch has been created, we can assume all blocks in the - # batch will have been computed before the next scheduling invocation. - # This is because the engine assumes that a failure in model execution - # will crash the vLLM instance / will not retry. - for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups: - self.block_manager.mark_blocks_as_computed( - scheduled_seq_group.seq_group, - scheduled_seq_group.token_chunk_size) - - self._seq_group_metadata_cache[self.next_cache_id].reset() - - scheduler_time = time.perf_counter() - scheduler_start_time - # Add this to scheduler time to all the sequences that are currently - # running. This will help estimate if the scheduler is a significant - # component in the e2e latency. - for seq_group in self.running: - if seq_group is not None and seq_group.metrics is not None: - if seq_group.metrics.scheduler_time is not None: - seq_group.metrics.scheduler_time += scheduler_time - else: - seq_group.metrics.scheduler_time = scheduler_time - - # Move to next cache (if exists) - self.cache_id = self.next_cache_id - - # Return results - return (seq_group_metadata_list, scheduler_outputs, - allow_async_output_proc) - - def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None: - self.block_manager.fork(parent_seq, child_seq) - - def free_seq(self, seq: Sequence) -> None: - """Free a sequence from a block table.""" - self.block_manager.free(seq) - - def _free_finished_seqs(self, seq_group: SequenceGroup) -> None: - """Free finished seqs in a sequence group.""" - for seq in seq_group.get_seqs(): - if seq.is_finished(): - self.free_seq(seq) - - def _free_finished_seq_group(self, seq_group: SequenceGroup) -> None: - if seq_group.is_finished(): - # Free cross-attention block table, if it exists - self._free_seq_group_cross_attn_blocks(seq_group) - - # Add the finished requests to the finished requests list. - # This list will be used to update the Mamba cache in the - # next step. - self._finished_requests_ids.append(seq_group.request_id) - - # Free finished seqs - self._free_finished_seqs(seq_group) - - def free_finished_seq_groups(self) -> None: - remaining: Deque[SequenceGroup] = deque() - for seq_group in self.running: - self._free_finished_seq_group(seq_group) - if not seq_group.is_finished(): - remaining.append(seq_group) - - self.running = remaining - - # Handle async stopped sequence groups - # (ones that reached max model len) - if self._async_stopped: - for seq_group in self._async_stopped: - self._free_seq_group_cross_attn_blocks(seq_group) - self._finished_requests_ids.append(seq_group.request_id) - - # Free finished seqs - self._free_finished_seqs(seq_group) - - self._async_stopped.clear() - - def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: - self.block_manager.allocate(seq_group) - for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): - seq.status = SequenceStatus.RUNNING - - def _append_slots( - self, - seq_group: SequenceGroup, - blocks_to_copy: List[Tuple[int, int]], - ) -> None: - """Appends new slots to the sequences in the given sequence group. - - Args: - seq_group (SequenceGroup): The sequence group containing the - sequences to append slots to. - blocks_to_copy (List[Tuple[int, int]]): A list of tuple of two - ints, the first int is the source block index, and the second - int is the destination block index. This list is updated with - the new source and destination block indices for the appended - slots. - """ - num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False) - seq_group.init_multi_step(num_scheduler_steps=num_lookahead_slots + 1) - - for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - cows = self.block_manager.append_slots(seq, num_lookahead_slots) - if len(cows) > 0: - blocks_to_copy.extend(cows) - - def _preempt( - self, - seq_group: SequenceGroup, - blocks_to_swap_out: List[Tuple[int, int]], - preemption_mode: Optional[PreemptionMode] = None, - ) -> PreemptionMode: - # If preemption mode is not specified, we determine the mode as follows: - # We use recomputation by default since it incurs lower overhead than - # swapping. However, when the sequence group has multiple sequences - # (e.g., beam search), recomputation is not currently supported. In - # such a case, we use swapping instead. - # FIXME(woosuk): This makes our scheduling policy a bit bizarre. - # As swapped sequences are prioritized over waiting sequences, - # sequence groups with multiple sequences are implicitly prioritized - # over sequence groups with a single sequence. - # TODO(woosuk): Support recomputation for sequence groups with multiple - # sequences. This may require a more sophisticated CUDA kernel. - if self.user_specified_preemption_mode is None: - if seq_group.get_max_num_running_seqs() == 1: - preemption_mode = PreemptionMode.RECOMPUTE - else: - preemption_mode = PreemptionMode.SWAP - - elif self.user_specified_preemption_mode == "swap": - preemption_mode = PreemptionMode.SWAP - else: - preemption_mode = PreemptionMode.RECOMPUTE - - if self.num_cumulative_preemption % 50 == 0: - logger.warning( - "Sequence group %s is preempted by %s mode because there is " - "not enough KV cache space. This can affect the end-to-end " - "performance. Increase gpu_memory_utilization or " - "tensor_parallel_size to provide more KV cache memory. " - "total_num_cumulative_preemption=%d", seq_group.request_id, - preemption_mode, self.num_cumulative_preemption + 1) - self.num_cumulative_preemption += 1 - - if preemption_mode == PreemptionMode.RECOMPUTE: - self._preempt_by_recompute(seq_group) - elif preemption_mode == PreemptionMode.SWAP: - self._preempt_by_swap(seq_group, blocks_to_swap_out) - else: - raise AssertionError("Invalid preemption mode.") - return preemption_mode - - def _preempt_by_recompute( - self, - seq_group: SequenceGroup, - ) -> None: - seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) - assert len(seqs) == 1 - for seq in seqs: - seq.status = SequenceStatus.WAITING - self.free_seq(seq) - seq.reset_state_for_recompute() - - def _preempt_by_swap( - self, - seq_group: SequenceGroup, - blocks_to_swap_out: List[Tuple[int, int]], - ) -> None: - self._swap_out(seq_group, blocks_to_swap_out) - - def _swap_in( - self, - seq_group: SequenceGroup, - blocks_to_swap_in: List[Tuple[int, int]], - ) -> None: - mapping = self.block_manager.swap_in(seq_group) - blocks_to_swap_in.extend(mapping) - for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): - seq.status = SequenceStatus.RUNNING - - def _swap_out( - self, - seq_group: SequenceGroup, - blocks_to_swap_out: List[Tuple[int, int]], - ) -> None: - if not self.block_manager.can_swap_out(seq_group): - # FIXME(woosuk): Abort the sequence group instead of aborting the - # entire engine. - raise RuntimeError( - "Aborted due to the lack of CPU swap space. Please increase " - "the swap space to avoid this error.") - mapping = self.block_manager.swap_out(seq_group) - blocks_to_swap_out.extend(mapping) - for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - seq.status = SequenceStatus.SWAPPED - - def _passed_delay(self, now: float) -> bool: - if self.prev_prompt: - self.last_prompt_latency = now - self.prev_time - self.prev_time, self.prev_prompt = now, False - # Delay scheduling prompts to let waiting queue fill up - if self.scheduler_config.delay_factor > 0 and self.waiting: - earliest_arrival_time = min( - [e.metrics.arrival_time for e in self.waiting]) - passed_delay = ( - (now - earliest_arrival_time) > - (self.scheduler_config.delay_factor * self.last_prompt_latency) - or not self.running) - else: - passed_delay = True - return passed_delay - - def _get_num_lookahead_slots(self, is_prefill: bool) -> int: - """The number of slots to allocate per sequence per step, beyond known - token ids. Speculative decoding uses these slots to store KV activations - of tokens which may or may not be accepted. - - Speculative decoding does not yet support prefill, so we do not perform - lookahead allocation for prefill. - """ - if is_prefill: - return 0 - - return self.scheduler_config.num_lookahead_slots - - def _get_num_new_tokens(self, seq_group: SequenceGroup, - status: SequenceStatus, enable_chunking: bool, - budget: SchedulingBudget) -> int: - """Get the next new tokens to compute for a given sequence group - that's in a given `status`. - - The API could chunk the number of tokens to compute based on `budget` - if `enable_chunking` is True. If a sequence group has multiple - sequences (e.g., running beam search), it means it is in decoding - phase, so chunking doesn't happen. - - Returns 0 if the new token cannot be computed due to token budget. - """ - num_new_tokens = 0 - seqs = seq_group.get_seqs(status=status) - for seq in seqs: - num_new_tokens += seq.get_num_new_tokens() - assert num_new_tokens > 0 - # Chunk if a running request cannot fit in the given budget. - # If number of seq > 1, it means it is doing beam search - # in a decode phase. Do not chunk. - if enable_chunking and len(seqs) == 1: - remaining_token_budget = budget.remaining_token_budget() - if self.cache_config.enable_prefix_caching: - # When prefix caching is enabled, we always allocate - # the number of new tokens that is dividable by the block size - # to avoid partial block matching. - block_size = self.cache_config.block_size - reminder = budget.token_budget % block_size - if reminder != 0: - raise ValueError("When enabling chunked prefill and " - "prefix caching, max_num_batched_tokens " - "(chunk size) must be dividable by " - "block size, but got chunk_size " - f"({budget.token_budget}) % block_size " - f"({block_size}) = {reminder}") - if remaining_token_budget < num_new_tokens: - num_new_tokens = (remaining_token_budget // - block_size) * block_size - else: - num_new_tokens = min(num_new_tokens, remaining_token_budget) - return num_new_tokens diff --git a/vllm/core/scheduler_v2.py b/vllm/core/scheduler_v2.py new file mode 100644 index 0000000000000..a94e3e7fb096e --- /dev/null +++ b/vllm/core/scheduler_v2.py @@ -0,0 +1,358 @@ +import enum +import os +import random +import time +from collections import deque +from dataclasses import dataclass, field +from typing import (Callable, Deque, Dict, Iterable, List, Optional, Set, + Tuple, Union) + +from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig +from vllm.core.interfaces import AllocStatus, BlockSpaceManager +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.utils import Device + +from vllm.request import Request, RequestStatus +from vllm.sampling_params import SamplingParams +from vllm.multimodal import MultiModalDataDict + +logger = init_logger(__name__) + + +class Scheduler: + + def __init__( + self, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + lora_config: Optional[LoRAConfig], + ) -> None: + self.scheduler_config = scheduler_config + self.cache_config = cache_config + # Note for LoRA scheduling: the current policy is extremely + # simple and NOT fair. It can lead to starvation of some + # LoRAs. This should be improved in the future. + self.lora_config = lora_config + + version = "v1" + if self.scheduler_config.use_v2_block_manager: + version = "v2" + if self.scheduler_config.embedding_mode: + version = "embedding" + BlockSpaceManagerImpl = \ + BlockSpaceManager.get_block_space_manager_class(version) + num_gpu_blocks = cache_config.num_gpu_blocks + num_cpu_blocks = cache_config.num_cpu_blocks + + # Create the block space manager. + self.block_manager = BlockSpaceManagerImpl( + block_size=self.cache_config.block_size, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks, + sliding_window=self.cache_config.sliding_window, + enable_caching=self.cache_config.enable_prefix_caching) + self.block_size = self.cache_config.block_size + + # Scheduling constraints. + self.max_num_running_reqs = self.scheduler_config.max_num_seqs + self.max_num_scheduled_tokens = \ + self.scheduler_config.max_num_batched_tokens + self.max_model_len = self.scheduler_config.max_model_len + + # Priority queues for requests. + self.waiting: Deque[Request] = deque() + self.running: Deque[Request] = deque() + + self.finished_req_ids: Set[str] = set() + self.aborted_req_ids: Set[str] = set() + + def schedule(self) -> "SchedulerOutput": + # Finish the requests that have reached the maximum length. + self._check_stop_by_len() + + scheduled_new_reqs: List[Request] = [] + scheduled_resumed_reqs: List[Request] = [] + scheduled_running_reqs: List[Request] = [] + preempted_reqs: List[Request] = [] + + req_to_new_block_ids: Dict[str, List[int]] = {} + num_scheduled_tokens: Dict[str, int] = {} + total_num_scheduled_tokens = 0 + num_remaining_tokens = self.max_num_scheduled_tokens + + # First, schedule the RUNNING requests. + while self.running: + if num_remaining_tokens == 0: + break + + request = self.running[0] + num_tokens = request.num_tokens - request.num_computed_tokens + num_tokens = min(num_tokens, num_remaining_tokens) + + new_block_ids: List[int] = [] + while not self.block_manager.can_append_slots(request, num_tokens): + new_block_ids = self.block_manager.append_slots( + request, num_tokens) + if not new_block_ids: + # The request cannot be scheduled. + # Preempt the lowest-priority request. + preempted_req = self.running.pop() + self.block_manager.free(preempted_req) + preempted_req.status = RequestStatus.PREEMPTED + preempted_req.num_computed_tokens = 0 + + self.waiting.appendleft(preempted_req) + preempted_reqs.append(preempted_req) + + if preempted_req == request: + break + else: + # The request can be scheduled. + self.running.popleft() + scheduled_running_reqs.append(request) + + req_to_new_block_ids[request.request_id] = new_block_ids + num_scheduled_tokens[request.request_id] = num_tokens + total_num_scheduled_tokens += num_tokens + num_remaining_tokens -= num_tokens + + request.status = RequestStatus.RUNNING + request.num_computed_tokens += num_tokens + if request.num_tokens == request.num_computed_tokens: + # TODO(woosuk): Consider speculative decoding. + request.num_output_tokens += 1 + + # Next, schedule the WAITING requests. + while self.waiting: + if preempted_reqs: + break + if len(self.running) == self.max_num_running_reqs: + break + if num_remaining_tokens == 0: + break + + request = self.waiting[0] + allocated = self.block_manager.allocate(request) + if allocated is None: + # The request cannot be scheduled. + break + + # The request can be scheduled. + computed_block_ids, new_block_ids = allocated + + # Get cached tokens. + num_computed_blocks = len(computed_block_ids) + num_computed_tokens = num_computed_blocks * self.block_size + + # Number of tokens to be scheduled. + num_tokens = request.num_tokens - num_computed_tokens + num_tokens = min(num_tokens, num_remaining_tokens) + + self.waiting.popleft() + self.running.append(request) + if request.status == RequestStatus.WAITING: + scheduled_new_reqs.append(request) + elif request.status == RequestStatus.PREEMPTED: + scheduled_resumed_reqs.append(request) + else: + assert False, f"Invalid request status: {request.status}" + + req_to_new_block_ids[request.request_id] = ( + computed_block_ids + new_block_ids) + num_scheduled_tokens[request.request_id] = num_tokens + total_num_scheduled_tokens += num_tokens + num_remaining_tokens -= num_tokens + + request.status = RequestStatus.RUNNING + request.num_computed_tokens = num_computed_tokens + num_tokens + if request.num_tokens == request.num_computed_tokens: + request.num_output_tokens += 1 + + # Check if the scheduling constraints are satisfied. + assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens + assert num_remaining_tokens >= 0 + assert len(self.running) <= self.max_num_running_reqs + assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + + len(scheduled_running_reqs) == len(self.running)) + + # Construct the scheduler output. + new_reqs_data = [ + NewRequestData.from_request( + req, req_to_new_block_ids[req.request_id]) + for req in scheduled_new_reqs + ] + resumed_reqs_data = [ + ResumedRequestData.from_request( + req, req_to_new_block_ids[req.request_id]) + for req in scheduled_resumed_reqs + ] + running_reqs_data = [ + RunningRequestData.from_request( + req, req_to_new_block_ids[req.request_id]) + for req in scheduled_running_reqs + ] + preempted_req_ids = {req.request_id for req in preempted_reqs} + scheduler_output = SchedulerOutput( + scheduled_new_reqs=new_reqs_data, + scheduled_resumed_reqs=resumed_reqs_data, + scheduled_running_reqs=running_reqs_data, + num_scheduled_tokens=num_scheduled_tokens, + total_num_scheduled_tokens=total_num_scheduled_tokens, + preempted_req_ids=preempted_req_ids, + finished_req_ids=self.finished_req_ids, + aborted_req_ids=self.aborted_req_ids, + ) + + self.finished_req_ids = set() + self.aborted_req_ids = set() + return scheduler_output + + def add_request(self, request: Request) -> None: + self.waiting.append(request) + + def abort_requests(self, request_ids: Union[str, Iterable[str]]) -> None: + if isinstance(request_ids, str): + request_ids = (request_ids, ) + request_ids = set(request_ids) + + # TODO: Optimize this. + for queue in [self.waiting, self.running]: + aborted_reqs: List[Request] = [] + for request in queue: + if not request_ids: + break + if request.request_id in request_ids: + request.status = RequestStatus.FINISHED_ABORTED + aborted_reqs.append(request) + request_ids.remove(request.request_id) + + for request in aborted_reqs: + queue.remove(request) + self.aborted_req_ids.add(request.request_id) + self._free_request(request) + + def stop_requests(self, request_ids: Union[str, Iterable[str]]) -> None: + if isinstance(request_ids, str): + request_ids = (request_ids, ) + request_ids = set(request_ids) + + # TODO: Optimize this. + for queue in [self.waiting, self.running]: + stopped_reqs: List[Request] = [] + for request in queue: + if not request_ids: + break + if request.request_id in request_ids: + request.status = RequestStatus.FINISHED_STOPPED + stopped_reqs.append(request) + request_ids.remove(request.request_id) + + for request in stopped_reqs: + queue.remove(request) + self.finished_req_ids.add(request.request_id) + self._free_request(request) + + def _check_stop_by_len(self) -> None: + stopped_reqs: List[Request] = [] + # TODO: Optimize this. + for request in self.running: + if (request.num_tokens >= self.max_model_len + or request.num_output_tokens >= request.max_tokens): + request.status = RequestStatus.FINISHED_LENGTH_CAPPED + stopped_reqs.append(request) + for request in stopped_reqs: + self.running.remove(request) + self.finished_req_ids.add(request.request_id) + self._free_request(request) + + def _free_request(self, request: Request) -> None: + assert request.is_finished() + self.block_manager.free(request) + + def has_unfinished_requests(self) -> bool: + return self.waiting or self.running + + def get_num_unfinished_requests(self) -> int: + return len(self.waiting) + len(self.running) + + +@dataclass +class NewRequestData: + + req_id: str + prompt_token_ids: List[int] + prompt: Optional[str] + multi_modal_data: Optional[MultiModalDataDict] + sampling_params: SamplingParams + block_ids: List[int] + num_computed_tokens: int + + @classmethod + def from_request( + cls, + request: Request, + block_ids: List[int], + ) -> "NewRequestData": + return cls( + req_id=request.request_id, + prompt_token_ids=request.inputs["prompt_token_ids"], + prompt=request.inputs.get("prompt"), + multi_modal_data=request.inputs.get("multi_modal_data"), + sampling_params=request.sampling_params, + block_ids=block_ids, + ) + + +@dataclass +class ResumedRequestData: + + req_id: str + block_ids: List[int] + num_computed_tokens: int + + @classmethod + def from_request( + cls, + request: Request, + block_ids: List[int], + ) -> "ResumedRequestData": + return cls( + req_id=request.request_id, + block_ids=block_ids, + ) + + +@dataclass +class RunningRequestData: + + req_id: str + new_block_ids: List[int] + num_computed_tokens: int + + @classmethod + def from_request( + cls, + request: Request, + new_block_ids: List[int], + ) -> "RunningRequestData": + return cls( + req_id=request.request_id, + new_block_ids=new_block_ids, + ) + + +@dataclass +class SchedulerOutput: + + scheduled_new_reqs: List[NewRequestData] + scheduled_resumed_reqs: List[ResumedRequestData] + scheduled_running_reqs: List[RunningRequestData] + + num_scheduled_tokens: Dict[str, int] + total_num_scheduled_tokens: int + + preempted_req_ids: Set[str] + finished_req_ids: Set[str] + aborted_req_ids: Set[str] diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py deleted file mode 100644 index 80dde804addac..0000000000000 --- a/vllm/engine/llm_engine.py +++ /dev/null @@ -1,1723 +0,0 @@ -import time -from collections import deque -from contextlib import contextmanager -from dataclasses import dataclass -from functools import partial -from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict, - Iterable, List, Mapping, NamedTuple, Optional) -from typing import Sequence as GenericSequence -from typing import Set, Type, Union - -import torch -from typing_extensions import TypeVar - -import vllm.envs as envs -from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, - EngineConfig, LoadConfig, LoRAConfig, ModelConfig, - ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig, - SpeculativeConfig) -from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler, - SchedulerOutputs) -from vllm.engine.arg_utils import EngineArgs -from vllm.engine.metrics_types import StatLoggerBase, Stats -from vllm.engine.output_processor.interfaces import ( - SequenceGroupOutputProcessor) -from vllm.engine.output_processor.stop_checker import StopChecker -from vllm.engine.output_processor.util import create_output_by_sequence_group -from vllm.executor.executor_base import ExecutorBase -from vllm.executor.gpu_executor import GPUExecutor -from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, - InputRegistry, LLMInputs, PromptType) -from vllm.inputs.preprocess import InputPreprocessor -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, - RequestOutputFactory) -from vllm.pooling_params import PoolingParams -from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import RequestOutputKind, SamplingParams -from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, - Sequence, SequenceGroup, SequenceGroupMetadata, - SequenceStatus) -from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, - init_tracer) -from vllm.transformers_utils.config import try_get_generation_config -from vllm.transformers_utils.detokenizer import Detokenizer -from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.transformers_utils.tokenizer_group import ( - BaseTokenizerGroup, init_tokenizer_from_configs) -from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, - usage_message) -from vllm.utils import Counter, Device, weak_bind -from vllm.version import __version__ as VLLM_VERSION - -logger = init_logger(__name__) -_LOCAL_LOGGING_INTERVAL_SEC = 5 - - -def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: - config = try_get_generation_config( - model_config.model, - trust_remote_code=model_config.trust_remote_code, - revision=model_config.revision, - ) - - if config is None: - return {} - - return config.to_diff_dict() - - -_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup) -_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput) - - -@dataclass -class SchedulerOutputState: - """Caches the scheduler outputs for a virtual engine. Used for Multi-Step""" - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None - scheduler_outputs: Optional[SchedulerOutputs] = None - allow_async_output_proc: bool = False - last_output: Optional[SamplerOutput] = None - - -class OutputData(NamedTuple): - outputs: List[SamplerOutput] - seq_group_metadata_list: List[SequenceGroupMetadata] - scheduler_outputs: SchedulerOutputs - is_async: bool - is_last_step: bool - skip: List[int] - - -class SchedulerContext: - - def __init__(self): - self.output_queue: Deque[OutputData] = deque() - self.request_outputs: List[Union[RequestOutput, - EmbeddingRequestOutput]] = [] - self.seq_group_metadata_list: Optional[ - List[SequenceGroupMetadata]] = None - self.scheduler_outputs: Optional[SchedulerOutputs] = None - - def append_output(self, outputs: List[SamplerOutput], - seq_group_metadata_list: List[SequenceGroupMetadata], - scheduler_outputs: SchedulerOutputs, is_async: bool, - is_last_step: bool): - self.output_queue.append( - OutputData(outputs=outputs, - seq_group_metadata_list=seq_group_metadata_list, - scheduler_outputs=scheduler_outputs, - is_async=is_async, - is_last_step=is_last_step, - skip=[])) - - -class LLMEngine: - """An LLM engine that receives requests and generates texts. - - This is the main class for the vLLM engine. It receives requests - from clients and generates texts from the LLM. It includes a tokenizer, a - language model (possibly distributed across multiple GPUs), and GPU memory - space allocated for intermediate states (aka KV cache). This class utilizes - iteration-level scheduling and efficient memory management to maximize the - serving throughput. - - The :class:`~vllm.LLM` class wraps this class for offline batched inference - and the :class:`AsyncLLMEngine` class wraps this class for online serving. - - The config arguments are derived from :class:`~vllm.EngineArgs`. (See - :ref:`engine_args`) - - Args: - model_config: The configuration related to the LLM model. - cache_config: The configuration related to the KV cache memory - management. - parallel_config: The configuration related to distributed execution. - scheduler_config: The configuration related to the request scheduler. - device_config: The configuration related to the device. - lora_config (Optional): The configuration related to serving multi-LoRA. - speculative_config (Optional): The configuration related to speculative - decoding. - executor_class: The model executor class for managing distributed - execution. - prompt_adapter_config (Optional): The configuration related to serving - prompt adapters. - log_stats: Whether to log statistics. - usage_context: Specified entry point, used for usage info collection. - """ - - DO_VALIDATE_OUTPUT: ClassVar[bool] = False - """A flag to toggle whether to validate the type of request output.""" - - @classmethod - @contextmanager - def enable_output_validation(cls): - cls.DO_VALIDATE_OUTPUT = True - - yield - - cls.DO_VALIDATE_OUTPUT = False - - @classmethod - def validate_output( - cls, - output: object, - output_type: Type[_O], - ) -> _O: - do_validate = cls.DO_VALIDATE_OUTPUT - - if ((TYPE_CHECKING or do_validate) - and not isinstance(output, output_type)): - raise TypeError(f"Expected output of type {output_type}, " - f"but found type {type(output)}") - - return output - - @classmethod - def validate_outputs( - cls, - outputs: GenericSequence[object], - output_type: Type[_O], - ) -> List[_O]: - do_validate = cls.DO_VALIDATE_OUTPUT - - outputs_: List[_O] - if TYPE_CHECKING or do_validate: - outputs_ = [] - for output in outputs: - if not isinstance(output, output_type): - raise TypeError(f"Expected output of type {output_type}, " - f"but found type {type(output)}") - - outputs_.append(output) - else: - outputs_ = outputs - - return outputs_ - - tokenizer: Optional[BaseTokenizerGroup] - - def __init__( - self, - model_config: ModelConfig, - cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - speculative_config: Optional[SpeculativeConfig], - decoding_config: Optional[DecodingConfig], - observability_config: Optional[ObservabilityConfig], - prompt_adapter_config: Optional[PromptAdapterConfig], - executor_class: Type[ExecutorBase], - log_stats: bool, - usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, - input_registry: InputRegistry = INPUT_REGISTRY, - ) -> None: - logger.info( - "Initializing an LLM engine (v%s) with config: " - "model=%r, speculative_config=%r, tokenizer=%r, " - "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " - "override_neuron_config=%s, " - "rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, " - "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " - "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " - "pipeline_parallel_size=%d, " - "disable_custom_all_reduce=%s, quantization=%s, " - "enforce_eager=%s, kv_cache_dtype=%s, " - "quantization_param_path=%s, device_config=%s, " - "decoding_config=%r, observability_config=%r, " - "seed=%d, served_model_name=%s, use_v2_block_manager=%s, " - "num_scheduler_steps=%d, enable_prefix_caching=%s, " - "use_async_output_proc=%s, mm_processor_kwargs=%s)", - VLLM_VERSION, - model_config.model, - speculative_config, - model_config.tokenizer, - model_config.skip_tokenizer_init, - model_config.tokenizer_mode, - model_config.revision, - model_config.override_neuron_config, - model_config.rope_scaling, - model_config.rope_theta, - model_config.tokenizer_revision, - model_config.trust_remote_code, - model_config.dtype, - model_config.max_model_len, - load_config.download_dir, - load_config.load_format, - parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size, - parallel_config.disable_custom_all_reduce, - model_config.quantization, - model_config.enforce_eager, - cache_config.cache_dtype, - model_config.quantization_param_path, - device_config.device, - decoding_config, - observability_config, - model_config.seed, - model_config.served_model_name, - scheduler_config.use_v2_block_manager, - scheduler_config.num_scheduler_steps, - cache_config.enable_prefix_caching, - model_config.use_async_output_proc, - model_config.mm_processor_kwargs, - ) - # TODO(woosuk): Print more configs in debug mode. - from vllm.plugins import load_general_plugins - load_general_plugins() - - self.model_config = model_config - self.cache_config = cache_config - self.lora_config = lora_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.speculative_config = speculative_config - self.load_config = load_config - self.decoding_config = decoding_config or DecodingConfig() - self.prompt_adapter_config = prompt_adapter_config - self.observability_config = observability_config or ObservabilityConfig( - ) - self.log_stats = log_stats - - if not self.model_config.skip_tokenizer_init: - self.tokenizer = self._init_tokenizer() - self.detokenizer = Detokenizer(self.tokenizer) - tokenizer_group = self.get_tokenizer_group() - else: - self.tokenizer = None - self.detokenizer = None - tokenizer_group = None - - # Ensure that the function doesn't contain a reference to self, - # to avoid engine GC issues - def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: - assert tokenizer_group, ("tokenizer_group cannot be None, " - "make sure skip_tokenizer_init is False") - return tokenizer_group.get_lora_tokenizer(sequence.lora_request) - - self.seq_counter = Counter() - self.generation_config_fields = _load_generation_config_dict( - model_config) - - self.input_preprocessor = InputPreprocessor(model_config, - self.tokenizer) - - self.input_registry = input_registry - self.input_processor = input_registry.create_input_processor( - model_config) - - self.model_executor = executor_class( - model_config=model_config, - cache_config=cache_config, - parallel_config=parallel_config, - scheduler_config=scheduler_config, - device_config=device_config, - lora_config=lora_config, - speculative_config=speculative_config, - load_config=load_config, - prompt_adapter_config=prompt_adapter_config, - observability_config=self.observability_config, - ) - - if not self.model_config.embedding_mode: - self._initialize_kv_caches() - - # If usage stat is enabled, collect relevant info. - if is_usage_stats_enabled(): - from vllm.model_executor.model_loader import ( - get_architecture_class_name) - usage_message.report_usage( - get_architecture_class_name(model_config), - usage_context, - extra_kvs={ - # Common configuration - "dtype": - str(model_config.dtype), - "tensor_parallel_size": - parallel_config.tensor_parallel_size, - "block_size": - cache_config.block_size, - "gpu_memory_utilization": - cache_config.gpu_memory_utilization, - - # Quantization - "quantization": - model_config.quantization, - "kv_cache_dtype": - str(cache_config.cache_dtype), - - # Feature flags - "enable_lora": - bool(lora_config), - "enable_prompt_adapter": - bool(prompt_adapter_config), - "enable_prefix_caching": - cache_config.enable_prefix_caching, - "enforce_eager": - model_config.enforce_eager, - "disable_custom_all_reduce": - parallel_config.disable_custom_all_reduce, - }) - - if self.tokenizer: - # Ping the tokenizer to ensure liveness if it runs in a - # different process. - self.tokenizer.ping() - - self.cached_scheduler_outputs = [ - SchedulerOutputState() - for _ in range(self.parallel_config.pipeline_parallel_size) - ] - - self.scheduler_contexts = [ - SchedulerContext() - for _ in range(self.parallel_config.pipeline_parallel_size) - ] - - if model_config.use_async_output_proc: - process_model_outputs = weak_bind(self._process_model_outputs) - - self.async_callbacks = [ - partial(process_model_outputs, - ctx=self.scheduler_contexts[v_id]) - for v_id in range(self.parallel_config.pipeline_parallel_size) - ] - else: - self.async_callbacks = [] - - # Currently used by AsyncLLMEngine to ensure quick append - # of request outputs to asyncio queues - self.process_request_outputs_callback: Optional[Callable] = None - - # Create the scheduler. - # NOTE: the cache_config here have been updated with the numbers of - # GPU and CPU blocks, which are profiled in the distributed executor. - self.scheduler = [ - Scheduler( - scheduler_config, cache_config, lora_config, - parallel_config.pipeline_parallel_size, - self.async_callbacks[v_id] - if model_config.use_async_output_proc else None) - for v_id in range(parallel_config.pipeline_parallel_size) - ] - - # Metric Logging. - if self.log_stats: - if stat_loggers is not None: - self.stat_loggers = stat_loggers - else: - # Lazy import for prometheus multiprocessing. - # We need to set PROMETHEUS_MULTIPROC_DIR environment variable - # before prometheus_client is imported. - # See https://prometheus.github.io/client_python/multiprocess/ - from vllm.engine.metrics import (LoggingStatLogger, - PrometheusStatLogger) - - self.stat_loggers = { - "logging": - LoggingStatLogger( - local_interval=_LOCAL_LOGGING_INTERVAL_SEC), - "prometheus": - PrometheusStatLogger( - local_interval=_LOCAL_LOGGING_INTERVAL_SEC, - labels=dict(model_name=model_config.served_model_name), - max_model_len=self.model_config.max_model_len), - } - self.stat_loggers["prometheus"].info("cache_config", - self.cache_config) - - self.tracer = None - if self.observability_config.otlp_traces_endpoint: - self.tracer = init_tracer( - "vllm.llm_engine", - self.observability_config.otlp_traces_endpoint) - - # Create sequence output processor, e.g. for beam search or - # speculative decoding. - self.output_processor = ( - SequenceGroupOutputProcessor.create_output_processor( - self.scheduler_config, - self.detokenizer, - self.scheduler, - self.seq_counter, - get_tokenizer_for_seq, - stop_checker=StopChecker( - self.scheduler_config.max_model_len, - get_tokenizer_for_seq, - ), - )) - - def _initialize_kv_caches(self) -> None: - """Initialize the KV cache in the worker(s). - - The workers will determine the number of blocks in both the GPU cache - and the swap CPU cache. - """ - num_gpu_blocks, num_cpu_blocks = ( - self.model_executor.determine_num_available_blocks()) - - if self.cache_config.num_gpu_blocks_override is not None: - num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override - logger.info( - "Overriding num_gpu_blocks=%d with " - "num_gpu_blocks_override=%d", num_gpu_blocks, - num_gpu_blocks_override) - num_gpu_blocks = num_gpu_blocks_override - - self.cache_config.num_gpu_blocks = num_gpu_blocks - self.cache_config.num_cpu_blocks = num_cpu_blocks - - self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks) - - @classmethod - def _get_executor_cls(cls, - engine_config: EngineConfig) -> Type[ExecutorBase]: - distributed_executor_backend = ( - engine_config.parallel_config.distributed_executor_backend) - # Initialize the cluster and specify the executor class. - if isinstance(distributed_executor_backend, type): - if not issubclass(distributed_executor_backend, ExecutorBase): - raise TypeError( - "distributed_executor_backend must be a subclass of " - f"ExecutorBase. Got {distributed_executor_backend}.") - if distributed_executor_backend.uses_ray: # type: ignore - initialize_ray_cluster(engine_config.parallel_config) - executor_class = distributed_executor_backend - elif engine_config.device_config.device_type == "neuron": - from vllm.executor.neuron_executor import NeuronExecutor - executor_class = NeuronExecutor - elif engine_config.device_config.device_type == "tpu": - if distributed_executor_backend == "ray": - initialize_ray_cluster(engine_config.parallel_config) - from vllm.executor.ray_tpu_executor import RayTPUExecutor - executor_class = RayTPUExecutor - else: - assert distributed_executor_backend is None - from vllm.executor.tpu_executor import TPUExecutor - executor_class = TPUExecutor - elif engine_config.device_config.device_type == "cpu": - from vllm.executor.cpu_executor import CPUExecutor - executor_class = CPUExecutor - elif engine_config.device_config.device_type == "openvino": - from vllm.executor.openvino_executor import OpenVINOExecutor - executor_class = OpenVINOExecutor - elif engine_config.device_config.device_type == "xpu": - if distributed_executor_backend == "ray": - initialize_ray_cluster(engine_config.parallel_config) - from vllm.executor.ray_xpu_executor import RayXPUExecutor - executor_class = RayXPUExecutor - elif distributed_executor_backend == "mp": - # FIXME(kunshang): - # spawn needs calling `if __name__ == '__main__':`` - # fork is not supported for xpu start new process. - logger.error( - "Both start methods (spawn and fork) have issue " - "on XPU if you use mp backend, Please try ray instead.") - else: - from vllm.executor.xpu_executor import XPUExecutor - executor_class = XPUExecutor - elif distributed_executor_backend == "ray": - initialize_ray_cluster(engine_config.parallel_config) - from vllm.executor.ray_gpu_executor import RayGPUExecutor - executor_class = RayGPUExecutor - elif distributed_executor_backend == "mp": - from vllm.executor.multiproc_gpu_executor import ( - MultiprocessingGPUExecutor) - assert not envs.VLLM_USE_RAY_SPMD_WORKER, ( - "multiprocessing distributed executor backend does not " - "support VLLM_USE_RAY_SPMD_WORKER=1") - executor_class = MultiprocessingGPUExecutor - else: - from vllm.executor.gpu_executor import GPUExecutor - executor_class = GPUExecutor - return executor_class - - @classmethod - def from_engine_args( - cls, - engine_args: EngineArgs, - usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, - ) -> "LLMEngine": - """Creates an LLM engine from the engine arguments.""" - # Create the engine configs. - engine_config = engine_args.create_engine_config() - executor_class = cls._get_executor_cls(engine_config) - # Create the LLM engine. - engine = cls( - **engine_config.to_dict(), - executor_class=executor_class, - log_stats=not engine_args.disable_log_stats, - usage_context=usage_context, - stat_loggers=stat_loggers, - ) - - return engine - - def __reduce__(self): - # This is to ensure that the LLMEngine is not referenced in - # the closure used to initialize Ray worker actors - raise RuntimeError("LLMEngine should not be pickled!") - - def __del__(self): - # Shutdown model executor when engine is garbage collected - # Use getattr since __init__ can fail before the field is set - if model_executor := getattr(self, "model_executor", None): - model_executor.shutdown() - - def get_tokenizer_group( - self, - group_type: Type[_G] = BaseTokenizerGroup, - ) -> _G: - tokenizer_group = self.tokenizer - - if tokenizer_group is None: - raise ValueError("Unable to get tokenizer because " - "skip_tokenizer_init is True") - if not isinstance(tokenizer_group, group_type): - raise TypeError("Invalid type of tokenizer group. " - f"Expected type: {group_type}, but " - f"found type: {type(tokenizer_group)}") - - return tokenizer_group - - def get_tokenizer( - self, - lora_request: Optional[LoRARequest] = None, - ) -> AnyTokenizer: - return self.get_tokenizer_group().get_lora_tokenizer(lora_request) - - def _init_tokenizer(self) -> BaseTokenizerGroup: - return init_tokenizer_from_configs( - model_config=self.model_config, - scheduler_config=self.scheduler_config, - parallel_config=self.parallel_config, - enable_lora=bool(self.lora_config)) - - def _verify_args(self) -> None: - self.model_config.verify_with_parallel_config(self.parallel_config) - self.cache_config.verify_with_parallel_config(self.parallel_config) - if self.lora_config: - self.lora_config.verify_with_model_config(self.model_config) - self.lora_config.verify_with_scheduler_config( - self.scheduler_config) - if self.prompt_adapter_config: - self.prompt_adapter_config.verify_with_model_config( - self.model_config) - - def _add_processed_request( - self, - request_id: str, - processed_inputs: Union[LLMInputs, EncoderDecoderLLMInputs], - params: Union[SamplingParams, PoolingParams], - arrival_time: float, - lora_request: Optional[LoRARequest], - prompt_adapter_request: Optional[PromptAdapterRequest], - trace_headers: Optional[Mapping[str, str]] = None, - ) -> None: - self._validate_model_inputs(processed_inputs) - # Create the sequences. - block_size = self.cache_config.block_size - seq_id = next(self.seq_counter) - eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) - - seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id, - lora_request, prompt_adapter_request) - - encoder_seq = None - if 'encoder_prompt_token_ids' in processed_inputs: - encoder_seq = Sequence(seq_id, - processed_inputs, - block_size, - eos_token_id, - lora_request, - prompt_adapter_request, - from_decoder_prompt=False) - - # Create a SequenceGroup based on SamplingParams or PoolingParams - if isinstance(params, SamplingParams): - seq_group = self._create_sequence_group_with_sampling( - request_id, - seq, - params, - arrival_time=arrival_time, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - encoder_seq=encoder_seq) - elif isinstance(params, PoolingParams): - seq_group = self._create_sequence_group_with_pooling( - request_id, - seq, - params, - arrival_time=arrival_time, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request, - encoder_seq=encoder_seq) - else: - raise ValueError( - "Either SamplingParams or PoolingParams must be provided.") - - # Add the sequence group to the scheduler with least unfinished seqs. - costs = [ - scheduler.get_num_unfinished_seq_groups() - for scheduler in self.scheduler - ] - min_cost_scheduler = self.scheduler[costs.index(min(costs))] - min_cost_scheduler.add_seq_group(seq_group) - - def stop_remote_worker_execution_loop(self) -> None: - self.model_executor.stop_remote_worker_execution_loop() - - def add_request( - self, - request_id: str, - prompt: PromptType, - params: Union[SamplingParams, PoolingParams], - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> None: - """Add a request to the engine's request pool. - - The request is added to the request pool and will be processed by the - scheduler as `engine.step()` is called. The exact scheduling policy is - determined by the scheduler. - - Args: - request_id: The unique ID of the request. - prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` - for more details about the format of each input. - params: Parameters for sampling or pooling. - :class:`~vllm.SamplingParams` for text generation. - :class:`~vllm.PoolingParams` for pooling. - arrival_time: The arrival time of the request. If None, we use - the current monotonic time. - trace_headers: OpenTelemetry trace headers. - - Details: - - Set arrival_time to the current time if it is None. - - Set prompt_token_ids to the encoded prompt if it is None. - - Create `best_of` number of :class:`~vllm.Sequence` objects. - - Create a :class:`~vllm.SequenceGroup` object - from the list of :class:`~vllm.Sequence`. - - Add the :class:`~vllm.SequenceGroup` object to the scheduler. - - Example: - >>> # initialize engine - >>> engine = LLMEngine.from_engine_args(engine_args) - >>> # set request arguments - >>> example_prompt = "Who is the president of the United States?" - >>> sampling_params = SamplingParams(temperature=0.0) - >>> request_id = 0 - >>> - >>> # add the request to the engine - >>> engine.add_request( - >>> str(request_id), - >>> example_prompt, - >>> SamplingParams(temperature=0.0)) - >>> # continue the request processing - >>> ... - """ - if lora_request is not None and not self.lora_config: - raise ValueError(f"Got lora_request {lora_request} but LoRA is " - "not enabled!") - if arrival_time is None: - arrival_time = time.time() - - preprocessed_inputs = self.input_preprocessor.preprocess( - prompt, - request_id=request_id, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request, - ) - processed_inputs = self.input_processor(preprocessed_inputs) - - self._add_processed_request( - request_id=request_id, - processed_inputs=processed_inputs, - params=params, - arrival_time=arrival_time, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request, - trace_headers=trace_headers, - ) - - def _create_sequence_group_with_sampling( - self, - request_id: str, - seq: Sequence, - sampling_params: SamplingParams, - arrival_time: float, - lora_request: Optional[LoRARequest], - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - encoder_seq: Optional[Sequence] = None, - ) -> SequenceGroup: - """Creates a SequenceGroup with SamplingParams.""" - max_logprobs = self.get_model_config().max_logprobs - if (sampling_params.logprobs - and sampling_params.logprobs > max_logprobs) or ( - sampling_params.prompt_logprobs - and sampling_params.prompt_logprobs > max_logprobs): - raise ValueError(f"Cannot request more than " - f"{max_logprobs} logprobs.") - - # Defensive copy of SamplingParams, which are used by the sampler, - # this doesn't deep-copy LogitsProcessor objects - sampling_params = sampling_params.clone() - - sampling_params.update_from_generation_config( - self.generation_config_fields, seq.eos_token_id) - - # Create the sequence group. - seq_group = SequenceGroup( - request_id=request_id, - seqs=[seq], - arrival_time=arrival_time, - sampling_params=sampling_params, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - encoder_seq=encoder_seq) - - return seq_group - - def _create_sequence_group_with_pooling( - self, - request_id: str, - seq: Sequence, - pooling_params: PoolingParams, - arrival_time: float, - lora_request: Optional[LoRARequest], - prompt_adapter_request: Optional[PromptAdapterRequest], - encoder_seq: Optional[Sequence] = None, - ) -> SequenceGroup: - """Creates a SequenceGroup with PoolingParams.""" - # Defensive copy of PoolingParams, which are used by the pooler - pooling_params = pooling_params.clone() - # Create the sequence group. - seq_group = SequenceGroup( - request_id=request_id, - seqs=[seq], - arrival_time=arrival_time, - lora_request=lora_request, - pooling_params=pooling_params, - prompt_adapter_request=prompt_adapter_request, - encoder_seq=encoder_seq) - return seq_group - - def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: - """Aborts a request(s) with the given ID. - - Args: - request_id: The ID(s) of the request to abort. - - Details: - - Refer to the - :meth:`~vllm.core.scheduler.Scheduler.abort_seq_group` - from class :class:`~vllm.core.scheduler.Scheduler`. - - Example: - >>> # initialize engine and add a request with request_id - >>> request_id = str(0) - >>> # abort the request - >>> engine.abort_request(request_id) - """ - for scheduler in self.scheduler: - scheduler.abort_seq_group(request_id) - - def get_model_config(self) -> ModelConfig: - """Gets the model configuration.""" - return self.model_config - - def get_parallel_config(self) -> ParallelConfig: - """Gets the parallel configuration.""" - return self.parallel_config - - def get_decoding_config(self) -> DecodingConfig: - """Gets the decoding configuration.""" - return self.decoding_config - - def get_scheduler_config(self) -> SchedulerConfig: - """Gets the scheduler configuration.""" - return self.scheduler_config - - def get_lora_config(self) -> LoRAConfig: - """Gets the LoRA configuration.""" - return self.lora_config - - def get_num_unfinished_requests(self) -> int: - """Gets the number of unfinished requests.""" - return sum(scheduler.get_num_unfinished_seq_groups() - for scheduler in self.scheduler) - - def has_unfinished_requests(self) -> bool: - """Returns True if there are unfinished requests.""" - return any(scheduler.has_unfinished_seqs() - for scheduler in self.scheduler) - - def has_unfinished_requests_for_virtual_engine( - self, virtual_engine: int) -> bool: - """ - Returns True if there are unfinished requests for the virtual engine. - """ - return self.scheduler[virtual_engine].has_unfinished_seqs() - - @staticmethod - def _process_sequence_group_outputs( - seq_group: SequenceGroup, - outputs: List[EmbeddingSequenceGroupOutput], - ) -> None: - seq_group.embeddings = outputs[0].embeddings - - for seq in seq_group.get_seqs(): - seq.status = SequenceStatus.FINISHED_STOPPED - - return - - def _process_model_outputs(self, - ctx: SchedulerContext, - request_id: Optional[str] = None) -> None: - """Apply the model output to the sequences in the scheduled seq groups - and return responses. - - ctx: The virtual engine context to work on - request_id: If provided, then only this request is going to be processed - - """ - now = time.time() - - if len(ctx.output_queue) == 0: - return None - - # Get pending async postprocessor - if request_id: - # When we process only one request, no pop is required - # (since later we will process all of the rest) - (outputs, seq_group_metadata_list, scheduler_outputs, is_async, - is_last_step, skip) = ctx.output_queue[0] - else: - (outputs, seq_group_metadata_list, scheduler_outputs, is_async, - is_last_step, skip) = ctx.output_queue.popleft() - - # Sanity check - assert len(seq_group_metadata_list) == len( - scheduler_outputs.scheduled_seq_groups) - - # Organize outputs by [step][sequence group] instead of - # [sequence group][step]. - if len(outputs) > 1: - outputs_by_sequence_group = create_output_by_sequence_group( - outputs, num_seq_groups=len(seq_group_metadata_list)) - else: - outputs_by_sequence_group = outputs - - # Determine the requests we need to operate on - if request_id: - indices = [] - for i, seq_group_meta in enumerate(seq_group_metadata_list): - if seq_group_meta.request_id == request_id: - assert i not in skip # Cannot be called twice - indices.append(i) - break - - # If the request_id was not found, then it means that - # this is a new request that has no pending async - # postprocessor - if not indices: - return - else: - indices = range(len(seq_group_metadata_list)) # type: ignore - - finished_before: List[int] = [] - finished_now: List[int] = [] - for i in indices: - if i in skip: - continue - - seq_group_meta = seq_group_metadata_list[i] - scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] - - seq_group = scheduled_seq_group.seq_group - - if seq_group.is_finished(): - finished_before.append(i) - continue - - if len(outputs) > 1: - output = outputs_by_sequence_group[i] - else: - output = [outputs_by_sequence_group[0][i]] - - if not is_async: - seq_group.update_num_computed_tokens( - scheduled_seq_group.token_chunk_size) - - if outputs: - for o in outputs: - if (isinstance(o, SamplerOutput) - and seq_group.metrics is not None): - if seq_group.metrics.model_forward_time is not None: - seq_group.metrics.model_forward_time += ( - o.model_forward_time) - else: - seq_group.metrics.model_forward_time = ( - o.model_forward_time) - if seq_group.metrics.model_execute_time is not None: - seq_group.metrics.model_execute_time += ( - o.model_execute_time) - else: - seq_group.metrics.model_execute_time = ( - o.model_execute_time) - - if self.model_config.embedding_mode: - self._process_sequence_group_outputs(seq_group, output) - else: - self.output_processor.process_prompt_logprob(seq_group, output) - if seq_group_meta.do_sample: - self.output_processor.process_outputs( - seq_group, output, is_async) - - if seq_group.is_finished(): - finished_now.append(i) - - # Generate outputs for the requests that finished this iteration - for i in finished_now: - scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] - - seq_group = scheduled_seq_group.seq_group - seq_group.maybe_set_first_token_time(now) - request_output = RequestOutputFactory.create(seq_group) - if request_output: - ctx.request_outputs.append(request_output) - - # When we process a single request, we skip it for the next time, - # and invoke the request output callback (if there was final output) - if request_id: - assert len(indices) == 1 - skip.append(indices[0]) - - if (finished_now - and self.process_request_outputs_callback is not None): - self.process_request_outputs_callback(ctx.request_outputs) - ctx.request_outputs.clear() - return - - # Free currently finished requests - if finished_now: - for scheduler in self.scheduler: - scheduler.free_finished_seq_groups() - - # For multi-step, do not create outputs each iteration - if not is_last_step: - # Immediately process request outputs here (if callback is given) - if (finished_now - and self.process_request_outputs_callback is not None): - self.process_request_outputs_callback(ctx.request_outputs) - ctx.request_outputs.clear() - return - - # Create the outputs - for i in indices: - if i in skip or i in finished_before or i in finished_now: - continue # Avoids double processing - - scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] - - seq_group = scheduled_seq_group.seq_group - seq_group.maybe_set_first_token_time(now) - request_output = RequestOutputFactory.create(seq_group) - if request_output: - ctx.request_outputs.append(request_output) - - for seq_group in scheduler_outputs.ignored_seq_groups: - params = seq_group.sampling_params - if params is not None and params.output_kind == ( - RequestOutputKind.DELTA) and not seq_group.is_finished(): - continue - - request_output = RequestOutputFactory.create(seq_group) - if request_output: - ctx.request_outputs.append(request_output) - - # Immediately process request outputs here (if callback is given) - if (ctx.request_outputs - and self.process_request_outputs_callback is not None): - self.process_request_outputs_callback(ctx.request_outputs) - ctx.request_outputs.clear() - - # For async case, we need to record the stats here. - # For non-async case, the stats are done in the - # LLMEngine/AsyncLLMEngine directly - if is_async: - # Log stats. - self.do_log_stats(scheduler_outputs, outputs, finished_before, - skip) - - # Tracing - self.do_tracing(scheduler_outputs) - - return None - - def _advance_to_next_step( - self, output: List[SamplerOutput], - seq_group_metadata_list: List[SequenceGroupMetadata], - scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None: - """Given model output from a single run, append the tokens to the - sequences. This is normally done inside output processor, but it is - required if the worker is to perform async forward pass to next step. - """ - for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \ - zip(seq_group_metadata_list, output, scheduled_seq_groups): - seq_group = scheduled_seq_group.seq_group - - if seq_group.is_finished(): - continue - - seq_group.update_num_computed_tokens( - seq_group_metadata.token_chunk_size) - - if seq_group_metadata.do_sample: - assert len(sequence_group_outputs.samples) == 1, ( - "Async output processor expects a single sample" - " (i.e sampling_params.n == 1 and no " - "sampling_params.best_of > 1)") - sample = sequence_group_outputs.samples[0] - - assert len(seq_group.seqs) == 1 - seq = seq_group.seqs[0] - seq.append_token_id(sample.output_token, sample.logprobs) - - def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: - """Performs one decoding iteration and returns newly generated results. - - .. figure:: https://i.imgur.com/sv2HssD.png - :alt: Overview of the step function - :align: center - - Overview of the step function. - - Details: - - Step 1: Schedules the sequences to be executed in the next - iteration and the token blocks to be swapped in/out/copy. - - - Depending on the scheduling policy, - sequences may be `preempted/reordered`. - - A Sequence Group (SG) refer to a group of sequences - that are generated from the same prompt. - - - Step 2: Calls the distributed executor to execute the model. - - Step 3: Processes the model output. This mainly includes: - - - Decodes the relevant outputs. - - Updates the scheduled sequence groups with model outputs - based on its `sampling parameters` (`use_beam_search` or not). - - Frees the finished sequence groups. - - - Finally, it creates and returns the newly generated results. - - Example: - >>> # Please see the example/ folder for more detailed examples. - >>> - >>> # initialize engine and request arguments - >>> engine = LLMEngine.from_engine_args(engine_args) - >>> example_inputs = [(0, "What is LLM?", - >>> SamplingParams(temperature=0.0))] - >>> - >>> # Start the engine with an event loop - >>> while True: - >>> if example_inputs: - >>> req_id, prompt, sampling_params = example_inputs.pop(0) - >>> engine.add_request(str(req_id),prompt,sampling_params) - >>> - >>> # continue the request processing - >>> request_outputs = engine.step() - >>> for request_output in request_outputs: - >>> if request_output.finished: - >>> # return or show the request output - >>> - >>> if not (engine.has_unfinished_requests() or example_inputs): - >>> break - """ - if self.parallel_config.pipeline_parallel_size > 1: - raise NotImplementedError( - "Pipeline parallelism is only supported through AsyncLLMEngine " - "as performance will be severely degraded otherwise.") - - # For llm_engine, there is no pipeline parallel support, so the engine - # used is always 0. - virtual_engine = 0 - - # These are cached outputs from previous iterations. None if on first - # iteration - cached_outputs = self.cached_scheduler_outputs[virtual_engine] - seq_group_metadata_list = cached_outputs.seq_group_metadata_list - scheduler_outputs = cached_outputs.scheduler_outputs - allow_async_output_proc = cached_outputs.allow_async_output_proc - - ctx = self.scheduler_contexts[virtual_engine] - - # Clear outputs for each new scheduler iteration - ctx.request_outputs.clear() - - # Skip the scheduler if there are any remaining steps in the seq groups. - # This ensures that the scheduler is only called again when the current - # batch has completed. - if not self._has_remaining_steps(seq_group_metadata_list): - # Schedule iteration - (seq_group_metadata_list, scheduler_outputs, - allow_async_output_proc - ) = self.scheduler[virtual_engine].schedule() - - ctx.seq_group_metadata_list = seq_group_metadata_list - ctx.scheduler_outputs = scheduler_outputs - - # Maybe switch from async mode to sync mode - if not allow_async_output_proc and len(ctx.output_queue) > 0: - self._process_model_outputs(ctx=ctx) - - if (self.scheduler_config.is_multi_step - and scheduler_outputs.num_lookahead_slots > 0): - # cache the scheduler outputs for the next iteration if we have - # lookahead slots - self._cache_scheduler_outputs_for_multi_step( - virtual_engine, seq_group_metadata_list, scheduler_outputs, - allow_async_output_proc) - - assert seq_group_metadata_list is not None - assert scheduler_outputs is not None - - if not scheduler_outputs.is_empty(): - finished_requests_ids = self.scheduler[ - virtual_engine].get_and_reset_finished_requests_ids() - - # Check if we have a cached last_output from the previous iteration. - # For supporting PP this is probably the best way to pass the - # sampled_token_ids, as a separate broadcast over all the PP stages - # will cause one virtual engine's microbatch to block the pipeline. - last_sampled_token_ids = \ - self._get_last_sampled_token_ids(virtual_engine) - - execute_model_req = ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, - blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, - blocks_to_copy=scheduler_outputs.blocks_to_copy, - num_lookahead_slots=scheduler_outputs.num_lookahead_slots, - running_queue_size=scheduler_outputs.running_queue_size, - finished_requests_ids=finished_requests_ids, - # We use ExecuteModelRequest to pass the last sampled_token_ids - # to each of the non-last PP stages for in-place prepare_input. - last_sampled_token_ids=last_sampled_token_ids) - - if allow_async_output_proc: - execute_model_req.async_callback = self.async_callbacks[ - virtual_engine] - - outputs = self.model_executor.execute_model( - execute_model_req=execute_model_req) - - # We need to do this here so that last step's sampled_token_ids can - # be passed to the next iteration for PP. - if self.scheduler_config.is_multi_step: - self._update_cached_scheduler_output(virtual_engine, outputs) - else: - # Nothing scheduled => If there is pending async postprocessor, - # then finish it here. - if len(ctx.output_queue) > 0: - self._process_model_outputs(ctx=ctx) - # No outputs in this case - outputs = [] - - # Finish the current step for all the sequence groups. - if self.scheduler_config.is_multi_step: - for seq_group in seq_group_metadata_list: - seq_group.finish_step() - - if not self._has_remaining_steps(seq_group_metadata_list): - # clear the cache if we have finished all the steps. - if self.scheduler_config.is_multi_step: - self.cached_scheduler_outputs[0] = SchedulerOutputState() - - # Add results to the output_queue - ctx.append_output(outputs=outputs, - seq_group_metadata_list=seq_group_metadata_list, - scheduler_outputs=scheduler_outputs, - is_async=allow_async_output_proc, - is_last_step=True) - - if outputs and allow_async_output_proc: - assert len(outputs) == 1, ( - "Async postprocessor expects only a single output set") - - self._advance_to_next_step( - outputs[0], seq_group_metadata_list, - scheduler_outputs.scheduled_seq_groups) - - # Check if need to run the usual non-async path - if not allow_async_output_proc: - self._process_model_outputs(ctx=ctx) - - # Log stats. - self.do_log_stats(scheduler_outputs, outputs) - - # Tracing - self.do_tracing(scheduler_outputs) - else: - # Multi-step case - return ctx.request_outputs - - if not self.has_unfinished_requests(): - # Drain async postprocessor (if exists) - if len(ctx.output_queue) > 0: - self._process_model_outputs(ctx=ctx) - assert len(ctx.output_queue) == 0 - - # Stop the execute model loop in parallel workers until there are - # more requests to process. This avoids waiting indefinitely in - # torch.distributed ops which may otherwise timeout, and unblocks - # the RPC thread in the workers so that they can process any other - # queued control plane messages, such as add/remove lora adapters. - logger.debug("Stopping remote worker execution loop.") - self.model_executor.stop_remote_worker_execution_loop() - - return ctx.request_outputs - - def _has_remaining_steps( - self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] - ) -> bool: - if (not self.scheduler_config.is_multi_step - or not seq_group_metadata_list): - return False - - # TODO(will) this is a sanity check for nowto make sure that all the - # seqs are on the same steps. Eventually we will want to do some sort of - # dynamic scheduling when doing multi-step decoding. - ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps - if any([ - seq_group.state.remaining_steps != ref_remaining_steps - for seq_group in seq_group_metadata_list[1:] - ]): - raise AssertionError(("All running sequence groups should " - "have the same remaining steps.")) - - return ref_remaining_steps > 0 - - def _cache_scheduler_outputs_for_multi_step( - self, virtual_engine: int, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], - scheduler_outputs: SchedulerOutputs, - allow_async_output_proc: bool) -> None: - co = self.cached_scheduler_outputs[virtual_engine] - - co.seq_group_metadata_list = seq_group_metadata_list - co.scheduler_outputs = scheduler_outputs - co.allow_async_output_proc = allow_async_output_proc - co.last_output = None - - def _update_cached_scheduler_output( - self, virtual_engine: int, - output: List[Optional[SamplerOutput]]) -> None: - if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0 - and output[0] is not None): - last_output = output[-1] - assert last_output is not None - assert last_output.sampled_token_ids_cpu is not None - assert last_output.sampled_token_ids is None - assert last_output.sampled_token_probs is None - self.cached_scheduler_outputs[ - virtual_engine].last_output = last_output - - def _get_last_sampled_token_ids( - self, virtual_engine: int) -> Optional[torch.Tensor]: - cached_last_output = self.cached_scheduler_outputs[ - virtual_engine].last_output - if (self.scheduler_config.is_multi_step - and self.parallel_config.pipeline_parallel_size > 1 - and cached_last_output is not None - and cached_last_output.sampled_token_ids_cpu is not None): - return cached_last_output.sampled_token_ids_cpu - return None - - def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None: - if not self.log_stats: - raise RuntimeError( - "Stat logging is disabled. Set `disable_log_stats=False` " - "argument to enable.") - if logger_name in self.stat_loggers: - raise KeyError(f"Logger with name {logger_name} already exists.") - self.stat_loggers[logger_name] = logger - - def remove_logger(self, logger_name: str) -> None: - if not self.log_stats: - raise RuntimeError( - "Stat logging is disabled. Set `disable_log_stats=False` " - "argument to enable.") - if logger_name not in self.stat_loggers: - raise KeyError(f"Logger with name {logger_name} does not exist.") - del self.stat_loggers[logger_name] - - def do_log_stats(self, - scheduler_outputs: Optional[SchedulerOutputs] = None, - model_output: Optional[List[SamplerOutput]] = None, - finished_before: Optional[List[int]] = None, - skip: Optional[List[int]] = None) -> None: - """Forced log when no requests active.""" - if self.log_stats: - stats = self._get_stats(scheduler_outputs, model_output, - finished_before, skip) - for logger in self.stat_loggers.values(): - logger.log(stats) - - def _get_stats(self, - scheduler_outputs: Optional[SchedulerOutputs], - model_output: Optional[List[SamplerOutput]] = None, - finished_before: Optional[List[int]] = None, - skip: Optional[List[int]] = None) -> Stats: - """Get Stats to be Logged to Prometheus. - - Args: - scheduler_outputs: Optional, used to populate metrics related to - the scheduled batch, - model_output: Optional, used to emit speculative decoding metrics - which are created by the workers. - finished_before: Optional, indices of sequences that were finished - before. These sequences will be ignored. - skip: Optional, indices of sequences that were preempted. These - sequences will be ignored. - """ - now = time.time() - - # System State - # Scheduler State - num_running_sys = sum( - len(scheduler.running) for scheduler in self.scheduler) - num_swapped_sys = sum( - len(scheduler.swapped) for scheduler in self.scheduler) - num_waiting_sys = sum( - len(scheduler.waiting) for scheduler in self.scheduler) - - # KV Cache Usage in % - num_total_gpu = self.cache_config.num_gpu_blocks - gpu_cache_usage_sys = 0. - if num_total_gpu is not None: - num_free_gpu = sum( - scheduler.block_manager.get_num_free_gpu_blocks() - for scheduler in self.scheduler) - gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu) - - num_total_cpu = self.cache_config.num_cpu_blocks - cpu_cache_usage_sys = 0. - if num_total_cpu is not None and num_total_cpu > 0: - num_free_cpu = sum( - scheduler.block_manager.get_num_free_cpu_blocks() - for scheduler in self.scheduler) - cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu) - - # Prefix Cache Hit Rate. Note that we always use - # the cache hit rate of the first virtual engine. - cpu_prefix_cache_hit_rate = self.scheduler[ - 0].get_prefix_cache_hit_rate(Device.CPU) - gpu_prefix_cache_hit_rate = self.scheduler[ - 0].get_prefix_cache_hit_rate(Device.GPU) - - # Iteration stats - num_prompt_tokens_iter = 0 - num_generation_tokens_iter = 0 - time_to_first_tokens_iter: List[float] = [] - time_per_output_tokens_iter: List[float] = [] - num_preemption_iter = (0 if scheduler_outputs is None else - scheduler_outputs.preempted) - - # Request stats - # Latency - time_e2e_requests: List[float] = [] - # Metadata - num_prompt_tokens_requests: List[int] = [] - num_generation_tokens_requests: List[int] = [] - best_of_requests: List[int] = [] - n_requests: List[int] = [] - finished_reason_requests: List[str] = [] - - # NOTE: This loop assumes prefill seq_groups are before - # decode seq_groups in scheduled_seq_groups. - if scheduler_outputs is not None: - # For async postprocessor, already finished sequences need to be - # not counted (to avoid double counting) - actual_num_batched_tokens = scheduler_outputs.num_batched_tokens # type: ignore - - num_generation_tokens_from_prefill_groups = 0. - # NOTE: if scheduler_outputs.num_prefill_groups > 0 and - # the len of scheduler_outputs.scheduled_seq_groups is != - # scheduler_outputs.num_prefill_groups, this means that - # chunked prefills have been detected. - - for idx, scheduled_seq_group in enumerate( - scheduler_outputs.scheduled_seq_groups): - # Skip double logging when using async output proc - if finished_before and idx in finished_before: - actual_num_batched_tokens -= 1 - continue - - # Currently, skip == preempted sequences, so we need to skip - # their log stats - if skip and idx in skip: - continue - - group_was_prefill = idx < scheduler_outputs.num_prefill_groups - seq_group = scheduled_seq_group.seq_group - - # NOTE: a seq_group that completed all of its prefill tokens - # in the last iteration will have seq_group.is_prefill() = False - # with group_was_prefill = True - if group_was_prefill: - # Number of prompt tokens. - num_prompt_tokens_iter += ( - scheduled_seq_group.token_chunk_size) - - # If the seq_group just finished the prefill state - # get TTFT. - if not seq_group.is_prefill(): - latency = seq_group.get_last_latency(now) - time_to_first_tokens_iter.append(latency) - - # One generation token per finished prefill. - num_generation_tokens_from_prefill_groups += ( - seq_group.num_seqs()) - else: - # TPOTs. - latency = seq_group.get_last_latency(now) - time_per_output_tokens_iter.append(latency) - - # Because of chunked prefill, we can have a single sequence - # group that does multiple prompt_runs. To prevent logging - # the same metadata more than once per request, we standardize - # on logging request level information for finished requests, - # which can only happen once. - if seq_group.is_finished(): - # Latency timings - time_e2e_requests.append(now - - seq_group.metrics.arrival_time) - # Metadata - num_prompt_tokens_requests.append( - len(seq_group.prompt_token_ids)) - num_generation_tokens_requests.extend([ - seq.get_output_len() - for seq in seq_group.get_finished_seqs() - ]) - if seq_group.sampling_params is not None: - best_of_requests.append( - seq_group.sampling_params.best_of) - n_requests.append(seq_group.sampling_params.n) - finished_reason_requests.extend([ - SequenceStatus.get_finished_reason(seq.status) - for seq in seq_group.get_finished_seqs() - ]) - - # Number of generation tokens. - # num_batched_tokens equals the number of prompt_tokens plus the - # number of decode_tokens in a single iteration. So, - # num_generation_tokens = num_batched_tokens - num_prompt_tokens - # + num_generation_tokens_from_prefill_groups (since we generate - # one token on prefills on iters where the prefill finishes). - num_generation_tokens_iter = ( - actual_num_batched_tokens - num_prompt_tokens_iter + - num_generation_tokens_from_prefill_groups) - - # Spec decode, if enabled, emits specialized metrics from the worker in - # sampler output. - if model_output and (model_output[0].spec_decode_worker_metrics - is not None): - spec_decode_metrics = model_output[0].spec_decode_worker_metrics - else: - spec_decode_metrics = None - - return Stats( - now=now, - # System stats - # Scheduler State - num_running_sys=num_running_sys, - num_swapped_sys=num_swapped_sys, - num_waiting_sys=num_waiting_sys, - # KV Cache Usage in % - gpu_cache_usage_sys=gpu_cache_usage_sys, - cpu_cache_usage_sys=cpu_cache_usage_sys, - # Prefix Cache Hit Rate - cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate, - gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate, - - # Iteration stats - num_prompt_tokens_iter=num_prompt_tokens_iter, - num_generation_tokens_iter=num_generation_tokens_iter, - time_to_first_tokens_iter=time_to_first_tokens_iter, - time_per_output_tokens_iter=time_per_output_tokens_iter, - spec_decode_metrics=spec_decode_metrics, - num_preemption_iter=num_preemption_iter, - - # Request stats - # Latency - time_e2e_requests=time_e2e_requests, - # Metadata - num_prompt_tokens_requests=num_prompt_tokens_requests, - num_generation_tokens_requests=num_generation_tokens_requests, - best_of_requests=best_of_requests, - n_requests=n_requests, - finished_reason_requests=finished_reason_requests, - ) - - def add_lora(self, lora_request: LoRARequest) -> bool: - return self.model_executor.add_lora(lora_request) - - def remove_lora(self, lora_id: int) -> bool: - return self.model_executor.remove_lora(lora_id) - - def list_loras(self) -> Set[int]: - return self.model_executor.list_loras() - - def pin_lora(self, lora_id: int) -> bool: - return self.model_executor.pin_lora(lora_id) - - def add_prompt_adapter( - self, prompt_adapter_request: PromptAdapterRequest) -> bool: - return self.model_executor.add_prompt_adapter(prompt_adapter_request) - - def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: - return self.model_executor.remove_prompt_adapter(prompt_adapter_id) - - def list_prompt_adapters(self) -> List[int]: - return self.model_executor.list_prompt_adapters() - - def check_health(self) -> None: - if self.tokenizer: - self.tokenizer.check_health() - self.model_executor.check_health() - - def start_profile(self) -> None: - # using type instead of isinstance to check to avoid capturing - # inherited classes (MultiprocessingGPUExecutor) - if type(self.model_executor) == GPUExecutor: # noqa: E721 - self.model_executor.start_profile() - else: - self.model_executor._run_workers("start_profile") - - def stop_profile(self) -> None: - # using type instead of isinstance to check to avoid capturing - # inherited classes (MultiprocessingGPUExecutor) - if type(self.model_executor) == GPUExecutor: # noqa: E721 - self.model_executor.stop_profile() - else: - self.model_executor._run_workers("stop_profile") - - def is_tracing_enabled(self) -> bool: - return self.tracer is not None - - def do_tracing(self, scheduler_outputs: SchedulerOutputs) -> None: - if self.tracer is None: - return - - for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups: - seq_group = scheduled_seq_group.seq_group - if seq_group.is_finished(): - self.create_trace_span(seq_group) - - def create_trace_span(self, seq_group: SequenceGroup) -> None: - if self.tracer is None or seq_group.sampling_params is None: - return - arrival_time_nano_seconds = int(seq_group.metrics.arrival_time * 1e9) - - trace_context = extract_trace_context(seq_group.trace_headers) - - with self.tracer.start_as_current_span( - "llm_request", - kind=SpanKind.SERVER, - context=trace_context, - start_time=arrival_time_nano_seconds) as seq_span: - metrics = seq_group.metrics - ttft = metrics.first_token_time - metrics.arrival_time - e2e_time = metrics.finished_time - metrics.arrival_time - # attribute names are based on - # https://github.com/open-telemetry/semantic-conventions/blob/main/docs/gen-ai/llm-spans.md - seq_span.set_attribute(SpanAttributes.LLM_RESPONSE_MODEL, - self.model_config.model) - seq_span.set_attribute(SpanAttributes.LLM_REQUEST_ID, - seq_group.request_id) - seq_span.set_attribute(SpanAttributes.LLM_REQUEST_TEMPERATURE, - seq_group.sampling_params.temperature) - seq_span.set_attribute(SpanAttributes.LLM_REQUEST_TOP_P, - seq_group.sampling_params.top_p) - seq_span.set_attribute(SpanAttributes.LLM_REQUEST_MAX_TOKENS, - seq_group.sampling_params.max_tokens) - seq_span.set_attribute(SpanAttributes.LLM_REQUEST_BEST_OF, - seq_group.sampling_params.best_of) - seq_span.set_attribute(SpanAttributes.LLM_REQUEST_N, - seq_group.sampling_params.n) - seq_span.set_attribute(SpanAttributes.LLM_USAGE_NUM_SEQUENCES, - seq_group.num_seqs()) - seq_span.set_attribute(SpanAttributes.LLM_USAGE_PROMPT_TOKENS, - len(seq_group.prompt_token_ids)) - seq_span.set_attribute( - SpanAttributes.LLM_USAGE_COMPLETION_TOKENS, - sum([ - seq.get_output_len() - for seq in seq_group.get_finished_seqs() - ])) - seq_span.set_attribute(SpanAttributes.LLM_LATENCY_TIME_IN_QUEUE, - metrics.time_in_queue) - seq_span.set_attribute( - SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN, ttft) - seq_span.set_attribute(SpanAttributes.LLM_LATENCY_E2E, e2e_time) - if metrics.scheduler_time is not None: - seq_span.set_attribute( - SpanAttributes.LLM_LATENCY_TIME_IN_SCHEDULER, - metrics.scheduler_time) - if metrics.model_forward_time is not None: - seq_span.set_attribute( - SpanAttributes.LLM_LATENCY_TIME_IN_MODEL_FORWARD, - metrics.model_forward_time / 1000.0) - if metrics.model_execute_time is not None: - seq_span.set_attribute( - SpanAttributes.LLM_LATENCY_TIME_IN_MODEL_EXECUTE, - metrics.model_execute_time) - - def is_encoder_decoder_model(self): - return self.input_preprocessor.is_encoder_decoder_model() - - def is_embedding_model(self): - return self.model_config.is_embedding_model - - def _validate_model_inputs(self, inputs: Union[LLMInputs, - EncoderDecoderLLMInputs]): - if self.is_encoder_decoder_model(): - prompt_ids = inputs.get("encoder_prompt_token_ids") - else: - prompt_ids = inputs.get("prompt_token_ids") - - if prompt_ids is None or len(prompt_ids) == 0: - raise ValueError("Prompt cannot be empty") - - if self.model_config.is_multimodal_model: - max_prompt_len = self.model_config.max_model_len - - if len(prompt_ids) > max_prompt_len: - raise ValueError( - f"The prompt (total length {len(prompt_ids)}) is too long " - f"to fit into the model (context length {max_prompt_len}). " - "Make sure that `max_model_len` is no smaller than the " - "number of text tokens plus multimodal tokens. For image " - "inputs, the number of image tokens depends on the number " - "of images, and possibly their aspect ratios as well.") - - # TODO: Find out how many placeholder tokens are there so we can - # check that chunked prefill does not truncate them - # max_batch_len = self.scheduler_config.max_num_batched_tokens diff --git a/vllm/engine/llm_engine_v2.py b/vllm/engine/llm_engine_v2.py new file mode 100644 index 0000000000000..e828ee440eb9c --- /dev/null +++ b/vllm/engine/llm_engine_v2.py @@ -0,0 +1,665 @@ +import time +from collections import deque +from contextlib import contextmanager +from dataclasses import dataclass +from functools import partial +from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict, + Iterable, List, Mapping, NamedTuple, Optional) +from typing import Sequence as GenericSequence +from typing import Set, Type, Union + +import torch +from typing_extensions import TypeVar + +import vllm.envs as envs +from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, + EngineConfig, LoadConfig, LoRAConfig, ModelConfig, + ObservabilityConfig, ParallelConfig, + PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig) +from vllm.core.scheduler_v2 import Scheduler +from vllm.engine.arg_utils import EngineArgs +from vllm.engine.metrics_types import StatLoggerBase, Stats +from vllm.executor.executor_base import ExecutorBase +from vllm.executor.ray_utils import initialize_ray_cluster +from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, + InputRegistry, LLMInputs, PromptType) +from vllm.inputs.preprocess import InputPreprocessor +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, + RequestOutputFactory) +from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams +from vllm.request import Request +from vllm.transformers_utils.config import try_get_generation_config +from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.transformers_utils.tokenizer_group import ( + BaseTokenizerGroup, init_tokenizer_from_configs) +from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, + usage_message) +from vllm.version import __version__ as VLLM_VERSION +from vllm.request import Request + +logger = init_logger(__name__) +_LOCAL_LOGGING_INTERVAL_SEC = 5 + + +def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: + config = try_get_generation_config( + model_config.model, + trust_remote_code=model_config.trust_remote_code, + revision=model_config.revision, + ) + + if config is None: + return {} + + return config.to_diff_dict() + + +_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup) +_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput) + + +class LLMEngine: + """An LLM engine that receives requests and generates texts. + + This is the main class for the vLLM engine. It receives requests + from clients and generates texts from the LLM. It includes a tokenizer, a + language model (possibly distributed across multiple GPUs), and GPU memory + space allocated for intermediate states (aka KV cache). This class utilizes + iteration-level scheduling and efficient memory management to maximize the + serving throughput. + + The :class:`~vllm.LLM` class wraps this class for offline batched inference + and the :class:`AsyncLLMEngine` class wraps this class for online serving. + + The config arguments are derived from :class:`~vllm.EngineArgs`. (See + :ref:`engine_args`) + + Args: + model_config: The configuration related to the LLM model. + cache_config: The configuration related to the KV cache memory + management. + parallel_config: The configuration related to distributed execution. + scheduler_config: The configuration related to the request scheduler. + device_config: The configuration related to the device. + lora_config (Optional): The configuration related to serving multi-LoRA. + speculative_config (Optional): The configuration related to speculative + decoding. + executor_class: The model executor class for managing distributed + execution. + prompt_adapter_config (Optional): The configuration related to serving + prompt adapters. + log_stats: Whether to log statistics. + usage_context: Specified entry point, used for usage info collection. + """ + + def __init__( + self, + model_config: ModelConfig, + cache_config: CacheConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + speculative_config: Optional[SpeculativeConfig], + decoding_config: Optional[DecodingConfig], + observability_config: Optional[ObservabilityConfig], + prompt_adapter_config: Optional[PromptAdapterConfig], + executor_class: Type[ExecutorBase], + log_stats: bool, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + input_registry: InputRegistry = INPUT_REGISTRY, + ) -> None: + logger.info( + "Initializing an LLM engine (v%s) with config: " + "model=%r, speculative_config=%r, tokenizer=%r, " + "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " + "override_neuron_config=%s, " + "rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, " + "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " + "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " + "pipeline_parallel_size=%d, " + "disable_custom_all_reduce=%s, quantization=%s, " + "enforce_eager=%s, kv_cache_dtype=%s, " + "quantization_param_path=%s, device_config=%s, " + "decoding_config=%r, observability_config=%r, " + "seed=%d, served_model_name=%s, use_v2_block_manager=%s, " + "num_scheduler_steps=%d, enable_prefix_caching=%s, " + "use_async_output_proc=%s, mm_processor_kwargs=%s)", + VLLM_VERSION, + model_config.model, + speculative_config, + model_config.tokenizer, + model_config.skip_tokenizer_init, + model_config.tokenizer_mode, + model_config.revision, + model_config.override_neuron_config, + model_config.rope_scaling, + model_config.rope_theta, + model_config.tokenizer_revision, + model_config.trust_remote_code, + model_config.dtype, + model_config.max_model_len, + load_config.download_dir, + load_config.load_format, + parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size, + parallel_config.disable_custom_all_reduce, + model_config.quantization, + model_config.enforce_eager, + cache_config.cache_dtype, + model_config.quantization_param_path, + device_config.device, + decoding_config, + observability_config, + model_config.seed, + model_config.served_model_name, + scheduler_config.use_v2_block_manager, + scheduler_config.num_scheduler_steps, + cache_config.enable_prefix_caching, + model_config.use_async_output_proc, + model_config.mm_processor_kwargs, + ) + # TODO(woosuk): Print more configs in debug mode. + from vllm.plugins import load_general_plugins + load_general_plugins() + + self.model_config = model_config + self.cache_config = cache_config + self.lora_config = lora_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.speculative_config = speculative_config + self.load_config = load_config + self.decoding_config = decoding_config or DecodingConfig() + self.prompt_adapter_config = prompt_adapter_config + self.observability_config = observability_config or ObservabilityConfig( + ) + self.log_stats = log_stats + + if not self.model_config.skip_tokenizer_init: + self.tokenizer = self._init_tokenizer() + self.detokenizer = Detokenizer(self.tokenizer) + tokenizer_group = self.get_tokenizer_group() + else: + self.tokenizer = None + self.detokenizer = None + tokenizer_group = None + + self.generation_config_fields = _load_generation_config_dict( + model_config) + + self.input_preprocessor = InputPreprocessor(model_config, + self.tokenizer) + + self.input_registry = input_registry + self.input_processor = input_registry.create_input_processor( + model_config) + + self.model_executor = executor_class( + model_config=model_config, + cache_config=cache_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + device_config=device_config, + lora_config=lora_config, + speculative_config=speculative_config, + load_config=load_config, + prompt_adapter_config=prompt_adapter_config, + observability_config=self.observability_config, + ) + + if not self.model_config.embedding_mode: + self._initialize_kv_caches() + + # If usage stat is enabled, collect relevant info. + if is_usage_stats_enabled(): + from vllm.model_executor.model_loader import ( + get_architecture_class_name) + usage_message.report_usage( + get_architecture_class_name(model_config), + usage_context, + extra_kvs={ + # Common configuration + "dtype": + str(model_config.dtype), + "tensor_parallel_size": + parallel_config.tensor_parallel_size, + "block_size": + cache_config.block_size, + "gpu_memory_utilization": + cache_config.gpu_memory_utilization, + + # Quantization + "quantization": + model_config.quantization, + "kv_cache_dtype": + str(cache_config.cache_dtype), + + # Feature flags + "enable_lora": + bool(lora_config), + "enable_prompt_adapter": + bool(prompt_adapter_config), + "enable_prefix_caching": + cache_config.enable_prefix_caching, + "enforce_eager": + model_config.enforce_eager, + "disable_custom_all_reduce": + parallel_config.disable_custom_all_reduce, + }) + + if self.tokenizer: + # Ping the tokenizer to ensure liveness if it runs in a + # different process. + self.tokenizer.ping() + + # Create the scheduler. + # NOTE: the cache_config here have been updated with the numbers of + # GPU and CPU blocks, which are profiled in the distributed executor. + self.scheduler = Scheduler( + scheduler_config, cache_config, lora_config) + + # Metric Logging. + if self.log_stats: + if stat_loggers is not None: + self.stat_loggers = stat_loggers + else: + # Lazy import for prometheus multiprocessing. + # We need to set PROMETHEUS_MULTIPROC_DIR environment variable + # before prometheus_client is imported. + # See https://prometheus.github.io/client_python/multiprocess/ + from vllm.engine.metrics import (LoggingStatLogger, + PrometheusStatLogger) + + self.stat_loggers = { + "logging": + LoggingStatLogger( + local_interval=_LOCAL_LOGGING_INTERVAL_SEC), + "prometheus": + PrometheusStatLogger( + local_interval=_LOCAL_LOGGING_INTERVAL_SEC, + labels=dict(model_name=model_config.served_model_name), + max_model_len=self.model_config.max_model_len), + } + self.stat_loggers["prometheus"].info("cache_config", + self.cache_config) + + def _initialize_kv_caches(self) -> None: + """Initialize the KV cache in the worker(s). + + The workers will determine the number of blocks in both the GPU cache + and the swap CPU cache. + """ + num_gpu_blocks, num_cpu_blocks = ( + self.model_executor.determine_num_available_blocks()) + + if self.cache_config.num_gpu_blocks_override is not None: + num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override + logger.info( + "Overriding num_gpu_blocks=%d with " + "num_gpu_blocks_override=%d", num_gpu_blocks, + num_gpu_blocks_override) + num_gpu_blocks = num_gpu_blocks_override + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks) + + @classmethod + def from_engine_args( + cls, + engine_args: EngineArgs, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + ) -> "LLMEngine": + """Creates an LLM engine from the engine arguments.""" + # Create the engine configs. + engine_config = engine_args.create_engine_config() + executor_class = cls._get_executor_cls(engine_config) + # Create the LLM engine. + engine = cls( + **engine_config.to_dict(), + executor_class=executor_class, + log_stats=not engine_args.disable_log_stats, + usage_context=usage_context, + stat_loggers=stat_loggers, + ) + + return engine + + def __reduce__(self): + # This is to ensure that the LLMEngine is not referenced in + # the closure used to initialize Ray worker actors + raise RuntimeError("LLMEngine should not be pickled!") + + def __del__(self): + # Shutdown model executor when engine is garbage collected + # Use getattr since __init__ can fail before the field is set + if model_executor := getattr(self, "model_executor", None): + model_executor.shutdown() + + def get_tokenizer_group( + self, + group_type: Type[_G] = BaseTokenizerGroup, + ) -> _G: + tokenizer_group = self.tokenizer + + if tokenizer_group is None: + raise ValueError("Unable to get tokenizer because " + "skip_tokenizer_init is True") + if not isinstance(tokenizer_group, group_type): + raise TypeError("Invalid type of tokenizer group. " + f"Expected type: {group_type}, but " + f"found type: {type(tokenizer_group)}") + + return tokenizer_group + + def get_tokenizer( + self, + lora_request: Optional[LoRARequest] = None, + ) -> AnyTokenizer: + return self.get_tokenizer_group().get_lora_tokenizer(lora_request) + + def _init_tokenizer(self) -> BaseTokenizerGroup: + return init_tokenizer_from_configs( + model_config=self.model_config, + scheduler_config=self.scheduler_config, + parallel_config=self.parallel_config, + enable_lora=bool(self.lora_config)) + + def _verify_args(self) -> None: + self.model_config.verify_with_parallel_config(self.parallel_config) + self.cache_config.verify_with_parallel_config(self.parallel_config) + if self.lora_config: + self.lora_config.verify_with_model_config(self.model_config) + self.lora_config.verify_with_scheduler_config( + self.scheduler_config) + if self.prompt_adapter_config: + self.prompt_adapter_config.verify_with_model_config( + self.model_config) + + def _add_processed_request( + self, + request_id: str, + processed_inputs: Union[LLMInputs, EncoderDecoderLLMInputs], + params: Union[SamplingParams, PoolingParams], + arrival_time: float, + lora_request: Optional[LoRARequest], + prompt_adapter_request: Optional[PromptAdapterRequest], + trace_headers: Optional[Mapping[str, str]] = None, + ) -> None: + self._validate_model_inputs(processed_inputs) + eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) + + # TODO(woosuk): Support embedding mode. + assert isinstance(params, SamplingParams) + sampling_params = params.clone() + sampling_params.update_from_generation_config( + self.generation_config_fields, eos_token_id) + + # TODO(woosuk): Check max_logprobs + # TODO(woosuk): Support encoder-decoder models. + req = Request(request_id, processed_inputs, arrival_time, sampling_params=params) + self.scheduler.add_req(req) + + def stop_remote_worker_execution_loop(self) -> None: + self.model_executor.stop_remote_worker_execution_loop() + + def add_request( + self, + request_id: str, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> None: + """Add a request to the engine's request pool. + + The request is added to the request pool and will be processed by the + scheduler as `engine.step()` is called. The exact scheduling policy is + determined by the scheduler. + + Args: + request_id: The unique ID of the request. + prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` + for more details about the format of each input. + params: Parameters for sampling or pooling. + :class:`~vllm.SamplingParams` for text generation. + :class:`~vllm.PoolingParams` for pooling. + arrival_time: The arrival time of the request. If None, we use + the current monotonic time. + trace_headers: OpenTelemetry trace headers. + + Details: + - Set arrival_time to the current time if it is None. + - Set prompt_token_ids to the encoded prompt if it is None. + - Create `best_of` number of :class:`~vllm.Sequence` objects. + - Create a :class:`~vllm.SequenceGroup` object + from the list of :class:`~vllm.Sequence`. + - Add the :class:`~vllm.SequenceGroup` object to the scheduler. + + Example: + >>> # initialize engine + >>> engine = LLMEngine.from_engine_args(engine_args) + >>> # set request arguments + >>> example_prompt = "Who is the president of the United States?" + >>> sampling_params = SamplingParams(temperature=0.0) + >>> request_id = 0 + >>> + >>> # add the request to the engine + >>> engine.add_request( + >>> str(request_id), + >>> example_prompt, + >>> SamplingParams(temperature=0.0)) + >>> # continue the request processing + >>> ... + """ + if lora_request is not None and not self.lora_config: + raise ValueError(f"Got lora_request {lora_request} but LoRA is " + "not enabled!") + if arrival_time is None: + arrival_time = time.time() + + preprocessed_inputs = self.input_preprocessor.preprocess( + prompt, + request_id=request_id, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) + processed_inputs = self.input_processor(preprocessed_inputs) + + self._add_processed_request( + request_id=request_id, + processed_inputs=processed_inputs, + params=params, + arrival_time=arrival_time, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + trace_headers=trace_headers, + ) + + def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: + """Aborts a request(s) with the given ID. + + Args: + request_id: The ID(s) of the request to abort. + + Details: + - Refer to the + :meth:`~vllm.core.scheduler.Scheduler.abort_seq_group` + from class :class:`~vllm.core.scheduler.Scheduler`. + + Example: + >>> # initialize engine and add a request with request_id + >>> request_id = str(0) + >>> # abort the request + >>> engine.abort_request(request_id) + """ + self.scheduler.abort_reqs(request_id) + + def get_model_config(self) -> ModelConfig: + """Gets the model configuration.""" + return self.model_config + + def get_parallel_config(self) -> ParallelConfig: + """Gets the parallel configuration.""" + return self.parallel_config + + def get_decoding_config(self) -> DecodingConfig: + """Gets the decoding configuration.""" + return self.decoding_config + + def get_scheduler_config(self) -> SchedulerConfig: + """Gets the scheduler configuration.""" + return self.scheduler_config + + def get_lora_config(self) -> LoRAConfig: + """Gets the LoRA configuration.""" + return self.lora_config + + def get_num_unfinished_requests(self) -> int: + """Gets the number of unfinished requests.""" + return self.scheduler.get_num_unfinished_reqs() + + def has_unfinished_requests(self) -> bool: + """Returns True if there are unfinished requests.""" + return self.scheduler.has_unfinished_req() + + def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: + scheduler_output = self.scheduler.schedule() + sampler_output = self.model_executor.execute_model(scheduler_output) + self._process_model_outputs(sampler_output) + + if not self.has_unfinished_requests(): + # Stop the execute model loop in parallel workers until there are + # more requests to process. This avoids waiting indefinitely in + # torch.distributed ops which may otherwise timeout, and unblocks + # the RPC thread in the workers so that they can process any other + # queued control plane messages, such as add/remove lora adapters. + logger.debug("Stopping remote worker execution loop.") + self.model_executor.stop_remote_worker_execution_loop() + return sampler_output + + def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None: + if not self.log_stats: + raise RuntimeError( + "Stat logging is disabled. Set `disable_log_stats=False` " + "argument to enable.") + if logger_name in self.stat_loggers: + raise KeyError(f"Logger with name {logger_name} already exists.") + self.stat_loggers[logger_name] = logger + + def remove_logger(self, logger_name: str) -> None: + if not self.log_stats: + raise RuntimeError( + "Stat logging is disabled. Set `disable_log_stats=False` " + "argument to enable.") + if logger_name not in self.stat_loggers: + raise KeyError(f"Logger with name {logger_name} does not exist.") + del self.stat_loggers[logger_name] + + def check_health(self) -> None: + if self.tokenizer: + self.tokenizer.check_health() + self.model_executor.check_health() + + def _validate_model_inputs(self, inputs: Union[LLMInputs, + EncoderDecoderLLMInputs]): + if self.is_encoder_decoder_model(): + prompt_ids = inputs.get("encoder_prompt_token_ids") + else: + prompt_ids = inputs.get("prompt_token_ids") + + if prompt_ids is None or len(prompt_ids) == 0: + raise ValueError("Prompt cannot be empty") + + if self.model_config.is_multimodal_model: + max_prompt_len = self.model_config.max_model_len + + if len(prompt_ids) > max_prompt_len: + raise ValueError( + f"The prompt (total length {len(prompt_ids)}) is too long " + f"to fit into the model (context length {max_prompt_len}). " + "Make sure that `max_model_len` is no smaller than the " + "number of text tokens plus multimodal tokens. For image " + "inputs, the number of image tokens depends on the number " + "of images, and possibly their aspect ratios as well.") + + # TODO: Find out how many placeholder tokens are there so we can + # check that chunked prefill does not truncate them + # max_batch_len = self.scheduler_config.max_num_batched_tokens + + +def _get_executor_cls(engine_config: EngineConfig) -> Type[ExecutorBase]: + distributed_executor_backend = ( + engine_config.parallel_config.distributed_executor_backend) + # Initialize the cluster and specify the executor class. + if isinstance(distributed_executor_backend, type): + if not issubclass(distributed_executor_backend, ExecutorBase): + raise TypeError( + "distributed_executor_backend must be a subclass of " + f"ExecutorBase. Got {distributed_executor_backend}.") + if distributed_executor_backend.uses_ray: # type: ignore + initialize_ray_cluster(engine_config.parallel_config) + executor_class = distributed_executor_backend + elif engine_config.device_config.device_type == "neuron": + from vllm.executor.neuron_executor import NeuronExecutor + executor_class = NeuronExecutor + elif engine_config.device_config.device_type == "tpu": + if distributed_executor_backend == "ray": + initialize_ray_cluster(engine_config.parallel_config) + from vllm.executor.ray_tpu_executor import RayTPUExecutor + executor_class = RayTPUExecutor + else: + assert distributed_executor_backend is None + from vllm.executor.tpu_executor import TPUExecutor + executor_class = TPUExecutor + elif engine_config.device_config.device_type == "cpu": + from vllm.executor.cpu_executor import CPUExecutor + executor_class = CPUExecutor + elif engine_config.device_config.device_type == "openvino": + from vllm.executor.openvino_executor import OpenVINOExecutor + executor_class = OpenVINOExecutor + elif engine_config.device_config.device_type == "xpu": + if distributed_executor_backend == "ray": + initialize_ray_cluster(engine_config.parallel_config) + from vllm.executor.ray_xpu_executor import RayXPUExecutor + executor_class = RayXPUExecutor + elif distributed_executor_backend == "mp": + # FIXME(kunshang): + # spawn needs calling `if __name__ == '__main__':`` + # fork is not supported for xpu start new process. + logger.error( + "Both start methods (spawn and fork) have issue " + "on XPU if you use mp backend, Please try ray instead.") + else: + from vllm.executor.xpu_executor import XPUExecutor + executor_class = XPUExecutor + elif distributed_executor_backend == "ray": + initialize_ray_cluster(engine_config.parallel_config) + from vllm.executor.ray_gpu_executor import RayGPUExecutor + executor_class = RayGPUExecutor + elif distributed_executor_backend == "mp": + from vllm.executor.multiproc_gpu_executor import ( + MultiprocessingGPUExecutor) + assert not envs.VLLM_USE_RAY_SPMD_WORKER, ( + "multiprocessing distributed executor backend does not " + "support VLLM_USE_RAY_SPMD_WORKER=1") + executor_class = MultiprocessingGPUExecutor + else: + from vllm.executor.gpu_executor import GPUExecutor + executor_class = GPUExecutor + return executor_class diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index a86c51d23b34d..03a58ddbf2a94 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -5,7 +5,7 @@ from tqdm import tqdm from vllm.engine.arg_utils import EngineArgs -from vllm.engine.llm_engine import LLMEngine +from vllm.engine.llm_engine_v2 import LLMEngine from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, apply_hf_chat_template, apply_mistral_chat_template, diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index c96cb0f2c2981..f82b13d3664ab 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -8,7 +8,6 @@ from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sequence import ExecuteModelRequest class ExecutorBase(ABC): @@ -75,7 +74,8 @@ def initialize_cache(self, num_gpu_blocks: int, @abstractmethod def execute_model( - self, execute_model_req: ExecuteModelRequest + self, + scheduler_output, ) -> Optional[List[SamplerOutput]]: """Executes at least one model step on the given sequences.""" raise NotImplementedError @@ -134,9 +134,8 @@ def __del__(self): class ExecutorAsyncBase(ExecutorBase): @abstractmethod - async def execute_model_async( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + async def execute_model_async(self, + scheduler_output) -> List[SamplerOutput]: """Executes one model step on the given sequences.""" raise NotImplementedError diff --git a/vllm/model_executor/__init__.py b/vllm/model_executor/__init__.py index 7278c7fbe8bea..5c767e22de4d0 100644 --- a/vllm/model_executor/__init__.py +++ b/vllm/model_executor/__init__.py @@ -1,12 +1,10 @@ from vllm.model_executor.parameter import (BasevLLMParameter, PackedvLLMParameter) -from vllm.model_executor.sampling_metadata import (SamplingMetadata, - SamplingMetadataCache) +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed __all__ = [ "SamplingMetadata", - "SamplingMetadataCache", "set_random_seed", "BasevLLMParameter", "PackedvLLMParameter", diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 2ca86a4653cf4..4ab2bb137488b 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -1,8 +1,5 @@ """A layer that samples the next tokens from the model's outputs.""" -import itertools -import warnings from dataclasses import dataclass -from importlib.util import find_spec from math import inf from typing import Dict, List, Optional, Tuple, Union @@ -11,1312 +8,112 @@ import torch.nn as nn import vllm.envs as envs -from vllm.model_executor.sampling_metadata import (SamplingMetadata, - SamplingTensors, - SequenceGroupToSample) -from vllm.sampling_params import SamplingType -from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, - PromptLogprobs, SampleLogprobs, SequenceOutput) -from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics +from vllm.sampler_output import SamplerOutput +from vllm.model_executor.sampling_metadata import SamplingMetadata -if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): - import flashinfer.sampling - # yapf: disable - from flashinfer.sampling import ( - top_k_top_p_sampling_from_probs as flashinfer_top_k_top_p_sampling) - - # yapf: enable -else: - flashinfer_top_k_top_p_sampling = None - -# (num_token_ids, num_parent_ids) per sequence group. -SampleResultType = List[Tuple[List[int], List[int]]] - -# Types of temporary data structures used for -# computing sample_result -SampleMetadataType = Dict[SamplingType, Tuple[List[int], - List[SequenceGroupToSample]]] -MultinomialSamplesType = Dict[SamplingType, torch.Tensor] -SampleResultsDictType = Dict[int, Tuple[List[int], List[int]]] - - -# Encapsulates temporary data structures for computing -# sample_result. -# -# * For multi-step scheduling: must be returned -# by `Sampler.forward()` and used later to compute the pythonized -# sample_result -# -# * For single-step scheduling: consumed immediately -# inside `Sampler.forward()` to compute pythonized sample_result. -@dataclass -class SampleResultArgsType: - sample_metadata: SampleMetadataType - multinomial_samples: MultinomialSamplesType - sample_results_dict: SampleResultsDictType - sampling_metadata: SamplingMetadata - greedy_samples: Optional[torch.Tensor] - beam_search_logprobs: Optional[torch.Tensor] - - -# Union of non-deferred (single-step scheduling) -# vs deferred (multi-step scheduling) -# sample result types -MaybeDeferredSampleResultType = Union[SampleResultType, SampleResultArgsType] - -# Abbreviation of the _sample() return type -SampleReturnType = Tuple[MaybeDeferredSampleResultType, Optional[torch.Tensor]] - - -class SamplerOutput( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - array_like=True): # type: ignore[call-arg] - """For each sequence group, we generate a list of SequenceOutput object, - each of which contains one possible candidate for the next token. - - This data structure implements methods, so it can be used like a list, but - also has optional fields for device tensors. - """ - - outputs: List[CompletionSequenceGroupOutput] - - # On-device tensor containing probabilities of each token. - sampled_token_probs: Optional[torch.Tensor] = None - - # On-device tensor containing the logprobs of each token. - logprobs: Optional["torch.Tensor"] = None - - # Holds either (1) the pythonized sampler result (single-step scheduling) - # or (2) what will be arguments for later deferred pythonization of the - # sampler result (muliti-step scheduling) - deferred_sample_results_args: Optional[SampleResultArgsType] = None - - # On-device tensor containing the sampled token ids. - sampled_token_ids: Optional[torch.Tensor] = None - # CPU tensor containing the sampled token ids. Used during multi-step to - # return the sampled token ids from last rank to AsyncLLMEngine to be - # 'broadcasted' to all other PP ranks for next step. - sampled_token_ids_cpu: Optional[torch.Tensor] = None - - # Spec decode metrics populated by workers. - spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None - - # Optional last hidden states from the model. - hidden_states: Optional[torch.Tensor] = None - - # Optional prefill hidden states from the model - # (used for models like EAGLE). - prefill_hidden_states: Optional[torch.Tensor] = None - - # Time taken in the forward pass for this across all workers - model_forward_time: Optional[float] = None - - # Time taken in the model execute function. This will include model forward, - # block/sync across workers, cpu-gpu sync time and sampling time. - model_execute_time: Optional[float] = None - - def __getitem__(self, idx: int): - return self.outputs[idx] - - def __setitem__(self, idx: int, value): - self.outputs[idx] = value - - def __len__(self): - return len(self.outputs) - - def __eq__(self, other: object): - return isinstance(other, - self.__class__) and self.outputs == other.outputs - - def __repr__(self) -> str: - """Show the shape of a tensor instead of its values to reduce noise. - """ - sampled_token_probs_repr = ("None" if self.sampled_token_probs is None - else self.sampled_token_probs.shape) - sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else - self.sampled_token_ids.shape) - return ( - f"SamplerOutput(outputs={self.outputs}, " - f"sampled_token_probs={sampled_token_probs_repr}, " - f"sampled_token_ids={sampled_token_ids_repr}, " - f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})") +_SAMPLING_EPS = 1e-5 class Sampler(nn.Module): - """Samples the next tokens from the model's outputs. - - This layer does the following: - 1. Discard the hidden states that are not used for sampling (i.e., all - tokens except the final one in each prompt). - 2. Compute the logits for the next tokens. - 3. Apply presence, frequency and repetition penalties. - 4. Apply temperature scaling. - 5. Apply top-p and top-k truncation. - 6. Sample the next tokens. - Here, each sequence group within the batch can have different sampling - parameters (e.g., sampling method, temperature, top-p, top-k, etc.). - - The structure of the logits tensor is coupled with the seq_groups in - sampling_metadata. Typically, each sequence in each seq_group has one row in - logits for the next token to be sampled; however, for a seq_group with a - prompt request with the prompt_logprobs sampling parameter, there are rows - in logits for each token in the input prompt. - """ def __init__(self): super().__init__() - # Whether or not the SamplerOutput should have on-device tensors - # containing the sampled token ids and probabilities. This is used by - # speculative decoding. - self.include_gpu_probs_tensor = False - self.should_modify_greedy_probs_inplace = False - - def _init_sampling_tensors( + def forward( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, - ): - """The goal here is to reuse sampling tensors between similar decode - runs. This is possible because sampling logic does not change between - decodes of the same sequences. - """ - _, vocab_size = logits.shape - - # First free any existing stored sampling tensors. - # This is necessary because some sampling tensors may - # have pinned memory. - self._sampling_tensors = None + ) -> SamplerOutput: + logits = self.apply_temperature(logits, sampling_metadata.temperature) + logits = self.apply_penalties(logits, sampling_metadata) - # Initialize new sampling tensors - (sampling_tensors, do_penalties, do_top_p_top_k, - do_min_p) = SamplingTensors.from_sampling_metadata( - sampling_metadata, vocab_size, logits.device, logits.dtype) + probs = self.get_probs(logits) + sampled = self.sample(probs, sampling_metadata) - self._sampling_tensors = sampling_tensors - self._do_penalties = do_penalties - self._do_top_p_top_k = do_top_p_top_k - self._do_min_p = do_min_p + if sampling_metadata.max_num_logprobs > 0: + logprobs = self.get_logprobs(logits) + topk_logprobs, topk_indices = torch.topk( + logprobs, sampling_metadata.max_num_logprobs, dim=-1) + else: + topk_logprobs = None + topk_indices = None + + sampler_output = SamplerOutput( + sampled_token_ids=sampled, + logprob_token_ids=topk_indices, + logprobs=topk_logprobs, + prompt_logprob_token_ids=None, + prompt_logprobs=None, + model_forward_time=0.0, + model_execute_time=0.0, + ) + return sampler_output - def forward( + def apply_temperature( self, logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - """ - Single-step scheduling: - * Perform GPU-side sampling computation & compute - GPU-side logprobs tensor - * Pythonize sampling result & logprobs tensor - - Multi-step scheduling: - * Perform GPU-side sampling computation & compute - GPU-side logprobs tensor - * Defer Pythonization of sampling result & logprobs - tensor - * Encapsulate arguments required for deferred Pythonization - in the :class:`SamplerOutput` structure - - Args: - logits: (num_tokens, vocab_size). - sampling_metadata: Metadata for sampling. - """ - assert logits is not None - _, vocab_size = logits.shape - - # Prepare sampling tensors with pinned memory to avoid blocking. - if not sampling_metadata.reuse_sampling_tensors: - self._init_sampling_tensors(logits, sampling_metadata) - elif self._do_penalties: - # In this case, the sampling tensors logic depends on - # "output_tokens" of a sequence. As a result, we cannot - # reuse sampling tensors, since "output_tokens" changes - # between decode runs. - self._init_sampling_tensors(logits, sampling_metadata) - - assert self._sampling_tensors is not None - sampling_tensors = self._sampling_tensors - do_penalties = self._do_penalties - do_top_p_top_k = self._do_top_p_top_k - do_min_p = self._do_min_p - - logits = _apply_min_tokens_penalty(logits, sampling_metadata) - - # Apply presence and frequency penalties. - if do_penalties: - logits = _apply_penalties(logits, sampling_tensors.prompt_tokens, - sampling_tensors.output_tokens, - sampling_tensors.presence_penalties, - sampling_tensors.frequency_penalties, - sampling_tensors.repetition_penalties) - + temp: torch.Tensor, + ) -> torch.Tensor: # Use float32 to apply temperature scaling. + logits = logits.to(torch.float32) + # Avoid division by zero. + temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp) # Use in-place division to avoid creating a new tensor. - logits = logits.to(torch.float) - logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1)) - - if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None: - logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, - sampling_tensors.top_ks) - - if do_min_p: - logits = _apply_min_p(logits, sampling_tensors.min_ps) - - # We use float32 for probabilities and log probabilities. - # Compute the probabilities. - probs = torch.softmax(logits, dim=-1, dtype=torch.float) - # Compute the log probabilities. - logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) - - # Sample the next tokens. - maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample( - probs, - logprobs, - sampling_metadata, - sampling_tensors, - include_gpu_probs_tensor=self.include_gpu_probs_tensor, - modify_greedy_probs=self._should_modify_greedy_probs_inplace, - ) - - if self.include_gpu_probs_tensor: - # Since we will defer sampler result Pythonization, - # preserve GPU-side tensors in support of later - # deferred pythonization of logprobs - assert maybe_sampled_tokens_tensor is not None - on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor) - else: - # Since Pythonization has already happened, don't preserve - # GPU-side tensors. - on_device_tensors = None - - # Get the logprobs query results. - prompt_logprobs = None - sample_logprobs = None - if not sampling_metadata.skip_sampler_cpu_output: - # Pythonize logprobs now (GPU -> CPU); do not defer. - assert not isinstance(maybe_deferred_sample_results, - SampleResultArgsType) - prompt_logprobs, sample_logprobs = get_logprobs( - logprobs, sampling_metadata, maybe_deferred_sample_results) - - return _build_sampler_output( - maybe_deferred_sample_results, - sampling_metadata, - prompt_logprobs, - sample_logprobs, - on_device_tensors=on_device_tensors, - skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output) - - @property - def _should_modify_greedy_probs_inplace(self) -> bool: - """Whether or not the sampler should modify the probability distribution - of greedily-sampled tokens such that multinomial sampling would sample - the greedily-sampled token. - - In other words, if True then we set the probability of the greedily- - sampled token to 1. - - This is used by speculative decoding, which requires that the sampling - method be encoded into the probability distribution. - """ - return self.should_modify_greedy_probs_inplace - - -def _get_bin_counts_and_mask( - tokens: torch.Tensor, - vocab_size: int, - num_seqs: int, -) -> Tuple[torch.Tensor, torch.Tensor]: - # Compute the bin counts for the tokens. - # vocab_size + 1 for padding. - bin_counts = torch.zeros((num_seqs, vocab_size + 1), - dtype=torch.long, - device=tokens.device) - bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens)) - bin_counts = bin_counts[:, :vocab_size] - mask = bin_counts > 0 - - return bin_counts, mask - - -def _apply_min_tokens_penalty( - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, -) -> torch.Tensor: - """Apply min_tokens penalty which sets stop tokens to -inf if min_tokens - have not been generated yet - """ - # list of indices in logits that will be set to -inf - logits_to_penalize: List[Tuple[int, int]] = [] - logits_applied = 0 - for seq_group in sampling_metadata.seq_groups: - seq_ids = seq_group.seq_ids - sampling_params = seq_group.sampling_params - - sample_indices = seq_group.sample_indices - logits_applied += len(sample_indices) + len( - seq_group.prompt_logprob_indices) - if not seq_group.do_sample: - continue - - start_idx = sample_indices[0] - min_tokens = sampling_params.min_tokens - token_ids_to_penalize = sampling_params.all_stop_token_ids - if min_tokens > 0 and token_ids_to_penalize: - seqs_to_penalize: List[int] = [] - for j, seq_id in enumerate(seq_ids): - seq_data = seq_group.seq_data[seq_id] - if len(seq_data.output_token_ids_array) < min_tokens: - seqs_to_penalize.append(j) - - if seqs_to_penalize: - # convert to the index into logits - seqs_to_penalize = [start_idx + j for j in seqs_to_penalize] - # itertools.product pairs each seq index with every token id - logits_to_penalize.extend( - itertools.product(seqs_to_penalize, token_ids_to_penalize)) - - if logits_to_penalize: - # use zip and * to group indices along each dimension - # eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) ) - logits[tuple(zip(*logits_to_penalize))] = -float("inf") - - # verifies that no rows in logits were missed unexpectedly - assert logits_applied == logits.shape[0] - return logits - - -def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, - output_tokens_tensor: torch.Tensor, - presence_penalties: torch.Tensor, - frequency_penalties: torch.Tensor, - repetition_penalties: torch.Tensor) -> torch.Tensor: - num_seqs, vocab_size = logits.shape - _, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size, - num_seqs) - output_bin_counts, output_mask = _get_bin_counts_and_mask( - output_tokens_tensor, vocab_size, num_seqs) - - repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size) - repetition_penalties[~(prompt_mask | output_mask)] = 1.0 - logits = torch.where(logits > 0, logits / repetition_penalties, - logits * repetition_penalties) - - # We follow the definition in OpenAI API. - # Refer to https://platform.openai.com/docs/api-reference/parameter-details - logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts - logits -= presence_penalties.unsqueeze_(dim=1) * output_mask - return logits - - -def _apply_top_k_top_p( - logits: torch.Tensor, - p: torch.Tensor, - k: torch.Tensor, -) -> torch.Tensor: - logits_sort, logits_idx = logits.sort(dim=-1, descending=False) - - # Apply top-k. - top_k_mask = logits_sort.size(1) - k.to(torch.long) - # Get all the top_k values. - top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) - top_k_mask = logits_sort < top_k_mask - logits_sort.masked_fill_(top_k_mask, -float("inf")) - - # Apply top-p. - probs_sort = logits_sort.softmax(dim=-1) - probs_sum = probs_sort.cumsum(dim=-1) - top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) - # at least one - top_p_mask[:, -1] = False - logits_sort.masked_fill_(top_p_mask, -float("inf")) - - # Re-sort the probabilities. - logits = torch.empty_like(logits_sort).scatter_(dim=-1, - index=logits_idx, - src=logits_sort) - return logits - - -def _apply_min_p( - logits: torch.Tensor, - min_p: torch.Tensor, -) -> torch.Tensor: - """ - Adapted from - https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17 - """ - probs = torch.softmax(logits, dim=-1) - top_probs, _ = probs.max(dim=-1, keepdim=True) - scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs - tokens_to_remove = probs < scaled_min_p - logits = logits.masked_fill_(tokens_to_remove, -float("inf")) - - return logits - - -def _greedy_sample( - selected_seq_groups: List[SequenceGroupToSample], - samples: torch.Tensor, -) -> SampleResultType: - """Run greedy sampling on a given samples. - - Args: - selected_seq_groups: A list of sequence groups batched. - samples: (num_selected_samples,) A tensor of samples. The length of - samples could be smaller than selected_seq_groups if - seq_group.do_sample is False. - Returns: - Tuple of (next_token_ids, parent_ids). The length of returned list is - same as the length of selected_seq_groups. If the corresponding - seq_group has do_sample=False, tuple contains ([], []) - """ - samples_lst = samples.tolist() - sample_idx = 0 - results: SampleResultType = [] - for seq_group in selected_seq_groups: - if not seq_group.do_sample: - results.append(([], [])) - continue - - seq_ids = seq_group.seq_ids - num_parent_seqs = len(seq_ids) - assert num_parent_seqs == 1, ( - "Greedy sampling should have only one seq.") - parent_ids = list(range(num_parent_seqs)) - next_token_ids = [samples_lst[sample_idx]] - results.append((next_token_ids, parent_ids)) - sample_idx += num_parent_seqs - return results - - -def _random_sample( - selected_seq_groups: List[SequenceGroupToSample], - random_samples: torch.Tensor, -) -> SampleResultType: - """Run random sampling on a given samples. - - Args: - selected_seq_groups: A list of sequence groups batched. - random_samples: (num_selected_samples,) A tensor of samples. The - length of samples could be smaller than selected_seq_groups if - seq_group.do_sample is False. - Returns: - Tuple of (next_token_ids, parent_ids). The length of returned list is - same as the length of selected_seq_groups. If the corresponding - seq_group has do_sample=False, tuple contains ([], []) - """ - # Find the maximum best_of value of the prompt phase requests. - random_samples = random_samples.cpu() - sample_idx = 0 - results: SampleResultType = [] - for seq_group in selected_seq_groups: - if not seq_group.do_sample: - results.append(([], [])) - continue - - seq_ids = seq_group.seq_ids - sampling_params = seq_group.sampling_params - is_prompt = seq_group.is_prompt - num_parent_seqs = len(seq_ids) - if is_prompt: - # Prompt phase. - parent_ids = [0] * sampling_params.best_of - next_token_ids = random_samples[ - sample_idx, :sampling_params.best_of].tolist() - else: - # Generation phase. - parent_ids = list(range(num_parent_seqs)) - next_token_ids = random_samples[sample_idx:sample_idx + - num_parent_seqs, 0].tolist() - results.append((next_token_ids, parent_ids)) - sample_idx += num_parent_seqs - return results - - -def _beam_search_sample( - selected_seq_groups: List[SequenceGroupToSample], - logprobs: torch.Tensor, -) -> SampleResultType: - """Run beam sampling on a given samples. - - Args: - selected_seq_groups: A list of sequence groups batched. - logprobs: (num_selected_samples, vocab_size,) A tensor of logprob - on selected sample indices. - Returns: - Tuple of (next_token_ids, parent_ids). The length of returned list is - same as the length of selected_seq_groups. If the corresponding - seq_group has do_sample=False, tuple contains ([], []) - """ - # We sample 2 * beam_width candidates to make sure that with high - # probability we can get `beam_width` candidates in addition to - # the finished sequences for the next iteration. See - # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563 - # for details. See also HF reference: - # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065 - # - # NOTE: Beam search is not vectorized, so its speed can be slower than - # other sampling methods. - sample_idx = 0 - results: SampleResultType = [] - for seq_group in selected_seq_groups: - if not seq_group.do_sample: - results.append(([], [])) - continue - - is_prompt = seq_group.is_prompt - seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params - num_parent_seqs = len(seq_ids) - beam_width = sampling_params.best_of - seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs] - if is_prompt: - # Prompt phase. - assert num_parent_seqs == 1, ( - "Prompt input should have only one seq.") - parent_ids = [0] * (2 * beam_width) - _, next_token_ids = torch.topk(seq_group_logprobs[0], - 2 * beam_width) - next_token_ids = next_token_ids.tolist() - else: - # Generation phase. - cumulative_logprobs: List[float] = [ - seq_group.seq_data[seq_id].cumulative_logprob - for seq_id in seq_ids - ] - cumulative_logprobs_tensor = torch.tensor( - cumulative_logprobs, - dtype=torch.float, - device=seq_group_logprobs.device) - seq_group_logprobs = (seq_group_logprobs + - cumulative_logprobs_tensor.unsqueeze(dim=1)) - _, topk_ids = torch.topk(seq_group_logprobs.flatten(), - 2 * beam_width) - topk_ids = topk_ids.tolist() - vocab_size = seq_group_logprobs.size(-1) - parent_ids = [i // vocab_size for i in topk_ids] - next_token_ids = [i % vocab_size for i in topk_ids] - results.append((next_token_ids, parent_ids)) - sample_idx += num_parent_seqs - assert sample_idx == logprobs.size(0) - return results - + logits.div_(temp.unsqueeze(dim=1)) + return logits -# torch.multinomial forces a GPU<->CPU sync. -# Therefore, we use an optimized implementation instead. -# Note that we always sample with replacement. -# probs will be modified in place, but this is fine, as we pass -# in a copy already. -def _multinomial( - probs: torch.Tensor, - num_samples: int, - seq_groups: Optional[List[SequenceGroupToSample]] = None, -) -> torch.Tensor: - if num_samples > 1: - probs = probs.repeat_interleave(num_samples, dim=0) - q = torch.empty_like(probs) - if seq_groups is None: - q.exponential_() - else: - sample_idx = 0 - for seq_group in seq_groups: - seq_ids = seq_group.seq_ids - stride = len(seq_ids) * num_samples - assert seq_group.generator is not None - q[sample_idx:sample_idx + - stride].exponential_(generator=seq_group.generator) - sample_idx += stride - return probs.div_(q).argmax(dim=1).view(-1, num_samples) - - -def _top_k_top_p_multinomial_with_flashinfer( - probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor, - num_samples: int, seq_groups: Optional[List[SequenceGroupToSample]]): - max_top_k_round = 32 - if num_samples > 1: - probs = probs.repeat_interleave(num_samples, dim=0) - top_ks = top_ks.repeat_interleave(num_samples) - top_ps = top_ps.repeat_interleave(num_samples) - batch_size = probs.shape[0] - uniform_samples = torch.empty((max_top_k_round, batch_size), - device=probs.device) - if seq_groups is None: - uniform_samples.uniform_() - else: - sample_idx = 0 - for seq_group in seq_groups: - seq_ids = seq_group.seq_ids - stride = len(seq_ids) * num_samples - assert seq_group.generator is not None - uniform_samples[:, sample_idx:sample_idx + - stride].uniform_(generator=seq_group.generator) - sample_idx += stride - batch_next_token_ids, success = flashinfer_top_k_top_p_sampling( - probs, - uniform_samples, - top_ks, - top_ps, - ) - if not success.all(): - warnings.warn("FlashInfer rejection sampling failed, fallback.", - stacklevel=1) - probs = flashinfer.sampling.top_k_renorm_prob(probs, top_ks) - probs = flashinfer.sampling.top_p_renorm_prob(probs, top_ps) - batch_next_token_ids = flashinfer.sampling.sampling_from_probs( - probs, uniform_samples[0]) - return batch_next_token_ids.view(-1, num_samples) - - -def get_pythonized_sample_results( - sample_result_args: SampleResultArgsType) -> SampleResultType: - '''This function consumes GPU-side sampler results and computes - Pythonized CPU-side sampler results (GPU -> CPU sync.) - - Single-step scheduling: this function is invoked at sampling-time - for immediate Pythonization. - - Multi-step scheduling: Pythonization is deferred until after multiple - GPU-side steps have been completed. - - Args: - sample_result_args: GPU-side inputs to the Pythonization process - - Returns: - Pythonized sampler results - ''' - - ( - sample_metadata, - sampling_metadata, - greedy_samples, - multinomial_samples, - beam_search_logprobs, - sample_results_dict, - ) = ( - sample_result_args.sample_metadata, - sample_result_args.sampling_metadata, - sample_result_args.greedy_samples, - sample_result_args.multinomial_samples, - sample_result_args.beam_search_logprobs, - sample_result_args.sample_results_dict, - ) - - for sampling_type in SamplingType: - if sampling_type not in sample_metadata: - continue - (seq_group_id, seq_groups) = sample_metadata[sampling_type] - if sampling_type == SamplingType.GREEDY: - sample_results = _greedy_sample(seq_groups, greedy_samples) - elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): - sample_results = _random_sample(seq_groups, - multinomial_samples[sampling_type]) - elif sampling_type == SamplingType.BEAM: - sample_results = _beam_search_sample(seq_groups, - beam_search_logprobs) - sample_results_dict.update(zip(seq_group_id, sample_results)) - - return [ - sample_results_dict.get(i, ([], [])) - for i in range(len(sampling_metadata.seq_groups)) - ] - - -def _sample_with_torch( - probs: torch.Tensor, - logprobs: torch.Tensor, - sampling_metadata: SamplingMetadata, - sampling_tensors: SamplingTensors, - include_gpu_probs_tensor: bool, - modify_greedy_probs: bool, -) -> SampleReturnType: - '''Torch-oriented _sample() implementation. - - Single-step scheduling: - * Perform GPU-side sampling computation - * Immediately Pythonize sampling result - - Multi-step scheduling: - * Perform GPU-side sampling computation - * Defer Pythonization & preserve GPU-side - tensors required for Pythonization - ''' - - categorized_seq_group_ids: Dict[SamplingType, - List[int]] = {t: [] - for t in SamplingType} - categorized_sample_indices = sampling_metadata.categorized_sample_indices - for i, seq_group in enumerate(sampling_metadata.seq_groups): - sampling_params = seq_group.sampling_params - sampling_type = sampling_params.sampling_type - categorized_seq_group_ids[sampling_type].append(i) - - sample_results_dict: SampleResultsDictType = {} - sample_metadata: SampleMetadataType = {} - multinomial_samples: MultinomialSamplesType = {} - greedy_samples: Optional[torch.Tensor] = None - beam_search_logprobs: Optional[torch.Tensor] = None - - # Create output tensor for sampled token ids. - if include_gpu_probs_tensor: - sampled_token_ids_tensor = torch.empty(logprobs.shape[0], - 1, - dtype=torch.long, - device=logprobs.device) - else: - sampled_token_ids_tensor = None - - # Counterintiutively, having two loops here is actually faster. - # The first loop can run without waiting on GPU<->CPU sync. - for sampling_type in SamplingType: - sample_indices = categorized_sample_indices[sampling_type] - num_tokens = len(sample_indices) - if num_tokens == 0: - continue - - seq_group_id = categorized_seq_group_ids[sampling_type] - seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id] - sample_metadata[sampling_type] = (seq_group_id, seq_groups) - long_sample_indices = sample_indices.long() - if sampling_type == SamplingType.GREEDY: - greedy_samples = torch.argmax(logprobs[long_sample_indices], - dim=-1) - - if sampled_token_ids_tensor is not None: - # Store sampled tokens in output tensor. - sampled_token_ids_tensor[ - long_sample_indices] = greedy_samples.unsqueeze(-1) - - if modify_greedy_probs: - # If required, modify the probabilities such that sampling from - # the modified distribution would always sample the argmax - # token id. - _modify_greedy_probs_inplace(logprobs, probs, - long_sample_indices, - greedy_samples) + def apply_penalties( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> torch.Tensor: + return logits - elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): - max_best_of_in_batch = 1 - for seq_group in seq_groups: - if seq_group.is_prompt: - sampling_params = seq_group.sampling_params - max_best_of_in_batch = max(max_best_of_in_batch, - sampling_params.best_of) - seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else - seq_groups) + def get_probs(self, logits: torch.Tensor) -> torch.Tensor: + return torch.softmax(logits, dim=-1, dtype=torch.float32) - if flashinfer_top_k_top_p_sampling is not None: - multinomial_samples[ - sampling_type] = _top_k_top_p_multinomial_with_flashinfer( - probs[long_sample_indices], - sampling_tensors.top_ks[long_sample_indices], - sampling_tensors.top_ps[long_sample_indices], - max_best_of_in_batch, - seq_groups_arg, - ) - else: - multinomial_samples[sampling_type] = _multinomial( - probs[long_sample_indices], - max_best_of_in_batch, - seq_groups=seq_groups_arg) + def get_logprobs(self, logits: torch.Tensor) -> torch.Tensor: + return torch.log_softmax(logits, dim=-1, dtype=torch.float32) - if sampled_token_ids_tensor is not None: - # Store sampled tokens in output tensor. - sampled_token_ids_tensor[long_sample_indices] = \ - multinomial_samples[sampling_type].to(torch.long) + def greedy_sample(self, probs: torch.Tensor) -> torch.Tensor: + return probs.argmax(dim=1).view(-1) - elif sampling_type == SamplingType.BEAM: - beam_search_logprobs = logprobs[sample_indices] + def random_sample( + self, + probs: torch.Tensor, + generators: Optional[List[torch.Generator]], + no_generator: bool, + ) -> torch.Tensor: + q = torch.empty_like(probs) + if no_generator: + q.exponential_() else: - raise ValueError(f"Unsupported sampling type: {sampling_type}") - - # Encapsulate arguments for computing Pythonized sampler - # results, whether deferred or otherwise. - maybe_deferred_args = SampleResultArgsType( - sampling_metadata=sampling_metadata, - sample_metadata=sample_metadata, - multinomial_samples=multinomial_samples, - greedy_samples=greedy_samples, - beam_search_logprobs=beam_search_logprobs, - sample_results_dict=sample_results_dict) - - if not sampling_metadata.skip_sampler_cpu_output: - # GPU<->CPU sync happens here. - # This also converts the sampler output to a Python object. - # Return Pythonized sampler result & sampled token ids - return get_pythonized_sample_results( - maybe_deferred_args), sampled_token_ids_tensor - else: - # Defer sampler result Pythonization; return deferred - # Pythonization args & sampled token ids - return ( - maybe_deferred_args, - sampled_token_ids_tensor, - ) - - -def _sample( - probs: torch.Tensor, - logprobs: torch.Tensor, - sampling_metadata: SamplingMetadata, - sampling_tensors: SamplingTensors, - include_gpu_probs_tensor: bool, - modify_greedy_probs: bool, -) -> SampleReturnType: - """ - Args: - probs: (num_query_tokens_in_batch, num_vocab) - logprobs: (num_query_tokens_in_batch, num_vocab) - sampling_metadata: The metadata for a batch for sampling. - sampling_tensors: Tensors that include sampling related metadata. - - Returns: - (next_token_ids, parent_seq_ids) for each seq group in a batch. - If sampling is skipped, it returns ([], []) - sampled_token_ids_tensor: A tensor of sampled token ids. - """ - return _sample_with_torch( - probs, - logprobs, - sampling_metadata, - sampling_tensors, - include_gpu_probs_tensor=include_gpu_probs_tensor, - modify_greedy_probs=modify_greedy_probs, - ) - - -def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: - """ - This function calculates the ranks of the chosen tokens in a logprob tensor. - - Args: - x (torch.Tensor): 2D logprob tensor of shape (N, M) - where N is the no. of tokens and M is the vocab dim. - indices (torch.Tensor): List of chosen token indices. - - Returns: - torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens. - Each element in the returned tensor represents the rank - of the chosen token in the input logprob tensor. - """ - vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype), - indices] - result = (x > vals[:, None]) - del vals - return result.sum(1).add_(1) - - -def get_logprobs( - logprobs: torch.Tensor, - sampling_metadata: SamplingMetadata, - sample_results: SampleResultType, -) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]: - """Return sample lobprobs and prompt logprobs. - - The logic consists of 3 parts. - - Select indices to compute logprob from, ranks of token ids, and - the top k token ids from logprobs. - - Compute prompt logprobs if required. - - Compute sample logprobs if required. - - Args: - logprobs: (num_query_tokens_across_batch, num_vocab). Each query token's - logprob per vocab. Sequence groups' query tokens are batched in a - single flattened tensor. For example, assuming there are N - seq groups, it is sorted by prefill tokens for seq_group_1 (if - prompt logprob is enabled), decode tokens for seq_group_1 (if - sampling is required), prefill tokens for seq_group_2, ... - sampling_metadata: The sampling metadata. - sample_results: (num_seq_groups) The tuple of (next_token_ids, - parent_ids) for each sequence group. When beam search is enabled, - sample_results can contain different number of seq_ids from - sampling_metadata.seq_groups. It is because beam search creates - 2 * BEAM_WIDTH number of samples (whereas there are only up to - BEAM_WIDTH number of seq_ids). - - Returns: - A tuple of prompt and sample logprobs per sequence group in a batch. - """ - # The index of query token to calculate logprobs. It includes both - # prompt and sample logprob indices. - query_indices: List[int] = [] - # The next token ids to get the logprob value from. - next_token_ids: List[int] = [] - # The largest requested number of logprobs. We find logprobs as many as the - # largest num logprobs in this API. If every logprobs is None, it will be - # set to -1. - largest_num_logprobs = -1 - # If beam search is enabled. - use_beam_search = False - - # Select indices to compute logprob from, ranks of token ids, and the top - # k token ids from logprobs. - for (seq_group, sample_result) in zip(sampling_metadata.seq_groups, - sample_results): - sampling_params = seq_group.sampling_params - - # Update indices and tokens for prompt logprobs. - if (seq_group.is_prompt - and sampling_params.prompt_logprobs is not None): - largest_num_logprobs = max(largest_num_logprobs, - sampling_params.prompt_logprobs) - next_prompt_tokens = _get_next_prompt_tokens(seq_group) - query_indices.extend(seq_group.prompt_logprob_indices) - next_token_ids.extend(next_prompt_tokens) - - # Update indices and next tokenes for sample logprob. - if seq_group.do_sample: - token_ids, parent_seq_ids = sample_result - # NOTE: We cannot directly use sample_indices because - # sample_indices only contain parent seq_ids of a previous step. - # The current step may have different number of seq_ids, and - # we can obtain it from `sample_result[1]`. - query_idx = seq_group.sample_indices[0] - query_indices.extend( - [query_idx + parent_id for parent_id in parent_seq_ids]) - next_token_ids.extend(token_ids) - - if sampling_params.logprobs is not None: - largest_num_logprobs = max(largest_num_logprobs, - sampling_params.logprobs) - - use_beam_search = use_beam_search or sampling_params.use_beam_search + assert generators is not None and len(generators) == probs.shape[0] + # TODO(woosuk): Optimize this. + for i, generator in enumerate(generators): + q[i].exponential_(generator=generator) + return probs.div_(q).argmax(dim=1).view(-1) - assert len(next_token_ids) == len(query_indices) - - if len(query_indices) == 0: - empty_sampled_logprob: SampleLogprobs = [] - empty_prompt_logprob: Optional[PromptLogprobs] = None - return [empty_prompt_logprob], [empty_sampled_logprob] - - selected_logprobs, ranks = None, None - top_logprobs, top_token_ids = None, None - - # If largest_num_logprobs == -1, i.e. no logprobs are requested, we can - # skip the whole logprob calculation. - if largest_num_logprobs >= 0 or use_beam_search: - query_indices_gpu = torch.tensor(query_indices, device=logprobs.device) - next_token_ids_gpu = torch.tensor(next_token_ids, - device=logprobs.device) - - # (num_selected_query_tokens, num_logprobs). Note that query_indices can - # contain duplicates if beam search is enabled. - selected_logprobs = logprobs[[ - query_indices_gpu, - next_token_ids_gpu, - ]] - ranks = _get_ranks( - logprobs[query_indices_gpu], - next_token_ids_gpu, + def sample( + self, + probs: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> torch.Tensor: + assert not (sampling_metadata.all_greedy + and sampling_metadata.all_random) + if sampling_metadata.all_greedy: + return self.greedy_sample(probs) + if sampling_metadata.all_random: + return self.random_sample(probs, sampling_metadata.generators, + sampling_metadata.no_generator) + + greedy_sampled = self.greedy_sample(probs) + random_sampled = self.random_sample(probs, + sampling_metadata.generators, + sampling_metadata.no_generator) + sampled = torch.where( + sampling_metadata.temperature < _SAMPLING_EPS, + greedy_sampled, + random_sampled, ) - assert selected_logprobs.shape[0] == ranks.shape[0] - - # We need to compute top k only if there exists logprobs > 0. - if largest_num_logprobs > 0: - # Logprobs of topk tokens for a batch of sequence groups. - # (num_query_tokens_across_batch). - top_logprobs, top_token_ids = torch.topk(logprobs, - largest_num_logprobs, - dim=-1) - top_logprobs = top_logprobs.to('cpu') - top_token_ids = top_token_ids.to('cpu') - - selected_logprobs = selected_logprobs.to('cpu') - ranks = ranks.to('cpu') - - # Find prompt/sample logprobs. - prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = [] - sample_logprobs_per_seq_group: List[SampleLogprobs] = [] - top_logprob_idx = 0 - selected_logprobs_idx = 0 - - for seq_group, sample_result in zip(sampling_metadata.seq_groups, - sample_results): - (prompt_logprobs, top_logprob_idx, - selected_logprobs_idx) = _get_prompt_logprob_if_needed( - seq_group, selected_logprobs, ranks, top_token_ids, top_logprobs, - selected_logprobs_idx, top_logprob_idx) - prompt_logprobs_per_seq_group.append(prompt_logprobs) - - (sampled_logprobs, top_logprob_idx, - selected_logprobs_idx) = _get_sampled_logprob_if_needed( - seq_group, sample_result, selected_logprobs, ranks, top_token_ids, - top_logprobs, selected_logprobs_idx, top_logprob_idx) - sample_logprobs_per_seq_group.append(sampled_logprobs) - - return prompt_logprobs_per_seq_group, sample_logprobs_per_seq_group - - -def _get_prompt_logprob_if_needed( - seq_group: SequenceGroupToSample, - selected_logprobs: torch.Tensor, - ranks: torch.Tensor, - top_token_ids: torch.Tensor, - top_logprobs: torch.Tensor, - selected_logprobs_idx: int, - top_logprob_idx: int, -): - """Compute the prompt logprob from a sequence group if needed.""" - sampling_params = seq_group.sampling_params - is_prompt = seq_group.is_prompt - - # Find prompt logprobs - prompt_logprobs: Optional[PromptLogprobs] = None - if is_prompt and sampling_params.prompt_logprobs is not None: - prompt_logprobs = [] - num_logprobs = sampling_params.prompt_logprobs - next_prompt_tokens = _get_next_prompt_tokens(seq_group) - # Pre-select indexes and create a list. It is faster than calling .item - # repetitively. - selected_logprob_items = selected_logprobs[ - selected_logprobs_idx:selected_logprobs_idx + - len(next_prompt_tokens)].tolist() - rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx + - len(next_prompt_tokens)].tolist() - - for idx, token_id in enumerate(next_prompt_tokens): - # Calculate the prompt logprob of the real prompt tokens. - # {token_id: (logprob, rank_from_vocab)} - prompt_logprobs_dict: Dict[int, Tuple[float, int]] = { - token_id: (selected_logprob_items[idx], rank_items[idx]) - } - - # Add top K prompt logprobs along with its rank. - if num_logprobs > 0: - top_ids = top_token_ids[ - top_logprob_idx, :num_logprobs].tolist() - top_probs = top_logprobs[ - top_logprob_idx, :num_logprobs].tolist() - # Top K is already sorted by rank, so we can use 1 ~ - # num_logprobs + 1 for rank. - top_ranks = range(1, num_logprobs + 1) - prompt_logprobs_dict.update({ - top_id: (top_prob, rank) - for top_id, top_prob, rank in zip(top_ids, top_probs, - top_ranks) - }) - prompt_logprobs.append({ - token_id: Logprob(*logprob_and_rank) - for token_id, logprob_and_rank in prompt_logprobs_dict.items() - }) - # + 1 to go to the next prompt token. - top_logprob_idx += 1 - - # + len(next_prompt_tokens) to go to the next prompt. - selected_logprobs_idx += len(next_prompt_tokens) - return prompt_logprobs, top_logprob_idx, selected_logprobs_idx - - -def _get_sampled_logprob_if_needed( - seq_group: SequenceGroupToSample, - sample_result: Tuple[List[int], List[int]], - selected_logprobs: torch.Tensor, - ranks: torch.Tensor, - top_token_ids: torch.Tensor, - top_logprobs: torch.Tensor, - selected_logprobs_idx: int, - top_logprob_idx: int, -): - """Compute the sample logprob if needed.""" - seq_ids = seq_group.seq_ids - num_logprobs = seq_group.sampling_params.logprobs - use_beam_search = seq_group.sampling_params.use_beam_search - sampled_logprobs: SampleLogprobs = [] - next_token_ids, parent_seq_ids = sample_result - - if seq_group.do_sample: - assert len(next_token_ids) > 0 - if num_logprobs is None and not use_beam_search: - for next_token_id in next_token_ids: - # Use a dummy logprob - sampled_logprobs.append({next_token_id: Logprob(inf)}) - else: - # Pre-select items from tensor. tolist() is faster than repetitive - # `.item()` calls. - selected_logprob_items = selected_logprobs[ - selected_logprobs_idx:selected_logprobs_idx + - len(next_token_ids)].tolist() - rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx + - len(next_token_ids)].tolist() - for idx, (next_token_id, parent_id) in enumerate( - zip(next_token_ids, parent_seq_ids)): - # Get the logprob of a sampled token. - sampled_logprobs_dict = { - next_token_id: - (selected_logprob_items[idx], rank_items[idx]) - } - if num_logprobs is not None and num_logprobs > 0: - # Get top K logprobs. - top_ids = top_token_ids[top_logprob_idx + - parent_id, :num_logprobs].tolist() - top_probs = top_logprobs[ - top_logprob_idx + parent_id, :num_logprobs].tolist() - # Top K is already sorted by rank, so we can use 1 ~ - # num_logprobs + 1 for rank. - top_ranks = range(1, num_logprobs + 1) - sampled_logprobs_dict.update({ - top_id: (top_prob, rank) - for top_id, top_prob, rank in zip( - top_ids, top_probs, top_ranks) - }) - - sampled_logprobs.append({ - token_id: Logprob(*logprob_and_rank) - for token_id, logprob_and_rank in - sampled_logprobs_dict.items() - }) - - # NOTE: This part of code is not intuitive. `selected_logprobs` include - # logprobs for the current step, which has len(next_token_ids) tokens - # per sequence group. `logprobs` includes logprobs from the previous - # steps, which has len(seq_ids) tokens per sequence group. - - # Iterate to the next sequence group in a batch. - selected_logprobs_idx += len(next_token_ids) - # Iterate to the next sequence group in a batch. - top_logprob_idx += len(seq_ids) - return sampled_logprobs, top_logprob_idx, selected_logprobs_idx - - -def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor, - sample_indices: torch.Tensor, - greedy_samples: torch.Tensor) -> None: - """Modify the probability distributions of the greedily-sampled tokens such - that each sampled token has a "probability" of 1.0. This is required by - speculative decoding, which depends on the sampling method being encoded - within the probability distribution for correctness. - - # Why do we only need to do this for greedy sampling? - - vLLM's sampler performs the following steps for greedy or multinomial - (random) sampling: - 1. Get logits from model. - 2. Modify logits according to per-sequence sampling parameters. - - Multiply by temperature, top-k and top-p masking, penalize tokens - according to their frequency, etc. - 3. Sample a token. - - Random sampling simply samples from the modified probability - distribution. - - Greedy sampling performs `argmax` to obtain the token with the - highest likelihood. - - Ignoring greedy sampling for a moment, we find that the computed probability - distribution has the following property: we can sample from it independently - and find that the token sampled by the Sampler has a frequency corresponding - to how often we see it in our sampling. In other words, for tokens sampled - with vLLM's random SamplingType, the computed probability distribution - encodes the sampling methodology completely. - - Greedy sampling does not normally have this property. vLLM modifies logits - according to sampling params, then performs `argmax`, then returns the - sampled token and the computed probability distribution. If we sample from - the distribution, we'll find the likelihood of the greedily-sampled token - is not always 1.0. - - Since lossless speculative decoding requires that the sampling methodology - be encoded within the probability distribution, we are motivated to modify - the probability distribution such that the sampled token has probability 1 - when speculative decoding is used. - - NOTE: Alternatively, we could use an extremely low temperature to achieve - greedy sampling using multinomial computation and unite the codepaths. This - has implications on the overall design of the sampler, e.g. how to record - accurate logprobs for the user, so this improvement is deferred to later. - """ - # NOTE: logprobs are not modified so they can be returned to the user. - probs[sample_indices, :] = 0 - probs[sample_indices, greedy_samples] = 1.0 - - -def _build_sampler_output( - maybe_deferred_sample_results: MaybeDeferredSampleResultType, - sampling_metadata: SamplingMetadata, - prompt_logprobs: Optional[List[Optional[PromptLogprobs]]], - sample_logprobs: Optional[List[SampleLogprobs]], - on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor, - torch.Tensor]], - skip_sampler_cpu_output: bool = False, -) -> SamplerOutput: - """Construct Python objects with the output of sampling. - - Args: - on_device_tensors: Tuple containing on-device tensors with the - probabilities used in sampling and the sampled token ids. This - allows post-processing without copies to CPU/serialization, e.g. in - speculative decoding rejection sampling. - """ - sampler_output: List[CompletionSequenceGroupOutput] = [] - - if skip_sampler_cpu_output: - assert isinstance(maybe_deferred_sample_results, SampleResultArgsType) - deferred_sample_results_args = maybe_deferred_sample_results - else: - assert prompt_logprobs is not None - assert sample_logprobs is not None - assert not isinstance(maybe_deferred_sample_results, - SampleResultArgsType) - deferred_sample_results_args = None - - for (seq_group, sample_result, group_prompt_logprobs, - group_sample_logprobs) in zip(sampling_metadata.seq_groups, - maybe_deferred_sample_results, - prompt_logprobs, sample_logprobs): - seq_ids = seq_group.seq_ids - next_token_ids, parent_ids = sample_result - seq_outputs: List[SequenceOutput] = [] - for parent_id, next_token_id, logprobs in zip( - parent_ids, next_token_ids, group_sample_logprobs): - seq_outputs.append( - SequenceOutput(seq_ids[parent_id], next_token_id, - logprobs)) - sampler_output.append( - CompletionSequenceGroupOutput(seq_outputs, - group_prompt_logprobs)) - - # If not specified, store None values in SamplerOutput. - if on_device_tensors is not None: - (sampled_token_probs, logprobs_tensor, - sampled_token_ids) = on_device_tensors - else: - sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None, - None) - - return SamplerOutput( - outputs=sampler_output, - sampled_token_probs=sampled_token_probs, - sampled_token_ids=sampled_token_ids, - logprobs=logprobs_tensor, - deferred_sample_results_args=deferred_sample_results_args) - - -def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]: - """Get a list of next prompt tokens to compute logprob from a - given sequence group. - - It is used to compute prompt logprob. Imagine you have logprob for each - query token. Query token needs to know the next prompt token id to compute - prompt logprob. This is a helper to obtain next prompt token ids. - - This API has to be used only when the caller knows seq_group is in prefill - stage. - - Returns: - A list of next prompt tokens to compute logprob. - """ - assert seq_group.is_prompt, ( - "Caller should ensure the sequence group is in a prefill stage.") - seq_ids = seq_group.seq_ids - query_len = seq_group.query_len - assert query_len is not None - # prompt has only 1 seq id. - assert len(seq_ids) == 1 - seq_data = seq_group.seq_data[seq_ids[0]] - computed_len = seq_data.get_num_computed_tokens() - prompt_tokens = seq_data.prompt_token_ids - # +1 because we are looking for a next prompt token. - next_token_index_start = computed_len + 1 - next_token_index_end = min(computed_len + query_len + 1, - len(prompt_tokens)) - next_prompt_tokens = prompt_tokens[ - next_token_index_start:next_token_index_end] - return next_prompt_tokens + return sampled diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 97d36d31f2b11..28614377b27b9 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -1,586 +1,22 @@ -from array import array from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple +from typing import List, Optional import torch -from vllm.sampling_params import SamplingParams, SamplingType -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData, - SequenceGroupMetadata) -from vllm.utils import (PyObjectCache, async_tensor_h2d, - is_pin_memory_available, make_tensor_with_pad) - -_SAMPLING_EPS = 1e-5 - @dataclass -class SequenceGroupToSample: - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ----------------------| - # |-- query_len ---| - - # Sequence ids for the sequence group in a previous step. - seq_ids: List[int] - sampling_params: SamplingParams - # seq_id -> sequence data. - seq_data: Dict[int, SequenceData] - # The length of the sequence (all tokens seen in the past + new token to - # compute attention) of the sequence group. None if it is in a decode - # stage. - seq_len: Optional[int] - # The length of new query tokens to compute in the current step. None if it - # is in a decode stage. The length of query_len <= seq_len if chunked - # prefill is enabled. - query_len: Optional[int] - # A random number generator for sampling. - generator: Optional[torch.Generator] - # True if the sequence group is in prefill stage. False if it is in a - # decode stage. - is_prompt: bool - # Query token indices from logits. to compute prompt logprob. Empty if - # prompt logprob is not required. - prompt_logprob_indices: List[int] - # Sample token indices from logits. Empty if sampling is not required. - sample_indices: List[int] - - @property - def do_sample(self): - return len(self.sample_indices) > 0 - - def __post_init__(self): - if len(self.prompt_logprob_indices) > 0: - assert self.sampling_params.prompt_logprobs is not None - if self.is_prompt: - assert self.seq_len is not None - assert self.query_len is not None - - -def gen_seq_group_to_sample_builder(num_seqs: int): - return lambda: SequenceGroupToSample( - seq_ids=[0] * num_seqs, - sampling_params=None, - seq_data=None, # type: ignore - seq_len=0, - query_len=0, - generator=None, - is_prompt=True, - prompt_logprob_indices=[], - sample_indices=[], - ) - - -class SamplingMetadataCache: - """Used to cache SamplingMetadata objects between scheduler iterations""" - - def __init__(self): - self._seq_group_to_sample_cache: Dict[int, PyObjectCache] = {} - - def get_cached_seq_group_to_sample(self, num_seqs): - if num_seqs not in self._seq_group_to_sample_cache: - self._seq_group_to_sample_cache[num_seqs] = PyObjectCache( - gen_seq_group_to_sample_builder(num_seqs)) - - obj = self._seq_group_to_sample_cache[num_seqs].get_object() - return obj - - def reset(self): - for cache in self._seq_group_to_sample_cache.values(): - cache.reset() - - class SamplingMetadata: - """Metadata for input sequences. Used in sampler. - - The usage is as follow; - ``` - hidden_states = execute_model(...) - logits = hidden_states[sampling_metadata.selected_token_indices] - sample(logits) - - def sample(logits): - # Use categorized_sample_indices for sampling.... - ``` - - Args: - seq_groups: List of batched sequence groups. - selected_token_indices: (num_query_tokens_to_logprob). Indices to find - logits from the initial model output hidden states. - categorized_sample_indices: SamplingType -> token indices to sample. - Each token indices is 2D tensor of (num_indices, num_indices) where - the first item means the sample index within the returned logit - (before pruning padding), and the second item means the sample - index after pruning using selected_token_indices. - For example, if the returned logit is [1, 2, 3], and we select - [1, 2] for sampling, the pruned logit will be [2, 3]. In this case, - The first tuple is [1, 2] (sampled index within original logit), - and the second tuple is [0, 1] (sampled index within pruned logit). - num_prompts: Number of prompt sequence groups in seq_groups. - skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU - serialization of token outputs. - reuse_sampling_tensors: Indicates if we want to reuse sampling - tensors that are part of the sampler forward pass. Currently, - it is mainly used for multi-step decode. - - """ - - def __init__( - self, - seq_groups: List[SequenceGroupToSample], - selected_token_indices: torch.Tensor, - categorized_sample_indices: Dict[SamplingType, torch.Tensor], - num_prompts: int, - skip_sampler_cpu_output: bool = False, - reuse_sampling_tensors: bool = False, - ) -> None: - self.seq_groups = seq_groups - self.selected_token_indices = selected_token_indices - self.categorized_sample_indices = categorized_sample_indices - self.num_prompts = num_prompts - self.skip_sampler_cpu_output = skip_sampler_cpu_output - self.reuse_sampling_tensors = reuse_sampling_tensors - - @staticmethod - def prepare( - seq_group_metadata_list: List[SequenceGroupMetadata], - seq_lens: List[int], - query_lens: Optional[List[int]], - device: str, - pin_memory: bool, - generators: Optional[Dict[str, torch.Generator]] = None, - cache: Optional[SamplingMetadataCache] = None, - ) -> "SamplingMetadata": - ( - seq_groups, - selected_token_indices, - categorized_sample_indices, - num_prompts, - ) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens, - device, generators, cache) - selected_token_indices = async_tensor_h2d( - selected_token_indices, - dtype=torch.long, - target_device=device, - pin_memory=pin_memory, - ) - categorized_sample_indices = { - t: async_tensor_h2d( - seq_ids, - dtype=torch.int, - target_device=device, - pin_memory=pin_memory, - ) - for t, seq_ids in categorized_sample_indices.items() - } - - sampling_metadata = SamplingMetadata( - seq_groups=seq_groups, - selected_token_indices=selected_token_indices, - categorized_sample_indices=categorized_sample_indices, - num_prompts=num_prompts, - ) - return sampling_metadata - - def __repr__(self) -> str: - return ( - "SamplingMetadata(" - f"seq_groups={self.seq_groups}, " - f"selected_token_indices={self.selected_token_indices}, " - f"categorized_sample_indices={self.categorized_sample_indices}), ") - - -def _prepare_seq_groups( - seq_group_metadata_list: List[SequenceGroupMetadata], - seq_lens: List[int], - query_lens: Optional[List[int]], - device: str, - generators: Optional[Dict[str, torch.Generator]] = None, - cache: Optional[SamplingMetadataCache] = None, -) -> Tuple[List[SequenceGroupToSample], List[int], Dict[SamplingType, - List[int]], int, ]: - """Prepare sequence groups and indices for sampling. - - Args: - seq_group_metadata_list: A list of sequence group to batch. - seq_lens: A list of sequence lens per sequence group. - Index of prompt len should match with seq_group_metadata_list. - query_lens: A list of query lengths. Prompt lens include the length - of entire prompt tokens, and it could be shorter. - device: A device to use for random number generators, - `SequenceGroupToSample.generator`. - generators: A store of per-request random number generators used - for seeded requests. - - Returns: - seq_groups: A list of sequence group to sample. - selected_token_indices: See the definition from `SamplingMetadata`. - categorized_sample_indices: See the definition from `SamplingMetadata`. - num_prompts: Total number of prompts from `seq_group_metadata_list`. - """ - # Batched sequence groups for the current model forward stsep. - seq_groups: List[SequenceGroupToSample] = [] - # A list of token indices to sample/compute logprob. It is used to - # prune the outcome logits from the model for the performance. - selected_token_indices: List[int] = [] - # Used for selected_token_indices. - model_output_idx = 0 - - # Sampling type -> ( - # indices to sample/prompt logprob within pruned output logits, - # indices to sample within pruned logits) - categorized_sample_indices: Dict[SamplingType, List[int]] = { - t: [] - for t in SamplingType - } - # Index of logits to compute logprob. Logits include both prompt logprob - # and sample logprob indices. - logit_idx = 0 - # Total number of prompts from given sequence groups. - num_prompts = 0 - - for i, seq_group_metadata in enumerate(seq_group_metadata_list): - seq_ids = seq_group_metadata.seq_data.keys() - - if cache is not None: - sample_obj = cache.get_cached_seq_group_to_sample(len(seq_ids)) - - for j, seq_id in enumerate(seq_ids): - sample_obj.seq_ids[j] = seq_id - - sample_obj.prompt_logprob_indices.clear() - sample_obj.sample_indices.clear() - - sampling_params = seq_group_metadata.sampling_params - is_prompt = seq_group_metadata.is_prompt - generator: Optional[torch.Generator] = None - # If the current seq group is in decode stage, it is None. - seq_len: Optional[int] = None - query_len: Optional[int] = None - prompt_logprob_indices: List[int] = (sample_obj.prompt_logprob_indices - if cache is not None else []) - sample_indices: List[int] = (sample_obj.sample_indices - if cache is not None else []) - do_sample = seq_group_metadata.do_sample - - if seq_group_metadata.is_prompt: - if sampling_params.seed is not None: - generator = torch.Generator(device=device).manual_seed( - sampling_params.seed) - if generators is not None: - generators[seq_group_metadata.request_id] = generator - - num_prompts += 1 - num_prefill_sample = len(seq_ids) - assert num_prefill_sample == 1 - assert query_lens is not None and seq_lens is not None - query_len, seq_len = query_lens[i], seq_lens[i] - # If we need sampling, exclude num_prefill_sample tokens from - # prompt logprob. - prompt_logprob_len = (query_len - num_prefill_sample - if do_sample else query_len) - sample_len = num_prefill_sample if do_sample else 0 - else: - # Decode - prompt_logprob_len = 0 - sample_len = len(seq_ids) if do_sample else 0 - - if sampling_params.seed is not None and generators is not None: - generator = generators.get(seq_group_metadata.request_id) - - # Update indices to select from the model output. - """ - This blocks computes selected_token_indices which is used in the - following way. - - hidden_states = model(...) - logits = hidden_states[selected_token_indices] - """ - - if sampling_params.prompt_logprobs is not None: - selected_token_indices.extend( - range(model_output_idx, model_output_idx + prompt_logprob_len)) - model_output_idx += prompt_logprob_len - if do_sample: - selected_token_indices.extend( - range(model_output_idx, model_output_idx + sample_len)) - model_output_idx += sample_len - - # We now find indices for logprob computation and sampling. - """ - This block computes categorized_sample_indices which is used in the - following way. - - hidden_states = model(...) - logits = hidden_states[selected_token_indices] - def sample(logits): - # Use categorized_sample_indices for sampling. - # prompt_logprob_indices to find prompt logprob indices. - # sample_indices to find sample indices. - """ - - if sampling_params.prompt_logprobs is not None: - prompt_logprob_indices.extend( - range(logit_idx, logit_idx + prompt_logprob_len)) - logit_idx += prompt_logprob_len - if do_sample: - sample_indices.extend(range(logit_idx, logit_idx + sample_len)) - categorized_sample_indices[sampling_params.sampling_type].extend( - list(range(logit_idx, logit_idx + sample_len))) - logit_idx += sample_len - - if cache is not None: - sample_obj.sampling_params = sampling_params - sample_obj.seq_data = seq_group_metadata.seq_data - sample_obj.seq_len = seq_len - sample_obj.query_len = query_len - sample_obj.generator = generator - sample_obj.is_prompt = is_prompt - else: - sample_obj = SequenceGroupToSample( - seq_ids=list(seq_ids), - sampling_params=sampling_params, - seq_data=seq_group_metadata.seq_data, - seq_len=seq_len, - query_len=query_len, - generator=generator, - is_prompt=is_prompt, - prompt_logprob_indices=list(prompt_logprob_indices), - sample_indices=list(sample_indices), - ) - - seq_groups.append(sample_obj) - - if cache is not None: - cache.reset() - - return (seq_groups, selected_token_indices, categorized_sample_indices, - num_prompts) - - -@dataclass -class SamplingTensors: - """Tensors for sampling.""" - - temperatures: torch.Tensor - top_ps: torch.Tensor - top_ks: torch.Tensor - min_ps: torch.Tensor - presence_penalties: torch.Tensor - frequency_penalties: torch.Tensor - repetition_penalties: torch.Tensor - prompt_tokens: torch.Tensor - output_tokens: torch.Tensor - - @classmethod - def from_sampling_metadata( - cls, - sampling_metadata: "SamplingMetadata", - vocab_size: int, - device: torch.device, - dtype: torch.dtype, - ) -> Tuple["SamplingTensors", bool, bool, bool]: - prompt_tokens: List[array] = [] - output_tokens: List[array] = [] - top_ks: List[int] = [] - temperatures: List[float] = [] - top_ps: List[float] = [] - min_ps: List[float] = [] - presence_penalties: List[float] = [] - frequency_penalties: List[float] = [] - repetition_penalties: List[float] = [] - do_penalties = False - do_top_p_top_k = False - do_min_p = False - - assert sampling_metadata.seq_groups is not None - for seq_group in sampling_metadata.seq_groups: - seq_ids = seq_group.seq_ids - sampling_params = seq_group.sampling_params - temperature = sampling_params.temperature - p = sampling_params.presence_penalty - f = sampling_params.frequency_penalty - r = sampling_params.repetition_penalty - top_p = sampling_params.top_p - min_p = sampling_params.min_p - - # k should not be greater than the vocab size. - top_k = min(sampling_params.top_k, vocab_size) - top_k = vocab_size if top_k == -1 else top_k - if temperature < _SAMPLING_EPS: - # NOTE: Zero temperature means deterministic sampling - # (i.e., greedy sampling or beam search). - # Set the temperature to 1 to avoid division by zero. - temperature = 1.0 - if not do_top_p_top_k and (top_p < 1.0 - _SAMPLING_EPS - or top_k != vocab_size): - do_top_p_top_k = True - if not do_min_p and min_p > _SAMPLING_EPS: - do_min_p = True - if not do_penalties and (abs(p) >= _SAMPLING_EPS - or abs(f) >= _SAMPLING_EPS - or abs(r - 1.0) >= _SAMPLING_EPS): - do_penalties = True - - is_prompt = seq_group.is_prompt - if is_prompt and sampling_params.prompt_logprobs is not None: - # For tokens in the prompt that we only need to get - # their logprobs - query_len = seq_group.query_len - assert query_len is not None - prefill_len = len(seq_group.prompt_logprob_indices) - temperatures += [temperature] * prefill_len - top_ps += [top_p] * prefill_len - top_ks += [top_k] * prefill_len - min_ps += [min_p] * prefill_len - presence_penalties += [0] * prefill_len - frequency_penalties += [0] * prefill_len - repetition_penalties += [1] * prefill_len - - if seq_group.do_sample: - sample_lens = len(seq_group.sample_indices) - assert sample_lens == len(seq_ids) - temperatures += [temperature] * len(seq_ids) - top_ps += [top_p] * len(seq_ids) - top_ks += [top_k] * len(seq_ids) - min_ps += [min_p] * len(seq_ids) - presence_penalties += [p] * len(seq_ids) - frequency_penalties += [f] * len(seq_ids) - repetition_penalties += [r] * len(seq_ids) - - if do_penalties: - for seq_group in sampling_metadata.seq_groups: - seq_ids = seq_group.seq_ids - if (seq_group.is_prompt - and sampling_params.prompt_logprobs is not None): - prefill_len = len(seq_group.prompt_logprob_indices) - prompt_tokens.extend( - array(VLLM_TOKEN_ID_ARRAY_TYPE) - for _ in range(prefill_len)) - output_tokens.extend( - array(VLLM_TOKEN_ID_ARRAY_TYPE) - for _ in range(prefill_len)) - if seq_group.do_sample: - for seq_id in seq_ids: - seq_data = seq_group.seq_data[seq_id] - prompt_tokens.append(seq_data.prompt_token_ids_array) - output_tokens.append(seq_data.output_token_ids_array) - - sampling_tensors = SamplingTensors.from_lists( - temperatures, - top_ps, - top_ks, - min_ps, - presence_penalties, - frequency_penalties, - repetition_penalties, - prompt_tokens, - output_tokens, - vocab_size, - device, - dtype, - ) - return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p) - - @classmethod - def from_lists( - cls, - temperatures: List[float], - top_ps: List[float], - top_ks: List[int], - min_ps: List[float], - presence_penalties: List[float], - frequency_penalties: List[float], - repetition_penalties: List[float], - prompt_tokens: List[array], - output_tokens: List[array], - vocab_size: int, - device: torch.device, - dtype: torch.dtype, - ) -> "SamplingTensors": - # Note that the performance will be very bad without - # pinned memory. - pin_memory = is_pin_memory_available() - do_penalties = prompt_tokens or output_tokens + temperature: torch.Tensor + all_greedy: bool + all_random: bool - if do_penalties: - prompt_t = make_tensor_with_pad( - prompt_tokens, - vocab_size, - device="cpu", - dtype=torch.int64, - pin_memory=pin_memory, - ) - output_t = make_tensor_with_pad( - output_tokens, - vocab_size, - device="cpu", - dtype=torch.int64, - pin_memory=pin_memory, - ) - else: - empty_tensor = torch.empty(0, device=device, dtype=torch.long) - prompt_t = empty_tensor - output_t = empty_tensor + top_p: torch.Tensor + top_k: torch.Tensor + no_top_p: bool + no_top_k: bool - temperatures_t = torch.tensor( - temperatures, - device="cpu", - dtype=dtype, - pin_memory=pin_memory, - ) - top_ps_t = torch.tensor( - top_ps, - device="cpu", - dtype=dtype, - pin_memory=pin_memory, - ) - min_ps_t = torch.tensor( - min_ps, - device="cpu", - dtype=dtype, - pin_memory=pin_memory, - ) - presence_penalties_t = torch.tensor( - presence_penalties, - device="cpu", - dtype=dtype, - pin_memory=pin_memory, - ) - frequency_penalties_t = torch.tensor( - frequency_penalties, - device="cpu", - dtype=dtype, - pin_memory=pin_memory, - ) - repetition_penalties_t = torch.tensor( - repetition_penalties, - device="cpu", - dtype=dtype, - pin_memory=pin_memory, - ) - top_ks_t = torch.tensor( - top_ks, - device="cpu", - dtype=torch.int, - pin_memory=pin_memory, - ) - # Because the memory is pinned, we can do non-blocking - # transfer to device. + generators: List[Optional[torch.Generator]] + no_generator: bool - return cls( - temperatures=temperatures_t.to(device=device, non_blocking=True), - top_ps=top_ps_t.to(device=device, non_blocking=True), - top_ks=top_ks_t.to(device=device, non_blocking=True), - min_ps=min_ps_t.to(device=device, non_blocking=True), - presence_penalties=presence_penalties_t.to(device=device, - non_blocking=True), - frequency_penalties=frequency_penalties_t.to(device=device, - non_blocking=True), - repetition_penalties=repetition_penalties_t.to(device=device, - non_blocking=True), - prompt_tokens=prompt_t.to(device=device, non_blocking=True), - output_tokens=output_t.to(device=device, non_blocking=True), - ) + max_num_logprobs: int diff --git a/vllm/request.py b/vllm/request.py new file mode 100644 index 0000000000000..69e4264249a65 --- /dev/null +++ b/vllm/request.py @@ -0,0 +1,112 @@ +import enum +from dataclasses import dataclass +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Mapping, + Optional, Union) + +import torch + +from vllm.sampling_params import SamplingParams +from vllm.lora.request import LoRARequest + +if TYPE_CHECKING: + from vllm.inputs import LLMInputs + from vllm.multimodal.base import MultiModalDataDict + + +class Request: + + def __init__( + self, + request_id: str, + inputs: "LLMInputs", + arrival_time: float, + sampling_params: SamplingParams, + lora_request: Optional[LoRARequest] = None, + ) -> None: + self.request_id = request_id + self.inputs = inputs + self.metrics = RequestMetrics(arrival_time=arrival_time, + last_token_time=arrival_time, + first_scheduled_time=None, + first_token_time=None, + time_in_queue=None) + self.sampling_params = sampling_params + self.lora_request = lora_request + + self.status = RequestStatus.WAITING + self.stop_reason: Union[int, str, None] = None + self.max_tokens = sampling_params.max_tokens + + self.num_prompt_tokens = len(inputs["prompt_token_ids"]) + self.num_output_tokens = 0 + self.num_computed_tokens = 0 + + @property + def num_tokens(self) -> int: + return self.num_prompt_tokens + self.num_output_tokens + + def is_finished(self) -> bool: + return RequestStatus.is_finished(self.status) + + +class RequestStatus(enum.IntEnum): + """Status of a sequence.""" + WAITING = 0 + RUNNING = 1 + PREEMPTED = 2 + # Note: anything after SWAPPED (2) will be considered + # as a finished status. + FINISHED_STOPPED = 3 + FINISHED_LENGTH_CAPPED = 4 + FINISHED_ABORTED = 5 + FINISHED_IGNORED = 6 + + @staticmethod + def is_finished(status: "RequestStatus") -> bool: + return status > RequestStatus.PREEMPTED + + @staticmethod + def get_finished_reason(status: "RequestStatus") -> Union[str, None]: + if status == RequestStatus.FINISHED_STOPPED: + finish_reason = "stop" + elif status == RequestStatus.FINISHED_LENGTH_CAPPED: + finish_reason = "length" + elif status == RequestStatus.FINISHED_ABORTED: + finish_reason = "abort" + elif status == RequestStatus.FINISHED_IGNORED: + # The ignored sequences are the sequences whose prompt lengths + # are longer than the model's length cap. Therefore, the stop + # reason should also be "length" as in OpenAI API. + finish_reason = "length" + else: + finish_reason = None + return finish_reason + + +@dataclass +class RequestMetrics: + """Metrics associated with a request. + + Attributes: + arrival_time: The time when the request arrived. + first_scheduled_time: The time when the request was first scheduled. + first_token_time: The time when the first token was generated. + time_in_queue: The time the request spent in the queue. + finished_time: The time when the request was finished. + scheduler_time: The time spent in the scheduler when this request was + being considered by the scheduler. + model_forward_time: The time spent in the model forward pass when this + request was in the batch. + model_execute_time: The time spent in the model execute function. This + will include model forward, block/sync across + workers, cpu-gpu sync time and sampling time. + """ + arrival_time: float + last_token_time: float + first_scheduled_time: Optional[float] + first_token_time: Optional[float] + time_in_queue: Optional[float] + finished_time: Optional[float] = None + scheduler_time: Optional[float] = None + model_forward_time: Optional[float] = None + model_execute_time: Optional[float] = None diff --git a/vllm/sampler_output.py b/vllm/sampler_output.py new file mode 100644 index 0000000000000..1fbc4ed8f6e3a --- /dev/null +++ b/vllm/sampler_output.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass +from typing import Optional + +import torch + + +@dataclass +class SamplerOutput: + + sampled_token_ids: torch.Tensor + + logprob_token_ids: Optional[torch.Tensor] + logprobs: Optional[torch.Tensor] + + prompt_logprob_token_ids: Optional[torch.Tensor] + prompt_logprobs: Optional[torch.Tensor] + + model_forward_time: float + model_execute_time: float diff --git a/vllm/sequence.py b/vllm/sequence.py deleted file mode 100644 index d8e54ff1fc708..0000000000000 --- a/vllm/sequence.py +++ /dev/null @@ -1,1329 +0,0 @@ -"""Sequence and its related classes.""" -import copy -import enum -from abc import ABC, abstractmethod -from array import array -from collections import defaultdict -from dataclasses import dataclass -from functools import cached_property, reduce -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional -from typing import Sequence as GenericSequence -from typing import Set, Tuple, Union, cast - -import msgspec -import torch - -from vllm.inputs.parse import is_valid_encoder_decoder_llm_inputs -from vllm.lora.request import LoRARequest -from vllm.pooling_params import PoolingParams -from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams -from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics - -if TYPE_CHECKING: - from vllm.inputs import LLMInputs - from vllm.multimodal.base import MultiModalDataDict - -VLLM_TOKEN_ID_ARRAY_TYPE = "l" - - -# We use dataclass for now because it is used for -# openai server output, and msgspec is not serializable. -# TODO(sang): Fix it. -@dataclass -class Logprob: - """Infos for supporting OpenAI compatible logprobs and token ranks. - - Attributes: - logprob: The logprob of chosen token - rank: The vocab rank of chosen token (>=1) - decoded_token: The decoded chosen token index - """ - logprob: float - rank: Optional[int] = None - decoded_token: Optional[str] = None - - -# {token_id -> logprob} per each sequence group. None if the corresponding -# sequence group doesn't require prompt logprob. -PromptLogprobs = List[Optional[Dict[int, Logprob]]] -# {token_id -> logprob} for each sequence group. -SampleLogprobs = List[Dict[int, Logprob]] - - -class SequenceStatus(enum.IntEnum): - """Status of a sequence.""" - WAITING = 0 - RUNNING = 1 - SWAPPED = 2 - # Note: anything after SWAPPED (2) will be considered - # as a finished status. - FINISHED_STOPPED = 3 - FINISHED_LENGTH_CAPPED = 4 - FINISHED_ABORTED = 5 - FINISHED_IGNORED = 6 - - @staticmethod - def is_finished(status: "SequenceStatus") -> bool: - return status > SequenceStatus.SWAPPED - - @staticmethod - def get_finished_reason(status: "SequenceStatus") -> Union[str, None]: - if status == SequenceStatus.FINISHED_STOPPED: - finish_reason = "stop" - elif status == SequenceStatus.FINISHED_LENGTH_CAPPED: - finish_reason = "length" - elif status == SequenceStatus.FINISHED_ABORTED: - finish_reason = "abort" - elif status == SequenceStatus.FINISHED_IGNORED: - # The ignored sequences are the sequences whose prompt lengths - # are longer than the model's length cap. Therefore, the stop - # reason should also be "length" as in OpenAI API. - finish_reason = "length" - else: - finish_reason = None - return finish_reason - - -class SequenceStage(enum.Enum): - PREFILL = enum.auto() - DECODE = enum.auto() - - -@dataclass -class RequestMetrics: - """Metrics associated with a request. - - Attributes: - arrival_time: The time when the request arrived. - first_scheduled_time: The time when the request was first scheduled. - first_token_time: The time when the first token was generated. - time_in_queue: The time the request spent in the queue. - finished_time: The time when the request was finished. - scheduler_time: The time spent in the scheduler when this request was - being considered by the scheduler. - model_forward_time: The time spent in the model forward pass when this - request was in the batch. - model_execute_time: The time spent in the model execute function. This - will include model forward, block/sync across - workers, cpu-gpu sync time and sampling time. - """ - arrival_time: float - last_token_time: float - first_scheduled_time: Optional[float] - first_token_time: Optional[float] - time_in_queue: Optional[float] - finished_time: Optional[float] = None - scheduler_time: Optional[float] = None - model_forward_time: Optional[float] = None - model_execute_time: Optional[float] = None - - -class SequenceDataDelta( - msgspec.Struct, - array_like=True, # type: ignore[call-arg] - omit_defaults=True): # type: ignore[call-arg] - """Delta SequenceData to send to workers per step.""" - # A new token to be appended to existing SequenceData. - new_output_token_ids: List[int] - # Overwriting existing `cumulative_logprob` - new_cumulative_logprob: float - # Overwriting existing `num_computed_tokens`. - new_num_computed_tokens: int - # Overwriting existing `stage`. - new_stage: SequenceStage - - -class SequenceData(msgspec.Struct, - omit_defaults=True): # type: ignore[call-arg] - """Data associated with a sequence. - - Args: - prompt_token_ids: The token IDs of the prompt. - output_token_ids: The token IDs of the output. Set to an empty list if - None. - - Attributes: - prompt_token_ids: The token IDs of the prompt. - output_token_ids: The token IDs of the output. - cumulative_logprob: The cumulative log probability of the output. - """ - # NOTE: we cannot use Union[List, array] because msgspec cannot support - # union of 2 list types. - _prompt_token_ids: array - _output_token_ids: array = msgspec.field( - default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, [])) - - ### The below fields should not be passed as an argument ### - _cumulative_logprob: float = 0.0 - _prompt_token_ids_tuple: Tuple[int, - ...] = msgspec.field(default_factory=tuple) - # The number of tokens that are computed (that run against the model). - _num_computed_tokens: int = 0 - _stage: SequenceStage = SequenceStage.PREFILL - _cached_all_token_ids: List[int] = msgspec.field(default_factory=list) - - # It is used to get delta input. It is reset when `get_delta_and_reset` - # is called. - _new_appended_tokens: List[int] = msgspec.field(default_factory=list) - - # It is used to compute mrope_position_ids. - _mrope_position_delta: Optional[int] = None - - @staticmethod - def from_token_counts(*token_counts: Tuple[int, int]) -> "SequenceData": - if len(token_counts) == 0: - return SequenceData.from_seqs([]) - - arrs = [ - array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count - for token_id, count in token_counts - ] - - return SequenceData(reduce(array.__add__, arrs)) - - @staticmethod - def from_seqs( - prompt_token_ids: GenericSequence[int], - output_token_ids: Optional[GenericSequence[int]] = None, - ) -> "SequenceData": - prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE, - prompt_token_ids) - - if output_token_ids is None: - return SequenceData(prompt_token_ids_arr) - - output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE, - output_token_ids) - - return SequenceData(prompt_token_ids_arr, - _output_token_ids=output_token_ids_arr) - - def __post_init__(self) -> None: - assert self._prompt_token_ids.typecode == "l" - assert self._output_token_ids.typecode == "l" - self._prompt_token_ids_tuple: Tuple[int, ...] = tuple( - self._prompt_token_ids) - self._update_cached_all_tokens() - - def _update_cached_all_tokens(self): - assert isinstance(self._prompt_token_ids, array) - assert isinstance(self._output_token_ids, array) - self._cached_all_token_ids: List[int] = list(self._prompt_token_ids + - self._output_token_ids) - - @property - def cumulative_logprob(self) -> float: - return self._cumulative_logprob - - @property - def prompt_token_ids(self) -> Tuple[int, ...]: - return self._prompt_token_ids_tuple - - @prompt_token_ids.setter - def prompt_token_ids(self, new_prompt_token_ids) -> None: - raise NotImplementedError - - @property - def prompt_token_ids_array(self) -> array: - """Return the prompt token ids in array type. - - Note that the array is in "I" type, and it is not compatible - with torch.long (2 bytes vs 4 bytes). So beware of the usage. - """ - return self._prompt_token_ids - - @property - def output_token_ids(self) -> Tuple[int, ...]: - return tuple(self._output_token_ids) - - @output_token_ids.setter - def output_token_ids(self, new_output_token_ids: List[int]) -> None: - self._output_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, - new_output_token_ids) - self._update_cached_all_tokens() - - @property - def output_token_ids_array(self) -> array: - """Return the prompt token ids in array type. - - Note that the array is in "I" type, and it is not compatible - with torch.long (2 bytes vs 4 bytes). So beware of the usage. - """ - assert isinstance(self._output_token_ids, array) - return self._output_token_ids - - @property - def mrope_position_delta(self) -> Optional[int]: - return self._mrope_position_delta - - @mrope_position_delta.setter - def mrope_position_delta(self, new_mrope_position_delta): - self._mrope_position_delta = new_mrope_position_delta - - def append_token_id(self, token_id: int, logprob: float) -> None: - self._output_token_ids.append(token_id) - self._new_appended_tokens.append(token_id) - self._cached_all_token_ids.append(token_id) - self._cumulative_logprob += logprob - - def get_len(self) -> int: - return len(self._output_token_ids) + len(self._prompt_token_ids) - - def get_prompt_len(self) -> int: - return len(self._prompt_token_ids) - - def get_output_len(self) -> int: - return len(self._output_token_ids) - - def get_token_ids(self) -> List[int]: - return self._cached_all_token_ids - - def get_prefix_token_ids( - self, num_tokens: int - ) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]: - """Get prefix tokens, and make the return value hashable""" - prompt_length = self.get_prompt_len() - if num_tokens > prompt_length: - return (self._prompt_token_ids_tuple, - tuple(self._output_token_ids[:num_tokens - prompt_length])) - else: - return (self._prompt_token_ids_tuple[:num_tokens], None) - - def get_num_computed_tokens(self) -> int: - """Return the number of prefill tokens that are already computed.""" - return self._num_computed_tokens - - def update_num_computed_tokens(self, num_new_computed_tokens: int): - """Update number of tokens computed so far.""" - self._num_computed_tokens += num_new_computed_tokens - assert self._num_computed_tokens <= self.get_len(), ( - self._num_computed_tokens, self.get_len()) - # If all tokens are computed, it means it is in decoding phase. - if self.get_num_uncomputed_tokens() == 0: - self._stage = SequenceStage.DECODE - - def reset_state_for_recompute(self) -> None: - """Reset the number of computed tokens from this sequence. It is - supposed to be called when a sequence needs to be started from - the beginning again (e.g., sequence is preempted). - """ - self._num_computed_tokens = 0 - self._stage = SequenceStage.PREFILL - self._new_appended_tokens = [] - - def get_num_uncomputed_tokens(self) -> int: - """Return the number of prefill tokens that are not computed.""" - # we use `get_len()` which includes prompt_len + output_len instead - # of prompt_len here. This is because during recompute we need to - # prefill for both prompt and output. - return self.get_len() - self.get_num_computed_tokens() - - def get_last_token_id(self) -> int: - if not self._output_token_ids: - return self._prompt_token_ids[-1] - return self._output_token_ids[-1] - - def get_prompt_token_ids(self) -> Tuple[int, ...]: - return self.prompt_token_ids - - def get_output_token_ids(self) -> Tuple[int, ...]: - return self.output_token_ids - - def get_delta_and_reset(self) -> SequenceDataDelta: - delta = SequenceDataDelta(self._new_appended_tokens, - self._cumulative_logprob, - self.get_num_computed_tokens(), self.stage) - # Reset delta state. - self._new_appended_tokens = [] - return delta - - def apply_delta(self, delta: SequenceDataDelta): - self._num_computed_tokens = delta.new_num_computed_tokens - self._cumulative_logprob = delta.new_cumulative_logprob - self._stage = delta.new_stage - self._output_token_ids.extend(delta.new_output_token_ids) - self._cached_all_token_ids.extend(delta.new_output_token_ids) - - @property - def stage(self) -> SequenceStage: - return self._stage - - def __repr__(self) -> str: - return (f"SequenceData(" - f"prompt_token_ids={self._prompt_token_ids}, " - f"output_token_ids={self.output_token_ids}, " - f"cumulative_logprob={self.cumulative_logprob}, " - f"get_num_computed_tokens={self.get_num_computed_tokens()}") - - -class Sequence: - """Stores the data, status, and block information of a sequence. - - The sequence is constructed from the LLMInputs instance passed - in through the `inputs` constructor argument. - - For encoder/decoder models, LLMInputs encapsulates both a - decoder and encoder prompt, creating an ambiguity about which - prompt to construct the sequence from. The `from_decoder_prompt` - constructor argument signals whether to construct the Sequence - from the LLMInputs decoder prompt, or encoder prompt. - - Args: - seq_id: The ID of the sequence. - inputs: The inputs of the sequence. - block_size: The block size of the sequence. Should be the same as the - block size used by the block manager and cache engine. - eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM. - lora_request: LoRA request. - prompt_adapter_request: Prompt Adapter request. - from_decoder_prompt: Construct Sequence from LLMInputs decoder prompt - (True) or encoder prompt (False.) Must be True - for decoder-only model. - - """ - - def __init__( - self, - seq_id: int, - inputs: "LLMInputs", - block_size: int, - eos_token_id: Optional[int] = None, - lora_request: Optional[LoRARequest] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - from_decoder_prompt: bool = True, - ) -> None: - self.seq_id = seq_id - self.inputs = inputs - self.block_size = block_size - self.eos_token_id = eos_token_id - self.lora_request = lora_request - self.prompt_adapter_request = prompt_adapter_request - self.from_decoder_prompt = from_decoder_prompt - - # For decoder-only models, a Sequence is constructed - # from an LLMInputs instance (the `inputs` arg.) - # - # For encoder/decoder models the same `inputs` - # instance could be utilized to construct either an - # encoder sequence or a decoder sequence, because - # `LLMInputs` has both decoder- and encoder-oriented - # member variables (i.e. it encapsulates both an encoder - # and a decoder prompt.) The decision of which type of sequence - # to generate is determined by the `from_decoder_prompt` argument. - # - # When constructing a encoder sequence - # (`from_decoder_prompt` False) it matters that - # the `LLMInputs` instance stored in `inputs` is valid - # in the sense that its encoder-related member variables are - # populated; below, an exception is raised if this is - # not the case. - # - # When constructing a decoder sequence (`from_decoder_prompt` True) - # it does not matter whether `inputs` has its encoder-related - # member variables populated. - if not (from_decoder_prompt - or is_valid_encoder_decoder_llm_inputs(inputs)): - raise ValueError("Cannot extract encoder input prompt from " - f"invalid input {inputs}; did you forget the " - "encoder input prompt fields?") - - self.data = SequenceData.from_seqs(self.prompt_token_ids) - self.output_logprobs: SampleLogprobs = [] - self.output_text = "" - - self.status = SequenceStatus.WAITING - self.stop_reason: Union[int, str, None] = None - - # These are used to keep track of delta outputs - self._last_token_ids_offset: int = 0 - self._last_output_text_offset: int = 0 - - # Used for incremental detokenization - self.prefix_offset = 0 - self.read_offset = 0 - # Input + output tokens - self.tokens: Optional[List[str]] = None - - @property - def n_blocks(self) -> int: - return (self.get_len() + self.block_size - 1) // self.block_size - - @cached_property - def prompt(self) -> Optional[str]: - # Select decoder or encoder input prompt str, as appropriate - prompt_key: str = ("prompt" - if self.from_decoder_prompt else "encoder_prompt") - - return cast(Optional[str], self.inputs.get(prompt_key)) - - @cached_property - def prompt_token_ids(self) -> List[int]: - # Select decoder or encoder input prompt token ids, as appropriate - prompt_token_ids_key: str = ("prompt_token_ids" - if self.from_decoder_prompt else - "encoder_prompt_token_ids") - - # Cache computed prompt token ids - return cast(List[int], self.inputs.get(prompt_token_ids_key)) - - @property - def multi_modal_data(self) -> "MultiModalDataDict": - return self.inputs.get("multi_modal_data") or {} - - @property - def lora_int_id(self) -> int: - return self.lora_request.lora_int_id if self.lora_request else 0 - - @property - def prompt_adapter_id(self) -> int: - return self.prompt_adapter_request.prompt_adapter_id \ - if self.prompt_adapter_request else 0 - - def get_output_text_to_return(self, buffer_length: int, - delta: bool) -> str: - """If delta is True, only new text since the last call to - this method is returned""" - - # We return the full output text if the sequence is finished. - truncate = buffer_length and not self.is_finished() - if not delta: - return self.output_text[:-buffer_length] if truncate else ( - self.output_text) - length = len(self.output_text) - if truncate: - length -= buffer_length - last_offset = self._last_output_text_offset - if last_offset < length: - self._last_output_text_offset = length - return self.output_text[last_offset:length] - return "" - - def get_output_token_ids_to_return(self, - delta: bool) -> GenericSequence[int]: - """If delta is True, only new tokens since the last call to - this method are returned""" - if not delta: - return self.get_output_token_ids() - length = self.get_output_len() - last_offset = self._last_token_ids_offset - if last_offset < length: - self._last_token_ids_offset = length - return self.data._output_token_ids[last_offset:] - return () - - def hash_of_block(self, logical_idx: int) -> int: - # TODO This can produce incorrect hash when block size > prompt size - - # Compute the number of tokens in the sequence - # TODO: The current hashing function is O(L^2). We should optimize - # this in the future. - num_tokens = self.num_hashed_tokens_of_block(logical_idx) - hashed_tokens = self.data.get_prefix_token_ids(num_tokens) - return hash((hashed_tokens, self.lora_int_id)) - - def num_hashed_tokens_of_block(self, logical_idx: int): - return logical_idx * self.block_size + self.block_size - - def reset_state_for_recompute(self): - """Reset the sequence states for recomputation.""" - self.data.reset_state_for_recompute() - - def append_token_id(self, token_id: int, logprobs: Dict[int, - Logprob]) -> None: - assert token_id in logprobs - self.output_logprobs.append(logprobs) - self.data.append_token_id(token_id, logprobs[token_id].logprob) - - def get_len(self) -> int: - return self.data.get_len() - - def get_prompt_len(self) -> int: - return self.data.get_prompt_len() - - def get_output_len(self) -> int: - return self.data.get_output_len() - - def get_token_ids(self) -> List[int]: - return self.data.get_token_ids() - - def get_prompt_token_ids(self) -> Tuple[int, ...]: - return self.data.get_prompt_token_ids() - - def get_last_token_id(self) -> int: - return self.data.get_last_token_id() - - def get_output_token_ids(self) -> Tuple[int, ...]: - return self.data.get_output_token_ids() - - def get_cumulative_logprob(self) -> float: - return self.data.cumulative_logprob - - def get_beam_search_score(self, - length_penalty: float = 1.0, - seq_len: Optional[int] = None, - eos_token_id: Optional[int] = None) -> float: - """Calculate the beam search score with length penalty. - - Adapted from - - https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938 - """ - if seq_len is None: - seq_len = self.get_len() - # NOTE: HF implementation does not count the EOS token - # towards the length, we align with that here for testing. - if (eos_token_id is not None - and self.get_last_token_id() == eos_token_id): - seq_len -= 1 - return self.get_cumulative_logprob() / (seq_len**length_penalty) - - def is_finished(self) -> bool: - return SequenceStatus.is_finished(self.status) - - def fork(self, new_seq_id: int) -> "Sequence": - new_seq = copy.deepcopy(self) - new_seq.seq_id = new_seq_id - return new_seq - - def get_num_new_tokens(self) -> int: - """Get the number of new tokens to be computed. - - Returns: - The new number of tokens to be computed. I.e., 1 for decode, or - the remaining prompt size for prefill. - """ - if self.data.stage == SequenceStage.DECODE: - return 1 - return self.data.get_num_uncomputed_tokens() - - def is_prefill(self) -> bool: - return self.data.stage == SequenceStage.PREFILL - - def __repr__(self) -> str: - return (f"Sequence(seq_id={self.seq_id}, " - f"status={self.status.name}, " - f"num_blocks={self.n_blocks}, ") - - -class SequenceGroupState(msgspec.Struct, - omit_defaults=True): # type: ignore[call-arg] - """Mutable state tied to a specific sequence group""" - - # for multi-step decoding - num_steps: int = 1 - current_step: int = 0 - - @property - def remaining_steps(self) -> int: - return self.num_steps - self.current_step - - -class SequenceGroup: - """A group of sequences that are generated from the same prompt. - - Args: - request_id: The ID of the request. - seqs: The list of sequences. - sampling_params: The sampling parameters used to generate the outputs. - arrival_time: The arrival time of the request. - lora_request: LoRA request. - embeddings: The embeddings vectors of the prompt of the sequence group - for an embedding model. - pooling_params: The pooling parameters used to generate the pooling - for an embedding model. - encoder_seq: Optional, the single encoder sequence. Should be None - unless you are working with an encoder/decoder model. - trace_headers: OpenTelemetry trace headers. - prompt_adapter_request: Prompt Adapter request. - """ - - def __init__( - self, - request_id: str, - seqs: List[Sequence], - arrival_time: float, - sampling_params: Optional[SamplingParams] = None, - lora_request: Optional[LoRARequest] = None, - embeddings: Optional[List[float]] = None, - pooling_params: Optional[PoolingParams] = None, - encoder_seq: Optional[Sequence] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> None: - self.request_id = request_id - self.seqs = seqs - self.is_single_seq = len(seqs) == 1 - self.seqs_dict = {seq.seq_id: seq for seq in seqs} - - self.sampling_params = sampling_params - self.metrics = RequestMetrics(arrival_time=arrival_time, - last_token_time=arrival_time, - first_scheduled_time=None, - first_token_time=None, - time_in_queue=None) - self.lora_request = lora_request - self.prompt_logprobs: Optional[PromptLogprobs] = None - self.state = SequenceGroupState() - self.embeddings = embeddings - self.pooling_params = pooling_params - self.prompt_adapter_request = prompt_adapter_request - self.encoder_seq = encoder_seq - self.trace_headers = trace_headers - - @property - def prompt(self) -> Optional[str]: - # All sequences in the group should have the same prompt. - # We use the prompt of an arbitrary sequence. - return self.seqs[0].prompt - - @property - def prompt_token_ids(self) -> List[int]: - # All sequences in the group should have the same prompt. - # We use the prompt of an arbitrary sequence. - return self.seqs[0].prompt_token_ids - - @property - def encoder_prompt(self) -> Optional[str]: - # There are either 0 or 1 encoder sequences - # If one is present, its prompt is distinct - # from the decoder's. - return (self.encoder_seq.prompt - if self.encoder_seq is not None else None) - - @property - def encoder_prompt_token_ids(self) -> Optional[List[int]]: - # There are either 0 or 1 encoder sequences - # If one is present, its prompt token ids are - # distinct from the decoder's. - return (self.encoder_seq.prompt_token_ids - if self.encoder_seq is not None else None) - - @property - def multi_modal_data(self) -> "MultiModalDataDict": - # All sequences in the group should have the same multi-modal data. - # We use the multi-modal data of an arbitrary sequence. - return self.seqs[0].multi_modal_data - - @property - def lora_int_id(self) -> int: - return self.lora_request.lora_int_id if self.lora_request else 0 - - @property - def prompt_adapter_id(self) -> int: - return self.prompt_adapter_request.prompt_adapter_id \ - if self.prompt_adapter_request else 0 - - @property - def prompt_adapter_num_virtual_tokens(self) -> int: - return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\ - if self.prompt_adapter_request else 0 - - def init_multi_step(self, num_scheduler_steps: int) -> None: - self.state.num_steps = num_scheduler_steps - self.state.current_step = 0 - - def get_last_latency(self, now: float) -> Optional[float]: - """Sets the last token time for Request level timings.""" - # If still in prefill phase, raise Error. - if self.is_prefill(): - raise ValueError( - "seq_group.get_last_latency() should not be called " - "if the seq_group is in prefill phase.") - - # Otherwise return token latency. - latency = now - self.metrics.last_token_time - self.metrics.last_token_time = now - return latency - - def maybe_set_first_token_time(self, time: float) -> None: - """Sets the first token time for Request level timings.""" - # Note: in a case where a sequence_group is swapped and - # recomputed, the time between iterations is counted - # in TPOT, rather than recalculating TTFT (since from the ) - # POV of the user, there is simply a long generation delay. - if (self.metrics.first_token_time is None - and self.seqs[0].get_output_len() == 1): - self.metrics.first_token_time = time - - def maybe_set_first_scheduled_time(self, time: float) -> None: - """Sets the first scheduled time and time in queue for Request - level timings.""" - if self.metrics.first_scheduled_time is None: - self.metrics.first_scheduled_time = time - self.metrics.time_in_queue = time - self.metrics.arrival_time - - def set_finished_time(self, time: Optional[float]) -> None: - """Sets the finished time for Request level timings.""" - self.metrics.finished_time = time - - def get_max_num_running_seqs(self) -> int: - """The maximum number of sequences running in parallel in the remaining - lifetime of the request.""" - if self.sampling_params and self.sampling_params.use_beam_search: - # For beam search, maximally there will always be `best_of` beam - # candidates running in the future. - best_of = self.sampling_params.best_of - assert isinstance(best_of, int) - return best_of - else: - if self.sampling_params: - best_of = self.sampling_params.best_of - assert isinstance(best_of, int) - if best_of > self.num_seqs(): - # At prompt stage, the sequence group is not yet filled up - # and only have one sequence running. However, in the - # generation stage, we will have `best_of` sequences - # running. - return best_of - # At sampling stages, return the number of actual sequences - # that are not finished yet. - return self.num_unfinished_seqs() - - def get_seqs( - self, - status: Optional[SequenceStatus] = None, - ) -> List[Sequence]: - if status is None: - return self.seqs - - if self.is_single_seq: - return self.seqs if self.seqs[0].status == status else [] - - return [seq for seq in self.seqs if seq.status == status] - - def is_encoder_decoder(self) -> bool: - return self.encoder_seq is not None - - def get_encoder_seq(self) -> Optional[Sequence]: - return self.encoder_seq - - def get_unfinished_seqs(self) -> List[Sequence]: - if self.is_single_seq: - return self.seqs if not self.seqs[0].is_finished() else [] - - return [seq for seq in self.seqs if not seq.is_finished()] - - def get_finished_seqs(self) -> List[Sequence]: - if self.is_single_seq: - return self.seqs if self.seqs[0].is_finished() else [] - - return [seq for seq in self.seqs if seq.is_finished()] - - def update_num_computed_tokens(self, num_new_computed_tokens: int): - """Update number of tokens computed so far.""" - for seq in self.seqs: - if not seq.is_finished(): - seq.data.update_num_computed_tokens(num_new_computed_tokens) - - def get_num_uncomputed_tokens(self) -> int: - num_uncomputed_tokens = 0 - for seq in self.seqs: - if not seq.is_finished(): - num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens() - return num_uncomputed_tokens - - def num_seqs(self, status: Optional[SequenceStatus] = None) -> int: - # Optimization. We don't need to call get_seqs if we don't need to - # filter by states. - if status is None: - return len(self.seqs) - - if self.is_single_seq: - return 1 if self.seqs[0].status == status else 0 - - return len(self.get_seqs(status)) - - def num_unfinished_seqs(self) -> int: - if self.is_single_seq: - return 1 if not self.seqs[0].is_finished() else 0 - - return len(self.get_unfinished_seqs()) - - def num_finished_seqs(self) -> int: - if self.is_single_seq: - return 1 if self.seqs[0].is_finished() else 0 - - return len(self.get_finished_seqs()) - - def find(self, seq_id: int) -> Sequence: - if seq_id not in self.seqs_dict: - raise ValueError(f"Sequence {seq_id} not found.") - return self.seqs_dict[seq_id] - - def add(self, seq: Sequence) -> None: - if seq.seq_id in self.seqs_dict: - raise ValueError(f"Sequence {seq.seq_id} already exists.") - self.seqs_dict[seq.seq_id] = seq - self.seqs.append(seq) - self.is_single_seq = len(self.seqs) == 1 - - def remove(self, seq_id: int) -> None: - seq = self.seqs_dict.pop(seq_id, None) - if seq is None: - raise ValueError(f"Sequence {seq_id} not found.") - self.seqs.remove(seq) - self.is_single_seq = len(self.seqs) == 1 - - def is_finished(self) -> bool: - if self.is_single_seq: - return self.seqs[0].is_finished() - - return all(seq.is_finished() for seq in self.seqs) - - def is_prefill(self) -> bool: - # Every sequence should be in the same stage. - return self.seqs[0].is_prefill() - - def __repr__(self) -> str: - return (f"SequenceGroup(request_id={self.request_id}, " - f"sampling_params={self.sampling_params}, " - f"num_seqs={len(self.seqs)})") - - -class SequenceGroupMetadataDelta( - msgspec.Struct, - tag=True, # type: ignore[call-arg] - array_like=True, # type: ignore[call-arg] - omit_defaults=True): # type: ignore[call-arg] - """Delta of SequenceGroupMetadata. - - After sending the first SequenceGroupMetadata, vLLM scheduler - only sends delta to reduce the data payload size. - """ - seq_data_delta: Dict[int, SequenceDataDelta] - request_id: str - block_tables: Dict[int, List[int]] - is_prompt: bool - do_sample: bool = True - token_chunk_size: Optional[int] = None - computed_block_nums: Optional[List[int]] = None - state: Optional[SequenceGroupState] = msgspec.field( - default_factory=lambda: SequenceGroupState()) - - -class SequenceGroupMetadata( - msgspec.Struct, - tag=True, # type: ignore[call-arg] - array_like=True, # type: ignore[call-arg] - omit_defaults=True): # type: ignore[call-arg] - """Metadata for a sequence group. Used to create `AttentionMetadata`. - - Args: - request_id: The ID of the request. - is_prompt: Whether the request is at prompt stage. - seq_data: The sequence data. (Seq id -> sequence data) - sampling_params: The sampling parameters used to generate the outputs. - block_tables: The block tables. (Seq id -> list of physical block - numbers) - do_sample: True if sampling is required. Sampling is not required when - e.g., prefill is chunked, and the current iteration only computes - query tokens for prefill, we don't need sampling. - token_chunk_size: The number of tokens to be processed (per sequence). - None if chunking is not required. - lora_request: LoRA request. - computed_block_nums: The block numbers that are already computed, - used in prefix caching. - state: Internal state tied to this sequence group. - multi_modal_data: Multi modal data. - encoder_seq_data: Optional sequence data for encoder prompt - (SequenceGroup.encoder_seq). Should be None - unless you are working with an encoder/decoder - model. - cross_block_table: Optional cross-attention block table associated - with the encoder prompt - (SequenceGroup.encoder_seq). Should be None - unless you are working with an encoder/decoder - model. - prompt_adapter_request: Prompt Adapter request. - """ - - request_id: str - is_prompt: bool - seq_data: Dict[int, SequenceData] - sampling_params: Optional[SamplingParams] - block_tables: Dict[int, List[int]] - do_sample: bool = True - pooling_params: Optional[PoolingParams] = None - lora_request: Optional[LoRARequest] = None - computed_block_nums: Optional[List[int]] = None - state: Optional[SequenceGroupState] = msgspec.field( - default_factory=lambda: SequenceGroupState()) - # "MultiModalDataDict" types. We have to use Any due to msgspec - # doesn't allow to have union of 2 different dicts. - multi_modal_data: Optional[Any] = None - encoder_seq_data: Optional[SequenceData] = None - cross_block_table: Optional[List[int]] = None - prompt_adapter_request: Optional[PromptAdapterRequest] = None - token_chunk_size: Optional[int] = None - - ### Stateful fields that are lazily defined. ### - # The number of speculative tokens adopted in this request. - # None means specuative decoding is not used. - # Zero means speculative decoding is disabled for some reasons. - # TODO: We should maintain this states out of the sequence group. - num_speculative_tokens: Optional[int] = None - - def __post_init__(self): - if self.seq_data is not None and self.token_chunk_size is None: - if self.is_prompt: - self.token_chunk_size = next(iter( - self.seq_data.values())).get_len() - else: - self.token_chunk_size = 1 - - @property - def lora_int_id(self) -> int: - return self.lora_request.lora_int_id if self.lora_request else 0 - - @property - def prompt_adapter_id(self) -> int: - return self.prompt_adapter_request.prompt_adapter_id \ - if self.prompt_adapter_request else 0 - - @property - def prompt_adapter_num_virtual_tokens(self) -> int: - return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \ - if self.prompt_adapter_request else 0 - - def apply_delta(self, - sequence_group_metadata_delta: SequenceGroupMetadataDelta): - for id, delta in sequence_group_metadata_delta.seq_data_delta.items(): - self.seq_data[id].apply_delta(delta) - assert self.request_id == sequence_group_metadata_delta.request_id - self.block_tables = sequence_group_metadata_delta.block_tables - self.token_chunk_size = sequence_group_metadata_delta.token_chunk_size - self.do_sample = sequence_group_metadata_delta.do_sample - self.is_prompt = sequence_group_metadata_delta.is_prompt - - def finish_step(self) -> None: - assert self.state is not None - assert self.state.current_step < self.state.num_steps - self.state.current_step += 1 - - -class SequenceOutput( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - array_like=True): # type: ignore[call-arg] - """The model output associated with a sequence. - - Args: - parent_seq_id: The ID of the parent sequence (for forking in beam - search). - output_token: The output token ID. - logprobs: The logprobs of the output token. - (Token id -> logP(x_i+1 | x_0, ..., x_i)) - """ - parent_seq_id: int - output_token: int - logprobs: Dict[int, Logprob] - - def __repr__(self) -> str: - return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " - f"output_token={self.output_token}, " - f"logprobs={self.logprobs})") - - def __eq__(self, other: object) -> bool: - if not isinstance(other, SequenceOutput): - raise NotImplementedError() - equal = (self.parent_seq_id == other.parent_seq_id - and self.output_token == other.output_token) - log_probs_equal = other.logprobs == self.logprobs - return equal and log_probs_equal - - -class SequenceGroupOutput(ABC): - """The base class for model outputs associated with a sequence group.""" - - @abstractmethod - def __repr__(self) -> str: - pass - - @abstractmethod - def __eq__(self, other: object) -> bool: - pass - - -class CompletionSequenceGroupOutput( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - array_like=True): # type: ignore[call-arg] - __metaclass__ = SequenceGroupOutput - """The model output associated with a completion sequence group.""" - samples: List[SequenceOutput] - # Prompt logprob for each prompt query token. - prompt_logprobs: Optional[PromptLogprobs] - - def __repr__(self) -> str: - return (f"CompletionSequenceGroupOutput(samples={self.samples}, " - f"prompt_logprobs={self.prompt_logprobs})") - - def __eq__(self, other: object) -> bool: - if not isinstance(other, CompletionSequenceGroupOutput): - raise NotImplementedError() - return (self.samples == other.samples - and self.prompt_logprobs == other.prompt_logprobs) - - -class EmbeddingSequenceGroupOutput( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - array_like=True, # type: ignore[call-arg] -): - """The model output associated with an embedding sequence group.""" - __metaclass__ = SequenceGroupOutput - embeddings: List[int] - - def __repr__(self) -> str: - return (f"EmbeddingSequenceGroupOutput(" - f"embeddings_shape={len(self.embeddings)})") - - def __eq__(self, other: object) -> bool: - if not isinstance(other, EmbeddingSequenceGroupOutput): - raise NotImplementedError() - return self.embeddings == other.embeddings - - -class IntermediateTensors( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - array_like=True): # type: ignore[call-arg] - """For all pipeline stages except the last, we need to return the hidden - states and residuals to be sent to the next stage. This data structure - contains the hidden states and residuals for a request. - """ - - tensors: Dict[str, torch.Tensor] - - def __getitem__(self, key: Union[str, slice]): - if isinstance(key, str): - return self.tensors[key] - elif isinstance(key, slice): - return self.__class__({k: v[key] for k, v in self.tensors.items()}) - - def __setitem__(self, key: str, value): - self.tensors[key] = value - - def __len__(self): - return len(self.tensors) - - def __eq__(self, other: object): - return isinstance(other, self.__class__) and self - - def __repr__(self) -> str: - return f"IntermediateTensors(tensors={self.tensors})" - - -class PoolerOutput( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - array_like=True): # type: ignore[call-arg] - """The output from a pooling operation in the embedding model.""" - outputs: List[EmbeddingSequenceGroupOutput] - - spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None - - def __getitem__(self, idx: int): - return self.outputs[idx] - - def __setitem__(self, idx: int, value): - self.outputs[idx] = value - - def __len__(self): - return len(self.outputs) - - def __eq__(self, other: object): - return isinstance(other, - self.__class__) and self.outputs == other.outputs - - -def get_all_seq_ids( - seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]: - """Given a list of SequenceGroupMetadata, create a list of all - sequence ids. - """ - return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data] - - -def get_all_seq_ids_and_request_ids( - seq_group_metadata_list: List[SequenceGroupMetadata] -) -> Tuple[List[int], Dict[str, Set[int]]]: - """Given a list of SequenceGroupMetadata, create a list of all - sequence ids. - """ - seq_ids: List[int] = [] - request_id_seq_ids_mapping: Dict[str, Set[int]] = defaultdict(set) - for sg in seq_group_metadata_list: - for seq_id in sg.seq_data: - seq_ids.append(seq_id) - request_id_seq_ids_mapping[sg.request_id].add(seq_id) - return seq_ids, request_id_seq_ids_mapping - - -class HiddenStates(msgspec.Struct, array_like=True, - omit_defaults=True): # type: ignore[call-arg] - """Hidden states corresponding to in-progress sequences. - Used in speculative decoding to pass hidden states from - the target model to the proposer model. - - seq_ids are the sequence ids of each entry of the batch - dimension of the hidden_states tensor""" - # Scorer hidden states. For prefill step, it is used for hidden states of - # all tokens, whereas for decode step, it use used for last accepted tokens. - hidden_states: torch.Tensor - # The sequence group metadata list. Only needed for decode step. - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None - # Scorer hidden states of the 2nd last token proposed by the proposer ( - # irrespective of whether it was accepted or not). Only used for cases when - # last proposed token is accepted (i.e., in case of bonus tokens). For the - # case of no bonus tokens, these are ignored. - second_last_token_hidden_states: Optional[torch.Tensor] = None - - _seq_ids: List[int] = msgspec.field(default_factory=list) - - def __post_init__(self): - if self.seq_group_metadata_list is not None: - assert len(self.seq_group_metadata_list) == len(self.hidden_states) - self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list) - - @property - def seq_ids(self) -> List[int]: - return self._seq_ids - - def update(self, - hidden_states: torch.Tensor, - seq_group_metadata_list: List[SequenceGroupMetadata], - second_last_token_hidden_states: Optional[torch.Tensor] = None): - """Update hidden states from target model invocation. Only used for - decode steps""" - assert len(seq_group_metadata_list) == len(hidden_states) - self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list)) - self.hidden_states = torch.cat([self.hidden_states, hidden_states]) - - if self.second_last_token_hidden_states is not None: - # Adding dummy hidden_states to this to maintain same shape - self.second_last_token_hidden_states = torch.cat([ - self.second_last_token_hidden_states, - torch.zeros_like(hidden_states) - if second_last_token_hidden_states is None else - second_last_token_hidden_states - ]) - - def prune(self, - seq_group_metadata_list: List[SequenceGroupMetadata]) -> None: - """Prune to provided list of sequence ids. Only used for decode steps. - """ - # Currently this prunes all seq_ids not present in - # seq_group_metadata_list which might cause problems where a sequence - # may be "paused" then "resumed" later. This should only prune sequences - # which are confirmed to be aborted. - seq_ids = get_all_seq_ids(seq_group_metadata_list) - if seq_ids != self._seq_ids: - # Batch contents changed - prune removed sequences. - index = [self._seq_ids.index(seq_id) for seq_id in seq_ids] - self.hidden_states = self.hidden_states[index] - if self.second_last_token_hidden_states is not None: - self.second_last_token_hidden_states = self\ - .second_last_token_hidden_states[index] - self._seq_ids = seq_ids - - def expand_with_bonus_tokens( - self, seq_with_bonus_token_in_last_step: set) -> None: - """Expand hidden states for sequences with bonus tokens. This is in - alignment with `MultiStepWorker._expand_execute_model_request`.""" - if self.second_last_token_hidden_states is None \ - or not seq_with_bonus_token_in_last_step: - return - - index = [] - for seq_id in self._seq_ids: - i = self._seq_ids.index(seq_id) - if seq_id in seq_with_bonus_token_in_last_step: - index.append(i + len(self._seq_ids)) - index.append(i) - - self.hidden_states = torch.cat( - [self.hidden_states, self.second_last_token_hidden_states])[index] - - -class ExecuteModelRequest( - msgspec.Struct, - array_like=True, # type: ignore[call-arg] - omit_defaults=True): # type: ignore[call-arg] - """The model execution request, containing CPU metadata only. The LLM - engine should create an instance of this class for each request batch.""" - # The sequence group metadata list. - seq_group_metadata_list: List[Union[SequenceGroupMetadata, - SequenceGroupMetadataDelta]] - # Blocks to swap in. List of CPU -> GPU block number. - blocks_to_swap_in: List[Tuple[int, - int]] = msgspec.field(default_factory=list) - # Blocks to swap out. List of GPU -> CPU block number. - blocks_to_swap_out: List[Tuple[int, - int]] = msgspec.field(default_factory=list) - # Blocks to copy. Source to dest block. - blocks_to_copy: List[Tuple[int, int]] = msgspec.field(default_factory=list) - # Virtual engine ID for pipeline parallel. - virtual_engine: int = 0 - # The number of slots for lookahead decoding. - num_lookahead_slots: int = 0 - # The number of requests in the running queue. - running_queue_size: int = 0 - # Optional hidden states from prior step. - previous_hidden_states: Optional[HiddenStates] = None - # The number of forward steps to run. - num_steps: int = 1 - # Finished request ids since last step. - finished_requests_ids: List[str] = msgspec.field(default_factory=list) - # The last sampled token ids for multi step decoding. - last_sampled_token_ids: Optional[torch.Tensor] = None - # Async callback - async_callback: Optional[Callable] = None - - @property - def is_first_multi_step(self) -> bool: - # TODO(will) make this be able to handle batches with variable number of - # steps - assert len(self.seq_group_metadata_list) > 0 - first_seq_group = self.seq_group_metadata_list[0] - assert first_seq_group.state is not None - return first_seq_group.state.current_step == 0 - - @property - def is_last_step(self) -> bool: - # TODO(will) make this be able to handle batches with variable number of - # steps - assert len(self.seq_group_metadata_list) > 0 - first_seq_group = self.seq_group_metadata_list[0] - assert first_seq_group.state is not None - return first_seq_group.state.remaining_steps == 1 - - @property - def current_step(self) -> int: - # TODO(will) make this be able to handle batches with variable number of - # steps - assert len(self.seq_group_metadata_list) > 0 - state = self.seq_group_metadata_list[0].state - assert state is not None - return state.current_step - - def clone( - self, seq_group_metadata_list: List[Union[SequenceGroupMetadata, - SequenceGroupMetadataDelta]] - ) -> "ExecuteModelRequest": - """Clone the request with a new sequence group metadata list.""" - return ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=self.blocks_to_swap_in.copy(), - blocks_to_swap_out=self.blocks_to_swap_out.copy(), - blocks_to_copy=self.blocks_to_copy.copy(), - virtual_engine=self.virtual_engine, - num_lookahead_slots=self.num_lookahead_slots, - running_queue_size=self.running_queue_size, - previous_hidden_states=self.previous_hidden_states, - num_steps=self.num_steps, - finished_requests_ids=self.finished_requests_ids, - last_sampled_token_ids=self.last_sampled_token_ids.clone() - if self.last_sampled_token_ids is not None else None, - async_callback=self.async_callback) From edbcd70311a1510b22dbdef9992c21f77fb20404 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 26 Sep 2024 09:15:21 -0700 Subject: [PATCH 02/31] Minor --- vllm/core/scheduler_v2.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/vllm/core/scheduler_v2.py b/vllm/core/scheduler_v2.py index a94e3e7fb096e..4b05a69be54c7 100644 --- a/vllm/core/scheduler_v2.py +++ b/vllm/core/scheduler_v2.py @@ -80,16 +80,16 @@ def schedule(self) -> "SchedulerOutput": req_to_new_block_ids: Dict[str, List[int]] = {} num_scheduled_tokens: Dict[str, int] = {} total_num_scheduled_tokens = 0 - num_remaining_tokens = self.max_num_scheduled_tokens + token_budget = self.max_num_scheduled_tokens # First, schedule the RUNNING requests. while self.running: - if num_remaining_tokens == 0: + if token_budget == 0: break request = self.running[0] num_tokens = request.num_tokens - request.num_computed_tokens - num_tokens = min(num_tokens, num_remaining_tokens) + num_tokens = min(num_tokens, token_budget) new_block_ids: List[int] = [] while not self.block_manager.can_append_slots(request, num_tokens): @@ -116,7 +116,7 @@ def schedule(self) -> "SchedulerOutput": req_to_new_block_ids[request.request_id] = new_block_ids num_scheduled_tokens[request.request_id] = num_tokens total_num_scheduled_tokens += num_tokens - num_remaining_tokens -= num_tokens + token_budget -= num_tokens request.status = RequestStatus.RUNNING request.num_computed_tokens += num_tokens @@ -130,7 +130,7 @@ def schedule(self) -> "SchedulerOutput": break if len(self.running) == self.max_num_running_reqs: break - if num_remaining_tokens == 0: + if token_budget == 0: break request = self.waiting[0] @@ -148,7 +148,7 @@ def schedule(self) -> "SchedulerOutput": # Number of tokens to be scheduled. num_tokens = request.num_tokens - num_computed_tokens - num_tokens = min(num_tokens, num_remaining_tokens) + num_tokens = min(num_tokens, token_budget) self.waiting.popleft() self.running.append(request) @@ -163,7 +163,7 @@ def schedule(self) -> "SchedulerOutput": computed_block_ids + new_block_ids) num_scheduled_tokens[request.request_id] = num_tokens total_num_scheduled_tokens += num_tokens - num_remaining_tokens -= num_tokens + token_budget -= num_tokens request.status = RequestStatus.RUNNING request.num_computed_tokens = num_computed_tokens + num_tokens @@ -172,7 +172,7 @@ def schedule(self) -> "SchedulerOutput": # Check if the scheduling constraints are satisfied. assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens - assert num_remaining_tokens >= 0 + assert token_budget >= 0 assert len(self.running) <= self.max_num_running_reqs assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(scheduled_running_reqs) == len(self.running)) From a817fe42d191ff018dedce7d3ab4ef17ecba56f2 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 27 Sep 2024 02:11:43 -0700 Subject: [PATCH 03/31] Minor --- vllm/core/scheduler_v2.py | 54 ++++++++++++++++++++------------------- 1 file changed, 28 insertions(+), 26 deletions(-) diff --git a/vllm/core/scheduler_v2.py b/vllm/core/scheduler_v2.py index 4b05a69be54c7..a656d8fda8cc2 100644 --- a/vllm/core/scheduler_v2.py +++ b/vllm/core/scheduler_v2.py @@ -13,7 +13,6 @@ from vllm.lora.request import LoRARequest from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.utils import Device - from vllm.request import Request, RequestStatus from vllm.sampling_params import SamplingParams from vllm.multimodal import MultiModalDataDict @@ -79,7 +78,6 @@ def schedule(self) -> "SchedulerOutput": req_to_new_block_ids: Dict[str, List[int]] = {} num_scheduled_tokens: Dict[str, int] = {} - total_num_scheduled_tokens = 0 token_budget = self.max_num_scheduled_tokens # First, schedule the RUNNING requests. @@ -115,14 +113,8 @@ def schedule(self) -> "SchedulerOutput": req_to_new_block_ids[request.request_id] = new_block_ids num_scheduled_tokens[request.request_id] = num_tokens - total_num_scheduled_tokens += num_tokens token_budget -= num_tokens - request.status = RequestStatus.RUNNING - request.num_computed_tokens += num_tokens - if request.num_tokens == request.num_computed_tokens: - # TODO(woosuk): Consider speculative decoding. - request.num_output_tokens += 1 # Next, schedule the WAITING requests. while self.waiting: @@ -159,18 +151,14 @@ def schedule(self) -> "SchedulerOutput": else: assert False, f"Invalid request status: {request.status}" - req_to_new_block_ids[request.request_id] = ( - computed_block_ids + new_block_ids) + req_to_new_block_ids[request.request_id] = (computed_block_ids + + new_block_ids) num_scheduled_tokens[request.request_id] = num_tokens - total_num_scheduled_tokens += num_tokens token_budget -= num_tokens - request.status = RequestStatus.RUNNING - request.num_computed_tokens = num_computed_tokens + num_tokens - if request.num_tokens == request.num_computed_tokens: - request.num_output_tokens += 1 # Check if the scheduling constraints are satisfied. + total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens assert token_budget >= 0 assert len(self.running) <= self.max_num_running_reqs @@ -179,18 +167,19 @@ def schedule(self) -> "SchedulerOutput": # Construct the scheduler output. new_reqs_data = [ - NewRequestData.from_request( - req, req_to_new_block_ids[req.request_id]) + NewRequestData.from_request(req, + req_to_new_block_ids[req.request_id], + num_computed_tokens) for req in scheduled_new_reqs ] resumed_reqs_data = [ ResumedRequestData.from_request( - req, req_to_new_block_ids[req.request_id]) + req, req_to_new_block_ids[req.request_id], num_computed_tokens) for req in scheduled_resumed_reqs ] running_reqs_data = [ RunningRequestData.from_request( - req, req_to_new_block_ids[req.request_id]) + req, req_to_new_block_ids[req.request_id], num_computed_tokens) for req in scheduled_running_reqs ] preempted_req_ids = {req.request_id for req in preempted_reqs} @@ -207,6 +196,12 @@ def schedule(self) -> "SchedulerOutput": self.finished_req_ids = set() self.aborted_req_ids = set() + for request in self.running: + num_tokens = num_scheduled_tokens[request.request_id] + request.num_computed_tokens = num_computed_tokens + num_tokens + if request.num_tokens == request.num_computed_tokens: + # TODO: Consider speculative decoding. + request.num_output_tokens += 1 return scheduler_output def add_request(self, request: Request) -> None: @@ -248,7 +243,7 @@ def stop_requests(self, request_ids: Union[str, Iterable[str]]) -> None: request.status = RequestStatus.FINISHED_STOPPED stopped_reqs.append(request) request_ids.remove(request.request_id) - + for request in stopped_reqs: queue.remove(request) self.finished_req_ids.add(request.request_id) @@ -258,8 +253,9 @@ def _check_stop_by_len(self) -> None: stopped_reqs: List[Request] = [] # TODO: Optimize this. for request in self.running: - if (request.num_tokens >= self.max_model_len - or request.num_output_tokens >= request.max_tokens): + assert request.max_tokens is not None + if (request.num_tokens >= self.max_model_len + or request.num_output_tokens >= request.max_tokens): request.status = RequestStatus.FINISHED_LENGTH_CAPPED stopped_reqs.append(request) for request in stopped_reqs: @@ -269,14 +265,14 @@ def _check_stop_by_len(self) -> None: def _free_request(self, request: Request) -> None: assert request.is_finished() - self.block_manager.free(request) - - def has_unfinished_requests(self) -> bool: - return self.waiting or self.running + self.block_manager.free(request) def get_num_unfinished_requests(self) -> int: return len(self.waiting) + len(self.running) + def has_unfinished_requests(self) -> bool: + return self.get_num_unfinished_requests() > 0 + @dataclass class NewRequestData: @@ -294,6 +290,7 @@ def from_request( cls, request: Request, block_ids: List[int], + num_computed_tokens: int, ) -> "NewRequestData": return cls( req_id=request.request_id, @@ -302,6 +299,7 @@ def from_request( multi_modal_data=request.inputs.get("multi_modal_data"), sampling_params=request.sampling_params, block_ids=block_ids, + num_computed_tokens=num_computed_tokens, ) @@ -317,10 +315,12 @@ def from_request( cls, request: Request, block_ids: List[int], + num_computed_tokens: int, ) -> "ResumedRequestData": return cls( req_id=request.request_id, block_ids=block_ids, + num_computed_tokens=num_computed_tokens, ) @@ -336,10 +336,12 @@ def from_request( cls, request: Request, new_block_ids: List[int], + num_computed_tokens: int, ) -> "RunningRequestData": return cls( req_id=request.request_id, new_block_ids=new_block_ids, + num_computed_tokens=num_computed_tokens, ) From 46bc43541ba17aa0f887ccd8c4b0ca311399a3ff Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 27 Sep 2024 02:12:37 -0700 Subject: [PATCH 04/31] Minor --- vllm/engine/llm_engine_v2.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/engine/llm_engine_v2.py b/vllm/engine/llm_engine_v2.py index e828ee440eb9c..65fa8cf2da6eb 100644 --- a/vllm/engine/llm_engine_v2.py +++ b/vllm/engine/llm_engine_v2.py @@ -266,8 +266,7 @@ def __init__( # Create the scheduler. # NOTE: the cache_config here have been updated with the numbers of # GPU and CPU blocks, which are profiled in the distributed executor. - self.scheduler = Scheduler( - scheduler_config, cache_config, lora_config) + self.scheduler = Scheduler(scheduler_config, cache_config, lora_config) # Metric Logging. if self.log_stats: @@ -410,7 +409,10 @@ def _add_processed_request( # TODO(woosuk): Check max_logprobs # TODO(woosuk): Support encoder-decoder models. - req = Request(request_id, processed_inputs, arrival_time, sampling_params=params) + req = Request(request_id, + processed_inputs, + arrival_time, + sampling_params=params) self.scheduler.add_req(req) def stop_remote_worker_execution_loop(self) -> None: From 78d89667ffd706d46f93f589215450ad77e83145 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 27 Sep 2024 02:15:26 -0700 Subject: [PATCH 05/31] Add worker v2 --- vllm/worker/worker.py | 488 --------------------------------------- vllm/worker/worker_v2.py | 221 ++++++++++++++++++ 2 files changed, 221 insertions(+), 488 deletions(-) delete mode 100644 vllm/worker/worker.py create mode 100644 vllm/worker/worker_v2.py diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py deleted file mode 100644 index 3851843afc960..0000000000000 --- a/vllm/worker/worker.py +++ /dev/null @@ -1,488 +0,0 @@ -"""A GPU worker class.""" -import gc -import os -from typing import Dict, List, Optional, Set, Tuple, Type, Union - -import torch -import torch.distributed - -import vllm.envs as envs -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig, - SpeculativeConfig) -from vllm.distributed import (ensure_model_parallel_initialized, - init_distributed_environment, - set_custom_all_reduce) -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.model_executor import set_random_seed -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.model_executor.model_loader.tensorizer import TensorizerConfig -from vllm.platforms import current_platform -from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, - SequenceGroupMetadata, SequenceGroupMetadataDelta) -from vllm.worker.cache_engine import CacheEngine -from vllm.worker.embedding_model_runner import EmbeddingModelRunner -from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner -from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner -from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput - -logger = init_logger(__name__) - - -class Worker(LocalOrDistributedWorkerBase): - """A worker class that executes (a partition of) the model on a GPU. - - Each worker is associated with a single GPU. The worker is responsible for - maintaining the KV cache and executing the model on the GPU. In case of - distributed inference, each worker is assigned a partition of the model. - """ - - def __init__( - self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, - local_rank: int, - rank: int, - distributed_init_method: str, - lora_config: Optional[LoRAConfig] = None, - speculative_config: Optional[SpeculativeConfig] = None, - prompt_adapter_config: Optional[PromptAdapterConfig] = None, - is_driver_worker: bool = False, - model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None, - observability_config: Optional[ObservabilityConfig] = None, - ) -> None: - self.model_config = model_config - self.parallel_config = parallel_config - self.parallel_config.rank = rank - self.scheduler_config = scheduler_config - self.device_config = device_config - self.cache_config = cache_config - self.local_rank = local_rank - self.rank = rank - self.distributed_init_method = distributed_init_method - self.lora_config = lora_config - self.load_config = load_config - self.prompt_adapter_config = prompt_adapter_config - self.is_driver_worker = is_driver_worker - if parallel_config and is_driver_worker: - assert rank % parallel_config.tensor_parallel_size == 0, \ - "Driver worker should be rank 0 of tensor parallel group." - if self.model_config.trust_remote_code: - # note: lazy import to avoid importing torch before initializing - from vllm.utils import init_cached_hf_modules - init_cached_hf_modules() - self.observability_config = observability_config - - # Return hidden states from target model if the draft model is an - # mlp_speculator - speculative_args = {} if speculative_config is None \ - or (speculative_config.draft_model_config.model == - model_config.model) \ - or (speculative_config.draft_model_config.hf_config.model_type - not in ["medusa", "mlp_speculator", "eagle"]) \ - else {"return_hidden_states": True} - - ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner - if model_runner_cls is not None: - ModelRunnerClass = model_runner_cls - elif self._is_embedding_model(): - ModelRunnerClass = EmbeddingModelRunner - elif self._is_encoder_decoder_model(): - ModelRunnerClass = EncoderDecoderModelRunner - self.model_runner: GPUModelRunnerBase = ModelRunnerClass( - model_config, - parallel_config, - scheduler_config, - device_config, - cache_config, - load_config=load_config, - lora_config=self.lora_config, - kv_cache_dtype=self.cache_config.cache_dtype, - is_driver_worker=is_driver_worker, - prompt_adapter_config=prompt_adapter_config, - observability_config=observability_config, - **speculative_args, - ) - # Uninitialized cache engine. Will be initialized by - # initialize_cache. - self.cache_engine: List[CacheEngine] - # Initialize gpu_cache as embedding models don't initialize kv_caches - self.gpu_cache: Optional[List[List[torch.Tensor]]] = None - self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {} - - # Torch profiler. Enabled and configured through env vars: - # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace - if envs.VLLM_TORCH_PROFILER_DIR: - torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR - logger.info("Profiling enabled. Traces will be saved to: %s", - torch_profiler_trace_dir) - self.profiler = torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - with_stack=True, - on_trace_ready=torch.profiler.tensorboard_trace_handler( - torch_profiler_trace_dir, use_gzip=True)) - else: - self.profiler = None - - def start_profile(self): - if self.profiler is None: - raise RuntimeError("Profiler is not enabled.") - self.profiler.start() - - def stop_profile(self): - if self.profiler is None: - raise RuntimeError("Profiler is not enabled.") - self.profiler.stop() - - def _is_encoder_decoder_model(self): - return self.model_config.is_encoder_decoder_model - - def _is_embedding_model(self): - return self.model_config.is_embedding_model - - def init_device(self) -> None: - if self.device_config.device.type == "cuda": - # torch.distributed.all_reduce does not free the input tensor until - # the synchronization point. This causes the memory usage to grow - # as the number of all_reduce calls increases. This env var disables - # this behavior. - # Related issue: - # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 - os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" - - # This env var set by Ray causes exceptions with graph building. - os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) - self.device = torch.device(f"cuda:{self.local_rank}") - torch.cuda.set_device(self.device) - - _check_if_gpu_supports_dtype(self.model_config.dtype) - gc.collect() - torch.cuda.empty_cache() - self.init_gpu_memory = torch.cuda.mem_get_info()[0] - else: - raise RuntimeError( - f"Not support device type: {self.device_config.device}") - # Initialize the distributed environment. - init_worker_distributed_environment(self.parallel_config, self.rank, - self.distributed_init_method, - self.local_rank) - # Set random seed. - set_random_seed(self.model_config.seed) - - def load_model(self): - self.model_runner.load_model() - - def save_sharded_state( - self, - path: str, - pattern: Optional[str] = None, - max_size: Optional[int] = None, - ) -> None: - self.model_runner.save_sharded_state( - path, - pattern=pattern, - max_size=max_size, - ) - - def save_tensorized_model( - self, - tensorizer_config: TensorizerConfig, - ) -> None: - self.model_runner.save_tensorized_model( - tensorizer_config=tensorizer_config, ) - - @torch.inference_mode() - def determine_num_available_blocks(self) -> Tuple[int, int]: - """Profiles the peak memory usage of the model to determine how many - KV blocks may be allocated without OOMs. - - The engine will first conduct a profiling of the existing memory usage. - Then, it calculate the maximum possible number of GPU and CPU blocks - that can be allocated with the remaining free memory. - - .. tip:: - You may limit the usage of GPU memory - by adjusting the `gpu_memory_utilization` parameter. - """ - # Profile the memory usage of the model and get the maximum number of - # cache blocks that can be allocated with the remaining free memory. - torch.cuda.empty_cache() - - # Execute a forward pass with dummy inputs to profile the memory usage - # of the model. - self.model_runner.profile_run() - - # Calculate the number of blocks that can be allocated with the - # profiled peak memory. - torch.cuda.synchronize() - free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() - # NOTE(woosuk): Here we assume that the other processes using the same - # GPU did not change their memory usage during the profiling. - peak_memory = self.init_gpu_memory - free_gpu_memory - assert peak_memory > 0, ( - "Error in memory profiling. " - f"Initial free memory {self.init_gpu_memory}, current free memory" - f" {free_gpu_memory}. This happens when the GPU memory was " - "not properly cleaned up before initializing the vLLM instance.") - - cache_block_size = self.get_cache_block_size_bytes() - num_gpu_blocks = int( - (total_gpu_memory * self.cache_config.gpu_memory_utilization - - peak_memory) // cache_block_size) - num_cpu_blocks = int(self.cache_config.swap_space_bytes // - cache_block_size) - num_gpu_blocks = max(num_gpu_blocks, 0) - num_cpu_blocks = max(num_cpu_blocks, 0) - if self.model_runner.lora_manager: - self.model_runner.remove_all_loras() - gc.collect() - torch.cuda.empty_cache() - return num_gpu_blocks, num_cpu_blocks - - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: - """Allocate GPU and CPU KV cache with the specified number of blocks. - - This also warms up the model, which may record CUDA graphs. - """ - raise_if_cache_size_invalid(num_gpu_blocks, - self.cache_config.block_size, - self.model_config.max_model_len) - - self.cache_config.num_gpu_blocks = num_gpu_blocks - self.cache_config.num_cpu_blocks = num_cpu_blocks - - self._init_cache_engine() - self._warm_up_model() - - def _init_cache_engine(self): - assert self.cache_config.num_gpu_blocks is not None - self.cache_engine = [ - CacheEngine(self.cache_config, self.model_config, - self.parallel_config, self.device_config) - for _ in range(self.parallel_config.pipeline_parallel_size) - ] - self.gpu_cache = [ - self.cache_engine[ve].gpu_cache - for ve in range(self.parallel_config.pipeline_parallel_size) - ] - - def _warm_up_model(self) -> None: - if not self.model_config.enforce_eager: - self.model_runner.capture_model(self.gpu_cache) - # Reset the seed to ensure that the random state is not affected by - # the model initialization and profiling. - set_random_seed(self.model_config.seed) - - @property - def do_metadata_broadcast(self) -> bool: - return self.parallel_config.tensor_parallel_size > 1 - - @property - def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: - return self.gpu_cache - - @torch.inference_mode() - def prepare_worker_input( - self, execute_model_req: ExecuteModelRequest) -> WorkerInput: - virtual_engine = execute_model_req.virtual_engine - num_steps = execute_model_req.num_steps - num_seq_groups = len(execute_model_req.seq_group_metadata_list) - # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors. - # they contain parameters to launch cudamemcpyasync. - blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in, - device="cpu", - dtype=torch.int64).view(-1, 2) - blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out, - device="cpu", - dtype=torch.int64).view(-1, 2) - # `blocks_to_copy` is a gpu tensor. The src and tgt of - # blocks to copy are in the same device, and `blocks_to_copy` - # can be used directly within cuda kernels. - blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, - device=self.device, - dtype=torch.int64).view(-1, 2) - - return WorkerInput( - num_seq_groups=num_seq_groups, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - virtual_engine=virtual_engine, - num_steps=num_steps, - ) - - @torch.inference_mode() - def execute_worker(self, worker_input: WorkerInput) -> None: - virtual_engine = worker_input.virtual_engine - # Issue cache operations. - if (worker_input.blocks_to_swap_in is not None - and worker_input.blocks_to_swap_in.numel() > 0): - self.cache_engine[virtual_engine].swap_in( - worker_input.blocks_to_swap_in) - if (worker_input.blocks_to_swap_out is not None - and worker_input.blocks_to_swap_out.numel() > 0): - self.cache_engine[virtual_engine].swap_out( - worker_input.blocks_to_swap_out) - if (worker_input.blocks_to_copy is not None - and worker_input.blocks_to_copy.numel() > 0): - self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy) - - def _get_cached_seq_group_metadata( - self, - seq_group_metadata_list: List[Union[SequenceGroupMetadata, - SequenceGroupMetadataDelta]], - finished_request_ids: List[str]) -> List[SequenceGroupMetadata]: - """Return a list of cached Sequence Group Metadata after updating its - state. - - It is used because scheduler only sends delta to workers to reduce - the data payload size. The function also cleans up cache based on - a given `finished_request_ids`. - """ - new_seq_group_metadata_list = [] - for metadata_or_delta in seq_group_metadata_list: - request_id = metadata_or_delta.request_id - if request_id not in self._seq_group_metadata_cache: - # The first prefill. - assert isinstance(metadata_or_delta, SequenceGroupMetadata) - self._seq_group_metadata_cache[request_id] = metadata_or_delta - else: - # The first prefill is already cached. - if isinstance(metadata_or_delta, SequenceGroupMetadataDelta): - self._seq_group_metadata_cache[request_id].apply_delta( - metadata_or_delta) - else: - # If metadata snapshot is sent again, it is - # preempted. Reset the cache because we need to start - # from scratch. - assert isinstance(metadata_or_delta, SequenceGroupMetadata) - self._seq_group_metadata_cache[ - request_id] = metadata_or_delta - - new_seq_group_metadata_list.append( - self._seq_group_metadata_cache[request_id]) - - # Clean up finished ids - for finished_id in finished_request_ids: - del self._seq_group_metadata_cache[finished_id] - - return new_seq_group_metadata_list - - def _execute_model_spmd( - self, - execute_model_req: ExecuteModelRequest, - intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> Optional[List[SamplerOutput]]: - if execute_model_req is not None: - new_seq_group_metadata_list = self._get_cached_seq_group_metadata( - execute_model_req.seq_group_metadata_list, - execute_model_req.finished_requests_ids) - - execute_model_req.seq_group_metadata_list = ( - new_seq_group_metadata_list) - output = super()._execute_model_spmd(execute_model_req, - intermediate_tensors) - return output - - def add_lora(self, lora_request: LoRARequest) -> bool: - return self.model_runner.add_lora(lora_request) - - def remove_lora(self, lora_id: int) -> bool: - return self.model_runner.remove_lora(lora_id) - - def pin_lora(self, lora_id: int) -> bool: - return self.model_runner.pin_lora(lora_id) - - def list_loras(self) -> Set[int]: - return self.model_runner.list_loras() - - def add_prompt_adapter( - self, prompt_adapter_request: PromptAdapterRequest) -> bool: - return self.model_runner.add_prompt_adapter(prompt_adapter_request) - - def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: - return self.model_runner.remove_lora(prompt_adapter_id) - - def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: - return self.model_runner.pin_prompt_adapter(prompt_adapter_id) - - def list_prompt_adapters(self) -> Set[int]: - return self.model_runner.list_prompt_adapters() - - @property - def max_model_len(self) -> int: - return self.model_config.max_model_len - - @property - def vocab_size(self) -> int: - return self.model_runner.vocab_size - - def get_cache_block_size_bytes(self) -> int: - """Get the size of the KV cache block size in bytes. - """ - return CacheEngine.get_cache_block_size(self.cache_config, - self.model_config, - self.parallel_config) - - -def init_worker_distributed_environment( - parallel_config: ParallelConfig, - rank: int, - distributed_init_method: Optional[str] = None, - local_rank: int = -1, -) -> None: - """Initialize the distributed environment.""" - set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) - - init_distributed_environment(parallel_config.world_size, rank, - distributed_init_method, local_rank) - - ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size) - - -def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): - # Check if the GPU supports the dtype. - if torch_dtype == torch.bfloat16: # noqa: SIM102 - if not current_platform.has_device_capability(80): - capability = current_platform.get_device_capability() - gpu_name = current_platform.get_device_name() - - if capability is None: - compute_str = "does not have a compute capability" - else: - version_str = capability.as_version_str() - compute_str = f"has compute capability {version_str}" - - raise ValueError( - "Bfloat16 is only supported on GPUs with compute capability " - f"of at least 8.0. Your {gpu_name} GPU {compute_str}. " - "You can use float16 instead by explicitly setting the" - "`dtype` flag in CLI, for example: --dtype=half.") - - -def raise_if_cache_size_invalid(num_gpu_blocks, block_size, - max_model_len) -> None: - if num_gpu_blocks <= 0: - raise ValueError("No available memory for the cache blocks. " - "Try increasing `gpu_memory_utilization` when " - "initializing the engine.") - max_seq_len = block_size * num_gpu_blocks - if max_model_len > max_seq_len: - raise ValueError( - f"The model's max seq len ({max_model_len}) " - "is larger than the maximum number of tokens that can be " - f"stored in KV cache ({max_seq_len}). Try increasing " - "`gpu_memory_utilization` or decreasing `max_model_len` when " - "initializing the engine.") diff --git a/vllm/worker/worker_v2.py b/vllm/worker/worker_v2.py new file mode 100644 index 0000000000000..561f052913cf5 --- /dev/null +++ b/vllm/worker/worker_v2.py @@ -0,0 +1,221 @@ +"""A GPU worker class.""" +import gc +import os +from typing import Dict, List, Optional, Set, Tuple, TYPE_CHECKING + +import torch +import torch.distributed + +import vllm.envs as envs +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ObservabilityConfig, ParallelConfig, + PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig) +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment, + set_custom_all_reduce) +from vllm.logger import init_logger +from vllm.model_executor import set_random_seed +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.platforms import current_platform +from vllm.worker.model_runner_v2 import GPUModelRunner + +logger = init_logger(__name__) + +if TYPE_CHECKING: + from vllm.core.scheduler_v2 import SchedulerOutput + + +class Worker: + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + lora_config: Optional[LoRAConfig] = None, + ): + self.model_config = model_config + self.parallel_config = parallel_config + self.device_config = device_config + self.cache_config = cache_config + self.load_config = load_config + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + self.lora_config = lora_config + + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() + + self.model_runner = GPUModelRunner( + model_config, + parallel_config, + device_config, + cache_config, + load_config, + kv_cache_dtype=cache_config.kv_cache_dtype, + lora_config=lora_config, + ) + + def initialize(self): + if self.device_config.device.type == "cuda": + # torch.distributed.all_reduce does not free the input tensor until + # the synchronization point. This causes the memory usage to grow + # as the number of all_reduce calls increases. This env var disables + # this behavior. + # Related issue: + # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + + # This env var set by Ray causes exceptions with graph building. + os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) + self.device = torch.device(f"cuda:{self.local_rank}") + torch.cuda.set_device(self.device) + + _check_if_gpu_supports_dtype(self.model_config.dtype) + gc.collect() + torch.cuda.empty_cache() + self.init_gpu_memory = torch.cuda.mem_get_info()[0] + else: + raise RuntimeError( + f"Not support device type: {self.device_config.device}") + # Initialize the distributed environment. + init_worker_distributed_environment(self.parallel_config, self.rank, + self.distributed_init_method, + self.local_rank) + # Set random seed. + set_random_seed(self.model_config.seed) + + def load_model(self) -> None: + self.model_runner.load_model() + + @torch.inference_mode() + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Profiles the peak memory usage of the model to determine how many + KV blocks may be allocated without OOMs. + + The engine will first conduct a profiling of the existing memory usage. + Then, it calculate the maximum possible number of GPU and CPU blocks + that can be allocated with the remaining free memory. + + .. tip:: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameter. + """ + # Profile the memory usage of the model and get the maximum number of + # cache blocks that can be allocated with the remaining free memory. + torch.cuda.empty_cache() + + # Execute a forward pass with dummy inputs to profile the memory usage + # of the model. + self.model_runner.profile_run() + + # Calculate the number of blocks that can be allocated with the + # profiled peak memory. + torch.cuda.synchronize() + free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() + # NOTE(woosuk): Here we assume that the other processes using the same + # GPU did not change their memory usage during the profiling. + peak_memory = self.init_gpu_memory - free_gpu_memory + assert peak_memory > 0, ( + "Error in memory profiling. " + f"Initial free memory {self.init_gpu_memory}, current free memory" + f" {free_gpu_memory}. This happens when the GPU memory was " + "not properly cleaned up before initializing the vLLM instance.") + + cache_block_size = self.get_cache_block_size_bytes() + num_gpu_blocks = int( + (total_gpu_memory * self.cache_config.gpu_memory_utilization - + peak_memory) // cache_block_size) + num_cpu_blocks = int(self.cache_config.swap_space_bytes // + cache_block_size) + num_gpu_blocks = max(num_gpu_blocks, 0) + num_cpu_blocks = max(num_cpu_blocks, 0) + if self.model_runner.lora_manager: + self.model_runner.remove_all_loras() + gc.collect() + torch.cuda.empty_cache() + return num_gpu_blocks, num_cpu_blocks + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Allocate GPU and CPU KV cache with the specified number of blocks.""" + if num_gpu_blocks <= 0: + raise ValueError("No available memory for the cache blocks. " + "Try increasing `gpu_memory_utilization` when " + "initializing the engine.") + + max_seq_len = self.cache_config.block_size * num_gpu_blocks + max_model_len = self.model_config.max_model_len + if max_model_len > max_seq_len: + raise ValueError( + f"The model's max seq len ({max_model_len}) " + "is larger than the maximum number of tokens that can be " + f"stored in KV cache ({max_seq_len}). Try increasing " + "`gpu_memory_utilization` or decreasing `max_model_len` when " + "initializing the engine.") + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + # TODO(woosuk): Create KV cache. + self.model_runner.initialize_kv_cache() + + def compile_or_warm_up_model(self) -> None: + if not self.model_config.enforce_eager: + self.model_runner.capture_model(self.gpu_cache) + # Reset the seed to ensure that the random state is not affected by + # the model initialization and profiling. + set_random_seed(self.model_config.seed) + + @torch.inference_mode() + def execute_model( + self, + scheduler_output: "SchedulerOutput", + ) -> None: + sampler_output = self.model_runner.execute_model(scheduler_output) + # TODO(woosuk): Send the output to the engine process. + + +def init_worker_distributed_environment( + parallel_config: ParallelConfig, + rank: int, + distributed_init_method: Optional[str] = None, + local_rank: int = -1, +) -> None: + """Initialize the distributed environment.""" + set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) + + init_distributed_environment(parallel_config.world_size, rank, + distributed_init_method, local_rank) + + ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size) + + +def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): + # Check if the GPU supports the dtype. + if torch_dtype == torch.bfloat16: # noqa: SIM102 + if not current_platform.has_device_capability(80): + capability = current_platform.get_device_capability() + gpu_name = current_platform.get_device_name() + + if capability is None: + compute_str = "does not have a compute capability" + else: + version_str = capability.as_version_str() + compute_str = f"has compute capability {version_str}" + + raise ValueError( + "Bfloat16 is only supported on GPUs with compute capability " + f"of at least 8.0. Your {gpu_name} GPU {compute_str}. " + "You can use float16 instead by explicitly setting the" + "`dtype` flag in CLI, for example: --dtype=half.") From 0d27d3db1f5d2574caecd8183d8fbff6c7ae9df2 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 27 Sep 2024 02:15:36 -0700 Subject: [PATCH 06/31] Add model runner v2 --- vllm/worker/model_runner.py | 1853 -------------------------------- vllm/worker/model_runner_v2.py | 411 +++++++ 2 files changed, 411 insertions(+), 1853 deletions(-) delete mode 100644 vllm/worker/model_runner.py create mode 100644 vllm/worker/model_runner_v2.py diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py deleted file mode 100644 index 0a90f767567d6..0000000000000 --- a/vllm/worker/model_runner.py +++ /dev/null @@ -1,1853 +0,0 @@ -import dataclasses -import gc -import inspect -import itertools -import time -import warnings -import weakref -from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, - Tuple, Type, TypeVar, Union) - -import numpy as np -import torch -import torch.distributed -import torch.nn as nn - -import vllm.envs as envs -from vllm.attention import AttentionMetadata, get_attn_backend -from vllm.attention.backends.abstract import AttentionState -from vllm.attention.backends.utils import CommonAttentionState -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig) -from vllm.core.scheduler import SchedulerOutputs -from vllm.distributed import get_pp_group -from vllm.distributed.parallel_state import graph_capture -from vllm.inputs import INPUT_REGISTRY, InputRegistry -from vllm.logger import init_logger -from vllm.lora.layers import LoRAMapping -from vllm.lora.request import LoRARequest -from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager -from vllm.model_executor import SamplingMetadata, SamplingMetadataCache -from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.model_executor.model_loader import get_model -from vllm.model_executor.model_loader.tensorizer import TensorizerConfig -from vllm.model_executor.models.interfaces import (supports_lora, - supports_multimodal) -from vllm.model_executor.models.utils import set_cpu_offload_max_bytes -from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, - MultiModalInputs, MultiModalRegistry) -from vllm.prompt_adapter.layers import PromptAdapterMapping -from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.prompt_adapter.worker_manager import ( - LRUCacheWorkerPromptAdapterManager) -from vllm.sampling_params import SamplingParams -from vllm.sequence import IntermediateTensors, SequenceGroupMetadata -from vllm.utils import (DeviceMemoryProfiler, PyObjectCache, async_tensor_h2d, - flatten_2d_lists, is_hip, is_pin_memory_available, - supports_dynamo) -from vllm.worker.model_runner_base import ( - ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, - _add_attn_metadata_broadcastable_dict, - _add_sampling_metadata_broadcastable_dict, - _init_attn_metadata_from_tensor_dict, - _init_sampling_metadata_from_tensor_dict, dump_input_when_exception) - -if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionBackend - -logger = init_logger(__name__) - -LORA_WARMUP_RANK = 8 -_BATCH_SIZE_ALIGNMENT = 8 -# all the token sizes that **can** be captured by cudagraph. -# they can be arbitrarily large. -# currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192. -# the actual sizes to capture will be determined by the model, -# depending on the model's max_num_seqs. -# NOTE: _get_graph_batch_size needs to be updated if this list is changed. -_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ - _BATCH_SIZE_ALIGNMENT * i for i in range(1, 1025) -] -_NUM_WARMUP_ITERS = 2 - -TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU") - -# For now, bump up cache limits for recompilations during CUDA graph warmups. -torch._dynamo.config.cache_size_limit = 128 -torch._dynamo.config.accumulated_cache_size_limit = 128 - - -@dataclass(frozen=True) -class ModelInputForGPU(ModelRunnerInputBase): - """ - This base class contains metadata needed for the base model forward pass - but not metadata for possible additional steps, e.g., sampling. Model - runners that run additional steps should subclass this method to add - additional fields. - """ - input_tokens: Optional[torch.Tensor] = None - input_positions: Optional[torch.Tensor] = None - seq_lens: Optional[List[int]] = None - query_lens: Optional[List[int]] = None - lora_mapping: Optional["LoRAMapping"] = None - lora_requests: Optional[Set[LoRARequest]] = None - attn_metadata: Optional["AttentionMetadata"] = None - prompt_adapter_mapping: Optional[PromptAdapterMapping] = None - prompt_adapter_requests: Optional[Set[PromptAdapterRequest]] = None - multi_modal_kwargs: Optional[BatchedTensorInputs] = None - request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None - finished_requests_ids: Optional[List[str]] = None - virtual_engine: int = 0 - async_callback: Optional[Callable] = None - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None - scheduler_outputs: Optional[SchedulerOutputs] = None - - def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: - tensor_dict = { - "input_tokens": self.input_tokens, - "input_positions": self.input_positions, - "lora_requests": self.lora_requests, - "lora_mapping": self.lora_mapping, - "multi_modal_kwargs": self.multi_modal_kwargs, - "prompt_adapter_mapping": self.prompt_adapter_mapping, - "prompt_adapter_requests": self.prompt_adapter_requests, - "virtual_engine": self.virtual_engine, - "request_ids_to_seq_ids": self.request_ids_to_seq_ids, - "finished_requests_ids": self.finished_requests_ids, - } - _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) - return tensor_dict - - @classmethod - def from_broadcasted_tensor_dict( - cls: Type[TModelInputForGPU], - tensor_dict: Dict[str, Any], - attn_backend: Optional["AttentionBackend"] = None, - ) -> TModelInputForGPU: - if attn_backend is not None: - tensor_dict = _init_attn_metadata_from_tensor_dict( - attn_backend, tensor_dict) - return cls(**tensor_dict) - - -@dataclass(frozen=True) -class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU): - """ - Used by the ModelRunner. - """ - sampling_metadata: Optional["SamplingMetadata"] = None - # Used for speculative decoding. We do not broadcast it because it is only - # used by the driver worker. - is_prompt: Optional[bool] = None - - def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: - tensor_dict = { - "input_tokens": self.input_tokens, - "input_positions": self.input_positions, - "lora_requests": self.lora_requests, - "lora_mapping": self.lora_mapping, - "multi_modal_kwargs": self.multi_modal_kwargs, - "prompt_adapter_mapping": self.prompt_adapter_mapping, - "prompt_adapter_requests": self.prompt_adapter_requests, - "virtual_engine": self.virtual_engine, - "request_ids_to_seq_ids": self.request_ids_to_seq_ids, - "finished_requests_ids": self.finished_requests_ids, - } - _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) - _add_sampling_metadata_broadcastable_dict(tensor_dict, - self.sampling_metadata) - return tensor_dict - - @classmethod - def from_broadcasted_tensor_dict( - cls, - tensor_dict: Dict[str, Any], - attn_backend: Optional["AttentionBackend"] = None, - ) -> "ModelInputForGPUWithSamplingMetadata": - tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) - if attn_backend is not None: - tensor_dict = _init_attn_metadata_from_tensor_dict( - attn_backend, tensor_dict) - return cls(**tensor_dict) - - -class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): - """Build ModelInputForGPU from SequenceGroupMetadata.""" - - # Note: ideally we would be using a dataclass(kw_only=True) - # here, so that this can be subclassed easily, - # but kw_only is not supported in python<3.10. - class InterDataForSeqGroup: - """Intermediate data for the current sequence group.""" - - def simple_reinit(self): - self.input_tokens[0].clear() # type: ignore - self.input_positions[0].clear() # type: ignore - self.mrope_input_positions = None # type: ignore - self.seq_lens[0] = 0 # type: ignore - self.orig_seq_lens[0] = 0 # type: ignore - self.query_lens[0] = 0 # type: ignore - self.context_lens[0] = 0 # type: ignore - self.curr_sliding_window_blocks[0] = 0 # type: ignore - self.lora_index_mapping.clear() # type: ignore - self.lora_prompt_mapping.clear() # type: ignore - self.lora_requests.clear() # type: ignore - self.prompt_adapter_index_mapping.clear() # type: ignore - self.prompt_adapter_prompt_mapping.clear() # type: ignore - - def __init__( - self, - *, - # From sequence group metadata. - request_id: str, - seq_ids: List[int], - is_prompt: bool, - block_tables: Optional[Dict[int, List[int]]], - computed_block_nums: List[int], - n_seqs: int = 0, - - # Input tokens and positions. - input_tokens: Optional[List[List[int]]] = None, - input_positions: Optional[List[List[int]]] = None, - mrope_input_positions: Optional[List[List[List[int]]]] = None, - - # The sequence length (may be capped to the sliding window). - seq_lens: Optional[List[int]] = None, - # The original sequence length (before applying sliding window). - # This is used to compute slot mapping. - orig_seq_lens: Optional[List[int]] = None, - # The query length. - query_lens: Optional[List[int]] = None, - # The number of tokens that are already computed. - context_lens: Optional[List[int]] = None, - # The current sliding window block. - curr_sliding_window_blocks: Optional[List[int]] = None, - - # LoRA inputs. - lora_index_mapping: Optional[List[List[int]]] = None, - lora_prompt_mapping: Optional[List[List[int]]] = None, - lora_requests: Optional[Set[LoRARequest]] = None, - - # Prompt adapter inputs. - prompt_adapter_index_mapping: Optional[List[int]] = None, - prompt_adapter_prompt_mapping: Optional[List[int]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - - # Multi-modal inputs. - multi_modal_inputs: Optional[MultiModalInputs] = None, - - # Whether the prefix cache is hit (prefill only). - prefix_cache_hit: bool = False, - reinit: bool = False, - reinit_use_defaults: bool = False, - encoder_seq_len: int = 0, - ): - if reinit: - assert len(self.seq_ids) == len(seq_ids) # type: ignore - for i, seq_id in enumerate(seq_ids): - self.seq_ids[i] = seq_id # type: ignore - else: - self.seq_ids = seq_ids - - self.request_id = request_id - self.is_prompt = is_prompt - self.block_tables = block_tables - self.computed_block_nums = computed_block_nums - self.n_seqs = n_seqs - self.encoder_seq_len = encoder_seq_len - - if reinit: - if len(self.seq_ids) == 1 and reinit_use_defaults: - self.simple_reinit() - else: - if input_tokens: - self.input_tokens = input_tokens - else: - for seq_id in range(len(self.seq_ids)): - self.input_tokens[seq_id].clear() - - if input_positions: - self.input_positions = input_positions - else: - for seq_id in range(len(self.seq_ids)): - self.input_positions[seq_id].clear() - - self.mrope_input_positions = None - - if seq_lens: - self.seq_lens = seq_lens - else: - for seq_id in range(len(self.seq_ids)): - self.seq_lens[seq_id] = 0 - - if orig_seq_lens: - self.orig_seq_lens = orig_seq_lens - else: - for seq_id in range(len(self.seq_ids)): - self.orig_seq_lens[seq_id] = 0 - - if query_lens: - self.query_lens = query_lens - else: - for seq_id in range(len(self.seq_ids)): - self.query_lens[seq_id] = 0 - - if context_lens: - self.context_lens = context_lens - else: - for seq_id in range(len(self.seq_ids)): - self.context_lens[seq_id] = 0 - - if curr_sliding_window_blocks: - self.curr_sliding_window_blocks = \ - curr_sliding_window_blocks - else: - for seq_id in range(len(self.seq_ids)): - self.curr_sliding_window_blocks[seq_id] = 0 - - if lora_index_mapping: - self.lora_index_mapping = lora_index_mapping - else: - self.lora_index_mapping.clear() - - if lora_prompt_mapping: - self.lora_prompt_mapping = lora_prompt_mapping - else: - self.lora_prompt_mapping.clear() - - if lora_requests: - self.lora_requests = lora_requests - else: - self.lora_requests.clear() - - if prompt_adapter_index_mapping: - self.prompt_adapter_index_mapping = \ - prompt_adapter_index_mapping - else: - self.prompt_adapter_index_mapping.clear() - - if prompt_adapter_prompt_mapping: - self.prompt_adapter_prompt_mapping = \ - prompt_adapter_prompt_mapping - else: - self.prompt_adapter_prompt_mapping.clear() - - else: - self.input_tokens = input_tokens or [] - self.input_positions = input_positions or [] - self.mrope_input_positions = mrope_input_positions or None - self.seq_lens = seq_lens or [] - self.orig_seq_lens = orig_seq_lens or [] - self.query_lens = query_lens or [] - self.context_lens = context_lens or [] - self.curr_sliding_window_blocks = \ - curr_sliding_window_blocks or [] - - self.lora_index_mapping = lora_index_mapping or [] - self.lora_prompt_mapping = lora_prompt_mapping or [] - self.lora_requests = lora_requests or set() - - self.prompt_adapter_index_mapping = ( - prompt_adapter_index_mapping or []) - self.prompt_adapter_prompt_mapping = ( - prompt_adapter_prompt_mapping or []) - - self.prompt_adapter_request = prompt_adapter_request - self.multi_modal_inputs = multi_modal_inputs - self.prefix_cache_hit = prefix_cache_hit - - self.n_seqs = len(self.seq_ids) - - if not reinit: - self.__post_init__() - - def __post_init__(self): - self.n_seqs = len(self.seq_ids) - - self.input_tokens = [[] for _ in range(self.n_seqs)] - self.input_positions = [[] for _ in range(self.n_seqs)] - self.mrope_input_positions = None - self.seq_lens = [0] * self.n_seqs - self.orig_seq_lens = [0] * self.n_seqs - self.query_lens = [0] * self.n_seqs - self.context_lens = [0] * self.n_seqs - self.curr_sliding_window_blocks = [0] * self.n_seqs - - self.lora_index_mapping = [] - self.lora_prompt_mapping = [] - - def gen_inter_data_builder(self, num_seqs: int): - return lambda: ModelInputForGPUBuilder.InterDataForSeqGroup( - request_id="", - seq_ids=[0] * num_seqs, - is_prompt=True, - block_tables=None, - computed_block_nums=[]) - - def init_cached_inter_data(self, *args, **kwargs): - assert len(args) == 0 - assert "seq_ids" in kwargs - seq_ids = kwargs["seq_ids"] - num_seqs = len(seq_ids) - - # The inter-data cache is per model_runner - inter_data_cache = self.runner.inter_data_cache - if num_seqs not in inter_data_cache: - inter_data_cache[num_seqs] = PyObjectCache( - self.gen_inter_data_builder(num_seqs)) - - obj = inter_data_cache[num_seqs].get_object() - obj.__init__(*args, **kwargs) - return obj - - def reset_cached_inter_data(self): - for cache in self.runner.inter_data_cache.values(): - cache.reset() - - def __init__(self, - runner: "GPUModelRunnerBase", - finished_requests_ids: Optional[List[str]] = None): - super().__init__() - # Compute functions for each sequence in a sequence group. - # WARNING: The order of the functions matters! - self.per_seq_compute_fns = [ - self._compute_lens, - self._compute_for_prefix_cache_hit, - self._compute_for_sliding_window, - self._compute_lora_input, - ] - # Compute functions for each sequence group. - # WARNING: The order of the functions matters! - self.per_seq_group_compute_fns = [ - self._compute_prompt_adapter_input, - self._compute_multi_modal_input, - ] - - self.runner = runner - self.model_input_cls = self.runner._model_input_cls - self.attn_backend = self.runner.attn_backend - self.scheduler_config = self.runner.scheduler_config - self.sliding_window = self.runner.sliding_window - self.block_size = self.runner.block_size - self.enable_lora = self.runner.lora_config is not None - self.enable_prompt_adapter = (self.runner.prompt_adapter_config - is not None) - self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper - self.finished_requests_ids = finished_requests_ids - self.decode_only = True - - # Intermediate data (data in CPU before going to GPU) for - # the current sequence group. - self.inter_data_list: List[ - ModelInputForGPUBuilder.InterDataForSeqGroup] = [] - - # Attention metadata inputs. - self.attn_metadata_builder = self.attn_backend.make_metadata_builder( - weakref.proxy(self)) - - # Engine/Model configurations. - self.chunked_prefill_enabled = ( - self.scheduler_config is not None - and self.scheduler_config.chunked_prefill_enabled) - if self.sliding_window is not None: - self.sliding_window_blocks = ( - self.sliding_window + self.block_size - 1) // self.block_size - self.block_aligned_sliding_window = \ - self.sliding_window_blocks * self.block_size - - def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, - seq_group_metadata: SequenceGroupMetadata): - """Compute context length, sequence length and tokens - for the given sequence data. - """ - seq_data = seq_group_metadata.seq_data[inter_data.seq_ids[seq_idx]] - token_chunk_size = seq_group_metadata.token_chunk_size - - # Compute context length (the number of tokens that are - # already computed) and sequence length (total number of tokens). - seq_len = seq_data.get_len() - if inter_data.is_prompt: - context_len = seq_data.get_num_computed_tokens() - else: - # get_num_computed_tokens is incorrect for spec decoding. - # So, we should have a special logic here. - # TODO(sang): Fix it. - context_len = seq_len - 1 - seq_len = min(seq_len, context_len + token_chunk_size) - - # Compute tokens. - if inter_data.is_prompt: - tokens = seq_data.get_token_ids() - if context_len != 0 or seq_len < len(tokens): - tokens = tokens[context_len:seq_len] - else: - # Optimization. get_token_ids requires the entire copy of - # tokens. - tokens = seq_data.get_last_token_id() - - inter_data.seq_lens[seq_idx] = seq_len - inter_data.orig_seq_lens[seq_idx] = seq_len - inter_data.context_lens[seq_idx] = context_len - - if isinstance(tokens, list): - inter_data.input_tokens[seq_idx].extend(tokens) - else: - inter_data.input_tokens[seq_idx].append(tokens) - - if (seq_len - context_len) == 1: - inter_data.input_positions[seq_idx].append(seq_len - 1) - else: - inter_data.input_positions[seq_idx].extend( - range(context_len, seq_len)) - - inter_data.query_lens[ - seq_idx] = seq_len - context_len if inter_data.is_prompt else 1 - - if seq_data.mrope_position_delta is not None: - if inter_data.mrope_input_positions is None: - inter_data.mrope_input_positions = [None] * inter_data.n_seqs - - inter_data.mrope_input_positions[ - seq_idx] = MRotaryEmbedding.get_next_input_positions( - seq_data.mrope_position_delta, - context_len, - seq_len, - ) - - def _compute_for_prefix_cache_hit( - self, inter_data: InterDataForSeqGroup, seq_idx: int, - seq_group_metadata: SequenceGroupMetadata): - """Check if hit prefix cache (i.e., some blocks are already computed). - If hit, update input tokens and positions to only compute the - remaining blocks. - """ - computed_block_nums = inter_data.computed_block_nums - - # Note that prefix caching does not support sliding window. - prefix_cache_hit = (computed_block_nums is not None - and len(computed_block_nums) > 0 - and self.sliding_window is None - and inter_data.is_prompt) - inter_data.prefix_cache_hit = prefix_cache_hit - - if not prefix_cache_hit: - return - - assert computed_block_nums is not None - # The cache hit prompt tokens in this sequence. Note that - # this may be larger than the sequence length if chunked - # prefill is enabled. - prefix_cache_len = len(computed_block_nums) * self.block_size - # The number of so far computed prompt tokens in this sequence. - context_len = inter_data.context_lens[seq_idx] - # The total number of prompt tokens in this sequence. - # When chunked prefill is enabled, this is the token number of - # computed chunks + current chunk. - seq_len = inter_data.seq_lens[seq_idx] - if prefix_cache_len <= context_len: - # We already passed the cache hit region, - # so do normal computation. - pass - elif context_len < prefix_cache_len < seq_len: - # Partial hit. Compute the missing part. - uncomputed_start = prefix_cache_len - context_len - inter_data.input_tokens[seq_idx] = inter_data.input_tokens[ - seq_idx][uncomputed_start:] - inter_data.input_positions[seq_idx] = inter_data.input_positions[ - seq_idx][uncomputed_start:] - context_len = prefix_cache_len - - inter_data.context_lens[seq_idx] = context_len - inter_data.query_lens[ - seq_idx] = inter_data.seq_lens[seq_idx] - context_len - elif seq_len <= prefix_cache_len: - # Full hit. Only compute the last token to avoid - # erroneous behavior. FIXME: Ideally we should directly - # mark all tokens as computed in the scheduler and do not - # schedule this sequence, so this case should not happen. - inter_data.input_tokens[seq_idx] = inter_data.input_tokens[ - seq_idx][-1:] - inter_data.input_positions[seq_idx] = inter_data.input_positions[ - seq_idx][-1:] - inter_data.query_lens[seq_idx] = 1 - inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1 - - def _compute_for_sliding_window(self, inter_data: InterDataForSeqGroup, - seq_idx: int, - seq_group_metadata: SequenceGroupMetadata): - """Update seq_len and curr_sliding_window_block for the given - sequence data (only required by decoding) if sliding window is enabled. - """ - curr_sliding_window_block = 0 - sliding_seq_len = inter_data.seq_lens[seq_idx] - if not inter_data.is_prompt and self.sliding_window is not None: - # TODO(sang): This is a hack to make sliding window work with - # paged attn. We can remove it if we make paged attn kernel - # to properly handle slinding window attn. - curr_sliding_window_block = self.sliding_window_blocks - if self.scheduler_config.use_v2_block_manager: - # number of elements in last block - suff_len = inter_data.seq_lens[seq_idx] % self.block_size - sliding_seq_len = min( - inter_data.seq_lens[seq_idx], - self.block_aligned_sliding_window + suff_len) - if suff_len > 0: - curr_sliding_window_block += 1 - else: - sliding_seq_len = min(inter_data.seq_lens[seq_idx], - self.sliding_window) - - inter_data.curr_sliding_window_blocks[ - seq_idx] = curr_sliding_window_block - inter_data.seq_lens[seq_idx] = sliding_seq_len - - def _compute_lora_input(self, inter_data: InterDataForSeqGroup, - seq_idx: int, - seq_group_metadata: SequenceGroupMetadata): - """If LoRA is enabled, compute LoRA index and prompt mapping.""" - if not self.enable_lora: - return - - lora_id = seq_group_metadata.lora_int_id - if lora_id > 0: - inter_data.lora_requests.add(seq_group_metadata.lora_request) - query_len = inter_data.query_lens[seq_idx] - inter_data.lora_index_mapping.append([lora_id] * query_len) - inter_data.lora_prompt_mapping.append( - [lora_id] * - (query_len if seq_group_metadata.sampling_params - and seq_group_metadata.sampling_params.prompt_logprobs is not None - else 1)) - - def _compute_prompt_adapter_input( - self, inter_data: InterDataForSeqGroup, - seq_group_metadata: SequenceGroupMetadata): - """If prompt adapter is enabled, compute index and prompt mapping. - """ - # Note that when is_prompt=True, we expect only one sequence - # in the group. - if not self.enable_prompt_adapter: - return - - prompt_adapter_id = seq_group_metadata.prompt_adapter_id - if prompt_adapter_id <= 0 or not inter_data.is_prompt: - return - - # We expect only one sequence in the group when is_prompt=True. - assert inter_data.n_seqs == 1 - query_len = inter_data.query_lens[0] - inter_data.prompt_adapter_request = ( - seq_group_metadata.prompt_adapter_request) - - num_tokens = seq_group_metadata.prompt_adapter_num_virtual_tokens - inter_data.prompt_adapter_index_mapping = [ - prompt_adapter_id - ] * num_tokens + [0] * (query_len - num_tokens) - inter_data.prompt_adapter_prompt_mapping = [prompt_adapter_id] * ( - query_len if seq_group_metadata.sampling_params - and seq_group_metadata.sampling_params.prompt_logprobs else 1) - - def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, - seq_group_metadata: SequenceGroupMetadata): - """If multi-modal data is given, add it to the input.""" - mm_data = seq_group_metadata.multi_modal_data - if not mm_data: - return - - mm_kwargs = self.multi_modal_input_mapper(mm_data) - inter_data.multi_modal_inputs = mm_kwargs - - # special processing for mrope position deltas. - if self.runner.model_is_mrope: - image_grid_thw = mm_kwargs.get("image_grid_thw", None) - video_grid_thw = mm_kwargs.get("video_grid_thw", None) - assert image_grid_thw is not None or video_grid_thw is not None, ( - "mrope embedding type requires multi-modal input mapper " - "returns 'image_grid_thw' or 'video_grid_thw'.") - - hf_config = self.runner.model_config.hf_config - - inter_data.mrope_input_positions = [None] * inter_data.n_seqs - for seq_idx in range(inter_data.n_seqs): - seq_data = seq_group_metadata.seq_data[ - inter_data.seq_ids[seq_idx]] - token_ids = seq_data.get_token_ids() - - mrope_input_positions, mrope_position_delta = \ - MRotaryEmbedding.get_input_positions( - token_ids, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - image_token_id=hf_config.image_token_id, - video_token_id=hf_config.video_token_id, - vision_start_token_id=hf_config.vision_start_token_id, - vision_end_token_id=hf_config.vision_end_token_id, - spatial_merge_size=hf_config.vision_config. - spatial_merge_size, - context_len=inter_data.context_lens[seq_idx], - ) - - seq_data.mrope_position_delta = mrope_position_delta - inter_data.mrope_input_positions[ - seq_idx] = mrope_input_positions - - def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): - """Add a sequence group to the builder.""" - seq_ids = seq_group_metadata.seq_data.keys() - n_seqs = len(seq_ids) - is_prompt = seq_group_metadata.is_prompt - - if is_prompt: - assert n_seqs == 1 - self.decode_only = False - - encoder_seq_len = 0 - - if self.runner.model_config.is_encoder_decoder_model: - encoder_seq_len = seq_group_metadata.encoder_seq_data.get_len() - - inter_data = self.init_cached_inter_data( - request_id=seq_group_metadata.request_id, - seq_ids=seq_ids, - is_prompt=is_prompt, - block_tables=seq_group_metadata.block_tables, - computed_block_nums=seq_group_metadata.computed_block_nums, - reinit=True, - reinit_use_defaults=True, - encoder_seq_len=encoder_seq_len) - - self.inter_data_list.append(inter_data) - - for seq_idx in range(n_seqs): - for per_seq_fn in self.per_seq_compute_fns: - per_seq_fn(inter_data, seq_idx, seq_group_metadata) - for per_seq_group_fn in self.per_seq_group_compute_fns: - per_seq_group_fn(inter_data, seq_group_metadata) - - def _use_captured_graph(self, - batch_size: int, - max_decode_seq_len: int, - max_encoder_seq_len: int = 0) -> bool: - return (self.decode_only and not self.runner.model_config.enforce_eager - and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] - and max_decode_seq_len <= self.runner.max_seq_len_to_capture - and max_encoder_seq_len <= self.runner.max_seq_len_to_capture - and batch_size <= self.runner.max_batchsize_to_capture) - - def build(self) -> ModelInputForGPU: - """Finalize the builder intermediate data and - create on-device tensors. - """ - # Combine and flatten intermediate data. - input_tokens = [] - for inter_data in self.inter_data_list: - for cur_input_tokens in inter_data.input_tokens: - input_tokens.extend(cur_input_tokens) - - if not input_tokens: - # This may happen when all prefill requests hit - # prefix caching and there is no decode request. - return self.model_input_cls() - - mrope_input_positions: Optional[List[List[int]]] = None - if any(inter_data.mrope_input_positions is not None - for inter_data in self.inter_data_list): - mrope_input_positions = [[] for _ in range(3)] - for idx in range(3): - for inter_data in self.inter_data_list: - msections = inter_data.mrope_input_positions - if msections is None: - for _seq_input_positions in inter_data.input_positions: - mrope_input_positions[idx].extend( - _seq_input_positions) - else: - for _seq_mrope_input_positions in msections: - mrope_input_positions[idx].extend( - _seq_mrope_input_positions[idx]) - input_positions = None - else: - input_positions = [] - for inter_data in self.inter_data_list: - for cur_input_positions in inter_data.input_positions: - input_positions.extend(cur_input_positions) - - seq_lens = [] - query_lens = [] - max_decode_seq_len = 0 - max_encoder_seq_len = 0 - for inter_data in self.inter_data_list: - seq_lens.extend(inter_data.seq_lens) - query_lens.extend(inter_data.query_lens) - if not inter_data.is_prompt: - max_decode_seq_len = max(max_decode_seq_len, - max(inter_data.seq_lens)) - if self.runner.model_config.is_encoder_decoder_model: - max_encoder_seq_len = max(max_encoder_seq_len, - inter_data.encoder_seq_len) - - # Mapping from request IDs to sequence IDs. Used for Jamba models - # that manages the cache by itself. - request_ids_to_seq_ids = { - data.request_id: data.seq_ids - for data in self.inter_data_list - } - - batch_size = len(input_tokens) - use_captured_graph = self._use_captured_graph( - batch_size, - max_decode_seq_len, - max_encoder_seq_len=max_encoder_seq_len) - - # If cuda graph can be used, pad tensors accordingly. - # See `capture_model` API for more details. - # vLLM uses cuda graph only for decoding requests. - cuda_graph_pad_size = -1 - if use_captured_graph: - graph_batch_size = _get_graph_batch_size(batch_size) - assert graph_batch_size >= batch_size - cuda_graph_pad_size = graph_batch_size - batch_size - batch_size = graph_batch_size - - # Tokens and positions. - if cuda_graph_pad_size: - input_tokens.extend(itertools.repeat(0, cuda_graph_pad_size)) - assert self.runner.device is not None - input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long, - self.runner.device, - self.runner.pin_memory) - if mrope_input_positions is not None: - for idx in range(3): - mrope_input_positions[idx].extend( - itertools.repeat(0, cuda_graph_pad_size)) - input_positions_tensor = async_tensor_h2d(mrope_input_positions, - torch.long, - self.runner.device, - self.runner.pin_memory) - else: - input_positions.extend(itertools.repeat(0, cuda_graph_pad_size)) - input_positions_tensor = async_tensor_h2d(input_positions, - torch.long, - self.runner.device, - self.runner.pin_memory) - # Sequence and query lengths. - if cuda_graph_pad_size: - seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size)) - - # Attention metadata. - attn_metadata = self.attn_metadata_builder.build( - seq_lens, query_lens, cuda_graph_pad_size, batch_size) - - # LoRA data. - lora_requests = set() - lora_mapping = None - if self.enable_lora: - lora_requests = set(r for data in self.inter_data_list - for r in data.lora_requests) - lora_index_mapping = flatten_2d_lists([ - flatten_2d_lists(inter_data.lora_index_mapping) - for inter_data in self.inter_data_list - ]) - if cuda_graph_pad_size: - lora_index_mapping.extend( - itertools.repeat(0, cuda_graph_pad_size)) - lora_prompt_mapping = flatten_2d_lists([ - flatten_2d_lists(inter_data.lora_prompt_mapping) - for inter_data in self.inter_data_list - ]) - - lora_mapping = LoRAMapping( - **dict(index_mapping=lora_index_mapping, - prompt_mapping=lora_prompt_mapping, - is_prefill=not self.decode_only)) - - # Prompt adapter data. - prompt_adapter_requests: Set[PromptAdapterRequest] = set() - prompt_adapter_mapping = None - if self.enable_prompt_adapter: - prompt_adapter_requests = set( - data.prompt_adapter_request for data in self.inter_data_list - if data.prompt_adapter_request is not None) - prompt_adapter_index_mapping = flatten_2d_lists([ - inter_data.prompt_adapter_index_mapping - for inter_data in self.inter_data_list - ]) - if cuda_graph_pad_size: - prompt_adapter_index_mapping.extend( - itertools.repeat(0, cuda_graph_pad_size)) - prompt_adapter_prompt_mapping = flatten_2d_lists([ - inter_data.prompt_adapter_prompt_mapping - for inter_data in self.inter_data_list - ]) - prompt_adapter_mapping = PromptAdapterMapping( - prompt_adapter_index_mapping, - prompt_adapter_prompt_mapping, - ) - - # Multi-modal data. - multi_modal_inputs_list = [ - data.multi_modal_inputs for data in self.inter_data_list - if data.multi_modal_inputs is not None - ] - multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) - - return self.model_input_cls( - input_tokens=input_tokens_tensor, - input_positions=input_positions_tensor, - attn_metadata=attn_metadata, - seq_lens=seq_lens, - query_lens=query_lens, - lora_mapping=lora_mapping, - lora_requests=lora_requests, - multi_modal_kwargs=multi_modal_kwargs, - request_ids_to_seq_ids=request_ids_to_seq_ids, - finished_requests_ids=self.finished_requests_ids, - prompt_adapter_mapping=prompt_adapter_mapping, - prompt_adapter_requests=prompt_adapter_requests) - - -class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): - """ - Helper class for shared methods between GPU model runners. - """ - _model_input_cls: Type[TModelInputForGPU] - _builder_cls: Type[ModelInputForGPUBuilder] - - def __init__( - self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - kv_cache_dtype: Optional[str] = "auto", - is_driver_worker: bool = False, - prompt_adapter_config: Optional[PromptAdapterConfig] = None, - return_hidden_states: bool = False, - observability_config: Optional[ObservabilityConfig] = None, - input_registry: InputRegistry = INPUT_REGISTRY, - mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, - ): - self.model_config = model_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.cache_config = cache_config - self.lora_config = lora_config - self.load_config = load_config - self.is_driver_worker = is_driver_worker - self.prompt_adapter_config = prompt_adapter_config - self.return_hidden_states = return_hidden_states - self.observability_config = observability_config - - self.device = self.device_config.device - self.pin_memory = is_pin_memory_available() - - self.kv_cache_dtype = kv_cache_dtype - self.sliding_window = model_config.get_sliding_window() - self.block_size = cache_config.block_size - self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture - self.max_batchsize_to_capture = _get_max_graph_batch_size( - self.scheduler_config.max_num_seqs) - - self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [ - {} for _ in range(self.parallel_config.pipeline_parallel_size) - ] - self.graph_memory_pool: Optional[Tuple[ - int, int]] = None # Set during graph capture. - - self.has_seqlen_agnostic = model_config.contains_seqlen_agnostic_layers( - parallel_config) - - # When using CUDA graph, the input block tables must be padded to - # max_seq_len_to_capture. However, creating the block table in - # Python can be expensive. To optimize this, we cache the block table - # in numpy and only copy the actual input content at every iteration. - # The shape of the cached block table will be - # (max batch size to capture, max context len to capture / block size). - self.graph_block_tables = np.zeros( - (self.max_batchsize_to_capture, self.get_max_block_per_batch()), - dtype=np.int32) - num_attn_heads = self.model_config.get_num_attention_heads( - self.parallel_config) - self.attn_backend = get_attn_backend( - num_attn_heads, - self.model_config.get_head_size(), - self.model_config.get_num_kv_heads(self.parallel_config), - self.model_config.get_sliding_window(), - self.model_config.dtype, - self.kv_cache_dtype, - self.block_size, - ) if num_attn_heads else None - if self.attn_backend: - self.attn_state = self.attn_backend.get_state_cls()( - weakref.proxy(self)) - else: - self.attn_state = CommonAttentionState(weakref.proxy(self)) - - # Multi-modal data support - self.input_registry = input_registry - self.mm_registry = mm_registry - self.multi_modal_input_mapper = mm_registry \ - .create_input_mapper(model_config) - self.mm_registry.init_mm_limits_per_prompt(self.model_config) - - # Lazy initialization - self.model: nn.Module # Set after load_model - # Set after load_model. - self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None - self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None - - set_cpu_offload_max_bytes( - int(self.cache_config.cpu_offload_gb * 1024**3)) - - # Used to cache python objects - self.inter_data_cache: Dict[int, PyObjectCache] = {} - self.sampling_metadata_cache: SamplingMetadataCache = \ - SamplingMetadataCache() - - def load_model(self) -> None: - logger.info("Starting to load model %s...", self.model_config.model) - with DeviceMemoryProfiler() as m: - self.model = get_model(model_config=self.model_config, - device_config=self.device_config, - load_config=self.load_config, - lora_config=self.lora_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - cache_config=self.cache_config) - - self.model_memory_usage = m.consumed_memory - logger.info("Loading model weights took %.4f GB", - self.model_memory_usage / float(2**30)) - - if self.lora_config: - assert supports_lora(self.model), "Model does not support LoRA" - assert not supports_multimodal( - self.model - ), "To be tested: Multi-modal model with LoRA settings." - - self.lora_manager = LRUCacheWorkerLoRAManager( - self.scheduler_config.max_num_seqs, - self.scheduler_config.max_num_batched_tokens, - self.vocab_size, - self.lora_config, - self.device, - self.model.embedding_modules, - self.model.embedding_padding_modules, - max_position_embeddings=self.model.config. - max_position_embeddings, - ) - self.model = self.lora_manager.create_lora_manager(self.model) - - if self.prompt_adapter_config: - self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager( - self.scheduler_config.max_num_seqs, - self.scheduler_config.max_num_batched_tokens, self.device, - self.prompt_adapter_config) - self.model = ( - self.prompt_adapter_manager.create_prompt_adapter_manager( - self.model)) - - if self.kv_cache_dtype == "fp8" and is_hip(): - # Currently only ROCm accepts kv-cache scaling factors - # via quantization_param_path and this will be deprecated - # in the future. - if self.model_config.quantization_param_path is not None: - if callable(getattr(self.model, "load_kv_cache_scales", None)): - warnings.warn( - "Loading kv cache scaling factor from JSON is " - "deprecated and will be removed. Please include " - "kv cache scaling factors in the model checkpoint.", - FutureWarning, - stacklevel=2) - self.model.load_kv_cache_scales( - self.model_config.quantization_param_path) - logger.info("Loaded KV cache scaling factors from %s", - self.model_config.quantization_param_path) - else: - raise RuntimeError( - "Using FP8 KV cache and scaling factors provided but " - "model %s does not support loading scaling factors.", - self.model.__class__) - else: - logger.warning( - "Using FP8 KV cache but no scaling factors " - "provided. Defaulting to scaling factors of 1.0. " - "This may lead to less accurate results!") - - if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo(): - from vllm.compilation.backends import vllm_backend - from vllm.plugins import get_torch_compile_backend - backend = get_torch_compile_backend() or vllm_backend - self.model = torch.compile( - self.model, - fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, - backend=backend) - - def save_sharded_state( - self, - path: str, - pattern: Optional[str] = None, - max_size: Optional[int] = None, - ) -> None: - from vllm.model_executor.model_loader.loader import ShardedStateLoader - ShardedStateLoader.save_model( - self.model, - path, - pattern=pattern, - max_size=max_size, - ) - - def save_tensorized_model( - self, - tensorizer_config: TensorizerConfig, - ) -> None: - from vllm.model_executor.model_loader.loader import TensorizerLoader - TensorizerLoader.save_model( - self.model, - tensorizer_config=tensorizer_config, - ) - - def get_max_block_per_batch(self) -> int: - block_size = self.block_size - return (self.max_seq_len_to_capture + block_size - 1) // block_size - - def _prepare_model_input_tensors( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - finished_requests_ids: Optional[List[str]] = None - ) -> TModelInputForGPU: - """Helper method to prepare the model input based on a given sequence - group. Prepares metadata needed for the base model forward pass but not - metadata for possible additional steps, e.g., sampling. - - The API assumes seq_group_metadata_list is sorted by prefill -> decode. - - The result tensors and data structure also batches input in prefill - -> decode order. For example, - - - input_tokens[:num_prefill_tokens] contains prefill tokens. - - input_tokens[num_prefill_tokens:] contains decode tokens. - - If cuda graph is required, this API automatically pads inputs. - """ - builder = self._builder_cls(weakref.proxy(self), finished_requests_ids) - for seq_group_metadata in seq_group_metadata_list: - builder.add_seq_group(seq_group_metadata) - - builder.reset_cached_inter_data() - - return builder.build() # type: ignore - - @torch.inference_mode() - def profile_run(self) -> None: - # Enable top-k sampling to reflect the accurate memory usage. - sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) - max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens - max_num_seqs = self.scheduler_config.max_num_seqs - # This represents the maximum number of different requests - # that will have unique loras, an therefore the max amount of memory - # consumption create dummy lora request copies from the lora request - # passed in, which contains a lora from the lora warmup path. - dummy_lora_requests: List[LoRARequest] = [] - dummy_lora_requests_per_seq: List[LoRARequest] = [] - if self.lora_config: - assert self.lora_manager is not None - with self.lora_manager.dummy_lora_cache(): - for idx in range(self.lora_config.max_loras): - lora_id = idx + 1 - dummy_lora_request = LoRARequest( - lora_name=f"warmup_{lora_id}", - lora_int_id=lora_id, - lora_path="/not/a/real/path", - ) - self.lora_manager.add_dummy_lora(dummy_lora_request, - rank=LORA_WARMUP_RANK) - dummy_lora_requests.append(dummy_lora_request) - dummy_lora_requests_per_seq = [ - dummy_lora_requests[idx % len(dummy_lora_requests)] - for idx in range(max_num_seqs) - ] - - # Profile memory usage with max_num_sequences sequences and the total - # number of tokens equal to max_num_batched_tokens. - seqs: List[SequenceGroupMetadata] = [] - # Additional GPU memory may be needed for multi-modal encoding, which - # needs to be accounted for when calculating the GPU blocks for - # vLLM blocker manager. - # To exercise the worst scenario for GPU memory consumption, - # the number of seqs (batch_size) is chosen to maximize the number - # of images processed. - - max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( - self.model_config) - if max_mm_tokens > 0: - max_num_seqs_orig = max_num_seqs - max_num_seqs = min(max_num_seqs, - max_num_batched_tokens // max_mm_tokens) - if max_num_seqs < 1: - expr = (f"min({max_num_seqs_orig}, " - f"{max_num_batched_tokens} // {max_mm_tokens})") - logger.warning( - "Computed max_num_seqs (%s) to be less than 1. " - "Setting it to the minimum value of 1.", expr) - max_num_seqs = 1 - - batch_size = 0 - for group_id in range(max_num_seqs): - seq_len = (max_num_batched_tokens // max_num_seqs + - (group_id < max_num_batched_tokens % max_num_seqs)) - batch_size += seq_len - - seq_data, dummy_multi_modal_data = self.input_registry \ - .dummy_data_for_profiling(self.model_config, - seq_len, - self.mm_registry) - - seq = SequenceGroupMetadata( - request_id=str(group_id), - is_prompt=True, - seq_data={group_id: seq_data}, - sampling_params=sampling_params, - block_tables=None, - lora_request=dummy_lora_requests_per_seq[group_id] - if dummy_lora_requests_per_seq else None, - multi_modal_data=dummy_multi_modal_data, - ) - seqs.append(seq) - - # Run the model with the dummy inputs. - num_layers = self.model_config.get_num_layers(self.parallel_config) - kv_caches = [None] * num_layers - finished_requests_ids = [seq.request_id for seq in seqs] - model_input = self.prepare_model_input( - seqs, finished_requests_ids=finished_requests_ids) - intermediate_tensors = None - if not get_pp_group().is_first_rank: - intermediate_tensors = self.model.make_empty_intermediate_tensors( - batch_size=batch_size, - dtype=self.model_config.dtype, - device=self.device) - self.execute_model(model_input, kv_caches, intermediate_tensors) - torch.cuda.synchronize() - return - - def remove_all_loras(self): - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - self.lora_manager.remove_all_adapters() - - def set_active_loras(self, lora_requests: Set[LoRARequest], - lora_mapping: LoRAMapping) -> None: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - self.lora_manager.set_active_adapters(lora_requests, lora_mapping) - - def add_lora(self, lora_request: LoRARequest) -> bool: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.add_adapter(lora_request) - - def remove_lora(self, lora_id: int) -> bool: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.remove_adapter(lora_id) - - def pin_lora(self, lora_id: int) -> bool: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.pin_adapter(lora_id) - - def list_loras(self) -> Set[int]: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.list_adapters() - - def remove_all_prompt_adapters(self): - if not self.prompt_adapter_manager: - raise RuntimeError("PromptAdapter is not enabled.") - self.prompt_adapter_manager.remove_all_adapters() - - def set_active_prompt_adapters( - self, prompt_adapter_requests: Set[PromptAdapterRequest], - prompt_adapter_mapping: PromptAdapterMapping) -> None: - if not self.prompt_adapter_manager: - raise RuntimeError("PromptAdapter is not enabled.") - self.prompt_adapter_manager.set_active_adapters( - prompt_adapter_requests, prompt_adapter_mapping) - - def add_prompt_adapter( - self, prompt_adapter_request: PromptAdapterRequest) -> bool: - if not self.prompt_adapter_manager: - raise RuntimeError("PromptAdapter is not enabled.") - return self.prompt_adapter_manager.add_adapter(prompt_adapter_request) - - def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: - if not self.prompt_adapter_manager: - raise RuntimeError("PromptAdapter is not enabled.") - return self.prompt_adapter_manager.remove_adapter(prompt_adapter_id) - - def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: - if not self.prompt_adapter_manager: - raise RuntimeError("PromptAdapter is not enabled.") - return self.prompt_adapter_manager.pin_adapter(prompt_adapter_id) - - def list_prompt_adapters(self) -> Set[int]: - if not self.prompt_adapter_manager: - raise RuntimeError("PromptAdapter is not enabled.") - return self.prompt_adapter_manager.list_adapters() - - @property - def model_is_mrope(self) -> bool: - """Detect if the model has "mrope" rope_scaling type. - mrope requires keep "rope_deltas" between prompt and decoding phases.""" - rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {}) - if rope_scaling is None: - return False - return rope_scaling.get("type", None) == "mrope" - - @torch.inference_mode() - def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: - """Cuda graph capture a model. - - Note that CUDA graph's performance gain is negligible if number - of batched tokens are larger than 200. And since CUDA graph - requires fixed sized tensors, supporting large/variable batch - size requires high GPU memory overhead. Thus, vLLM only captures - decoding requests. Mixed batch (chunked prefill + decoding) or - prefill requests are not captured. - - Since it is used for decoding-only, it assumes there's only 1 token - per sequence in the batch. - """ - assert not self.model_config.enforce_eager - logger.info("Capturing the model for CUDA graphs. This may lead to " - "unexpected consequences if the model is not static. To " - "run the model in eager mode, set 'enforce_eager=True' or " - "use '--enforce-eager' in the CLI.") - logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. " - "If you are running out of memory, consider decreasing " - "`gpu_memory_utilization` or enforcing eager mode. " - "You can also reduce the `max_num_seqs` as needed " - "to decrease memory usage.") - start_time = time.perf_counter() - - # Prepare dummy inputs. These will be reused for all batch sizes. - max_batch_size = self.max_batchsize_to_capture - input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda() - input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() - if self.model_is_mrope: - input_positions = torch.tile(input_positions, (3, 1)) - # Prepare dummy previous_hidden_states only if needed by the model. - # This is used by draft models such as EAGLE. - previous_hidden_states = None - if "previous_hidden_states" in inspect.signature( - self.model.forward).parameters: - previous_hidden_states = torch.empty( - [max_batch_size, - self.model_config.get_hidden_size()], - dtype=self.model_config.dtype, - device=self.device) - - intermediate_inputs = None - if not get_pp_group().is_first_rank: - intermediate_inputs = self.model.make_empty_intermediate_tensors( - batch_size=max_batch_size, - dtype=self.model_config.dtype, - device=self.device) - - # Prepare buffer for outputs. These will be reused for all batch sizes. - # It will be filled after the first graph capture. - hidden_or_intermediate_states: List[Optional[torch.Tensor]] = [ - None - ] * self.parallel_config.pipeline_parallel_size - - graph_batch_size = self.max_batchsize_to_capture - batch_size_capture_list = [ - bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size - ] - - with self.attn_state.graph_capture( - max_batch_size), graph_capture() as graph_capture_context: - # NOTE: Capturing the largest batch size first may help reduce the - # memory usage of CUDA graph. - for virtual_engine in range( - self.parallel_config.pipeline_parallel_size): - for batch_size in reversed(batch_size_capture_list): - attn_metadata = ( - self.attn_state.graph_capture_get_metadata_for_batch( - batch_size, - is_encoder_decoder_model=self.model_config. - is_encoder_decoder_model)) - - if self.lora_config: - lora_mapping = LoRAMapping( - **dict(index_mapping=[0] * batch_size, - prompt_mapping=[0] * batch_size, - is_prefill=False)) - self.set_active_loras(set(), lora_mapping) - - if self.prompt_adapter_config: - prompt_adapter_mapping = PromptAdapterMapping( - [-1] * batch_size, - [-1] * batch_size, - ) - self.set_active_prompt_adapters( - set(), prompt_adapter_mapping) - graph_runner = CUDAGraphRunner( - self.model, self.attn_backend.get_name(), - self.attn_state.graph_clone(batch_size), - self.model_config.is_encoder_decoder_model) - - capture_inputs = { - "input_ids": - input_tokens[:batch_size], - "positions": - input_positions[..., :batch_size], - "hidden_or_intermediate_states": - hidden_or_intermediate_states[ - virtual_engine] # type: ignore - [:batch_size] - if hidden_or_intermediate_states[virtual_engine] - is not None else None, - "intermediate_inputs": - intermediate_inputs[:batch_size] - if intermediate_inputs is not None else None, - "kv_caches": - kv_caches[virtual_engine], - "attn_metadata": - attn_metadata, - "memory_pool": - self.graph_memory_pool, - "stream": - graph_capture_context.stream - } - if previous_hidden_states is not None: - capture_inputs[ - "previous_hidden_states"] = previous_hidden_states[: - batch_size] - - if self.has_seqlen_agnostic: - # Only used by Mamba-based models CUDA graph atm (Jamba) - capture_inputs.update({ - "seqlen_agnostic_capture_inputs": - self.model.get_seqlen_agnostic_capture_inputs( - batch_size) - }) - if self.model_config.is_encoder_decoder_model: - # add the additional inputs to capture for - # encoder-decoder models. - self._update_inputs_to_capture_for_enc_dec_model( - capture_inputs) - - graph_runner.capture(**capture_inputs) - self.graph_memory_pool = graph_runner.graph.pool() - self.graph_runners[virtual_engine][batch_size] = ( - graph_runner) - - end_time = time.perf_counter() - elapsed_time = end_time - start_time - # This usually takes < 10 seconds. - logger.info("Graph capturing finished in %.0f secs.", elapsed_time) - - def _update_inputs_to_capture_for_enc_dec_model(self, - capture_inputs: Dict[str, - Any]): - """ - Updates the set of input tensors needed for CUDA graph capture in an - encoder-decoder model. - - This method modifies the provided `capture_inputs` dictionary by - adding tensors specific to encoder-decoder specific models that - need to be captured for CUDA Graph replay. - """ - # During the decode phase encoder_input_ids and encoder_positions are - # unset. Do the same thing for graph capture. - capture_inputs["encoder_input_ids"] = torch.tensor( - [], dtype=torch.long).cuda() - capture_inputs["encoder_positions"] = torch.tensor( - [], dtype=torch.long).cuda() - - @property - def vocab_size(self) -> int: - return self.model_config.get_vocab_size() - - -class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): - """ - GPU model runner with sampling step. - """ - _model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = ( - ModelInputForGPUWithSamplingMetadata) - _builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder - - def make_model_input_from_broadcasted_tensor_dict( - self, - tensor_dict: Dict[str, Any], - ) -> ModelInputForGPUWithSamplingMetadata: - model_input = \ - ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict( - tensor_dict, - attn_backend=self.attn_backend, - ) - return model_input - - def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None, - ) -> ModelInputForGPUWithSamplingMetadata: - """Prepare the model input based on a given sequence group, including - metadata for the sampling step. - - The API assumes seq_group_metadata_list is sorted by prefill -> decode. - - The result tensors and data structure also batches input in prefill - -> decode order. For example, - - - input_tokens[:num_prefill_tokens] contains prefill tokens. - - input_tokens[num_prefill_tokens:] contains decode tokens. - - If cuda graph is required, this API automatically pads inputs. - """ - model_input = self._prepare_model_input_tensors( - seq_group_metadata_list, finished_requests_ids) - if get_pp_group().is_last_rank: - # Sampling metadata is only required for the final pp group - generators = self.get_generators(finished_requests_ids) - sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, model_input.seq_lens, - model_input.query_lens, self.device, self.pin_memory, - generators, self.sampling_metadata_cache) - else: - sampling_metadata = None - is_prompt = (seq_group_metadata_list[0].is_prompt - if seq_group_metadata_list else None) - return dataclasses.replace(model_input, - sampling_metadata=sampling_metadata, - is_prompt=is_prompt, - virtual_engine=virtual_engine) - - @torch.inference_mode() - @dump_input_when_exception(exclude_args=[0], exclude_kwargs=["self"]) - def execute_model( - self, - model_input: ModelInputForGPUWithSamplingMetadata, - kv_caches: List[torch.Tensor], - intermediate_tensors: Optional[IntermediateTensors] = None, - num_steps: int = 1, - ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: - if num_steps > 1: - raise ValueError("num_steps > 1 is not supported in ModelRunner") - - if self.lora_config: - assert model_input.lora_requests is not None - assert model_input.lora_mapping is not None - self.set_active_loras(model_input.lora_requests, - model_input.lora_mapping) - - if self.prompt_adapter_config: - assert model_input.prompt_adapter_requests is not None - assert model_input.prompt_adapter_mapping is not None - self.set_active_prompt_adapters( - model_input.prompt_adapter_requests, - model_input.prompt_adapter_mapping) - - self.attn_state.begin_forward(model_input) - - # Currently cuda graph is only supported by the decode phase. - assert model_input.attn_metadata is not None - prefill_meta = model_input.attn_metadata.prefill_metadata - decode_meta = model_input.attn_metadata.decode_metadata - # TODO(andoorve): We can remove this once all - # virtual engines share the same kv cache. - virtual_engine = model_input.virtual_engine - if prefill_meta is None and decode_meta.use_cuda_graph: - assert model_input.input_tokens is not None - graph_batch_size = model_input.input_tokens.shape[0] - model_executable = self.graph_runners[virtual_engine][ - graph_batch_size] - else: - model_executable = self.model - - multi_modal_kwargs = model_input.multi_modal_kwargs or {} - seqlen_agnostic_kwargs = { - "finished_requests_ids": model_input.finished_requests_ids, - "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, - } if self.has_seqlen_agnostic else {} - if (self.observability_config is not None - and self.observability_config.collect_model_forward_time): - model_forward_start = torch.cuda.Event(enable_timing=True) - model_forward_end = torch.cuda.Event(enable_timing=True) - model_forward_start.record() - - hidden_or_intermediate_states = model_executable( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, - intermediate_tensors=intermediate_tensors, - **MultiModalInputs.as_kwargs(multi_modal_kwargs, - device=self.device), - **seqlen_agnostic_kwargs) - - if (self.observability_config is not None - and self.observability_config.collect_model_forward_time): - model_forward_end.record() - - # Compute the logits in the last pipeline stage. - if not get_pp_group().is_last_rank: - if (self.is_driver_worker - and hidden_or_intermediate_states is not None - and isinstance(hidden_or_intermediate_states, - IntermediateTensors) - and self.observability_config is not None - and self.observability_config.collect_model_forward_time): - model_forward_end.synchronize() - model_forward_time = model_forward_start.elapsed_time( - model_forward_end) - orig_model_forward_time = 0.0 - if intermediate_tensors is not None: - orig_model_forward_time = intermediate_tensors.tensors.get( - "model_forward_time", torch.tensor(0.0)).item() - hidden_or_intermediate_states.tensors["model_forward_time"] = ( - torch.tensor(model_forward_time + orig_model_forward_time)) - return hidden_or_intermediate_states - - logits = self.model.compute_logits(hidden_or_intermediate_states, - model_input.sampling_metadata) - - if not self.is_driver_worker: - return [] - - if model_input.async_callback is not None: - model_input.async_callback() - - # Sample the next token. - output: SamplerOutput = self.model.sample( - logits=logits, - sampling_metadata=model_input.sampling_metadata, - ) - if (self.observability_config is not None - and self.observability_config.collect_model_forward_time - and output is not None): - model_forward_end.synchronize() - model_forward_time = model_forward_start.elapsed_time( - model_forward_end) - orig_model_forward_time = 0.0 - if intermediate_tensors is not None: - orig_model_forward_time = intermediate_tensors.tensors.get( - "model_forward_time", torch.tensor(0.0)).item() - # If there are multiple workers, we are still tracking the latency - # from the start time of the driver worker to the end time of the - # driver worker. The model forward time will then end up covering - # the communication time as well. - output.model_forward_time = (orig_model_forward_time + - model_forward_time) - - if self.return_hidden_states: - # we only need to pass hidden states of most recent token - assert model_input.sampling_metadata is not None - indices = model_input.sampling_metadata.selected_token_indices - if model_input.is_prompt: - hidden_states = hidden_or_intermediate_states.index_select( - 0, indices) - output.prefill_hidden_states = hidden_or_intermediate_states - elif decode_meta.use_cuda_graph: - hidden_states = hidden_or_intermediate_states[:len(indices)] - else: - hidden_states = hidden_or_intermediate_states - - output.hidden_states = hidden_states - - return [output] - - -class CUDAGraphRunner: - - def __init__(self, model: nn.Module, backend_name: str, - attn_state: AttentionState, is_encoder_decoder_model: bool): - self.model = model - self.backend_name = backend_name - self.attn_state = attn_state - - self.input_buffers: Dict[str, torch.Tensor] = {} - self.output_buffers: Dict[str, torch.Tensor] = {} - - self._graph: Optional[torch.cuda.CUDAGraph] = None - self._is_encoder_decoder_model = is_encoder_decoder_model - - @property - def graph(self): - assert self._graph is not None - return self._graph - - def capture( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - hidden_or_intermediate_states: Optional[Union[IntermediateTensors, - torch.Tensor]], - intermediate_inputs: Optional[IntermediateTensors], - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - memory_pool: Optional[Tuple[int, int]], - stream: torch.cuda.Stream, - **kwargs, - ) -> Union[torch.Tensor, IntermediateTensors]: - assert self._graph is None - # Run the model a few times without capturing the graph. - # This is to make sure that the captured graph does not include the - # kernel launches for initial benchmarking (e.g., Triton autotune). - # Note one iteration is not enough for torch.jit.script - for _ in range(_NUM_WARMUP_ITERS): - self.model( - input_ids=input_ids, - positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, - intermediate_tensors=intermediate_inputs, - **kwargs, - ) - # Wait for the warm up operations to finish before proceeding with - # Graph Capture. - torch.cuda.synchronize() - # Capture the graph. - self._graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream): - output_hidden_or_intermediate_states = self.model( - input_ids=input_ids, - positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, - intermediate_tensors=intermediate_inputs, - **kwargs, - ) - if hidden_or_intermediate_states is not None: - if get_pp_group().is_last_rank: - hidden_or_intermediate_states.copy_( - output_hidden_or_intermediate_states) - else: - for key in hidden_or_intermediate_states.tensors: - hidden_or_intermediate_states[key].copy_( - output_hidden_or_intermediate_states[key]) - else: - hidden_or_intermediate_states = ( - output_hidden_or_intermediate_states) - - del output_hidden_or_intermediate_states - # make sure `output_hidden_states` is deleted - # in the graph's memory pool - gc.collect() - torch.cuda.synchronize() - - # Save the input and output buffers. - self.input_buffers = { - "input_ids": - input_ids, - "positions": - positions, - "kv_caches": - kv_caches, - **self.attn_state.get_graph_input_buffers( - attn_metadata, self._is_encoder_decoder_model), - **kwargs, - } - if intermediate_inputs is not None: - self.input_buffers.update(intermediate_inputs.tensors) - if get_pp_group().is_last_rank: - self.output_buffers = { - "hidden_states": hidden_or_intermediate_states - } - else: - self.output_buffers = hidden_or_intermediate_states - return hidden_or_intermediate_states - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors], - **kwargs, - ) -> torch.Tensor: - # KV caches are fixed tensors, so we don't need to copy them. - del kv_caches - - # Copy the input tensors to the input buffers. - self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True) - self.input_buffers["positions"].copy_(positions, non_blocking=True) - self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping, - non_blocking=True) - self.attn_state.prepare_graph_input_buffers( - self.input_buffers, attn_metadata, self._is_encoder_decoder_model) - if "seqlen_agnostic_capture_inputs" in self.input_buffers: - self.model.copy_inputs_before_cuda_graphs(self.input_buffers, - **kwargs) - - if "previous_hidden_states" in self.input_buffers: - self.input_buffers["previous_hidden_states"].copy_( - kwargs["previous_hidden_states"], non_blocking=True) - - if intermediate_tensors is not None: - for key in intermediate_tensors.tensors: - if key != "model_execute_time" and key != "model_forward_time": - self.input_buffers[key].copy_(intermediate_tensors[key], - non_blocking=True) - if self._is_encoder_decoder_model: - self.input_buffers["encoder_input_ids"].copy_( - kwargs['encoder_input_ids'], non_blocking=True) - self.input_buffers["encoder_positions"].copy_( - kwargs['encoder_positions'], non_blocking=True) - - # Run the graph. - self.graph.replay() - # Return the output tensor. - if get_pp_group().is_last_rank: - return self.output_buffers["hidden_states"] - - return self.output_buffers - - def __call__(self, *args, **kwargs): - return self.forward(*args, **kwargs) - - -def _get_graph_batch_size(batch_size: int) -> int: - """Returns the padded batch size given actual batch size. - - Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT, - 2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT... - """ - if batch_size <= 2: - return batch_size - elif batch_size <= 4: - return 4 - else: - return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) // - _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT) - - -def _get_max_graph_batch_size(max_num_seqs: int) -> int: - """ - max_num_seqs: Maximum number of sequences in a batch. - _BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture. - - pad the max_num_seqs if necessary by calling _get_graph_batch_size, - which will deal with some edge cases like 1, 2, 4. - - if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded size. - if not, it means the padded size is larger than the largest size in - _BATCH_SIZES_TO_CAPTURE, return the largest size in _BATCH_SIZES_TO_CAPTURE. - """ - padded_size = _get_graph_batch_size(max_num_seqs) - if padded_size in _BATCH_SIZES_TO_CAPTURE: - return padded_size - assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1] - return _BATCH_SIZES_TO_CAPTURE[-1] diff --git a/vllm/worker/model_runner_v2.py b/vllm/worker/model_runner_v2.py new file mode 100644 index 0000000000000..06ac2f024da07 --- /dev/null +++ b/vllm/worker/model_runner_v2.py @@ -0,0 +1,411 @@ +from dataclasses import dataclass +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, + Tuple, Type, TypeVar, Union) + +import numpy as np +import torch +import torch.distributed +import torch.nn as nn + +import vllm.envs as envs +from vllm.attention import AttentionMetadata, get_attn_backend +from vllm.attention.backends.abstract import AttentionState +from vllm.attention.backends.utils import CommonAttentionState +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ObservabilityConfig, ParallelConfig, + PromptAdapterConfig, SchedulerConfig) +from vllm.inputs import INPUT_REGISTRY, InputRegistry +from vllm.logger import init_logger +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.model_loader import get_model +from vllm.sampling_params import SamplingParams, SamplingType +from vllm.utils import (DeviceMemoryProfiler, is_pin_memory_available) +from vllm.worker.model_runner_base import dump_input_when_exception +from vllm.multimodal import MultiModalDataDict + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + from vllm.core.scheduler_v2 import SchedulerOutput + +logger = init_logger(__name__) + + +class GPUModelRunner: + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + kv_cache_dtype: Optional[str] = "auto", + lora_config: Optional[LoRAConfig] = None, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, + observability_config: Optional[ObservabilityConfig] = None, + ): + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.cache_config = cache_config + self.lora_config = lora_config + self.load_config = load_config + self.prompt_adapter_config = prompt_adapter_config + self.observability_config = observability_config + + self.device = self.device_config.device + self.pin_memory = is_pin_memory_available() + + self.kv_cache_dtype = kv_cache_dtype + self.sliding_window = model_config.get_sliding_window() + self.block_size = cache_config.block_size + self.max_num_blocks_per_req = ( + (self.model_config.max_model_len + self.block_size - 1) + // self.block_size) + + # Lazy initialization + self.model: nn.Module # Set after load_model + self.kv_caches: List[torch.Tensor] = [] + + # Request states. + self.requests: Dict[str, RequestState] = {} + self.batched_states = BatchedRequestStates( + max_num_reqs=self.scheduler_config.max_num_seqs, + max_num_blocks_per_req=self.max_num_blocks_per_req, + device=self.device, + pin_memory=self.pin_memory, + ) + + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: + # Remove stopped requests from the cached states. + # Keep the states of the pre-empted requests. + for req_id in scheduler_output.finished_req_ids: + self.requests.pop(req_id, None) + for req_id in scheduler_output.aborted_req_ids: + self.requests.pop(req_id, None) + + # Remove the requests from the batched states. + stopped_req_ids = ( + scheduler_output.preempted_req_ids + + scheduler_output.finished_req_ids + + scheduler_output.aborted_req_ids) + removed_req_indices: List[int] = [] + for req_id in stopped_req_ids: + req_index = self.batched_states.remove_request(req_id) + if req_index is not None: + removed_req_indices.append(req_index) + + # Update the states of the running requests. + num_prev_blocks: Dict[str, int] = {} + new_block_ids: Dict[str, List[int]] = {} + for req_data in scheduler_output.scheduled_running_reqs: + req_id = req_data.req_id + req_state = self.requests[req_id] + + num_prev_blocks[req_id] = len(req_state.block_ids) + new_block_ids[req_id] = req_data.new_block_ids + req_state.block_ids.extend(req_data.new_block_ids) + req_state.num_computed_tokens = req_data.num_computed + # Update the block table and the number of computed tokens + # of the running requests. + for req_id in self.batched_states.req_ids: + if req_id is None: + continue + start_block_index = num_prev_blocks[req_id] + block_ids = new_block_ids[req_id] + end_block_index = start_block_index + len(block_ids) + self.batched_states.block_table_cpu[ + req_index, start_block_index:end_block_index] = block_ids + self.batched_states.num_computed_tokens_cpu[req_index] = ( + self.requests[req_id].num_computed_tokens) + + req_ids_to_add: List[str] = [] + # Add new requests to the cached states. + for req_data in scheduler_output.scheduled_new_reqs: + req_id = req_data.req_id + prompt_token_ids_cpu = torch.tensor( + req_data.prompt_token_ids, device="cpu", pin_memory=self.pin_memory) + prompt_token_ids = prompt_token_ids_cpu.to(self.device, non_blocking=True) + + self.requests[req_id] = RequestState( + req_id=req_id, + prompt_token_ids=prompt_token_ids, + prompt_token_ids_cpu=prompt_token_ids_cpu, + prompt=req_data.prompt, + multi_modal_data=req_data.multi_modal_data, + sampling_params=req_data.sampling_params, + generator=None, # TODO + block_ids=req_data.block_ids, + num_computed_tokens=req_data.num_computed_tokens, + output_token_ids=[], + ) + req_ids_to_add.append(req_id) + + # Update the cached states of the resumed requests. + for req_data in scheduler_output.scheduled_resumed_reqs: + req_id = req_data.req_id + req_state = self.requests[req_id] + + req_state.block_ids = req_data.block_ids + req_state.num_computed_tokens = req_data.num_computed_tokens + req_ids_to_add.append(req_id) + + # Add the new or resumed requests to the batched states. + # The smaller empty indices are filled first. + removed_req_indices.sort(reverse=True) + for req_id in req_ids_to_add: + req_state = self.requests[req_id] + if removed_req_indices: + # TODO(woosuk): Consider LoRA. + req_index = removed_req_indices.pop() + else: + req_index = self.batched_states.num_reqs + self.batched_states.add_request(req_state, req_index) + + # Condense the batched states. + self.batched_states.condense(removed_req_indices) + + def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): + + pass + + @torch.inference_mode() + @dump_input_when_exception(exclude_args=[0], exclude_kwargs=["self"]) + def execute_model( + self, + scheduler_output: "SchedulerOutput", + ) -> SamplerOutput: + self._update_states(scheduler_output) + inputs = self._prepare_inputs(scheduler_output) + input_ids, position_ids, attn_metadata = inputs + hidden_states = self.model( + input_ids=input_ids, + position_ids=position_ids, + attn_metadata=attn_metadata, + kv_caches=self.kv_caches, + ) + + logits = self.model.compute_logits(hidden_states, + sampling_metadata) + # Create the sampling metadata. + sampling_metadata = self.batched_states.get_sampling_metadata() + # Sample the next token and get logprobs if needed. + sampler_output = self.model.sample( + logits=logits, + sampling_metadata=sampling_metadata, + ) + return sampler_output + + def load_model(self) -> None: + logger.info("Starting to load model %s...", self.model_config.model) + with DeviceMemoryProfiler() as m: + self.model = get_model(model_config=self.model_config, + device_config=self.device_config, + load_config=self.load_config, + lora_config=self.lora_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + cache_config=self.cache_config) + + self.model_memory_usage = m.consumed_memory + logger.info("Loading model weights took %.4f GB", + self.model_memory_usage / float(2**30)) + + @torch.inference_mode() + def profile_run(self) -> None: + return + + def initialize_kv_cache(self) -> None: + ... + + +class BatchedRequestStates: + + def __init__( + self, + max_num_reqs: int, + max_num_blocks_per_req: int, + device: torch.device, + pin_memory: bool, + ): + self.max_num_reqs = max_num_reqs + self.max_num_blocks_per_req = max_num_blocks_per_req + self.device = device + self.pin_memory = pin_memory + + self.num_reqs = 0 + self.req_ids: List[Optional[str]] = [None] * max_num_reqs + + self.num_computed_tokens = torch.empty( + (max_num_reqs,), dtype=torch.int32, device=device) + self.num_computed_tokens_cpu = torch.empty( + (max_num_reqs,), dtype=torch.int32, + device="cpu", pin_memory=pin_memory) + + # Attention-related. + self.block_table = torch.empty( + (max_num_reqs, max_num_blocks_per_req), + device=self.device, dtype=torch.int32) + self.block_table_cpu = torch.empty( + (max_num_reqs, max_num_blocks_per_req), + device="cpu", dtype=torch.int32, pin_memory=pin_memory) + + # Sampling-related. + self.temperature = torch.empty( + (max_num_reqs,), dtype=torch.float32, device=device) + self.temperature_cpu = torch.empty( + (max_num_reqs,), dtype=torch.float32, + device="cpu", pin_memory=pin_memory) + self.greedy_reqs: Set[str] = set() + self.random_reqs: Set[str] = set() + + self.top_p = torch.empty( + (max_num_reqs,), dtype=torch.float32, device=device) + self.top_p_cpu = torch.empty( + (max_num_reqs,), dtype=torch.float32, + device="cpu", pin_memory=pin_memory) + self.top_p_reqs: Set[str] = set() + + self.top_k = torch.empty( + (max_num_reqs,), dtype=torch.float32, device=device) + self.top_k_cpu = torch.empty( + (max_num_reqs,), dtype=torch.float32, + device="cpu", pin_memory=pin_memory) + self.top_k_reqs: Set[str] = set() + + self.generators: List[Optional[torch.Generator]] = [None] * max_num_reqs + + self.num_logprobs: Dict[str, int] = {} + self.prompt_logprob_reqs: Set[str] = set() + + def add_request(self, request: "RequestState", req_index: int) -> None: + assert req_index < self.max_num_reqs + self.req_ids[req_index] = request.req_id + self.num_reqs += 1 + + self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens + self.block_table_cpu[req_index, :len(request.block_ids)] = request.block_ids + + sampling_params = request.sampling_params + self.temperature_cpu[req_index] = sampling_params.temperature + if sampling_params.sampling_type == SamplingType.GREEDY: + self.greedy_reqs.add(req_index) + elif sampling_params.sampling_type == SamplingType.RANDOM: + self.random_reqs.add(req_index) + elif sampling_params.sampling_type == SamplingType.RANDOM_SEED: + assert False + + self.top_p_cpu[req_index] = sampling_params.top_p + if sampling_params.top_p < 1: + self.top_p_reqs.add(req_index) + self.top_k_cpu[req_index] = sampling_params.top_k + if sampling_params.top_k > 0: + self.top_k_reqs.add(req_index) + + # TODO + self.generators[req_index] = None + + num_logprobs = sampling_params.logprobs + if num_logprobs is not None and num_logprobs > 0: + self.num_logprobs[request.req_id] = num_logprobs + if sampling_params.prompt_logprob: + self.prompt_logprob_reqs.add(req_index) + + def remove_request(self, req_id: str) -> Optional[int]: + if not req_id in self.req_ids: + return None + req_index = self.req_ids.index(req_id) + self.req_ids[req_index] = None + self.num_reqs -= 1 + + self.greedy_reqs.discard(req_id) + self.random_reqs.discard(req_id) + self.top_p_reqs.discard(req_id) + self.top_k_reqs.discard(req_id) + self.generators[req_index] = None + self.num_logprobs.pop(req_id, None) + self.prompt_logprob_reqs.discard(req_id) + return req_index + + def condense(self, empty_req_indices: List[int]) -> None: + # TODO(woosuk): Consider LoRA. + while empty_req_indices: + empty_index = empty_req_indices.pop() + last_req_index = self.num_reqs + len(empty_req_indices) - 1 + if empty_index == last_req_index: + assert len(empty_req_indices) == 0 + break + + # Swap the last request with the empty slot. + self.req_ids[empty_index] = self.req_ids[last_req_index] + self.num_computed_tokens_cpu[empty_index] = self.num_computed_tokens_cpu[last_req_index] + self.block_table_cpu[empty_index] = self.block_table_cpu[last_req_index] + self.temperature_cpu[empty_index] = self.temperature_cpu[last_req_index] + self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] + self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] + self.generators[empty_index] = self.generators[last_req_index] + + @property + def all_greedy(self) -> bool: + return len(self.random_reqs) == 0 + + @property + def all_random(self) -> bool: + return len(self.greedy_reqs) == 0 + + @property + def no_top_p(self) -> bool: + return len(self.top_p_reqs) == 0 + + @property + def no_top_k(self) -> bool: + return len(self.top_k_reqs) == 0 + + @property + def no_generator(self) -> bool: + return len(self.generators) == 0 + + @property + def max_num_logprobs(self) -> int: + return max(self.num_logprobs.values()) + + @property + def no_logprob(self) -> bool: + return len(self.num_logprobs) == 0 + + @property + def no_prompt_logprob(self) -> bool: + return len(self.prompt_logprob_reqs) == 0 + + def get_sampling_metadata(self) -> SamplingMetadata: + return SamplingMetadata( + temperature=self.temperature[:self.num_reqs], + all_greedy=self.all_greedy, + all_random=self.all_random, + top_p=self.top_p[:self.num_reqs], + top_k=self.top_k[:self.num_reqs], + no_top_p=self.no_top_p, + no_top_k=self.no_top_k, + generators=self.generators[:self.num_reqs], + no_generator=self.no_generator, + max_num_logprobs=self.max_num_logprobs, + ) + + +@dataclass +class RequestState: + + req_id: str + prompt_token_ids: torch.Tensor + prompt_token_ids_cpu: torch.Tensor + prompt: Optional[str] + multi_modal_data: Optional["MultiModalDataDict"] + sampling_params: SamplingParams + generator: Optional[torch.Generator] + + block_ids: List[int] + num_computed_tokens: int + output_token_ids: List[int] From c86ce2c8f5d9edfc617b0d744dd13a76539b54a3 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 27 Sep 2024 02:15:46 -0700 Subject: [PATCH 07/31] Minor --- vllm/core/interfaces.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index fa46a78480cd4..109458372c456 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -64,23 +64,6 @@ def append_slots( ) -> List[Tuple[int, int]]: pass - @abstractmethod - def can_swap_in(self, request: Request, - num_lookahead_slots: int) -> AllocStatus: - pass - - @abstractmethod - def swap_in(self, request: Request) -> List[Tuple[int, int]]: - pass - - @abstractmethod - def can_swap_out(self, request: Request) -> bool: - pass - - @abstractmethod - def swap_out(self, request: Request) -> List[Tuple[int, int]]: - pass - @abstractmethod def free(self, request: Request) -> None: pass From 50e4af201b2b50928910bc1f8a8d2791300d8818 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 27 Sep 2024 02:21:50 -0700 Subject: [PATCH 08/31] yapf --- vllm/worker/model_runner_v2.py | 95 ++++++++++++++++++++-------------- vllm/worker/worker_v2.py | 8 +-- 2 files changed, 59 insertions(+), 44 deletions(-) diff --git a/vllm/worker/model_runner_v2.py b/vllm/worker/model_runner_v2.py index 06ac2f024da07..9b9a402c62597 100644 --- a/vllm/worker/model_runner_v2.py +++ b/vllm/worker/model_runner_v2.py @@ -63,8 +63,8 @@ def __init__( self.sliding_window = model_config.get_sliding_window() self.block_size = cache_config.block_size self.max_num_blocks_per_req = ( - (self.model_config.max_model_len + self.block_size - 1) - // self.block_size) + (self.model_config.max_model_len + self.block_size - 1) // + self.block_size) # Lazy initialization self.model: nn.Module # Set after load_model @@ -88,10 +88,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self.requests.pop(req_id, None) # Remove the requests from the batched states. - stopped_req_ids = ( - scheduler_output.preempted_req_ids - + scheduler_output.finished_req_ids - + scheduler_output.aborted_req_ids) + stopped_req_ids = (scheduler_output.preempted_req_ids + + scheduler_output.finished_req_ids + + scheduler_output.aborted_req_ids) removed_req_indices: List[int] = [] for req_id in stopped_req_ids: req_index = self.batched_states.remove_request(req_id) @@ -126,9 +125,11 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Add new requests to the cached states. for req_data in scheduler_output.scheduled_new_reqs: req_id = req_data.req_id - prompt_token_ids_cpu = torch.tensor( - req_data.prompt_token_ids, device="cpu", pin_memory=self.pin_memory) - prompt_token_ids = prompt_token_ids_cpu.to(self.device, non_blocking=True) + prompt_token_ids_cpu = torch.tensor(req_data.prompt_token_ids, + device="cpu", + pin_memory=self.pin_memory) + prompt_token_ids = prompt_token_ids_cpu.to(self.device, + non_blocking=True) self.requests[req_id] = RequestState( req_id=req_id, @@ -188,8 +189,7 @@ def execute_model( kv_caches=self.kv_caches, ) - logits = self.model.compute_logits(hidden_states, - sampling_metadata) + logits = self.model.compute_logits(hidden_states, sampling_metadata) # Create the sampling metadata. sampling_metadata = self.batched_states.get_sampling_metadata() # Sample the next token and get logprobs if needed. @@ -239,44 +239,55 @@ def __init__( self.num_reqs = 0 self.req_ids: List[Optional[str]] = [None] * max_num_reqs - self.num_computed_tokens = torch.empty( - (max_num_reqs,), dtype=torch.int32, device=device) - self.num_computed_tokens_cpu = torch.empty( - (max_num_reqs,), dtype=torch.int32, - device="cpu", pin_memory=pin_memory) + self.num_computed_tokens = torch.empty((max_num_reqs, ), + dtype=torch.int32, + device=device) + self.num_computed_tokens_cpu = torch.empty((max_num_reqs, ), + dtype=torch.int32, + device="cpu", + pin_memory=pin_memory) # Attention-related. - self.block_table = torch.empty( - (max_num_reqs, max_num_blocks_per_req), - device=self.device, dtype=torch.int32) + self.block_table = torch.empty((max_num_reqs, max_num_blocks_per_req), + device=self.device, + dtype=torch.int32) self.block_table_cpu = torch.empty( (max_num_reqs, max_num_blocks_per_req), - device="cpu", dtype=torch.int32, pin_memory=pin_memory) + device="cpu", + dtype=torch.int32, + pin_memory=pin_memory) # Sampling-related. - self.temperature = torch.empty( - (max_num_reqs,), dtype=torch.float32, device=device) - self.temperature_cpu = torch.empty( - (max_num_reqs,), dtype=torch.float32, - device="cpu", pin_memory=pin_memory) + self.temperature = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device=device) + self.temperature_cpu = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device="cpu", + pin_memory=pin_memory) self.greedy_reqs: Set[str] = set() self.random_reqs: Set[str] = set() - self.top_p = torch.empty( - (max_num_reqs,), dtype=torch.float32, device=device) - self.top_p_cpu = torch.empty( - (max_num_reqs,), dtype=torch.float32, - device="cpu", pin_memory=pin_memory) + self.top_p = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device=device) + self.top_p_cpu = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device="cpu", + pin_memory=pin_memory) self.top_p_reqs: Set[str] = set() - self.top_k = torch.empty( - (max_num_reqs,), dtype=torch.float32, device=device) - self.top_k_cpu = torch.empty( - (max_num_reqs,), dtype=torch.float32, - device="cpu", pin_memory=pin_memory) + self.top_k = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device=device) + self.top_k_cpu = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device="cpu", + pin_memory=pin_memory) self.top_k_reqs: Set[str] = set() - self.generators: List[Optional[torch.Generator]] = [None] * max_num_reqs + self.generators: List[Optional[torch.Generator]] = [None + ] * max_num_reqs self.num_logprobs: Dict[str, int] = {} self.prompt_logprob_reqs: Set[str] = set() @@ -287,7 +298,8 @@ def add_request(self, request: "RequestState", req_index: int) -> None: self.num_reqs += 1 self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens - self.block_table_cpu[req_index, :len(request.block_ids)] = request.block_ids + self.block_table_cpu[ + req_index, :len(request.block_ids)] = request.block_ids sampling_params = request.sampling_params self.temperature_cpu[req_index] = sampling_params.temperature @@ -341,9 +353,12 @@ def condense(self, empty_req_indices: List[int]) -> None: # Swap the last request with the empty slot. self.req_ids[empty_index] = self.req_ids[last_req_index] - self.num_computed_tokens_cpu[empty_index] = self.num_computed_tokens_cpu[last_req_index] - self.block_table_cpu[empty_index] = self.block_table_cpu[last_req_index] - self.temperature_cpu[empty_index] = self.temperature_cpu[last_req_index] + self.num_computed_tokens_cpu[ + empty_index] = self.num_computed_tokens_cpu[last_req_index] + self.block_table_cpu[empty_index] = self.block_table_cpu[ + last_req_index] + self.temperature_cpu[empty_index] = self.temperature_cpu[ + last_req_index] self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] self.generators[empty_index] = self.generators[last_req_index] diff --git a/vllm/worker/worker_v2.py b/vllm/worker/worker_v2.py index 561f052913cf5..255bd569f254a 100644 --- a/vllm/worker/worker_v2.py +++ b/vllm/worker/worker_v2.py @@ -143,15 +143,15 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: self.model_runner.remove_all_loras() gc.collect() torch.cuda.empty_cache() - return num_gpu_blocks, num_cpu_blocks + return num_gpu_blocks, num_cpu_blocks def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: """Allocate GPU and CPU KV cache with the specified number of blocks.""" if num_gpu_blocks <= 0: raise ValueError("No available memory for the cache blocks. " - "Try increasing `gpu_memory_utilization` when " - "initializing the engine.") + "Try increasing `gpu_memory_utilization` when " + "initializing the engine.") max_seq_len = self.cache_config.block_size * num_gpu_blocks max_model_len = self.model_config.max_model_len @@ -181,7 +181,7 @@ def execute_model( self, scheduler_output: "SchedulerOutput", ) -> None: - sampler_output = self.model_runner.execute_model(scheduler_output) + sampler_output = self.model_runner.execute_model(scheduler_output) # TODO(woosuk): Send the output to the engine process. From 9d14fd1214da683b8ab943744d80f5dc1b6f64da Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 29 Sep 2024 16:55:58 -0700 Subject: [PATCH 09/31] Minor --- vllm/worker/model_runner_v2.py | 47 ++++++++++++++++++++++------------ 1 file changed, 30 insertions(+), 17 deletions(-) diff --git a/vllm/worker/model_runner_v2.py b/vllm/worker/model_runner_v2.py index 9b9a402c62597..27237fbffe6d6 100644 --- a/vllm/worker/model_runner_v2.py +++ b/vllm/worker/model_runner_v2.py @@ -97,6 +97,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: if req_index is not None: removed_req_indices.append(req_index) + # Condense the batched states. + self.batched_states.condense(removed_req_indices) + # Update the states of the running requests. num_prev_blocks: Dict[str, int] = {} new_block_ids: Dict[str, List[int]] = {} @@ -155,19 +158,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: req_ids_to_add.append(req_id) # Add the new or resumed requests to the batched states. - # The smaller empty indices are filled first. - removed_req_indices.sort(reverse=True) for req_id in req_ids_to_add: req_state = self.requests[req_id] - if removed_req_indices: - # TODO(woosuk): Consider LoRA. - req_index = removed_req_indices.pop() - else: - req_index = self.batched_states.num_reqs - self.batched_states.add_request(req_state, req_index) - - # Condense the batched states. - self.batched_states.condense(removed_req_indices) + self.batched_states.add_request(req_state) def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): @@ -292,8 +285,15 @@ def __init__( self.num_logprobs: Dict[str, int] = {} self.prompt_logprob_reqs: Set[str] = set() - def add_request(self, request: "RequestState", req_index: int) -> None: + def add_request( + self, + request: "RequestState", + req_index: Optional[int] = None, + ) -> None: + if req_index is None: + req_index = self.num_reqs assert req_index < self.max_num_reqs + self.req_ids[req_index] = request.req_id self.num_reqs += 1 @@ -308,6 +308,7 @@ def add_request(self, request: "RequestState", req_index: int) -> None: elif sampling_params.sampling_type == SamplingType.RANDOM: self.random_reqs.add(req_index) elif sampling_params.sampling_type == SamplingType.RANDOM_SEED: + # TODO assert False self.top_p_cpu[req_index] = sampling_params.top_p @@ -344,14 +345,26 @@ def remove_request(self, req_id: str) -> Optional[int]: def condense(self, empty_req_indices: List[int]) -> None: # TODO(woosuk): Consider LoRA. + if not empty_req_indices: + # The batched states are already condensed. + return + if self.num_reqs == 0: + # The batched states are empty. + return + + empty_req_indices = sorted(empty_req_indices, reverse=True) + last_req_index = self.num_reqs + len(empty_req_indices) - 1 while empty_req_indices: - empty_index = empty_req_indices.pop() - last_req_index = self.num_reqs + len(empty_req_indices) - 1 - if empty_index == last_req_index: - assert len(empty_req_indices) == 0 + # Find the largest non-empty index. + while last_req_index in empty_req_indices: + last_req_index -= 1 + + # Find the smallest empty index. + empty_index = empty_req_indices.pop() + if empty_index >= last_req_index: break - # Swap the last request with the empty slot. + # Swap the states. self.req_ids[empty_index] = self.req_ids[last_req_index] self.num_computed_tokens_cpu[ empty_index] = self.num_computed_tokens_cpu[last_req_index] From 23152fa575d69f1ddaf0d67a48ebbf912e4af6a8 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 29 Sep 2024 17:05:52 -0700 Subject: [PATCH 10/31] Add clear --- vllm/worker/model_runner_v2.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/vllm/worker/model_runner_v2.py b/vllm/worker/model_runner_v2.py index 27237fbffe6d6..c32c42bca50bf 100644 --- a/vllm/worker/model_runner_v2.py +++ b/vllm/worker/model_runner_v2.py @@ -98,6 +98,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: removed_req_indices.append(req_index) # Condense the batched states. + # We condense the states before adding new/resumed requests + # because the attention backend may require it. self.batched_states.condense(removed_req_indices) # Update the states of the running requests. @@ -163,7 +165,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self.batched_states.add_request(req_state) def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): - pass @torch.inference_mode() @@ -343,6 +344,16 @@ def remove_request(self, req_id: str) -> Optional[int]: self.prompt_logprob_reqs.discard(req_id) return req_index + def clear(self) -> None: + self.num_reqs = 0 + self.greedy_reqs.clear() + self.random_reqs.clear() + self.top_p_reqs.clear() + self.top_k_reqs.clear() + self.generators.clear() + self.num_logprobs.clear() + self.prompt_logprob_reqs.clear() + def condense(self, empty_req_indices: List[int]) -> None: # TODO(woosuk): Consider LoRA. if not empty_req_indices: From 27a2683ac024cd65cc8f03875da51731c4de8d06 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 29 Sep 2024 17:17:27 -0700 Subject: [PATCH 11/31] Minor --- vllm/worker/model_runner_v2.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/worker/model_runner_v2.py b/vllm/worker/model_runner_v2.py index c32c42bca50bf..b63d5d2825f9b 100644 --- a/vllm/worker/model_runner_v2.py +++ b/vllm/worker/model_runner_v2.py @@ -165,7 +165,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self.batched_states.add_request(req_state) def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): - pass + assert scheduler_output.total_num_scheduled_tokens > 0 + num_scheduled_tokens = scheduler_output.num_scheduled_tokens @torch.inference_mode() @dump_input_when_exception(exclude_args=[0], exclude_kwargs=["self"]) @@ -176,6 +177,8 @@ def execute_model( self._update_states(scheduler_output) inputs = self._prepare_inputs(scheduler_output) input_ids, position_ids, attn_metadata = inputs + # Create the sampling metadata. + sampling_metadata = self.batched_states.get_sampling_metadata() hidden_states = self.model( input_ids=input_ids, position_ids=position_ids, @@ -184,8 +187,6 @@ def execute_model( ) logits = self.model.compute_logits(hidden_states, sampling_metadata) - # Create the sampling metadata. - sampling_metadata = self.batched_states.get_sampling_metadata() # Sample the next token and get logprobs if needed. sampler_output = self.model.sample( logits=logits, @@ -319,8 +320,7 @@ def add_request( if sampling_params.top_k > 0: self.top_k_reqs.add(req_index) - # TODO - self.generators[req_index] = None + self.generators[req_index] = request.generator num_logprobs = sampling_params.logprobs if num_logprobs is not None and num_logprobs > 0: From 025aeb8fcf13701ed0aa8833d75b6aa4c5b35284 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 29 Sep 2024 17:30:53 -0700 Subject: [PATCH 12/31] Minor --- vllm/worker/model_runner_v2.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/worker/model_runner_v2.py b/vllm/worker/model_runner_v2.py index b63d5d2825f9b..1901101beb51f 100644 --- a/vllm/worker/model_runner_v2.py +++ b/vllm/worker/model_runner_v2.py @@ -100,7 +100,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Condense the batched states. # We condense the states before adding new/resumed requests # because the attention backend may require it. - self.batched_states.condense(removed_req_indices) + if removed_req_indices: + self.batched_states.condense(removed_req_indices) # Update the states of the running requests. num_prev_blocks: Dict[str, int] = {} @@ -356,9 +357,6 @@ def clear(self) -> None: def condense(self, empty_req_indices: List[int]) -> None: # TODO(woosuk): Consider LoRA. - if not empty_req_indices: - # The batched states are already condensed. - return if self.num_reqs == 0: # The batched states are empty. return From 8d29ffc6ec0dd5c388bf9188500c3e2a08cb2070 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 7 Oct 2024 00:57:01 -0700 Subject: [PATCH 13/31] Working --- benchmarks/benchmark_throughput.py | 4 +- examples/offline_inference.py | 5 +- vllm/__init__.py | 7 +- vllm/attention/backends/flash_attn.py | 666 +--------------- vllm/config.py | 4 +- vllm/core/block/utils.py | 4 +- vllm/core/block_manager_v1.py | 738 ------------------ vllm/core/block_manager_v2.py | 503 ------------ vllm/core/kv_cache_manager.py | 89 +++ vllm/core/scheduler_v2.py | 152 ++-- vllm/engine/arg_utils.py | 2 +- vllm/engine/async_llm_engine.py | 12 +- vllm/engine/llm_engine_v2.py | 58 +- vllm/entrypoints/llm.py | 44 +- vllm/entrypoints/openai/protocol.py | 3 +- .../{gpu_executor.py => executor_v2.py} | 82 +- vllm/executor/msgspec_utils.py | 2 +- vllm/executor/ray_utils.py | 136 ++-- .../model_executor/layers/logits_processor.py | 14 +- vllm/model_executor/layers/sampler.py | 22 +- .../model_executor/model_loader/tensorizer.py | 2 +- vllm/model_executor/models/llama.py | 7 +- vllm/model_executor/models/utils.py | 1 - vllm/outputs.py | 270 ------- vllm/outputs_v2.py | 82 ++ vllm/request.py | 16 +- vllm/sampler_output.py | 19 - vllm/transformers_utils/detokenizer.py | 6 +- vllm/worker/model_runner_base.py | 267 ------- vllm/worker/model_runner_v2.py | 459 ++++++++--- vllm/worker/worker_base.py | 485 ------------ vllm/worker/worker_v2.py | 53 +- 32 files changed, 841 insertions(+), 3373 deletions(-) delete mode 100644 vllm/core/block_manager_v1.py delete mode 100644 vllm/core/block_manager_v2.py create mode 100644 vllm/core/kv_cache_manager.py rename vllm/executor/{gpu_executor.py => executor_v2.py} (62%) delete mode 100644 vllm/outputs.py create mode 100644 vllm/outputs_v2.py delete mode 100644 vllm/sampler_output.py delete mode 100644 vllm/worker/model_runner_base.py delete mode 100644 vllm/worker/worker_base.py diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index e1a5d4ee28ea1..299ea3b649347 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -12,8 +12,8 @@ PreTrainedTokenizerBase) from vllm.engine.arg_utils import DEVICE_OPTIONS, AsyncEngineArgs, EngineArgs -from vllm.entrypoints.openai.api_server import ( - build_async_engine_client_from_engine_args) +# from vllm.entrypoints.openai.api_server import ( +# build_async_engine_client_from_engine_args) from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.utils import FlexibleArgumentParser, merge_async_iterators diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 9b758fa2479f6..a5225b1378c89 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -8,10 +8,10 @@ "The future of AI is", ] # Create a sampling params object. -sampling_params = SamplingParams(temperature=0.8, top_p=0.95) +sampling_params = SamplingParams(temperature=0.8, top_p=0.95, logprobs=5) # Create an LLM. -llm = LLM(model="facebook/opt-125m") +llm = LLM(model="meta-llama/Llama-3.1-8B") # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) @@ -20,3 +20,4 @@ prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + print(output.outputs[0]) diff --git a/vllm/__init__.py b/vllm/__init__.py index 59af68fb493e5..3b5ed587af280 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -2,13 +2,12 @@ from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.engine.llm_engine import LLMEngine +from vllm.engine.llm_engine_v2 import LLMEngine from vllm.entrypoints.llm import LLM from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import PromptType, TextPrompt, TokensPrompt from vllm.model_executor.models import ModelRegistry -from vllm.outputs import (CompletionOutput, EmbeddingOutput, - EmbeddingRequestOutput, RequestOutput) +from vllm.outputs_v2 import CompletionOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams @@ -25,8 +24,6 @@ "SamplingParams", "RequestOutput", "CompletionOutput", - "EmbeddingOutput", - "EmbeddingRequestOutput", "LLMEngine", "EngineArgs", "AsyncLLMEngine", diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 084e8113cd421..9e93e931bf077 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -6,159 +6,8 @@ from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, - AttentionMetadataBuilder, - AttentionType) -from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState, - compute_slot_mapping, - compute_slot_mapping_start_idx, - is_block_tables_empty) -from vllm.utils import async_tensor_h2d, make_tensor_with_pad - -if TYPE_CHECKING: - from vllm.worker.model_runner import (ModelInputForGPUBuilder, - ModelInputForGPUWithSamplingMetadata) - -# yapf: disable -from vllm.vllm_flash_attn import ( - flash_attn_varlen_func as _flash_attn_varlen_func) -from vllm.vllm_flash_attn import ( - flash_attn_with_kvcache as _flash_attn_with_kvcache) - -# yapf: enable - - -@torch.library.custom_op("vllm::flash_attn_varlen_func", mutates_args=[]) -def flash_attn_varlen_func( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_q: int, - max_seqlen_k: int, - softmax_scale: Optional[float] = None, - causal: bool = False, - window_size: Optional[List[int]] = None, - softcap: float = 0.0, - alibi_slopes: Optional[torch.Tensor] = None, - block_table: Optional[torch.Tensor] = None, -) -> torch.Tensor: - # custom op does not support tuple input - real_window_size: Tuple[int, int] - if window_size is None: - real_window_size = (-1, -1) - else: - assert len(window_size) == 2 - real_window_size = (window_size[0], window_size[1]) - return _flash_attn_varlen_func( - q=q, - k=k, - v=v, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, - softmax_scale=softmax_scale, - causal=causal, - window_size=real_window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - block_table=block_table, - ) - - -@flash_attn_varlen_func.register_fake # type: ignore -def _( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_q: int, - max_seqlen_k: int, - softmax_scale: Optional[float] = None, - causal: bool = False, - window_size: Optional[List[int]] = None, - softcap: float = 0.0, - alibi_slopes: Optional[torch.Tensor] = None, - block_table: Optional[torch.Tensor] = None, -) -> torch.Tensor: - return torch.empty_like(q) - - -@torch.library.custom_op("vllm::flash_attn_with_kvcache", mutates_args=[]) -def flash_attn_with_kvcache( - decode_query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - cache_seqlens: Optional[torch.Tensor] = None, - block_table: Optional[torch.Tensor] = None, - softmax_scale: Optional[float] = None, - causal: bool = False, - alibi_slopes: Optional[torch.Tensor] = None, - softcap: float = 0.0, -) -> torch.Tensor: - return _flash_attn_with_kvcache( - decode_query, - key_cache, - value_cache, - cache_seqlens=cache_seqlens, - block_table=block_table, - softmax_scale=softmax_scale, - causal=causal, - alibi_slopes=alibi_slopes, - softcap=softcap, - ) - - -@flash_attn_with_kvcache.register_fake # type: ignore -def _( - decode_query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - cache_seqlens: Optional[torch.Tensor] = None, - block_table: Optional[torch.Tensor] = None, - softmax_scale: Optional[float] = None, - causal: bool = False, - alibi_slopes: Optional[torch.Tensor] = None, - softcap: float = 0.0, -) -> torch.Tensor: - return torch.empty_like(decode_query) - - -@torch.library.custom_op("vllm::reshape_and_cache_flash", - mutates_args=["kv_cache"]) -def reshape_and_cache_flash( - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - slot_mapping: torch.Tensor, - kv_cache_dtype: str, - k_scale: float, - v_scale: float, -) -> None: - """Inductor cannot deal with inplace operations on views. - See https://github.com/pytorch/pytorch/issues/131192 - and https://github.com/pytorch/pytorch/issues/130174 - This is a workaround to hide the view operation from the inductor. - """ - return torch.ops._C_cache_ops.reshape_and_cache_flash( - key, value, kv_cache[0], kv_cache[1], slot_mapping, kv_cache_dtype, - k_scale, v_scale) - - -@reshape_and_cache_flash.register_fake # type: ignore -def _( - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - slot_mapping: torch.Tensor, - kv_cache_dtype: str, - k_scale: float, - v_scale: float, -) -> None: - pass + AttentionMetadata, AttentionType) +from vllm.vllm_flash_attn import flash_attn_varlen_func class FlashAttentionBackend(AttentionBackend): @@ -179,14 +28,6 @@ def get_impl_cls() -> Type["FlashAttentionImpl"]: def get_metadata_cls() -> Type["AttentionMetadata"]: return FlashAttentionMetadata - @staticmethod - def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]: - return FlashAttentionMetadataBuilder - - @staticmethod - def get_state_cls() -> Type["CommonAttentionState"]: - return CommonAttentionState - @staticmethod def get_kv_cache_shape( num_blocks: int, @@ -198,45 +39,9 @@ def get_kv_cache_shape( raise ValueError("Block size must be a multiple of 16.") return (2, num_blocks, block_size, num_kv_heads, head_size) - @staticmethod - def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: torch.Tensor, - ) -> None: - src_key_cache = src_kv_cache[0] - dst_key_cache = dst_kv_cache[0] - ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) - - src_value_cache = src_kv_cache[1] - dst_value_cache = dst_kv_cache[1] - ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) - - @staticmethod - def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: torch.Tensor, - ) -> None: - key_caches = [kv_cache[0] for kv_cache in kv_caches] - value_caches = [kv_cache[1] for kv_cache in kv_caches] - ops.copy_blocks(key_caches, value_caches, src_to_dists) - @dataclass -class FlashAttentionMetadata(AttentionMetadata): - """Metadata for FlashAttentionBackend. - - NOTE: Any python object stored here is not updated when it is - cuda-graph replayed. If you have values that need to be changed - dynamically, it should be stored in tensor. The tensor has to be - updated from `CUDAGraphRunner.forward` API. - """ - # (batch_size,). The sequence length per sequence. Sequence length means - # the computed tokens + new tokens None if it is a decoding. - seq_lens: Optional[List[int]] - # seq_lens stored as a tensor. - seq_lens_tensor: Optional[torch.Tensor] - +class FlashAttentionMetadata: # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| @@ -245,362 +50,15 @@ class FlashAttentionMetadata(AttentionMetadata): # |-------------------- seq_len ---------------------| # |-- query_len ---| - # Maximum query length in the batch. None for decoding. - max_query_len: Optional[int] - # Maximum sequence length among prefill batch. 0 if there are decoding - # requests only. - max_prefill_seq_len: int - # Maximum sequence length among decode batch. 0 if there are prefill - # requests only. - max_decode_seq_len: int - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] - # (batch_size,) A tensor of context lengths (tokens that are computed - # so far). - context_lens_tensor: Optional[torch.Tensor] - - # (batch_size, max_blocks_per_seq). - # Block addresses per sequence. (Seq id -> list of physical block) - # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks - # in the kv cache. Each block can contain up to block_size tokens. - # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph - # captured. - block_tables: Optional[torch.Tensor] - - # Whether or not if cuda graph is enabled. - # Cuda-graph is currently enabled for decoding only. - # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. - use_cuda_graph: bool - - _cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None - _cached_decode_metadata: Optional["FlashAttentionMetadata"] = None - - @property - def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: - if self.num_prefills == 0: - return None - - if self._cached_prefill_metadata is not None: - return self._cached_prefill_metadata - - assert self.seq_lens is not None - assert self.seq_lens_tensor is not None - assert self.query_start_loc is not None - assert self.context_lens_tensor is not None - assert self.block_tables is not None - assert self.seq_start_loc is not None - - self._cached_prefill_metadata = FlashAttentionMetadata( - num_prefills=self.num_prefills, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=0, - slot_mapping=self.slot_mapping[:self.num_prefill_tokens], - seq_lens=self.seq_lens[:self.num_prefills], - seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], - max_query_len=self.max_query_len, - max_prefill_seq_len=self.max_prefill_seq_len, - max_decode_seq_len=0, - query_start_loc=self.query_start_loc[:self.num_prefills + 1], - seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], - context_lens_tensor=self.context_lens_tensor[:self.num_prefills], - block_tables=self.block_tables[:self.num_prefills], - use_cuda_graph=False, - ) - return self._cached_prefill_metadata - - @property - def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: - if self.num_decode_tokens == 0: - return None - - if self._cached_decode_metadata is not None: - return self._cached_decode_metadata - assert self.block_tables is not None - assert self.seq_lens_tensor is not None - - self._cached_decode_metadata = FlashAttentionMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=self.num_decode_tokens, - slot_mapping=self.slot_mapping[self.num_prefill_tokens:], - seq_lens=None, - seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], - max_query_len=None, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_decode_seq_len, - query_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, - block_tables=self.block_tables[self.num_prefills:], - use_cuda_graph=self.use_cuda_graph, - ) - return self._cached_decode_metadata - - def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata", - sampled_token_ids: Optional[torch.Tensor], - block_size: int, num_seqs: int, num_queries: int): - """ - Update metadata in-place to advance one decode step. - """ - # When using cudagraph, the num_seqs is padded to the next captured - # batch sized, but num_queries tracks the actual number of requests in - # the batch. For --enforce-eager mode, num_seqs == num_queries - if num_seqs != num_queries: - assert num_seqs > num_queries - assert self.use_cuda_graph - - assert self.num_prefills == 0 - assert self.num_prefill_tokens == 0 - assert self.num_decode_tokens == num_seqs - assert self.slot_mapping.shape == (num_seqs, ) - - assert self.seq_lens is not None - assert len(self.seq_lens) == num_seqs - assert self.seq_lens_tensor is not None - assert self.seq_lens_tensor.shape == (num_seqs, ) - assert self.max_query_len == 1 - assert self.max_prefill_seq_len == 0 - assert self.max_decode_seq_len == max(self.seq_lens) - - assert self.query_start_loc is not None - assert self.query_start_loc.shape == (num_queries + 1, ) - assert self.seq_start_loc is not None - assert self.seq_start_loc.shape == (num_seqs + 1, ) - - assert self.context_lens_tensor is not None - assert self.context_lens_tensor.shape == (num_queries, ) - - assert self.block_tables is not None - assert self.block_tables.shape[0] == num_seqs - - # Update query lengths. Note that we update only queries and not seqs, - # since tensors may be padded due to captured cuda graph batch size - for i in range(num_queries): - self.seq_lens[i] += 1 - self.max_decode_seq_len = max(self.seq_lens) - - ops.advance_step_flashattn(num_seqs=num_seqs, - num_queries=num_queries, - block_size=block_size, - input_tokens=model_input.input_tokens, - sampled_token_ids=sampled_token_ids, - input_positions=model_input.input_positions, - seq_lens=self.seq_lens_tensor, - slot_mapping=self.slot_mapping, - block_tables=self.block_tables) - - -class FlashAttentionMetadataBuilder( - AttentionMetadataBuilder[FlashAttentionMetadata]): - - def __init__(self, input_builder: "ModelInputForGPUBuilder"): - self.slot_mapping: List[int] = [] - self.prefill_seq_lens: List[int] = [] - self.context_lens: List[int] = [] - self.block_tables: List[List[int]] = [] - self.curr_seq_lens: List[int] = [] - self.num_prefills = 0 - self.num_prefill_tokens = 0 - self.num_decode_tokens = 0 - self.has_prefix_cache_hit = False - - self.input_builder = input_builder - self.runner = input_builder.runner - self.sliding_window = input_builder.sliding_window - self.block_size = input_builder.block_size - self.use_v2_block_manager = ( - input_builder.scheduler_config.use_v2_block_manager) - - def _add_seq_group( - self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", - chunked_prefill_enabled: bool, prefix_cache_hit: bool): - """Add a sequence group to the metadata. Specifically update/append - 1. context length. - 2. block table. - 3. slot mapping. - """ - is_prompt = inter_data.is_prompt - block_tables = inter_data.block_tables - - for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, - curr_sliding_window_block) in zip( - inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], - inter_data.orig_seq_lens, inter_data.seq_lens, - inter_data.query_lens, inter_data.context_lens, - inter_data.curr_sliding_window_blocks): - self.context_lens.append(context_len) - - if is_prompt: - self.num_prefills += 1 - self.num_prefill_tokens += token_len - self.prefill_seq_lens.append(seq_len) - else: - assert query_len == 1, ( - "seq_len: {}, context_len: {}, query_len: {}".format( - seq_len, context_len, query_len)) - self.num_decode_tokens += query_len - self.curr_seq_lens.append(curr_seq_len) - - # Compute block table. - # TODO(sang): Combine chunked prefill and prefix caching by - # only allowing multiple of block_size chunk size. - # NOTE: This only works for oooooooxxx style attention. - block_table = [] - if prefix_cache_hit: - # NOTE(woosuk): For flash-attn, the block table should - # include the entries for the incoming prefill tokens. - block_table = block_tables[seq_id] - elif ((chunked_prefill_enabled or not is_prompt) - and block_tables is not None): - if curr_sliding_window_block == 0: - block_table = block_tables[seq_id] - else: - block_table = block_tables[seq_id][ - -curr_sliding_window_block:] - self.block_tables.append(block_table) - - # Compute slot mapping. - is_profile_run = is_block_tables_empty(block_tables) - start_idx = compute_slot_mapping_start_idx( - is_prompt, query_len, context_len, self.sliding_window, - self.use_v2_block_manager) - compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, - seq_len, context_len, start_idx, - self.block_size, inter_data.block_tables) - - def build(self, seq_lens: List[int], query_lens: List[int], - cuda_graph_pad_size: int, batch_size: int): - """Build attention metadata with on-device tensors. - - Args: - seq_lens: The maybe padded sequence lengths of the input sequences. - query_lens: The query lengths of the input sequences. - cuda_graph_pad_size: The padding size for cuda graph. - -1 if cuda graph is not used. - batch_size: The maybe padded batch size. - """ - prefix_cache_hit = any([ - inter_data.prefix_cache_hit - for inter_data in self.input_builder.inter_data_list - ]) - for inter_data in self.input_builder.inter_data_list: - self._add_seq_group(inter_data, - self.input_builder.chunked_prefill_enabled, - prefix_cache_hit) - - device = self.runner.device - use_captured_graph = cuda_graph_pad_size != -1 - - max_query_len = max(query_lens) - max_prefill_seq_len = max(self.prefill_seq_lens, default=0) - max_decode_seq_len = max(self.curr_seq_lens, default=0) - num_decode_tokens = self.num_decode_tokens - - if use_captured_graph: - self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) - self.block_tables.extend([] * cuda_graph_pad_size) - num_decode_tokens = batch_size - - # The shape of graph_block_tables is - # [max batch size, max context len // block size]. - input_block_tables = self.runner.graph_block_tables[:batch_size] - max_blocks = input_block_tables.shape[1] - for i, block_table in enumerate(self.block_tables): - if block_table: - num_blocks = len(block_table) - if num_blocks <= max_blocks: - input_block_tables[i, :num_blocks] = block_table - else: - # It may be possible to have more blocks allocated due - # to lookahead slots of multi-step, however, they are - # not used anyway, so can be safely ignored. - input_block_tables[ - i, :max_blocks] = block_table[:max_blocks] - - block_tables = torch.from_numpy(input_block_tables).to( - device=device, non_blocking=True) - else: - block_tables = make_tensor_with_pad( - self.block_tables, - pad=0, - dtype=torch.int, - device=device, - ) - assert max_query_len > 0, ("query_lens: {}".format(query_lens)) - - assert device is not None - context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, - device, self.runner.pin_memory) - seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, - self.runner.pin_memory) - query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device, - self.runner.pin_memory) - slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, - device, self.runner.pin_memory) - query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=device) - seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=device) - torch.cumsum(seq_lens_tensor, - dim=0, - dtype=seq_start_loc.dtype, - out=seq_start_loc[1:]) - torch.cumsum(query_lens_tensor, - dim=0, - dtype=query_start_loc.dtype, - out=query_start_loc[1:]) - - return FlashAttentionMetadata( - num_prefills=self.num_prefills, - slot_mapping=slot_mapping_tensor, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_query_len=max_query_len, - max_prefill_seq_len=max_prefill_seq_len, - max_decode_seq_len=max_decode_seq_len, - query_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=use_captured_graph, - ) + max_query_len: int + query_start_loc: torch.Tensor + max_seq_len: int + seq_start_loc: torch.Tensor + block_table: torch.Tensor + slot_mapping: torch.Tensor class FlashAttentionImpl(AttentionImpl): - """ - If the input tensors contain prompt tokens, the layout is as follows: - |<--------------- num_prefill_tokens ----------------->| - |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| - - Otherwise, the layout is as follows: - |<----------------- num_decode_tokens ------------------>| - |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| - - Generation tokens can contain padding when cuda-graph is used. - Currently, prompt tokens don't contain any padding. - - The prompts might have different lengths, while the generation tokens - always have length 1. - - If chunked prefill is enabled, prefill tokens and decode tokens can be - batched together in a flattened 1D query. - - |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| - |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| - - Currently, cuda graph is disabled for chunked prefill, meaning there's no - padding between prefill and decode tokens. - """ def __init__( self, @@ -692,93 +150,35 @@ def forward( # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. - torch.ops.vllm.reshape_and_cache_flash( + torch.ops._C_cache_ops.reshape_and_cache_flash( key, value, - kv_cache, - attn_metadata.slot_mapping.flatten(), + kv_cache[0], + kv_cache[1], + attn_metadata.slot_mapping, self.kv_cache_dtype, k_scale, v_scale, ) - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens - assert key.shape[0] == num_prefill_tokens + num_decode_tokens - assert value.shape[0] == num_prefill_tokens + num_decode_tokens - - # Query for decode. KV is not needed because it is already cached. - decode_query = query[num_prefill_tokens:] - # QKV for prefill. - query = query[:num_prefill_tokens] - key = key[:num_prefill_tokens] - value = value[:num_prefill_tokens] - - assert query.shape[0] == num_prefill_tokens - assert decode_query.shape[0] == num_decode_tokens - - prefill_output: Optional[torch.Tensor] = None - decode_output: Optional[torch.Tensor] = None - - if prefill_meta := attn_metadata.prefill_metadata: - # Prompt run. - if (kv_cache is None or prefill_meta.block_tables is None - or prefill_meta.block_tables.numel() == 0): - # normal attention - # When block_tables are not filled, it means q and k are the - # prompt, and they have the same length. - prefill_output = torch.ops.vllm.flash_attn_varlen_func( - q=query, - k=key, - v=value, - cu_seqlens_q=prefill_meta.seq_start_loc, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_prefill_seq_len, - max_seqlen_k=prefill_meta.max_prefill_seq_len, - softmax_scale=self.scale, - causal=True, - window_size=self.sliding_window, - alibi_slopes=self.alibi_slopes, - softcap=self.logits_soft_cap, - ) - else: - # prefix-enabled attention - assert prefill_meta.seq_lens is not None - max_seq_len = max(prefill_meta.seq_lens) - prefill_output = torch.ops.vllm.flash_attn_varlen_func( # noqa - q=query, - k=key_cache, - v=value_cache, - cu_seqlens_q=prefill_meta.query_start_loc, - max_seqlen_q=prefill_meta.max_query_len, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_k=max_seq_len, - softmax_scale=self.scale, - causal=True, - alibi_slopes=self.alibi_slopes, - block_table=prefill_meta.block_tables, - softcap=self.logits_soft_cap, - ) - - if decode_meta := attn_metadata.decode_metadata: - # Decoding run. - decode_output = torch.ops.vllm.flash_attn_with_kvcache( - decode_query.unsqueeze(1), - key_cache, - value_cache, - block_table=decode_meta.block_tables, - cache_seqlens=decode_meta.seq_lens_tensor, - softmax_scale=self.scale, - causal=True, - alibi_slopes=self.alibi_slopes, - softcap=self.logits_soft_cap, - ).squeeze(1) - - if prefill_output is None: - assert decode_output is not None - return decode_output.view(num_decode_tokens, hidden_size) - if decode_output is None: - assert prefill_output is not None - return prefill_output.view(num_prefill_tokens, hidden_size) - output = torch.cat([prefill_output, decode_output], dim=0) + if (attn_metadata.block_table is None + or attn_metadata.block_table.numel() == 0): + # Profiling run. + output = torch.empty_like(query) + return output + + output = flash_attn_varlen_func( # noqa + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=attn_metadata.query_start_loc, + max_seqlen_q=attn_metadata.max_query_len, + cu_seqlens_k=attn_metadata.seq_start_loc, + max_seqlen_k=attn_metadata.max_seq_len, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + block_table=attn_metadata.block_table, + softcap=self.logits_soft_cap, + ) return output.view(num_tokens, hidden_size) diff --git a/vllm/config.py b/vllm/config.py index 960a8d3928584..9cb3b39f89c12 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -963,9 +963,7 @@ def __init__(self, send_delta_data: bool = False) -> None: if max_num_batched_tokens is None: if enable_chunked_prefill: - # It is the values that have the best balance between ITL - # and TTFT on A100. Note it is not optimized for throughput. - max_num_batched_tokens = 512 + max_num_batched_tokens = 4096 else: # If max_model_len is too short, use 2048 as the default value # for higher throughput. diff --git a/vllm/core/block/utils.py b/vllm/core/block/utils.py index 28839437c33c5..c96200131539f 100644 --- a/vllm/core/block/utils.py +++ b/vllm/core/block/utils.py @@ -1,5 +1,4 @@ """Block manager utils.""" -from vllm.sequence import SequenceGroup from vllm.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, STR_NOT_IMPL_ENC_DEC_SWA) @@ -26,8 +25,7 @@ def _get_block_mgr_sliding_window_attr(block_mgr): "max_block_sliding_window attributes.") -def check_no_caching_or_swa_for_blockmgr_encdec( - block_mgr, seq_group: SequenceGroup) -> None: +def check_no_caching_or_swa_for_blockmgr_encdec(block_mgr, seq_group) -> None: ''' Enforce that prefix caching & sliding-window attention (SWA) are currently unsupported *specifically* for encoder/decoder models. diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py deleted file mode 100644 index 24ab9eb66194d..0000000000000 --- a/vllm/core/block_manager_v1.py +++ /dev/null @@ -1,738 +0,0 @@ -"""A block manager that manages token blocks.""" -import math -from abc import ABC, abstractmethod -from itertools import count, takewhile -from os.path import commonprefix -from typing import Dict, List, Optional -from typing import Sequence as GenericSequence -from typing import Set, Tuple - -from vllm.block import BlockTable, PhysicalTokenBlock -from vllm.core.block.common import CacheMetricData -from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec -from vllm.core.evictor_v1 import EvictionPolicy, Evictor, make_evictor -from vllm.core.interfaces import AllocStatus, BlockSpaceManager -from vllm.logger import init_logger -from vllm.sequence import Sequence, SequenceGroup, SequenceStatus -from vllm.utils import Device - -logger = init_logger(__name__) - - -class BlockAllocatorBase(ABC): - """Manages free physical token blocks for a device. - - The allocator maintains a list of free blocks and allocates a block when - requested. When a block is freed, its reference count is decremented. If - the reference count becomes zero, the block is added back to the free list. - """ - - @abstractmethod - def __init__(self, - device: Device, - block_size: int, - num_blocks: int, - eviction_policy: EvictionPolicy = EvictionPolicy.LRU): - pass - - @abstractmethod - def allocate(self, - block_hash: Optional[int] = None, - num_hashed_tokens: int = 0) -> PhysicalTokenBlock: - pass - - @abstractmethod - def free(self, block: PhysicalTokenBlock) -> None: - pass - - @abstractmethod - def get_num_free_blocks(self) -> int: - pass - - @abstractmethod - def get_num_total_blocks(self) -> int: - pass - - @abstractmethod - def contains_block(self, block_hash: int) -> bool: - pass - - @abstractmethod - def update_hash(self, block_hash: int, block: PhysicalTokenBlock): - pass - - @abstractmethod - def get_prefix_cache_hit_rate(self) -> float: - """Prefix cache hit rate. -1 means not supported or disabled.""" - pass - - -class CachedBlockAllocator(BlockAllocatorBase): - """Manages free physical token blocks for a device. - - The allocator maintains a list of free blocks and allocates a block when - requested. When a block is freed, its reference count is decremented. If - the reference count becomes zero, the block is added back to the free list. - """ - - def __init__(self, - device: Device, - block_size: int, - num_blocks: int, - eviction_policy: EvictionPolicy = EvictionPolicy.LRU) -> None: - self.device = device - self.block_size = block_size - self.num_blocks = num_blocks - - self.current_num_blocks = 0 - self.cached_blocks: Dict[int, PhysicalTokenBlock] = {} - - self.evictor: Evictor = make_evictor(eviction_policy) - - self.default_hash_ctr = count() - - self.cache_metric_data = CacheMetricData() - - def allocate_block(self, block_hash: int, - num_hashed_tokens: int) -> PhysicalTokenBlock: - if self.current_num_blocks == self.num_blocks: - block = self.evictor.evict() - block.block_hash = block_hash - block.num_hashed_tokens = num_hashed_tokens - return block - block = PhysicalTokenBlock(device=self.device, - block_number=self.current_num_blocks, - block_size=self.block_size, - block_hash=block_hash, - num_hashed_tokens=num_hashed_tokens) - self.current_num_blocks += 1 - return block - - def allocate(self, - block_hash: Optional[int] = None, - num_hashed_tokens: int = 0) -> PhysicalTokenBlock: - if block_hash is None: - block_hash = next(self.default_hash_ctr) - - if block_hash in self.evictor: - assert block_hash not in self.cached_blocks - block = self.evictor.remove(block_hash) - assert block.ref_count == 0 - self.cached_blocks[block_hash] = block - - if block_hash in self.cached_blocks: - self.cache_metric_data.query(hit=True) - else: - self.cache_metric_data.query(hit=False) - self.cached_blocks[block_hash] = self.allocate_block( - block_hash, num_hashed_tokens) - block = self.cached_blocks[block_hash] - assert block.block_hash == block_hash - block.ref_count += 1 - return block - - def free(self, block: PhysicalTokenBlock) -> None: - if block.ref_count == 0: - raise ValueError(f"Double free! {block} is already freed.") - block.ref_count -= 1 - if block.ref_count == 0: - assert block.block_hash not in self.evictor - self.evictor.add(block) - - # Remove the block from the cached_blocks - del self.cached_blocks[block.block_hash] - - def get_num_free_blocks(self) -> int: - return (self.num_blocks - self.current_num_blocks + - self.evictor.num_blocks) - - def get_num_total_blocks(self) -> int: - return self.num_blocks - - def contains_block(self, block_hash: int) -> bool: - return block_hash in self.cached_blocks or block_hash in self.evictor - - def update_hash(self, block_hash: int, block: PhysicalTokenBlock): - # Update the hash of block and the cached_blocks dictionary. - assert not self.contains_block(block_hash) - old_hash = block.block_hash - block.block_hash = block_hash - del self.cached_blocks[old_hash] - self.cached_blocks[block_hash] = block - - def get_prefix_cache_hit_rate(self) -> float: - return self.cache_metric_data.get_hit_rate() - - -class UncachedBlockAllocator(BlockAllocatorBase): - """Manages free physical token blocks for a device. - - The allocator maintains a list of free blocks and allocates a block when - requested. When a block is freed, its reference count is decremented. If - the reference count becomes zero, the block is added back to the free list. - """ - - def __init__( - self, - device: Device, - block_size: int, - num_blocks: int, - ) -> None: - self.device = device - self.block_size = block_size - self.num_blocks = num_blocks - - # Initialize the free blocks. - self.free_blocks: List[PhysicalTokenBlock] = [] - for i in range(num_blocks): - block = PhysicalTokenBlock(device=device, - block_number=i, - block_size=block_size, - block_hash=-1, - num_hashed_tokens=0) - self.free_blocks.append(block) - - def allocate(self, - block_hash: Optional[int] = None, - num_hashed_tokens: int = 0) -> PhysicalTokenBlock: - if not self.free_blocks: - raise ValueError("Out of memory! No free blocks are available.") - block = self.free_blocks.pop() - block.ref_count = 1 - return block - - def free(self, block: PhysicalTokenBlock) -> None: - if block.ref_count == 0: - raise ValueError(f"Double free! {block} is already freed.") - block.ref_count -= 1 - if block.ref_count == 0: - self.free_blocks.append(block) - - def get_num_free_blocks(self) -> int: - return len(self.free_blocks) - - def get_num_total_blocks(self) -> int: - return self.num_blocks - - def contains_block(self, block_hash: int) -> bool: - raise NotImplementedError( - "Invalid codepath for uncached block allocator.") - - def update_hash(self, block_hash: int, block: PhysicalTokenBlock): - raise NotImplementedError( - "Invalid codepath for uncached block allocator.") - - def get_prefix_cache_hit_rate(self) -> float: - return -1 - - -class BlockSpaceManagerV1(BlockSpaceManager): - """Manages the mapping between logical and physical token blocks.""" - - def __init__( - self, - block_size: int, - num_gpu_blocks: int, - num_cpu_blocks: int, - watermark: float = 0.01, - sliding_window: Optional[int] = None, - enable_caching: bool = False, - ) -> None: - self.block_size = block_size - self.num_total_gpu_blocks = num_gpu_blocks - self.num_total_cpu_blocks = num_cpu_blocks - - if enable_caching and sliding_window is not None: - raise NotImplementedError( - "Sliding window is not allowed with prefix caching enabled!") - - self.block_sliding_window = None - if sliding_window is not None: - # Round up to nearest block size to regularize sliding window - # allocation sizes. - self.block_sliding_window = math.ceil(sliding_window / block_size) - - self.watermark = watermark - assert watermark >= 0.0 - - self.enable_caching = enable_caching - - self.watermark_blocks = int(watermark * num_gpu_blocks) - - if self.enable_caching: - logger.info("Automatic prefix caching is enabled.") - self.gpu_allocator: BlockAllocatorBase = CachedBlockAllocator( - Device.GPU, block_size, num_gpu_blocks) - self.cpu_allocator: BlockAllocatorBase = CachedBlockAllocator( - Device.CPU, block_size, num_cpu_blocks) - else: - self.gpu_allocator = UncachedBlockAllocator( - Device.GPU, block_size, num_gpu_blocks) - self.cpu_allocator = UncachedBlockAllocator( - Device.CPU, block_size, num_cpu_blocks) - # Mapping: seq_id -> BlockTable. - self.block_tables: Dict[int, BlockTable] = {} - - # Mapping: req_id -> BlockTable - # Note that each SequenceGroup has a unique - # request ID - self.cross_block_tables: Dict[str, BlockTable] = {} - - def _get_seq_num_required_blocks(self, seq: Optional[Sequence]) -> int: - return 0 if seq is None else seq.n_blocks - - def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: - # FIXME(woosuk): Here we assume that all sequences in the group share - # the same prompt. This may not be true for preempted sequences. - - check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) - - self_num_required_blocks = self._get_seq_num_required_blocks( - seq_group.get_seqs(status=SequenceStatus.WAITING)[0]) - cross_num_required_blocks = self._get_seq_num_required_blocks( - seq_group.get_encoder_seq()) - num_required_blocks = self_num_required_blocks + \ - cross_num_required_blocks - - if self.block_sliding_window is not None: - - num_required_blocks = min(num_required_blocks, - self.block_sliding_window) - num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() - - # Use watermark to avoid frequent cache eviction. - if (self.num_total_gpu_blocks - num_required_blocks < - self.watermark_blocks): - return AllocStatus.NEVER - if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks: - return AllocStatus.OK - else: - return AllocStatus.LATER - - def _allocate_sequence(self, \ - seq: Optional[Sequence], \ - ref_count: int, \ - is_encoder_decoder: bool = True) -> BlockTable: - # Allocate new physical token blocks that will store the prompt tokens. - num_prompt_blocks = self._get_seq_num_required_blocks(seq) - - block_table: BlockTable = BlockTable() - assert seq is not None - for logical_idx in range(num_prompt_blocks): - if (self.block_sliding_window is not None - and logical_idx >= self.block_sliding_window): - block = block_table[logical_idx % self.block_sliding_window] - # Set the reference counts of the token blocks. - block.ref_count = ref_count - elif not is_encoder_decoder and self.enable_caching: - block = self.gpu_allocator.allocate( - seq.hash_of_block(logical_idx), - seq.num_hashed_tokens_of_block(logical_idx)) - else: - block = self.gpu_allocator.allocate() - # Set the reference counts of the token blocks. - block.ref_count = ref_count - block_table.append(block) - - return block_table - - def allocate(self, seq_group: SequenceGroup) -> None: - is_encoder_decoder = seq_group.is_encoder_decoder() - check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) - - # Allocate decoder sequences - # - # NOTE: Here we assume that all sequences in the group have the same - # decoder prompt. - wait_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) - seq = wait_seqs[0] - block_table: BlockTable = \ - self._allocate_sequence(seq, - seq_group.num_seqs(), - is_encoder_decoder) - - # Assign the self-attention block tables for each sequence. - if len(wait_seqs) == 1: - self.block_tables[seq.seq_id] = block_table - else: - for seq in wait_seqs: - self.block_tables[seq.seq_id] = block_table.copy() - - # Allocate encoder sequence - if is_encoder_decoder: - # A SequenceGroup has only a single encoder sequence (at most), - # thus allocate with a ref count of 1 - block_table = self._allocate_sequence(seq_group.get_encoder_seq(), - 1, is_encoder_decoder) - # Assign the cross-attention block table for the SequenceGroup. - self.cross_block_tables[seq_group.request_id] = block_table - - def can_append_slots(self, - seq_group: SequenceGroup, - num_lookahead_slots: int = 0) -> bool: - assert (num_lookahead_slots == 0 - ), "lookahead allocation not supported in BlockSpaceManagerV1" - - # Simple heuristic: If there is at least one free block - # for each sequence, we can append. - num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() - num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING) - return num_seqs <= num_free_gpu_blocks - - def _promote_last_block( - self, - seq: Sequence, - last_block: PhysicalTokenBlock, - ) -> PhysicalTokenBlock: - assert self.enable_caching - - # Compute a new hash for the block so that it can be shared by other - # Sequences - new_hash = seq.hash_of_block(seq.n_blocks - 1) - - # if new_hash is already in the cached table, then free last_block - # and return the cached version - if self.gpu_allocator.contains_block(new_hash): - self.gpu_allocator.free(last_block) - return self.gpu_allocator.allocate(new_hash) - else: - self.gpu_allocator.update_hash(new_hash, last_block) - return last_block - - def _is_last_block_full( - self, - seq: Sequence, - ) -> bool: - token_ids_len = seq.data.get_len() - return token_ids_len > 0 and token_ids_len % seq.block_size == 0 - - def _maybe_promote_last_block( - self, - seq: Sequence, - last_block: PhysicalTokenBlock, - ) -> PhysicalTokenBlock: - if self._is_last_block_full(seq): - return self._promote_last_block(seq, last_block) - else: - return last_block - - def _allocate_last_physical_block( - self, - seq: Sequence, - ) -> PhysicalTokenBlock: - # Called before a new block is appended. - # This is in charge of allocating a new physical block (to be appended). - - # None if the last block is not full. Otherwise, we set it to the - # content hash. - if not self.enable_caching: - return self.gpu_allocator.allocate() - block_hash: Optional[int] = None - n_blocks = seq.n_blocks - if (self._is_last_block_full(seq)): - block_hash = seq.hash_of_block(n_blocks - 1) - num_hashed_tokens = seq.num_hashed_tokens_of_block(n_blocks - 1) - - # num_hashed_tokens is used to compute future hashes - # (e.g. in the hashing function, it is used to ask the sequence for - # prefix tokens) - new_block = self.gpu_allocator.allocate(block_hash, num_hashed_tokens) - - # If the block has is None, then the block is not full. - # If the block is not full, then we expect it to have a refcount of 1. - if block_hash is None: - assert new_block.ref_count == 1 - return new_block - - def append_slots( - self, - seq: Sequence, - num_lookahead_slots: int = 0, - ) -> List[Tuple[int, int]]: - """Allocate a physical slot for a new token.""" - n_blocks = seq.n_blocks - block_table = self.block_tables[seq.seq_id] - # If we need to allocate a new physical block - if len(block_table) < n_blocks: - # Currently this code only supports adding one physical block - assert len(block_table) == n_blocks - 1 - - if (self.block_sliding_window - and len(block_table) >= self.block_sliding_window): - # reuse a block - block_table.append(block_table[len(block_table) % - self.block_sliding_window]) - else: - # The sequence hash a new logical block. - # Allocate a new physical block. - new_block = self._allocate_last_physical_block(seq) - block_table.append(new_block) - return [] - - # We want to append the token to the last physical block. - last_block = block_table[-1] - assert last_block.device == Device.GPU - if last_block.ref_count == 1: - # Not shared with other sequences. Appendable. - if self.enable_caching: - # If the last block is now complete, we may reuse an old block - # to save memory. - maybe_new_block = self._maybe_promote_last_block( - seq, last_block) - block_table[-1] = maybe_new_block - return [] - else: - # The last block is shared with other sequences. - # Copy on Write: Allocate a new block and copy the tokens. - new_block = self._allocate_last_physical_block(seq) - - block_table[-1] = new_block - self.gpu_allocator.free(last_block) - return [(last_block.block_number, new_block.block_number)] - - def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: - # NOTE: fork does not allocate a new physical block. - # Thus, it is always safe from OOM. - if parent_seq.seq_id not in self.block_tables: - # Parent sequence has either been freed or never existed. - return - src_block_table = self.block_tables[parent_seq.seq_id] - self.block_tables[child_seq.seq_id] = src_block_table.copy() - - # When using a sliding window, blocks will be eventually reused. - # In this case the block tables will contain repeated blocks. - # When forking, we must make sure that each block's `ref_count` - # is only incremented by one, so we deduplicate them by wrapping - # them in a set. - for block in set(src_block_table): - block.ref_count += 1 - - def _get_physical_blocks( - self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]: - - # NOTE: Here, we assume that the physical blocks are only shared by - # the sequences in the same group. - request_id = seq_group.request_id - blocks: Set[PhysicalTokenBlock] = set() - for seq in seq_group.get_seqs(): - if seq.is_finished(): - continue - blocks.update(self.block_tables[seq.seq_id]) - # Cross-attention blocks - if seq_group.is_encoder_decoder(): - blocks.update(self.cross_block_tables[request_id]) - return list(blocks) - - def can_swap_in(self, - seq_group: SequenceGroup, - num_lookahead_slots: int = 0) -> AllocStatus: - assert (num_lookahead_slots == 0 - ), "BlockSpaceManagerV1 does not support lookahead allocation" - - blocks = self._get_physical_blocks(seq_group) - num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED) - if seq_group.is_encoder_decoder(): - num_swapped_seqs += 1 - num_free_blocks = self.gpu_allocator.get_num_free_blocks() - # NOTE: Conservatively, we assume that every sequence will allocate - # at least one free block right after the swap-in. - # NOTE: This should match the logic in can_append_slot(). - num_required_blocks = len(blocks) + num_swapped_seqs - if self.gpu_allocator.get_num_total_blocks() < num_required_blocks: - return AllocStatus.NEVER - elif num_free_blocks - num_required_blocks >= self.watermark_blocks: - return AllocStatus.OK - else: - return AllocStatus.LATER - - def _swap_block_table( - self, block_table: BlockTable, src_allocator: BlockAllocatorBase, - dest_allocator: BlockAllocatorBase, - mapping: Dict[PhysicalTokenBlock, - PhysicalTokenBlock]) -> BlockTable: - new_block_table: BlockTable = BlockTable() - - for from_block in block_table: - if from_block in mapping: - to_block = mapping[from_block] - to_block.ref_count += 1 - else: - to_block = dest_allocator.allocate( - from_block.block_hash, from_block.num_hashed_tokens) - mapping[from_block] = to_block - new_block_table.append(to_block) - # Free the source block swapped in to destination. - src_allocator.free(from_block) - - return new_block_table - - def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: - - request_id = seq_group.request_id - - # CPU block -> GPU block. - # dict is efficient in lookup `if cpu_block in mapping` - mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} - for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): - self.block_tables[seq.seq_id] = \ - self._swap_block_table(self.block_tables[seq.seq_id], - self.cpu_allocator, self.gpu_allocator, - mapping) - - if seq_group.is_encoder_decoder(): - self.cross_block_tables[request_id] = \ - self._swap_block_table(self.cross_block_tables[request_id], - self.cpu_allocator, - self.gpu_allocator, - mapping) - - return [(cpu_block.block_number, gpu_block.block_number) - for cpu_block, gpu_block in mapping.items()] - - def can_swap_out(self, seq_group: SequenceGroup) -> bool: - blocks = self._get_physical_blocks(seq_group) - return len(blocks) <= self.cpu_allocator.get_num_free_blocks() - - def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: - request_id = seq_group.request_id - - # GPU block -> CPU block. - # dict is efficient in lookup `if gpu_block in mapping` - mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} - for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - self.block_tables[seq.seq_id] = \ - self._swap_block_table(self.block_tables[seq.seq_id], - self.gpu_allocator, self.cpu_allocator, - mapping) - - if seq_group.is_encoder_decoder(): - self.cross_block_tables[request_id] = \ - self._swap_block_table(self.cross_block_tables[request_id], - self.gpu_allocator, - self.cpu_allocator, - mapping) - - return [(cpu_block.block_number, gpu_block.block_number) - for cpu_block, gpu_block in mapping.items()] - - def _free_block_table(self, block_table: BlockTable) -> None: - # when using a sliding window, each seq will only use up - # to `self.block_sliding_window` blocks. When freeing - # the block table, we must make sure to not free blocks more - # than once. If no sliding window is used, there is no block - # reuse in the block table, so we must free all blocks. - blocks_to_free = (block_table[-self.block_sliding_window:] - if self.block_sliding_window is not None else - block_table) - for block in set(blocks_to_free): - if block.device == Device.GPU: - self.gpu_allocator.free(block) - else: - self.cpu_allocator.free(block) - - def free(self, seq: Sequence) -> None: - if seq.seq_id not in self.block_tables: - # Already freed or haven't been scheduled yet. - return - block_table = self.block_tables[seq.seq_id] - self._free_block_table(block_table) - del self.block_tables[seq.seq_id] - - def free_cross(self, seq_group: SequenceGroup) -> None: - if seq_group.request_id not in self.cross_block_tables: - # Already freed or hasn't ben scheduled yet. - return - block_table = self.cross_block_tables[seq_group.request_id] - self._free_block_table(block_table) - del self.cross_block_tables[seq_group.request_id] - - def reset(self) -> None: - # Free decoder block tables - for block_table in self.block_tables.values(): - self._free_block_table(block_table) - self.block_tables.clear() - # Free cross-attention block tables - for block_table in self.cross_block_tables.values(): - self._free_block_table(block_table) - self.cross_block_tables.clear() - - def get_block_table(self, seq: Sequence) -> List[int]: - return self.block_tables[seq.seq_id].ids() - - def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]: - block_table = self.cross_block_tables[seq_group.request_id] - return [block.block_number for block in block_table] - - def get_num_free_gpu_blocks(self) -> int: - return self.gpu_allocator.get_num_free_blocks() - - def get_num_free_cpu_blocks(self) -> int: - return self.cpu_allocator.get_num_free_blocks() - - def access_all_blocks_in_seq( - self, - seq: Sequence, - access_time: float, - ) -> None: - if self.enable_caching: - # Update the last accessed time of all the blocks accessed - # in this step. - block_table = self.block_tables[seq.seq_id] - for block in block_table: - block.last_accessed = access_time - - def compute_full_blocks_in_seq(self, seq: Sequence, token_chunk_size: int): - if seq.seq_id not in self.block_tables: - return - - # When chunked prefill is enabled, the computed full blocks - # should be calculated based on the number of computed tokens. - max_computed_tokens = (seq.data.get_num_computed_tokens() + - token_chunk_size) - computed_full_blocks = max_computed_tokens // self.block_size - - block_table = self.block_tables[seq.seq_id] - if computed_full_blocks == 0: - return - for i in reversed(range(computed_full_blocks)): - if block_table[i].computed: - break - block_table[i].computed = True - - def get_all_computed_blocks(self, seq: Sequence) -> List[int]: - if seq.seq_id not in self.block_tables: - return [] - block_table = self.block_tables[seq.seq_id] - # NOTE We exclude the last block to avoid the case where the entire - # prompt is cached. This would cause erroneous behavior in model - # runner. - return [ - b.block_number - for b in takewhile(lambda b: b.computed, block_table[:-1]) - ] - - def get_common_computed_block_ids( - self, seqs: List[Sequence]) -> GenericSequence[int]: - """Return the block ids that are common for a given sequence group. - - Used in prefill (can skip prefill of some blocks). - """ - # Can return non-empty result only with prefix caching enabled. - if not self.enable_caching: - return [] - - ids_list = [self.get_all_computed_blocks(seq) for seq in seqs] - return commonprefix([ids for ids in ids_list if ids != []]) - - def mark_blocks_as_computed(self, seq_group: SequenceGroup, - token_chunk_size: int): - if self.enable_caching: - for seq in seq_group.get_seqs(): - self.compute_full_blocks_in_seq(seq, token_chunk_size) - - def get_prefix_cache_hit_rate(self, device: Device) -> float: - if device == Device.GPU: - return self.gpu_allocator.get_prefix_cache_hit_rate() - if device == Device.CPU: - return self.cpu_allocator.get_prefix_cache_hit_rate() - raise ValueError(f"Invalid device: {device}") diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py deleted file mode 100644 index 54818c7e3e9a6..0000000000000 --- a/vllm/core/block_manager_v2.py +++ /dev/null @@ -1,503 +0,0 @@ -"""A block manager that manages token blocks.""" -from itertools import chain -from typing import Dict, List, Optional -from typing import Sequence as GenericSequence -from typing import Tuple - -from vllm.core.block.block_table import BlockTable -from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator -from vllm.core.block.interfaces import Block -from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker, - LastAccessBlocksTracker) -from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec -from vllm.core.interfaces import AllocStatus, BlockSpaceManager -from vllm.sequence import Sequence, SequenceGroup, SequenceStatus -from vllm.utils import Device - -SeqId = int -EncoderSeqId = str - - -class BlockSpaceManagerV2(BlockSpaceManager): - """BlockSpaceManager which manages the allocation of KV cache. - - It owns responsibility for allocation, swapping, allocating memory for - autoregressively-generated tokens, and other advanced features such as - prefix caching, forking/copy-on-write, and sliding-window memory allocation. - - The current implementation is partial; in particular prefix caching and - sliding-window are not feature complete. This class implements the design - described in https://github.com/vllm-project/vllm/pull/3492. - - Lookahead slots - The block manager has the notion of a "lookahead slot". These are slots - in the KV cache that are allocated for a sequence. Unlike the other - allocated slots, the content of these slots is undefined -- the worker - may use the memory allocations in any way. - - In practice, a worker could use these lookahead slots to run multiple - forward passes for a single scheduler invocation. Each successive - forward pass would write KV activations to the corresponding lookahead - slot. This allows low inter-token latency use-cases, where the overhead - of continuous batching scheduling is amortized over >1 generated tokens. - - Speculative decoding uses lookahead slots to store KV activations of - proposal tokens. - - See https://github.com/vllm-project/vllm/pull/3250 for more information - on lookahead scheduling. - - Args: - block_size (int): The size of each memory block. - num_gpu_blocks (int): The number of memory blocks allocated on GPU. - num_cpu_blocks (int): The number of memory blocks allocated on CPU. - watermark (float, optional): The threshold used for memory swapping. - Defaults to 0.01. - sliding_window (Optional[int], optional): The size of the sliding - window. Defaults to None. - enable_caching (bool, optional): Flag indicating whether caching is - enabled. Defaults to False. - """ - - def __init__( - self, - block_size: int, - num_gpu_blocks: int, - num_cpu_blocks: int, - watermark: float = 0.01, - sliding_window: Optional[int] = None, - enable_caching: bool = False, - ) -> None: - self.block_size = block_size - self.num_total_gpu_blocks = num_gpu_blocks - self.num_total_cpu_blocks = num_cpu_blocks - - self.sliding_window = sliding_window - # max_block_sliding_window is the max number of blocks that need to be - # allocated - self.max_block_sliding_window = None - if sliding_window is not None: - # +1 here because // rounds down - num_blocks = sliding_window // block_size + 1 - # +1 here because the last block may not be full, - # and so the sequence stretches one more block at the beginning - # For example, if sliding_window is 3 and block_size is 4, - # we may need 2 blocks when the second block only holds 1 token. - self.max_block_sliding_window = num_blocks + 1 - - self.watermark = watermark - assert watermark >= 0.0 - - self.enable_caching = enable_caching - - self.watermark_blocks = int(watermark * num_gpu_blocks) - - self.block_allocator = CpuGpuBlockAllocator.create( - allocator_type="prefix_caching" if enable_caching else "naive", - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=num_cpu_blocks, - block_size=block_size, - ) - - self.block_tables: Dict[SeqId, BlockTable] = {} - self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {} - - self._computed_blocks_tracker = ComputedBlocksTracker( - self.block_allocator) - self._last_access_blocks_tracker = LastAccessBlocksTracker( - self.block_allocator) - - def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: - # FIXME(woosuk): Here we assume that all sequences in the group share - # the same prompt. This may not be true for preempted sequences. - - check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) - - seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] - num_required_blocks = BlockTable.get_num_required_blocks( - seq.get_token_ids(), - block_size=self.block_size, - ) - - if seq_group.is_encoder_decoder(): - encoder_seq = seq_group.get_encoder_seq() - assert encoder_seq is not None - num_required_blocks += BlockTable.get_num_required_blocks( - encoder_seq.get_token_ids(), - block_size=self.block_size, - ) - - if self.max_block_sliding_window is not None: - num_required_blocks = min(num_required_blocks, - self.max_block_sliding_window) - - num_free_gpu_blocks = self.block_allocator.get_num_free_blocks( - device=Device.GPU) - - # Use watermark to avoid frequent cache eviction. - if (self.num_total_gpu_blocks - num_required_blocks < - self.watermark_blocks): - return AllocStatus.NEVER - if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks: - return AllocStatus.OK - else: - return AllocStatus.LATER - - def _allocate_sequence(self, seq: Sequence) -> BlockTable: - block_table = BlockTable( - block_size=self.block_size, - block_allocator=self.block_allocator, - max_block_sliding_window=self.max_block_sliding_window, - ) - block_table.allocate(seq.get_token_ids()) - - return block_table - - def allocate(self, seq_group: SequenceGroup) -> None: - - # Allocate self-attention block tables for decoder sequences - waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) - assert not (set(seq.seq_id for seq in waiting_seqs) - & self.block_tables.keys()), "block table already exists" - - # NOTE: Here we assume that all sequences in the group have the same - # prompt. - seq = waiting_seqs[0] - block_table: BlockTable = self._allocate_sequence(seq) - self.block_tables[seq.seq_id] = block_table - - # Track seq - self._computed_blocks_tracker.add_seq(seq.seq_id) - self._last_access_blocks_tracker.add_seq(seq.seq_id) - - # Assign the block table for each sequence. - for seq in waiting_seqs[1:]: - self.block_tables[seq.seq_id] = block_table.fork() - - # Track seq - self._computed_blocks_tracker.add_seq(seq.seq_id) - self._last_access_blocks_tracker.add_seq(seq.seq_id) - - # Allocate cross-attention block table for encoder sequence - # - # NOTE: Here we assume that all sequences in the group have the same - # encoder prompt. - request_id = seq_group.request_id - - assert (request_id - not in self.cross_block_tables), \ - "block table already exists" - - check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) - - if seq_group.is_encoder_decoder(): - encoder_seq = seq_group.get_encoder_seq() - assert encoder_seq is not None - block_table = self._allocate_sequence(encoder_seq) - self.cross_block_tables[request_id] = block_table - - def can_append_slots(self, seq_group: SequenceGroup, - num_lookahead_slots: int) -> bool: - """Determine if there is enough space in the GPU KV cache to continue - generation of the specified sequence group. - - We use a worst-case heuristic: assume each touched block will require a - new allocation (either via CoW or new block). We can append slots if the - number of touched blocks is less than the number of free blocks. - - "Lookahead slots" are slots that are allocated in addition to the slots - for known tokens. The contents of the lookahead slots are not defined. - This is used by speculative decoding when speculating future tokens. - """ - - num_touched_blocks = 0 - for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - block_table = self.block_tables[seq.seq_id] - - num_touched_blocks += ( - block_table.get_num_blocks_touched_by_append_slots( - token_ids=block_table.get_unseen_token_ids( - seq.get_token_ids()), - num_lookahead_slots=num_lookahead_slots, - )) - - num_free_gpu_blocks = self.block_allocator.get_num_free_blocks( - Device.GPU) - return num_touched_blocks <= num_free_gpu_blocks - - def append_slots( - self, - seq: Sequence, - num_lookahead_slots: int, - ) -> List[Tuple[int, int]]: - - block_table = self.block_tables[seq.seq_id] - - block_table.append_token_ids( - token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()), - num_lookahead_slots=num_lookahead_slots, - num_computed_slots=seq.data.get_num_computed_tokens(), - ) - # Return any new copy-on-writes. - new_cows = self.block_allocator.clear_copy_on_writes() - return new_cows - - def free(self, seq: Sequence) -> None: - seq_id = seq.seq_id - - if seq_id not in self.block_tables: - # Already freed or haven't been scheduled yet. - return - - # Update seq block ids with the latest access time - self._last_access_blocks_tracker.update_seq_blocks_last_access( - seq_id, self.block_tables[seq.seq_id].physical_block_ids) - - # Untrack seq - self._last_access_blocks_tracker.remove_seq(seq_id) - self._computed_blocks_tracker.remove_seq(seq_id) - - # Free table/blocks - self.block_tables[seq_id].free() - del self.block_tables[seq_id] - - def free_cross(self, seq_group: SequenceGroup) -> None: - request_id = seq_group.request_id - if request_id not in self.cross_block_tables: - # Already freed or hasn't been scheduled yet. - return - self.cross_block_tables[request_id].free() - del self.cross_block_tables[request_id] - - def get_block_table(self, seq: Sequence) -> List[int]: - block_ids = self.block_tables[seq.seq_id].physical_block_ids - return block_ids # type: ignore - - def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]: - request_id = seq_group.request_id - assert request_id in self.cross_block_tables - block_ids = self.cross_block_tables[request_id].physical_block_ids - assert all(b is not None for b in block_ids) - return block_ids # type: ignore - - def access_all_blocks_in_seq(self, seq: Sequence, now: float): - if self.enable_caching: - # Record the latest access time for the sequence. The actual update - # of the block ids is deferred to the sequence free(..) call, since - # only during freeing of block ids, the blocks are actually added to - # the evictor (which is when the most updated time is required) - # (This avoids expensive calls to mark_blocks_as_accessed(..)) - self._last_access_blocks_tracker.update_last_access( - seq.seq_id, now) - - def mark_blocks_as_computed(self, seq_group: SequenceGroup, - token_chunk_size: int): - # If prefix caching is enabled, mark immutable blocks as computed - # right after they have been scheduled (for prefill). This assumes - # the scheduler is synchronous so blocks are actually computed when - # scheduling the next batch. - self.block_allocator.mark_blocks_as_computed([]) - - def get_common_computed_block_ids( - self, seqs: List[Sequence]) -> GenericSequence[int]: - """Determine which blocks for which we skip prefill. - - With prefix caching we can skip prefill for previously-generated blocks. - Currently, the attention implementation only supports skipping cached - blocks if they are a contiguous prefix of cached blocks. - - This method determines which blocks can be safely skipped for all - sequences in the sequence group. - """ - computed_seq_block_ids = [] - for seq in seqs: - computed_seq_block_ids.append( - self._computed_blocks_tracker. - get_cached_computed_blocks_and_update( - seq.seq_id, - self.block_tables[seq.seq_id].physical_block_ids)) - - # NOTE(sang): This assumes seq_block_ids doesn't contain any None. - return self.block_allocator.get_common_computed_block_ids( - computed_seq_block_ids) # type: ignore - - def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: - if parent_seq.seq_id not in self.block_tables: - # Parent sequence has either been freed or never existed. - return - src_block_table = self.block_tables[parent_seq.seq_id] - self.block_tables[child_seq.seq_id] = src_block_table.fork() - - # Track child seq - self._computed_blocks_tracker.add_seq(child_seq.seq_id) - self._last_access_blocks_tracker.add_seq(child_seq.seq_id) - - def can_swap_in(self, seq_group: SequenceGroup, - num_lookahead_slots: int) -> AllocStatus: - """Returns the AllocStatus for the given sequence_group - with num_lookahead_slots. - - Args: - sequence_group (SequenceGroup): The sequence group to swap in. - num_lookahead_slots (int): Number of lookahead slots used in - speculative decoding, default to 0. - - Returns: - AllocStatus: The AllocStatus for the given sequence group. - """ - return self._can_swap(seq_group, Device.GPU, SequenceStatus.SWAPPED, - num_lookahead_slots) - - def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: - """Returns the block id mapping (from CPU to GPU) generated by - swapping in the given seq_group with num_lookahead_slots. - - Args: - seq_group (SequenceGroup): The sequence group to swap in. - - Returns: - List[Tuple[int, int]]: The mapping of swapping block from CPU - to GPU. - """ - physical_block_id_mapping = [] - for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): - blocks = self.block_tables[seq.seq_id].blocks - if len(blocks) == 0: - continue - - seq_swap_mapping = self.block_allocator.swap(blocks=blocks, - src_device=Device.CPU, - dst_device=Device.GPU) - - # Refresh the block ids of the table (post-swap) - self.block_tables[seq.seq_id].update(blocks) - - seq_physical_block_id_mapping = { - self.block_allocator.get_physical_block_id( - Device.CPU, cpu_block_id): - self.block_allocator.get_physical_block_id( - Device.GPU, gpu_block_id) - for cpu_block_id, gpu_block_id in seq_swap_mapping.items() - } - - physical_block_id_mapping.extend( - list(seq_physical_block_id_mapping.items())) - - return physical_block_id_mapping - - def can_swap_out(self, seq_group: SequenceGroup) -> bool: - """Returns whether we can swap out the given sequence_group - with num_lookahead_slots. - - Args: - seq_group (SequenceGroup): The sequence group to swap in. - num_lookahead_slots (int): Number of lookahead slots used in - speculative decoding, default to 0. - - Returns: - bool: Whether it's possible to swap out current sequence group. - """ - alloc_status = self._can_swap(seq_group, Device.CPU, - SequenceStatus.RUNNING) - return alloc_status == AllocStatus.OK - - def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: - """Returns the block id mapping (from GPU to CPU) generated by - swapping out the given sequence_group with num_lookahead_slots. - - Args: - sequence_group (SequenceGroup): The sequence group to swap in. - - Returns: - List[Tuple[int, int]]: The mapping of swapping block from - GPU to CPU. - """ - physical_block_id_mapping = [] - for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - blocks = self.block_tables[seq.seq_id].blocks - if len(blocks) == 0: - continue - - seq_swap_mapping = self.block_allocator.swap(blocks=blocks, - src_device=Device.GPU, - dst_device=Device.CPU) - - # Refresh the block ids of the table (post-swap) - self.block_tables[seq.seq_id].update(blocks) - - seq_physical_block_id_mapping = { - self.block_allocator.get_physical_block_id( - Device.GPU, gpu_block_id): - self.block_allocator.get_physical_block_id( - Device.CPU, cpu_block_id) - for gpu_block_id, cpu_block_id in seq_swap_mapping.items() - } - - physical_block_id_mapping.extend( - list(seq_physical_block_id_mapping.items())) - - return physical_block_id_mapping - - def get_num_free_gpu_blocks(self) -> int: - return self.block_allocator.get_num_free_blocks(Device.GPU) - - def get_num_free_cpu_blocks(self) -> int: - return self.block_allocator.get_num_free_blocks(Device.CPU) - - def get_prefix_cache_hit_rate(self, device: Device) -> float: - return self.block_allocator.get_prefix_cache_hit_rate(device) - - def _can_swap(self, - seq_group: SequenceGroup, - device: Device, - status: SequenceStatus, - num_lookahead_slots: int = 0) -> AllocStatus: - """Returns the AllocStatus for swapping in/out the given sequence_group - on to the 'device'. - - Args: - sequence_group (SequenceGroup): The sequence group to swap in. - device (Device): device to swap the 'seq_group' on. - status (SequenceStatus): The status of sequence which is needed - for action. RUNNING for swap out and SWAPPED for swap in - num_lookahead_slots (int): Number of lookahead slots used in - speculative decoding, default to 0. - - Returns: - AllocStatus: The AllocStatus for swapping in/out the given - sequence_group on to the 'device'. - """ - blocks = self._get_blocks_for_swap(seq_group, status) - num_blocks_touched = self.block_allocator.get_num_blocks_touched( - blocks, device, num_lookahead_slots) - watermark_blocks = 0 - if device == Device.GPU: - watermark_blocks = self.watermark_blocks - if self.block_allocator.get_num_total_blocks( - device) < num_blocks_touched: - return AllocStatus.NEVER - elif self.block_allocator.get_num_free_blocks( - device) - num_blocks_touched >= watermark_blocks: - return AllocStatus.OK - else: - return AllocStatus.LATER - - def _get_blocks_for_swap(self, seq_group: SequenceGroup, - status: SequenceStatus) -> List[Block]: - """Returns the list of blocks those are touched by the seq_group - - Args: - sequence_group (SequenceGroup): The sequence group to swap in. - status (SequenceStatus): The status of sequence which is needed - for action. RUNNING for swap out and SWAPPED for swap in - - Returns: - The list of blocks those are touched by the seq_group. - """ - blocks: Dict[int, List[Block]] = {} - for seq in seq_group.get_seqs(status=status): - block_table = self.block_tables[seq.seq_id] - if block_table.blocks is not None: - blocks[seq.seq_id] = block_table.blocks - combined_blocks = list(chain(*blocks.values())) - return combined_blocks diff --git a/vllm/core/kv_cache_manager.py b/vllm/core/kv_cache_manager.py new file mode 100644 index 0000000000000..69c5e2f07ed09 --- /dev/null +++ b/vllm/core/kv_cache_manager.py @@ -0,0 +1,89 @@ +from typing import Dict, List, Optional, Set, Tuple + +from vllm.request import Request +from vllm.logger import init_logger +from vllm.utils import cdiv + +logger = init_logger(__name__) + + +class KVCacheManager: + + def __init__( + self, + block_size: int, + num_gpu_blocks: int, + sliding_window: Optional[int] = None, + enable_caching: bool = True, + watermark: float = 0.01, + ) -> None: + self.block_size = block_size + self.num_gpu_blocks = num_gpu_blocks + self.sliding_window = sliding_window + self.enable_caching = enable_caching + self.watermark = watermark + + # Reserve block id 0 for padding. + self.free_block_ids = list(range(num_gpu_blocks)) + self.req_to_block_ids: Dict[str, List[int]] = {} + self.block_id_to_reqs: List[Set[str]] = [ + set() for _ in range(num_gpu_blocks) + ] + + def get_computed_blocks(self, request: Request) -> List[int]: + return [] + + def append_slots( + self, + request: Request, + num_tokens: int, + ) -> Optional[List[int]]: + num_blocks = cdiv(request.num_computed_tokens + num_tokens, + self.block_size) + req_block_ids = self.req_to_block_ids[request.request_id] + num_new_blocks = num_blocks - len(req_block_ids) + if num_new_blocks > len(self.free_block_ids): + # Cannot allocate new blocks. + return None + if num_new_blocks == 0: + # No new block is needed. + return [] + # Allocate new blocks. + new_block_ids = self._get_new_blocks(num_new_blocks) + req_block_ids.extend(new_block_ids) + for block_id in new_block_ids: + self.block_id_to_reqs[block_id].add(request.request_id) + return new_block_ids + + def allocate_slots( + self, + request: Request, + num_tokens: int, + computed_block_ids: List[int], + ) -> Optional[List[int]]: + num_new_blocks = cdiv(num_tokens, self.block_size) + if (len(self.free_block_ids) - num_new_blocks < + self.watermark * self.num_gpu_blocks): + # Cannot allocate new blocks. + return None + + new_block_ids = self._get_new_blocks(num_new_blocks) + self.req_to_block_ids[request.request_id] = (computed_block_ids + + new_block_ids) + for block_id in new_block_ids: + self.block_id_to_reqs[block_id].add(request.request_id) + return new_block_ids + + def free(self, request: Request) -> None: + block_ids = self.req_to_block_ids.pop(request.request_id) + for block_id in block_ids: + reqs = self.block_id_to_reqs[block_id] + reqs.remove(request.request_id) + if not reqs: + self.free_block_ids.append(block_id) + + def _get_new_blocks(self, num_blocks: int) -> List[int]: + assert num_blocks <= len(self.free_block_ids) + new_block_ids = self.free_block_ids[-num_blocks:] + self.free_block_ids = self.free_block_ids[:-num_blocks] + return new_block_ids diff --git a/vllm/core/scheduler_v2.py b/vllm/core/scheduler_v2.py index a656d8fda8cc2..c0a2476dd5ed6 100644 --- a/vllm/core/scheduler_v2.py +++ b/vllm/core/scheduler_v2.py @@ -8,7 +8,8 @@ Tuple, Union) from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig -from vllm.core.interfaces import AllocStatus, BlockSpaceManager +from vllm.core.kv_cache_manager import KVCacheManager +from vllm.outputs_v2 import ModelRunnerOutput from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.prompt_adapter.request import PromptAdapterRequest @@ -35,23 +36,14 @@ def __init__( # LoRAs. This should be improved in the future. self.lora_config = lora_config - version = "v1" - if self.scheduler_config.use_v2_block_manager: - version = "v2" - if self.scheduler_config.embedding_mode: - version = "embedding" - BlockSpaceManagerImpl = \ - BlockSpaceManager.get_block_space_manager_class(version) num_gpu_blocks = cache_config.num_gpu_blocks - num_cpu_blocks = cache_config.num_cpu_blocks - + assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0 # Create the block space manager. - self.block_manager = BlockSpaceManagerImpl( + self.kv_cache_manager = KVCacheManager( block_size=self.cache_config.block_size, num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=num_cpu_blocks, sliding_window=self.cache_config.sliding_window, - enable_caching=self.cache_config.enable_prefix_caching) + enable_caching=True) self.block_size = self.cache_config.block_size # Scheduling constraints. @@ -67,10 +59,10 @@ def __init__( self.finished_req_ids: Set[str] = set() self.aborted_req_ids: Set[str] = set() - def schedule(self) -> "SchedulerOutput": - # Finish the requests that have reached the maximum length. - self._check_stop_by_len() + self.cum = 0 + def schedule(self) -> "SchedulerOutput": + start = time.time() scheduled_new_reqs: List[Request] = [] scheduled_resumed_reqs: List[Request] = [] scheduled_running_reqs: List[Request] = [] @@ -81,6 +73,7 @@ def schedule(self) -> "SchedulerOutput": token_budget = self.max_num_scheduled_tokens # First, schedule the RUNNING requests. + new_running: Deque[Request] = deque() while self.running: if token_budget == 0: break @@ -88,33 +81,36 @@ def schedule(self) -> "SchedulerOutput": request = self.running[0] num_tokens = request.num_tokens - request.num_computed_tokens num_tokens = min(num_tokens, token_budget) + assert num_tokens > 0 - new_block_ids: List[int] = [] - while not self.block_manager.can_append_slots(request, num_tokens): - new_block_ids = self.block_manager.append_slots( + while True: + new_block_ids = self.kv_cache_manager.append_slots( request, num_tokens) - if not new_block_ids: + if new_block_ids is None: # The request cannot be scheduled. # Preempt the lowest-priority request. preempted_req = self.running.pop() - self.block_manager.free(preempted_req) + self.kv_cache_manager.free(preempted_req) preempted_req.status = RequestStatus.PREEMPTED preempted_req.num_computed_tokens = 0 self.waiting.appendleft(preempted_req) preempted_reqs.append(preempted_req) - if preempted_req == request: + # No more request to preempt. break - else: - # The request can be scheduled. - self.running.popleft() - scheduled_running_reqs.append(request) - - req_to_new_block_ids[request.request_id] = new_block_ids - num_scheduled_tokens[request.request_id] = num_tokens - token_budget -= num_tokens - request.status = RequestStatus.RUNNING + else: + # The request can be scheduled. + self.running.popleft() + new_running.append(request) + scheduled_running_reqs.append(request) + + req_to_new_block_ids[request.request_id] = new_block_ids + num_scheduled_tokens[request.request_id] = num_tokens + token_budget -= num_tokens + request.status = RequestStatus.RUNNING + break + self.running = new_running # Next, schedule the WAITING requests. while self.waiting: @@ -126,21 +122,25 @@ def schedule(self) -> "SchedulerOutput": break request = self.waiting[0] - allocated = self.block_manager.allocate(request) - if allocated is None: - # The request cannot be scheduled. - break - - # The request can be scheduled. - computed_block_ids, new_block_ids = allocated - - # Get cached tokens. - num_computed_blocks = len(computed_block_ids) - num_computed_tokens = num_computed_blocks * self.block_size - + # Get already-cached tokens. + computed_block_ids = self.kv_cache_manager.get_computed_blocks( + request) + # NOTE(woosuk): Since incomplete blocks are not eligible for + # sharing, `num_computed_tokens` is always a multiple of + # `block_size`. + num_computed_tokens = len(computed_block_ids) * self.block_size # Number of tokens to be scheduled. + # We use `request.num_tokens` instead of `request.num_prompt_tokens` + # to consider the resumed requests, which have output tokens. num_tokens = request.num_tokens - num_computed_tokens num_tokens = min(num_tokens, token_budget) + assert num_tokens > 0 + new_block_ids = self.kv_cache_manager.allocate_slots( + request, num_tokens, computed_block_ids) + if new_block_ids is None: + # The request cannot be scheduled. + break + request.num_computed_tokens = num_computed_tokens self.waiting.popleft() self.running.append(request) @@ -169,18 +169,18 @@ def schedule(self) -> "SchedulerOutput": new_reqs_data = [ NewRequestData.from_request(req, req_to_new_block_ids[req.request_id], - num_computed_tokens) + req.num_computed_tokens) for req in scheduled_new_reqs ] resumed_reqs_data = [ ResumedRequestData.from_request( - req, req_to_new_block_ids[req.request_id], num_computed_tokens) - for req in scheduled_resumed_reqs + req, req_to_new_block_ids[req.request_id], + req.num_computed_tokens) for req in scheduled_resumed_reqs ] running_reqs_data = [ RunningRequestData.from_request( - req, req_to_new_block_ids[req.request_id], num_computed_tokens) - for req in scheduled_running_reqs + req, req_to_new_block_ids[req.request_id], + req.num_computed_tokens) for req in scheduled_running_reqs ] preempted_req_ids = {req.request_id for req in preempted_reqs} scheduler_output = SchedulerOutput( @@ -193,17 +193,45 @@ def schedule(self) -> "SchedulerOutput": finished_req_ids=self.finished_req_ids, aborted_req_ids=self.aborted_req_ids, ) + end = time.time() + self.cum += (end - start) + print(f"Scheduler time: {(end - start) * 1000:.3f} ms") + print(f"Cumulative scheduler time: {self.cum * 1000:.3f} ms") self.finished_req_ids = set() self.aborted_req_ids = set() - for request in self.running: - num_tokens = num_scheduled_tokens[request.request_id] - request.num_computed_tokens = num_computed_tokens + num_tokens - if request.num_tokens == request.num_computed_tokens: - # TODO: Consider speculative decoding. - request.num_output_tokens += 1 return scheduler_output + def update_from_output( + self, + scheduler_output: "SchedulerOutput", + model_runner_output: "ModelRunnerOutput", + ) -> List[Request]: + sampled_token_ids = model_runner_output.sampled_token_ids_cpu.numpy() + num_scheduled_tokens = scheduler_output.num_scheduled_tokens + new_running: Deque[Request] = deque() + finished_reqs: List[Request] = [] + for request in self.running: + req_id = request.request_id + # TODO: Consider speculative decoding. + request.num_computed_tokens += num_scheduled_tokens[req_id] + if request.num_computed_tokens >= request.num_prompt_tokens: + req_index = model_runner_output.req_id_to_index[req_id] + token_id = sampled_token_ids[req_index] + request.output_token_ids.append(token_id) + # TODO: Update the KV cache manager for prefix caching. + + if (request.num_tokens >= self.max_model_len + or request.num_output_tokens >= request.max_tokens): + request.status = RequestStatus.FINISHED_LENGTH_CAPPED + self.finished_req_ids.add(req_id) + finished_reqs.append(request) + self._free_request(request) + continue + new_running.append(request) + self.running = new_running + return finished_reqs + def add_request(self, request: Request) -> None: self.waiting.append(request) @@ -249,23 +277,9 @@ def stop_requests(self, request_ids: Union[str, Iterable[str]]) -> None: self.finished_req_ids.add(request.request_id) self._free_request(request) - def _check_stop_by_len(self) -> None: - stopped_reqs: List[Request] = [] - # TODO: Optimize this. - for request in self.running: - assert request.max_tokens is not None - if (request.num_tokens >= self.max_model_len - or request.num_output_tokens >= request.max_tokens): - request.status = RequestStatus.FINISHED_LENGTH_CAPPED - stopped_reqs.append(request) - for request in stopped_reqs: - self.running.remove(request) - self.finished_req_ids.add(request.request_id) - self._free_request(request) - def _free_request(self, request: Request) -> None: assert request.is_finished() - self.block_manager.free(request) + self.kv_cache_manager.free(request) def get_num_unfinished_requests(self) -> int: return len(self.waiting) + len(self.running) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ca6034ddbe5c5..31e5ed7b39509 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -112,7 +112,7 @@ class EngineArgs: cpu_offload_gb: float = 0 # GiB gpu_memory_utilization: float = 0.90 max_num_batched_tokens: Optional[int] = None - max_num_seqs: int = 256 + max_num_seqs: int = 2048 max_logprobs: int = 20 # Default value for OpenAI Chat Completions API disable_log_stats: bool = False revision: Optional[str] = None diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index f108751056ab5..e1ac611769c5b 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -9,29 +9,29 @@ import vllm.envs as envs from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig) -from vllm.core.scheduler import SchedulerOutputs +from vllm.core.scheduler_v2 import SchedulerOutput from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_timeout import asyncio_timeout -from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState +from vllm.engine.llm_engine_v2 import LLMEngine from vllm.engine.metrics_types import StatLoggerBase from vllm.executor.executor_base import ExecutorAsyncBase -from vllm.executor.gpu_executor import GPUExecutorAsync +from vllm.executor.executor_v2 import GPUExecutorAsync from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import PromptType from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -from vllm.sequence import ExecuteModelRequest from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.usage.usage_lib import UsageContext from vllm.utils import weak_bind logger = init_logger(__name__) ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S +EmbeddingRequestOutput = None +RequestOutput = None class AsyncEngineDeadError(RuntimeError): @@ -1017,7 +1017,7 @@ async def get_lora_config(self) -> LoRAConfig: async def do_log_stats( self, - scheduler_outputs: Optional[SchedulerOutputs] = None, + scheduler_outputs: Optional[SchedulerOutput] = None, model_output: Optional[List[SamplerOutput]] = None) -> None: self.engine.do_log_stats() diff --git a/vllm/engine/llm_engine_v2.py b/vllm/engine/llm_engine_v2.py index 65fa8cf2da6eb..1a41526027f24 100644 --- a/vllm/engine/llm_engine_v2.py +++ b/vllm/engine/llm_engine_v2.py @@ -6,7 +6,7 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict, Iterable, List, Mapping, NamedTuple, Optional) from typing import Sequence as GenericSequence -from typing import Set, Type, Union +from typing import Set, Type, Union, Tuple import torch from typing_extensions import TypeVar @@ -28,8 +28,6 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, - RequestOutputFactory) from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams @@ -62,7 +60,6 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: _G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup) -_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput) class LLMEngine: @@ -312,7 +309,6 @@ def _initialize_kv_caches(self) -> None: self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks - self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks) @classmethod @@ -325,7 +321,7 @@ def from_engine_args( """Creates an LLM engine from the engine arguments.""" # Create the engine configs. engine_config = engine_args.create_engine_config() - executor_class = cls._get_executor_cls(engine_config) + executor_class = _get_executor_cls(engine_config) # Create the LLM engine. engine = cls( **engine_config.to_dict(), @@ -334,20 +330,8 @@ def from_engine_args( usage_context=usage_context, stat_loggers=stat_loggers, ) - return engine - def __reduce__(self): - # This is to ensure that the LLMEngine is not referenced in - # the closure used to initialize Ray worker actors - raise RuntimeError("LLMEngine should not be pickled!") - - def __del__(self): - # Shutdown model executor when engine is garbage collected - # Use getattr since __init__ can fail before the field is set - if model_executor := getattr(self, "model_executor", None): - model_executor.shutdown() - def get_tokenizer_group( self, group_type: Type[_G] = BaseTokenizerGroup, @@ -413,7 +397,7 @@ def _add_processed_request( processed_inputs, arrival_time, sampling_params=params) - self.scheduler.add_req(req) + self.scheduler.add_request(req) def stop_remote_worker_execution_loop(self) -> None: self.model_executor.stop_remote_worker_execution_loop() @@ -510,7 +494,7 @@ def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: >>> # abort the request >>> engine.abort_request(request_id) """ - self.scheduler.abort_reqs(request_id) + self.scheduler.abort_requests(request_id) def get_model_config(self) -> ModelConfig: """Gets the model configuration.""" @@ -534,26 +518,19 @@ def get_lora_config(self) -> LoRAConfig: def get_num_unfinished_requests(self) -> int: """Gets the number of unfinished requests.""" - return self.scheduler.get_num_unfinished_reqs() + return self.scheduler.get_num_unfinished_requests() def has_unfinished_requests(self) -> bool: """Returns True if there are unfinished requests.""" - return self.scheduler.has_unfinished_req() + return self.scheduler.has_unfinished_requests() - def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: + def step(self) -> Tuple[List[Request], List[Request]]: scheduler_output = self.scheduler.schedule() - sampler_output = self.model_executor.execute_model(scheduler_output) - self._process_model_outputs(sampler_output) - - if not self.has_unfinished_requests(): - # Stop the execute model loop in parallel workers until there are - # more requests to process. This avoids waiting indefinitely in - # torch.distributed ops which may otherwise timeout, and unblocks - # the RPC thread in the workers so that they can process any other - # queued control plane messages, such as add/remove lora adapters. - logger.debug("Stopping remote worker execution loop.") - self.model_executor.stop_remote_worker_execution_loop() - return sampler_output + output = self.model_executor.execute_model(scheduler_output) + finished_reqs = self.scheduler.update_from_output( + scheduler_output, output) + running_reqs = self.scheduler.running + return finished_reqs, running_reqs def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None: if not self.log_stats: @@ -580,10 +557,11 @@ def check_health(self) -> None: def _validate_model_inputs(self, inputs: Union[LLMInputs, EncoderDecoderLLMInputs]): - if self.is_encoder_decoder_model(): - prompt_ids = inputs.get("encoder_prompt_token_ids") - else: - prompt_ids = inputs.get("prompt_token_ids") + # if self.is_encoder_decoder_model(): + # prompt_ids = inputs.get("encoder_prompt_token_ids") + # else: + # prompt_ids = inputs.get("prompt_token_ids") + prompt_ids = inputs.get("prompt_token_ids") if prompt_ids is None or len(prompt_ids) == 0: raise ValueError("Prompt cannot be empty") @@ -662,6 +640,6 @@ def _get_executor_cls(engine_config: EngineConfig) -> Type[ExecutorBase]: "support VLLM_USE_RAY_SPMD_WORKER=1") executor_class = MultiprocessingGPUExecutor else: - from vllm.executor.gpu_executor import GPUExecutor + from vllm.executor.executor_v2 import GPUExecutor executor_class = GPUExecutor return executor_class diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 03a58ddbf2a94..7257bc4e3c966 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -17,7 +17,7 @@ from vllm.model_executor.guided_decoding import ( GuidedDecodingRequest, get_local_guided_decoding_logits_processor) from vllm.model_executor.guided_decoding.guided_fields import LLMGuidedOptions -from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.outputs_v2 import RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import RequestOutputKind, SamplingParams @@ -28,6 +28,7 @@ from vllm.utils import Counter, deprecate_kwargs, is_list_of logger = init_logger(__name__) +EmbeddingRequestOutput = RequestOutput # FIXME class LLM: @@ -352,7 +353,8 @@ def generate( guided_options=guided_options_request) outputs = self._run_engine(use_tqdm=use_tqdm) - return LLMEngine.validate_outputs(outputs, RequestOutput) + # return LLMEngine.validate_outputs(outputs, RequestOutput) + return outputs def chat( self, @@ -704,8 +706,10 @@ def _add_guided_processor( return params def _run_engine( - self, *, use_tqdm: bool - ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: + self, + *, + use_tqdm: bool, + ) -> List[Union[RequestOutput]]: # Initialize tqdm. if use_tqdm: num_requests = self.llm_engine.get_num_unfinished_requests() @@ -722,24 +726,20 @@ def _run_engine( total_in_toks = 0 total_out_toks = 0 while self.llm_engine.has_unfinished_requests(): - step_outputs = self.llm_engine.step() - for output in step_outputs: - if output.finished: - outputs.append(output) - if use_tqdm: - if isinstance(output, RequestOutput): - # Calculate tokens only for RequestOutput - assert output.prompt_token_ids is not None - total_in_toks += len(output.prompt_token_ids) - in_spd = total_in_toks / pbar.format_dict["elapsed"] - total_out_toks += sum( - len(stp.token_ids) for stp in output.outputs) - out_spd = (total_out_toks / - pbar.format_dict["elapsed"]) - pbar.postfix = ( - f"est. speed input: {in_spd:.2f} toks/s, " - f"output: {out_spd:.2f} toks/s") - pbar.update(1) + finished_reqs, _ = self.llm_engine.step() + for req in finished_reqs: + output = RequestOutput.from_request(req) + outputs.append(output) + if use_tqdm: + # Calculate tokens only for RequestOutput + assert output.prompt_token_ids is not None + total_in_toks += len(output.prompt_token_ids) + in_spd = total_in_toks / pbar.format_dict["elapsed"] + total_out_toks += len(output.outputs[0].token_ids) + out_spd = (total_out_toks / pbar.format_dict["elapsed"]) + pbar.postfix = (f"est. speed input: {in_spd:.2f} toks/s, " + f"output: {out_spd:.2f} toks/s") + pbar.update(1) if use_tqdm: pbar.close() diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 7e9f53b1816d1..148aa2e591db7 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -14,10 +14,11 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import (LogitsProcessor, RequestOutputKind, SamplingParams) -from vllm.sequence import Logprob from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import random_uuid +Logprob = float + # torch is mocked during docs generation, # so we have to provide the values as literals _MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807) diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/executor_v2.py similarity index 62% rename from vllm/executor/gpu_executor.py rename to vllm/executor/executor_v2.py index 2185c9cf6cead..bde0eda1aeb82 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/executor_v2.py @@ -1,30 +1,18 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union +import os from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.outputs_v2 import ModelRunnerOutput from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sequence import ExecuteModelRequest, PoolerOutput from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, make_async) -from vllm.worker.worker_base import WorkerBase, WorkerWrapperBase +from vllm.worker.worker_v2 import Worker logger = init_logger(__name__) -def create_worker(worker_module_name: str, worker_class_name: str, - worker_class_fn: Optional[Callable[[], Type[WorkerBase]]], - **kwargs): - wrapper = WorkerWrapperBase( - worker_module_name=worker_module_name, - worker_class_name=worker_class_name, - worker_class_fn=worker_class_fn, - ) - wrapper.init_worker(**kwargs) - return wrapper.worker - - class GPUExecutor(ExecutorBase): uses_ray: bool = False @@ -36,19 +24,22 @@ def _init_executor(self) -> None: "GPUExecutor only supports single GPU.") self.driver_worker = self._create_worker() - self.driver_worker.init_device() + self.driver_worker.initialize() self.driver_worker.load_model() - def _get_worker_kwargs( + def _create_worker( self, local_rank: int = 0, rank: int = 0, - distributed_init_method: Optional[str] = None) -> Dict[str, Any]: + distributed_init_method: Optional[str] = None) -> Worker: """Return worker init args for a given rank.""" + # see https://github.com/NVIDIA/nccl/issues/1234 + os.environ['NCCL_CUMEM_ENABLE'] = '0' + if distributed_init_method is None: distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) - return dict( + return Worker( model_config=self.model_config, parallel_config=self.parallel_config, scheduler_config=self.scheduler_config, @@ -66,47 +57,6 @@ def _get_worker_kwargs( observability_config=self.observability_config, ) - def _get_worker_module_and_class( - self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]: - worker_class_fn = None - if self.scheduler_config.is_multi_step: - worker_module_name = "vllm.worker.multi_step_worker" - worker_class_name = "MultiStepWorker" - elif self.speculative_config: - worker_module_name = "vllm.spec_decode.spec_decode_worker" - worker_class_name = "create_spec_worker" - else: - worker_module_name = "vllm.worker.worker" - worker_class_name = "Worker" - return (worker_module_name, worker_class_name, worker_class_fn) - - def _get_create_worker_kwargs( - self, - local_rank: int = 0, - rank: int = 0, - distributed_init_method: Optional[str] = None) -> Dict: - worker_kwargs = self._get_worker_kwargs(local_rank, rank, - distributed_init_method) - - (worker_module_name, worker_class_name, - worker_class_fn) = self._get_worker_module_and_class() - worker_kwargs.update( - worker_module_name=worker_module_name, - worker_class_name=worker_class_name, - worker_class_fn=worker_class_fn, - ) - - return worker_kwargs - - def _create_worker(self, - local_rank: int = 0, - rank: int = 0, - distributed_init_method: Optional[str] = None): - return create_worker(**self._get_create_worker_kwargs( - local_rank=local_rank, - rank=rank, - distributed_init_method=distributed_init_method)) - def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available KV blocks by invoking the underlying worker. @@ -125,9 +75,9 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) def execute_model( - self, execute_model_req: ExecuteModelRequest - ) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]: - output = self.driver_worker.execute_model(execute_model_req) + self, scheduler_output, + ) -> Optional[List[Union[ModelRunnerOutput]]]: + output = self.driver_worker.execute_model(scheduler_output) return output def add_lora(self, lora_request: LoRARequest) -> bool: @@ -180,8 +130,8 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase): async def execute_model_async( self, - execute_model_req: ExecuteModelRequest, - ) -> List[Union[SamplerOutput, PoolerOutput]]: + scheduler_output, + ) -> List[Union[ModelRunnerOutput]]: output = await make_async(self.driver_worker.execute_model - )(execute_model_req=execute_model_req) + )(scheduler_output=scheduler_output) return output diff --git a/vllm/executor/msgspec_utils.py b/vllm/executor/msgspec_utils.py index c467115f124ca..ce08272e1db66 100644 --- a/vllm/executor/msgspec_utils.py +++ b/vllm/executor/msgspec_utils.py @@ -1,7 +1,7 @@ from array import array from typing import Any, Type -from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE +# from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE def encode_hook(obj: Any) -> Any: diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 59e9854393b6b..7332b2b2a5ec7 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -9,9 +9,7 @@ from vllm.executor.msgspec_utils import decode_hook, encode_hook from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.utils import get_ip, is_hip, is_xpu -from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) PG_WAIT_TIMEOUT = 1800 @@ -22,73 +20,73 @@ from ray.util import placement_group_table from ray.util.placement_group import PlacementGroup - class RayWorkerWrapper(WorkerWrapperBase): - """Ray wrapper for vllm.worker.Worker, allowing Worker to be - lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES.""" - - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - # Since the compiled DAG runs a main execution - # in a different thread that calls cuda.set_device. - # The flag indicates is set_device is called on - # that thread. - self.compiled_dag_cuda_device_set = False - - self.input_decoder = msgspec.msgpack.Decoder(ExecuteModelRequest, - dec_hook=decode_hook) - self.output_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook) - - def get_node_ip(self) -> str: - return get_ip() - - def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]: - node_id = ray.get_runtime_context().get_node_id() - gpu_ids = ray.get_gpu_ids() - return node_id, gpu_ids - - def execute_model_spmd( - self, req_or_tuple: Union[bytes, - Tuple[bytes, - Optional[IntermediateTensors]]] - ) -> bytes: - """Execute model in SPMD fashion: used only when SPMD worker and - compiled DAG are both enabled. - - Args: - req_or_tuple: A request or a tuple containing the - request and intermediate tensors. Intermediate tensors are - None unless if it is provided because it is > 0 pipeline - stage. The request is serialized by msgspec. - """ - if isinstance(req_or_tuple, bytes): - serialized_req, intermediate_tensors = req_or_tuple, None - else: - serialized_req, intermediate_tensors = req_or_tuple - - execute_model_req = self.input_decoder.decode(serialized_req) - - # TODO(swang): This is needed right now because Ray aDAG executes - # on a background thread, so we need to reset torch's current - # device. - import torch - if not self.compiled_dag_cuda_device_set: - torch.cuda.set_device(self.worker.device) - self.compiled_dag_cuda_device_set = True - - output = self.worker._execute_model_spmd(execute_model_req, - intermediate_tensors) - # Pipeline model request and output to the next pipeline stage. - if isinstance(output, IntermediateTensors): - output = serialized_req, output - else: - output = self.output_encoder.encode(output) - - return output - - def override_env_vars(self, vars: Dict[str, str]): - os.environ.update(vars) - - ray_import_err = None + # class RayWorkerWrapper(WorkerWrapperBase): + # """Ray wrapper for vllm.worker.Worker, allowing Worker to be + # lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES.""" + + # def __init__(self, *args, **kwargs) -> None: + # super().__init__(*args, **kwargs) + # # Since the compiled DAG runs a main execution + # # in a different thread that calls cuda.set_device. + # # The flag indicates is set_device is called on + # # that thread. + # self.compiled_dag_cuda_device_set = False + + # self.input_decoder = msgspec.msgpack.Decoder(ExecuteModelRequest, + # dec_hook=decode_hook) + # self.output_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook) + + # def get_node_ip(self) -> str: + # return get_ip() + + # def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]: + # node_id = ray.get_runtime_context().get_node_id() + # gpu_ids = ray.get_gpu_ids() + # return node_id, gpu_ids + + # def execute_model_spmd( + # self, req_or_tuple: Union[bytes, + # Tuple[bytes, + # Optional[IntermediateTensors]]] + # ) -> bytes: + # """Execute model in SPMD fashion: used only when SPMD worker and + # compiled DAG are both enabled. + + # Args: + # req_or_tuple: A request or a tuple containing the + # request and intermediate tensors. Intermediate tensors are + # None unless if it is provided because it is > 0 pipeline + # stage. The request is serialized by msgspec. + # """ + # if isinstance(req_or_tuple, bytes): + # serialized_req, intermediate_tensors = req_or_tuple, None + # else: + # serialized_req, intermediate_tensors = req_or_tuple + + # execute_model_req = self.input_decoder.decode(serialized_req) + + # # TODO(swang): This is needed right now because Ray aDAG executes + # # on a background thread, so we need to reset torch's current + # # device. + # import torch + # if not self.compiled_dag_cuda_device_set: + # torch.cuda.set_device(self.worker.device) + # self.compiled_dag_cuda_device_set = True + + # output = self.worker._execute_model_spmd(execute_model_req, + # intermediate_tensors) + # # Pipeline model request and output to the next pipeline stage. + # if isinstance(output, IntermediateTensors): + # output = serialized_req, output + # else: + # output = self.output_encoder.encode(output) + + # return output + + # def override_env_vars(self, vars: Dict[str, str]): + # os.environ.update(vars) + + # ray_import_err = None except ImportError as e: ray = None # type: ignore diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 1d5b6fad2e160..a5355131ca6d4 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -48,15 +48,11 @@ def forward( self, lm_head: VocabParallelEmbedding, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, embedding_bias: Optional[torch.Tensor] = None, ) -> Optional[torch.Tensor]: if self.logits_as_input: logits = hidden_states else: - hidden_states = _prune_hidden_states(hidden_states, - sampling_metadata) - # Get the logits for the next tokens. logits = self._get_logits(hidden_states, lm_head, embedding_bias) if logits is not None: @@ -69,7 +65,7 @@ def forward( logits *= self.scale # Apply logits processors (if any). - logits = _apply_logits_processors(logits, sampling_metadata) + # logits = _apply_logits_processors(logits, sampling_metadata) return logits @@ -105,14 +101,6 @@ def extra_repr(self) -> str: return s -def _prune_hidden_states( - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, -) -> torch.Tensor: - return hidden_states.index_select(0, - sampling_metadata.selected_token_indices) - - def _apply_logits_processors( logits: torch.Tensor, sampling_metadata: SamplingMetadata, diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 4ab2bb137488b..0f6cdb5536ce5 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -8,7 +8,7 @@ import torch.nn as nn import vllm.envs as envs -from vllm.sampler_output import SamplerOutput +from vllm.outputs_v2 import SamplerOutput from vllm.model_executor.sampling_metadata import SamplingMetadata _SAMPLING_EPS = 1e-5 @@ -25,13 +25,14 @@ def forward( sampling_metadata: SamplingMetadata, ) -> SamplerOutput: logits = self.apply_temperature(logits, sampling_metadata.temperature) - logits = self.apply_penalties(logits, sampling_metadata) probs = self.get_probs(logits) sampled = self.sample(probs, sampling_metadata) if sampling_metadata.max_num_logprobs > 0: logprobs = self.get_logprobs(logits) + # FIXME: Mask the sampled token_id, get topk logprobs, + # and concatenate the topk with the sampled token_id. topk_logprobs, topk_indices = torch.topk( logprobs, sampling_metadata.max_num_logprobs, dim=-1) else: @@ -44,8 +45,6 @@ def forward( logprobs=topk_logprobs, prompt_logprob_token_ids=None, prompt_logprobs=None, - model_forward_time=0.0, - model_execute_time=0.0, ) return sampler_output @@ -62,13 +61,6 @@ def apply_temperature( logits.div_(temp.unsqueeze(dim=1)) return logits - def apply_penalties( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> torch.Tensor: - return logits - def get_probs(self, logits: torch.Tensor) -> torch.Tensor: return torch.softmax(logits, dim=-1, dtype=torch.float32) @@ -76,23 +68,23 @@ def get_logprobs(self, logits: torch.Tensor) -> torch.Tensor: return torch.log_softmax(logits, dim=-1, dtype=torch.float32) def greedy_sample(self, probs: torch.Tensor) -> torch.Tensor: - return probs.argmax(dim=1).view(-1) + return probs.argmax(dim=-1).view(-1) def random_sample( self, probs: torch.Tensor, - generators: Optional[List[torch.Generator]], + generators: List[Optional[torch.Generator]], no_generator: bool, ) -> torch.Tensor: q = torch.empty_like(probs) if no_generator: q.exponential_() else: - assert generators is not None and len(generators) == probs.shape[0] + assert len(generators) == probs.shape[0] # TODO(woosuk): Optimize this. for i, generator in enumerate(generators): q[i].exponential_(generator=generator) - return probs.div_(q).argmax(dim=1).view(-1) + return probs.div_(q).argmax(dim=-1).view(-1) def sample( self, diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index 36f33d6d139ee..b11f73124e6df 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -15,7 +15,7 @@ import vllm.envs as envs from vllm.config import ModelConfig, ParallelConfig from vllm.engine.arg_utils import EngineArgs -from vllm.engine.llm_engine import LLMEngine +from vllm.engine.llm_engine_v2 import LLMEngine from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 5ff31e3833ec9..0f53112809ea6 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -48,12 +48,13 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors from vllm.utils import is_hip from .interfaces import SupportsLoRA from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers +IntermediateTensors = Dict[str, torch.Tensor] + class LlamaMLP(nn.Module): @@ -452,10 +453,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def sample( diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 38d6a4653ebd6..83119a1a4e82a 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -14,7 +14,6 @@ from vllm.model_executor.model_loader.loader import build_model from vllm.model_executor.models import ModelRegistry from vllm.multimodal.base import NestedTensors -from vllm.sequence import IntermediateTensors from vllm.utils import is_pin_memory_available diff --git a/vllm/outputs.py b/vllm/outputs.py deleted file mode 100644 index 85ea9196b25df..0000000000000 --- a/vllm/outputs.py +++ /dev/null @@ -1,270 +0,0 @@ -import time -from dataclasses import dataclass -from typing import List, Optional -from typing import Sequence as GenericSequence -from typing import Union - -from vllm.lora.request import LoRARequest -from vllm.sampling_params import RequestOutputKind -from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, - SequenceGroup, SequenceStatus) - - -@dataclass -class CompletionOutput: - """The output data of one completion output of a request. - - Args: - index: The index of the output in the request. - text: The generated output text. - token_ids: The token IDs of the generated output text. - cumulative_logprob: The cumulative log probability of the generated - output text. - logprobs: The log probabilities of the top probability words at each - position if the logprobs are requested. - finish_reason: The reason why the sequence is finished. - stop_reason: The stop string or token id that caused the completion - to stop, None if the completion finished for some other reason - including encountering the EOS token. - lora_request: The LoRA request that was used to generate the output. - """ - - index: int - text: str - token_ids: GenericSequence[int] - cumulative_logprob: Optional[float] - logprobs: Optional[SampleLogprobs] - finish_reason: Optional[str] = None - stop_reason: Union[int, str, None] = None - lora_request: Optional[LoRARequest] = None - - def finished(self) -> bool: - return self.finish_reason is not None - - def __repr__(self) -> str: - return (f"CompletionOutput(index={self.index}, " - f"text={self.text!r}, " - f"token_ids={self.token_ids}, " - f"cumulative_logprob={self.cumulative_logprob}, " - f"logprobs={self.logprobs}, " - f"finish_reason={self.finish_reason}, " - f"stop_reason={self.stop_reason})") - - -@dataclass -class EmbeddingOutput: - """The output data of one completion output of a request. - - Args: - embedding: The embedding vector, which is a list of floats. The - length of vector depends on the model as listed in the embedding guide. - """ - - embedding: List[float] - - def __repr__(self) -> str: - return (f"EmbeddingOutput(" - f"embedding={len(self.embedding)})") - - -class RequestOutput: - """The output data of a completion request to the LLM. - - Args: - request_id: The unique ID of the request. - prompt: The prompt string of the request. - For encoder/decoder models, this is the - decoder input prompt. - prompt_token_ids: The token IDs of the prompt. - For encoder/decoder models, this is the - decoder input prompt token ids. - prompt_logprobs: The log probabilities to return per prompt token. - outputs: The output sequences of the request. - finished: Whether the whole request is finished. - metrics: Metrics associated with the request. - lora_request: The LoRA request that was used to generate the output. - encoder_prompt: The encoder prompt string of the request; - None if decoder-only - encoder_prompt_token_ids: The token IDs of the encoder prompt; - None if decoder-only - """ - - def __init__( - self, - request_id: str, - prompt: Optional[str], - prompt_token_ids: Optional[List[int]], - prompt_logprobs: Optional[PromptLogprobs], - outputs: List[CompletionOutput], - finished: bool, - metrics: Optional[RequestMetrics] = None, - lora_request: Optional[LoRARequest] = None, - encoder_prompt: Optional[str] = None, - encoder_prompt_token_ids: Optional[List[int]] = None, - ) -> None: - self.request_id = request_id - self.prompt = prompt - self.prompt_token_ids = prompt_token_ids - self.prompt_logprobs = prompt_logprobs - self.outputs = outputs - self.finished = finished - self.metrics = metrics - self.lora_request = lora_request - self.encoder_prompt = encoder_prompt - self.encoder_prompt_token_ids = encoder_prompt_token_ids - - @classmethod - def from_seq_group(cls, - seq_group: SequenceGroup) -> Optional["RequestOutput"]: - sampling_params = seq_group.sampling_params - if sampling_params is None: - raise ValueError( - "Sampling parameters are missing for a CompletionRequest.") - finished = seq_group.is_finished() - if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and ( - not finished): - return None - - seqs = seq_group.get_seqs() - if len(seqs) == 1: - top_n_seqs = seqs - else: - # Get the top-n sequences. - n = sampling_params.n - if sampling_params.use_beam_search: - sorting_key = lambda seq: seq.get_beam_search_score( - sampling_params.length_penalty) - else: - sorting_key = lambda seq: seq.get_cumulative_logprob() - sorted_seqs = sorted(seqs, key=sorting_key, reverse=True) - top_n_seqs = sorted_seqs[:n] - - # Create the outputs. - # NOTE: We need omit logprobs here explicitly because the sequence - # always has the logprobs of the sampled tokens even if the - # logprobs are not requested. - include_logprobs = sampling_params.logprobs is not None - text_buffer_length = sampling_params.output_text_buffer_length - delta = sampling_params.output_kind == RequestOutputKind.DELTA - - outputs = [] - include_prompt = True - for seq in top_n_seqs: - output_text = seq.get_output_text_to_return( - text_buffer_length, delta) - output_token_ids = seq.get_output_token_ids_to_return(delta) - output_logprobs = seq.output_logprobs if include_logprobs else None - - if delta: - # Slice logprobs delta if applicable - if output_logprobs: - output_logprobs = output_logprobs[-len(output_token_ids):] - # Don't include prompt if this is after the first output - # containing decode token ids - if include_prompt and seq.get_output_len() > len( - output_token_ids): - include_prompt = False - - outputs.append( - CompletionOutput( - seqs.index(seq), output_text, output_token_ids, - seq.get_cumulative_logprob() if include_logprobs else None, - output_logprobs, - SequenceStatus.get_finished_reason(seq.status), - seq.stop_reason)) - - # Every sequence in the sequence group should have the same prompt. - if include_prompt: - prompt = seq_group.prompt - prompt_token_ids = seq_group.prompt_token_ids - encoder_prompt = seq_group.encoder_prompt - encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids - prompt_logprobs = seq_group.prompt_logprobs - else: - prompt = None - prompt_token_ids = None - encoder_prompt = None - encoder_prompt_token_ids = None - prompt_logprobs = None - finished_time = time.time() if finished else None - seq_group.set_finished_time(finished_time) - return cls(seq_group.request_id, - prompt, - prompt_token_ids, - prompt_logprobs, - outputs, - finished, - seq_group.metrics, - lora_request=seq_group.lora_request, - encoder_prompt=encoder_prompt, - encoder_prompt_token_ids=encoder_prompt_token_ids) - - def __repr__(self) -> str: - return (f"RequestOutput(request_id={self.request_id}, " - f"prompt={self.prompt!r}, " - f"prompt_token_ids={self.prompt_token_ids}, " - f"encoder_prompt={self.encoder_prompt!r}, " - f"encoder_prompt_token_ids={self.encoder_prompt_token_ids}, " - f"prompt_logprobs={self.prompt_logprobs}, " - f"outputs={self.outputs}, " - f"finished={self.finished}, " - f"metrics={self.metrics}, " - f"lora_request={self.lora_request})") - - -class EmbeddingRequestOutput: - """ - The output data of an embedding request to the LLM. - - Args: - request_id (str): A unique identifier for the embedding request. - outputs (EmbeddingOutput): The embedding results for the given input. - prompt_token_ids (List[int]): A list of token IDs used in the prompt. - finished (bool): A flag indicating whether the embedding is completed. - """ - - def __init__(self, request_id: str, outputs: "EmbeddingOutput", - prompt_token_ids: List[int], finished: bool): - self.request_id = request_id - self.prompt_token_ids = prompt_token_ids - self.finished = finished - self.outputs = outputs - - @classmethod - def from_seq_group(cls, - seq_group: 'SequenceGroup') -> "EmbeddingRequestOutput": - if seq_group.embeddings is None: - raise ValueError( - "Embeddings are missing in seq_group for EmbeddingRequest.") - output = EmbeddingOutput(seq_group.embeddings) - prompt_token_ids = seq_group.prompt_token_ids - finished = seq_group.is_finished() - - return cls(seq_group.request_id, output, prompt_token_ids, finished) - - def __repr__(self): - """ - Returns a string representation of an EmbeddingRequestOutput instance. - - The representation includes the request_id and the number of outputs, - providing a quick overview of the embedding request's results. - - Returns: - str: A string representation of the EmbeddingRequestOutput instance. - """ - return (f"EmbeddingRequestOutput(request_id='{self.request_id}', " - f"outputs={repr(self.outputs)}, " - f"prompt_token_ids={self.prompt_token_ids}, " - f"finished={self.finished})") - - -class RequestOutputFactory: - - @staticmethod - def create(seq_group): - # Determine the type based on a condition, for example: - if hasattr(seq_group, - 'embeddings') and seq_group.embeddings is not None: - return EmbeddingRequestOutput.from_seq_group(seq_group) - else: - return RequestOutput.from_seq_group(seq_group) diff --git a/vllm/outputs_v2.py b/vllm/outputs_v2.py new file mode 100644 index 0000000000000..a893907968643 --- /dev/null +++ b/vllm/outputs_v2.py @@ -0,0 +1,82 @@ +from dataclasses import dataclass +from typing import Dict, List, Optional, Union + +import torch + +from vllm.request import Request + + +@dataclass +class SamplerOutput: + + # [num_reqs] + sampled_token_ids: torch.Tensor + + # [num_reqs, max_num_logprobs + 1] + logprob_token_ids: Optional[torch.Tensor] + # [num_reqs, max_num_logprobs + 1] + logprobs: Optional[torch.Tensor] + + # TODO: Support prompt logprobs. + prompt_logprob_token_ids: Optional[torch.Tensor] + prompt_logprobs: Optional[torch.Tensor] + + +@dataclass +class ModelRunnerOutput: + + # [num_reqs] + req_ids: List[str] + # req_id -> index + req_id_to_index: Dict[str, int] + + # [num_reqs] + sampled_token_ids_cpu: torch.Tensor + + # [num_reqs, max_num_logprobs + 1] + logprob_token_ids_cpu: Optional[torch.Tensor] + # [num_reqs, max_num_logprobs + 1] + logprobs_cpu: Optional[torch.Tensor] + + +@dataclass +class CompletionOutput: + + index: int + text: str + token_ids: List[int] + logprobs: Optional[Dict[int, float]] + finish_reason: Optional[str] = None + stop_reason: Union[int, str, None] = None + + def finished(self) -> bool: + return self.finish_reason is not None + + +@dataclass +class RequestOutput: + + request_id: str + prompt: Optional[str] + prompt_token_ids: List[int] + outputs: List[CompletionOutput] + finished: bool + + @classmethod + def from_request(cls, request: Request) -> "RequestOutput": + # TODO: Support `n` > 1. + completion_output = CompletionOutput( + index=0, + text="", + token_ids=request.output_token_ids, + logprobs=None, + finish_reason=request.get_finished_reason(), + stop_reason=request.stop_reason, + ) + return cls( + request_id=request.request_id, + prompt=request.prompt, + prompt_token_ids=request.prompt_token_ids, + outputs=[completion_output], + finished=request.is_finished(), + ) diff --git a/vllm/request.py b/vllm/request.py index 69e4264249a65..7c0a25835ae95 100644 --- a/vllm/request.py +++ b/vllm/request.py @@ -35,19 +35,29 @@ def __init__( self.status = RequestStatus.WAITING self.stop_reason: Union[int, str, None] = None + assert sampling_params.max_tokens is not None self.max_tokens = sampling_params.max_tokens - self.num_prompt_tokens = len(inputs["prompt_token_ids"]) - self.num_output_tokens = 0 + self.prompt = inputs.get("prompt") + self.prompt_token_ids = inputs["prompt_token_ids"] + self.num_prompt_tokens = len(self.prompt_token_ids) + self.output_token_ids: List[int] = [] self.num_computed_tokens = 0 @property def num_tokens(self) -> int: - return self.num_prompt_tokens + self.num_output_tokens + return self.num_prompt_tokens + len(self.output_token_ids) + + @property + def num_output_tokens(self) -> int: + return len(self.output_token_ids) def is_finished(self) -> bool: return RequestStatus.is_finished(self.status) + def get_finished_reason(self) -> Union[str, None]: + return RequestStatus.get_finished_reason(self.status) + class RequestStatus(enum.IntEnum): """Status of a sequence.""" diff --git a/vllm/sampler_output.py b/vllm/sampler_output.py deleted file mode 100644 index 1fbc4ed8f6e3a..0000000000000 --- a/vllm/sampler_output.py +++ /dev/null @@ -1,19 +0,0 @@ -from dataclasses import dataclass -from typing import Optional - -import torch - - -@dataclass -class SamplerOutput: - - sampled_token_ids: torch.Tensor - - logprob_token_ids: Optional[torch.Tensor] - logprobs: Optional[torch.Tensor] - - prompt_logprob_token_ids: Optional[torch.Tensor] - prompt_logprobs: Optional[torch.Tensor] - - model_forward_time: float - model_execute_time: float diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py index d27d7ba9e67bb..fba9ea69a5f48 100644 --- a/vllm/transformers_utils/detokenizer.py +++ b/vllm/transformers_utils/detokenizer.py @@ -1,12 +1,14 @@ from typing import Dict, List, Optional, Tuple -from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup - from .tokenizer import AnyTokenizer from .tokenizer_group import BaseTokenizerGroup +from vllm.sampling_params import SamplingParams # Used eg. for marking rejected tokens in spec decoding. INVALID_TOKEN_ID = -1 +Sequence = None +SequenceGroup = None +Logprob = None class Detokenizer: diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py deleted file mode 100644 index 975b88c0e79a2..0000000000000 --- a/vllm/worker/model_runner_base.py +++ /dev/null @@ -1,267 +0,0 @@ -import dataclasses -import pickle -from abc import ABC, abstractmethod -from datetime import datetime -from functools import wraps -from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List, - Optional, Type, TypeVar) - -import torch -from torch import is_tensor - -from vllm.logger import init_logger -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.platforms import current_platform -from vllm.sequence import IntermediateTensors, SequenceGroupMetadata - -if TYPE_CHECKING: - from vllm.attention import AttentionMetadata - from vllm.attention.backends.abstract import AttentionBackend - from vllm.model_executor import SamplingMetadata - -logger = init_logger(__name__) - -T = TypeVar('T', bound="BroadcastableModelInput") - - -def _add_attn_metadata_broadcastable_dict( - tensor_dict: Dict[str, Any], - attn_metadata: Optional["AttentionMetadata"]) -> None: - """ - Helper method to update tensor_dict with broadcastable - AttentionMetadata fields. - """ - if attn_metadata is not None: - tensor_dict.update(attn_metadata.asdict_zerocopy()) - - -def _init_attn_metadata_from_tensor_dict( - attn_backend: "AttentionBackend", - tensor_dict: Dict[str, Any], -) -> Dict[str, Any]: - """ - Helper method to initialize AttentionMetadata based on an - AttentionBackend and broadcastable AttentionMetadata fields. - """ - # Extract the fields used to create AttentionMetadata. - valid_attn_kwargs = {} - for field in dataclasses.fields(attn_backend.get_metadata_cls()): - val = tensor_dict.pop(field.name, None) - if val is not None: - valid_attn_kwargs[field.name] = val - - attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs) - tensor_dict["attn_metadata"] = attn_metadata - return tensor_dict - - -def _init_sampling_metadata_from_tensor_dict( # type: ignore - tensor_dict: Dict[str, Any]) -> Dict[str, Any]: - """ - Helper method to initialize SamplingMetadata based on broadcastable - SamplingMetadata fields. - """ - from vllm.model_executor import SamplingMetadata - - selected_token_indices = tensor_dict.pop("selected_token_indices", None) - # An empty SamplingMetadata to signal that the worker should skip - # sampling. - if selected_token_indices is not None: - tensor_dict["sampling_metadata"] = SamplingMetadata( - seq_groups=None, - selected_token_indices=selected_token_indices, - categorized_sample_indices=None, - num_prompts=0, - ) - return tensor_dict - - -def _add_sampling_metadata_broadcastable_dict( - tensor_dict: Dict[str, Any], - sampling_metadata: Optional["SamplingMetadata"]) -> None: - """ - Helper method to update tensor_dict with broadcastable - SamplingMetadata fields. - """ - if sampling_metadata is not None: - tensor_dict["selected_token_indices"] = ( - sampling_metadata.selected_token_indices) - - -def _init_frozen_model_input_from_tensor_dict( - frozen_model_input_cls: Type["ModelRunnerInputBase"], - tensor_dict: Dict[str, Any]) -> Dict[str, Any]: - """ - Helper method to initialize a frozen ModelInput based on broadcastable - """ - valid_tensor_kwargs = {} - for field in dataclasses.fields(frozen_model_input_cls): - val = tensor_dict.pop(field.name, None) - if val is not None: - valid_tensor_kwargs[field.name] = val - - frozen_model_input = frozen_model_input_cls(**valid_tensor_kwargs) - tensor_dict["frozen_model_input"] = frozen_model_input - return tensor_dict - - -def dump_input_when_exception(exclude_args: Optional[List[int]] = None, - exclude_kwargs: Optional[List[str]] = None): - - def _inner(func): - - @wraps(func) - def _wrapper(*args, **kwargs): - try: - return func(*args, **kwargs) - except Exception as err: - timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") - filename = f"/tmp/err_{func.__name__}_input_{timestamp}.pkl" - logger.info("Writing input of failed execution to %s...", - filename) - with open(filename, "wb") as filep: - dumped_inputs = { - k: v - for k, v in kwargs.items() - if k not in (exclude_kwargs or []) - } - for i, arg in enumerate(args): - if i not in (exclude_args or []): - dumped_inputs[f"arg_{i}"] = arg - - # Only persist dtype and shape for kvcache tensors - # (can be way to big otherwise) - if (kv_caches := dumped_inputs.get("kv_caches")) \ - and isinstance(kv_caches, Iterable): - dumped_inputs["kv_caches"] = [(t.dtype, t.shape) - for t in kv_caches - if is_tensor(t)] - - pickle.dump(dumped_inputs, filep) - logger.info( - "Completed writing input of failed execution to %s.", - filename) - raise type(err)( - f"Error in model execution (input dumped to {filename}): " - f"{str(err)}") from err - - return _wrapper - - return _inner - - -class BroadcastableModelInput(ABC): - - @abstractmethod - def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: - """ - Extract broadcastable fields. Override for fields that require some - custom deserialization. - """ - raise NotImplementedError - - @classmethod - @abstractmethod - def from_broadcasted_tensor_dict( - cls: Type[T], - tensor_dict: Dict[str, Any], - attn_backend: Optional["AttentionBackend"] = None, - ) -> T: - """ - Pop fields from the given tensor_dict and populate a new instance of - BroadcastableModelInput. - """ - raise NotImplementedError - - -@dataclasses.dataclass(frozen=True) -class ModelRunnerInputBase(BroadcastableModelInput): - """Local inputs to each worker's model runner. May contain - device-specific data. Different worker backends may have different methods - of converting from the global ExecuteModelRequest produced by the LLM - engine to the worker-local ModelRunnerInputBase objects. - - Model runners that support multi-GPU execution should define a - ModelRunnerInputBase subclass, add their required fields, and specify how to - serialize/deserialize a ModelInput for broadcast between workers. - """ - pass - - -class ModelRunnerInputBuilderBase(ABC, Generic[T]): - """A builder to create ModelRunnerInputBase objects. - """ - - @abstractmethod - def add_seq_group(self, seq_group_metadata): - """TBA""" - raise NotImplementedError - - @abstractmethod - def build(self, *args, **kwargs) -> T: - """Build metadata with on-device tensors.""" - raise NotImplementedError - - -class ModelRunnerBase(ABC, Generic[T]): - """ - Model runner interface that abstracts a particular hardware and/or type of - model. Model execution may communicate data with model runners in other - processes, but it should not include control plane metadata communication. - - Each ModelRunnerBase subclass should define a corresponding - ModelRunnerInputBase subclass. - """ - - # Map of request_id -> generator used for seeded random sampling - generators: Dict[str, torch.Generator] = {} - - @abstractmethod - def make_model_input_from_broadcasted_tensor_dict( - self, - tensor_dict: Dict[str, Any], - ) -> T: - """ - Make an instance of a ModelRunnerInputBase from the broadcasted tensor - dict. - """ - raise NotImplementedError - - @abstractmethod - def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None, - ) -> T: - """ - Prepare the inputs to ModelRunnerBase.execute_model from an execution - request. This method may move data to the worker's local device. It is - not allowed to communicate with other workers or devices. - """ - raise NotImplementedError - - @current_platform.inference_mode() - def execute_model( - self, - model_input: T, - kv_caches: Optional[List[torch.Tensor]], - intermediate_tensors: Optional[IntermediateTensors], - num_steps: int = 1, - ) -> Optional[List[SamplerOutput]]: - """ - Execute the model on the given input. - """ - raise NotImplementedError - - def get_generators(self, finished_request_ids: Optional[List[str]] = None): - """ - Return dict of per-request generators used for random sampling. - """ - - # Clean up generators from completed requests - if finished_request_ids: - for request_id in finished_request_ids: - self.generators.pop(request_id, None) - - return self.generators diff --git a/vllm/worker/model_runner_v2.py b/vllm/worker/model_runner_v2.py index 1901101beb51f..bbdb76dce2fd5 100644 --- a/vllm/worker/model_runner_v2.py +++ b/vllm/worker/model_runner_v2.py @@ -8,24 +8,20 @@ import torch.nn as nn import vllm.envs as envs -from vllm.attention import AttentionMetadata, get_attn_backend -from vllm.attention.backends.abstract import AttentionState -from vllm.attention.backends.utils import CommonAttentionState from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) -from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata -from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.outputs_v2 import SamplerOutput, ModelRunnerOutput from vllm.model_executor.model_loader import get_model from vllm.sampling_params import SamplingParams, SamplingType -from vllm.utils import (DeviceMemoryProfiler, is_pin_memory_available) -from vllm.worker.model_runner_base import dump_input_when_exception +from vllm.utils import (DeviceMemoryProfiler, is_pin_memory_available, cdiv, + STR_DTYPE_TO_TORCH_DTYPE) from vllm.multimodal import MultiModalDataDict +from vllm.attention.backends.flash_attn import FlashAttentionMetadata if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionBackend from vllm.core.scheduler_v2 import SchedulerOutput logger = init_logger(__name__) @@ -41,7 +37,6 @@ def __init__( device_config: DeviceConfig, cache_config: CacheConfig, load_config: LoadConfig, - kv_cache_dtype: Optional[str] = "auto", lora_config: Optional[LoRAConfig] = None, prompt_adapter_config: Optional[PromptAdapterConfig] = None, observability_config: Optional[ObservabilityConfig] = None, @@ -58,22 +53,28 @@ def __init__( self.device = self.device_config.device self.pin_memory = is_pin_memory_available() + self.dtype = self.model_config.dtype + if cache_config.cache_dtype == "auto": + self.kv_cache_dtype = self.dtype + else: + self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ + cache_config.cache_dtype] - self.kv_cache_dtype = kv_cache_dtype self.sliding_window = model_config.get_sliding_window() self.block_size = cache_config.block_size - self.max_num_blocks_per_req = ( - (self.model_config.max_model_len + self.block_size - 1) // - self.block_size) + self.max_model_len = model_config.max_model_len + self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) + self.max_num_tokens = scheduler_config.max_num_batched_tokens # Lazy initialization self.model: nn.Module # Set after load_model self.kv_caches: List[torch.Tensor] = [] # Request states. - self.requests: Dict[str, RequestState] = {} - self.batched_states = BatchedRequestStates( + self.requests: Dict[str, CachedRequestState] = {} + self.persistent_batch = PersistentBatch( max_num_reqs=self.scheduler_config.max_num_seqs, + max_model_len=self.max_model_len, max_num_blocks_per_req=self.max_num_blocks_per_req, device=self.device, pin_memory=self.pin_memory, @@ -87,60 +88,43 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: for req_id in scheduler_output.aborted_req_ids: self.requests.pop(req_id, None) - # Remove the requests from the batched states. - stopped_req_ids = (scheduler_output.preempted_req_ids + - scheduler_output.finished_req_ids + - scheduler_output.aborted_req_ids) + # Remove the requests from the persistent batch. + stopped_req_ids = set().union( + scheduler_output.preempted_req_ids, + scheduler_output.finished_req_ids, + scheduler_output.aborted_req_ids, + ) removed_req_indices: List[int] = [] for req_id in stopped_req_ids: - req_index = self.batched_states.remove_request(req_id) + req_index = self.persistent_batch.remove_request(req_id) if req_index is not None: removed_req_indices.append(req_index) - # Condense the batched states. - # We condense the states before adding new/resumed requests - # because the attention backend may require it. - if removed_req_indices: - self.batched_states.condense(removed_req_indices) - # Update the states of the running requests. - num_prev_blocks: Dict[str, int] = {} - new_block_ids: Dict[str, List[int]] = {} for req_data in scheduler_output.scheduled_running_reqs: req_id = req_data.req_id req_state = self.requests[req_id] - - num_prev_blocks[req_id] = len(req_state.block_ids) - new_block_ids[req_id] = req_data.new_block_ids req_state.block_ids.extend(req_data.new_block_ids) - req_state.num_computed_tokens = req_data.num_computed - # Update the block table and the number of computed tokens - # of the running requests. - for req_id in self.batched_states.req_ids: - if req_id is None: - continue - start_block_index = num_prev_blocks[req_id] - block_ids = new_block_ids[req_id] - end_block_index = start_block_index + len(block_ids) - self.batched_states.block_table_cpu[ - req_index, start_block_index:end_block_index] = block_ids - self.batched_states.num_computed_tokens_cpu[req_index] = ( - self.requests[req_id].num_computed_tokens) + req_state.num_computed_tokens = req_data.num_computed_tokens + + # Update the block table and num_computed_tokens. + req_index = self.persistent_batch.req_id_to_index[req_id] + end_block_index = len(req_state.block_ids) + start_block_index = end_block_index - len(req_data.new_block_ids) + self.persistent_batch.block_table_cpu[ + req_index, + start_block_index:end_block_index] = torch.as_tensor( + req_data.new_block_ids, dtype=torch.int32, device="cpu") + self.persistent_batch.num_computed_tokens_cpu[req_index] = ( + req_data.num_computed_tokens) req_ids_to_add: List[str] = [] # Add new requests to the cached states. for req_data in scheduler_output.scheduled_new_reqs: req_id = req_data.req_id - prompt_token_ids_cpu = torch.tensor(req_data.prompt_token_ids, - device="cpu", - pin_memory=self.pin_memory) - prompt_token_ids = prompt_token_ids_cpu.to(self.device, - non_blocking=True) - - self.requests[req_id] = RequestState( + self.requests[req_id] = CachedRequestState( req_id=req_id, - prompt_token_ids=prompt_token_ids, - prompt_token_ids_cpu=prompt_token_ids_cpu, + prompt_token_ids=req_data.prompt_token_ids, prompt=req_data.prompt, multi_modal_data=req_data.multi_modal_data, sampling_params=req_data.sampling_params, @@ -160,40 +144,217 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: req_state.num_computed_tokens = req_data.num_computed_tokens req_ids_to_add.append(req_id) - # Add the new or resumed requests to the batched states. + # Add the new or resumed requests to the persistent batch. + # The smaller empty indices are filled first. + removed_req_indices = sorted(removed_req_indices, reverse=True) for req_id in req_ids_to_add: req_state = self.requests[req_id] - self.batched_states.add_request(req_state) + if removed_req_indices: + # Fill the empty index. + req_index = removed_req_indices.pop() + else: + # Append to the end. + req_index = None + self.persistent_batch.add_request(req_state, req_index) + + # Condense the batched states if there are empty indices. + if removed_req_indices: + self.persistent_batch.condense(removed_req_indices) def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): - assert scheduler_output.total_num_scheduled_tokens > 0 - num_scheduled_tokens = scheduler_output.num_scheduled_tokens + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + assert total_num_scheduled_tokens > 0 + num_reqs = self.persistent_batch.num_reqs + assert num_reqs > 0 + + # OPTIMIZATION: Start copying the block table first. + # This way, we can overlap the copy with the following CPU operations. + self.persistent_batch.block_table[:num_reqs].copy_( + self.persistent_batch.block_table_cpu[:num_reqs], + non_blocking=True) + + # Get the number of scheduled tokens for each request. + # TODO: The Python loop can be slow. Optimize. + num_scheduled_tokens = [] + max_num_scheduled_tokens = 0 + for req_id in self.persistent_batch.req_ids[:num_reqs]: + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + num_scheduled_tokens.append(num_tokens) + max_num_scheduled_tokens = max(max_num_scheduled_tokens, + num_tokens) + num_scheduled_tokens = np.array(num_scheduled_tokens, dtype=np.int32) + assert max_num_scheduled_tokens > 0 + + # Get request indices. + # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] + indices = np.arange(num_reqs) + req_indices = np.repeat(indices, num_scheduled_tokens) + req_indices = torch.from_numpy(req_indices) + + # Get batched arange. + # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + arange_matrix = np.tile(np.arange(max_num_scheduled_tokens), + (num_reqs, 1)) + mask = arange_matrix < num_scheduled_tokens[:, np.newaxis] + arange = arange_matrix[mask] + arange = torch.from_numpy(arange) + + # Get positions. + positions = torch.empty((total_num_scheduled_tokens, ), + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + torch.add(self.persistent_batch.num_computed_tokens_cpu[req_indices], + arange, + out=positions) + + # Get token indices. + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] + # where M is the max_model_len. + token_indices = positions + req_indices * self.max_model_len + input_ids = torch.empty((total_num_scheduled_tokens, ), + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + torch.index_select(self.persistent_batch.token_ids_cpu.flatten(), + 0, + token_indices, + out=input_ids) + + # Calculate the slot mapping. + block_numbers = self.persistent_batch.block_table_cpu.flatten()[ + token_indices // self.block_size] + block_offsets = token_indices % self.block_size + slot_mapping = torch.empty((total_num_scheduled_tokens, ), + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + torch.add(block_numbers * self.block_size, + block_offsets, + out=slot_mapping) + + # Prepare the attention metadata. + num_scheduled_tokens = torch.from_numpy(num_scheduled_tokens) + query_start_loc = torch.empty((num_reqs + 1, ), + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + query_start_loc[0] = 0 + torch.cumsum(num_scheduled_tokens, dim=0, out=query_start_loc[1:]) + + seq_lens = (self.persistent_batch.num_computed_tokens_cpu[:num_reqs] + + num_scheduled_tokens) + max_seq_len = seq_lens.max().item() + seq_start_loc = torch.empty((num_reqs + 1, ), + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + seq_start_loc[0] = 0 + torch.cumsum(seq_lens, dim=0, out=seq_start_loc[1:]) + + # Move the tensors to the device. + input_ids = input_ids.to(self.device, non_blocking=True) + positions = positions.to(self.device, non_blocking=True).long() + query_start_loc = query_start_loc.to(self.device, non_blocking=True) + seq_start_loc = seq_start_loc.to(self.device, non_blocking=True) + slot_mapping = slot_mapping.to(self.device, non_blocking=True).long() + attn_metadata = FlashAttentionMetadata( + max_query_len=max_num_scheduled_tokens, + query_start_loc=query_start_loc, + max_seq_len=max_seq_len, + seq_start_loc=seq_start_loc, + block_table=self.persistent_batch.block_table[:num_reqs], + slot_mapping=slot_mapping, + ) + # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial + # request in the batch. While we should not sample any token from this + # partial request, we do so for simplicty. We will ignore the sampled + # token from the partial request. + # TODO: Support prompt logprobs. + logits_indices = query_start_loc[1:] - 1 + return input_ids, positions, attn_metadata, logits_indices + + def _prepare_sampling( + self, + scheduler_output: "SchedulerOutput", + ) -> SamplingMetadata: + skip_copy = True + if (scheduler_output.aborted_req_ids + or scheduler_output.finished_req_ids + or scheduler_output.preempted_req_ids): + skip_copy = False + if (scheduler_output.scheduled_new_reqs + or scheduler_output.scheduled_resumed_reqs): + skip_copy = False + # Create the sampling metadata. + sampling_metadata = self.persistent_batch.make_sampling_metadata( + skip_copy) + return sampling_metadata @torch.inference_mode() - @dump_input_when_exception(exclude_args=[0], exclude_kwargs=["self"]) def execute_model( self, scheduler_output: "SchedulerOutput", - ) -> SamplerOutput: + ) -> ModelRunnerOutput: self._update_states(scheduler_output) inputs = self._prepare_inputs(scheduler_output) - input_ids, position_ids, attn_metadata = inputs - # Create the sampling metadata. - sampling_metadata = self.batched_states.get_sampling_metadata() + input_ids, positions, attn_metadata, logits_indices = inputs hidden_states = self.model( input_ids=input_ids, - position_ids=position_ids, - attn_metadata=attn_metadata, + positions=positions, kv_caches=self.kv_caches, + attn_metadata=attn_metadata, ) + hidden_states = hidden_states[logits_indices] + logits = self.model.compute_logits(hidden_states) - logits = self.model.compute_logits(hidden_states, sampling_metadata) # Sample the next token and get logprobs if needed. + sampling_metadata = self._prepare_sampling(scheduler_output) sampler_output = self.model.sample( logits=logits, sampling_metadata=sampling_metadata, ) - return sampler_output + + # CPU-GPU synchronization happens here. + # TODO: Optimize. + sampled_token_ids = sampler_output.sampled_token_ids.cpu() + num_reqs = self.persistent_batch.num_reqs + for i, req_id in enumerate(self.persistent_batch.req_ids[:num_reqs]): + req_state = self.requests[req_id] + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + num_prompt_tokens = len(req_state.prompt_token_ids) + if seq_len >= num_prompt_tokens: + # Append the sampled token to the output token ids. + token_id = sampled_token_ids[i] + self.persistent_batch.token_ids_cpu[i, seq_len] = token_id + req_state.output_token_ids.append(token_id) + else: + # Ignore the sampled token from the partial request. + # Rewind the generator state as if the token was not sampled. + generator = self.persistent_batch.generators[i] + if generator is not None: + offset = generator.get_offset() + generator = generator.set_offset(offset - 1) + self.persistent_batch.generators[i] = generator + + if sampler_output.logprob_token_ids is None: + logprob_token_ids = None + else: + logprob_token_ids = sampler_output.logprob_token_ids.cpu() + if sampler_output.logprobs is None: + logprobs = None + else: + logprobs = sampler_output.logprobs.cpu() + model_runner_output = ModelRunnerOutput( + req_ids=self.persistent_batch.req_ids[:num_reqs], + req_id_to_index=self.persistent_batch.req_id_to_index, + sampled_token_ids_cpu=sampled_token_ids, + logprob_token_ids_cpu=logprob_token_ids, + logprobs_cpu=logprobs, + ) + return model_runner_output def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) @@ -212,36 +373,75 @@ def load_model(self) -> None: @torch.inference_mode() def profile_run(self) -> None: + # FIXME + hidden_size = self.model_config.get_hidden_size() + intermediate_size = int(hidden_size * 3.5) + tp_size = self.parallel_config.tensor_parallel_size + d = max(6 * hidden_size, 4 * intermediate_size // tp_size) + tmp = torch.empty((self.max_num_tokens, d), + dtype=self.dtype, + device=self.device) + return + + def capture_model(self) -> None: return - def initialize_kv_cache(self) -> None: - ... + def initialize_kv_cache(self, num_blocks: int) -> None: + assert len(self.kv_caches) == 0 + # Models like Jamba, have mixed typed layers, E.g Mamba + num_attn_layers = self.model_config.get_num_attention_layers( + self.parallel_config) + num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) + head_size = self.model_config.get_head_size() + for _ in range(num_attn_layers): + kv_cache_shape = (2, num_blocks, self.block_size, num_kv_heads, + head_size) + self.kv_caches.append( + torch.zeros(kv_cache_shape, + dtype=self.kv_cache_dtype, + device=self.device)) -class BatchedRequestStates: +@dataclass +class CachedRequestState: + + req_id: str + prompt_token_ids: List[int] + prompt: Optional[str] + multi_modal_data: Optional["MultiModalDataDict"] + sampling_params: SamplingParams + generator: Optional[torch.Generator] + + block_ids: List[int] + num_computed_tokens: int + output_token_ids: List[int] + + +class PersistentBatch: def __init__( self, max_num_reqs: int, + max_model_len: int, max_num_blocks_per_req: int, device: torch.device, pin_memory: bool, ): self.max_num_reqs = max_num_reqs + self.max_model_len = max_model_len self.max_num_blocks_per_req = max_num_blocks_per_req self.device = device self.pin_memory = pin_memory - self.num_reqs = 0 self.req_ids: List[Optional[str]] = [None] * max_num_reqs + self.req_id_to_index: Dict[str, int] = {} - self.num_computed_tokens = torch.empty((max_num_reqs, ), - dtype=torch.int32, - device=device) + self.token_ids_cpu = torch.empty((max_num_reqs, max_model_len), + dtype=torch.int32, + device="cpu") self.num_computed_tokens_cpu = torch.empty((max_num_reqs, ), dtype=torch.int32, - device="cpu", - pin_memory=pin_memory) + device="cpu") # Attention-related. self.block_table = torch.empty((max_num_reqs, max_num_blocks_per_req), @@ -290,7 +490,7 @@ def __init__( def add_request( self, - request: "RequestState", + request: "CachedRequestState", req_index: Optional[int] = None, ) -> None: if req_index is None: @@ -298,11 +498,20 @@ def add_request( assert req_index < self.max_num_reqs self.req_ids[req_index] = request.req_id - self.num_reqs += 1 + self.req_id_to_index[request.req_id] = req_index + + # Copy the prompt token ids and output token ids. + # TODO: Optimize. + for i, token_id in enumerate(request.prompt_token_ids): + self.token_ids_cpu[req_index, i] = token_id + num_prompt_tokens = len(request.prompt_token_ids) + for i, token_id in enumerate(request.output_token_ids): + self.token_ids_cpu[req_index, num_prompt_tokens + i] = token_id self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens - self.block_table_cpu[ - req_index, :len(request.block_ids)] = request.block_ids + # TODO: Optimize. + for i, block_id in enumerate(request.block_ids): + self.block_table_cpu[req_index, i] = block_id sampling_params = request.sampling_params self.temperature_cpu[req_index] = sampling_params.temperature @@ -326,15 +535,14 @@ def add_request( num_logprobs = sampling_params.logprobs if num_logprobs is not None and num_logprobs > 0: self.num_logprobs[request.req_id] = num_logprobs - if sampling_params.prompt_logprob: + if sampling_params.prompt_logprobs: self.prompt_logprob_reqs.add(req_index) def remove_request(self, req_id: str) -> Optional[int]: - if not req_id in self.req_ids: + req_index = self.req_id_to_index.pop(req_id, None) + if req_index is None: return None - req_index = self.req_ids.index(req_id) self.req_ids[req_index] = None - self.num_reqs -= 1 self.greedy_reqs.discard(req_id) self.random_reqs.discard(req_id) @@ -346,7 +554,8 @@ def remove_request(self, req_id: str) -> Optional[int]: return req_index def clear(self) -> None: - self.num_reqs = 0 + self.req_ids = [None] * self.max_num_reqs + self.req_id_to_index.clear() self.greedy_reqs.clear() self.random_reqs.clear() self.top_p_reqs.clear() @@ -356,12 +565,12 @@ def clear(self) -> None: self.prompt_logprob_reqs.clear() def condense(self, empty_req_indices: List[int]) -> None: - # TODO(woosuk): Consider LoRA. if self.num_reqs == 0: # The batched states are empty. return - empty_req_indices = sorted(empty_req_indices, reverse=True) + # NOTE(woosuk): This function assumes that the empty_req_indices + # is sorted in descending order. last_req_index = self.num_reqs + len(empty_req_indices) - 1 while empty_req_indices: # Find the largest non-empty index. @@ -369,12 +578,20 @@ def condense(self, empty_req_indices: List[int]) -> None: last_req_index -= 1 # Find the smallest empty index. - empty_index = empty_req_indices.pop() + empty_index = empty_req_indices.pop() if empty_index >= last_req_index: break # Swap the states. - self.req_ids[empty_index] = self.req_ids[last_req_index] + req_id = self.req_ids[last_req_index] + self.req_ids[empty_index] = req_id + self.req_ids[last_req_index] = None + self.req_id_to_index[req_id] = empty_index + + # TODO(woosuk): Optimize the copy of token_ids_cpu and + # block_table_cpu. + self.token_ids_cpu[empty_index] = self.token_ids_cpu[ + last_req_index] self.num_computed_tokens_cpu[ empty_index] = self.num_computed_tokens_cpu[last_req_index] self.block_table_cpu[empty_index] = self.block_table_cpu[ @@ -385,6 +602,37 @@ def condense(self, empty_req_indices: List[int]) -> None: self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] self.generators[empty_index] = self.generators[last_req_index] + # Decrement last_req_index since it is now empty. + last_req_index -= 1 + + def make_sampling_metadata( + self, + skip_copy: bool = False, + ) -> SamplingMetadata: + if not skip_copy: + self.temperature[:self.num_reqs].copy_( + self.temperature_cpu[:self.num_reqs], non_blocking=True) + self.top_p[:self.num_reqs].copy_(self.top_p_cpu[:self.num_reqs], + non_blocking=True) + self.top_k[:self.num_reqs].copy_(self.top_k_cpu[:self.num_reqs], + non_blocking=True) + return SamplingMetadata( + temperature=self.temperature[:self.num_reqs], + all_greedy=self.all_greedy, + all_random=self.all_random, + top_p=self.top_p[:self.num_reqs], + top_k=self.top_k[:self.num_reqs], + no_top_p=self.no_top_p, + no_top_k=self.no_top_k, + generators=self.generators[:self.num_reqs], + no_generator=self.no_generator, + max_num_logprobs=self.max_num_logprobs, + ) + + @property + def num_reqs(self) -> int: + return len(self.req_id_to_index) + @property def all_greedy(self) -> bool: return len(self.random_reqs) == 0 @@ -407,7 +655,10 @@ def no_generator(self) -> bool: @property def max_num_logprobs(self) -> int: - return max(self.num_logprobs.values()) + if self.num_logprobs: + return max(self.num_logprobs.values()) + else: + return 0 @property def no_logprob(self) -> bool: @@ -416,33 +667,3 @@ def no_logprob(self) -> bool: @property def no_prompt_logprob(self) -> bool: return len(self.prompt_logprob_reqs) == 0 - - def get_sampling_metadata(self) -> SamplingMetadata: - return SamplingMetadata( - temperature=self.temperature[:self.num_reqs], - all_greedy=self.all_greedy, - all_random=self.all_random, - top_p=self.top_p[:self.num_reqs], - top_k=self.top_k[:self.num_reqs], - no_top_p=self.no_top_p, - no_top_k=self.no_top_k, - generators=self.generators[:self.num_reqs], - no_generator=self.no_generator, - max_num_logprobs=self.max_num_logprobs, - ) - - -@dataclass -class RequestState: - - req_id: str - prompt_token_ids: torch.Tensor - prompt_token_ids_cpu: torch.Tensor - prompt: Optional[str] - multi_modal_data: Optional["MultiModalDataDict"] - sampling_params: SamplingParams - generator: Optional[torch.Generator] - - block_ids: List[int] - num_computed_tokens: int - output_token_ids: List[int] diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py deleted file mode 100644 index 6ba4f272315ce..0000000000000 --- a/vllm/worker/worker_base.py +++ /dev/null @@ -1,485 +0,0 @@ -import dataclasses -import importlib -import os -import time -from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union - -import torch - -from vllm.config import ObservabilityConfig -from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.platforms import current_platform -from vllm.sequence import ExecuteModelRequest, IntermediateTensors -from vllm.utils import (enable_trace_function_call_for_thread, - update_environment_variables) -from vllm.worker.model_runner_base import (BroadcastableModelInput, - ModelRunnerBase, - ModelRunnerInputBase) - -logger = init_logger(__name__) - - -class WorkerBase(ABC): - """Worker interface that allows vLLM to cleanly separate implementations for - different hardware. Also abstracts control plane communication, e.g., to - communicate request metadata to other workers. - """ - - @abstractmethod - def init_device(self) -> None: - """Initialize device state, such as loading the model or other on-device - memory allocations. - """ - raise NotImplementedError - - @abstractmethod - def determine_num_available_blocks(self) -> Tuple[int, int]: - """Determine the number of available blocks for the GPU KV cache and - swappable CPU KV cache. - - The implementation may run profiling or other heuristics to determine - the size of caches. - - Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks - are blocks that are "active" on the device and can be appended to. - num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be - appended to. - """ - raise NotImplementedError - - @abstractmethod - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: - """Initialize the KV cache with the given size in blocks. - """ - raise NotImplementedError - - @current_platform.inference_mode() - def start_worker_execution_loop(self) -> None: - """Execute model loop in parallel worker. - - You can stop the loop by executing a driver worker with an empty output. - See `stop_remote_worker_execution_loop` for more details. - """ - while True: - output = self.execute_model(execute_model_req=None) - if output is None: - return None - - @abstractmethod - def execute_model( - self, - execute_model_req: Optional[ExecuteModelRequest] = None - ) -> Optional[List[SamplerOutput]]: - raise NotImplementedError - - @abstractmethod - def get_cache_block_size_bytes(self) -> int: - """Return the size of a single cache block, in bytes. Used in - speculative decoding. - """ - raise NotImplementedError - - @abstractmethod - def add_lora(self, lora_request: LoRARequest) -> bool: - raise NotImplementedError - - @abstractmethod - def remove_lora(self, lora_id: int) -> bool: - raise NotImplementedError - - @abstractmethod - def pin_lora(self, lora_id: int) -> bool: - raise NotImplementedError - - @abstractmethod - def list_loras(self) -> Set[int]: - raise NotImplementedError - - -class LoraNotSupportedWorkerBase(WorkerBase): - """Partial implementation of WorkerBase that raises exceptions when LoRA - methods are invoked. - """ - - def add_lora(self, lora_request: LoRARequest) -> bool: - raise ValueError(f"{type(self)} does not support LoRA") - - def remove_lora(self, lora_id: int) -> bool: - raise ValueError(f"{type(self)} does not support LoRA") - - def pin_lora(self, lora_id: int) -> bool: - return ValueError( - f"{type(self)} does not support LoRA") # type: ignore - - def list_loras(self) -> Set[int]: - raise ValueError(f"{type(self)} does not support LoRA") - - -@dataclasses.dataclass(frozen=True) -class WorkerInput: - """Local inputs to each worker. May contain device-specific data. These - fields should be broadcastable to other workers. - """ - - num_seq_groups: Optional[int] = None - blocks_to_swap_in: Optional[torch.Tensor] = None - blocks_to_swap_out: Optional[torch.Tensor] = None - blocks_to_copy: Optional[torch.Tensor] = None - virtual_engine: int = 0 - num_steps: int = 1 - - @classmethod - def from_broadcasted_tensor_dict( - cls: Type["WorkerInput"], - tensor_dict: Dict[str, Any], - ) -> "WorkerInput": - """ - Pop fields from the given tensor_dict and populate a new instance of - WorkerInput. - """ - return cls( - num_seq_groups=tensor_dict.pop("num_seq_groups"), - blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"), - blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"), - blocks_to_copy=tensor_dict.pop("blocks_to_copy"), - virtual_engine=tensor_dict["virtual_engine"], - num_steps=tensor_dict.pop("num_steps"), - ) - - def as_broadcastable_tensor_dict( - self) -> Dict[str, Union[int, torch.Tensor]]: - """ - Extract broadcastable fields. - """ - tensor_dict = { - "num_seq_groups": self.num_seq_groups, - "blocks_to_swap_in": self.blocks_to_swap_in, - "blocks_to_swap_out": self.blocks_to_swap_out, - "blocks_to_copy": self.blocks_to_copy, - "virtual_engine": self.virtual_engine, - "num_steps": self.num_steps, - } - - return tensor_dict - - -class LocalOrDistributedWorkerBase(WorkerBase): - """ - Partial implementation of WorkerBase that has a default `execute_model` - definition to perform metadata transfer between workers when in distributed - mode. Subclasses of this interface should use model runners that inherit - from ModelRunnerBase, and should only need to implement worker-local logic. - If custom control plane logic is needed to transfer metadata, or if the - model runner cannot inherit from ModelRunnerBase, use WorkerBase instead. - """ - is_driver_worker: bool - model_runner: ModelRunnerBase - observability_config: Optional[ObservabilityConfig] = None - - @property - @abstractmethod - def do_metadata_broadcast(self) -> bool: - """ - Used by the default `execute_model` to check whether broadcast is - needed to transfer request inputs from the driver worker to other - workers in the TP group. If WorkerBase subclass only supports - single-worker execution, then this method should return False. - """ - raise NotImplementedError - - @property - @abstractmethod - def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: - """ - Gets the list of kv caches to pass to the worker's model runner. Each - element in the list is a kv cache corresponding to a particular virtual - engine (PP stream). Used by the default `execute_model`. If the worker's - model runner does not follow the ModelRunnerBase interface, then inherit - from WorkerBase instead. - """ - raise NotImplementedError - - @abstractmethod - def prepare_worker_input( - self, execute_model_req: ExecuteModelRequest) -> WorkerInput: - """ - Prepare the inputs to WorkerBase.execute_worker from an execution - request. This method may move data to the worker's local device. It is - not allowed to communicate with other workers or devices. - """ - raise NotImplementedError - - @abstractmethod - def execute_worker(self, worker_input: WorkerInput) -> None: - """ - Process an execution request. - """ - raise NotImplementedError - - def _get_worker_input_from_broadcast( - self - ) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[ - str, torch.Tensor]]]: - """ Get the worker input from the broadcasted tensor dict. """ - assert self.do_metadata_broadcast - assert not self.is_driver_worker - broadcast_data = broadcast_tensor_dict(src=0) - if not broadcast_data: - return None - - worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data) - model_input = ( - self.model_runner.make_model_input_from_broadcasted_tensor_dict( - broadcast_data)) - - kwargs = extract_previous_hidden_states(broadcast_data) - - return model_input, worker_input, kwargs - - def _get_driver_input_and_broadcast( - self, execute_model_req: ExecuteModelRequest - ) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]: - """ Get the driver input and broadcast it to other workers. """ - assert self.is_driver_worker - - worker_input: WorkerInput = self.prepare_worker_input( - execute_model_req=execute_model_req) - model_input: ModelRunnerInputBase = ( - self.model_runner.prepare_model_input( - execute_model_req.seq_group_metadata_list, - execute_model_req.virtual_engine, - execute_model_req.finished_requests_ids)) - - kwargs = extract_previous_hidden_states(execute_model_req) - - if self.do_metadata_broadcast: - broadcast_data = worker_input.as_broadcastable_tensor_dict() - broadcast_data.update(model_input.as_broadcastable_tensor_dict()) - broadcast_data.update(kwargs) - broadcast_tensor_dict(broadcast_data, src=0) - - if execute_model_req.async_callback: - model_input = dataclasses.replace( # type: ignore - model_input, - async_callback=execute_model_req.async_callback) - - return model_input, worker_input, kwargs - - def prepare_input( - self, - execute_model_req: Optional[ExecuteModelRequest] = None - ) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[ - str, torch.Tensor]]]: - """ - Prepare the inputs to ModelRunner and workers. - """ - if self.is_driver_worker: - if execute_model_req is None: - if self.do_metadata_broadcast: - # This signals that there's no more requests to process for - # now. All workers are running infinite loop with - # broadcast_tensor_dict, and it stops the loop when the - # driver broadcasts an empty input. Send an empty input to - # notify all other workers to stop their execution loop. - broadcast_tensor_dict({}, src=0) - return None - return self._get_driver_input_and_broadcast(execute_model_req) - else: - return self._get_worker_input_from_broadcast() - - def execute_model( - self, - execute_model_req: Optional[ExecuteModelRequest] = None, - ) -> Optional[List[SamplerOutput]]: - """Executes at least one model step on the given sequences, unless no - sequences are provided.""" - start_time = time.perf_counter() - - inputs = self.prepare_input(execute_model_req) - if inputs is None: - return None - - model_input, worker_input, kwargs = inputs - num_steps = worker_input.num_steps - - self.execute_worker(worker_input) - - # If there is no input, we don't need to execute the model. - if worker_input.num_seq_groups == 0: - return [] - - intermediate_tensors = None - orig_model_execute_time = 0.0 - if not get_pp_group().is_first_rank: - intermediate_tensors = IntermediateTensors( - get_pp_group().recv_tensor_dict( - all_gather_group=get_tp_group())) - if (self.observability_config is not None - and self.observability_config.collect_model_execute_time): - orig_model_execute_time = intermediate_tensors.tensors.get( - "model_execute_time", torch.tensor(0)).item() - - output = self.model_runner.execute_model( - model_input=model_input, - kv_caches=self.kv_cache[worker_input.virtual_engine] - if self.kv_cache is not None else None, - intermediate_tensors=intermediate_tensors, - num_steps=num_steps, - **kwargs, - ) - - model_execute_time = time.perf_counter() - start_time - if not get_pp_group().is_last_rank: - # output is IntermediateTensors - if (self.observability_config is not None - and self.observability_config.collect_model_execute_time): - output.tensors["model_execute_time"] = torch.tensor( - model_execute_time + orig_model_execute_time) - get_pp_group().send_tensor_dict(output.tensors, - all_gather_group=get_tp_group()) - return [None] - if (self.observability_config is not None - and self.observability_config.collect_model_execute_time - and output is not None): - for o in output: - o.model_execute_time = (orig_model_execute_time + - model_execute_time) - - # output is List[SamplerOutput] - return output - - def _execute_model_spmd( - self, - execute_model_req: ExecuteModelRequest, - intermediate_tensors: Optional[IntermediateTensors] = None - ) -> Optional[List[SamplerOutput]]: - """ - Execute model in Single Program Multiple Data (SPMD) fashion. - All workers take the same request, prepare the input and - execute the model. - """ - assert execute_model_req is not None, ( - "_execute_model_spmd() requires each worker to take in an " - "ExecuteModelRequest") - worker_input: WorkerInput = self.prepare_worker_input( - execute_model_req=execute_model_req) - model_input: ModelRunnerInputBase = ( - self.model_runner.prepare_model_input( - execute_model_req.seq_group_metadata_list)) - - self.execute_worker(worker_input) - - # If there is no input, we don't need to execute the model. - if worker_input.num_seq_groups == 0: - return [] - - kwargs = extract_previous_hidden_states(execute_model_req) - - return self.model_runner.execute_model( - model_input=model_input, - kv_caches=self.kv_cache[worker_input.virtual_engine] - if self.kv_cache is not None else None, - intermediate_tensors=intermediate_tensors, - **kwargs, - ) - - -class WorkerWrapperBase: - """ - The whole point of this class is to lazily initialize the worker. - We first instantiate the WorkerWrapper, which remembers the worker module - and class name. Then, when we call `update_environment_variables`, and the - real initialization happens in `init_worker`. - - If worker_class_fn is specified, it will be executed to get the worker - class. - Otherwise, the worker class will be obtained by dynamically importing it - using worker_module_name and worker_class_name. - """ - - def __init__( - self, - worker_module_name: str, - worker_class_name: str, - trust_remote_code: bool = False, - worker_class_fn: Optional[Callable[[], - Type[WorkerBase]]] = None) -> None: - self.worker_module_name = worker_module_name - self.worker_class_name = worker_class_name - self.worker_class_fn = worker_class_fn - self.worker: Optional[WorkerBase] = None - if trust_remote_code: - # note: lazy import to avoid importing torch before initializing - from vllm.utils import init_cached_hf_modules - init_cached_hf_modules() - - @staticmethod - def update_environment_variables(envs: Dict[str, str]) -> None: - key = 'CUDA_VISIBLE_DEVICES' - if key in envs and key in os.environ: - # overwriting CUDA_VISIBLE_DEVICES is desired behavior - # suppress the warning in `update_environment_variables` - del os.environ[key] - update_environment_variables(envs) - - def init_worker(self, *args, **kwargs): - """ - Here we inject some common logic before initializing the worker. - Arguments are passed to the worker class constructor. - """ - enable_trace_function_call_for_thread() - - # see https://github.com/NVIDIA/nccl/issues/1234 - os.environ['NCCL_CUMEM_ENABLE'] = '0' - - from vllm.plugins import load_general_plugins - load_general_plugins() - - if self.worker_class_fn: - worker_class = self.worker_class_fn() - else: - mod = importlib.import_module(self.worker_module_name) - worker_class = getattr(mod, self.worker_class_name) - - self.worker = worker_class(*args, **kwargs) - assert self.worker is not None - - def execute_method(self, method, *args, **kwargs): - try: - target = self if self.worker is None else self.worker - executor = getattr(target, method) - return executor(*args, **kwargs) - except Exception as e: - # if the driver worker also execute methods, - # exceptions in the rest worker may cause deadlock in rpc like ray - # see https://github.com/vllm-project/vllm/issues/3455 - # print the error and inform the user to solve the error - msg = (f"Error executing method {method}. " - "This might cause deadlock in distributed execution.") - logger.exception(msg) - raise e - - -def extract_previous_hidden_states( - data: Union[ExecuteModelRequest, Dict[str, torch.Tensor]]) -> \ - Dict[str, torch.Tensor]: - """If data contains previous_hidden_states, extract it. This returns a dict - which can be used directly as additional kwargs in any following - execute_model calls. This is used in draft models like EAGLE.""" - output = {} - - # When called from non-driver worker, data is dict but when called from - # driver worker, data is ExecuteModelRequest. - if isinstance(data, dict): - if "previous_hidden_states" in data: - output["previous_hidden_states"] = data["previous_hidden_states"] - elif data.previous_hidden_states is not None: - output["previous_hidden_states"] = data.previous_hidden_states\ - .hidden_states - - return output diff --git a/vllm/worker/worker_v2.py b/vllm/worker/worker_v2.py index 255bd569f254a..1ece3e76af822 100644 --- a/vllm/worker/worker_v2.py +++ b/vllm/worker/worker_v2.py @@ -16,9 +16,10 @@ set_custom_all_reduce) from vllm.logger import init_logger from vllm.model_executor import set_random_seed -from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.outputs_v2 import ModelRunnerOutput from vllm.platforms import current_platform from vllm.worker.model_runner_v2 import GPUModelRunner +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size logger = init_logger(__name__) @@ -32,23 +33,33 @@ def __init__( self, model_config: ModelConfig, parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, device_config: DeviceConfig, cache_config: CacheConfig, load_config: LoadConfig, local_rank: int, rank: int, distributed_init_method: str, + is_driver_worker: bool, + speculative_config: Optional[SpeculativeConfig] = None, lora_config: Optional[LoRAConfig] = None, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, + observability_config: Optional[ObservabilityConfig] = None, ): self.model_config = model_config self.parallel_config = parallel_config + self.scheduler_config = scheduler_config self.device_config = device_config self.cache_config = cache_config self.load_config = load_config self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method + self.is_driver_worker = is_driver_worker self.lora_config = lora_config + self.speculative_config = speculative_config + self.prompt_adapter_config = prompt_adapter_config + self.observability_config = observability_config if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing @@ -58,10 +69,10 @@ def __init__( self.model_runner = GPUModelRunner( model_config, parallel_config, + scheduler_config, device_config, cache_config, load_config, - kv_cache_dtype=cache_config.kv_cache_dtype, lora_config=lora_config, ) @@ -131,7 +142,9 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: f" {free_gpu_memory}. This happens when the GPU memory was " "not properly cleaned up before initializing the vLLM instance.") - cache_block_size = self.get_cache_block_size_bytes() + cache_block_size = _get_cache_block_size(self.cache_config, + self.model_config, + self.parallel_config) num_gpu_blocks = int( (total_gpu_memory * self.cache_config.gpu_memory_utilization - peak_memory) // cache_block_size) @@ -139,8 +152,8 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: cache_block_size) num_gpu_blocks = max(num_gpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0) - if self.model_runner.lora_manager: - self.model_runner.remove_all_loras() + # if self.model_runner.lora_manager: + # self.model_runner.remove_all_loras() gc.collect() torch.cuda.empty_cache() return num_gpu_blocks, num_cpu_blocks @@ -165,9 +178,7 @@ def initialize_cache(self, num_gpu_blocks: int, self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks - - # TODO(woosuk): Create KV cache. - self.model_runner.initialize_kv_cache() + self.model_runner.initialize_kv_cache(num_gpu_blocks) def compile_or_warm_up_model(self) -> None: if not self.model_config.enforce_eager: @@ -180,9 +191,10 @@ def compile_or_warm_up_model(self) -> None: def execute_model( self, scheduler_output: "SchedulerOutput", - ) -> None: - sampler_output = self.model_runner.execute_model(scheduler_output) + ) -> ModelRunnerOutput: + output = self.model_runner.execute_model(scheduler_output) # TODO(woosuk): Send the output to the engine process. + return output def init_worker_distributed_environment( @@ -219,3 +231,24 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): f"of at least 8.0. Your {gpu_name} GPU {compute_str}. " "You can use float16 instead by explicitly setting the" "`dtype` flag in CLI, for example: --dtype=half.") + + +def _get_cache_block_size( + cache_config: CacheConfig, + model_config: ModelConfig, + parallel_config: ParallelConfig, +) -> int: + head_size = model_config.get_head_size() + num_heads = model_config.get_num_kv_heads(parallel_config) + num_attention_layers = model_config.get_num_attention_layers( + parallel_config) + + key_cache_block = cache_config.block_size * num_heads * head_size + value_cache_block = key_cache_block + total = num_attention_layers * (key_cache_block + value_cache_block) + if cache_config.cache_dtype == "auto": + dtype = model_config.dtype + else: + dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + dtype_size = get_dtype_size(dtype) + return dtype_size * total From 788d3f4286dc12f8298c3f5c53a6c23fc2fd3a6c Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 7 Oct 2024 01:48:20 -0700 Subject: [PATCH 14/31] Deque -> List --- vllm/core/scheduler_v2.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/vllm/core/scheduler_v2.py b/vllm/core/scheduler_v2.py index c0a2476dd5ed6..9c04ae046bd63 100644 --- a/vllm/core/scheduler_v2.py +++ b/vllm/core/scheduler_v2.py @@ -54,7 +54,7 @@ def __init__( # Priority queues for requests. self.waiting: Deque[Request] = deque() - self.running: Deque[Request] = deque() + self.running: List[Request] = [] self.finished_req_ids: Set[str] = set() self.aborted_req_ids: Set[str] = set() @@ -62,7 +62,6 @@ def __init__( self.cum = 0 def schedule(self) -> "SchedulerOutput": - start = time.time() scheduled_new_reqs: List[Request] = [] scheduled_resumed_reqs: List[Request] = [] scheduled_running_reqs: List[Request] = [] @@ -73,12 +72,12 @@ def schedule(self) -> "SchedulerOutput": token_budget = self.max_num_scheduled_tokens # First, schedule the RUNNING requests. - new_running: Deque[Request] = deque() - while self.running: + req_index = 0 + while req_index < len(self.running): if token_budget == 0: break - request = self.running[0] + request = self.running[req_index] num_tokens = request.num_tokens - request.num_computed_tokens num_tokens = min(num_tokens, token_budget) assert num_tokens > 0 @@ -101,17 +100,15 @@ def schedule(self) -> "SchedulerOutput": break else: # The request can be scheduled. - self.running.popleft() - new_running.append(request) scheduled_running_reqs.append(request) req_to_new_block_ids[request.request_id] = new_block_ids num_scheduled_tokens[request.request_id] = num_tokens token_budget -= num_tokens - request.status = RequestStatus.RUNNING + req_index += 1 break - self.running = new_running + start = time.time() # Next, schedule the WAITING requests. while self.waiting: if preempted_reqs: @@ -209,7 +206,7 @@ def update_from_output( ) -> List[Request]: sampled_token_ids = model_runner_output.sampled_token_ids_cpu.numpy() num_scheduled_tokens = scheduler_output.num_scheduled_tokens - new_running: Deque[Request] = deque() + new_running: List[Request] = [] finished_reqs: List[Request] = [] for request in self.running: req_id = request.request_id From a7912ce45ab4606c5a4983ac607701243f641710 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 7 Oct 2024 02:21:38 -0700 Subject: [PATCH 15/31] Minor --- vllm/core/kv_cache_manager.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/vllm/core/kv_cache_manager.py b/vllm/core/kv_cache_manager.py index 69c5e2f07ed09..2b3a2f15d3d11 100644 --- a/vllm/core/kv_cache_manager.py +++ b/vllm/core/kv_cache_manager.py @@ -1,5 +1,7 @@ from typing import Dict, List, Optional, Set, Tuple +import numpy as np + from vllm.request import Request from vllm.logger import init_logger from vllm.utils import cdiv @@ -26,9 +28,7 @@ def __init__( # Reserve block id 0 for padding. self.free_block_ids = list(range(num_gpu_blocks)) self.req_to_block_ids: Dict[str, List[int]] = {} - self.block_id_to_reqs: List[Set[str]] = [ - set() for _ in range(num_gpu_blocks) - ] + self.ref_cnts = np.zeros(num_gpu_blocks, dtype=np.int32) def get_computed_blocks(self, request: Request) -> List[int]: return [] @@ -51,8 +51,7 @@ def append_slots( # Allocate new blocks. new_block_ids = self._get_new_blocks(num_new_blocks) req_block_ids.extend(new_block_ids) - for block_id in new_block_ids: - self.block_id_to_reqs[block_id].add(request.request_id) + self.ref_cnts[new_block_ids] += 1 return new_block_ids def allocate_slots( @@ -68,18 +67,17 @@ def allocate_slots( return None new_block_ids = self._get_new_blocks(num_new_blocks) - self.req_to_block_ids[request.request_id] = (computed_block_ids + - new_block_ids) - for block_id in new_block_ids: - self.block_id_to_reqs[block_id].add(request.request_id) + block_ids = computed_block_ids + new_block_ids + self.req_to_block_ids[request.request_id] = block_ids + self.ref_cnts[block_ids] += 1 return new_block_ids def free(self, request: Request) -> None: block_ids = self.req_to_block_ids.pop(request.request_id) + self.ref_cnts[block_ids] -= 1 for block_id in block_ids: - reqs = self.block_id_to_reqs[block_id] - reqs.remove(request.request_id) - if not reqs: + ref_cnt = self.ref_cnts[block_id] + if ref_cnt == 0: self.free_block_ids.append(block_id) def _get_new_blocks(self, num_blocks: int) -> List[int]: From 1ff2463fa4447e1aa1d082bb18fccbdc603e52e0 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 7 Oct 2024 03:02:10 -0700 Subject: [PATCH 16/31] Top-k Top-p sampling --- vllm/model_executor/layers/sampler.py | 48 +++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 0f6cdb5536ce5..ccff777fd2a93 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -25,6 +25,7 @@ def forward( sampling_metadata: SamplingMetadata, ) -> SamplerOutput: logits = self.apply_temperature(logits, sampling_metadata.temperature) + logits = self.apply_top_k_top_p(logits, sampling_metadata) probs = self.get_probs(logits) sampled = self.sample(probs, sampling_metadata) @@ -61,6 +62,20 @@ def apply_temperature( logits.div_(temp.unsqueeze(dim=1)) return logits + def apply_top_k_top_p( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> torch.Tensor: + logits = _apply_top_k_top_p( + logits, + sampling_metadata.no_top_k, + sampling_metadata.top_k, + sampling_metadata.no_top_p, + sampling_metadata.top_p, + ) + return logits + def get_probs(self, logits: torch.Tensor) -> torch.Tensor: return torch.softmax(logits, dim=-1, dtype=torch.float32) @@ -109,3 +124,36 @@ def sample( random_sampled, ) return sampled + + +def _apply_top_k_top_p( + logits: torch.Tensor, + no_top_k: bool, + k: torch.Tensor, + no_top_p: bool, + p: torch.Tensor, +) -> torch.Tensor: + if no_top_k and no_top_p: + return logits + logits_sort, logits_idx = logits.sort(dim=-1, descending=False) + + if not no_top_k: + # Apply top-k. + top_k_mask = logits_sort.size(1) - k.to(torch.long) + # Get all the top_k values. + top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) + top_k_mask = logits_sort < top_k_mask + logits_sort.masked_fill_(top_k_mask, -float("inf")) + + if not no_top_p: + # Apply top-p. + probs_sort = logits_sort.softmax(dim=-1) + probs_sum = probs_sort.cumsum(dim=-1) + top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) + # at least one + top_p_mask[:, -1] = False + logits_sort.masked_fill_(top_p_mask, -float("inf")) + + # Re-sort the probabilities. + logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) + return logits From 256ac81337a4c76f58861f853a44af0d0313ef98 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 7 Oct 2024 03:02:47 -0700 Subject: [PATCH 17/31] Add comment --- vllm/core/kv_cache_manager.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/core/kv_cache_manager.py b/vllm/core/kv_cache_manager.py index 2b3a2f15d3d11..ea5e1edc486ae 100644 --- a/vllm/core/kv_cache_manager.py +++ b/vllm/core/kv_cache_manager.py @@ -38,6 +38,8 @@ def append_slots( request: Request, num_tokens: int, ) -> Optional[List[int]]: + # NOTE(woosuk): This method takes up to 5% of the total runtime. + # OPTIMIZE THIS. num_blocks = cdiv(request.num_computed_tokens + num_tokens, self.block_size) req_block_ids = self.req_to_block_ids[request.request_id] @@ -51,7 +53,7 @@ def append_slots( # Allocate new blocks. new_block_ids = self._get_new_blocks(num_new_blocks) req_block_ids.extend(new_block_ids) - self.ref_cnts[new_block_ids] += 1 + self.ref_cnts[new_block_ids] += 1 return new_block_ids def allocate_slots( From a53cfae61cba8160c29a691e71d7186b7744a480 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 7 Oct 2024 03:04:49 -0700 Subject: [PATCH 18/31] Optimize --- vllm/worker/model_runner_v2.py | 44 ++++++++++++++++++++++++++-------- 1 file changed, 34 insertions(+), 10 deletions(-) diff --git a/vllm/worker/model_runner_v2.py b/vllm/worker/model_runner_v2.py index bbdb76dce2fd5..aba8f00342bdb 100644 --- a/vllm/worker/model_runner_v2.py +++ b/vllm/worker/model_runner_v2.py @@ -1,3 +1,4 @@ +import time from dataclasses import dataclass from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union) @@ -80,6 +81,10 @@ def __init__( pin_memory=self.pin_memory, ) + self.cum1 = 0 + self.cum2 = 0 + self.cum3 = 0 + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Remove stopped requests from the cached states. # Keep the states of the pre-empted requests. @@ -104,20 +109,24 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: for req_data in scheduler_output.scheduled_running_reqs: req_id = req_data.req_id req_state = self.requests[req_id] - req_state.block_ids.extend(req_data.new_block_ids) - req_state.num_computed_tokens = req_data.num_computed_tokens - - # Update the block table and num_computed_tokens. req_index = self.persistent_batch.req_id_to_index[req_id] - end_block_index = len(req_state.block_ids) - start_block_index = end_block_index - len(req_data.new_block_ids) - self.persistent_batch.block_table_cpu[ - req_index, - start_block_index:end_block_index] = torch.as_tensor( - req_data.new_block_ids, dtype=torch.int32, device="cpu") + + # Update the num_computed_tokens. + req_state.num_computed_tokens = req_data.num_computed_tokens self.persistent_batch.num_computed_tokens_cpu[req_index] = ( req_data.num_computed_tokens) + # Update the block table. + num_new_blocks = len(req_data.new_block_ids) + if num_new_blocks == 0: + continue + start_block_index = len(req_state.block_ids) + req_state.block_ids.extend(req_data.new_block_ids) + for i, block_id in enumerate(req_data.new_block_ids): + self.persistent_batch.block_table_cpu[req_index, + start_block_index + + i] = block_id + req_ids_to_add: List[str] = [] # Add new requests to the cached states. for req_data in scheduler_output.scheduled_new_reqs: @@ -297,9 +306,17 @@ def execute_model( self, scheduler_output: "SchedulerOutput", ) -> ModelRunnerOutput: + start = time.time() self._update_states(scheduler_output) + end = time.time() + self.cum1 += end - start + + start = time.time() inputs = self._prepare_inputs(scheduler_output) input_ids, positions, attn_metadata, logits_indices = inputs + end = time.time() + self.cum2 += end - start + hidden_states = self.model( input_ids=input_ids, positions=positions, @@ -318,6 +335,8 @@ def execute_model( # CPU-GPU synchronization happens here. # TODO: Optimize. + # torch.cuda.synchronize() + start = time.time() sampled_token_ids = sampler_output.sampled_token_ids.cpu() num_reqs = self.persistent_batch.num_reqs for i, req_id in enumerate(self.persistent_batch.req_ids[:num_reqs]): @@ -354,6 +373,11 @@ def execute_model( logprob_token_ids_cpu=logprob_token_ids, logprobs_cpu=logprobs, ) + end = time.time() + self.cum3 += end - start + # print(f"cum1: {self.cum1 * 1000:.3f} ms") + # print(f"cum2: {self.cum2 * 1000:.3f} ms") + # print(f"cum3: {self.cum3 * 1000:.3f} ms") return model_runner_output def load_model(self) -> None: From 105ceaa5546bd37e91d27d4a7939996a6354edc8 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 7 Oct 2024 04:06:54 -0700 Subject: [PATCH 19/31] Minor --- vllm/core/scheduler_v2.py | 7 +++---- vllm/worker/model_runner_v2.py | 24 +++++++++++++----------- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/vllm/core/scheduler_v2.py b/vllm/core/scheduler_v2.py index 9c04ae046bd63..7c2f72bf43d0a 100644 --- a/vllm/core/scheduler_v2.py +++ b/vllm/core/scheduler_v2.py @@ -83,8 +83,11 @@ def schedule(self) -> "SchedulerOutput": assert num_tokens > 0 while True: + start = time.time() new_block_ids = self.kv_cache_manager.append_slots( request, num_tokens) + end = time.time() + self.cum += (end - start) if new_block_ids is None: # The request cannot be scheduled. # Preempt the lowest-priority request. @@ -190,10 +193,6 @@ def schedule(self) -> "SchedulerOutput": finished_req_ids=self.finished_req_ids, aborted_req_ids=self.aborted_req_ids, ) - end = time.time() - self.cum += (end - start) - print(f"Scheduler time: {(end - start) * 1000:.3f} ms") - print(f"Cumulative scheduler time: {self.cum * 1000:.3f} ms") self.finished_req_ids = set() self.aborted_req_ids = set() diff --git a/vllm/worker/model_runner_v2.py b/vllm/worker/model_runner_v2.py index aba8f00342bdb..f62bd76c05348 100644 --- a/vllm/worker/model_runner_v2.py +++ b/vllm/worker/model_runner_v2.py @@ -106,6 +106,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: removed_req_indices.append(req_index) # Update the states of the running requests. + start = time.time() for req_data in scheduler_output.scheduled_running_reqs: req_id = req_data.req_id req_state = self.requests[req_id] @@ -126,6 +127,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self.persistent_batch.block_table_cpu[req_index, start_block_index + i] = block_id + end = time.time() + self.cum2 += end - start req_ids_to_add: List[str] = [] # Add new requests to the cached states. @@ -315,7 +318,7 @@ def execute_model( inputs = self._prepare_inputs(scheduler_output) input_ids, positions, attn_metadata, logits_indices = inputs end = time.time() - self.cum2 += end - start + # self.cum2 += end - start hidden_states = self.model( input_ids=input_ids, @@ -333,11 +336,11 @@ def execute_model( sampling_metadata=sampling_metadata, ) - # CPU-GPU synchronization happens here. - # TODO: Optimize. - # torch.cuda.synchronize() - start = time.time() + # NOTE: CPU-GPU synchronization happens here. sampled_token_ids = sampler_output.sampled_token_ids.cpu() + start = time.time() + sampled_token_ids_list = sampled_token_ids.tolist() + # TODO: Optimize. num_reqs = self.persistent_batch.num_reqs for i, req_id in enumerate(self.persistent_batch.req_ids[:num_reqs]): req_state = self.requests[req_id] @@ -346,7 +349,7 @@ def execute_model( num_prompt_tokens = len(req_state.prompt_token_ids) if seq_len >= num_prompt_tokens: # Append the sampled token to the output token ids. - token_id = sampled_token_ids[i] + token_id = sampled_token_ids_list[i] self.persistent_batch.token_ids_cpu[i, seq_len] = token_id req_state.output_token_ids.append(token_id) else: @@ -357,6 +360,8 @@ def execute_model( offset = generator.get_offset() generator = generator.set_offset(offset - 1) self.persistent_batch.generators[i] = generator + end = time.time() + self.cum3 += end - start if sampler_output.logprob_token_ids is None: logprob_token_ids = None @@ -373,8 +378,6 @@ def execute_model( logprob_token_ids_cpu=logprob_token_ids, logprobs_cpu=logprobs, ) - end = time.time() - self.cum3 += end - start # print(f"cum1: {self.cum1 * 1000:.3f} ms") # print(f"cum2: {self.cum2 * 1000:.3f} ms") # print(f"cum3: {self.cum3 * 1000:.3f} ms") @@ -525,10 +528,9 @@ def add_request( self.req_id_to_index[request.req_id] = req_index # Copy the prompt token ids and output token ids. - # TODO: Optimize. - for i, token_id in enumerate(request.prompt_token_ids): - self.token_ids_cpu[req_index, i] = token_id num_prompt_tokens = len(request.prompt_token_ids) + self.token_ids_cpu[req_index, :num_prompt_tokens] = torch.as_tensor( + request.prompt_token_ids, dtype=torch.int32, device="cpu") for i, token_id in enumerate(request.output_token_ids): self.token_ids_cpu[req_index, num_prompt_tokens + i] = token_id From a5ca329b2984a6c3740de73a147d9ddac863be6a Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 7 Oct 2024 04:07:29 -0700 Subject: [PATCH 20/31] Use int32 instead of int64 --- vllm/model_executor/layers/sampler.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index ccff777fd2a93..5de2d869f9aef 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -29,6 +29,8 @@ def forward( probs = self.get_probs(logits) sampled = self.sample(probs, sampling_metadata) + # Use int32 to reduce the tensor size. + sampled = sampled.to(torch.int32) if sampling_metadata.max_num_logprobs > 0: logprobs = self.get_logprobs(logits) @@ -36,6 +38,8 @@ def forward( # and concatenate the topk with the sampled token_id. topk_logprobs, topk_indices = torch.topk( logprobs, sampling_metadata.max_num_logprobs, dim=-1) + # Use int32 to reduce the tensor size. + topk_indices = topk_indices.to(torch.int32) else: topk_logprobs = None topk_indices = None From 438dc092dab60b560ecff4540089ea3f07a64a24 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 7 Oct 2024 10:19:33 -0700 Subject: [PATCH 21/31] Remove ref cnt --- vllm/core/kv_cache_manager.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/vllm/core/kv_cache_manager.py b/vllm/core/kv_cache_manager.py index ea5e1edc486ae..7d705e86a1af3 100644 --- a/vllm/core/kv_cache_manager.py +++ b/vllm/core/kv_cache_manager.py @@ -28,7 +28,7 @@ def __init__( # Reserve block id 0 for padding. self.free_block_ids = list(range(num_gpu_blocks)) self.req_to_block_ids: Dict[str, List[int]] = {} - self.ref_cnts = np.zeros(num_gpu_blocks, dtype=np.int32) + # self.ref_cnts = np.zeros(num_gpu_blocks, dtype=np.int32) def get_computed_blocks(self, request: Request) -> List[int]: return [] @@ -53,7 +53,7 @@ def append_slots( # Allocate new blocks. new_block_ids = self._get_new_blocks(num_new_blocks) req_block_ids.extend(new_block_ids) - self.ref_cnts[new_block_ids] += 1 + # self.ref_cnts[new_block_ids] += 1 return new_block_ids def allocate_slots( @@ -71,16 +71,16 @@ def allocate_slots( new_block_ids = self._get_new_blocks(num_new_blocks) block_ids = computed_block_ids + new_block_ids self.req_to_block_ids[request.request_id] = block_ids - self.ref_cnts[block_ids] += 1 + # self.ref_cnts[block_ids] += 1 return new_block_ids def free(self, request: Request) -> None: block_ids = self.req_to_block_ids.pop(request.request_id) - self.ref_cnts[block_ids] -= 1 + # self.ref_cnts[block_ids] -= 1 for block_id in block_ids: - ref_cnt = self.ref_cnts[block_id] - if ref_cnt == 0: - self.free_block_ids.append(block_id) + # ref_cnt = self.ref_cnts[block_id] + # if ref_cnt == 0: + self.free_block_ids.append(block_id) def _get_new_blocks(self, num_blocks: int) -> List[int]: assert num_blocks <= len(self.free_block_ids) From 09a7fa428815cdd447ee25b491d7403736a6aa90 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 7 Oct 2024 11:38:07 -0700 Subject: [PATCH 22/31] Use numpy --- vllm/worker/model_runner_v2.py | 120 +++++++++++++++++---------------- 1 file changed, 63 insertions(+), 57 deletions(-) diff --git a/vllm/worker/model_runner_v2.py b/vllm/worker/model_runner_v2.py index f62bd76c05348..d6ebb5285bc2b 100644 --- a/vllm/worker/model_runner_v2.py +++ b/vllm/worker/model_runner_v2.py @@ -84,6 +84,7 @@ def __init__( self.cum1 = 0 self.cum2 = 0 self.cum3 = 0 + self.cum4 = 0 def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Remove stopped requests from the cached states. @@ -106,29 +107,28 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: removed_req_indices.append(req_index) # Update the states of the running requests. - start = time.time() for req_data in scheduler_output.scheduled_running_reqs: req_id = req_data.req_id req_state = self.requests[req_id] req_index = self.persistent_batch.req_id_to_index[req_id] # Update the num_computed_tokens. + start = time.time() req_state.num_computed_tokens = req_data.num_computed_tokens self.persistent_batch.num_computed_tokens_cpu[req_index] = ( req_data.num_computed_tokens) + end = time.time() + self.cum4 += end - start # Update the block table. num_new_blocks = len(req_data.new_block_ids) if num_new_blocks == 0: continue - start_block_index = len(req_state.block_ids) + start_index = len(req_state.block_ids) + end_index = start_index + num_new_blocks req_state.block_ids.extend(req_data.new_block_ids) - for i, block_id in enumerate(req_data.new_block_ids): - self.persistent_batch.block_table_cpu[req_index, - start_block_index + - i] = block_id - end = time.time() - self.cum2 += end - start + self.persistent_batch.block_table_cpu[ + req_index, start_index:end_index] = req_data.new_block_ids req_ids_to_add: List[str] = [] # Add new requests to the cached states. @@ -182,7 +182,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # OPTIMIZATION: Start copying the block table first. # This way, we can overlap the copy with the following CPU operations. self.persistent_batch.block_table[:num_reqs].copy_( - self.persistent_batch.block_table_cpu[:num_reqs], + self.persistent_batch.block_table_cpu_tensor[:num_reqs], non_blocking=True) # Get the number of scheduled tokens for each request. @@ -201,7 +201,6 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] indices = np.arange(num_reqs) req_indices = np.repeat(indices, num_scheduled_tokens) - req_indices = torch.from_numpy(req_indices) # Get batched arange. # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] @@ -209,33 +208,35 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): (num_reqs, 1)) mask = arange_matrix < num_scheduled_tokens[:, np.newaxis] arange = arange_matrix[mask] - arange = torch.from_numpy(arange) # Get positions. positions = torch.empty((total_num_scheduled_tokens, ), dtype=torch.int32, device="cpu", pin_memory=self.pin_memory) - torch.add(self.persistent_batch.num_computed_tokens_cpu[req_indices], - arange, - out=positions) + positions_np = positions.numpy() + np.add(self.persistent_batch.num_computed_tokens_cpu[req_indices], + arange, + out=positions_np) # Get token indices. # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] # where M is the max_model_len. - token_indices = positions + req_indices * self.max_model_len + token_indices = positions_np + req_indices * self.max_model_len + token_indices = torch.from_numpy(token_indices) input_ids = torch.empty((total_num_scheduled_tokens, ), dtype=torch.int32, device="cpu", pin_memory=self.pin_memory) - torch.index_select(self.persistent_batch.token_ids_cpu.flatten(), + torch.index_select(torch.from_numpy( + self.persistent_batch.token_ids_cpu).flatten(), 0, token_indices, out=input_ids) # Calculate the slot mapping. - block_numbers = self.persistent_batch.block_table_cpu.flatten()[ + block_numbers = self.persistent_batch.block_table_cpu_tensor.flatten()[ token_indices // self.block_size] block_offsets = token_indices % self.block_size slot_mapping = torch.empty((total_num_scheduled_tokens, ), @@ -245,27 +246,28 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): torch.add(block_numbers * self.block_size, block_offsets, out=slot_mapping) + slot_mapping = block_numbers * self.block_size + block_offsets # Prepare the attention metadata. - num_scheduled_tokens = torch.from_numpy(num_scheduled_tokens) query_start_loc = torch.empty((num_reqs + 1, ), dtype=torch.int32, device="cpu", pin_memory=self.pin_memory) - query_start_loc[0] = 0 - torch.cumsum(num_scheduled_tokens, dim=0, out=query_start_loc[1:]) + query_start_loc_np = query_start_loc.numpy() + query_start_loc_np[0] = 0 + np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1:]) seq_lens = (self.persistent_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens) - max_seq_len = seq_lens.max().item() + max_seq_len = seq_lens.max() seq_start_loc = torch.empty((num_reqs + 1, ), dtype=torch.int32, device="cpu", pin_memory=self.pin_memory) - seq_start_loc[0] = 0 - torch.cumsum(seq_lens, dim=0, out=seq_start_loc[1:]) + seq_start_loc_np = seq_start_loc.numpy() + seq_start_loc_np[0] = 0 + np.cumsum(seq_lens, out=seq_start_loc_np[1:]) - # Move the tensors to the device. input_ids = input_ids.to(self.device, non_blocking=True) positions = positions.to(self.device, non_blocking=True).long() query_start_loc = query_start_loc.to(self.device, non_blocking=True) @@ -318,7 +320,7 @@ def execute_model( inputs = self._prepare_inputs(scheduler_output) input_ids, positions, attn_metadata, logits_indices = inputs end = time.time() - # self.cum2 += end - start + self.cum2 += end - start hidden_states = self.model( input_ids=input_ids, @@ -381,6 +383,7 @@ def execute_model( # print(f"cum1: {self.cum1 * 1000:.3f} ms") # print(f"cum2: {self.cum2 * 1000:.3f} ms") # print(f"cum3: {self.cum3 * 1000:.3f} ms") + # print(f"cum4: {self.cum4 * 1000:.3f} ms") return model_runner_output def load_model(self) -> None: @@ -463,50 +466,52 @@ def __init__( self.req_ids: List[Optional[str]] = [None] * max_num_reqs self.req_id_to_index: Dict[str, int] = {} - self.token_ids_cpu = torch.empty((max_num_reqs, max_model_len), - dtype=torch.int32, - device="cpu") - self.num_computed_tokens_cpu = torch.empty((max_num_reqs, ), - dtype=torch.int32, - device="cpu") + self.token_ids_cpu = np.empty((max_num_reqs, max_model_len), + dtype=np.int32) + self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) # Attention-related. - self.block_table = torch.empty((max_num_reqs, max_num_blocks_per_req), + self.block_table = torch.zeros((max_num_reqs, max_num_blocks_per_req), device=self.device, dtype=torch.int32) - self.block_table_cpu = torch.empty( + self.block_table_cpu_tensor = torch.zeros( (max_num_reqs, max_num_blocks_per_req), device="cpu", dtype=torch.int32, - pin_memory=pin_memory) + pin_memory=pin_memory, + ) + self.block_table_cpu = self.block_table_cpu_tensor.numpy() # Sampling-related. self.temperature = torch.empty((max_num_reqs, ), dtype=torch.float32, device=device) - self.temperature_cpu = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=pin_memory) + self.temperature_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device="cpu", + pin_memory=pin_memory) + self.temperature_cpu = self.temperature_cpu_tensor.numpy() self.greedy_reqs: Set[str] = set() self.random_reqs: Set[str] = set() self.top_p = torch.empty((max_num_reqs, ), dtype=torch.float32, device=device) - self.top_p_cpu = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=pin_memory) + self.top_p_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device="cpu", + pin_memory=pin_memory) + self.top_p_cpu = self.top_p_cpu_tensor.numpy() self.top_p_reqs: Set[str] = set() self.top_k = torch.empty((max_num_reqs, ), - dtype=torch.float32, + dtype=torch.int32, device=device) - self.top_k_cpu = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=pin_memory) + self.top_k_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.int32, + device="cpu", + pin_memory=pin_memory) + self.top_k_cpu = self.top_k_cpu_tensor.numpy() self.top_k_reqs: Set[str] = set() self.generators: List[Optional[torch.Generator]] = [None @@ -529,15 +534,16 @@ def add_request( # Copy the prompt token ids and output token ids. num_prompt_tokens = len(request.prompt_token_ids) - self.token_ids_cpu[req_index, :num_prompt_tokens] = torch.as_tensor( - request.prompt_token_ids, dtype=torch.int32, device="cpu") - for i, token_id in enumerate(request.output_token_ids): - self.token_ids_cpu[req_index, num_prompt_tokens + i] = token_id + self.token_ids_cpu[ + req_index, :num_prompt_tokens] = request.prompt_token_ids + start_idx = num_prompt_tokens + end_idx = start_idx + len(request.output_token_ids) + self.token_ids_cpu[req_index, + start_idx:end_idx] = request.output_token_ids self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens - # TODO: Optimize. - for i, block_id in enumerate(request.block_ids): - self.block_table_cpu[req_index, i] = block_id + num_blocks = len(request.block_ids) + self.block_table_cpu[req_index, :num_blocks] = request.block_ids sampling_params = request.sampling_params self.temperature_cpu[req_index] = sampling_params.temperature @@ -637,10 +643,10 @@ def make_sampling_metadata( ) -> SamplingMetadata: if not skip_copy: self.temperature[:self.num_reqs].copy_( - self.temperature_cpu[:self.num_reqs], non_blocking=True) - self.top_p[:self.num_reqs].copy_(self.top_p_cpu[:self.num_reqs], + self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True) + self.top_p[:self.num_reqs].copy_(self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True) - self.top_k[:self.num_reqs].copy_(self.top_k_cpu[:self.num_reqs], + self.top_k[:self.num_reqs].copy_(self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True) return SamplingMetadata( temperature=self.temperature[:self.num_reqs], From ea2b5e068b302e16e84a8852c5ca57ea70ab57ad Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 7 Oct 2024 18:08:27 -0700 Subject: [PATCH 23/31] Add back ref cnts --- vllm/core/kv_cache_manager.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/vllm/core/kv_cache_manager.py b/vllm/core/kv_cache_manager.py index 7d705e86a1af3..ea5e1edc486ae 100644 --- a/vllm/core/kv_cache_manager.py +++ b/vllm/core/kv_cache_manager.py @@ -28,7 +28,7 @@ def __init__( # Reserve block id 0 for padding. self.free_block_ids = list(range(num_gpu_blocks)) self.req_to_block_ids: Dict[str, List[int]] = {} - # self.ref_cnts = np.zeros(num_gpu_blocks, dtype=np.int32) + self.ref_cnts = np.zeros(num_gpu_blocks, dtype=np.int32) def get_computed_blocks(self, request: Request) -> List[int]: return [] @@ -53,7 +53,7 @@ def append_slots( # Allocate new blocks. new_block_ids = self._get_new_blocks(num_new_blocks) req_block_ids.extend(new_block_ids) - # self.ref_cnts[new_block_ids] += 1 + self.ref_cnts[new_block_ids] += 1 return new_block_ids def allocate_slots( @@ -71,16 +71,16 @@ def allocate_slots( new_block_ids = self._get_new_blocks(num_new_blocks) block_ids = computed_block_ids + new_block_ids self.req_to_block_ids[request.request_id] = block_ids - # self.ref_cnts[block_ids] += 1 + self.ref_cnts[block_ids] += 1 return new_block_ids def free(self, request: Request) -> None: block_ids = self.req_to_block_ids.pop(request.request_id) - # self.ref_cnts[block_ids] -= 1 + self.ref_cnts[block_ids] -= 1 for block_id in block_ids: - # ref_cnt = self.ref_cnts[block_id] - # if ref_cnt == 0: - self.free_block_ids.append(block_id) + ref_cnt = self.ref_cnts[block_id] + if ref_cnt == 0: + self.free_block_ids.append(block_id) def _get_new_blocks(self, num_blocks: int) -> List[int]: assert num_blocks <= len(self.free_block_ids) From 5f8bc7d7783a45f9e716353045b51a71727ed94b Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 7 Oct 2024 22:17:08 -0700 Subject: [PATCH 24/31] Minor --- vllm/core/scheduler_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/core/scheduler_v2.py b/vllm/core/scheduler_v2.py index 7c2f72bf43d0a..8c7be73bf1d25 100644 --- a/vllm/core/scheduler_v2.py +++ b/vllm/core/scheduler_v2.py @@ -203,7 +203,7 @@ def update_from_output( scheduler_output: "SchedulerOutput", model_runner_output: "ModelRunnerOutput", ) -> List[Request]: - sampled_token_ids = model_runner_output.sampled_token_ids_cpu.numpy() + sampled_token_ids = model_runner_output.sampled_token_ids_cpu.tolist() num_scheduled_tokens = scheduler_output.num_scheduled_tokens new_running: List[Request] = [] finished_reqs: List[Request] = [] From 36453c124849f1d474b1c725f8b2241834a9a6d4 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 7 Oct 2024 22:17:19 -0700 Subject: [PATCH 25/31] Output text --- vllm/outputs_v2.py | 2 +- vllm/request.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/outputs_v2.py b/vllm/outputs_v2.py index a893907968643..dd96c8049314b 100644 --- a/vllm/outputs_v2.py +++ b/vllm/outputs_v2.py @@ -67,7 +67,7 @@ def from_request(cls, request: Request) -> "RequestOutput": # TODO: Support `n` > 1. completion_output = CompletionOutput( index=0, - text="", + text=request.output_text, token_ids=request.output_token_ids, logprobs=None, finish_reason=request.get_finished_reason(), diff --git a/vllm/request.py b/vllm/request.py index 7c0a25835ae95..58fb5ce31d825 100644 --- a/vllm/request.py +++ b/vllm/request.py @@ -42,6 +42,7 @@ def __init__( self.prompt_token_ids = inputs["prompt_token_ids"] self.num_prompt_tokens = len(self.prompt_token_ids) self.output_token_ids: List[int] = [] + self.output_text = "" self.num_computed_tokens = 0 @property From 7a813f6dcce261495dd35b3bba2eb7f68d221ae4 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 7 Oct 2024 22:20:46 -0700 Subject: [PATCH 26/31] Detokenizer --- vllm/transformers_utils/detokenizer.py | 331 ------------------- vllm/transformers_utils/detokenizer_utils.py | 167 ++++++++++ vllm/transformers_utils/detokenizer_v2.py | 150 +++++++++ 3 files changed, 317 insertions(+), 331 deletions(-) delete mode 100644 vllm/transformers_utils/detokenizer.py create mode 100644 vllm/transformers_utils/detokenizer_utils.py create mode 100644 vllm/transformers_utils/detokenizer_v2.py diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py deleted file mode 100644 index fba9ea69a5f48..0000000000000 --- a/vllm/transformers_utils/detokenizer.py +++ /dev/null @@ -1,331 +0,0 @@ -from typing import Dict, List, Optional, Tuple - -from .tokenizer import AnyTokenizer -from .tokenizer_group import BaseTokenizerGroup -from vllm.sampling_params import SamplingParams - -# Used eg. for marking rejected tokens in spec decoding. -INVALID_TOKEN_ID = -1 -Sequence = None -SequenceGroup = None -Logprob = None - - -class Detokenizer: - """Provides methods to decode the output of a model into text.""" - - def __init__(self, tokenizer_group: BaseTokenizerGroup): - self.tokenizer_group = tokenizer_group - - def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer: - """Returns the HF tokenizer to use for a given sequence.""" - return self.tokenizer_group.get_lora_tokenizer(sequence.lora_request) - - def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup, - prompt_logprobs: List[Optional[Dict[ - int, Logprob]]], - position_offset: int) -> None: - """Decodes the logprobs for the prompt of a sequence group. - - Args: - seq_group: The sequence group to decode. - prompt_logprobs: The logprobs to decode. - position_offset: Offset of the first index of the logprobs - relative to the start of the sequence (for chunked prefill). - - Returns: - The prompt logprobs with the decoded tokens. - """ - prms = seq_group.sampling_params - assert prms is not None - - # We can pick any sequence for the prompt. - seq = seq_group.get_seqs()[0] - # Only prompt, without the generated token. - all_token_ids = seq.get_token_ids() - prompt_token_ids = all_token_ids[:-1] - tokenizer = self.get_tokenizer_for_seq(seq) - prefix_offset = 0 - read_offset = 0 - next_iter_prefix_offset = 0 - next_iter_read_offset = 0 - next_iter_tokens: List[str] = [] - prev_tokens = None - - for token_position_in_logprob, prompt_logprobs_for_token in enumerate( - prompt_logprobs): - - # Absolute token position equals the index in the logprobs - # list plus the offset of the entire logprobs list relative - # to the start of the sequence. - token_position = token_position_in_logprob + position_offset - if not prompt_logprobs_for_token: - continue - for token_id, sample_logprob in prompt_logprobs_for_token.items(): - if (sample_logprob.decoded_token is None - and token_id != INVALID_TOKEN_ID): - prompt_token_ids_with_token = ( - prompt_token_ids[:token_position] + [token_id]) - (new_tokens, new_text, new_prefix_offset, - new_read_offset) = detokenize_incrementally( - tokenizer=tokenizer, - all_input_ids=prompt_token_ids_with_token, - prev_tokens=prev_tokens, - prefix_offset=prefix_offset, - read_offset=read_offset, - skip_special_tokens=prms.skip_special_tokens, - spaces_between_special_tokens=prms. - spaces_between_special_tokens, - ) - - sample_logprob.decoded_token = new_text - - # Use the offsets & prev tokens corresponding to - # real tokens to ensure detokenization is consistent - # actual with prompt. - if token_id == all_token_ids[token_position]: - next_iter_prefix_offset = new_prefix_offset - next_iter_read_offset = new_read_offset - next_iter_tokens = new_tokens - - # Advance to the next token position. - prefix_offset = next_iter_prefix_offset - read_offset = next_iter_read_offset - if prev_tokens is None: - prev_tokens = next_iter_tokens - else: - prev_tokens.extend(next_iter_tokens) - - def decode_sequence_inplace(self, seq: Sequence, - prms: SamplingParams) -> int: - """Decodes the new token for a sequence. In-place operation. - - Args: - seq: The sequence to decode. - prms: The sampling parameters used to generate the sequence. - - Returns: - The number of characters added to the output text. - """ - all_input_ids = seq.get_token_ids() - token_id_generated_this_iteration = all_input_ids[-1] - tokenizer = self.get_tokenizer_for_seq(seq) - - # Convert prompt token IDs to tokens if necessary. - # Do it here so that we don't have to repeat this - # computation for each logprob. - if seq.tokens is None: - (seq.tokens, seq.prefix_offset, - seq.read_offset) = convert_prompt_ids_to_tokens( - tokenizer=tokenizer, - prompt_ids=all_input_ids[:-1], - skip_special_tokens=prms.skip_special_tokens, - ) - - (new_tokens, new_decoded_token_text, prefix_offset, - read_offset) = detokenize_incrementally( - tokenizer=tokenizer, - all_input_ids=all_input_ids, - prev_tokens=seq.tokens, - prefix_offset=seq.prefix_offset, - read_offset=seq.read_offset, - skip_special_tokens=prms.skip_special_tokens, - spaces_between_special_tokens=prms.spaces_between_special_tokens, - ) - - # Decode logprobs - logprobs = seq.output_logprobs[-1] - if logprobs: - previous_tokens = all_input_ids[:-1] - for token_id, sample_logprob in logprobs.items(): - # If the token was generated this iteration, - # use the provided text. - if token_id == token_id_generated_this_iteration: - sample_logprob.decoded_token = new_decoded_token_text - continue - - if (sample_logprob.decoded_token is None - and token_id != INVALID_TOKEN_ID): - all_input_ids_with_logprob = previous_tokens + [token_id] - (_, new_text, _, _) = detokenize_incrementally( - tokenizer=tokenizer, - all_input_ids=all_input_ids_with_logprob, - prev_tokens=seq.tokens, - prefix_offset=seq.prefix_offset, - read_offset=seq.read_offset, - skip_special_tokens=prms.skip_special_tokens, - spaces_between_special_tokens=prms. - spaces_between_special_tokens, - ) - sample_logprob.decoded_token = new_text - - seq.tokens.extend(new_tokens) - seq.prefix_offset = prefix_offset - seq.read_offset = read_offset - seq.output_text += new_decoded_token_text - - return len(new_decoded_token_text) - - -def _replace_none_with_empty(tokens: List[Optional[str]]): - for i, token in enumerate(tokens): - if token is None: - tokens[i] = "" - - -def _convert_tokens_to_string_with_added_encoders( - tokenizer: AnyTokenizer, - output_tokens: List[str], - skip_special_tokens: bool, - spaces_between_special_tokens: bool, -) -> str: - # Adapted from - # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921 - # NOTE(woosuk): The following code is slow because it runs a for loop over - # the output_tokens. In Python, running a for loop over a list can be slow - # even when the loop body is very simple. - sub_texts: List[str] = [] - current_sub_text: List[str] = [] - all_special_tokens = set(tokenizer.all_special_tokens) - for token in output_tokens: - if skip_special_tokens and token in all_special_tokens: - continue - if token in tokenizer.get_added_vocab(): - if current_sub_text: - sub_text = tokenizer.convert_tokens_to_string(current_sub_text) - sub_texts.append(sub_text) - current_sub_text = [] - sub_texts.append(token) - else: - current_sub_text.append(token) - if current_sub_text: - sub_text = tokenizer.convert_tokens_to_string(current_sub_text) - sub_texts.append(sub_text) - if spaces_between_special_tokens: - return " ".join(sub_texts) - else: - return "".join(sub_texts) - - -# 5 is an arbitrary value that should work for all -# tokenizers (bigger = more conservative). -INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5 - - -def convert_prompt_ids_to_tokens( - tokenizer: AnyTokenizer, - prompt_ids: List[int], - skip_special_tokens: bool = False, -) -> Tuple[List[str], int, int]: - """Converts the prompt ids to tokens and returns the tokens and offsets - for incremental detokenization. - - Note that not all tokens are converted to strings. Only the tokens that - are necessary for incremental detokenization are converted to strings. - """ - # We do not need to convert the whole prompt to tokens. - # Offset a little more in case we have special tokens. - new_tokens = tokenizer.convert_ids_to_tokens( - prompt_ids[-INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET - 2:], - skip_special_tokens=skip_special_tokens) - read_offset = len(new_tokens) - prefix_offset = max( - read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0) - # This is required to guard against out-of-vocab prompt token ids - _replace_none_with_empty(new_tokens) # type: ignore[arg-type] - return new_tokens, prefix_offset, read_offset - - -# Based on -# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15 -# under Apache 2.0 license -def detokenize_incrementally( - tokenizer: AnyTokenizer, - all_input_ids: List[int], - prev_tokens: Optional[List[str]], - prefix_offset: int, - read_offset: int, - skip_special_tokens: bool = False, - spaces_between_special_tokens: bool = True, -) -> Tuple[List[str], str, int, int]: - """Detokenizes the input ids incrementally and returns the new tokens - and the new text. - - If `prev_tokens` is None, this function will convert the input ids to - tokens and return the tokens and the new text. Otherwise, it will return the - new tokens and the new text. - - This function will also return the new prefix offset and the new read - offset to be used in the next iteration. - - The offsets are necessary to defeat cleanup algorithms in the decode which - decide to add a space or not depending on the surrounding ids. - - Args: - tokenizer: The tokenizer to use. - all_input_ids: The input ids. The last id is the new token id. - prev_tokens: The previous tokens. If None, this function will convert - the input ids to tokens and return the tokens and the new text. - prefix_offset: The prefix offset. - read_offset: The read offset. - skip_special_tokens: Whether to skip special tokens. - spaces_between_special_tokens: Whether to add spaces between special - tokens. - """ - new_token_id = all_input_ids[-1] - # This is the first iteration for this sequence - is_first_iter = prev_tokens is None - if is_first_iter: - (prev_tokens, prefix_offset, - read_offset) = convert_prompt_ids_to_tokens( - tokenizer, - all_input_ids[:-1], - skip_special_tokens=skip_special_tokens) - assert prev_tokens is not None - - # If the new token id is out of bounds, return an empty string. - if new_token_id >= len(tokenizer): - new_tokens = [""] - else: - # Put new_token_id in a list so skip_special_tokens is respected - new_tokens = tokenizer.convert_ids_to_tokens( - [new_token_id], skip_special_tokens=skip_special_tokens) - if isinstance(new_tokens, str): - new_tokens = [new_tokens] - output_tokens = prev_tokens + new_tokens - - # If this is the first iteration, return all tokens. - if is_first_iter: - new_tokens = output_tokens - - # The prefix text is necessary only to defeat cleanup algorithms in - # the decode which decide to add a space or not depending on the - # surrounding ids. - if tokenizer.is_fast or not tokenizer.get_added_vocab(): - prefix_text = tokenizer.convert_tokens_to_string( - output_tokens[prefix_offset:read_offset]) - new_text = tokenizer.convert_tokens_to_string( - output_tokens[prefix_offset:]) - else: - prefix_text = _convert_tokens_to_string_with_added_encoders( - tokenizer, - output_tokens[prefix_offset:read_offset], - skip_special_tokens=skip_special_tokens, - spaces_between_special_tokens=spaces_between_special_tokens, - ) - new_text = _convert_tokens_to_string_with_added_encoders( - tokenizer, - output_tokens[prefix_offset:], - skip_special_tokens=skip_special_tokens, - spaces_between_special_tokens=spaces_between_special_tokens, - ) - - if len(new_text) <= len(prefix_text) or new_text.endswith("�"): - # utf-8 char at the end means it's a potential unfinished byte sequence - # from byte fallback tokenization. - # If it's in the middle, it's probably a real invalid id generated - # by the model - return new_tokens, "", prefix_offset, read_offset - - new_text = new_text[len(prefix_text):] - return new_tokens, new_text, read_offset, len(output_tokens) diff --git a/vllm/transformers_utils/detokenizer_utils.py b/vllm/transformers_utils/detokenizer_utils.py new file mode 100644 index 0000000000000..795fb34478aaf --- /dev/null +++ b/vllm/transformers_utils/detokenizer_utils.py @@ -0,0 +1,167 @@ +from typing import Dict, List, Optional, Tuple + +from .tokenizer import AnyTokenizer + + +def _replace_none_with_empty(tokens: List[Optional[str]]): + for i, token in enumerate(tokens): + if token is None: + tokens[i] = "" + + +def _convert_tokens_to_string_with_added_encoders( + tokenizer: AnyTokenizer, + output_tokens: List[str], + skip_special_tokens: bool, + spaces_between_special_tokens: bool, +) -> str: + # Adapted from + # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921 + # NOTE(woosuk): The following code is slow because it runs a for loop over + # the output_tokens. In Python, running a for loop over a list can be slow + # even when the loop body is very simple. + sub_texts: List[str] = [] + current_sub_text: List[str] = [] + all_special_tokens = set(tokenizer.all_special_tokens) + for token in output_tokens: + if skip_special_tokens and token in all_special_tokens: + continue + if token in tokenizer.get_added_vocab(): + if current_sub_text: + sub_text = tokenizer.convert_tokens_to_string(current_sub_text) + sub_texts.append(sub_text) + current_sub_text = [] + sub_texts.append(token) + else: + current_sub_text.append(token) + if current_sub_text: + sub_text = tokenizer.convert_tokens_to_string(current_sub_text) + sub_texts.append(sub_text) + if spaces_between_special_tokens: + return " ".join(sub_texts) + else: + return "".join(sub_texts) + + +# 5 is an arbitrary value that should work for all +# tokenizers (bigger = more conservative). +INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5 + + +def convert_prompt_ids_to_tokens( + tokenizer: AnyTokenizer, + prompt_ids: List[int], + skip_special_tokens: bool = False, +) -> Tuple[List[str], int, int]: + """Converts the prompt ids to tokens and returns the tokens and offsets + for incremental detokenization. + + Note that not all tokens are converted to strings. Only the tokens that + are necessary for incremental detokenization are converted to strings. + """ + # We do not need to convert the whole prompt to tokens. + # Offset a little more in case we have special tokens. + new_tokens = tokenizer.convert_ids_to_tokens( + prompt_ids[-INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET - 2:], + skip_special_tokens=skip_special_tokens) + read_offset = len(new_tokens) + prefix_offset = max( + read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0) + # This is required to guard against out-of-vocab prompt token ids + _replace_none_with_empty(new_tokens) # type: ignore[arg-type] + return new_tokens, prefix_offset, read_offset + + +# Based on +# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15 +# under Apache 2.0 license +def detokenize_incrementally( + tokenizer: AnyTokenizer, + all_input_ids: List[int], + prev_tokens: Optional[List[str]], + prefix_offset: int, + read_offset: int, + skip_special_tokens: bool = False, + spaces_between_special_tokens: bool = True, +) -> Tuple[List[str], str, int, int]: + """Detokenizes the input ids incrementally and returns the new tokens + and the new text. + + If `prev_tokens` is None, this function will convert the input ids to + tokens and return the tokens and the new text. Otherwise, it will return the + new tokens and the new text. + + This function will also return the new prefix offset and the new read + offset to be used in the next iteration. + + The offsets are necessary to defeat cleanup algorithms in the decode which + decide to add a space or not depending on the surrounding ids. + + Args: + tokenizer: The tokenizer to use. + all_input_ids: The input ids. The last id is the new token id. + prev_tokens: The previous tokens. If None, this function will convert + the input ids to tokens and return the tokens and the new text. + prefix_offset: The prefix offset. + read_offset: The read offset. + skip_special_tokens: Whether to skip special tokens. + spaces_between_special_tokens: Whether to add spaces between special + tokens. + """ + new_token_id = all_input_ids[-1] + # This is the first iteration for this sequence + is_first_iter = prev_tokens is None + if is_first_iter: + (prev_tokens, prefix_offset, + read_offset) = convert_prompt_ids_to_tokens( + tokenizer, + all_input_ids[:-1], + skip_special_tokens=skip_special_tokens) + assert prev_tokens is not None + + # If the new token id is out of bounds, return an empty string. + if new_token_id >= len(tokenizer): + new_tokens = [""] + else: + # Put new_token_id in a list so skip_special_tokens is respected + new_tokens = tokenizer.convert_ids_to_tokens( + [new_token_id], skip_special_tokens=skip_special_tokens) + if isinstance(new_tokens, str): + new_tokens = [new_tokens] + output_tokens = prev_tokens + new_tokens + + # If this is the first iteration, return all tokens. + if is_first_iter: + new_tokens = output_tokens + + # The prefix text is necessary only to defeat cleanup algorithms in + # the decode which decide to add a space or not depending on the + # surrounding ids. + if tokenizer.is_fast or not tokenizer.get_added_vocab(): + prefix_text = tokenizer.convert_tokens_to_string( + output_tokens[prefix_offset:read_offset]) + new_text = tokenizer.convert_tokens_to_string( + output_tokens[prefix_offset:]) + else: + prefix_text = _convert_tokens_to_string_with_added_encoders( + tokenizer, + output_tokens[prefix_offset:read_offset], + skip_special_tokens=skip_special_tokens, + spaces_between_special_tokens=spaces_between_special_tokens, + ) + new_text = _convert_tokens_to_string_with_added_encoders( + tokenizer, + output_tokens[prefix_offset:], + skip_special_tokens=skip_special_tokens, + spaces_between_special_tokens=spaces_between_special_tokens, + ) + + if len(new_text) <= len(prefix_text) or new_text.endswith("�"): + # utf-8 char at the end means it's a potential unfinished byte sequence + # from byte fallback tokenization. + # If it's in the middle, it's probably a real invalid id generated + # by the model + return new_tokens, "", prefix_offset, read_offset + + new_text = new_text[len(prefix_text):] + return new_tokens, new_text, read_offset, len(output_tokens) diff --git a/vllm/transformers_utils/detokenizer_v2.py b/vllm/transformers_utils/detokenizer_v2.py new file mode 100644 index 0000000000000..486f06ade92f0 --- /dev/null +++ b/vllm/transformers_utils/detokenizer_v2.py @@ -0,0 +1,150 @@ +import multiprocessing +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +import msgspec +import zmq + +from .tokenizer import get_tokenizer +from .detokenizer_utils import ( + convert_prompt_ids_to_tokens, + detokenize_incrementally) + + +class RequestData(msgspec.Struct): + + # [num_reqs] + request_ids: List[str] + prompt_token_ids: List[List[int]] + new_token_ids: List[List[int]] + skip_special_tokens: List[bool] + spaces_between_special_tokens: List[bool] + + # [num_free_reqs] + free_request_ids: List[str] + + +class DetokenizedData(msgspec.Struct): + + # [num_reqs] + request_ids: List[str] + detokenized_texts: List[str] + + +@dataclass +class RequestState: + + req_id: str + + token_ids: List[int] + tokens: List[str] + + prefix_offset: int + read_offset: int + + skip_special_tokens: bool + spaces_between_special_tokens: bool + + output_text: str = "" + + +class Detokenizer(multiprocessing.Process): + + def __init__( + self, + tokenizer_name: str, + port1: int, + port2: int, + ): + super().__init__() + self.port1 = port1 + self.port2 = port2 + self.encoder = msgspec.msgpack.Encoder() + self.decoder = msgspec.msgpack.Decoder(RequestData) + + self.tokenizer = get_tokenizer(tokenizer_name) + self.requests: Dict[str, RequestState] = {} + + def run(self): + self.context = zmq.Context() + self.pull_socket = self.context.socket(zmq.PULL) + self.pull_socket.bind(f"tcp://*:{self.port1}") + self.push_socket = self.context.socket(zmq.PUSH) + self.push_socket.bind(f"tcp://*:{self.port2}") + + while True: + message = self.pull_socket.recv() + data = self.decoder.decode(message) + + for req_id in data.free_request_ids: + self.free(req_id) + + req_ids: List[str] = [] + detokenized_texts: List[str] = [] + num_reqs = len(data.request_ids) + for i in range(num_reqs): + req_id = data.request_ids[i] + req_ids.append(req_id) + if req_id not in self.requests: + self.add_request( + request_id=req_id, + prompt_token_ids=data.prompt_token_ids[i], + skip_special_tokens=data.skip_special_tokens[i], + spaces_between_special_tokens=data.spaces_between_special_tokens[i], + ) + new_str = self.detokenize(req_id, data.new_token_ids[i]) + detokenized_texts.append(new_str) + + detokenized = DetokenizedData( + request_ids=req_ids, + detokenized_texts=detokenized_texts, + ) + self.push_socket.send(self.encoder.encode(detokenized), flags=zmq.NOBLOCK) + + def add_request( + self, + request_id: str, + prompt_token_ids: List[int], + skip_special_tokens: bool, + spaces_between_special_tokens: bool, + ) -> None: + tokens, prefix_offset, read_offset = convert_prompt_ids_to_tokens( + tokenizer=self.tokenizer, + prompt_ids=prompt_token_ids, + skip_special_tokens=skip_special_tokens, + ) + self.requests[request_id] = RequestState( + req_id=request_id, + token_ids=prompt_token_ids, + tokens=tokens, + prefix_offset=prefix_offset, + read_offset=read_offset, + skip_special_tokens=skip_special_tokens, + spaces_between_special_tokens=spaces_between_special_tokens, + ) + + def free(self, request_id: str) -> None: + del self.requests[request_id] + + def detokenize(self, request_id: str, new_token_ids: List[int]) -> str: + req_state = self.requests[request_id] + decoded_text = "" + for new_token_id in new_token_ids: + req_state.token_ids.append(new_token_id) + (new_tokens, new_decoded_token_text, prefix_offset, + read_offset) = detokenize_incrementally( + tokenizer=self.tokenizer, + all_input_ids=req_state.token_ids, + prev_tokens=req_state.tokens, + prefix_offset=req_state.prefix_offset, + read_offset=req_state.read_offset, + skip_special_tokens=req_state.skip_special_tokens, + spaces_between_special_tokens=req_state.spaces_between_special_tokens, + ) + + req_state.tokens.extend(new_tokens) + req_state.prefix_offset = prefix_offset + req_state.read_offset = read_offset + req_state.output_text += new_decoded_token_text + decoded_text += new_decoded_token_text + return decoded_text From 9c15340d053099bb96ab95fcbdaf72bc032e72dc Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 7 Oct 2024 22:21:56 -0700 Subject: [PATCH 27/31] yapf --- vllm/worker/model_runner_v2.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/worker/model_runner_v2.py b/vllm/worker/model_runner_v2.py index d6ebb5285bc2b..de5290fbf29e6 100644 --- a/vllm/worker/model_runner_v2.py +++ b/vllm/worker/model_runner_v2.py @@ -644,10 +644,10 @@ def make_sampling_metadata( if not skip_copy: self.temperature[:self.num_reqs].copy_( self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True) - self.top_p[:self.num_reqs].copy_(self.top_p_cpu_tensor[:self.num_reqs], - non_blocking=True) - self.top_k[:self.num_reqs].copy_(self.top_k_cpu_tensor[:self.num_reqs], - non_blocking=True) + self.top_p[:self.num_reqs].copy_( + self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True) + self.top_k[:self.num_reqs].copy_( + self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True) return SamplingMetadata( temperature=self.temperature[:self.num_reqs], all_greedy=self.all_greedy, From 3777a59c33672a3397a378596e5e909c3461c7e0 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 7 Oct 2024 22:27:26 -0700 Subject: [PATCH 28/31] Minor --- vllm/transformers_utils/detokenizer_v2.py | 33 +++++++++++++---------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/vllm/transformers_utils/detokenizer_v2.py b/vllm/transformers_utils/detokenizer_v2.py index 486f06ade92f0..75033c6699d87 100644 --- a/vllm/transformers_utils/detokenizer_v2.py +++ b/vllm/transformers_utils/detokenizer_v2.py @@ -6,9 +6,8 @@ import zmq from .tokenizer import get_tokenizer -from .detokenizer_utils import ( - convert_prompt_ids_to_tokens, - detokenize_incrementally) +from .detokenizer_utils import (convert_prompt_ids_to_tokens, + detokenize_incrementally) class RequestData(msgspec.Struct): @@ -74,6 +73,9 @@ def run(self): while True: message = self.pull_socket.recv() + if message == b"": + # Terminate signal. + break data = self.decoder.decode(message) for req_id in data.free_request_ids: @@ -90,7 +92,8 @@ def run(self): request_id=req_id, prompt_token_ids=data.prompt_token_ids[i], skip_special_tokens=data.skip_special_tokens[i], - spaces_between_special_tokens=data.spaces_between_special_tokens[i], + spaces_between_special_tokens=data. + spaces_between_special_tokens[i], ) new_str = self.detokenize(req_id, data.new_token_ids[i]) detokenized_texts.append(new_str) @@ -99,7 +102,8 @@ def run(self): request_ids=req_ids, detokenized_texts=detokenized_texts, ) - self.push_socket.send(self.encoder.encode(detokenized), flags=zmq.NOBLOCK) + self.push_socket.send(self.encoder.encode(detokenized), + flags=zmq.NOBLOCK) def add_request( self, @@ -132,15 +136,16 @@ def detokenize(self, request_id: str, new_token_ids: List[int]) -> str: for new_token_id in new_token_ids: req_state.token_ids.append(new_token_id) (new_tokens, new_decoded_token_text, prefix_offset, - read_offset) = detokenize_incrementally( - tokenizer=self.tokenizer, - all_input_ids=req_state.token_ids, - prev_tokens=req_state.tokens, - prefix_offset=req_state.prefix_offset, - read_offset=req_state.read_offset, - skip_special_tokens=req_state.skip_special_tokens, - spaces_between_special_tokens=req_state.spaces_between_special_tokens, - ) + read_offset) = detokenize_incrementally( + tokenizer=self.tokenizer, + all_input_ids=req_state.token_ids, + prev_tokens=req_state.tokens, + prefix_offset=req_state.prefix_offset, + read_offset=req_state.read_offset, + skip_special_tokens=req_state.skip_special_tokens, + spaces_between_special_tokens=req_state. + spaces_between_special_tokens, + ) req_state.tokens.extend(new_tokens) req_state.prefix_offset = prefix_offset From f7e80625e8fa885f795ce6f4887aba2b90635b1b Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 7 Oct 2024 22:27:39 -0700 Subject: [PATCH 29/31] Hacky impl of detokenizer --- vllm/engine/llm_engine_v2.py | 117 +++++++++++++++++++++++------------ 1 file changed, 76 insertions(+), 41 deletions(-) diff --git a/vllm/engine/llm_engine_v2.py b/vllm/engine/llm_engine_v2.py index 1a41526027f24..6da2cd46578b1 100644 --- a/vllm/engine/llm_engine_v2.py +++ b/vllm/engine/llm_engine_v2.py @@ -9,6 +9,8 @@ from typing import Set, Type, Union, Tuple import torch +import msgspec +import zmq from typing_extensions import TypeVar import vllm.envs as envs @@ -33,7 +35,7 @@ from vllm.sampling_params import SamplingParams from vllm.request import Request from vllm.transformers_utils.config import try_get_generation_config -from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.transformers_utils.detokenizer_v2 import Detokenizer, RequestData, DetokenizedData from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import ( BaseTokenizerGroup, init_tokenizer_from_configs) @@ -41,6 +43,7 @@ usage_message) from vllm.version import __version__ as VLLM_VERSION from vllm.request import Request +from vllm.utils import get_open_port logger = init_logger(__name__) _LOCAL_LOGGING_INTERVAL_SEC = 5 @@ -59,9 +62,6 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: return config.to_diff_dict() -_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup) - - class LLMEngine: """An LLM engine that receives requests and generates texts. @@ -183,14 +183,28 @@ def __init__( ) self.log_stats = log_stats - if not self.model_config.skip_tokenizer_init: - self.tokenizer = self._init_tokenizer() - self.detokenizer = Detokenizer(self.tokenizer) - tokenizer_group = self.get_tokenizer_group() - else: - self.tokenizer = None - self.detokenizer = None - tokenizer_group = None + # Detokenizer + # FIXME(woosuk): This is a temporary hack. + assert not self.model_config.skip_tokenizer_init + self.tokenizer = self._init_tokenizer() + + self.port1 = get_open_port() + self.port2 = get_open_port() + self.detokenizer = Detokenizer(self.model_config.tokenizer, self.port1, + self.port2) + self.detokenizer.start() + + self.context = zmq.Context() + self.push_socket = self.context.socket(zmq.PUSH) + self.push_socket.connect(f"tcp://localhost:{self.port1}") + self.pull_socket = self.context.socket(zmq.PULL) + self.pull_socket.connect(f"tcp://localhost:{self.port2}") + self.poller = zmq.Poller() + self.poller.register(self.pull_socket, zmq.POLLIN) + + self.encoder = msgspec.msgpack.Encoder() + self.decoder = msgspec.msgpack.Decoder(DetokenizedData) + self.detokenizer_reqs: Dict[str, Request] = {} self.generation_config_fields = _load_generation_config_dict( model_config) @@ -332,28 +346,6 @@ def from_engine_args( ) return engine - def get_tokenizer_group( - self, - group_type: Type[_G] = BaseTokenizerGroup, - ) -> _G: - tokenizer_group = self.tokenizer - - if tokenizer_group is None: - raise ValueError("Unable to get tokenizer because " - "skip_tokenizer_init is True") - if not isinstance(tokenizer_group, group_type): - raise TypeError("Invalid type of tokenizer group. " - f"Expected type: {group_type}, but " - f"found type: {type(tokenizer_group)}") - - return tokenizer_group - - def get_tokenizer( - self, - lora_request: Optional[LoRARequest] = None, - ) -> AnyTokenizer: - return self.get_tokenizer_group().get_lora_tokenizer(lora_request) - def _init_tokenizer(self) -> BaseTokenizerGroup: return init_tokenizer_from_configs( model_config=self.model_config, @@ -518,19 +510,62 @@ def get_lora_config(self) -> LoRAConfig: def get_num_unfinished_requests(self) -> int: """Gets the number of unfinished requests.""" + # FIXME(woosuk) return self.scheduler.get_num_unfinished_requests() def has_unfinished_requests(self) -> bool: """Returns True if there are unfinished requests.""" - return self.scheduler.has_unfinished_requests() + # FIXME(woosuk) + return (self.scheduler.has_unfinished_requests() + or len(self.detokenizer_reqs) > 0) def step(self) -> Tuple[List[Request], List[Request]]: - scheduler_output = self.scheduler.schedule() - output = self.model_executor.execute_model(scheduler_output) - finished_reqs = self.scheduler.update_from_output( - scheduler_output, output) - running_reqs = self.scheduler.running - return finished_reqs, running_reqs + if self.scheduler.has_unfinished_requests(): + scheduler_output = self.scheduler.schedule() + output = self.model_executor.execute_model(scheduler_output) + finished_reqs = self.scheduler.update_from_output( + scheduler_output, output) + + if finished_reqs: + for req in finished_reqs: + self.detokenizer_reqs[req.request_id] = req + self.send_to_detokenizer(finished_reqs) + detokenized_reqs = self.recv_from_detokenizer() + return detokenized_reqs, self.scheduler.running + + def send_to_detokenizer(self, requests: List[Request]) -> None: + data = RequestData( + request_ids=[req.request_id for req in requests], + prompt_token_ids=[req.prompt_token_ids for req in requests], + new_token_ids=[req.output_token_ids for req in requests], + skip_special_tokens=[ + req.sampling_params.skip_special_tokens for req in requests + ], + spaces_between_special_tokens=[ + req.sampling_params.spaces_between_special_tokens + for req in requests + ], + free_request_ids=[], + ) + self.push_socket.send(self.encoder.encode(data), flags=zmq.NOBLOCK) + + def recv_from_detokenizer(self) -> List[Request]: + detokenized_reqs: List[Request] = [] + socks = dict(self.poller.poll(timeout=0)) + if self.pull_socket in socks and socks[self.pull_socket] == zmq.POLLIN: + msg = self.pull_socket.recv() + data = self.decoder.decode(msg) + num_reqs = len(data.request_ids) + for i in range(num_reqs): + req_id = data.request_ids[i] + assert req_id in self.detokenizer_reqs + req = self.detokenizer_reqs.pop(req_id) + req.output_text += data.detokenized_texts[i] + detokenized_reqs.append(req) + return detokenized_reqs + + def terminate_detokenizer(self) -> None: + self.push_socket.send(b"", flags=zmq.NOBLOCK) def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None: if not self.log_stats: From 37b0d99384eae374475619ae1d6bb1b18c8c4de8 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 7 Oct 2024 22:28:45 -0700 Subject: [PATCH 30/31] terminate --- vllm/entrypoints/llm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 7257bc4e3c966..8a0c54a322326 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -743,6 +743,7 @@ def _run_engine( if use_tqdm: pbar.close() + self.llm_engine.terminate_detokenizer() # Sort the outputs by request ID. # This is necessary because some requests may be finished earlier than # its previous requests. From f40d51ae9c0ccbbfb4ec6412a9bbdf7e86410c25 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 7 Oct 2024 22:28:55 -0700 Subject: [PATCH 31/31] Add TODO --- vllm/transformers_utils/detokenizer_v2.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/transformers_utils/detokenizer_v2.py b/vllm/transformers_utils/detokenizer_v2.py index 75033c6699d87..36c1bfe3ed51b 100644 --- a/vllm/transformers_utils/detokenizer_v2.py +++ b/vllm/transformers_utils/detokenizer_v2.py @@ -131,6 +131,8 @@ def free(self, request_id: str) -> None: del self.requests[request_id] def detokenize(self, request_id: str, new_token_ids: List[int]) -> str: + # TODO(woosuk): This method becomes very inefficient when the number of + # new_token_ids is more than 1. We need to optimize this. req_state = self.requests[request_id] decoded_text = "" for new_token_id in new_token_ids: