Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash attention experiments #18

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,28 @@ class InferenceRunnerType(IntEnum):
FULL_GRAPH = 3


class AttentionImplementation(IntEnum):
# Ours
BASE = 0
# Flash attention
FLASH = 1
# scaled_dot_product_attention (multiple implementations)
TORCH = 2
TORCH_FLASH = 3
TORCH_MEM = 4
TORCH_CPP = 5
# DEBUG
OLD = 6


TORCH_IMPLEMENTATIONS = (
AttentionImplementation.TORCH,
AttentionImplementation.TORCH_FLASH,
AttentionImplementation.TORCH_MEM,
AttentionImplementation.TORCH_CPP,
)


class GPTBigCodeConfig(PretrainedConfig):
"""
This is the configuration class to store the configuration of a [`GPTBigCodeModel`]. It is used to instantiate a
Expand Down Expand Up @@ -133,13 +155,16 @@ def __init__(
eos_token_id=50256,
attention_softmax_in_fp32=True,
scale_attention_softmax_in_fp32=True,
fused_softmax=None,
multi_query=True,
attention_implementation=AttentionImplementation.BASE,
inference_runner=InferenceRunnerType.NO_RUNNER,
validate_runner_input=True,
pre_allocate_kv_cache=False,
max_sequence_length=None,
max_batch_size=None,
pad_key_length=True,
predict_last_token: bool = False,
**kwargs,
):
self.vocab_size = vocab_size
Expand All @@ -158,7 +183,9 @@ def __init__(
self.use_cache = use_cache
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
self.scale_attention_softmax_in_fp32 = scale_attention_softmax_in_fp32
self.fused_softmax = fused_softmax
self.multi_query = multi_query
self.attention_implementation = attention_implementation

self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
Expand All @@ -175,4 +202,7 @@ def __init__(
# Pad key length to a multiple of 8 (requires pre_allocate_kv_cache).
self.pad_key_length = pad_key_length

# Predict only the last token in inference even if the input is bigger.
self.predict_last_token = predict_last_token

super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
58 changes: 35 additions & 23 deletions src/transformers/models/gpt_bigcode/inference_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,17 @@

from transformers import GPTBigCodeConfig
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from transformers.models.gpt_bigcode.configuration_gpt_bigcode import InferenceRunnerType
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeBlock, masked_softmax, upcast_masked_softmax
from transformers.models.gpt_bigcode.configuration_gpt_bigcode import (
AttentionImplementation,
InferenceRunnerType,
)
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeBlock, softmax_function


try:
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
except ImportError:
flash_attn_unpadded_func = None


def _align_tensor(x):
Expand All @@ -23,6 +32,7 @@ def __init__(self, config: GPTBigCodeConfig, model):
assert config.pre_allocate_kv_cache
self.validate_input = config.validate_runner_input
self.pad_key_length = 8 if config.pad_key_length else 1
self.fused_softmax = True if config.fused_softmax is None and config.pad_key_length else config.fused_softmax

# TODO: Support other attention types?
assert model.multi_query
Expand Down Expand Up @@ -51,9 +61,10 @@ def _allocate(self, batch_size, device, dtype):
query_end = query_begin + self.batch_size * attn.embed_dim
# KV: (bs, 2 * kv_dim), combines with query into c_attn.
kv_end = query_end + 2 * self.batch_size * attn.kv_dim
# Attn weights: (batch_size, num_heads, key_length), no overlap with value
# Attn weights: (batch_size, num_heads, key_length), no overlap with value (not needed for torch/flash attn)
attn_weights_begin = _align_tensor(kv_end)
attn_weights_end = kv_end + self.batch_size * attn.num_heads * self.max_sequence_length
attn_weights_end = attn_weights_begin
attn_weights_end += self.batch_size * attn.num_heads * self.max_sequence_length
# Projection: (batch_size, embed_dim), no overlap with attn outputs ~ query.
# Also used for MLP projection
c_proj_begin = _align_tensor(query_end)
Expand Down Expand Up @@ -119,11 +130,13 @@ def _allocate(self, batch_size, device, dtype):
# QKV: (bs, embed_dim + 2 * kv_dim).
self.c_attn = activation_pool[query_begin:kv_end].view(self.batch_size, -1)
self.query = self.c_attn[:, : attn.embed_dim].view(self.batch_size, attn.num_heads, attn.head_dim)

self.kv_attn = self.c_attn[:, attn.embed_dim :]

keys, values = zip(*(kv_cache.split((attn.head_dim, attn.head_dim), dim=-1) for kv_cache in kv_caches))
head_slice = 0 if attn.multi_query else slice(None)

# No transpose for torch/flash attn
self.padded_keys = [
[key[:, head_slice, :key_length, :].transpose(-1, -2) for key in keys] for key_length in padded_key_lengths
]
Expand Down Expand Up @@ -159,6 +172,10 @@ def _allocate(self, batch_size, device, dtype):
if self.inference_runner_type != InferenceRunnerType.BASE_RUNNER:
print("Generating cuda graphs")
self.memory_pool = None
# This prevents some issue with cublas initialization.
# https://github.com/pytorch/pytorch/issues/99397
dummy_matrix = self.mask_value.view([1, 1])
torch.matmul(dummy_matrix, dummy_matrix)
if self.inference_runner_type == InferenceRunnerType.FULL_GRAPH:
self.cuda_graphs = {}
# The output may not always be at the same memory location.
Expand Down Expand Up @@ -187,22 +204,19 @@ def _generate_cuda_graphs(self):

def _generate_full_cuda_graph(self, key_length):
# We need to warmup the jit function before creating the graph, otherwise it will crash.
# https://github.com/pytorch/pytorch/issues/99397
# Warmup needs to be done for every input shape (key length), and for both scale == 1 and scale != 1
if self.upcast:
if self.fused_softmax or (self.fused_softmax is None and key_length % 8 == 0):
for scale in (1.0, 2.0):
upcast_masked_softmax(
softmax_function(
self.padded_attn_weights[key_length],
self.padded_attn_masks[key_length],
self.mask_value,
scale,
self.softmax_dtype,
self.upcast,
self.fused_softmax,
)
else:
masked_softmax(
self.padded_attn_weights[key_length],
self.padded_attn_masks[key_length],
self.mask_value,
)
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, pool=self.memory_pool):
self.output_hidden_states[key_length] = self._forward(key_length)
Expand Down Expand Up @@ -239,18 +253,16 @@ def _forward_attn(self, block, key_length):
alpha=self.scale[layer_idx],
out=attn_weights,
)
# Use a fused kernel to prevent a large overhead from casting and scaling.
# Jit doesn't allow inplace kernel.
if self.upcast:
attn_weights = upcast_masked_softmax(
attn_weights,
self.padded_attn_masks[key_length],
self.mask_value,
self.unscale[layer_idx],
self.softmax_dtype,
)
else:
attn_weights = masked_softmax(attn_weights, self.padded_attn_masks[key_length], self.mask_value)
attn_weights = softmax_function(
attn_weights,
self.padded_attn_masks[key_length],
self.mask_value,
self.unscale[layer_idx],
self.softmax_dtype,
self.upcast,
self.fused_softmax,
)

torch.bmm(attn_weights, self.padded_values[key_length][layer_idx], out=self.attn_output_expanded)

Expand Down
Loading