diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index e4a93814b88..7c18c683e96 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -1,132 +1,647 @@ -# Copyright 2023-2024 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""MRotaryEmbedding""" +# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/refs/tags/v0.6.6.post1/vllm/model_executor/layers/rotary_embedding.py + +"""Rotary Positional Embeddings.""" +import math from typing import Any, Dict, List, Optional, Tuple, Union import torch -from vllm.model_executor.layers.rotary_embedding import ( - RotaryEmbedding, - _rotate_gptj, - _rotate_neox, - _yarn_find_correction_range, - _yarn_linear_ramp_mask, - get_rope, - yarn_get_mscale, -) - - -class MRotaryEmbedding: - """Rotary Embedding with Multimodal Sections.""" +import torch.nn as nn +from vllm.model_executor.custom_op import CustomOp - @staticmethod - def get_input_positions( - input_tokens: torch.Tensor, - image_grid_thw: Union[List[List[int]], torch.Tensor], - vision_start_token_id: int, - spatial_merge_size: int, - context_len: int = 0, - ) -> Tuple[List[List[int]], int]: - """Get mrope input positions and delta value.""" - if isinstance(image_grid_thw, torch.Tensor): - image_grid_thw = image_grid_thw.tolist() +def _rotate_neox(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) - vision_start_indices = torch.argwhere( - input_tokens == vision_start_token_id - ).squeeze(1) - image_indices = vision_start_indices + 1 - image_nums = image_indices.shape[0] - llm_pos_ids_list: list = [] - st = 0 - input_tokens_len = input_tokens.shape[0] - for image_index in range(image_nums): - ed = image_indices[image_index].item() - t, h, w = ( - image_grid_thw[image_index][0], - image_grid_thw[image_index][1], - image_grid_thw[image_index][2], - ) - llm_grid_t, llm_grid_h, llm_grid_w = ( - t, - h // spatial_merge_size, - w // spatial_merge_size, +def _rotate_gptj(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., ::2] + x2 = x[..., 1::2] + x = torch.stack((-x2, x1), dim=-1) + return x.flatten(-2) + + +def _apply_rotary_emb( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_neox_style: bool, +) -> torch.Tensor: + """ + Args: + x: [num_tokens, num_heads, head_size] + cos: [num_tokens, head_size // 2] + sin: [num_tokens, head_size // 2] + is_neox_style: Whether to use the Neox-style or GPT-J-style rotary + positional embeddings. + """ + cos = cos.unsqueeze(-2).to(x.dtype) + sin = sin.unsqueeze(-2).to(x.dtype) + if is_neox_style: + x1, x2 = torch.chunk(x, 2, dim=-1) + else: + x1 = x[..., ::2] + x2 = x[..., 1::2] + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + if is_neox_style: + return torch.cat((o1, o2), dim=-1) + else: + return torch.stack((o1, o2), dim=-1).flatten(-2) + + +@CustomOp.register("rotary_embedding") +class RotaryEmbedding(CustomOp): + """Original rotary positional embedding.""" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + self.dtype = dtype + + cache = self._compute_cos_sin_cache() + cache = cache.to(dtype) + self.cos_sin_cache: torch.Tensor + self.register_buffer("cos_sin_cache", cache, persistent=False) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + """Compute the inverse frequency.""" + # NOTE(woosuk): To exactly match the HF implementation, we need to + # use CPU to compute the cache and then move it to GPU. However, we + # create the cache on GPU for faster initialization. This may cause + # a slight numerical difference between the HF implementation and ours. + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim ) - text_len = ed - st + ) + return inv_freq - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + t = torch.arange(self.max_position_embeddings, dtype=torch.float) + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward_native( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """A PyTorch-native implementation of forward().""" + if offsets is not None: + positions = positions + offsets + positions = positions.flatten() + num_tokens = positions.shape[0] + cos_sin = self.cos_sin_cache.index_select(0, positions) + cos, sin = cos_sin.chunk(2, dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] + query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] + key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + + def forward_cuda( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + from vllm import _custom_ops as ops + + self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) + # ops.rotary_embedding()/batched_rotary_embedding() + # are in-place operations that update the query and key tensors. + if offsets is not None: + ops.batched_rotary_embedding( + positions, + query, + key, + self.head_size, + self.cos_sin_cache, + self.is_neox_style, + self.rotary_dim, + offsets, ) + else: + ops.rotary_embedding( + positions, + query, + key, + self.head_size, + self.cos_sin_cache, + self.is_neox_style, + ) + return query, key - t_index = ( - torch.arange(llm_grid_t) - .view(-1, 1) - .expand(-1, llm_grid_h * llm_grid_w) - .flatten() + def forward_xpu( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + from vllm._ipex_ops import ipex_ops as ops + + self.cos_sin_cache = self.cos_sin_cache.to(positions.device, dtype=query.dtype) + # ops.rotary_embedding()/batched_rotary_embedding() + # are in-place operations that update the query and key tensors. + if offsets is not None: + ops.batched_rotary_embedding( + positions, + query, + key, + self.head_size, + self.cos_sin_cache, + self.is_neox_style, + self.rotary_dim, + offsets, ) - h_index = ( - torch.arange(llm_grid_h) - .view(1, -1, 1) - .expand(llm_grid_t, -1, llm_grid_w) - .flatten() + else: + ops.rotary_embedding( + positions, + query, + key, + self.head_size, + self.cos_sin_cache, + self.is_neox_style, ) - w_index = ( - torch.arange(llm_grid_w) - .view(1, 1, -1) - .expand(llm_grid_t, llm_grid_h, -1) - .flatten() + return query, key + + def forward_hpu( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + from habana_frameworks.torch.hpex.kernels import ( + RotaryPosEmbeddingMode, + apply_rotary_pos_emb, + ) + + positions = positions.flatten() + if offsets is not None: + positions = positions + offsets + num_tokens = positions.shape[0] + cos_sin = self.cos_sin_cache.index_select(0, positions).view(num_tokens, 1, -1) + cos, sin = cos_sin.chunk(2, dim=-1) + # HPU RoPE kernel requires hidden dimension for cos and sin to be equal + # to query hidden dimension, so the original tensors need to be + # expanded + # GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE + # and expansion of cos/sin tensors via concatenation + # GPT-J kernel requires position_ids = None, offset = 0, mode = PAIRWISE + # and expansion of cos/sin tensors via repeat_interleave + rope_mode: RotaryPosEmbeddingMode + if self.is_neox_style: + rope_mode = RotaryPosEmbeddingMode.BLOCKWISE + cos = torch.cat((cos, cos), dim=-1) + sin = torch.cat((sin, sin), dim=-1) + else: + rope_mode = RotaryPosEmbeddingMode.PAIRWISE + sin = torch.repeat_interleave(sin, 2, dim=-1, output_size=cos_sin.shape[-1]) + cos = torch.repeat_interleave(cos, 2, dim=-1, output_size=cos_sin.shape[-1]) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] + query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] + key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + + def extra_repr(self) -> str: + s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" + s += f", max_position_embeddings={self.max_position_embeddings}" + s += f", base={self.base}, is_neox_style={self.is_neox_style}" + return s + + +class LinearScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with linear scaling. + + It supports multiple scaling factors. Since multiple LoRA adapters may have + different scaling factors, we need multiple cos/sin caches. In this way, + instead of running rotary embedding kernel per lora, we can run multiple + lora in a batched way. + + In addition to that, we also keep the cos/sin cache for the scaling factor + of 1 (default) at all times. + + Exemplary for two scaling factors x=1, y and z with embeddings + [[x11, x12, ... x1m], ..., [xn1, xn2, ..., xnm]] and + [[y11, y12, ... y1o], ..., [yn1, yn2, ..., yno]], and + [[z11, z12, ... z1p], ..., [zn1, zn2, ..., znp]], + + we construct the cos/sin cache as follows: + [[x11, x12, ... x1m, y11, y12, ... y1o, z11, z12, ... z1p], + ... + [xn1, xn2, ... xnm, yn1, yn2, ... yno, zn1, zn2, ... znp]] + + We then use offsets to index into the cos/sin cache for + the respective scaling factors. + + The offset to cache can be accessed via `scaling_factor_to_offset` API. + + Credits to the Reddit user /u/kaiokendev + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factors: Union[List[float], float], + dtype: torch.dtype, + ) -> None: + if isinstance(scaling_factors, float): + scaling_factors = [scaling_factors] + self.scaling_factors: List[float] = scaling_factors # noqa + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) + # Lazy initialized. + self._scaling_factor_to_offset: Dict[float, int] + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.base) + cache_list: List[torch.Tensor] = [] + # offsets to the next cache in a tensor. + # Each offset corresponds to the same index in scaling_factors. + offsets: List[int] = [] + for scaling_factor in self.scaling_factors: + # NOTE(woosuk): self.max_position_embeddings is the original + # maximum length before applying the rope scaling. + # Thus, the maximum length after applying the rope scaling is + # self.max_position_embeddings * self.scaling_factor. + max_len = self.max_position_embeddings * scaling_factor + t = torch.arange(max_len, dtype=torch.float) + t = t / scaling_factor + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + if not cache_list: + offset = 0 + else: + last_offset = offsets[-1] + next_max_len = cache_list[-1].shape[0] + offset = last_offset + next_max_len + offsets.append(offset) + cache_list.append(cache) + self._scaling_factor_to_offset = { + float(scaling_factor): offsets[i] + for i, scaling_factor in enumerate(self.scaling_factors) + } + assert len(self.scaling_factors) == len(offsets) + return torch.cat(cache_list, dim=0) + + @property + def scaling_factor_to_offset(self) -> Dict[float, int]: + return self._scaling_factor_to_offset + + +class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with Dynamic NTK scaling. + + Credits to the Reddit users /u/bloc97 and /u/emozilla + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + ) -> None: + self.scaling_factor = scaling_factor + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) + + def _compute_cos_sin_cache(self) -> torch.Tensor: + # NOTE(woosuk): self.max_position_embeddings is the original + # maximum length before applying the rope scaling. + # Thus, the maximum length after applying the rope scaling is + # self.max_position_embeddings * self.scaling_factor. + max_len = self.max_position_embeddings * self.scaling_factor + base = self.base * ( + (self.scaling_factor * max_len / self.max_position_embeddings) + - (self.scaling_factor - 1) + ) ** (self.rotary_dim / (self.rotary_dim - 2)) + inv_freq = self._compute_inv_freq(base) + t = torch.arange(max_len, dtype=torch.float) + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + +# Inverse dim formula to find dim based on number of rotations +def _yarn_find_correction_dim( + num_rotations: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048, +) -> float: + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + +# Find dim range bounds based on rotations +def _yarn_find_correction_range( + low_rot: int, + high_rot: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048, +) -> Tuple[int, int]: + low = math.floor( + _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + _yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def _yarn_linear_ramp_mask( + low: float, high: float, dim: int, dtype: torch.dtype +) -> torch.Tensor: + if low == high: + high += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +def _yarn_get_mscale(scale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * math.log(scale) + 1.0 + + +class YaRNScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with YaRN method. + + Credits to Peng et al. github.com/jquesnelle/yarn + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + *, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + ) -> None: + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + # Get n-d magnitude scaling corrected for interpolation + self.mscale = float(_yarn_get_mscale(self.scaling_factor) * attn_factor) + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) + + def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: + pos_freqs = self.base ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim + ) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) + + low, high = _yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + self.rotary_dim, + self.base, + self.max_position_embeddings, + ) + # Get n-d rotational scaling corrected for extrapolation + inv_freq_mask = ( + 1 + - _yarn_linear_ramp_mask(low, high, self.rotary_dim // 2, dtype=torch.float) + ) * self.extrapolation_factor + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_mask) + + inv_freq_extrapolation * inv_freq_mask + ) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.scaling_factor) + t = torch.arange( + self.max_position_embeddings * self.scaling_factor, dtype=torch.float32 + ) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() * self.mscale + sin = freqs.sin() * self.mscale + cache = torch.cat((cos, sin), dim=-1) + return cache + + +class Phi3LongRoPEScaledRotaryEmbedding(nn.Module): + """Phi3 family of models scaled rotary embedding. + + Based on the original RotaryEmbedding implementation. + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + original_max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + short_factor: List[float], + long_factor: List[float], + short_mscale: Optional[float] = None, + long_mscale: Optional[float] = None, + ): + super().__init__() + + if rotary_dim != head_size: + raise ValueError( + f"`Phi3LongRoPEScaledRotaryEmbedding` does not support \ + rotary_dim != head_size ({rotary_dim}!={head_size})." ) - llm_pos_ids_list.append( - torch.stack([t_index, h_index, w_index]) + text_len + st_idx + if is_neox_style is False: + raise ValueError( + "`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style." ) - st = ed + llm_grid_t * llm_grid_h * llm_grid_w - if st < input_tokens_len: - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - text_len = input_tokens_len - st - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + self.head_size = head_size + self.max_position_embeddings = max_position_embeddings + self.original_max_position_embeddings = original_max_position_embeddings + self.base = base + self.short_factor = short_factor + self.long_factor = long_factor + + scale = self.max_position_embeddings / self.original_max_position_embeddings + if scale <= 1.0: + scaling_factor = 1.0 + else: + scaling_factor = math.sqrt( + 1 + math.log(scale) / math.log(self.original_max_position_embeddings) ) + if short_mscale is None: + short_mscale = scaling_factor + if long_mscale is None: + long_mscale = scaling_factor - llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - llm_positions = llm_positions[:, context_len:] - mrope_position_delta = (llm_positions.max() + 1 - input_tokens_len).item() - return llm_positions.tolist(), mrope_position_delta + self.short_mscale = short_mscale + self.long_mscale = long_mscale - @staticmethod - def get_next_input_positions( - mrope_position_delta: int, - context_len: int, - seq_len: int, - ) -> List[List[int]]: - return [ - list( - range( - context_len + mrope_position_delta, seq_len + mrope_position_delta + short_cache = self._compute_cos_sin_cache( + original_max_position_embeddings, short_factor, short_mscale + ) + short_cache = short_cache.to(dtype) + self.register_buffer("short_cos_sin_cache", short_cache, persistent=False) + + long_cache = self._compute_cos_sin_cache( + max_position_embeddings, long_factor, long_mscale + ) + long_cache = long_cache.to(dtype) + self.register_buffer("long_cos_sin_cache", long_cache, persistent=False) + + long_short_cache = torch.cat( + [self.short_cos_sin_cache, self.long_cos_sin_cache], dim=0 + ) + self.register_buffer( + "long_short_cos_sin_cache", long_short_cache, persistent=False + ) + + def _compute_inv_freq(self, rescale_factors: List[float]) -> torch.Tensor: + rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32) + inv_freq = 1.0 / ( + rescale_factors + * ( + self.base + ** ( + torch.arange(0, self.head_size, 2, dtype=torch.float) + / self.head_size ) ) - for _ in range(3) - ] + ) + return inv_freq + + def _compute_cos_sin_cache( + self, + max_position_embeddings: int, + rescale_factors: List[float], + mscale: float, + ) -> torch.Tensor: + inv_freq = self._compute_inv_freq(rescale_factors) + t = torch.arange(max_position_embeddings, dtype=torch.float) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() * mscale + sin = freqs.sin() * mscale + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + query = query.view(*query.shape[:-1], -1, self.head_size) + key = key.view(*key.shape[:-1], -1, self.head_size) + + k = self.original_max_position_embeddings + long_prompt_offset = ( + torch.any(positions > k).float() * torch.full_like(positions, k) + ).long() + idx = ( + torch.add(positions, long_prompt_offset) + if long_prompt_offset is not None + else positions + ) + self.long_short_cos_sin_cache: torch.Tensor = self.long_short_cos_sin_cache.to( + idx.device + ) + idx = torch.add(idx, offsets) if offsets is not None else idx + cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx) + + cos, sin = cos_sin.chunk(2, dim=-1) + cos = cos.repeat(1, 2).unsqueeze(-2) + sin = sin.repeat(1, 2).unsqueeze(-2) + + query = query * cos + _rotate_neox(query) * sin + key = key * cos + _rotate_neox(key) * sin + + return query.flatten(-2), key.flatten(-2) + + +def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 -# TODO: in the DeepseekScalingRotaryEmbedding class defined in vllm, -# the device has been hard-coded to "cuda" in these two places: -# https://github.com/vllm-project/vllm/blob/8a1f938e6f02052df0f4953c149410605a2d56d8/vllm/model_executor/layers/rotary_embedding.py#L646 -# https://github.com/vllm-project/vllm/blob/8a1f938e6f02052df0f4953c149410605a2d56d8/vllm/model_executor/layers/rotary_embedding.py#L665 -# We port the related code to this file to make it compatible with the CPU version. -# We will add an optimized rotary embedding kernel for CPU and will remove the ported code then. class DeepseekScalingRotaryEmbedding(RotaryEmbedding): """RotaryEmbedding extended with YaRN method. @@ -149,7 +664,6 @@ def __init__( beta_slow: int = 1, mscale: float = 1, mscale_all_dim: float = 0, - device: Optional[str] = None, ) -> None: self.scaling_factor = scaling_factor self.extrapolation_factor = extrapolation_factor @@ -162,14 +676,13 @@ def __init__( / yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) * attn_factor ) - self.device = device super().__init__( head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype ) def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: pos_freqs = self.base ** ( - torch.arange(0, self.rotary_dim, 2, dtype=torch.float, device=self.device) + torch.arange(0, self.rotary_dim, 2, dtype=torch.float, device="cuda") / self.rotary_dim ) inv_freq_extrapolation = 1.0 / pos_freqs @@ -197,7 +710,7 @@ def _compute_cos_sin_cache(self) -> torch.Tensor: inv_freq = self._compute_inv_freq(self.scaling_factor) t = torch.arange( self.max_position_embeddings * self.scaling_factor, - device=self.device, + device="cuda", dtype=torch.float32, ) freqs = torch.einsum("i,j -> ij", t, inv_freq) @@ -248,10 +761,249 @@ def forward( return query, key +class Llama3RotaryEmbedding(RotaryEmbedding): + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + scaling_factor: float, + low_freq_factor: float, + high_freq_factor: float, + orig_max_position: int, + ) -> None: + self.scaling_factor = scaling_factor + self.low_freq_factor = low_freq_factor + self.high_freq_factor = high_freq_factor + self.orig_max_position = orig_max_position + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + inv_freqs = super()._compute_inv_freq(base) + low_freq_wavelen = self.orig_max_position / self.low_freq_factor + high_freq_wavelen = self.orig_max_position / self.high_freq_factor + + wave_len = 2 * math.pi / inv_freqs + if self.low_freq_factor != self.high_freq_factor: + smooth = (self.orig_max_position / wave_len - self.low_freq_factor) / ( + self.high_freq_factor - self.low_freq_factor + ) + else: + smooth = 0 + new_freqs = torch.where( + wave_len < high_freq_wavelen, + inv_freqs, + torch.where( + wave_len > low_freq_wavelen, + inv_freqs / self.scaling_factor, + (1 - smooth) * inv_freqs / self.scaling_factor + smooth * inv_freqs, + ), + ) + return new_freqs + + +class MRotaryEmbedding(RotaryEmbedding): + """Rotary Embedding with Multimodal Sections.""" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + mrope_section: Optional[List[int]] = None, + ) -> None: + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) + + self.mrope_section = mrope_section + if self.mrope_section: + assert sum(self.mrope_section) == rotary_dim // 2 + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """PyTorch-native implementation equivalent to forward(). + + Args: + positions: + [num_tokens,] (text only) or + [3, num_tokens] (T/H/W positions with multimodal inputs) + query: [num_tokens, num_heads * head_size] + key: [num_tokens, num_kv_heads * head_size] + """ + assert positions.ndim == 1 or positions.ndim == 2 + + num_tokens = positions.shape[-1] + cos_sin = self.cos_sin_cache[positions] + cos, sin = cos_sin.chunk(2, dim=-1) + if positions.ndim == 2: + assert self.mrope_section + + cos = torch.cat( + [m[i] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))], + dim=-1, + ) + sin = torch.cat( + [m[i] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))], + dim=-1, + ) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] + query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] + key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + + @staticmethod + def get_input_positions( + input_tokens: List[int], + image_grid_thw: Union[List[List[int]], torch.Tensor], + video_grid_thw: Union[List[List[int]], torch.Tensor], + image_token_id: int, + video_token_id: int, + vision_start_token_id: int, + vision_end_token_id: int, + spatial_merge_size: int, + context_len: int = 0, + seq_len: Optional[int] = None, + ) -> Tuple[List[List[int]], int]: + """Get mrope input positions and delta value.""" + + if isinstance(image_grid_thw, torch.Tensor): + image_grid_thw = image_grid_thw.tolist() + if isinstance(video_grid_thw, torch.Tensor): + video_grid_thw = video_grid_thw.tolist() + + input_tokens_tensor = torch.tensor(input_tokens) + vision_start_indices = torch.argwhere( + input_tokens_tensor == vision_start_token_id + ).squeeze(1) + vision_tokens = input_tokens_tensor[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + llm_pos_ids_list: list = [] + + st = 0 + remain_images, remain_videos = image_nums, video_nums + + image_index, video_index = 0, 0 + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t, + h // spatial_merge_size, + w // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + t_index = ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + ) + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx + ) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() + llm_positions = llm_positions[:, context_len:seq_len] + + return llm_positions.tolist(), mrope_position_delta + + @staticmethod + def get_next_input_positions( + mrope_position_delta: int, + context_len: int, + seq_len: int, + ) -> List[List[int]]: + return [ + list( + range( + context_len + mrope_position_delta, seq_len + mrope_position_delta + ) + ) + for _ in range(3) + ] + + _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} -def get_rope_cpu( +def get_rope( head_size: int, rotary_dim: int, max_position: int, @@ -260,7 +1012,6 @@ def get_rope_cpu( rope_scaling: Optional[Dict[str, Any]] = None, dtype: Optional[torch.dtype] = None, partial_rotary_factor: float = 1.0, - device: Optional[str] = None, ) -> RotaryEmbedding: if dtype is None: dtype = torch.get_default_dtype() @@ -286,75 +1037,140 @@ def get_rope_cpu( if key in _ROPE_DICT: return _ROPE_DICT[key] - assert rope_scaling is not None - scaling_type = rope_scaling["rope_type"] - assert ( - scaling_type == "deepseek_yarn" - ), "Only deepseek_yarn is supported for CPU for now" - - scaling_factor = rope_scaling["factor"] - original_max_position = rope_scaling["original_max_position_embeddings"] - # assert max_position == original_max_position * scaling_factor - extra_kwargs = { - k: v - for k, v in rope_scaling.items() - if k - in ( - "extrapolation_factor", - "attn_factor", - "beta_fast", - "beta_slow", - "mscale", - "mscale_all_dim", + if rope_scaling is None: + rotary_emb = RotaryEmbedding( + head_size, rotary_dim, max_position, base, is_neox_style, dtype ) - } - extra_kwargs["device"] = device - rotary_emb = DeepseekScalingRotaryEmbedding( - head_size, - rotary_dim, - original_max_position, - base, - is_neox_style, - scaling_factor, - dtype, - **extra_kwargs, - ) + else: + scaling_type = rope_scaling["rope_type"] + if scaling_type == "llama3": + scaling_factor = rope_scaling["factor"] + low_freq_factor = rope_scaling["low_freq_factor"] + high_freq_factor = rope_scaling["high_freq_factor"] + original_max_position = rope_scaling["original_max_position_embeddings"] + rotary_emb = Llama3RotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + scaling_factor, + low_freq_factor, + high_freq_factor, + original_max_position, + ) + elif scaling_type == "default": + if "mrope_section" in rope_scaling: + rotary_emb = MRotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + mrope_section=rope_scaling["mrope_section"], + ) + else: + rotary_emb = RotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + ) + elif scaling_type == "linear": + scaling_factor = rope_scaling["factor"] + rotary_emb = LinearScalingRotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + scaling_factor, + dtype, + ) + elif scaling_type == "dynamic": + scaling_factor = rope_scaling["factor"] + rotary_emb = DynamicNTKScalingRotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + scaling_factor, + dtype, + ) + elif scaling_type == "yarn": + scaling_factor = rope_scaling["factor"] + original_max_position = rope_scaling["original_max_position_embeddings"] + extra_kwargs = { + k: v + for k, v in rope_scaling.items() + if k + in ("extrapolation_factor", "attn_factor", "beta_fast", "beta_slow") + } + rotary_emb = YaRNScalingRotaryEmbedding( + head_size, + rotary_dim, + original_max_position, + base, + is_neox_style, + scaling_factor, + dtype, + **extra_kwargs, + ) + elif scaling_type == "deepseek_yarn": + scaling_factor = rope_scaling["factor"] + original_max_position = rope_scaling["original_max_position_embeddings"] + # assert max_position == original_max_position * scaling_factor + extra_kwargs = { + k: v + for k, v in rope_scaling.items() + if k + in ( + "extrapolation_factor", + "attn_factor", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ) + } + rotary_emb = DeepseekScalingRotaryEmbedding( + head_size, + rotary_dim, + original_max_position, + base, + is_neox_style, + scaling_factor, + dtype, + **extra_kwargs, + ) + elif scaling_type == "longrope": + short_factor = rope_scaling["short_factor"] + long_factor = rope_scaling["long_factor"] + original_max_position = rope_scaling["original_max_position_embeddings"] + extra_kwargs = { + k: v + for k, v in rope_scaling.items() + if k in ("short_mscale", "long_mscale") + } + rotary_emb = Phi3LongRoPEScaledRotaryEmbedding( + head_size, + rotary_dim, + max_position, + original_max_position, + base, + is_neox_style, + dtype, + short_factor, + long_factor, + **extra_kwargs, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") _ROPE_DICT[key] = rotary_emb return rotary_emb - - -def get_rope_wrapper( - head_size: int, - rotary_dim: int, - max_position: int, - base: int, - is_neox_style: bool = True, - rope_scaling: Optional[Dict[str, Any]] = None, - dtype: Optional[torch.dtype] = None, - partial_rotary_factor: float = 1.0, - device: Optional[str] = None, -): - if device != "cpu": - return get_rope( - head_size, - rotary_dim, - max_position, - base, - is_neox_style, - rope_scaling, - dtype, - partial_rotary_factor, - ) - - return get_rope_cpu( - head_size, - rotary_dim, - max_position, - base, - is_neox_style, - rope_scaling, - dtype, - partial_rotary_factor, - device, - ) diff --git a/python/sglang/srt/models/baichuan.py b/python/sglang/srt/models/baichuan.py index d8916abacfb..066157f05ce 100644 --- a/python/sglang/srt/models/baichuan.py +++ b/python/sglang/srt/models/baichuan.py @@ -24,7 +24,6 @@ import torch from torch import nn from transformers import PretrainedConfig -from vllm.model_executor.layers.rotary_embedding import get_rope from sglang.srt.distributed import ( get_tensor_model_parallel_rank, @@ -40,6 +39,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/chatglm.py b/python/sglang/srt/models/chatglm.py index 4d73aa0def7..222cc3e2d80 100644 --- a/python/sglang/srt/models/chatglm.py +++ b/python/sglang/srt/models/chatglm.py @@ -21,7 +21,6 @@ import torch from torch import nn from torch.nn import LayerNorm -from vllm.model_executor.layers.rotary_embedding import get_rope from sglang.srt.configs import ChatGLMConfig from sglang.srt.distributed import get_tensor_model_parallel_world_size @@ -35,6 +34,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/commandr.py b/python/sglang/srt/models/commandr.py index d701c10a767..151087732f0 100644 --- a/python/sglang/srt/models/commandr.py +++ b/python/sglang/srt/models/commandr.py @@ -44,7 +44,6 @@ from torch import nn from torch.nn.parameter import Parameter from transformers import PretrainedConfig -from vllm.model_executor.layers.rotary_embedding import get_rope from sglang.srt.distributed import ( get_tensor_model_parallel_rank, @@ -59,6 +58,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader diff --git a/python/sglang/srt/models/dbrx.py b/python/sglang/srt/models/dbrx.py index 206f24d61bf..cedc9639220 100644 --- a/python/sglang/srt/models/dbrx.py +++ b/python/sglang/srt/models/dbrx.py @@ -19,7 +19,6 @@ import torch import torch.nn as nn -from vllm.model_executor.layers.rotary_embedding import get_rope from sglang.srt.configs import DbrxConfig from sglang.srt.distributed import ( @@ -36,6 +35,7 @@ from sglang.srt.layers.moe.fused_moe_triton import fused_moe from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, diff --git a/python/sglang/srt/models/deepseek.py b/python/sglang/srt/models/deepseek.py index cbd123c9e96..7d2c0700fe4 100644 --- a/python/sglang/srt/models/deepseek.py +++ b/python/sglang/srt/models/deepseek.py @@ -21,7 +21,6 @@ import torch from torch import nn from transformers import PretrainedConfig -from vllm.model_executor.layers.rotary_embedding import get_rope from sglang.srt.distributed import ( get_tensor_model_parallel_rank, @@ -40,6 +39,7 @@ from sglang.srt.layers.moe.fused_moe_triton import fused_moe from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 5a76c8ac981..0d327c0ca97 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -23,7 +23,6 @@ from torch import nn from transformers import PretrainedConfig from vllm import _custom_ops as ops -from vllm.model_executor.layers.rotary_embedding import get_rope from sglang.srt.distributed import ( get_tensor_model_parallel_rank, @@ -49,7 +48,7 @@ normalize_e4m3fn_to_e4m3fnuz, ) from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.rotary_embedding import get_rope_wrapper +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, @@ -272,7 +271,7 @@ def __init__( quant_config=quant_config, ) rope_scaling["rope_type"] = "deepseek_yarn" - self.rotary_emb = get_rope_wrapper( + self.rotary_emb = get_rope( qk_rope_head_dim, rotary_dim=qk_rope_head_dim, max_position=max_position_embeddings, diff --git a/python/sglang/srt/models/exaone.py b/python/sglang/srt/models/exaone.py index 5bb0ea538f9..10be1e74d61 100644 --- a/python/sglang/srt/models/exaone.py +++ b/python/sglang/srt/models/exaone.py @@ -20,7 +20,6 @@ import torch from torch import nn -from vllm.model_executor.layers.rotary_embedding import get_rope from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.activation import SiluAndMul @@ -33,6 +32,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/gemma.py b/python/sglang/srt/models/gemma.py index bed496c6952..9940c569e25 100644 --- a/python/sglang/srt/models/gemma.py +++ b/python/sglang/srt/models/gemma.py @@ -21,7 +21,6 @@ import torch from torch import nn from transformers import PretrainedConfig -from vllm.model_executor.layers.rotary_embedding import get_rope from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.activation import GeluAndMul @@ -34,6 +33,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index ee0a762aa0a..4d21901de7c 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -15,12 +15,11 @@ # Adapted from: # https://github.com/vllm-project/vllm/blob/56b325e977435af744f8b3dca7af0ca209663558/vllm/model_executor/models/gemma2.py -from typing import Iterable, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple import torch from torch import nn from transformers import PretrainedConfig -from vllm.model_executor.layers.rotary_embedding import get_rope from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.activation import GeluAndMul @@ -33,6 +32,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader @@ -45,10 +45,6 @@ def get_attention_sliding_window_size(config): return config.sliding_window - 1 -# FIXME: temporary solution, remove after next vllm release -from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding - - class Gemma2MLP(nn.Module): def __init__( self, diff --git a/python/sglang/srt/models/gpt2.py b/python/sglang/srt/models/gpt2.py index d457603b063..04c3005ce2f 100644 --- a/python/sglang/srt/models/gpt2.py +++ b/python/sglang/srt/models/gpt2.py @@ -25,8 +25,6 @@ from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_world_size from sglang.srt.layers.activation import get_act_fn - -# from sglang.srt.layers.activation import get_act_fn from sglang.srt.layers.linear import ( ColumnParallelLinear, QKVParallelLinear, diff --git a/python/sglang/srt/models/granite.py b/python/sglang/srt/models/granite.py index 1383e0ef0fd..255f23227ff 100644 --- a/python/sglang/srt/models/granite.py +++ b/python/sglang/srt/models/granite.py @@ -22,7 +22,6 @@ import torch from torch import nn from transformers import GraniteConfig -from vllm.model_executor.layers.rotary_embedding import get_rope from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.activation import SiluAndMul @@ -36,6 +35,7 @@ from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index 490798b5b66..c13d3e25368 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -22,7 +22,6 @@ import torch.nn.functional as F from torch import nn from transformers import PretrainedConfig -from vllm.model_executor.layers.rotary_embedding import get_rope from sglang.srt.distributed import ( get_tensor_model_parallel_rank, @@ -40,6 +39,7 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/internlm2.py b/python/sglang/srt/models/internlm2.py index 31617db5e5a..ce8f9a3cf65 100644 --- a/python/sglang/srt/models/internlm2.py +++ b/python/sglang/srt/models/internlm2.py @@ -19,7 +19,6 @@ import torch from torch import nn from transformers import PretrainedConfig -from vllm.model_executor.layers.rotary_embedding import get_rope from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.activation import SiluAndMul @@ -32,6 +31,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 9ea80d0c05d..4ea77eede9b 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -22,7 +22,6 @@ import torch from torch import nn from transformers import LlamaConfig -from vllm.model_executor.layers.rotary_embedding import get_rope from sglang.srt.distributed import ( get_tensor_model_parallel_rank, @@ -39,6 +38,7 @@ from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/minicpm.py b/python/sglang/srt/models/minicpm.py index b0853c3eee6..f5e69411acc 100644 --- a/python/sglang/srt/models/minicpm.py +++ b/python/sglang/srt/models/minicpm.py @@ -18,7 +18,6 @@ import torch from torch import nn -from vllm.model_executor.layers.rotary_embedding import get_rope from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.activation import SiluAndMul @@ -31,6 +30,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/minicpm3.py b/python/sglang/srt/models/minicpm3.py index 2d15af43ff2..118be8ff6c8 100644 --- a/python/sglang/srt/models/minicpm3.py +++ b/python/sglang/srt/models/minicpm3.py @@ -19,7 +19,6 @@ import torch from torch import nn from transformers import PretrainedConfig -from vllm.model_executor.layers.rotary_embedding import get_rope from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.activation import SiluAndMul @@ -33,6 +32,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index c2c8d2294d3..4ea734836af 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -21,7 +21,6 @@ import torch from torch import nn from transformers import MixtralConfig -from vllm.model_executor.layers.rotary_embedding import get_rope from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, @@ -38,6 +37,7 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/mixtral_quant.py b/python/sglang/srt/models/mixtral_quant.py index c38328369c4..244dc7df2d0 100644 --- a/python/sglang/srt/models/mixtral_quant.py +++ b/python/sglang/srt/models/mixtral_quant.py @@ -23,7 +23,6 @@ import torch.nn.functional as F from torch import nn from transformers import MixtralConfig -from vllm.model_executor.layers.rotary_embedding import get_rope from sglang.srt.distributed import ( get_tensor_model_parallel_rank, @@ -39,6 +38,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/olmo.py b/python/sglang/srt/models/olmo.py index 45a3f3ff461..4d8a79900f4 100644 --- a/python/sglang/srt/models/olmo.py +++ b/python/sglang/srt/models/olmo.py @@ -20,7 +20,6 @@ import torch from torch import nn from transformers import OlmoConfig -from vllm.model_executor.layers.rotary_embedding import get_rope from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.activation import SiluAndMul @@ -32,6 +31,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/olmo2.py b/python/sglang/srt/models/olmo2.py index fafe39d7189..f3e1979f849 100755 --- a/python/sglang/srt/models/olmo2.py +++ b/python/sglang/srt/models/olmo2.py @@ -21,7 +21,6 @@ import torch from torch import nn from transformers import PretrainedConfig -from vllm.model_executor.layers.rotary_embedding import get_rope from sglang.srt.distributed import ( get_tensor_model_parallel_rank, @@ -39,6 +38,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/olmoe.py b/python/sglang/srt/models/olmoe.py index 91722d966c8..10b781d72ff 100644 --- a/python/sglang/srt/models/olmoe.py +++ b/python/sglang/srt/models/olmoe.py @@ -22,7 +22,6 @@ import torch from torch import nn from transformers import PretrainedConfig -from vllm.model_executor.layers.rotary_embedding import get_rope from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.layernorm import RMSNorm @@ -35,6 +34,7 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/phi3_small.py b/python/sglang/srt/models/phi3_small.py index e59f88013d3..b7195dbaa28 100644 --- a/python/sglang/srt/models/phi3_small.py +++ b/python/sglang/srt/models/phi3_small.py @@ -5,7 +5,6 @@ from torch import nn from transformers import Phi3Config from transformers.configuration_utils import PretrainedConfig -from vllm.model_executor.layers.rotary_embedding import get_rope from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.linear import ( @@ -17,6 +16,7 @@ from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, diff --git a/python/sglang/srt/models/qwen.py b/python/sglang/srt/models/qwen.py index ed9ca02b7da..2c99da926b6 100644 --- a/python/sglang/srt/models/qwen.py +++ b/python/sglang/srt/models/qwen.py @@ -20,7 +20,6 @@ import torch from torch import nn from transformers import PretrainedConfig -from vllm.model_executor.layers.rotary_embedding import get_rope from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.activation import SiluAndMul @@ -33,6 +32,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 04faa8dea1b..935e743bf6a 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -20,7 +20,6 @@ import torch from torch import nn -from vllm.model_executor.layers.rotary_embedding import get_rope from sglang.srt.distributed import ( get_tensor_model_parallel_rank, @@ -37,6 +36,7 @@ from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index b0f08f975ba..6183f30daf4 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -22,7 +22,6 @@ import torch.nn.functional as F from torch import nn from transformers import PretrainedConfig -from vllm.model_executor.layers.rotary_embedding import get_rope from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, @@ -40,6 +39,7 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/stablelm.py b/python/sglang/srt/models/stablelm.py index 2f144dcb1ee..c169dd6fba4 100644 --- a/python/sglang/srt/models/stablelm.py +++ b/python/sglang/srt/models/stablelm.py @@ -24,7 +24,6 @@ import torch from torch import nn from transformers import PretrainedConfig -from vllm.model_executor.layers.rotary_embedding import get_rope from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.activation import SiluAndMul @@ -36,6 +35,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/torch_native_llama.py b/python/sglang/srt/models/torch_native_llama.py index 04d21f0d4f3..024a6f317fa 100644 --- a/python/sglang/srt/models/torch_native_llama.py +++ b/python/sglang/srt/models/torch_native_llama.py @@ -47,7 +47,6 @@ from torch import nn from torch.nn.parameter import Parameter from transformers import LlamaConfig -from vllm.model_executor.layers.rotary_embedding import get_rope from sglang.srt.distributed import ( get_tensor_model_parallel_rank, @@ -58,6 +57,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/xverse.py b/python/sglang/srt/models/xverse.py index 799e513ae7e..7fd24182374 100644 --- a/python/sglang/srt/models/xverse.py +++ b/python/sglang/srt/models/xverse.py @@ -21,7 +21,6 @@ import torch from torch import nn from transformers import LlamaConfig -from vllm.model_executor.layers.rotary_embedding import get_rope from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.activation import SiluAndMul @@ -34,6 +33,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/xverse_moe.py b/python/sglang/srt/models/xverse_moe.py index 97b62815a87..218b96f9cb4 100644 --- a/python/sglang/srt/models/xverse_moe.py +++ b/python/sglang/srt/models/xverse_moe.py @@ -18,7 +18,6 @@ import torch from torch import nn from transformers import PretrainedConfig -from vllm.model_executor.layers.rotary_embedding import get_rope from sglang.srt.distributed import ( get_tensor_model_parallel_rank, @@ -37,6 +36,7 @@ from sglang.srt.layers.moe.fused_moe_triton import fused_moe from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding,