Skip to content

Commit

Permalink
feat: remove vllm get_rope (#2964)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhyncs authored Jan 18, 2025
1 parent 6f98c58 commit 2add697
Show file tree
Hide file tree
Showing 30 changed files with 1,026 additions and 217 deletions.
1,176 changes: 996 additions & 180 deletions python/sglang/srt/layers/rotary_embedding.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion python/sglang/srt/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.model_executor.layers.rotary_embedding import get_rope

from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
Expand All @@ -40,6 +39,7 @@
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import torch
from torch import nn
from torch.nn import LayerNorm
from vllm.model_executor.layers.rotary_embedding import get_rope

from sglang.srt.configs import ChatGLMConfig
from sglang.srt.distributed import get_tensor_model_parallel_world_size
Expand All @@ -35,6 +34,7 @@
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/commandr.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
from torch import nn
from torch.nn.parameter import Parameter
from transformers import PretrainedConfig
from vllm.model_executor.layers.rotary_embedding import get_rope

from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
Expand All @@ -59,6 +58,7 @@
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import torch
import torch.nn as nn
from vllm.model_executor.layers.rotary_embedding import get_rope

from sglang.srt.configs import DbrxConfig
from sglang.srt.distributed import (
Expand All @@ -36,6 +35,7 @@
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead,
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.model_executor.layers.rotary_embedding import get_rope

from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
Expand All @@ -40,6 +39,7 @@
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
Expand Down
5 changes: 2 additions & 3 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from torch import nn
from transformers import PretrainedConfig
from vllm import _custom_ops as ops
from vllm.model_executor.layers.rotary_embedding import get_rope

from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
Expand All @@ -49,7 +48,7 @@
normalize_e4m3fn_to_e4m3fnuz,
)
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope_wrapper
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
Expand Down Expand Up @@ -272,7 +271,7 @@ def __init__(
quant_config=quant_config,
)
rope_scaling["rope_type"] = "deepseek_yarn"
self.rotary_emb = get_rope_wrapper(
self.rotary_emb = get_rope(
qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/exaone.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import torch
from torch import nn
from vllm.model_executor.layers.rotary_embedding import get_rope

from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import SiluAndMul
Expand All @@ -33,6 +32,7 @@
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.model_executor.layers.rotary_embedding import get_rope

from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import GeluAndMul
Expand All @@ -34,6 +33,7 @@
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
Expand Down
8 changes: 2 additions & 6 deletions python/sglang/srt/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@
# Adapted from:
# https://github.com/vllm-project/vllm/blob/56b325e977435af744f8b3dca7af0ca209663558/vllm/model_executor/models/gemma2.py

from typing import Iterable, Optional, Set, Tuple, Union
from typing import Iterable, Optional, Set, Tuple

import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.model_executor.layers.rotary_embedding import get_rope

from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import GeluAndMul
Expand All @@ -33,6 +32,7 @@
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
Expand All @@ -45,10 +45,6 @@ def get_attention_sliding_window_size(config):
return config.sliding_window - 1


# FIXME: temporary solution, remove after next vllm release
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding


class Gemma2MLP(nn.Module):
def __init__(
self,
Expand Down
2 changes: 0 additions & 2 deletions python/sglang/srt/models/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@

from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import get_act_fn

# from sglang.srt.layers.activation import get_act_fn
from sglang.srt.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import torch
from torch import nn
from transformers import GraniteConfig
from vllm.model_executor.layers.rotary_embedding import get_rope

from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import SiluAndMul
Expand All @@ -36,6 +35,7 @@
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/grok.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import torch.nn.functional as F
from torch import nn
from transformers import PretrainedConfig
from vllm.model_executor.layers.rotary_embedding import get_rope

from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
Expand All @@ -40,6 +39,7 @@
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.model_executor.layers.rotary_embedding import get_rope

from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import SiluAndMul
Expand All @@ -32,6 +31,7 @@
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import torch
from torch import nn
from transformers import LlamaConfig
from vllm.model_executor.layers.rotary_embedding import get_rope

from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
Expand All @@ -39,6 +38,7 @@
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/minicpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import torch
from torch import nn
from vllm.model_executor.layers.rotary_embedding import get_rope

from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import SiluAndMul
Expand All @@ -31,6 +30,7 @@
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/minicpm3.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.model_executor.layers.rotary_embedding import get_rope

from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import SiluAndMul
Expand All @@ -33,6 +32,7 @@
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import torch
from torch import nn
from transformers import MixtralConfig
from vllm.model_executor.layers.rotary_embedding import get_rope

from sglang.srt.distributed import (
get_tensor_model_parallel_world_size,
Expand All @@ -38,6 +37,7 @@
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/mixtral_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import torch.nn.functional as F
from torch import nn
from transformers import MixtralConfig
from vllm.model_executor.layers.rotary_embedding import get_rope

from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
Expand All @@ -39,6 +38,7 @@
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import torch
from torch import nn
from transformers import OlmoConfig
from vllm.model_executor.layers.rotary_embedding import get_rope

from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import SiluAndMul
Expand All @@ -32,6 +31,7 @@
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/olmo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.model_executor.layers.rotary_embedding import get_rope

from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
Expand All @@ -39,6 +38,7 @@
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/olmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.model_executor.layers.rotary_embedding import get_rope

from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.layernorm import RMSNorm
Expand All @@ -35,6 +34,7 @@
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/phi3_small.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from torch import nn
from transformers import Phi3Config
from transformers.configuration_utils import PretrainedConfig
from vllm.model_executor.layers.rotary_embedding import get_rope

from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.linear import (
Expand All @@ -17,6 +16,7 @@
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead,
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.model_executor.layers.rotary_embedding import get_rope

from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import SiluAndMul
Expand All @@ -33,6 +32,7 @@
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
Expand Down
Loading

0 comments on commit 2add697

Please sign in to comment.