diff --git a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py index 1cfba93a71..09bad35274 100644 --- a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py @@ -40,6 +40,28 @@ class InferenceRunnerType(IntEnum): FULL_GRAPH = 3 +class AttentionImplementation(IntEnum): + # Ours + BASE = 0 + # Flash attention + FLASH = 1 + # scaled_dot_product_attention (multiple implementations) + TORCH = 2 + TORCH_FLASH = 3 + TORCH_MEM = 4 + TORCH_CPP = 5 + # DEBUG + OLD = 6 + + +TORCH_IMPLEMENTATIONS = ( + AttentionImplementation.TORCH, + AttentionImplementation.TORCH_FLASH, + AttentionImplementation.TORCH_MEM, + AttentionImplementation.TORCH_CPP, +) + + class GPTBigCodeConfig(PretrainedConfig): """ This is the configuration class to store the configuration of a [`GPTBigCodeModel`]. It is used to instantiate a @@ -133,13 +155,16 @@ def __init__( eos_token_id=50256, attention_softmax_in_fp32=True, scale_attention_softmax_in_fp32=True, + fused_softmax=None, multi_query=True, + attention_implementation=AttentionImplementation.BASE, inference_runner=InferenceRunnerType.NO_RUNNER, validate_runner_input=True, pre_allocate_kv_cache=False, max_sequence_length=None, max_batch_size=None, pad_key_length=True, + predict_last_token: bool = False, **kwargs, ): self.vocab_size = vocab_size @@ -158,7 +183,9 @@ def __init__( self.use_cache = use_cache self.attention_softmax_in_fp32 = attention_softmax_in_fp32 self.scale_attention_softmax_in_fp32 = scale_attention_softmax_in_fp32 + self.fused_softmax = fused_softmax self.multi_query = multi_query + self.attention_implementation = attention_implementation self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id @@ -175,4 +202,7 @@ def __init__( # Pad key length to a multiple of 8 (requires pre_allocate_kv_cache). self.pad_key_length = pad_key_length + # Predict only the last token in inference even if the input is bigger. + self.predict_last_token = predict_last_token + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) diff --git a/src/transformers/models/gpt_bigcode/inference_runner.py b/src/transformers/models/gpt_bigcode/inference_runner.py index 1767bf9642..b5fc56109d 100644 --- a/src/transformers/models/gpt_bigcode/inference_runner.py +++ b/src/transformers/models/gpt_bigcode/inference_runner.py @@ -4,8 +4,17 @@ from transformers import GPTBigCodeConfig from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions -from transformers.models.gpt_bigcode.configuration_gpt_bigcode import InferenceRunnerType -from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeBlock, masked_softmax, upcast_masked_softmax +from transformers.models.gpt_bigcode.configuration_gpt_bigcode import ( + AttentionImplementation, + InferenceRunnerType, +) +from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeBlock, softmax_function + + +try: + from flash_attn.flash_attn_interface import flash_attn_unpadded_func +except ImportError: + flash_attn_unpadded_func = None def _align_tensor(x): @@ -23,6 +32,7 @@ def __init__(self, config: GPTBigCodeConfig, model): assert config.pre_allocate_kv_cache self.validate_input = config.validate_runner_input self.pad_key_length = 8 if config.pad_key_length else 1 + self.fused_softmax = True if config.fused_softmax is None and config.pad_key_length else config.fused_softmax # TODO: Support other attention types? assert model.multi_query @@ -51,9 +61,10 @@ def _allocate(self, batch_size, device, dtype): query_end = query_begin + self.batch_size * attn.embed_dim # KV: (bs, 2 * kv_dim), combines with query into c_attn. kv_end = query_end + 2 * self.batch_size * attn.kv_dim - # Attn weights: (batch_size, num_heads, key_length), no overlap with value + # Attn weights: (batch_size, num_heads, key_length), no overlap with value (not needed for torch/flash attn) attn_weights_begin = _align_tensor(kv_end) - attn_weights_end = kv_end + self.batch_size * attn.num_heads * self.max_sequence_length + attn_weights_end = attn_weights_begin + attn_weights_end += self.batch_size * attn.num_heads * self.max_sequence_length # Projection: (batch_size, embed_dim), no overlap with attn outputs ~ query. # Also used for MLP projection c_proj_begin = _align_tensor(query_end) @@ -119,11 +130,13 @@ def _allocate(self, batch_size, device, dtype): # QKV: (bs, embed_dim + 2 * kv_dim). self.c_attn = activation_pool[query_begin:kv_end].view(self.batch_size, -1) self.query = self.c_attn[:, : attn.embed_dim].view(self.batch_size, attn.num_heads, attn.head_dim) + self.kv_attn = self.c_attn[:, attn.embed_dim :] keys, values = zip(*(kv_cache.split((attn.head_dim, attn.head_dim), dim=-1) for kv_cache in kv_caches)) head_slice = 0 if attn.multi_query else slice(None) + # No transpose for torch/flash attn self.padded_keys = [ [key[:, head_slice, :key_length, :].transpose(-1, -2) for key in keys] for key_length in padded_key_lengths ] @@ -159,6 +172,10 @@ def _allocate(self, batch_size, device, dtype): if self.inference_runner_type != InferenceRunnerType.BASE_RUNNER: print("Generating cuda graphs") self.memory_pool = None + # This prevents some issue with cublas initialization. + # https://github.com/pytorch/pytorch/issues/99397 + dummy_matrix = self.mask_value.view([1, 1]) + torch.matmul(dummy_matrix, dummy_matrix) if self.inference_runner_type == InferenceRunnerType.FULL_GRAPH: self.cuda_graphs = {} # The output may not always be at the same memory location. @@ -187,22 +204,19 @@ def _generate_cuda_graphs(self): def _generate_full_cuda_graph(self, key_length): # We need to warmup the jit function before creating the graph, otherwise it will crash. + # https://github.com/pytorch/pytorch/issues/99397 # Warmup needs to be done for every input shape (key length), and for both scale == 1 and scale != 1 - if self.upcast: + if self.fused_softmax or (self.fused_softmax is None and key_length % 8 == 0): for scale in (1.0, 2.0): - upcast_masked_softmax( + softmax_function( self.padded_attn_weights[key_length], self.padded_attn_masks[key_length], self.mask_value, scale, self.softmax_dtype, + self.upcast, + self.fused_softmax, ) - else: - masked_softmax( - self.padded_attn_weights[key_length], - self.padded_attn_masks[key_length], - self.mask_value, - ) graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, pool=self.memory_pool): self.output_hidden_states[key_length] = self._forward(key_length) @@ -239,18 +253,16 @@ def _forward_attn(self, block, key_length): alpha=self.scale[layer_idx], out=attn_weights, ) - # Use a fused kernel to prevent a large overhead from casting and scaling. # Jit doesn't allow inplace kernel. - if self.upcast: - attn_weights = upcast_masked_softmax( - attn_weights, - self.padded_attn_masks[key_length], - self.mask_value, - self.unscale[layer_idx], - self.softmax_dtype, - ) - else: - attn_weights = masked_softmax(attn_weights, self.padded_attn_masks[key_length], self.mask_value) + attn_weights = softmax_function( + attn_weights, + self.padded_attn_masks[key_length], + self.mask_value, + self.unscale[layer_idx], + self.softmax_dtype, + self.upcast, + self.fused_softmax, + ) torch.bmm(attn_weights, self.padded_values[key_length][layer_idx], out=self.attn_output_expanded) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 76ba07b73e..8aa801c4e2 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """PyTorch GPTBigCode model.""" +import contextlib import math from typing import List, Optional, Tuple, Union @@ -34,7 +35,22 @@ add_start_docstrings_to_model_forward, logging, ) -from .configuration_gpt_bigcode import GPTBigCodeConfig, InferenceRunnerType +from .configuration_gpt_bigcode import ( + TORCH_IMPLEMENTATIONS, + AttentionImplementation, + GPTBigCodeConfig, + InferenceRunnerType, +) + + +try: + # TODO: This needs einops. + from flash_attn.bert_padding import pad_input, unpad_input + from flash_attn.flash_attn_interface import flash_attn_unpadded_func +except ImportError: + flash_attn_unpadded_func = None + pad_input = None + unpad_input = None logger = logging.get_logger(__name__) @@ -48,11 +64,6 @@ ] -# Fused kernels -# Use separate functions for each case because conditionals prevent kernel fusion. -# TODO: Could have better fused kernels depending on scaling, dropout and head mask. -# Is it doable without writing 32 functions? -@torch.jit.script def upcast_masked_softmax( x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float, softmax_dtype: torch.dtype ): @@ -64,6 +75,12 @@ def upcast_masked_softmax( @torch.jit.script +def upcast_masked_softmax_fused( + x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float, softmax_dtype: torch.dtype +): + return upcast_masked_softmax(x, mask, mask_value, scale, softmax_dtype) + + def upcast_softmax(x: torch.Tensor, scale: float, softmax_dtype: torch.dtype): input_dtype = x.dtype x = x.to(softmax_dtype) * scale @@ -72,18 +89,61 @@ def upcast_softmax(x: torch.Tensor, scale: float, softmax_dtype: torch.dtype): @torch.jit.script +def upcast_softmax_fused(x: torch.Tensor, scale: float, softmax_dtype: torch.dtype): + return upcast_softmax(x, scale, softmax_dtype) + + def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor): x = torch.where(mask, x, mask_value) x = torch.nn.functional.softmax(x, dim=-1) return x +@torch.jit.script +def masked_softmax_fused(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor): + return masked_softmax(x, mask, mask_value) + + +def softmax_function( + x: torch.Tensor, + mask: torch.Tensor, + mask_value: torch.Tensor, + scale: float, + softmax_dtype: torch.dtype, + upcast: bool = True, + fused_softmax: Optional[bool] = None, +): + """ + This selects the appropriate (fused) (upcast) (masked) softmax method. Because of the way jit works, each case + needs to be handled through a separate method. The fused kernels remove most of the overhead from masking, casting + and scaling, but only work well when the key length is a multiple of 8. For other key lengths, it is extremely + inefficient. TODO: Could have better fused kernels depending on scaling, dropout and head mask. + Is it doable without writing 32 functions? + """ + if fused_softmax is None: + fused_softmax = x.size(-1) % 8 == 0 + if upcast: + if mask is None: + return (upcast_softmax_fused if fused_softmax else upcast_softmax)(x, scale, softmax_dtype) + else: + return (upcast_masked_softmax_fused if fused_softmax else upcast_masked_softmax)( + x, mask, mask_value, scale, softmax_dtype + ) + else: + if mask is None: + return torch.nn.functional.softmax(x, dim=-1) + else: + return (masked_softmax_fused if fused_softmax else masked_softmax)(x, mask, mask_value) + + class GPTBigCodeAttention(nn.Module): def __init__(self, config, is_cross_attention=False, layer_idx=None): super().__init__() self.mask_value = None self.multi_query = config.multi_query + # TODO: chack availability + self.attention_implementation = config.attention_implementation self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads @@ -104,6 +164,7 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): self.scale_attention_softmax_in_fp32 = ( config.scale_attention_softmax_in_fp32 and config.attention_softmax_in_fp32 ) + self.fused_softmax = config.fused_softmax # KV caching and padding self.kv_cache = None @@ -114,26 +175,43 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): self._frozen_kv_cache = False if self.is_cross_attention: - if self.multi_query: - raise NotImplementedError("Multi-Query Attention not supported for cross_attention") - - self.c_attn = nn.Linear(self.embed_dim, 2 * self.embed_dim) - self.q_attn = nn.Linear(self.embed_dim, self.embed_dim) - else: - self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.kv_dim) + raise NotImplementedError("Cross-attention is not supported for gpt_bigcode.") + self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.kv_dim) self.c_proj = nn.Linear(self.embed_dim, self.embed_dim) self.attn_dropout = nn.Dropout(config.attn_pdrop) self.resid_dropout = nn.Dropout(config.resid_pdrop) + if self.attention_implementation == AttentionImplementation.BASE: + self._attn_fn = self._attn_mqa if self.multi_query else self._attn_mha + elif self.attention_implementation in TORCH_IMPLEMENTATIONS: + assert not self.pre_allocate_kv_cache + self._attn_fn = self._attn_torch_mqa if self.multi_query else self._attn_torch_mha + self.backend_context = ( + lambda: contextlib.nullcontext() + if self.attention_implementation == AttentionImplementation.TORCH + else torch.backends.cuda.sdp_kernel( + enable_flash=self.attention_implementation == AttentionImplementation.TORCH_FLASH, + enable_math=self.attention_implementation == AttentionImplementation.TORCH_CPP, + enable_mem_efficient=self.attention_implementation == AttentionImplementation.TORCH_MEM, + ) + ) + elif self.attention_implementation == AttentionImplementation.FLASH: + assert not self.pre_allocate_kv_cache + self._attn_fn = self._attn_flash + elif self.attention_implementation == AttentionImplementation.OLD: + self._attn_fn = None + else: + raise ValueError() + def _get_mask_value(self, device, dtype): # torch.where expects a tensor. We use a cache to avoid recreating it every time. if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device: self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device) return self.mask_value - def _attn(self, query, key, value, attention_mask=None, head_mask=None): + def _get_attn_weights(self, query, key, attn_view, attn_shape, attention_mask=None, head_mask=None): dtype = query.dtype softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else dtype upcast = dtype != softmax_dtype @@ -143,30 +221,6 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): if self.scale_attn_weights: scale_factor /= self.head_dim**0.5 - # MQA models: (batch_size, query_length, num_heads * head_dim) - # MHA models: (batch_size, num_heads, query_length, head_dim) - query_shape = query.shape - batch_size = query_shape[0] - key_length = key.size(-1) - if self.multi_query: - # (batch_size, query_length, num_heads, head_dim) x (batch_size, head_dim, key_length) - # -> (batch_size, query_length, num_heads, key_length) - query_length = query_shape[1] - attn_shape = (batch_size, query_length, self.num_heads, key_length) - attn_view = (batch_size, query_length * self.num_heads, key_length) - # No copy needed for MQA 2, or when layer_past is provided. - query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim) - else: - # (batch_size, num_heads, query_length, head_dim) x (batch_size, num_heads, head_dim, key_length) - # -> (batch_size, num_heads, query_length, key_length) - query_length = query_shape[2] - attn_shape = (batch_size, self.num_heads, query_length, key_length) - attn_view = (batch_size * self.num_heads, query_length, key_length) - # Always copies - query = query.reshape(batch_size * self.num_heads, query_length, self.head_dim) - # No copy when layer_past is provided. - key = key.reshape(batch_size * self.num_heads, self.head_dim, key_length) - attn_weights = torch.empty(attn_view, device=query.device, dtype=query.dtype) if query.device.type == "cpu": # This is needed because of a bug in pytorch https://github.com/pytorch/pytorch/issues/80588. @@ -178,37 +232,204 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): beta = 0 attn_weights = torch.baddbmm(attn_weights, query, key, beta=beta, alpha=scale_factor).view(attn_shape) - if upcast: - # Use a fused kernel to prevent a large overhead from casting and scaling. - # Sub-optimal when the key length is not a multiple of 8. - if attention_mask is None: - attn_weights = upcast_softmax(attn_weights, unscale, softmax_dtype) - else: - mask_value = self._get_mask_value(attn_weights.device, softmax_dtype) - attn_weights = upcast_masked_softmax(attn_weights, attention_mask, mask_value, unscale, softmax_dtype) - else: - if attention_mask is not None: - mask_value = self._get_mask_value(attn_weights.device, softmax_dtype) - - # The fused kernel is very slow when the key length is not a multiple of 8, so we skip fusion. - attn_weights = torch.where(attention_mask, attn_weights, mask_value) - - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + attn_weights = softmax_function( + attn_weights, + attention_mask, + None if attention_mask is None else self._get_mask_value(attn_weights.device, softmax_dtype), + unscale, + softmax_dtype, + upcast, + self.fused_softmax, + ) attn_weights = self.attn_dropout(attn_weights) # Mask heads if we want to if head_mask is not None: - if self.multi_query: - head_mask = head_mask.transpose(1, 2) attn_weights = attn_weights * head_mask - if self.multi_query: - attn_output = torch.bmm(attn_weights.view(attn_view), value).view(query_shape) + return attn_weights + + def _attn_mha(self, hidden_states, layer_past, use_cache, attention_mask=None, head_mask=None): + # Q: (batch_size, num_heads, query_length, head_dim) + # K, V: (batch_size, num_heads, query_length, head_dim) + # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim), + # i.e., the memory layout is not the same as GPT2. + # This makes the concatenation with past_key_value more efficient. + query, key_value = ( + self.c_attn(hidden_states) + .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim) + .transpose(1, 2) + .split((self.head_dim, 2 * self.head_dim), dim=3) + ) + key_value, present = self._merge_kv_caches(key_value, layer_past, use_cache) + key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) + + batch_size = query.size(0) + query_length = query.size(1) + key_length = key.size(1) + + # MHA models: (batch_size, num_heads, query_length, head_dim) + # (batch_size, num_heads, query_length, head_dim) x (batch_size, num_heads, head_dim, key_length) + # -> (batch_size, num_heads, query_length, key_length) + attn_shape = (batch_size, self.num_heads, query_length, key_length) + attn_view = (batch_size * self.num_heads, query_length, key_length) + # Always copies + query = query.reshape(batch_size * self.num_heads, query_length, self.head_dim) + # No copy when layer_past is provided. + key = key.transpose(-1, -2).reshape(batch_size * self.num_heads, self.head_dim, key_length) + + # attn_weights: (batch_size, num_heads, query_length, key_length) + attn_weights = self._get_attn_weights(query, key, attn_view, attn_shape, attention_mask, head_mask) + + # attn_output: (batch_size, num_heads, query_length, head_dim) + attn_output = torch.matmul(attn_weights, value) + + # attn_output: (batch_size, query_length, num_heads * head_dim) + attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape) + + return attn_output, present, attn_weights + + def _attn_mqa(self, hidden_states, layer_past, use_cache, attention_mask=None, head_mask=None): + # Q: (batch_size, query_length, num_heads * head_dim) + # K, V:(batch_size, key_length, head_dim) + query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) + key_value, present = self._merge_kv_caches(key_value, layer_past, use_cache) + key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) + + batch_size = query.size(0) + query_length = query.size(1) + key_length = key.size(1) + + # (batch_size, query_length, num_heads, head_dim) x (batch_size, head_dim, key_length) + # -> (batch_size, query_length, num_heads, key_length) + attn_shape = (batch_size, query_length, self.num_heads, key_length) + attn_view = (batch_size, query_length * self.num_heads, key_length) + # No copy needed for MQA 2, or when layer_past is provided. + query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim) + key = key.transpose(-1, -2) + + if head_mask is not None: + head_mask = head_mask.transpose(1, 2) + + # attn_weights: (batch_size, query_length, num_heads, key_length) + attn_weights = self._get_attn_weights(query, key, attn_view, attn_shape, attention_mask, head_mask) + + # attn_output: (batch_size, query_length, num_heads * head_dim) + attn_output = torch.bmm(attn_weights.view(attn_view), value).view(hidden_states.shape) + + return attn_output, present, attn_weights + + def _attn_torch_mha(self, hidden_states, layer_past, use_cache, attention_mask=None, head_mask=None): + # TODO: Scale? + # TODO: Use attn mask + if head_mask is not None: + raise NotImplementedError("Head mask is not supported with torch attention.") + # Q: (batch_size, num_heads, query_length, head_dim) + # K, V: (batch_size, num_heads, query_length, head_dim) + query, key_value = ( + self.c_attn(hidden_states) + .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim) + .transpose(1, 2) + .split((self.head_dim, 2 * self.head_dim), dim=3) + ) + + key_value, present = self._merge_kv_caches(key_value, layer_past, use_cache) + key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) + + with self.backend_context(): + # attn_output: (batch_size, num_heads, query_length, head_dim) + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, key, value, None, self.attn_dropout, is_causal=True # attention_mask, + ) + + # attn_output: (batch_size, query_length, num_heads * head_dim) + attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape) + + return attn_output, present, None + + def _attn_torch_mqa(self, hidden_states, layer_past, use_cache, attention_mask=None, head_mask=None): + # TODO: Scale? + # TODO: Use attn mask + if head_mask is not None: + raise NotImplementedError("Head mask is not supported with torch attention.") + # Q: (batch_size, query_length, num_heads * head_dim) + # K, V:(batch_size, key_length, head_dim) + query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) + key_value, present = self._merge_kv_caches(key_value, layer_past, use_cache) + key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) + + batch_size = query.size(0) + query_length = query.size(1) + + if query_length == 1: + # attn_output: (batch_size, 1, num_heads, head_dim) + is_causal = False + query = query.view(batch_size, 1, self.num_heads, self.head_dim) + key = key.unsqueeze(1) + value = value.unsqueeze(1) else: - attn_output = torch.matmul(attn_weights, value) + # attn_output: (batch_size, num_heads, query_length, head_dim) + is_causal = True + expanded_shape = (batch_size, self.num_heads, key.size(-2), self.head_dim) + query = query.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2) + key = key.unsqueeze(1).expand(expanded_shape) + value = value.unsqueeze(1).expand(expanded_shape) + with self.backend_context(): + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + None, # attention_mask, + self.dropout_p if self.training else 0.0, + is_causal=is_causal, + ) - return attn_output, attn_weights + if query_length != 1: + attn_output = attn_output.transpose(1, 2) + + # attn_output: (batch_size, query_length, num_heads * head_dim) + # CPP backend needs reshape, others are ok with a view. + attn_output = attn_output.reshape(hidden_states.shape) + + return attn_output, present, None + + def _attn_flash(self, hidden_states, layer_past, use_cache, attention_mask=None, head_mask=None): + # TODO: Use attn mask + if head_mask is not None: + raise NotImplementedError("Head mask is not supported with flash attention.") + # Q: (batch_size, query_length, num_heads * head_dim) + # K, V:(batch_size, key_length, num_heads * head_dim) + query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=-1) + _, present = self._merge_kv_caches(key_value, layer_past, use_cache, attention_mask) + key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) + + attn_shape = query.size(0), self.num_heads, self.head_dim + query = query.view(attn_shape) + if self.multi_query: + key = key.unsqueeze(1).expand(attn_shape) + value = value.unsqueeze(1).expand(attn_shape) + else: + key = key.view(attn_shape) + value = value.view(attn_shape) + + sequence_lengths, padding_index, _, max_sequence_length = attention_mask + + # attn_output: (sum_seq_len, num_heads * head_dim) + attn_output = flash_attn_unpadded_func( + query, + key, + value, + sequence_lengths, + sequence_lengths, + max_sequence_length, + max_sequence_length, + self.dropout_p if self.training else 0.0, + softmax_scale=self.head_dim**-0.5 if self.scale_attn_weights else 1, + causal=True, + ).view(hidden_states.shape) + + return attn_output, present, None def freeze_kv_cache(self, enable=True): if self.kv_cache is None: @@ -245,6 +466,33 @@ def get_kv_cache(self, batch_size, sequence_length, device, dtype, allocate=True ) return kv_cache[:, 0, :sequence_length, :] if self.multi_query else kv_cache[:, :, :sequence_length, :] + def _merge_kv_caches(self, key_value, layer_past, use_cache, flash_attention_parameters=None): + present = None + if flash_attention_parameters is not None and (use_cache or layer_past is not None): + # Todo: unpadding is only needed if the cache is reused. + _, padding_index, batch_size, max_sequence_length = flash_attention_parameters + key_value = pad_input(key_value, padding_index, batch_size, max_sequence_length) + if self.pre_allocate_kv_cache: + if use_cache or layer_past is not None: + last_key_length = layer_past or 0 + batch_size = key_value.size(0) + key_length = last_key_length + key_value.size(-2) + padded_key_length = key_length + -key_length % (8 if self.pad_key_length else 1) + kv_cache = self.get_kv_cache( + batch_size, padded_key_length, key_value.device, key_value.dtype, allocate=last_key_length == 0 + ) + if self.multi_query: + kv_cache[:, last_key_length:key_length, :].copy_(key_value) + key_value = kv_cache + if use_cache: + present = key_length + else: + if layer_past is not None: + key_value = torch.cat((layer_past, key_value), dim=-2) + if use_cache: + present = key_value + return key_value, present + def forward( self, hidden_states: torch.Tensor, @@ -259,17 +507,58 @@ def forward( Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], ]: - if encoder_hidden_states is not None: - if not hasattr(self, "q_attn") or not self.is_cross_attention: - raise ValueError( - "If class is used as cross attention, the weights `q_attn` have to be defined. " - "Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`." - ) + if encoder_hidden_states is not None or encoder_attention_mask is not None: + raise NotImplementedError("Cross-attention is not supported for gpt_bigcode.") - query = self.q_attn(hidden_states) - key_value = self.c_attn(encoder_hidden_states) - attention_mask = encoder_attention_mask - elif self.multi_query: + if self.attention_implementation == AttentionImplementation.OLD: + return self._old_forward( + hidden_states, + layer_past, + attention_mask, + head_mask, + use_cache, + output_attentions, + ) + + if self.attention_implementation == AttentionImplementation.BASE or layer_past is not None: + attn_fn = self._attn_mqa if self.multi_query else self._attn_mha + elif self.attention_implementation in TORCH_IMPLEMENTATIONS: + assert not self.pre_allocate_kv_cache + attn_fn = self._attn_torch_mqa if self.multi_query else self._attn_torch_mha + elif self.attention_implementation == AttentionImplementation.FLASH: + attn_fn = self._attn_flash + else: + raise ValueError() + + attn_output, present, attn_weights = attn_fn(hidden_states, layer_past, use_cache, attention_mask, head_mask) + + attn_output = self.c_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + if attn_weights is None: + raise NotImplementedError("`output_attentions` not supported.") + if self.multi_query: + # Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length) + attn_weights = attn_weights.transpose(1, 2) + outputs += (attn_weights,) + + return outputs # a, present, (attentions) + + def _old_forward( + self, + hidden_states: torch.Tensor, + layer_past: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Union[ + Tuple[torch.Tensor, Optional[torch.Tensor]], + Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], + ]: + if self.multi_query: query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) else: # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim), @@ -322,6 +611,83 @@ def forward( return outputs # a, present, (attentions) + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + dtype = query.dtype + softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else dtype + upcast = dtype != softmax_dtype + + unscale = self.layer_idx + 1 if self.scale_attention_softmax_in_fp32 and upcast else 1 + scale_factor = unscale**-1 + if self.scale_attn_weights: + scale_factor /= self.head_dim**0.5 + + # MQA models: (batch_size, query_length, num_heads * head_dim) + # MHA models: (batch_size, num_heads, query_length, head_dim) + query_shape = query.shape + batch_size = query_shape[0] + key_length = key.size(-1) + if self.multi_query: + # (batch_size, query_length, num_heads, head_dim) x (batch_size, head_dim, key_length) + # -> (batch_size, query_length, num_heads, key_length) + query_length = query_shape[1] + attn_shape = (batch_size, query_length, self.num_heads, key_length) + attn_view = (batch_size, query_length * self.num_heads, key_length) + # No copy needed for MQA 2, or when layer_past is provided. + query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim) + else: + # (batch_size, num_heads, query_length, head_dim) x (batch_size, num_heads, head_dim, key_length) + # -> (batch_size, num_heads, query_length, key_length) + query_length = query_shape[2] + attn_shape = (batch_size, self.num_heads, query_length, key_length) + attn_view = (batch_size * self.num_heads, query_length, key_length) + # Always copies + query = query.reshape(batch_size * self.num_heads, query_length, self.head_dim) + # No copy when layer_past is provided. + key = key.reshape(batch_size * self.num_heads, self.head_dim, key_length) + + attn_weights = torch.empty(attn_view, device=query.device, dtype=query.dtype) + if query.device.type == "cpu": + # This is needed because of a bug in pytorch https://github.com/pytorch/pytorch/issues/80588. + # The bug was fixed in https://github.com/pytorch/pytorch/pull/96086, + # but the fix has not been released as of pytorch version 2.0.0. + attn_weights.zero_() + beta = 1 + else: + beta = 0 + attn_weights = torch.baddbmm(attn_weights, query, key, beta=beta, alpha=scale_factor).view(attn_shape) + + if upcast: + # Use a fused kernel to prevent a large overhead from casting and scaling. + # Sub-optimal when the key length is not a multiple of 8. + if attention_mask is None: + attn_weights = upcast_softmax(attn_weights, unscale, softmax_dtype) + else: + mask_value = self._get_mask_value(attn_weights.device, softmax_dtype) + attn_weights = upcast_masked_softmax(attn_weights, attention_mask, mask_value, unscale, softmax_dtype) + else: + if attention_mask is not None: + mask_value = self._get_mask_value(attn_weights.device, softmax_dtype) + + # The fused kernel is very slow when the key length is not a multiple of 8, so we skip fusion. + attn_weights = torch.where(attention_mask, attn_weights, mask_value) + + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + if self.multi_query: + head_mask = head_mask.transpose(1, 2) + attn_weights = attn_weights * head_mask + + if self.multi_query: + attn_output = torch.bmm(attn_weights.view(attn_view), value).view(query_shape) + else: + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + class GPTBigCodeMLP(nn.Module): def __init__(self, intermediate_size, config): @@ -352,10 +718,7 @@ def __init__(self, config, layer_idx=None): self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) if config.add_cross_attention: - if config.multi_query: - raise NotImplementedError("Cross-attention not implemented for MQA") - self.crossattention = GPTBigCodeAttention(config, is_cross_attention=True, layer_idx=layer_idx) - self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + raise NotImplementedError("Cross-attention is not supported for gpt_bigcode.") self.mlp = GPTBigCodeMLP(self.inner_dim, config) @@ -372,6 +735,9 @@ def forward( ) -> Union[ Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor] ]: + if encoder_hidden_states is not None or encoder_attention_mask is not None: + raise NotImplementedError("Cross-attention is not supported for gpt_bigcode.") + residual = hidden_states hidden_states = self.ln_1(hidden_states) attn_outputs = self.attn( @@ -387,28 +753,6 @@ def forward( # residual connection hidden_states = attn_output + residual - if encoder_hidden_states is not None: - # add one self-attention block for cross-attention - if not hasattr(self, "crossattention"): - raise ValueError( - f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " - "cross-attention layers by setting `config.add_cross_attention=True`" - ) - residual = hidden_states - hidden_states = self.ln_cross_attn(hidden_states) - cross_attn_outputs = self.crossattention( - hidden_states, - attention_mask=attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, - ) - attn_output = cross_attn_outputs[0] - # residual connection - hidden_states = residual + attn_output - outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights - residual = hidden_states hidden_states = self.ln_2(hidden_states) feed_forward_hidden_states = self.mlp(hidden_states) @@ -567,6 +911,9 @@ def __init__(self, config): self.multi_query = config.multi_query self.embed_dim = config.hidden_size + if config.add_cross_attention: + raise NotImplementedError("Cross-attention is not supported for gpt_bigcode.") + self.wte = nn.Embedding(config.vocab_size, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) @@ -578,6 +925,8 @@ def __init__(self, config): self.pad_key_length = config.pad_key_length and self.pre_allocate_kv_cache self.inference_runner_type = InferenceRunnerType(config.inference_runner) + self.attention_implementation = config.attention_implementation + if self.inference_runner_type == InferenceRunnerType.NO_RUNNER: self.inference_runner = None else: @@ -586,6 +935,7 @@ def __init__(self, config): self.inference_runner = GPTBigCodeInferenceRunner(config, self) max_positions = config.max_position_embeddings + # Causal mask self.register_buffer( "bias", torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)), persistent=False ) @@ -601,6 +951,32 @@ def get_input_embeddings(self): def set_input_embeddings(self, new_embeddings): self.wte = new_embeddings + def _get_causal_mask(self, padding_mask, query_length, key_length): + # Self-attention mask. + attention_mask = self.bias[None, key_length - query_length : key_length, :key_length] + + if padding_mask is not None: + attention_mask = attention_mask * padding_mask.unsqueeze(1).to( + dtype=torch.bool, device=attention_mask.device + ) + + # MQA models: (batch_size, query_length, n_heads, key_length) + # MHA models: (batch_size, n_heads, query_length, key_length) + return attention_mask.unsqueeze(2 if self.multi_query else 1) + + def _get_position_ids(self, position_ids, padding_mask, query_length, key_length, device): + if position_ids is not None: + position_ids = position_ids.to(device) + elif padding_mask is not None and padding_mask.ndim == 2: + # create position_ids on the fly for batch generation + position_ids = padding_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(padding_mask == 0, 1) + if key_length > query_length: + position_ids = position_ids[:, key_length - query_length : key_length :] + else: + position_ids = torch.arange(key_length - query_length, key_length, dtype=torch.long, device=device) + return position_ids.view(-1, query_length) + @add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, @@ -652,30 +1028,28 @@ def forward( output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + use_cache = self.config.use_cache if use_cache is None else use_cache + return_dict = self.config.use_return_dict if return_dict is None else return_dict - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = input_ids.size() + if input_ids is not None: + if inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + input_shape = input_ids.shape input_ids = input_ids.view(-1, input_shape[-1]) - batch_size = input_ids.shape[0] + batch_size, query_length = input_ids.shape elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - batch_size = inputs_embeds.shape[0] + input_shape = inputs_embeds.shape[:-1] + inputs_embeds = inputs_embeds.view(-1, input_shape[-2:]) + batch_size, query_length = inputs_embeds.shape[:-1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") if batch_size <= 0: raise ValueError("batch_size has to be defined and > 0") - device = input_ids.device if input_ids is not None else inputs_embeds.device - - if token_type_ids is not None: - token_type_ids = token_type_ids.view(-1, input_shape[-1]) - if position_ids is not None: - position_ids = position_ids.view(-1, input_shape[-1]) + using_flash_attention = ( + self.attention_implementation == AttentionImplementation.FLASH and past_key_values is None + ) if past_key_values is None: past_length = 0 @@ -684,49 +1058,23 @@ def forward( past_length = past_key_values[0] else: past_length = past_key_values[0].size(-2) - - if attention_mask is not None and len(attention_mask.shape) == 2 and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_length > 0: - position_ids = position_ids[:, past_length : input_shape[-1] + past_length :] - elif position_ids is None: - position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) - - # Self-attention mask. - query_length = input_shape[-1] key_length = past_length + query_length - self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length] - if attention_mask is not None: - self_attention_mask = self_attention_mask * attention_mask.view(batch_size, 1, -1).to( - dtype=torch.bool, device=self_attention_mask.device - ) + position_ids = self._get_position_ids(position_ids, attention_mask, query_length, key_length, input_ids.device) - # MQA models: (batch_size, query_length, n_heads, key_length) - # MHA models: (batch_size, n_heads, query_length, key_length) - attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1) + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, query_length) - if self.pad_key_length: - pad = -key_length % 8 - if pad > 0: - attention_mask = torch.nn.functional.pad(attention_mask, (0, pad), mode="constant", value=False) + if not using_flash_attention: + # Self-attention mask (padding + causal). + attention_mask = self._get_causal_mask(attention_mask, query_length, key_length) + if self.pad_key_length: + pad = -key_length % 8 + if pad > 0: + attention_mask = torch.nn.functional.pad(attention_mask, (0, pad), mode="constant", value=False) - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if ( - self.config.add_cross_attention - and encoder_hidden_states is not None - and encoder_attention_mask is not None - ): - if encoder_attention_mask.dim() == 2: - encoder_attention_mask.unsqueeze(1) - assert encoder_attention_mask.dim() == 3 - encoder_attention_mask = encoder_attention_mask.bool().unsqueeze(2 if self.multi_query else 1) - else: - encoder_attention_mask = None + if encoder_hidden_states is not None or encoder_attention_mask is not None: + raise NotImplementedError("Cross-attention is not supported for gpt_bigcode.") # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head @@ -745,7 +1093,13 @@ def forward( hidden_states = self.drop(hidden_states) - output_shape = input_shape + (hidden_states.size(-1),) + # TODO: Unpad earlier (input ids), support unpadded input? + if using_flash_attention: + hidden_states, padding_index, sequence_lengths, max_sequence_length = unpad_input( + hidden_states, attention_mask + ) + # Pass the required parameters through the attention_mask argument + attention_mask = (sequence_lengths, padding_index, batch_size, max_sequence_length) presents = [] if use_cache else None all_self_attentions = () if output_attentions else None @@ -770,8 +1124,6 @@ def custom_forward(*inputs): None, attention_mask, head_mask[i], - encoder_hidden_states, - encoder_attention_mask, ) else: outputs = block( @@ -779,8 +1131,6 @@ def custom_forward(*inputs): layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, use_cache=use_cache, output_attentions=output_attentions, ) @@ -796,7 +1146,11 @@ def custom_forward(*inputs): hidden_states = self.ln_f(hidden_states) - hidden_states = hidden_states.view(output_shape) + if using_flash_attention: + hidden_states = pad_input(hidden_states, padding_index, batch_size, query_length) + + hidden_states = hidden_states.view(input_shape + (hidden_states.size(-1),)) + # Add last hidden state if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -831,6 +1185,8 @@ def __init__(self, config): super().__init__(config) self.transformer = GPTBigCodeModel(config) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + self.predict_last_token = config.predict_last_token + self.attention_implementation = config.attention_implementation # Initialize weights and apply final processing self.post_init() @@ -845,21 +1201,9 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_ token_type_ids = kwargs.get("token_type_ids", None) # only last token for inputs_ids if past is defined in kwargs if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(-1) + input_ids = input_ids[:, -1:] if token_type_ids is not None: - token_type_ids = token_type_ids[:, -1].unsqueeze(-1) - - attention_mask = kwargs.get("attention_mask", None) - position_ids = kwargs.get("position_ids", None) - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) - else: - position_ids = None + token_type_ids = token_type_ids[:, -1:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: @@ -871,8 +1215,8 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_ { "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), - "position_ids": position_ids, - "attention_mask": attention_mask, + "position_ids": kwargs.get("position_ids", None), + "attention_mask": kwargs.get("attention_mask", None), "token_type_ids": token_type_ids, } ) @@ -926,6 +1270,10 @@ def forward( ) hidden_states = transformer_outputs[0] + if self.predict_last_token and not self.training: + # We only care about the last token. + hidden_states = hidden_states[:, -1:] + lm_logits = self.lm_head(hidden_states) loss = None