From 0b4585cce20c1f853301d2e0f2ca2b0344ec4140 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 12 Apr 2023 21:08:19 -0400 Subject: [PATCH 1/8] Add back experimental features --- .../gpt_bigcode/configuration_gpt_bigcode.py | 33 ++++++ .../models/gpt_bigcode/inference_runner.py | 6 +- .../gpt_bigcode/modeling_gpt_bigcode.py | 108 +++++++++++++++++- 3 files changed, 139 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py index 9cbaf3e184..1cfba93a71 100644 --- a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py @@ -14,6 +14,8 @@ # limitations under the License. """ GPTBigCode configuration""" +from enum import IntEnum + from ...configuration_utils import PretrainedConfig from ...utils import logging @@ -25,6 +27,19 @@ } +class InferenceRunnerType(IntEnum): + NO_RUNNER = 0 + # Use the inference runner without cuda graphs. + BASE_RUNNER = 1 + # Use cuda graphs in the inference runner. Leave out the attention which has a variable shape. + # This significantly lowers the cpu time and prevent a cpu bottleneck for smaller batches and models. + PARTIAL_GRAPH = 2 + # Turn the whole model into a cuda graph. One graph for each sequence length. + # Note: only useful for small batches and models, graphs take some time to generate, flaky. + # Crashes with jit on A100 but seems to work without jit (PYTORCH_JIT=0) and on V100. + FULL_GRAPH = 3 + + class GPTBigCodeConfig(PretrainedConfig): """ This is the configuration class to store the configuration of a [`GPTBigCodeModel`]. It is used to instantiate a @@ -119,6 +134,12 @@ def __init__( attention_softmax_in_fp32=True, scale_attention_softmax_in_fp32=True, multi_query=True, + inference_runner=InferenceRunnerType.NO_RUNNER, + validate_runner_input=True, + pre_allocate_kv_cache=False, + max_sequence_length=None, + max_batch_size=None, + pad_key_length=True, **kwargs, ): self.vocab_size = vocab_size @@ -142,4 +163,16 @@ def __init__( self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id + self.inference_runner = InferenceRunnerType(inference_runner) + # Set to False to disable input validation of safe inputs, for a small speedup. + self.validate_runner_input = validate_runner_input + + self.pre_allocate_kv_cache = pre_allocate_kv_cache + # The max sequence length for the pre-allocated KV cache (`n_positions` if not provided). + self.max_sequence_length = max_sequence_length + # The max batch size for the pre-allocated KV cache, (deduce from input if not provided). + self.max_batch_size = max_batch_size + # Pad key length to a multiple of 8 (requires pre_allocate_kv_cache). + self.pad_key_length = pad_key_length + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) diff --git a/src/transformers/models/gpt_bigcode/inference_runner.py b/src/transformers/models/gpt_bigcode/inference_runner.py index 0cba0cc7e5..3e896fa2c7 100644 --- a/src/transformers/models/gpt_bigcode/inference_runner.py +++ b/src/transformers/models/gpt_bigcode/inference_runner.py @@ -285,9 +285,9 @@ def _forward(self, key_length): def forward( self, - input_ids: torch.LongTensor, - attention_mask: torch.FloatTensor, - position_ids: torch.LongTensor, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + position_ids: torch.Tensor, past_key_values: Union[List[torch.Tensor], int], ) -> BaseModelOutputWithPastAndCrossAttentions: batch_size, query_length = input_ids.shape diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 72858532bf..8a4880ceb6 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -34,7 +34,7 @@ add_start_docstrings_to_model_forward, logging, ) -from .configuration_gpt_bigcode import GPTBigCodeConfig +from .configuration_gpt_bigcode import GPTBigCodeConfig, InferenceRunnerType logger = logging.get_logger(__name__) @@ -105,6 +105,14 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): config.scale_attention_softmax_in_fp32 and config.attention_softmax_in_fp32 ) + # KV caching and padding + self.kv_cache = None + self.kv_cache_max_batch_size = config.max_batch_size or 0 + self.kv_cache_max_sequence_length = config.max_sequence_length or config.n_positions + self.pre_allocate_kv_cache = config.pre_allocate_kv_cache + self.pad_key_length = config.pad_key_length and config.pre_allocate_kv_cache + self._frozen_kv_cache = False + if self.is_cross_attention: if self.multi_query: raise NotImplementedError("Multi-Query Attention not supported for cross_attention") @@ -202,6 +210,35 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): return attn_output, attn_weights + def freeze_kv_cache(self, enable=True): + if self.kv_cache is None: + raise RuntimeError("KV cache not found.") + # Prevent re-allocation of the KV cache. + self._frozen_kv_cache = enable + + def get_kv_cache(self, batch_size, sequence_length, device, dtype, allocate=True): + if ( + self.kv_cache is None + or self.kv_cache.dtype != dtype + or self.kv_cache.device != device + or batch_size > self.kv_cache_max_batch_size + or sequence_length > self.kv_cache_max_sequence_length + ): + if self._frozen_kv_cache or not allocate: + # TODO: Improve error message + raise RuntimeError("KV cache not found." if self.kv_cache is None else "Invalid KV cache.") + # Free memory first. + self.kv_cache = None + self.kv_cache_max_sequence_length = max(sequence_length, self.kv_cache_max_sequence_length) + self.kv_cache_max_batch_size = max(batch_size, self.kv_cache_max_batch_size) + kv_cache_size = 2 * self.kv_cache_max_batch_size * self.kv_cache_max_sequence_length * self.kv_dim + self.kv_cache = torch.empty([kv_cache_size], device=device, dtype=dtype) + # This view ensures the cache is contiguous for all batch sizes. + kv_cache = self.kv_cache[: 2 * batch_size * self.kv_cache_max_sequence_length * self.kv_dim].view( + batch_size, self.kv_heads, self.kv_cache_max_sequence_length, 2 * self.head_dim + ) + return kv_cache[:, 0, :sequence_length, :] if self.is_mqa else kv_cache[:, :, :sequence_length, :] + def forward( self, hidden_states: torch.Tensor, @@ -239,9 +276,27 @@ def forward( .split((self.head_dim, 2 * self.head_dim), dim=3) ) - if layer_past is not None: - key_value = torch.cat((layer_past, key_value), dim=-2) - present = key_value if use_cache else None + present = None + + if self.pre_allocate_kv_cache: + if use_cache or layer_past is not None: + last_key_length = layer_past or 0 + batch_size = key_value.size(0) + key_length = last_key_length + key_value.size(-2) + padded_key_length = key_length + -key_length % (8 if self.pad_key_length else 1) + kv_cache = self.get_kv_cache( + batch_size, padded_key_length, key_value.device, key_value.dtype, allocate=last_key_length == 0 + ) + if self.is_mqa: + kv_cache[:, last_key_length:key_length, :].copy_(key_value) + key_value = kv_cache + if use_cache: + present = key_length + else: + if layer_past is not None: + key_value = torch.cat((layer_past, key_value), dim=-2) + if use_cache: + present = key_value key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) @@ -513,6 +568,17 @@ def __init__(self, config): self.h = nn.ModuleList([GPTBigCodeBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.pre_allocate_kv_cache = config.pre_allocate_kv_cache + self.pad_key_length = config.pad_key_length and self.pre_allocate_kv_cache + self.inference_runner_type = InferenceRunnerType(config.inference_runner) + + if self.inference_runner_type == InferenceRunnerType.NO_RUNNER: + self.inference_runner = None + else: + from .inference_runner import GPTBigCodeInferenceRunner + + self.inference_runner = GPTBigCodeInferenceRunner(config, self) + max_positions = config.max_position_embeddings self.register_buffer( "bias", torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)), persistent=False @@ -551,6 +617,31 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + if self.inference_runner is not None and past_key_values is not None: + if self.config.validate_runner_input: + assert input_ids is not None + assert past_key_values is not None + assert attention_mask is not None + assert token_type_ids is None + assert position_ids is not None + assert head_mask is None + assert inputs_embeds is None + assert encoder_hidden_states is None + assert encoder_attention_mask is None + use_cache = use_cache if use_cache is not None else self.config.use_cache + assert use_cache is True + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + assert output_attentions is False + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + assert output_hidden_states is False + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + assert return_dict is True + return self.inference_runner.forward(input_ids, attention_mask, position_ids, past_key_values) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -583,8 +674,10 @@ def forward( if past_key_values is None: past_length = 0 past_key_values = tuple([None] * len(self.h)) + elif self.pre_allocate_kv_cache: + past_length = past_key_values[0] else: - past_length = past_key_values[0][0].size(-2) + past_length = past_key_values[0].size(-2) if attention_mask is not None and len(attention_mask.shape) == 2 and position_ids is None: # create position_ids on the fly for batch generation @@ -610,6 +703,11 @@ def forward( # MHA models: (batch_size, n_heads, query_length, key_length) attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1) + if self.pad_key_length: + pad = -key_length % 8 + if pad > 0: + attention_mask = torch.nn.functional.pad(attention_mask, (0, pad), mode="constant", value=False) + # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] if ( From 4187786eb5c42074f0ea6b3f50b7f6ae4810c282 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 12 Apr 2023 23:38:59 -0400 Subject: [PATCH 2/8] Fixes --- src/transformers/models/gpt_bigcode/inference_runner.py | 8 ++++---- .../models/gpt_bigcode/modeling_gpt_bigcode.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/inference_runner.py b/src/transformers/models/gpt_bigcode/inference_runner.py index 3e896fa2c7..1767bf9642 100644 --- a/src/transformers/models/gpt_bigcode/inference_runner.py +++ b/src/transformers/models/gpt_bigcode/inference_runner.py @@ -4,7 +4,7 @@ from transformers import GPTBigCodeConfig from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions -from transformers.models.gpt_bigcode.configuration_gpt_bigcode import AttentionType, InferenceRunnerType +from transformers.models.gpt_bigcode.configuration_gpt_bigcode import InferenceRunnerType from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeBlock, masked_softmax, upcast_masked_softmax @@ -25,7 +25,7 @@ def __init__(self, config: GPTBigCodeConfig, model): self.pad_key_length = 8 if config.pad_key_length else 1 # TODO: Support other attention types? - assert model.attention_type == AttentionType.MULTI_QUERY_1 + assert model.multi_query self.max_sequence_length = config.max_sequence_length or config.n_positions @@ -71,7 +71,7 @@ def _allocate(self, batch_size, device, dtype): for block in self.model.h: block.attn.freeze_kv_cache() kv_cache = block.attn.get_kv_cache(self.batch_size, self.max_sequence_length, self.device, self.dtype) - if attn.is_mqa: + if attn.multi_query: kv_cache = kv_cache.unsqueeze(1) kv_caches.append(kv_cache) @@ -122,7 +122,7 @@ def _allocate(self, batch_size, device, dtype): self.kv_attn = self.c_attn[:, attn.embed_dim :] keys, values = zip(*(kv_cache.split((attn.head_dim, attn.head_dim), dim=-1) for kv_cache in kv_caches)) - head_slice = 0 if attn.is_mqa else slice(None) + head_slice = 0 if attn.multi_query else slice(None) self.padded_keys = [ [key[:, head_slice, :key_length, :].transpose(-1, -2) for key in keys] for key_length in padded_key_lengths diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 8a4880ceb6..8810b34584 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -237,7 +237,7 @@ def get_kv_cache(self, batch_size, sequence_length, device, dtype, allocate=True kv_cache = self.kv_cache[: 2 * batch_size * self.kv_cache_max_sequence_length * self.kv_dim].view( batch_size, self.kv_heads, self.kv_cache_max_sequence_length, 2 * self.head_dim ) - return kv_cache[:, 0, :sequence_length, :] if self.is_mqa else kv_cache[:, :, :sequence_length, :] + return kv_cache[:, 0, :sequence_length, :] if self.multi_query else kv_cache[:, :, :sequence_length, :] def forward( self, @@ -287,7 +287,7 @@ def forward( kv_cache = self.get_kv_cache( batch_size, padded_key_length, key_value.device, key_value.dtype, allocate=last_key_length == 0 ) - if self.is_mqa: + if self.multi_query: kv_cache[:, last_key_length:key_length, :].copy_(key_value) key_value = kv_cache if use_cache: From 10f4a98dbfbbf8e00754267949cf85898e60795a Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 14 Apr 2023 17:20:56 -0400 Subject: [PATCH 3/8] Error message --- .../models/gpt_bigcode/modeling_gpt_bigcode.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 8810b34584..76ba07b73e 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -225,8 +225,14 @@ def get_kv_cache(self, batch_size, sequence_length, device, dtype, allocate=True or sequence_length > self.kv_cache_max_sequence_length ): if self._frozen_kv_cache or not allocate: - # TODO: Improve error message - raise RuntimeError("KV cache not found." if self.kv_cache is None else "Invalid KV cache.") + if self.kv_cache is None: + raise RuntimeError("KV cache not found.") + else: + raise RuntimeError( + f"Invalid KV cache: " + f"existing = {(self.kv_cache.dtype,self.kv_cache.device,self.kv_cache_max_batch_size,self.kv_cache_max_sequence_length)}, " + f"requested = {(dtype,device,batch_size,sequence_length)}" + ) # Free memory first. self.kv_cache = None self.kv_cache_max_sequence_length = max(sequence_length, self.kv_cache_max_sequence_length) From d08ce174360daf9d49425859882720b4c50b6813 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 19 Apr 2023 23:23:37 -0400 Subject: [PATCH 4/8] Attn implementations --- .../gpt_bigcode/configuration_gpt_bigcode.py | 24 + .../models/gpt_bigcode/inference_runner.py | 155 +++++- .../gpt_bigcode/modeling_gpt_bigcode.py | 448 ++++++++++++++++-- 3 files changed, 569 insertions(+), 58 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py index 1cfba93a71..807a17b258 100644 --- a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py @@ -40,6 +40,28 @@ class InferenceRunnerType(IntEnum): FULL_GRAPH = 3 +class AttentionImplementation(IntEnum): + # Ours + BASE = 0 + # Flash attention + FLASH = 1 + # scaled_dot_product_attention (multiple implementations) + TORCH = 2 + TORCH_FLASH = 3 + TORCH_MEM = 4 + TORCH_CPP = 5 + # DEBUG + OLD = 6 + + +TORCH_IMPLEMENTATIONS = ( + AttentionImplementation.TORCH, + AttentionImplementation.TORCH_FLASH, + AttentionImplementation.TORCH_MEM, + AttentionImplementation.TORCH_CPP, +) + + class GPTBigCodeConfig(PretrainedConfig): """ This is the configuration class to store the configuration of a [`GPTBigCodeModel`]. It is used to instantiate a @@ -134,6 +156,7 @@ def __init__( attention_softmax_in_fp32=True, scale_attention_softmax_in_fp32=True, multi_query=True, + attention_implementation=AttentionImplementation.BASE, inference_runner=InferenceRunnerType.NO_RUNNER, validate_runner_input=True, pre_allocate_kv_cache=False, @@ -159,6 +182,7 @@ def __init__( self.attention_softmax_in_fp32 = attention_softmax_in_fp32 self.scale_attention_softmax_in_fp32 = scale_attention_softmax_in_fp32 self.multi_query = multi_query + self.attention_implementation = attention_implementation self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id diff --git a/src/transformers/models/gpt_bigcode/inference_runner.py b/src/transformers/models/gpt_bigcode/inference_runner.py index 1767bf9642..c532852b5d 100644 --- a/src/transformers/models/gpt_bigcode/inference_runner.py +++ b/src/transformers/models/gpt_bigcode/inference_runner.py @@ -4,10 +4,20 @@ from transformers import GPTBigCodeConfig from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions -from transformers.models.gpt_bigcode.configuration_gpt_bigcode import InferenceRunnerType +from transformers.models.gpt_bigcode.configuration_gpt_bigcode import ( + TORCH_IMPLEMENTATIONS, + AttentionImplementation, + InferenceRunnerType, +) from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeBlock, masked_softmax, upcast_masked_softmax +try: + from flash_attn.flash_attn_interface import flash_attn_unpadded_func +except ImportError: + flash_attn_unpadded_func = None + + def _align_tensor(x): return x + -x % 128 @@ -19,11 +29,26 @@ def __init__(self, config: GPTBigCodeConfig, model): self.n_layer = len(self.model.h) self.inference_runner_type = InferenceRunnerType(config.inference_runner) + self.attention_implementation = config.attention_implementation assert self.inference_runner_type != InferenceRunnerType.NO_RUNNER + assert self.attention_implementation in ( + AttentionImplementation.BASE, + *TORCH_IMPLEMENTATIONS, + AttentionImplementation.FLASH, + ) assert config.pre_allocate_kv_cache self.validate_input = config.validate_runner_input self.pad_key_length = 8 if config.pad_key_length else 1 + if self.attention_implementation == AttentionImplementation.BASE: + self._forward_attn = self._forward_attn_base + elif self.attention_implementation in TORCH_IMPLEMENTATIONS: + self._forward_attn = self._forward_attn_torch + elif self.attention_implementation == AttentionImplementation.FLASH: + self._forward_attn = self._forward_attn_flash + else: + raise NotImplementedError(self.attention_implementation) + # TODO: Support other attention types? assert model.multi_query @@ -38,7 +63,11 @@ def _allocate(self, batch_size, device, dtype): self.softmax_dtype = torch.float32 if attn.attention_softmax_in_fp32 else self.dtype self.upcast = self.softmax_dtype != self.dtype - do_unscale = attn.scale_attention_softmax_in_fp32 and self.upcast + do_unscale = ( + attn.scale_attention_softmax_in_fp32 + and self.upcast + and self.attention_implementation == AttentionImplementation.BASE + ) self.unscale = [i + 1.0 if do_unscale else 1.0 for i in range(self.n_layer)] scale = attn.head_dim**-0.5 if attn.scale_attn_weights else 1 self.scale = [scale / unscale for unscale in self.unscale] @@ -51,9 +80,11 @@ def _allocate(self, batch_size, device, dtype): query_end = query_begin + self.batch_size * attn.embed_dim # KV: (bs, 2 * kv_dim), combines with query into c_attn. kv_end = query_end + 2 * self.batch_size * attn.kv_dim - # Attn weights: (batch_size, num_heads, key_length), no overlap with value + # Attn weights: (batch_size, num_heads, key_length), no overlap with value (not needed for torch/flash attn) attn_weights_begin = _align_tensor(kv_end) - attn_weights_end = kv_end + self.batch_size * attn.num_heads * self.max_sequence_length + attn_weights_end = attn_weights_begin + if self.attention_implementation == AttentionImplementation.BASE: + attn_weights_end += self.batch_size * attn.num_heads * self.max_sequence_length # Projection: (batch_size, embed_dim), no overlap with attn outputs ~ query. # Also used for MLP projection c_proj_begin = _align_tensor(query_end) @@ -119,13 +150,25 @@ def _allocate(self, batch_size, device, dtype): # QKV: (bs, embed_dim + 2 * kv_dim). self.c_attn = activation_pool[query_begin:kv_end].view(self.batch_size, -1) self.query = self.c_attn[:, : attn.embed_dim].view(self.batch_size, attn.num_heads, attn.head_dim) + # if self.attention_implementation==AttentionImplementation.FLASH: + # self.query=query.view(self.batch_size * attn.num_heads, 1, attn.head_dim) + # else: + # self.query=query.view(self.batch_size, attn.num_heads, attn.head_dim) + self.kv_attn = self.c_attn[:, attn.embed_dim :] keys, values = zip(*(kv_cache.split((attn.head_dim, attn.head_dim), dim=-1) for kv_cache in kv_caches)) head_slice = 0 if attn.multi_query else slice(None) + # No transpose for torch/flash attn self.padded_keys = [ - [key[:, head_slice, :key_length, :].transpose(-1, -2) for key in keys] for key_length in padded_key_lengths + [ + key[:, head_slice, :key_length, :].transpose( + -1, -2 if self.attention_implementation == AttentionImplementation.BASE else -1 + ) + for key in keys + ] + for key_length in padded_key_lengths ] self.padded_values = [ [value[:, head_slice, :key_length, :] for value in values] for key_length in padded_key_lengths @@ -139,15 +182,22 @@ def _allocate(self, batch_size, device, dtype): [kv_cache[:, head_slice, : key_length - 1, :] for kv_cache in kv_caches] for key_length in key_lengths ] - # Attn weights: (batch_size, num_heads, key_length), no overlap with value. - attn_weights = activation_pool[attn_weights_begin:attn_weights_end].view( - self.batch_size, attn.num_heads, self.max_sequence_length - ) - self.padded_attn_weights = [attn_weights[:, :, :key_length] for key_length in padded_key_lengths] - - # Attn outputs: (batch_size, embed_dim), no overlap with value. - self.attn_output = activation_pool[query_begin:query_end].view(self.batch_size, -1) - self.attn_output_expanded = self.attn_output.view(self.batch_size, attn.num_heads, attn.head_dim) + if self.attention_implementation == AttentionImplementation.BASE: + # Attn weights: (batch_size, num_heads, key_length), no overlap with value. + attn_weights = activation_pool[attn_weights_begin:attn_weights_end].view( + self.batch_size, attn.num_heads, self.max_sequence_length + ) + self.padded_attn_weights = [attn_weights[:, :, :key_length] for key_length in padded_key_lengths] + + # Attn outputs: (batch_size, embed_dim), no overlap with value. + self.attn_output = activation_pool[query_begin:query_end].view(self.batch_size, -1) + self.attn_output_expanded = self.attn_output.view(self.batch_size, attn.num_heads, attn.head_dim) + elif self.attention_implementation == AttentionImplementation.FLASH: + self.cu_sq = torch.arange( + 0, (self.batch_size + 1) * attn.num_heads, step=attn.num_heads, dtype=torch.int32, device=self.device + ) + # self.cu_sk = torch.arange(0, (self.batch_size + 1) * self.max_sequence_length, step=self.max_sequence_length, dtype=torch.int32, + # device=self.device) # Attn projection: (batch_size, embed_dim), no overlap with attn outputs. self.c_proj = activation_pool[c_proj_begin:c_proj_end].view(self.batch_size, -1) @@ -159,6 +209,10 @@ def _allocate(self, batch_size, device, dtype): if self.inference_runner_type != InferenceRunnerType.BASE_RUNNER: print("Generating cuda graphs") self.memory_pool = None + # This prevents some issue with cublas initialization. + # https://github.com/pytorch/pytorch/issues/99397 + dummy_matrix = self.mask_value.view([1, 1]) + torch.matmul(dummy_matrix, dummy_matrix) if self.inference_runner_type == InferenceRunnerType.FULL_GRAPH: self.cuda_graphs = {} # The output may not always be at the same memory location. @@ -187,22 +241,24 @@ def _generate_cuda_graphs(self): def _generate_full_cuda_graph(self, key_length): # We need to warmup the jit function before creating the graph, otherwise it will crash. + # https://github.com/pytorch/pytorch/issues/99397 # Warmup needs to be done for every input shape (key length), and for both scale == 1 and scale != 1 - if self.upcast: - for scale in (1.0, 2.0): - upcast_masked_softmax( + if self.attention_implementation == AttentionImplementation.BASE: + if self.upcast: + for scale in (1.0, 2.0): + upcast_masked_softmax( + self.padded_attn_weights[key_length], + self.padded_attn_masks[key_length], + self.mask_value, + scale, + self.softmax_dtype, + ) + else: + masked_softmax( self.padded_attn_weights[key_length], self.padded_attn_masks[key_length], self.mask_value, - scale, - self.softmax_dtype, ) - else: - masked_softmax( - self.padded_attn_weights[key_length], - self.padded_attn_masks[key_length], - self.mask_value, - ) graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, pool=self.memory_pool): self.output_hidden_states[key_length] = self._forward(key_length) @@ -226,7 +282,7 @@ def _forward_qkv(self, block): out=self.c_attn, ) - def _forward_attn(self, block, key_length): + def _forward_attn_base(self, block, key_length): layer_idx = block.attn.layer_idx self.current_key_values[key_length][layer_idx].copy_(self.kv_attn) attn_weights = self.padded_attn_weights[key_length] @@ -254,6 +310,50 @@ def _forward_attn(self, block, key_length): torch.bmm(attn_weights, self.padded_values[key_length][layer_idx], out=self.attn_output_expanded) + def _forward_attn_torch(self, block, key_length): + layer_idx = block.attn.layer_idx + self.current_key_values[key_length][layer_idx].copy_(self.kv_attn) + with block.attn.backend_context(): + attn_output = torch.nn.functional.scaled_dot_product_attention( + self.query, + self.padded_keys[key_length][layer_idx], + self.padded_values[key_length][layer_idx], + None, # attention_mask, + 0.0, + is_causal=False, + ) + # Out arg not supported so we set the variable instead. + self.attn_output = attn_output.view(self.batch_size, -1) + + def _forward_attn_flash(self, block, key_length): + layer_idx = block.attn.layer_idx + num_heads = block.attn.num_heads + self.current_key_values[key_length][layer_idx].copy_(self.kv_attn) + # TODO: Pre-allocate? + # TODO: Adjust for non-contiguous key/value? (max seq len instead of key length) + cu_sk = torch.arange( + 0, (self.batch_size + 1) * key_length, step=key_length, dtype=torch.int32, device=self.device + ) + # TODO: Avoid reshape + q = self.query.reshape(self.batch_size * num_heads, 1, block.attn.head_dim) + k = self.padded_keys[key_length][layer_idx].reshape(self.batch_size * key_length, 1, block.attn.head_dim) + v = self.padded_values[key_length][layer_idx].reshape(self.batch_size * key_length, 1, block.attn.head_dim) + print("A", q.shape, k.shape, v.shape) + attn_output = flash_attn_unpadded_func( + q, + k, + v, + self.cu_sq, + cu_sk, + num_heads, + key_length, + 0.0, + softmax_scale=self.scale[layer_idx], + causal=False, + ) + # Out arg not supported so we set the variable instead. + self.attn_output = attn_output.view(self.batch_size, -1) + def _forward_post_attn(self, block): torch.nn.functional.linear( self.attn_output, @@ -261,6 +361,9 @@ def _forward_post_attn(self, block): block.attn.c_proj.bias, out=self.c_proj, ) + if self.attention_implementation != AttentionImplementation.BASE: + # Free memory. + del self.attn_output self.hidden_states_squeezed.add_(self.c_proj) # LN doesn't support out argument. hidden_states = block.ln_2(self.hidden_states_squeezed) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 76ba07b73e..77df8e26f9 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """PyTorch GPTBigCode model.""" +import contextlib import math from typing import List, Optional, Tuple, Union @@ -34,7 +35,18 @@ add_start_docstrings_to_model_forward, logging, ) -from .configuration_gpt_bigcode import GPTBigCodeConfig, InferenceRunnerType +from .configuration_gpt_bigcode import ( + TORCH_IMPLEMENTATIONS, + AttentionImplementation, + GPTBigCodeConfig, + InferenceRunnerType, +) + + +try: + from flash_attn.flash_attn_interface import flash_attn_unpadded_func +except ImportError: + flash_attn_unpadded_func = None logger = logging.get_logger(__name__) @@ -84,6 +96,8 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): self.mask_value = None self.multi_query = config.multi_query + # TODO: chack availability + self.attention_implementation = config.attention_implementation self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads @@ -127,13 +141,37 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): self.attn_dropout = nn.Dropout(config.attn_pdrop) self.resid_dropout = nn.Dropout(config.resid_pdrop) + if self.attention_implementation == AttentionImplementation.BASE: + self._attn_fn = self._attn_mqa if self.multi_query else self._attn_mha + elif self.attention_implementation in TORCH_IMPLEMENTATIONS: + # TODO: Implement + assert not self.pre_allocate_kv_cache + self._attn_fn = self._attn_torch_mqa if self.multi_query else self._attn_torch_mha + self.backend_context = ( + lambda: contextlib.nullcontext() + if self.attention_implementation == AttentionImplementation.TORCH + else torch.backends.cuda.sdp_kernel( + enable_flash=self.attention_implementation == AttentionImplementation.TORCH_FLASH, + enable_math=self.attention_implementation == AttentionImplementation.TORCH_CPP, + enable_mem_efficient=self.attention_implementation == AttentionImplementation.TORCH_MEM, + ) + ) + elif self.attention_implementation == AttentionImplementation.FLASH: + # TODO: Implement + assert not self.pre_allocate_kv_cache + self._attn_fn = self._attn_flash_mqa if self.multi_query else self._attn_flash_mha + elif self.attention_implementation == AttentionImplementation.OLD: + self._attn_fn = None + else: + raise ValueError() + def _get_mask_value(self, device, dtype): # torch.where expects a tensor. We use a cache to avoid recreating it every time. if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device: self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device) return self.mask_value - def _attn(self, query, key, value, attention_mask=None, head_mask=None): + def _get_attn_weights(self, query, key, attn_view, attn_shape, attention_mask=None, head_mask=None): dtype = query.dtype softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else dtype upcast = dtype != softmax_dtype @@ -143,30 +181,6 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): if self.scale_attn_weights: scale_factor /= self.head_dim**0.5 - # MQA models: (batch_size, query_length, num_heads * head_dim) - # MHA models: (batch_size, num_heads, query_length, head_dim) - query_shape = query.shape - batch_size = query_shape[0] - key_length = key.size(-1) - if self.multi_query: - # (batch_size, query_length, num_heads, head_dim) x (batch_size, head_dim, key_length) - # -> (batch_size, query_length, num_heads, key_length) - query_length = query_shape[1] - attn_shape = (batch_size, query_length, self.num_heads, key_length) - attn_view = (batch_size, query_length * self.num_heads, key_length) - # No copy needed for MQA 2, or when layer_past is provided. - query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim) - else: - # (batch_size, num_heads, query_length, head_dim) x (batch_size, num_heads, head_dim, key_length) - # -> (batch_size, num_heads, query_length, key_length) - query_length = query_shape[2] - attn_shape = (batch_size, self.num_heads, query_length, key_length) - attn_view = (batch_size * self.num_heads, query_length, key_length) - # Always copies - query = query.reshape(batch_size * self.num_heads, query_length, self.head_dim) - # No copy when layer_past is provided. - key = key.reshape(batch_size * self.num_heads, self.head_dim, key_length) - attn_weights = torch.empty(attn_view, device=query.device, dtype=query.dtype) if query.device.type == "cpu": # This is needed because of a bug in pytorch https://github.com/pytorch/pytorch/issues/80588. @@ -199,16 +213,239 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): # Mask heads if we want to if head_mask is not None: - if self.multi_query: - head_mask = head_mask.transpose(1, 2) attn_weights = attn_weights * head_mask - if self.multi_query: - attn_output = torch.bmm(attn_weights.view(attn_view), value).view(query_shape) + return attn_weights + + def _attn_mha(self, hidden_states, layer_past, use_cache, attention_mask=None, head_mask=None): + # Q: (batch_size, num_heads, query_length, head_dim) + # K, V: (batch_size, num_heads, query_length, head_dim) + # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim), + # i.e., the memory layout is not the same as GPT2. + # This makes the concatenation with past_key_value more efficient. + query, key_value = ( + self.c_attn(hidden_states) + .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim) + .transpose(1, 2) + .split((self.head_dim, 2 * self.head_dim), dim=3) + ) + key, value, present = self._merge_kv_caches(key_value, layer_past, use_cache) + + batch_size = query.size(0) + query_length = query.size(1) + key_length = key.size(1) + + # MHA models: (batch_size, num_heads, query_length, head_dim) + # (batch_size, num_heads, query_length, head_dim) x (batch_size, num_heads, head_dim, key_length) + # -> (batch_size, num_heads, query_length, key_length) + attn_shape = (batch_size, self.num_heads, query_length, key_length) + attn_view = (batch_size * self.num_heads, query_length, key_length) + # Always copies + query = query.reshape(batch_size * self.num_heads, query_length, self.head_dim) + # No copy when layer_past is provided. + key = key.transpose(-1, -2).reshape(batch_size * self.num_heads, self.head_dim, key_length) + + # attn_weights: (batch_size, num_heads, query_length, key_length) + attn_weights = self._get_attn_weights(query, key, attn_view, attn_shape, attention_mask, head_mask) + + # attn_output: (batch_size, num_heads, query_length, head_dim) + attn_output = torch.matmul(attn_weights, value) + + # attn_output: (batch_size, query_length, num_heads * head_dim) + attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape) + + return attn_output, present, attn_weights + + def _attn_mqa(self, hidden_states, layer_past, use_cache, attention_mask=None, head_mask=None): + # Q: (batch_size, query_length, num_heads * head_dim) + # K, V:(batch_size, key_length, head_dim) + query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) + key, value, present = self._merge_kv_caches(key_value, layer_past, use_cache) + + batch_size = query.size(0) + query_length = query.size(1) + key_length = key.size(1) + + # (batch_size, query_length, num_heads, head_dim) x (batch_size, head_dim, key_length) + # -> (batch_size, query_length, num_heads, key_length) + attn_shape = (batch_size, query_length, self.num_heads, key_length) + attn_view = (batch_size, query_length * self.num_heads, key_length) + # No copy needed for MQA 2, or when layer_past is provided. + query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim) + key = key.transpose(-1, -2) + + if head_mask is not None: + head_mask = head_mask.transpose(1, 2) + + # attn_weights: (batch_size, query_length, num_heads, key_length) + attn_weights = self._get_attn_weights(query, key, attn_view, attn_shape, attention_mask, head_mask) + + # attn_output: (batch_size, query_length, num_heads * head_dim) + attn_output = torch.bmm(attn_weights.view(attn_view), value).view(hidden_states.shape) + + return attn_output, present, attn_weights + + def _attn_torch_mha(self, hidden_states, layer_past, use_cache, attention_mask=None, head_mask=None): + # TODO: Scale? + # TODO: Use attn mask + if head_mask is not None: + raise NotImplementedError() + # Q: (batch_size, num_heads, query_length, head_dim) + # K, V: (batch_size, num_heads, query_length, head_dim) + query, key_value = ( + self.c_attn(hidden_states) + .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim) + .transpose(1, 2) + .split((self.head_dim, 2 * self.head_dim), dim=3) + ) + + key, value, present = self._merge_kv_caches(key_value, layer_past, use_cache) + + with self.backend_context(): + # attn_output: (batch_size, num_heads, query_length, head_dim) + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, key, value, None, self.attn_dropout, is_causal=True # attention_mask, + ) + + # attn_output: (batch_size, query_length, num_heads * head_dim) + attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape) + + return attn_output, present, None + + def _attn_torch_mqa(self, hidden_states, layer_past, use_cache, attention_mask=None, head_mask=None): + # TODO: Scale? + # TODO: Use attn mask + if head_mask is not None: + raise NotImplementedError() + # Q: (batch_size, query_length, num_heads * head_dim) + # K, V:(batch_size, key_length, head_dim) + query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) + key, value, present = self._merge_kv_caches(key_value, layer_past, use_cache) + + batch_size = query.size(0) + query_length = query.size(1) + + if query_length == 1: + # attn_output: (batch_size, 1, num_heads, head_dim) + is_causal = False + query = query.view(batch_size, 1, self.num_heads, self.head_dim) + key = key.unsqueeze(1) + value = value.unsqueeze(1) else: - attn_output = torch.matmul(attn_weights, value) + # attn_output: (batch_size, num_heads, query_length, head_dim) + is_causal = True + expanded_shape = (batch_size, self.num_heads, key.size(-2), self.head_dim) + query = query.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2) + key = key.unsqueeze(1).expand(expanded_shape) + value = value.unsqueeze(1).expand(expanded_shape) + with self.backend_context(): + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + None, # attention_mask, + self.dropout_p if self.training else 0.0, + is_causal=is_causal, + ) - return attn_output, attn_weights + if query_length != 1: + attn_output = attn_output.transpose(1, 2) + + # attn_output: (batch_size, query_length, num_heads * head_dim) + # CPP backend needs reshape, others are ok with a view. + attn_output = attn_output.reshape(hidden_states.shape) + + return attn_output, present, None + + def _attn_flash_mha(self, hidden_states, layer_past, use_cache, attention_mask=None, head_mask=None): + # TODO: Use attn mask + if head_mask is not None: + raise NotImplementedError() + # Q: (batch_size, query_length, num_heads * head_dim) + # K, V:(batch_size, key_length, num_heads * head_dim) + query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) + key, value, present = self._merge_kv_caches(key_value, layer_past, use_cache) + + batch_size = query.size(0) + query_length = query.size(1) + key_length = key.size(1) + + # TODO: Pre-allocate? + cu_sq = torch.arange( + 0, (batch_size + 1) * query_length, step=query_length, dtype=torch.int32, device=query.device + ) + cu_sk = torch.arange(0, (batch_size + 1) * key_length, step=key_length, dtype=torch.int32, device=query.device) + + # attn_output: (batch_size * query_length, num_heads, head_dim) + attn_output = flash_attn_unpadded_func( + query.view(batch_size * query_length, self.num_heads, self.head_dim), + key.view(batch_size * key_length, self.num_heads, self.head_dim), + value.view(batch_size * key_length, self.num_heads, self.head_dim), + cu_sq, + cu_sk, + query_length, + key_length, + self.dropout_p if self.training else 0.0, + softmax_scale=self.head_dim**-0.5 if self.scale_attn_weights else 1, + causal=True, + ) + + # attn_output: (batch_size, query_length, num_heads * head_dim) + # CPP backend needs reshape, others are ok with a view. (Checked for mqa, confirm for mha) + attn_output = attn_output.reshape(hidden_states.shape) + + return attn_output, present, None + + def _attn_flash_mqa(self, hidden_states, layer_past, use_cache, attention_mask=None, head_mask=None): + # TODO: Use attn mask + if head_mask is not None: + raise NotImplementedError() + # Q: (batch_size, query_length, num_heads * head_dim) + # K, V:(batch_size, key_length, head_dim) + query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) + key, value, present = self._merge_kv_caches(key_value, layer_past, use_cache) + + batch_size = query.size(0) + query_length = query.size(1) + key_length = key.size(1) + + key = key.reshape(batch_size * key_length, 1, self.head_dim) + value = value.reshape(batch_size * key_length, 1, self.head_dim) + + if query_length == 1: + # attn_output: (batch_size * num_heads, 1, head_dim) + query_step = self.num_heads + causal = False + query = query.reshape(batch_size * query_step, 1, self.head_dim) + else: + # attn_output: (batch_size * query_length, num_heads, head_dim) + query_step = query_length + causal = True + query = query.view(batch_size * query_step, self.num_heads, self.head_dim) + key = key.expand(batch_size * key_length, self.num_heads, self.head_dim) + value = value.expand(batch_size * key_length, self.num_heads, self.head_dim) + + # TODO: Pre-allocate? + cu_sq = torch.arange(0, (batch_size + 1) * query_step, step=query_step, dtype=torch.int32, device=query.device) + cu_sk = torch.arange(0, (batch_size + 1) * key_length, step=key_length, dtype=torch.int32, device=query.device) + + attn_output = flash_attn_unpadded_func( + query, + key, + value, + cu_sq, + cu_sk, + query_step, + key_length, + self.dropout_p if self.training else 0.0, + softmax_scale=self.head_dim**-0.5 if self.scale_attn_weights else 1, + causal=causal, + ) + + # attn_output: (batch_size, query_length, num_heads * head_dim) + attn_output = attn_output.view(hidden_states.shape) + + return attn_output, present, None def freeze_kv_cache(self, enable=True): if self.kv_cache is None: @@ -245,6 +482,30 @@ def get_kv_cache(self, batch_size, sequence_length, device, dtype, allocate=True ) return kv_cache[:, 0, :sequence_length, :] if self.multi_query else kv_cache[:, :, :sequence_length, :] + def _merge_kv_caches(self, key_value, layer_past, use_cache): + present = None + if self.pre_allocate_kv_cache: + if use_cache or layer_past is not None: + last_key_length = layer_past or 0 + batch_size = key_value.size(0) + key_length = last_key_length + key_value.size(-2) + padded_key_length = key_length + -key_length % (8 if self.pad_key_length else 1) + kv_cache = self.get_kv_cache( + batch_size, padded_key_length, key_value.device, key_value.dtype, allocate=last_key_length == 0 + ) + if self.multi_query: + kv_cache[:, last_key_length:key_length, :].copy_(key_value) + key_value = kv_cache + if use_cache: + present = key_length + else: + if layer_past is not None: + key_value = torch.cat((layer_past, key_value), dim=-2) + if use_cache: + present = key_value + key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) + return key, value, present + def forward( self, hidden_states: torch.Tensor, @@ -258,6 +519,52 @@ def forward( ) -> Union[ Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], + ]: + if self.attention_implementation == AttentionImplementation.OLD: + return self._old_forward( + hidden_states, + layer_past, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + use_cache, + output_attentions, + ) + if encoder_hidden_states is not None: + raise NotImplementedError() + + attn_output, present, attn_weights = self._attn_fn( + hidden_states, layer_past, use_cache, attention_mask, head_mask + ) + + attn_output = self.c_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + if attn_weights is None: + raise NotImplementedError("`output_attentions` not supported.") + if self.multi_query: + # Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length) + attn_weights = attn_weights.transpose(1, 2) + outputs += (attn_weights,) + + return outputs # a, present, (attentions) + + def _old_forward( + self, + hidden_states: torch.Tensor, + layer_past: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Union[ + Tuple[torch.Tensor, Optional[torch.Tensor]], + Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], ]: if encoder_hidden_states is not None: if not hasattr(self, "q_attn") or not self.is_cross_attention: @@ -322,6 +629,83 @@ def forward( return outputs # a, present, (attentions) + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + dtype = query.dtype + softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else dtype + upcast = dtype != softmax_dtype + + unscale = self.layer_idx + 1 if self.scale_attention_softmax_in_fp32 and upcast else 1 + scale_factor = unscale**-1 + if self.scale_attn_weights: + scale_factor /= self.head_dim**0.5 + + # MQA models: (batch_size, query_length, num_heads * head_dim) + # MHA models: (batch_size, num_heads, query_length, head_dim) + query_shape = query.shape + batch_size = query_shape[0] + key_length = key.size(-1) + if self.multi_query: + # (batch_size, query_length, num_heads, head_dim) x (batch_size, head_dim, key_length) + # -> (batch_size, query_length, num_heads, key_length) + query_length = query_shape[1] + attn_shape = (batch_size, query_length, self.num_heads, key_length) + attn_view = (batch_size, query_length * self.num_heads, key_length) + # No copy needed for MQA 2, or when layer_past is provided. + query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim) + else: + # (batch_size, num_heads, query_length, head_dim) x (batch_size, num_heads, head_dim, key_length) + # -> (batch_size, num_heads, query_length, key_length) + query_length = query_shape[2] + attn_shape = (batch_size, self.num_heads, query_length, key_length) + attn_view = (batch_size * self.num_heads, query_length, key_length) + # Always copies + query = query.reshape(batch_size * self.num_heads, query_length, self.head_dim) + # No copy when layer_past is provided. + key = key.reshape(batch_size * self.num_heads, self.head_dim, key_length) + + attn_weights = torch.empty(attn_view, device=query.device, dtype=query.dtype) + if query.device.type == "cpu": + # This is needed because of a bug in pytorch https://github.com/pytorch/pytorch/issues/80588. + # The bug was fixed in https://github.com/pytorch/pytorch/pull/96086, + # but the fix has not been released as of pytorch version 2.0.0. + attn_weights.zero_() + beta = 1 + else: + beta = 0 + attn_weights = torch.baddbmm(attn_weights, query, key, beta=beta, alpha=scale_factor).view(attn_shape) + + if upcast: + # Use a fused kernel to prevent a large overhead from casting and scaling. + # Sub-optimal when the key length is not a multiple of 8. + if attention_mask is None: + attn_weights = upcast_softmax(attn_weights, unscale, softmax_dtype) + else: + mask_value = self._get_mask_value(attn_weights.device, softmax_dtype) + attn_weights = upcast_masked_softmax(attn_weights, attention_mask, mask_value, unscale, softmax_dtype) + else: + if attention_mask is not None: + mask_value = self._get_mask_value(attn_weights.device, softmax_dtype) + + # The fused kernel is very slow when the key length is not a multiple of 8, so we skip fusion. + attn_weights = torch.where(attention_mask, attn_weights, mask_value) + + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + if self.multi_query: + head_mask = head_mask.transpose(1, 2) + attn_weights = attn_weights * head_mask + + if self.multi_query: + attn_output = torch.bmm(attn_weights.view(attn_view), value).view(query_shape) + else: + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + class GPTBigCodeMLP(nn.Module): def __init__(self, intermediate_size, config): From 23eacc74998cc30fdf965c917a1445e0f6c55943 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 20 Apr 2023 22:28:58 -0400 Subject: [PATCH 5/8] Optional jit --- .../gpt_bigcode/configuration_gpt_bigcode.py | 2 + .../models/gpt_bigcode/inference_runner.py | 42 +++++----- .../gpt_bigcode/modeling_gpt_bigcode.py | 76 ++++++++++++++----- 3 files changed, 75 insertions(+), 45 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py index 807a17b258..c5a0bf1b56 100644 --- a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py @@ -155,6 +155,7 @@ def __init__( eos_token_id=50256, attention_softmax_in_fp32=True, scale_attention_softmax_in_fp32=True, + fused_softmax=None, multi_query=True, attention_implementation=AttentionImplementation.BASE, inference_runner=InferenceRunnerType.NO_RUNNER, @@ -181,6 +182,7 @@ def __init__( self.use_cache = use_cache self.attention_softmax_in_fp32 = attention_softmax_in_fp32 self.scale_attention_softmax_in_fp32 = scale_attention_softmax_in_fp32 + self.fused_softmax = fused_softmax self.multi_query = multi_query self.attention_implementation = attention_implementation diff --git a/src/transformers/models/gpt_bigcode/inference_runner.py b/src/transformers/models/gpt_bigcode/inference_runner.py index c532852b5d..a4452195bd 100644 --- a/src/transformers/models/gpt_bigcode/inference_runner.py +++ b/src/transformers/models/gpt_bigcode/inference_runner.py @@ -9,7 +9,7 @@ AttentionImplementation, InferenceRunnerType, ) -from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeBlock, masked_softmax, upcast_masked_softmax +from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeBlock, softmax_function try: @@ -39,6 +39,7 @@ def __init__(self, config: GPTBigCodeConfig, model): assert config.pre_allocate_kv_cache self.validate_input = config.validate_runner_input self.pad_key_length = 8 if config.pad_key_length else 1 + self.fused_softmax = config.fused_softmax and self.attention_implementation == AttentionImplementation.BASE if self.attention_implementation == AttentionImplementation.BASE: self._forward_attn = self._forward_attn_base @@ -243,21 +244,16 @@ def _generate_full_cuda_graph(self, key_length): # We need to warmup the jit function before creating the graph, otherwise it will crash. # https://github.com/pytorch/pytorch/issues/99397 # Warmup needs to be done for every input shape (key length), and for both scale == 1 and scale != 1 - if self.attention_implementation == AttentionImplementation.BASE: - if self.upcast: - for scale in (1.0, 2.0): - upcast_masked_softmax( - self.padded_attn_weights[key_length], - self.padded_attn_masks[key_length], - self.mask_value, - scale, - self.softmax_dtype, - ) - else: - masked_softmax( + if self.attention_implementation == AttentionImplementation.BASE and self.fused_softmax: + for scale in (1.0, 2.0): + softmax_function( self.padded_attn_weights[key_length], self.padded_attn_masks[key_length], self.mask_value, + scale, + self.softmax_dtype, + self.upcast, + self.fused_softmax, ) graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, pool=self.memory_pool): @@ -295,18 +291,16 @@ def _forward_attn_base(self, block, key_length): alpha=self.scale[layer_idx], out=attn_weights, ) - # Use a fused kernel to prevent a large overhead from casting and scaling. # Jit doesn't allow inplace kernel. - if self.upcast: - attn_weights = upcast_masked_softmax( - attn_weights, - self.padded_attn_masks[key_length], - self.mask_value, - self.unscale[layer_idx], - self.softmax_dtype, - ) - else: - attn_weights = masked_softmax(attn_weights, self.padded_attn_masks[key_length], self.mask_value) + attn_weights = softmax_function( + attn_weights, + self.padded_attn_masks[key_length], + self.mask_value, + self.unscale[layer_idx], + self.softmax_dtype, + self.upcast, + block.attn.self.fused_softmax, + ) torch.bmm(attn_weights, self.padded_values[key_length][layer_idx], out=self.attn_output_expanded) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 77df8e26f9..3976022f1c 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -60,11 +60,6 @@ ] -# Fused kernels -# Use separate functions for each case because conditionals prevent kernel fusion. -# TODO: Could have better fused kernels depending on scaling, dropout and head mask. -# Is it doable without writing 32 functions? -@torch.jit.script def upcast_masked_softmax( x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float, softmax_dtype: torch.dtype ): @@ -76,6 +71,12 @@ def upcast_masked_softmax( @torch.jit.script +def upcast_masked_softmax_fused( + x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float, softmax_dtype: torch.dtype +): + return upcast_masked_softmax(x, mask, mask_value, scale, softmax_dtype) + + def upcast_softmax(x: torch.Tensor, scale: float, softmax_dtype: torch.dtype): input_dtype = x.dtype x = x.to(softmax_dtype) * scale @@ -84,12 +85,51 @@ def upcast_softmax(x: torch.Tensor, scale: float, softmax_dtype: torch.dtype): @torch.jit.script +def upcast_softmax_fused(x: torch.Tensor, scale: float, softmax_dtype: torch.dtype): + return upcast_softmax(x, scale, softmax_dtype) + + def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor): x = torch.where(mask, x, mask_value) x = torch.nn.functional.softmax(x, dim=-1) return x +@torch.jit.script +def masked_softmax_fused(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor): + return masked_softmax(x, mask, mask_value) + + +def softmax_function( + x: torch.Tensor, + mask: torch.Tensor, + mask_value: torch.Tensor, + scale: float, + softmax_dtype: torch.dtype, + upcast: bool = True, + fused_softmax: bool = False, +): + """ + This selects the appropriate (fused) (upcast) (masked) softmax method. Because of the way jit works, each case + needs to be handled through a separate method. The fused kernels remove most of the overhead from masking, casting + and scaling, but only work well when the key length is a multiple of 8. For other key lengths, it is extremely + inefficient. TODO: Could have better fused kernels depending on scaling, dropout and head mask. + Is it doable without writing 32 functions? + """ + if upcast: + if mask is None: + return (upcast_softmax_fused if fused_softmax else upcast_softmax)(x, scale, softmax_dtype) + else: + return (upcast_masked_softmax_fused if fused_softmax else upcast_masked_softmax)( + x, mask, mask_value, scale, softmax_dtype + ) + else: + if mask is None: + return torch.nn.functional.softmax(x, dim=-1) + else: + return (masked_softmax_fused if fused_softmax else masked_softmax)(x, mask, mask_value) + + class GPTBigCodeAttention(nn.Module): def __init__(self, config, is_cross_attention=False, layer_idx=None): super().__init__() @@ -118,6 +158,7 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): self.scale_attention_softmax_in_fp32 = ( config.scale_attention_softmax_in_fp32 and config.attention_softmax_in_fp32 ) + self.fused_softmax = config.fused_softmax # KV caching and padding self.kv_cache = None @@ -192,22 +233,15 @@ def _get_attn_weights(self, query, key, attn_view, attn_shape, attention_mask=No beta = 0 attn_weights = torch.baddbmm(attn_weights, query, key, beta=beta, alpha=scale_factor).view(attn_shape) - if upcast: - # Use a fused kernel to prevent a large overhead from casting and scaling. - # Sub-optimal when the key length is not a multiple of 8. - if attention_mask is None: - attn_weights = upcast_softmax(attn_weights, unscale, softmax_dtype) - else: - mask_value = self._get_mask_value(attn_weights.device, softmax_dtype) - attn_weights = upcast_masked_softmax(attn_weights, attention_mask, mask_value, unscale, softmax_dtype) - else: - if attention_mask is not None: - mask_value = self._get_mask_value(attn_weights.device, softmax_dtype) - - # The fused kernel is very slow when the key length is not a multiple of 8, so we skip fusion. - attn_weights = torch.where(attention_mask, attn_weights, mask_value) - - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + attn_weights = softmax_function( + attn_weights, + attention_mask, + None if attention_mask is None else self._get_mask_value(attn_weights.device, softmax_dtype), + unscale, + softmax_dtype, + upcast, + key.size(-1) % 8 == 0 if self.fused_softmax is None else self.fused_softmax, + ) attn_weights = self.attn_dropout(attn_weights) From 5dd048f1065ee9e9047a6f725457482b1a64278f Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 20 Apr 2023 22:41:31 -0400 Subject: [PATCH 6/8] No flash for graphs --- .../models/gpt_bigcode/inference_runner.py | 68 ++++--------------- .../gpt_bigcode/modeling_gpt_bigcode.py | 8 +-- 2 files changed, 19 insertions(+), 57 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/inference_runner.py b/src/transformers/models/gpt_bigcode/inference_runner.py index a4452195bd..8de50f05e7 100644 --- a/src/transformers/models/gpt_bigcode/inference_runner.py +++ b/src/transformers/models/gpt_bigcode/inference_runner.py @@ -5,7 +5,6 @@ from transformers import GPTBigCodeConfig from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions from transformers.models.gpt_bigcode.configuration_gpt_bigcode import ( - TORCH_IMPLEMENTATIONS, AttentionImplementation, InferenceRunnerType, ) @@ -29,26 +28,11 @@ def __init__(self, config: GPTBigCodeConfig, model): self.n_layer = len(self.model.h) self.inference_runner_type = InferenceRunnerType(config.inference_runner) - self.attention_implementation = config.attention_implementation assert self.inference_runner_type != InferenceRunnerType.NO_RUNNER - assert self.attention_implementation in ( - AttentionImplementation.BASE, - *TORCH_IMPLEMENTATIONS, - AttentionImplementation.FLASH, - ) assert config.pre_allocate_kv_cache self.validate_input = config.validate_runner_input self.pad_key_length = 8 if config.pad_key_length else 1 - self.fused_softmax = config.fused_softmax and self.attention_implementation == AttentionImplementation.BASE - - if self.attention_implementation == AttentionImplementation.BASE: - self._forward_attn = self._forward_attn_base - elif self.attention_implementation in TORCH_IMPLEMENTATIONS: - self._forward_attn = self._forward_attn_torch - elif self.attention_implementation == AttentionImplementation.FLASH: - self._forward_attn = self._forward_attn_flash - else: - raise NotImplementedError(self.attention_implementation) + self.fused_softmax = True if config.fused_softmax is None and config.pad_key_length else config.fused_softmax # TODO: Support other attention types? assert model.multi_query @@ -64,11 +48,7 @@ def _allocate(self, batch_size, device, dtype): self.softmax_dtype = torch.float32 if attn.attention_softmax_in_fp32 else self.dtype self.upcast = self.softmax_dtype != self.dtype - do_unscale = ( - attn.scale_attention_softmax_in_fp32 - and self.upcast - and self.attention_implementation == AttentionImplementation.BASE - ) + do_unscale = attn.scale_attention_softmax_in_fp32 and self.upcast self.unscale = [i + 1.0 if do_unscale else 1.0 for i in range(self.n_layer)] scale = attn.head_dim**-0.5 if attn.scale_attn_weights else 1 self.scale = [scale / unscale for unscale in self.unscale] @@ -84,8 +64,7 @@ def _allocate(self, batch_size, device, dtype): # Attn weights: (batch_size, num_heads, key_length), no overlap with value (not needed for torch/flash attn) attn_weights_begin = _align_tensor(kv_end) attn_weights_end = attn_weights_begin - if self.attention_implementation == AttentionImplementation.BASE: - attn_weights_end += self.batch_size * attn.num_heads * self.max_sequence_length + attn_weights_end += self.batch_size * attn.num_heads * self.max_sequence_length # Projection: (batch_size, embed_dim), no overlap with attn outputs ~ query. # Also used for MLP projection c_proj_begin = _align_tensor(query_end) @@ -151,10 +130,6 @@ def _allocate(self, batch_size, device, dtype): # QKV: (bs, embed_dim + 2 * kv_dim). self.c_attn = activation_pool[query_begin:kv_end].view(self.batch_size, -1) self.query = self.c_attn[:, : attn.embed_dim].view(self.batch_size, attn.num_heads, attn.head_dim) - # if self.attention_implementation==AttentionImplementation.FLASH: - # self.query=query.view(self.batch_size * attn.num_heads, 1, attn.head_dim) - # else: - # self.query=query.view(self.batch_size, attn.num_heads, attn.head_dim) self.kv_attn = self.c_attn[:, attn.embed_dim :] @@ -163,13 +138,7 @@ def _allocate(self, batch_size, device, dtype): # No transpose for torch/flash attn self.padded_keys = [ - [ - key[:, head_slice, :key_length, :].transpose( - -1, -2 if self.attention_implementation == AttentionImplementation.BASE else -1 - ) - for key in keys - ] - for key_length in padded_key_lengths + [key[:, head_slice, :key_length, :].transpose(-1, -2) for key in keys] for key_length in padded_key_lengths ] self.padded_values = [ [value[:, head_slice, :key_length, :] for value in values] for key_length in padded_key_lengths @@ -183,22 +152,15 @@ def _allocate(self, batch_size, device, dtype): [kv_cache[:, head_slice, : key_length - 1, :] for kv_cache in kv_caches] for key_length in key_lengths ] - if self.attention_implementation == AttentionImplementation.BASE: - # Attn weights: (batch_size, num_heads, key_length), no overlap with value. - attn_weights = activation_pool[attn_weights_begin:attn_weights_end].view( - self.batch_size, attn.num_heads, self.max_sequence_length - ) - self.padded_attn_weights = [attn_weights[:, :, :key_length] for key_length in padded_key_lengths] - - # Attn outputs: (batch_size, embed_dim), no overlap with value. - self.attn_output = activation_pool[query_begin:query_end].view(self.batch_size, -1) - self.attn_output_expanded = self.attn_output.view(self.batch_size, attn.num_heads, attn.head_dim) - elif self.attention_implementation == AttentionImplementation.FLASH: - self.cu_sq = torch.arange( - 0, (self.batch_size + 1) * attn.num_heads, step=attn.num_heads, dtype=torch.int32, device=self.device - ) - # self.cu_sk = torch.arange(0, (self.batch_size + 1) * self.max_sequence_length, step=self.max_sequence_length, dtype=torch.int32, - # device=self.device) + # Attn weights: (batch_size, num_heads, key_length), no overlap with value. + attn_weights = activation_pool[attn_weights_begin:attn_weights_end].view( + self.batch_size, attn.num_heads, self.max_sequence_length + ) + self.padded_attn_weights = [attn_weights[:, :, :key_length] for key_length in padded_key_lengths] + + # Attn outputs: (batch_size, embed_dim), no overlap with value. + self.attn_output = activation_pool[query_begin:query_end].view(self.batch_size, -1) + self.attn_output_expanded = self.attn_output.view(self.batch_size, attn.num_heads, attn.head_dim) # Attn projection: (batch_size, embed_dim), no overlap with attn outputs. self.c_proj = activation_pool[c_proj_begin:c_proj_end].view(self.batch_size, -1) @@ -244,7 +206,7 @@ def _generate_full_cuda_graph(self, key_length): # We need to warmup the jit function before creating the graph, otherwise it will crash. # https://github.com/pytorch/pytorch/issues/99397 # Warmup needs to be done for every input shape (key length), and for both scale == 1 and scale != 1 - if self.attention_implementation == AttentionImplementation.BASE and self.fused_softmax: + if self.fused_softmax or (self.fused_softmax is None and key_length % 8 == 0): for scale in (1.0, 2.0): softmax_function( self.padded_attn_weights[key_length], @@ -299,7 +261,7 @@ def _forward_attn_base(self, block, key_length): self.unscale[layer_idx], self.softmax_dtype, self.upcast, - block.attn.self.fused_softmax, + self.fused_softmax, ) torch.bmm(attn_weights, self.padded_values[key_length][layer_idx], out=self.attn_output_expanded) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 3976022f1c..3b6a5c9075 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -107,7 +107,7 @@ def softmax_function( scale: float, softmax_dtype: torch.dtype, upcast: bool = True, - fused_softmax: bool = False, + fused_softmax: Optional[bool] = None, ): """ This selects the appropriate (fused) (upcast) (masked) softmax method. Because of the way jit works, each case @@ -116,6 +116,8 @@ def softmax_function( inefficient. TODO: Could have better fused kernels depending on scaling, dropout and head mask. Is it doable without writing 32 functions? """ + if fused_softmax is None: + fused_softmax = x.size(-1) % 8 == 0 if upcast: if mask is None: return (upcast_softmax_fused if fused_softmax else upcast_softmax)(x, scale, softmax_dtype) @@ -185,7 +187,6 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): if self.attention_implementation == AttentionImplementation.BASE: self._attn_fn = self._attn_mqa if self.multi_query else self._attn_mha elif self.attention_implementation in TORCH_IMPLEMENTATIONS: - # TODO: Implement assert not self.pre_allocate_kv_cache self._attn_fn = self._attn_torch_mqa if self.multi_query else self._attn_torch_mha self.backend_context = ( @@ -198,7 +199,6 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): ) ) elif self.attention_implementation == AttentionImplementation.FLASH: - # TODO: Implement assert not self.pre_allocate_kv_cache self._attn_fn = self._attn_flash_mqa if self.multi_query else self._attn_flash_mha elif self.attention_implementation == AttentionImplementation.OLD: @@ -240,7 +240,7 @@ def _get_attn_weights(self, query, key, attn_view, attn_shape, attention_mask=No unscale, softmax_dtype, upcast, - key.size(-1) % 8 == 0 if self.fused_softmax is None else self.fused_softmax, + self.fused_softmax, ) attn_weights = self.attn_dropout(attn_weights) From a6efba9d1c93fb13c331a5cecbe837e7608c6205 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Sat, 22 Apr 2023 02:08:21 -0400 Subject: [PATCH 7/8] Unpad for flash attn, remove flash/torch for decode, misc improvements --- .../gpt_bigcode/configuration_gpt_bigcode.py | 4 + .../gpt_bigcode/modeling_gpt_bigcode.py | 352 +++++++----------- 2 files changed, 145 insertions(+), 211 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py index c5a0bf1b56..09bad35274 100644 --- a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py @@ -164,6 +164,7 @@ def __init__( max_sequence_length=None, max_batch_size=None, pad_key_length=True, + predict_last_token: bool = False, **kwargs, ): self.vocab_size = vocab_size @@ -201,4 +202,7 @@ def __init__( # Pad key length to a multiple of 8 (requires pre_allocate_kv_cache). self.pad_key_length = pad_key_length + # Predict only the last token in inference even if the input is bigger. + self.predict_last_token = predict_last_token + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 3b6a5c9075..8aa801c4e2 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -44,9 +44,13 @@ try: + # TODO: This needs einops. + from flash_attn.bert_padding import pad_input, unpad_input from flash_attn.flash_attn_interface import flash_attn_unpadded_func except ImportError: flash_attn_unpadded_func = None + pad_input = None + unpad_input = None logger = logging.get_logger(__name__) @@ -171,13 +175,8 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): self._frozen_kv_cache = False if self.is_cross_attention: - if self.multi_query: - raise NotImplementedError("Multi-Query Attention not supported for cross_attention") - - self.c_attn = nn.Linear(self.embed_dim, 2 * self.embed_dim) - self.q_attn = nn.Linear(self.embed_dim, self.embed_dim) - else: - self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.kv_dim) + raise NotImplementedError("Cross-attention is not supported for gpt_bigcode.") + self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.kv_dim) self.c_proj = nn.Linear(self.embed_dim, self.embed_dim) @@ -200,7 +199,7 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): ) elif self.attention_implementation == AttentionImplementation.FLASH: assert not self.pre_allocate_kv_cache - self._attn_fn = self._attn_flash_mqa if self.multi_query else self._attn_flash_mha + self._attn_fn = self._attn_flash elif self.attention_implementation == AttentionImplementation.OLD: self._attn_fn = None else: @@ -263,7 +262,8 @@ def _attn_mha(self, hidden_states, layer_past, use_cache, attention_mask=None, h .transpose(1, 2) .split((self.head_dim, 2 * self.head_dim), dim=3) ) - key, value, present = self._merge_kv_caches(key_value, layer_past, use_cache) + key_value, present = self._merge_kv_caches(key_value, layer_past, use_cache) + key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) batch_size = query.size(0) query_length = query.size(1) @@ -294,7 +294,8 @@ def _attn_mqa(self, hidden_states, layer_past, use_cache, attention_mask=None, h # Q: (batch_size, query_length, num_heads * head_dim) # K, V:(batch_size, key_length, head_dim) query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) - key, value, present = self._merge_kv_caches(key_value, layer_past, use_cache) + key_value, present = self._merge_kv_caches(key_value, layer_past, use_cache) + key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) batch_size = query.size(0) query_length = query.size(1) @@ -323,7 +324,7 @@ def _attn_torch_mha(self, hidden_states, layer_past, use_cache, attention_mask=N # TODO: Scale? # TODO: Use attn mask if head_mask is not None: - raise NotImplementedError() + raise NotImplementedError("Head mask is not supported with torch attention.") # Q: (batch_size, num_heads, query_length, head_dim) # K, V: (batch_size, num_heads, query_length, head_dim) query, key_value = ( @@ -333,7 +334,8 @@ def _attn_torch_mha(self, hidden_states, layer_past, use_cache, attention_mask=N .split((self.head_dim, 2 * self.head_dim), dim=3) ) - key, value, present = self._merge_kv_caches(key_value, layer_past, use_cache) + key_value, present = self._merge_kv_caches(key_value, layer_past, use_cache) + key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) with self.backend_context(): # attn_output: (batch_size, num_heads, query_length, head_dim) @@ -350,11 +352,12 @@ def _attn_torch_mqa(self, hidden_states, layer_past, use_cache, attention_mask=N # TODO: Scale? # TODO: Use attn mask if head_mask is not None: - raise NotImplementedError() + raise NotImplementedError("Head mask is not supported with torch attention.") # Q: (batch_size, query_length, num_heads * head_dim) # K, V:(batch_size, key_length, head_dim) query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) - key, value, present = self._merge_kv_caches(key_value, layer_past, use_cache) + key_value, present = self._merge_kv_caches(key_value, layer_past, use_cache) + key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) batch_size = query.size(0) query_length = query.size(1) @@ -391,93 +394,40 @@ def _attn_torch_mqa(self, hidden_states, layer_past, use_cache, attention_mask=N return attn_output, present, None - def _attn_flash_mha(self, hidden_states, layer_past, use_cache, attention_mask=None, head_mask=None): + def _attn_flash(self, hidden_states, layer_past, use_cache, attention_mask=None, head_mask=None): # TODO: Use attn mask if head_mask is not None: - raise NotImplementedError() + raise NotImplementedError("Head mask is not supported with flash attention.") # Q: (batch_size, query_length, num_heads * head_dim) # K, V:(batch_size, key_length, num_heads * head_dim) - query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) - key, value, present = self._merge_kv_caches(key_value, layer_past, use_cache) - - batch_size = query.size(0) - query_length = query.size(1) - key_length = key.size(1) - - # TODO: Pre-allocate? - cu_sq = torch.arange( - 0, (batch_size + 1) * query_length, step=query_length, dtype=torch.int32, device=query.device - ) - cu_sk = torch.arange(0, (batch_size + 1) * key_length, step=key_length, dtype=torch.int32, device=query.device) - - # attn_output: (batch_size * query_length, num_heads, head_dim) - attn_output = flash_attn_unpadded_func( - query.view(batch_size * query_length, self.num_heads, self.head_dim), - key.view(batch_size * key_length, self.num_heads, self.head_dim), - value.view(batch_size * key_length, self.num_heads, self.head_dim), - cu_sq, - cu_sk, - query_length, - key_length, - self.dropout_p if self.training else 0.0, - softmax_scale=self.head_dim**-0.5 if self.scale_attn_weights else 1, - causal=True, - ) - - # attn_output: (batch_size, query_length, num_heads * head_dim) - # CPP backend needs reshape, others are ok with a view. (Checked for mqa, confirm for mha) - attn_output = attn_output.reshape(hidden_states.shape) - - return attn_output, present, None - - def _attn_flash_mqa(self, hidden_states, layer_past, use_cache, attention_mask=None, head_mask=None): - # TODO: Use attn mask - if head_mask is not None: - raise NotImplementedError() - # Q: (batch_size, query_length, num_heads * head_dim) - # K, V:(batch_size, key_length, head_dim) - query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) - key, value, present = self._merge_kv_caches(key_value, layer_past, use_cache) - - batch_size = query.size(0) - query_length = query.size(1) - key_length = key.size(1) - - key = key.reshape(batch_size * key_length, 1, self.head_dim) - value = value.reshape(batch_size * key_length, 1, self.head_dim) + query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=-1) + _, present = self._merge_kv_caches(key_value, layer_past, use_cache, attention_mask) + key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) - if query_length == 1: - # attn_output: (batch_size * num_heads, 1, head_dim) - query_step = self.num_heads - causal = False - query = query.reshape(batch_size * query_step, 1, self.head_dim) + attn_shape = query.size(0), self.num_heads, self.head_dim + query = query.view(attn_shape) + if self.multi_query: + key = key.unsqueeze(1).expand(attn_shape) + value = value.unsqueeze(1).expand(attn_shape) else: - # attn_output: (batch_size * query_length, num_heads, head_dim) - query_step = query_length - causal = True - query = query.view(batch_size * query_step, self.num_heads, self.head_dim) - key = key.expand(batch_size * key_length, self.num_heads, self.head_dim) - value = value.expand(batch_size * key_length, self.num_heads, self.head_dim) + key = key.view(attn_shape) + value = value.view(attn_shape) - # TODO: Pre-allocate? - cu_sq = torch.arange(0, (batch_size + 1) * query_step, step=query_step, dtype=torch.int32, device=query.device) - cu_sk = torch.arange(0, (batch_size + 1) * key_length, step=key_length, dtype=torch.int32, device=query.device) + sequence_lengths, padding_index, _, max_sequence_length = attention_mask + # attn_output: (sum_seq_len, num_heads * head_dim) attn_output = flash_attn_unpadded_func( query, key, value, - cu_sq, - cu_sk, - query_step, - key_length, + sequence_lengths, + sequence_lengths, + max_sequence_length, + max_sequence_length, self.dropout_p if self.training else 0.0, softmax_scale=self.head_dim**-0.5 if self.scale_attn_weights else 1, - causal=causal, - ) - - # attn_output: (batch_size, query_length, num_heads * head_dim) - attn_output = attn_output.view(hidden_states.shape) + causal=True, + ).view(hidden_states.shape) return attn_output, present, None @@ -516,8 +466,12 @@ def get_kv_cache(self, batch_size, sequence_length, device, dtype, allocate=True ) return kv_cache[:, 0, :sequence_length, :] if self.multi_query else kv_cache[:, :, :sequence_length, :] - def _merge_kv_caches(self, key_value, layer_past, use_cache): + def _merge_kv_caches(self, key_value, layer_past, use_cache, flash_attention_parameters=None): present = None + if flash_attention_parameters is not None and (use_cache or layer_past is not None): + # Todo: unpadding is only needed if the cache is reused. + _, padding_index, batch_size, max_sequence_length = flash_attention_parameters + key_value = pad_input(key_value, padding_index, batch_size, max_sequence_length) if self.pre_allocate_kv_cache: if use_cache or layer_past is not None: last_key_length = layer_past or 0 @@ -537,8 +491,7 @@ def _merge_kv_caches(self, key_value, layer_past, use_cache): key_value = torch.cat((layer_past, key_value), dim=-2) if use_cache: present = key_value - key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) - return key, value, present + return key_value, present def forward( self, @@ -554,23 +507,30 @@ def forward( Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], ]: + if encoder_hidden_states is not None or encoder_attention_mask is not None: + raise NotImplementedError("Cross-attention is not supported for gpt_bigcode.") + if self.attention_implementation == AttentionImplementation.OLD: return self._old_forward( hidden_states, layer_past, attention_mask, head_mask, - encoder_hidden_states, - encoder_attention_mask, use_cache, output_attentions, ) - if encoder_hidden_states is not None: - raise NotImplementedError() - attn_output, present, attn_weights = self._attn_fn( - hidden_states, layer_past, use_cache, attention_mask, head_mask - ) + if self.attention_implementation == AttentionImplementation.BASE or layer_past is not None: + attn_fn = self._attn_mqa if self.multi_query else self._attn_mha + elif self.attention_implementation in TORCH_IMPLEMENTATIONS: + assert not self.pre_allocate_kv_cache + attn_fn = self._attn_torch_mqa if self.multi_query else self._attn_torch_mha + elif self.attention_implementation == AttentionImplementation.FLASH: + attn_fn = self._attn_flash + else: + raise ValueError() + + attn_output, present, attn_weights = attn_fn(hidden_states, layer_past, use_cache, attention_mask, head_mask) attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) @@ -592,25 +552,13 @@ def _old_forward( layer_past: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, ) -> Union[ Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], ]: - if encoder_hidden_states is not None: - if not hasattr(self, "q_attn") or not self.is_cross_attention: - raise ValueError( - "If class is used as cross attention, the weights `q_attn` have to be defined. " - "Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`." - ) - - query = self.q_attn(hidden_states) - key_value = self.c_attn(encoder_hidden_states) - attention_mask = encoder_attention_mask - elif self.multi_query: + if self.multi_query: query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) else: # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim), @@ -770,10 +718,7 @@ def __init__(self, config, layer_idx=None): self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) if config.add_cross_attention: - if config.multi_query: - raise NotImplementedError("Cross-attention not implemented for MQA") - self.crossattention = GPTBigCodeAttention(config, is_cross_attention=True, layer_idx=layer_idx) - self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + raise NotImplementedError("Cross-attention is not supported for gpt_bigcode.") self.mlp = GPTBigCodeMLP(self.inner_dim, config) @@ -790,6 +735,9 @@ def forward( ) -> Union[ Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor] ]: + if encoder_hidden_states is not None or encoder_attention_mask is not None: + raise NotImplementedError("Cross-attention is not supported for gpt_bigcode.") + residual = hidden_states hidden_states = self.ln_1(hidden_states) attn_outputs = self.attn( @@ -805,28 +753,6 @@ def forward( # residual connection hidden_states = attn_output + residual - if encoder_hidden_states is not None: - # add one self-attention block for cross-attention - if not hasattr(self, "crossattention"): - raise ValueError( - f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " - "cross-attention layers by setting `config.add_cross_attention=True`" - ) - residual = hidden_states - hidden_states = self.ln_cross_attn(hidden_states) - cross_attn_outputs = self.crossattention( - hidden_states, - attention_mask=attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, - ) - attn_output = cross_attn_outputs[0] - # residual connection - hidden_states = residual + attn_output - outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights - residual = hidden_states hidden_states = self.ln_2(hidden_states) feed_forward_hidden_states = self.mlp(hidden_states) @@ -985,6 +911,9 @@ def __init__(self, config): self.multi_query = config.multi_query self.embed_dim = config.hidden_size + if config.add_cross_attention: + raise NotImplementedError("Cross-attention is not supported for gpt_bigcode.") + self.wte = nn.Embedding(config.vocab_size, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) @@ -996,6 +925,8 @@ def __init__(self, config): self.pad_key_length = config.pad_key_length and self.pre_allocate_kv_cache self.inference_runner_type = InferenceRunnerType(config.inference_runner) + self.attention_implementation = config.attention_implementation + if self.inference_runner_type == InferenceRunnerType.NO_RUNNER: self.inference_runner = None else: @@ -1004,6 +935,7 @@ def __init__(self, config): self.inference_runner = GPTBigCodeInferenceRunner(config, self) max_positions = config.max_position_embeddings + # Causal mask self.register_buffer( "bias", torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)), persistent=False ) @@ -1019,6 +951,32 @@ def get_input_embeddings(self): def set_input_embeddings(self, new_embeddings): self.wte = new_embeddings + def _get_causal_mask(self, padding_mask, query_length, key_length): + # Self-attention mask. + attention_mask = self.bias[None, key_length - query_length : key_length, :key_length] + + if padding_mask is not None: + attention_mask = attention_mask * padding_mask.unsqueeze(1).to( + dtype=torch.bool, device=attention_mask.device + ) + + # MQA models: (batch_size, query_length, n_heads, key_length) + # MHA models: (batch_size, n_heads, query_length, key_length) + return attention_mask.unsqueeze(2 if self.multi_query else 1) + + def _get_position_ids(self, position_ids, padding_mask, query_length, key_length, device): + if position_ids is not None: + position_ids = position_ids.to(device) + elif padding_mask is not None and padding_mask.ndim == 2: + # create position_ids on the fly for batch generation + position_ids = padding_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(padding_mask == 0, 1) + if key_length > query_length: + position_ids = position_ids[:, key_length - query_length : key_length :] + else: + position_ids = torch.arange(key_length - query_length, key_length, dtype=torch.long, device=device) + return position_ids.view(-1, query_length) + @add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, @@ -1070,30 +1028,28 @@ def forward( output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + use_cache = self.config.use_cache if use_cache is None else use_cache + return_dict = self.config.use_return_dict if return_dict is None else return_dict - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = input_ids.size() + if input_ids is not None: + if inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + input_shape = input_ids.shape input_ids = input_ids.view(-1, input_shape[-1]) - batch_size = input_ids.shape[0] + batch_size, query_length = input_ids.shape elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - batch_size = inputs_embeds.shape[0] + input_shape = inputs_embeds.shape[:-1] + inputs_embeds = inputs_embeds.view(-1, input_shape[-2:]) + batch_size, query_length = inputs_embeds.shape[:-1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") if batch_size <= 0: raise ValueError("batch_size has to be defined and > 0") - device = input_ids.device if input_ids is not None else inputs_embeds.device - - if token_type_ids is not None: - token_type_ids = token_type_ids.view(-1, input_shape[-1]) - if position_ids is not None: - position_ids = position_ids.view(-1, input_shape[-1]) + using_flash_attention = ( + self.attention_implementation == AttentionImplementation.FLASH and past_key_values is None + ) if past_key_values is None: past_length = 0 @@ -1102,49 +1058,23 @@ def forward( past_length = past_key_values[0] else: past_length = past_key_values[0].size(-2) - - if attention_mask is not None and len(attention_mask.shape) == 2 and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_length > 0: - position_ids = position_ids[:, past_length : input_shape[-1] + past_length :] - elif position_ids is None: - position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) - - # Self-attention mask. - query_length = input_shape[-1] key_length = past_length + query_length - self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length] - if attention_mask is not None: - self_attention_mask = self_attention_mask * attention_mask.view(batch_size, 1, -1).to( - dtype=torch.bool, device=self_attention_mask.device - ) + position_ids = self._get_position_ids(position_ids, attention_mask, query_length, key_length, input_ids.device) - # MQA models: (batch_size, query_length, n_heads, key_length) - # MHA models: (batch_size, n_heads, query_length, key_length) - attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1) + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, query_length) - if self.pad_key_length: - pad = -key_length % 8 - if pad > 0: - attention_mask = torch.nn.functional.pad(attention_mask, (0, pad), mode="constant", value=False) + if not using_flash_attention: + # Self-attention mask (padding + causal). + attention_mask = self._get_causal_mask(attention_mask, query_length, key_length) + if self.pad_key_length: + pad = -key_length % 8 + if pad > 0: + attention_mask = torch.nn.functional.pad(attention_mask, (0, pad), mode="constant", value=False) - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if ( - self.config.add_cross_attention - and encoder_hidden_states is not None - and encoder_attention_mask is not None - ): - if encoder_attention_mask.dim() == 2: - encoder_attention_mask.unsqueeze(1) - assert encoder_attention_mask.dim() == 3 - encoder_attention_mask = encoder_attention_mask.bool().unsqueeze(2 if self.multi_query else 1) - else: - encoder_attention_mask = None + if encoder_hidden_states is not None or encoder_attention_mask is not None: + raise NotImplementedError("Cross-attention is not supported for gpt_bigcode.") # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head @@ -1163,7 +1093,13 @@ def forward( hidden_states = self.drop(hidden_states) - output_shape = input_shape + (hidden_states.size(-1),) + # TODO: Unpad earlier (input ids), support unpadded input? + if using_flash_attention: + hidden_states, padding_index, sequence_lengths, max_sequence_length = unpad_input( + hidden_states, attention_mask + ) + # Pass the required parameters through the attention_mask argument + attention_mask = (sequence_lengths, padding_index, batch_size, max_sequence_length) presents = [] if use_cache else None all_self_attentions = () if output_attentions else None @@ -1188,8 +1124,6 @@ def custom_forward(*inputs): None, attention_mask, head_mask[i], - encoder_hidden_states, - encoder_attention_mask, ) else: outputs = block( @@ -1197,8 +1131,6 @@ def custom_forward(*inputs): layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, use_cache=use_cache, output_attentions=output_attentions, ) @@ -1214,7 +1146,11 @@ def custom_forward(*inputs): hidden_states = self.ln_f(hidden_states) - hidden_states = hidden_states.view(output_shape) + if using_flash_attention: + hidden_states = pad_input(hidden_states, padding_index, batch_size, query_length) + + hidden_states = hidden_states.view(input_shape + (hidden_states.size(-1),)) + # Add last hidden state if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -1249,6 +1185,8 @@ def __init__(self, config): super().__init__(config) self.transformer = GPTBigCodeModel(config) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + self.predict_last_token = config.predict_last_token + self.attention_implementation = config.attention_implementation # Initialize weights and apply final processing self.post_init() @@ -1263,21 +1201,9 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_ token_type_ids = kwargs.get("token_type_ids", None) # only last token for inputs_ids if past is defined in kwargs if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(-1) + input_ids = input_ids[:, -1:] if token_type_ids is not None: - token_type_ids = token_type_ids[:, -1].unsqueeze(-1) - - attention_mask = kwargs.get("attention_mask", None) - position_ids = kwargs.get("position_ids", None) - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) - else: - position_ids = None + token_type_ids = token_type_ids[:, -1:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: @@ -1289,8 +1215,8 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_ { "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), - "position_ids": position_ids, - "attention_mask": attention_mask, + "position_ids": kwargs.get("position_ids", None), + "attention_mask": kwargs.get("attention_mask", None), "token_type_ids": token_type_ids, } ) @@ -1344,6 +1270,10 @@ def forward( ) hidden_states = transformer_outputs[0] + if self.predict_last_token and not self.training: + # We only care about the last token. + hidden_states = hidden_states[:, -1:] + lm_logits = self.lm_head(hidden_states) loss = None From a2efad2c96e6da982f102eea53918c7b8431da80 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 24 Apr 2023 10:35:09 -0400 Subject: [PATCH 8/8] cleanup --- .../models/gpt_bigcode/inference_runner.py | 49 +------------------ 1 file changed, 1 insertion(+), 48 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/inference_runner.py b/src/transformers/models/gpt_bigcode/inference_runner.py index 8de50f05e7..b5fc56109d 100644 --- a/src/transformers/models/gpt_bigcode/inference_runner.py +++ b/src/transformers/models/gpt_bigcode/inference_runner.py @@ -240,7 +240,7 @@ def _forward_qkv(self, block): out=self.c_attn, ) - def _forward_attn_base(self, block, key_length): + def _forward_attn(self, block, key_length): layer_idx = block.attn.layer_idx self.current_key_values[key_length][layer_idx].copy_(self.kv_attn) attn_weights = self.padded_attn_weights[key_length] @@ -266,50 +266,6 @@ def _forward_attn_base(self, block, key_length): torch.bmm(attn_weights, self.padded_values[key_length][layer_idx], out=self.attn_output_expanded) - def _forward_attn_torch(self, block, key_length): - layer_idx = block.attn.layer_idx - self.current_key_values[key_length][layer_idx].copy_(self.kv_attn) - with block.attn.backend_context(): - attn_output = torch.nn.functional.scaled_dot_product_attention( - self.query, - self.padded_keys[key_length][layer_idx], - self.padded_values[key_length][layer_idx], - None, # attention_mask, - 0.0, - is_causal=False, - ) - # Out arg not supported so we set the variable instead. - self.attn_output = attn_output.view(self.batch_size, -1) - - def _forward_attn_flash(self, block, key_length): - layer_idx = block.attn.layer_idx - num_heads = block.attn.num_heads - self.current_key_values[key_length][layer_idx].copy_(self.kv_attn) - # TODO: Pre-allocate? - # TODO: Adjust for non-contiguous key/value? (max seq len instead of key length) - cu_sk = torch.arange( - 0, (self.batch_size + 1) * key_length, step=key_length, dtype=torch.int32, device=self.device - ) - # TODO: Avoid reshape - q = self.query.reshape(self.batch_size * num_heads, 1, block.attn.head_dim) - k = self.padded_keys[key_length][layer_idx].reshape(self.batch_size * key_length, 1, block.attn.head_dim) - v = self.padded_values[key_length][layer_idx].reshape(self.batch_size * key_length, 1, block.attn.head_dim) - print("A", q.shape, k.shape, v.shape) - attn_output = flash_attn_unpadded_func( - q, - k, - v, - self.cu_sq, - cu_sk, - num_heads, - key_length, - 0.0, - softmax_scale=self.scale[layer_idx], - causal=False, - ) - # Out arg not supported so we set the variable instead. - self.attn_output = attn_output.view(self.batch_size, -1) - def _forward_post_attn(self, block): torch.nn.functional.linear( self.attn_output, @@ -317,9 +273,6 @@ def _forward_post_attn(self, block): block.attn.c_proj.bias, out=self.c_proj, ) - if self.attention_implementation != AttentionImplementation.BASE: - # Free memory. - del self.attn_output self.hidden_states_squeezed.add_(self.c_proj) # LN doesn't support out argument. hidden_states = block.ln_2(self.hidden_states_squeezed)