Skip to content

Commit

Permalink
Further fixes for performance with internal bucketing. (huggingface#781)
Browse files Browse the repository at this point in the history
Signed-off-by: Puneesh Khanna <[email protected]>
  • Loading branch information
Puneesh Khanna authored Mar 11, 2024
1 parent 639c21b commit 20832c9
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 19 deletions.
40 changes: 22 additions & 18 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,8 @@ def _update_model_kwargs_for_generation(

if token_idx is not None:
token_idx.add_(1)
if "token_idx_cpu" in model_kwargs:
model_kwargs["token_idx_cpu"] += 1

return model_kwargs

Expand Down Expand Up @@ -576,6 +578,7 @@ def generate(
# token_idx is the current index in the generation process, it is incremented each time a new token is generated
token_idx = inputs_tensor.shape[-1]
model_kwargs["token_idx"] = torch.tensor(token_idx, device=inputs_tensor.device)
model_kwargs["token_idx_cpu"] = token_idx
if generation_config.max_new_tokens is None:
generation_config.max_new_tokens = generation_config.max_length - token_idx
inputs_tensor = torch.nn.functional.pad(
Expand Down Expand Up @@ -670,6 +673,7 @@ def generate(
model_kwargs["attn_softmax_bf16"] = generation_config.attn_softmax_bf16

# determine whether limit_hpu_graphs needs to be used
model_kwargs["use_hpu_graphs"] = hpu_graphs
model_kwargs["limit_hpu_graphs"] = generation_config.limit_hpu_graphs

# prepare for allocate kv cache
Expand Down Expand Up @@ -1333,8 +1337,9 @@ def greedy_search(
hb_profer.start()
this_peer_finished = False # used by synced_gpus only
bucket_size = model_kwargs.get("bucket_size", -1)
bucket_internal = model_kwargs["bucket_internal"]
reduce_recompile = model_kwargs.get("reduce_recompile", False)
prev_idx = None # avoiding calculate cache_idx when its value is not changing
prev_idx = -1 # avoiding calculate cache_idx when its value is not changing
bucket_internal = model_kwargs.get("bucket_internal", None)

prompt_len = input_ids.shape[-1]
Expand Down Expand Up @@ -1362,23 +1367,12 @@ def greedy_search(
if this_peer_finished_flag.item() == 0.0:
break

if bucket_size > 0:
if not bucket_internal:
# it will not have been padded if bucket_size > 0
params = next(inc)
input_ids, model_kwargs = self.update_model_kwargs_for_bucketing(
params, input_ids, model_kwargs, pad_token_id, bucket_size, reduce_recompile
)
else:
# Calculate slice idx for kv cache. Breaking down the kv cache in the attention block helps to reduce computation time.
if model_kwargs.get("token_idx") <= (model_kwargs["kv_cache_len"] // bucket_size) * bucket_size:
idx = torch.div(model_kwargs.get("token_idx") - 1, bucket_size, rounding_mode="floor")
if idx != prev_idx:
cache_idx = (idx.item() + 1) * bucket_size
model_kwargs["cache_idx"] = cache_idx
prev_idx = idx
else:
model_kwargs["cache_idx"] = model_kwargs["kv_cache_len"]
if bucket_size > 0 and not bucket_internal:
# it will not have been padded if bucket_size > 0
params = next(inc)
input_ids, model_kwargs = self.update_model_kwargs_for_bucketing(
params, input_ids, model_kwargs, pad_token_id, bucket_size, reduce_recompile
)

# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
Expand Down Expand Up @@ -1453,6 +1447,16 @@ def greedy_search(
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
if bucket_size > 0 and bucket_internal:
# Calculate slice idx for kv cache during the decode phase.
# Breaking down the kv cache in the attention block helps to reduce computation time.
if model_kwargs.get("token_idx_cpu") <= (model_kwargs["kv_cache_len"] // bucket_size) * bucket_size:
idx = (model_kwargs.get("token_idx_cpu") - 1) // bucket_size
if prev_idx != idx:
model_kwargs["cache_idx"] = (idx + 1) * bucket_size
prev_idx = idx
else:
model_kwargs["cache_idx"] = model_kwargs["kv_cache_len"]
cur_len = cur_len + 1

# if eos_token was found in one sentence, set sentence to finished
Expand Down
3 changes: 2 additions & 1 deletion optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,8 @@ def pre_attn_forward(
if cache_idx is not None and q_len == 1:
key_states = key_states[:, :, :cache_idx, :]
value_states = value_states[:, :, :cache_idx, :]
attention_mask = attention_mask[:, :, :, :cache_idx]
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, :cache_idx]
kv_seq_len = key_states.shape[-2]

if use_cache:
Expand Down

0 comments on commit 20832c9

Please sign in to comment.