Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The first generation token output sees the whole cache key and value #27

Open
PengWenChen opened this issue Jan 6, 2025 · 3 comments
Open

Comments

@PengWenChen
Copy link

past_key_value.update(key_states_compress, value_states_compress, self.layer_idx, cache_kwargs)

Hi there~
Thanks for your great work!
The past_key_value in L130 does update the new compressed key and value.
However, the first generation tokens(L168) are still generated with full cache key and value after the prompt compression.

attn_output = self._flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=dropout_rate,
use_sliding_windows=use_sliding_windows,
)

Is this a bug?

@XiongxiaoL
Copy link

Because this is during the prefill stage, it is unrelated to kv compression, so full kv is used for computation.

@PengWenChen
Copy link
Author

To my understanding, the first generation token is the last output logit of prefilling stage.
So the first token of the model response comes from the attn_output here right?

attn_output = self._flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=dropout_rate,
use_sliding_windows=use_sliding_windows,
)

If so, then the first generation(predict) token sees the whole KV from input prompt.
If not, what's the input token of the first generation token after KV compressing? There must exist a input token to become hidden states and predict the first response token right?

@akhauriyash
Copy link

akhauriyash commented Jan 27, 2025

Hello,
Has there been a resolution / more discussion on this?

I think the simple fix is to do this here:
key_states, value_states = past_key_value.update(key_states_compress, value_states_compress, self.layer_idx, cache_kwargs)

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants