From 57119bb201b8d938e9cfcf4b55f8e8ae896d144e Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Sun, 3 Sep 2023 00:11:03 +0300 Subject: [PATCH 1/8] Merge query/key/value projection layers --- src/petals/models/llama/block.py | 109 +++++++++++++++++++++++++++++- src/petals/utils/convert_block.py | 13 ++++ 2 files changed, 120 insertions(+), 2 deletions(-) diff --git a/src/petals/models/llama/block.py b/src/petals/models/llama/block.py index 55f659a61..b7616a151 100644 --- a/src/petals/models/llama/block.py +++ b/src/petals/models/llama/block.py @@ -3,13 +3,118 @@ Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py See commit history for authorship. """ +import math from typing import Optional, Tuple import torch -from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel +import torch.nn as nn +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaConfig, + LlamaDecoderLayer, + LlamaMLP, + LlamaModel, + LlamaRMSNorm, + apply_rotary_pos_emb, + repeat_kv, +) -class WrappedLlamaBlock(LlamaDecoderLayer): +class OptimizedLlamaAttention(LlamaAttention): + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.qkv_proj = nn.Linear( + self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, bias=False + ) + self.qkv_sizes = [ + self.num_heads * self.head_dim, + self.num_key_value_heads * self.head_dim, + self.num_key_value_heads * self.head_dim, + ] + self.attn_norm_constant = math.sqrt(self.head_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + assert ( + self.config.pretraining_tp == 1 + ), "OptimizedLlamaAttention assumes that config.pretraining_tp is equal to 1" + assert not output_attentions, "output_attentions=True is not supported" + + query_states, key_states, value_states = torch.split(self.qkv_proj(hidden_states), self.qkv_sizes, dim=2) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / self.attn_norm_constant + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +class OptimizedLlamaDecoderLayer(LlamaDecoderLayer): + def __init__(self, config: LlamaConfig): + nn.Module.__init__(self) + self.hidden_size = config.hidden_size + self.self_attn = OptimizedLlamaAttention(config=config) + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + +class WrappedLlamaBlock(OptimizedLlamaDecoderLayer): def forward( self, hidden_states: torch.Tensor, diff --git a/src/petals/utils/convert_block.py b/src/petals/utils/convert_block.py index 94d3e29f3..fc4c6c7fc 100644 --- a/src/petals/utils/convert_block.py +++ b/src/petals/utils/convert_block.py @@ -50,6 +50,19 @@ def convert_block( if freeze: block.requires_grad_(False) + if hasattr(block, "self_attn") and hasattr(block.self_attn, "qkv_proj"): + offset = 0 + for data in [ + block.self_attn.q_proj.weight.data, + block.self_attn.k_proj.weight.data, + block.self_attn.v_proj.weight.data, + ]: + block.self_attn.qkv_proj.weight.data[offset : offset + data.size(0)].copy_(data) + offset += data.size(0) + del block.self_attn.q_proj + del block.self_attn.k_proj + del block.self_attn.v_proj + block = make_tensor_parallel(block, config, tensor_parallel_devices, output_device=output_device) if quant_type != QuantType.NONE: From c666a975d0435c37dbff3a4c13f3fcca72d131dc Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Sun, 3 Sep 2023 00:38:35 +0300 Subject: [PATCH 2/8] Remove unused import in throughput.py --- src/petals/server/throughput.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/petals/server/throughput.py b/src/petals/server/throughput.py index bf71f44c0..dcb94d748 100644 --- a/src/petals/server/throughput.py +++ b/src/petals/server/throughput.py @@ -1,6 +1,5 @@ import fcntl import json -import math import multiprocessing as mp import os import time From b2ab84cc3347575dae4d43224caa127b3a53736c Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Sun, 3 Sep 2023 00:41:03 +0300 Subject: [PATCH 3/8] Add dry_run option to --throughput --- src/petals/cli/run_server.py | 5 +++-- src/petals/server/server.py | 11 ++++++++--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index 3728c163e..94f5c2e53 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -106,12 +106,13 @@ def main(): "and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.") parser.add_argument('--throughput', - type=lambda value: value if value in ['auto', 'eval'] else float(value), + type=lambda value: value if value in ['auto', 'eval', 'dry_run'] else float(value), default='auto', help='Expected server throughput (a float measured in RPS). ' 'If set to "auto" (default), the script evaluates network and compute throughput ' 'on the first run and uses these estimates for future runs. ' - 'If set to "eval", the script re-evaluates the throughput and overrides the cache.') + 'If set to "eval", the script re-evaluates the throughput and overrides the cache. ' + 'If set to "dry_run", the script re-evaluates the throughput and exits.') parser.add_argument('--update_period', type=float, required=False, default=120, help='Server will report blocks to DHT once in this many seconds') parser.add_argument('--expiration', type=float, required=False, default=None, diff --git a/src/petals/server/server.py b/src/petals/server/server.py index ab646a5f4..a5f2ba01e 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -5,6 +5,7 @@ import multiprocessing as mp import os import random +import sys import threading import time from typing import Dict, List, Optional, Sequence, Union @@ -234,8 +235,9 @@ def __init__( self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB") - assert isinstance(throughput, float) or throughput in ["auto", "eval"] - if throughput in ["auto", "eval"]: + assert isinstance(throughput, float) or throughput in ["auto", "eval", "dry_run"] + if throughput in ["auto", "eval", "dry_run"]: + force_eval = throughput in ["eval", "dry_run"] throughput_info = get_server_throughput( converted_model_name_or_path, self.block_config, @@ -245,9 +247,12 @@ def __init__( quant_type=quant_type, tensor_parallel_devices=self.tensor_parallel_devices, reachable_via_relay=reachable_via_relay, - force_eval=(throughput == "eval"), + force_eval=force_eval, cache_dir=cache_dir, ) + if throughput == "dry_run": + logger.info("Finished estimating throughput, exiting") + sys.exit(0) else: throughput_info = {"throughput": throughput} self.server_info = ServerInfo( From 46441310861e78a17f6a29d75b7a643430b52bac Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Sun, 3 Sep 2023 01:08:55 +0300 Subject: [PATCH 4/8] Ignore missing qkv_proj.weight when loading a checkpoint --- src/petals/server/from_pretrained.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/petals/server/from_pretrained.py b/src/petals/server/from_pretrained.py index 73956fe87..3ca8ef10a 100644 --- a/src/petals/server/from_pretrained.py +++ b/src/petals/server/from_pretrained.py @@ -65,6 +65,7 @@ def load_pretrained_block( # dummy load, check that keys match report = block.load_state_dict(state_dict, strict=False) + report.missing_keys.pop("self_attn.qkv_proj.weight", None) # will be filled later assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}" for param_name, _ in block.named_parameters(): From f1009156412aed14891ac1bb27e3d4092c4932d6 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Sun, 3 Sep 2023 01:09:56 +0300 Subject: [PATCH 5/8] Reformat code with black --- src/petals/server/from_pretrained.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/petals/server/from_pretrained.py b/src/petals/server/from_pretrained.py index 3ca8ef10a..5a571b73d 100644 --- a/src/petals/server/from_pretrained.py +++ b/src/petals/server/from_pretrained.py @@ -65,7 +65,7 @@ def load_pretrained_block( # dummy load, check that keys match report = block.load_state_dict(state_dict, strict=False) - report.missing_keys.pop("self_attn.qkv_proj.weight", None) # will be filled later + report.missing_keys.pop("self_attn.qkv_proj.weight", None) # will be filled later assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}" for param_name, _ in block.named_parameters(): From 16fb5479603346aaa72dd35a81017b10337224e7 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Sun, 3 Sep 2023 01:35:00 +0300 Subject: [PATCH 6/8] Fix removal of nonexistent keys --- src/petals/server/from_pretrained.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/petals/server/from_pretrained.py b/src/petals/server/from_pretrained.py index 5a571b73d..85617e71b 100644 --- a/src/petals/server/from_pretrained.py +++ b/src/petals/server/from_pretrained.py @@ -65,7 +65,8 @@ def load_pretrained_block( # dummy load, check that keys match report = block.load_state_dict(state_dict, strict=False) - report.missing_keys.pop("self_attn.qkv_proj.weight", None) # will be filled later + if "self_attn.qkv_proj.weight" in report.missing_keys: + report.missing_keys.remove("self_attn.qkv_proj.weight") # will be filled later assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}" for param_name, _ in block.named_parameters(): From 9cb4c721e79ccb5e20fec449e62d092c271c95fc Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Sun, 3 Sep 2023 01:55:50 +0300 Subject: [PATCH 7/8] Fix checking for nonexistent keys --- src/petals/server/from_pretrained.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/petals/server/from_pretrained.py b/src/petals/server/from_pretrained.py index 85617e71b..c52ad4b68 100644 --- a/src/petals/server/from_pretrained.py +++ b/src/petals/server/from_pretrained.py @@ -70,11 +70,12 @@ def load_pretrained_block( assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}" for param_name, _ in block.named_parameters(): - assert param_name in state_dict, f"{param_name} not in state dict" - param = state_dict[param_name] - if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")): - param = param.to(torch_dtype) - set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype) + if param_name != "self_attn.qkv_proj.weight": + assert param_name in state_dict, f"{param_name} not in state dict" + param = state_dict[param_name] + if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")): + param = param.to(torch_dtype) + set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype) logger.info(f"Loaded {model_name} block {block_index}") logger.debug(f"Details: {report}") From 4159e557bfc191a71ef001b0350f1d76a09d5f0b Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Sun, 3 Sep 2023 19:20:07 +0300 Subject: [PATCH 8/8] Create dummy data when materializing qkv_proj --- src/petals/server/from_pretrained.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/petals/server/from_pretrained.py b/src/petals/server/from_pretrained.py index c52ad4b68..bb8901687 100644 --- a/src/petals/server/from_pretrained.py +++ b/src/petals/server/from_pretrained.py @@ -76,6 +76,10 @@ def load_pretrained_block( if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")): param = param.to(torch_dtype) set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype) + else: + cur_block = getattr(block, param_name) + dummy_value = torch.empty_like(cur_block, device="cpu") + set_module_tensor_to_device(block, param_name, "cpu", dummy_value) logger.info(f"Loaded {model_name} block {block_index}") logger.debug(f"Details: {report}")