Skip to content

Commit

Permalink
Falcon40b - enable SDPA as single op - updated version (#11051)
Browse files Browse the repository at this point in the history
* #9637: Enable SDPA as single op - updated version

* #9637: Update expected output tokens

* #9637: Remove unnecessary commented code

* #9637: Fix SDPA
  • Loading branch information
djordje-tt authored Aug 7, 2024
1 parent 7987dce commit 4c579d3
Show file tree
Hide file tree
Showing 12 changed files with 60 additions and 240 deletions.
16 changes: 16 additions & 0 deletions models/demos/t3000/falcon40b/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,22 @@
SPACE = 204


# Used for debugging non-deterministic outputs of prefill stage
def save_kv_cache_to_file(device_mesh, kv_cache, kv_cache_path):
# generate tensor of 60 layers and key and value tensors for each layer where there is 60 layers, key and value and tensor shape (32, 1, 128, 64)
final_tensor = torch.zeros(60, 2, 32, 1, 128, 512)
for layer in range(60):
for type in range(len(kv_cache[layer])):
# get key tensor from device
tensor = ttnn.to_torch(
kv_cache[layer][type], device=device_mesh, mesh_composer=ttnn.ConcatMeshToTensor(device_mesh, dim=-1)
)
# save tensor to file
final_tensor[layer][type] = tensor

torch.save(final_tensor, kv_cache_path)


# load from jason, return as a list
def load_inputs(user_input, batch):
if isinstance(user_input, str):
Expand Down

Large diffs are not rendered by default.

12 changes: 9 additions & 3 deletions models/demos/t3000/falcon40b/tests/test_falcon_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,18 @@ def run_test_FalconAttention_inference(

tt_attention_mask = ttnn.as_tensor(
tensor=attention_mask_bool,
dtype=model_config["ATTN_MASK_DTYPE"],
layout=ttnn.TILE_LAYOUT,
dtype=model_config["BFLOAT16_DTYPE"],
layout=ttnn.ROW_MAJOR_LAYOUT,
device=device_mesh,
memory_config=attention_mask_memconfig,
mesh_mapper=ReplicateTensorToMesh(device_mesh),
preprocess=lambda x: x * (-1e5),
preprocess=lambda x: (x * (-1e5)).expand(1, 1, -1, -1),
)

tt_attention_mask = ttnn.tilize(
tt_attention_mask,
memory_config=model_config["DRAM_MEMCFG"],
dtype=model_config["ATTN_MASK_DTYPE"],
)

tt_k_cache_host = torch.zeros(batch, configuration.num_kv_heads, max_position_embeddings, head_dim)
Expand Down
4 changes: 2 additions & 2 deletions models/demos/t3000/falcon40b/tests/test_falcon_causallm.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ def run_test_FalconCausalLM_inference(
use_global_cos_sin_cache = True

if 1:
model_input = torch.arange(seq_len * batch).reshape(batch, seq_len)
model_input = torch.randint(0, seq_len * batch, (batch, seq_len))
else:
# batch identical sequences for debugging
model_input = torch.stack([torch.arange(seq_len)] * batch).reshape(batch, seq_len)
model_input = torch.stack([torch.randint(0, seq_len)] * batch).reshape(batch, seq_len)

# Generate dummy kv_cache --------------------------------------------------------------
if llm_mode == "prefill":
Expand Down
4 changes: 2 additions & 2 deletions models/demos/t3000/falcon40b/tests/test_falcon_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ def run_test_FalconCausalLM_end_to_end(
use_global_cos_sin_cache = True

if 1:
model_input = torch.arange(seq_len * batch).reshape(batch, seq_len)
model_input = torch.randint(0, seq_len * batch, (batch, seq_len))
else:
# batch identical sequences for debugging
model_input = torch.stack([torch.arange(seq_len)] * batch).reshape(batch, seq_len)
model_input = torch.stack([torch.randint(0, seq_len)] * batch).reshape(batch, seq_len)

# Generate dummy kv_cache --------------------------------------------------------------
if llm_mode == "prefill":
Expand Down
4 changes: 2 additions & 2 deletions models/demos/t3000/falcon40b/tests/test_falcon_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ def run_test_FalconModel_inference(
use_global_cos_sin_cache = True

if 1:
model_input = torch.arange(seq_len * batch).reshape(batch, seq_len)
model_input = torch.randint(0, seq_len * batch, (batch, seq_len))
else:
# batch identical sequences for debugging
model_input = torch.stack([torch.arange(seq_len)] * batch).reshape(batch, seq_len)
model_input = torch.stack([torch.randint(0, seq_len)] * batch).reshape(batch, seq_len)

# Generate dummy kv_cache --------------------------------------------------------------
if llm_mode == "prefill":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def run_test_falcon_prefill_end_to_end_determinism(
logger.info("Done loading TT Falcon Model")

# Prepare inputs -----------------------------------------------------------------------
model_input = torch.arange(seq_len * batch).reshape(batch, seq_len)
model_input = torch.randint(0, seq_len * batch, (batch, seq_len))
model_inputs = torch.split(model_input, 1)

# First run to get reference output ----------------------------------------------------
Expand Down
4 changes: 2 additions & 2 deletions models/demos/t3000/falcon40b/tests/test_perf_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ def run_test_FalconCausalLM_end_to_end(
use_global_cos_sin_cache = True

if True:
model_input = torch.arange(seq_len * batch).reshape(batch, seq_len)
model_input = torch.randint(0, seq_len * batch, (batch, seq_len))
else:
# batch identical sequences for debugging
model_input = torch.stack([torch.arange(seq_len)] * batch).reshape(batch, seq_len)
model_input = torch.stack([torch.randint(0, seq_len)] * batch).reshape(batch, seq_len)

# Generate dummy kv_cache --------------------------------------------------------------
if llm_mode == "prefill":
Expand Down
129 changes: 11 additions & 118 deletions models/demos/t3000/falcon40b/tt/falcon_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,8 @@
import ttnn
from ttnn import ShardTensorToMesh, ReplicateTensorToMesh

from models.utility_functions import (
torch2tt_tensor,
tt2torch_tensor,
pad_by_zero,
nearest_32,
)
from models.demos.t3000.falcon40b.tt.model_utils import (
convert_to_layout,
)
from models.utility_functions import nearest_32
from models.demos.t3000.falcon40b.tt.model_utils import convert_to_layout

from models.demos.t3000.falcon40b.tt.model_utils import falcon_prefill_matmul, determine_tensor_deallocation

Expand Down Expand Up @@ -205,7 +198,7 @@ def __init__(
# self.scalar = pad_by_zero(torch.Tensor([1 / math.sqrt(self.head_dim)]), self.device)[0]
self.scalar = 1 / math.sqrt(self.head_dim)

self.init_preprocessing(self.model_config["LLM_MODE"], max_position_embeddings)
# self.init_preprocessing(self.model_config["LLM_MODE"], max_position_embeddings)
self.layer_past = None

def initialize_kvcache(self):
Expand Down Expand Up @@ -252,22 +245,6 @@ def initialize_kvcache(self):

def set_model_config(self, model_config):
self.model_config = model_config
self.init_preprocessing(self.model_config["LLM_MODE"], self.max_position_embeddings)

def init_preprocessing(self, llm_mode, max_sequence_size):
if llm_mode == "prefill":
self.attn_output = ttnn.as_tensor(
torch.zeros([1, self.num_heads_per_device, max_sequence_size, self.head_dim]),
dtype=self.model_config["POST_SOFTMAX_MM_OUTPUT_DTYPE"],
layout=ttnn.TILE_LAYOUT,
device=self.device_mesh,
memory_config=self.model_config["DRAM_MEMCFG"],
mesh_mapper=ReplicateTensorToMesh(self.device_mesh),
)

def online_preprocessing(self, llm_mode, sequence_size):
if llm_mode == "prefill":
self.sliced_attn_output = self.attn_output[:, :, :sequence_size, :]

def __call__(
self,
Expand Down Expand Up @@ -369,72 +346,18 @@ def fwd_prefill(
ttnn.experimental.tensor.typecast(value_layer, self.model_config["KV_CACHE_DTYPE"]),
user_id,
)
key_layer_transposed = ttnn.transpose(
attn_output = ttnn.experimental.operations.primary.transformers.scaled_dot_product_attention(
query_layer,
key_layer,
-2,
-1,
memory_config=self.model_config["K_TRANSPOSED_OUTPUT_MEMCFG"],
value_layer,
attention_mask,
is_causal=True,
scale=self.scalar,
program_config=self.model_config["SDPA_PROGCFG"],
)
key_layer.deallocate(True)

slice_size = self.model_config["attention_params"]["attention_slice_size"]
num_slices = self.model_config["attention_params"]["attention_num_slices"]

if num_slices > 1:
if not hasattr(self, "sliced_attn_output"):
self.online_preprocessing(llm_mode, q_len)
attn_output_tensor = self.sliced_attn_output

for slice_i in range(num_slices):
# Partially slice and convert activations to sharded
q_slices = ttnn.experimental.tensor.interleaved_to_sharded_partial(
query_layer,
(8, 8),
[slice_size * 16 // 64, self.head_dim], # each slice is [1,16,128,64], we use 64 cores
num_slices,
slice_i,
ttnn.experimental.tensor.TensorMemoryLayout.HEIGHT_SHARDED,
ttnn.experimental.tensor.ShardOrientation.ROW_MAJOR,
)
attn_output_slice = self.scaled_dot_product_attention(
q_slices,
key_layer_transposed,
attention_mask,
value_layer,
q_len,
)
ttnn.experimental.tensor.sharded_to_interleaved_partial(
attn_output_slice,
attn_output_tensor,
num_slices,
slice_i,
self.model_config["DRAM_MEMCFG"],
)
attn_output_slice.deallocate(True)
attn_output = attn_output_tensor
q_slices.deallocate(True)
else:
query_layer = convert_to_layout(
query_layer,
self.model_config["DRAM_MEMCFG"],
self.model_config["QUERY_HEIGHT_SHARDED_MEMCFG"],
)
attn_output = self.scaled_dot_product_attention(
query_layer,
key_layer_transposed,
attention_mask,
value_layer,
q_len,
)
attn_output = convert_to_layout(
attn_output,
self.model_config["ATTN_OUTPUT_HEIGHT_SHARDED_MEMCFG"],
self.model_config["DRAM_MEMCFG"],
)

# Deallocate query, key, value
query_layer.deallocate(True)
key_layer_transposed.deallocate(True)
key_layer.deallocate(True)
value_layer.deallocate(True)

# Output projection
Expand Down Expand Up @@ -467,36 +390,6 @@ def fwd_prefill(
layer_present = layer_past if use_cache else None
return attn_output, layer_present

def scaled_dot_product_attention(self, q_slices, key_layer_transposed, attn_mask_slices, value_layer, q_len):
# Q * KˆT
attn_weights = ttnn.matmul(
q_slices,
key_layer_transposed,
compute_kernel_config=self.model_config["COMPUTE_KERNEL_FP16_ACC_CONFIG"],
memory_config=self.model_config["HEIGHT_SHARDED_MEMCFG"],
program_config=self.model_config["ATTENTION_MM_PROGCFG"],
dtype=self.model_config["ATTENTION_DTYPE"],
)
# Softmax
attn_weights = ttnn.scale_causal_mask_hw_dims_softmax_in_place(
attn_weights,
self.scalar,
attn_mask_slices,
program_config=self.model_config["SOFTMAX_PROGCFG"],
)
# Attention score * V
attn_output_slice = ttnn.matmul(
attn_weights,
value_layer,
compute_kernel_config=self.model_config["COMPUTE_KERNEL_FP16_ACC_CONFIG"],
memory_config=self.model_config["HEIGHT_SHARDED_MEMCFG"],
program_config=self.model_config["ATTENTION_MM_2_PROGCFG"],
dtype=self.model_config["ATTENTION_OUT_DTYPE"],
)
attn_weights.deallocate(True)

return attn_output_slice

def fwd_decode(
self,
hidden_states: ttnn.experimental.tensor.Tensor,
Expand Down
8 changes: 0 additions & 8 deletions models/demos/t3000/falcon40b/tt/falcon_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,19 +131,11 @@ def pad_ln_params(x):

self.layernorm_eps = config.layer_norm_epsilon

self.init_preprocessing("prefill", 1, max_position_embeddings)

def set_model_config(self, model_config):
self.model_config = model_config
self.self_attn.set_model_config(model_config)
self.mlp.set_model_config(model_config)

def init_preprocessing(self, llm_mode, batch_size, max_sequence_size):
self.self_attn.init_preprocessing(llm_mode, max_sequence_size)

def online_preprocessing(self, llm_mode, sequence_size):
self.self_attn.online_preprocessing(llm_mode, sequence_size)

def __call__(
self,
hidden_states: ttnn.experimental.tensor.Tensor,
Expand Down
16 changes: 4 additions & 12 deletions models/demos/t3000/falcon40b/tt/falcon_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,35 +222,27 @@ def model_preprocessing(self, llm_mode, input_ids, kv_cache_len, num_input_token
),
dim=-1,
)
attention_mask_memconfig = self.model_config["ATTN_MASK_MEMCFG"]
if attention_mask_memconfig.is_sharded():
attn_mask_shard_shape = attention_mask_memconfig.shard_spec.shape
attn_mask_shard_shape[-1] = num_max_tokens
attention_mask_memconfig.shard_spec.shape = attn_mask_shard_shape

# Push attention mask to device in row major order and then tilize on device (faster than tilizing on CPU)
tt_attention_mask = ttnn.as_tensor(
tensor=attention_mask_bool_padded,
dtype=self.model_config["BFLOAT16_DTYPE"], # subsequent tilize op expects bfloat16 inputs
layout=ttnn.ROW_MAJOR_LAYOUT,
device=self.device_mesh,
memory_config=attention_mask_memconfig,
mesh_mapper=ShardTensorToMesh(self.device_mesh, dim=1),
preprocess=lambda x: (x.transpose(0, 2) * -1e5).expand(-1, self.config.num_attention_heads, -1, -1),
memory_config=self.model_config["DEFAULT_MEMCFG"],
mesh_mapper=ReplicateTensorToMesh(self.device_mesh),
preprocess=lambda x: (x.transpose(0, 2) * -1e5).expand(1, 1, -1, -1),
)

tt_attention_mask = ttnn.tilize(
tt_attention_mask,
memory_config=attention_mask_memconfig,
memory_config=self.model_config["DEFAULT_MEMCFG"],
dtype=self.model_config["ATTN_MASK_DTYPE"],
)

else:
raise NotImplementedError(f"Llm mode {llm_mode} is not supported! Must be one of prefill or decode.")

for layer in self.layers:
layer.online_preprocessing(llm_mode, sequence_size)

return tt_inputs, tt_attention_mask

@abstractmethod
Expand Down
Loading

0 comments on commit 4c579d3

Please sign in to comment.