You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When I want to use the SnapKV on the Qwen2-VL to compress the visual token, the key, value is compressed successfully. I print the input shape of _flash_attn_forward_func, but the result is same as origin which is not compressed. Even I set the window_size to 1 and max_capacity_prompt=2, the result has no change.
here is my code:
importmathfromtypingimportOptional, Tuplefromloguruimportloggeraseval_loggerfromtransformersimportAutoProcessor, AutoTokenizer, Qwen2VLForConditionalGenerationimporttransformersfromtransformers.cache_utilsimportCache, DynamicCache, StaticCachefromtransformers.models.qwen2_vl.modeling_qwen2_vlimport(
apply_multimodal_rotary_pos_emb,
repeat_kv
)
fromtransformers.utilsimport (
is_flash_attn_2_available,
logging,
)
importtorchimporttorch.nnasnnimporttorch.nn.functionalasFfromtransformers.utilsimportlogginglogger=logging.get_logger(__name__)
try:
fromqwen_vl_utilsimportprocess_vision_info, extract_vision_info, fetch_image, fetch_videoexceptImportError:
eval_logger.warning("Failed to import qwen_vl_utils; Please install it via `pip install qwen-vl-utils`")
ifis_flash_attn_2_available():
fromflash_attnimportflash_attn_varlen_funcfromtransformers.modeling_flash_attention_utilsimport_flash_attention_forwardelse:
raiseRuntimeError("Only support flash attention 2 for now, please install the flash attention 2")
# NOTE: copy from SnapKV.snapkv_utils.pyclassSnapKVCluster():
def__init__(self, window_size=64, max_capacity_prompt=256+64, kernel_size=5, pooling='avgpool'):
self.window_size=window_sizeself.max_capacity_prompt=max_capacity_promptassertself.max_capacity_prompt-self.window_size>0self.kernel_size=kernel_sizeself.pooling=poolingdefreset(self, window_size=64, max_capacity_prompt=256+64, kernel_size=5, pooling='avgpool'):
self.window_size=window_sizeself.max_capacity_prompt=max_capacity_promptassertself.max_capacity_prompt-self.window_size>0self.kernel_size=kernel_sizeself.pooling=poolingdefupdate_kv(self, key_states, query_states, value_states, attention_mask, num_key_value_groups):
# check if prefix phaseassertkey_states.shape[-2] ==query_states.shape[-2]
bsz, num_heads, q_len, head_dim=query_states.shapeifq_len<self.max_capacity_prompt:
returnkey_states, value_stateselse:
# 计算L_obs queries与 所有keys的attention scoreattn_weights=torch.matmul(query_states[..., -self.window_size:, :], key_states.transpose(2, 3)) /math.sqrt(head_dim)
# 构建一个下三角掩码矩阵,下三角的元素为0,上三角的元素为-infmask=torch.full((self.window_size, self.window_size), torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
mask_cond=torch.arange(mask.size(-1), device=attn_weights.device)
mask.masked_fill_(mask_cond< (mask_cond+1).view(mask.size(-1), 1), 0)
mask=mask.to(attn_weights.device)
# 将掩码矩阵扩展到4维,与key_value的维度匹配attention_mask=mask[None, None, :, :]
# 对L_obs窗口内的attention score进行mask,由于会进行softmax,所有mask为0的位置不影响,mask为-inf的位置会被softmax后变为0attn_weights[:, :, -self.window_size:, -self.window_size:] +=attention_maskattn_weights=nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
# NOTE: 计算论文中的公式(2)的C,即L_obs窗口内的每个query对prefix的attention score之和attn_weights_sum=attn_weights[:, :, -self.window_size:, : -self.window_size].sum(dim=-2)
# 池化,暂时没在论文中看见ifself.pooling=='avgpool':
attn_cache=F.avg_pool1d(attn_weights_sum, kernel_size=self.kernel_size, padding=self.kernel_size//2, stride=1)
elifself.pooling=='maxpool':
attn_cache=F.max_pool1d(attn_weights_sum, kernel_size=self.kernel_size, padding=self.kernel_size//2, stride=1)
else:
raiseValueError('Pooling method not supported')
# 计算论文中的公式(3)的I, attn_cache.shape=(batch, num_heads, L_prompt), indices.shape=(batch, num_heads, k)indices=attn_cache.topk(self.max_capacity_prompt-self.window_size, dim=-1).indicesindices=indices.unsqueeze(-1).expand(-1, -1, -1, head_dim)
# 根据indices取出prefix中对应的key_valuek_past_compress=key_states[:, :, :-self.window_size, :].gather(dim=2, index=indices)
v_past_compress=value_states[:, :, :-self.window_size, :].gather(dim=2, index=indices)
# 将prefix中的key_value与L_obs窗口内的key_value拼接k_cur=key_states[:, :, -self.window_size:, :]
v_cur=value_states[:, :, -self.window_size:, :]
key_states=torch.cat([k_past_compress, k_cur], dim=2)
value_states=torch.cat([v_past_compress, v_cur], dim=2)
returnkey_states, value_statesdefinit_snapkv(self):
ifnothasattr(self, "kv_cluster"):
ifnothasattr(self.config, 'window_size'):
self.config.window_size=512ifnothasattr(self.config, 'max_capacity_prompt'):
self.config.max_capacity_prompt=2048ifnothasattr(self.config, 'kernel_size'):
self.config.kernel_size=5ifnothasattr(self.config, 'pooling'):
self.config.pooling='avgpool'self.kv_cluster=SnapKVCluster(
window_size=self.config.window_size,
max_capacity_prompt=self.config.max_capacity_prompt,
kernel_size=self.config.kernel_size,
pooling=self.config.pooling
)
defqwen2_vl_flash_attn2_forward(
self: torch.nn.Module,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] =None,
position_ids: Optional[torch.LongTensor] =None,
past_key_value: Optional[Cache] =None,
output_attentions: bool=False,
use_cache: bool=False,
**kwargs,
)->Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
init_snapkv(self)
output_attentions=Falsebsz, q_len, _=hidden_states.size()
query_states=self.q_proj(hidden_states)
key_states=self.k_proj(hidden_states)
value_states=self.v_proj(hidden_states)
query_states=query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states=key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states=value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len=key_states.shape[-2]
ifpast_key_valueisnotNone:
ifself.layer_idxisNone:
raiseValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} ""for auto-regressive decoding with k/v caching, please make sure to initialize the attention class ""with a layer index."
)
ifhasattr(self, "kv_seq_len"): #[SnapKV] add kv_seq_lenifself.kv_seq_len!=0:
kv_seq_len+=self.kv_seq_lenelse:
kv_seq_len+=past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
else:
kv_seq_len+=past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin=self.rotary_emb(value_states, position_ids)
query_states, key_states=apply_multimodal_rotary_pos_emb(
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
)
key_states=repeat_kv(key_states, self.num_key_value_groups)
value_states=repeat_kv(value_states, self.num_key_value_groups)
ifpast_key_valueisnotNone:
cache_kwargs= {"sin": sin, "cos": cos} # Specific to RoPE models# key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)# print('kv_seq_len:', kv_seq_len)# print('key_states.shape:', key_states.shape)ifkey_states.shape[-2] ==kv_seq_len: # [SnapKV] add kv_clusterself.kv_seq_len=kv_seq_len# [SnapKV] register kv_seq_lenkey_states_compress, value_states_compress=self.kv_cluster.update_kv(key_states, query_states, value_states, attention_mask, self.num_key_value_groups)
past_key_value.update(key_states_compress, value_states_compress, self.layer_idx, cache_kwargs)
ifself.layer_idx==0:
eval_logger.info(f"SnapKV compressing...After compress, the length of kv is {key_states_compress.shape[-2]}")
else:
self.kv_seq_len+=q_lenkey_states, value_states=past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
ifself.layer_idx==0:
eval_logger.info(f"SnapKV do nothing, the length of kv is {key_states.shape[-2]}")
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache# to be able to avoid many of these transpose/reshape/view.query_states=query_states.transpose(1, 2)
key_states=key_states.transpose(1, 2)
value_states=value_states.transpose(1, 2)
dropout_rate=self.attention_dropoutifself.trainingelse0.0# In PEFT, usually we cast the layer norms in float32 for training stability reasons# therefore the input hidden states gets silently casted in float32. Hence, we need# cast them back in the correct dtype just to be sure everything works as expected.# This might slowdown training & inference so it is recommended to not cast the LayerNorms# in fp32. (LlamaRMSNorm handles it correctly)input_dtype=query_states.dtypeifinput_dtype==torch.float32:
iftorch.is_autocast_enabled():
target_dtype=torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantizedelifhasattr(self.config, "_pre_quantization_dtype"):
target_dtype=self.config._pre_quantization_dtypeelse:
target_dtype=self.q_proj.weight.dtypelogger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"f" {target_dtype}."
)
query_states=query_states.to(target_dtype)
key_states=key_states.to(target_dtype)
value_states=value_states.to(target_dtype)
if (
self.config.use_sliding_windowandgetattr(self.config, "sliding_window", None) isnotNoneandself.layer_idx>=self.config.max_window_layers
):
sliding_window=self.config.sliding_windowelse:
sliding_window=Noneprint(f"SnapKV flash attn: Q: {query_states.shape}, K: {key_states.shape}, V: {value_states.shape}")
attn_output=_flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=dropout_rate,
sliding_window=sliding_window,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
)
attn_output=attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output=self.o_proj(attn_output)
ifnotoutput_attentions:
attn_weights=Nonereturnattn_output, attn_weights, past_key_valuedefqwen2_vl_prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
pixel_values=None,
pixel_values_videos=None,
image_grid_thw=None,
video_grid_thw=None,
**kwargs,
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens# Exception 1: when passing input_embeds, input_ids may be missing entries# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it hereifpast_key_valuesisNoneor (isinstance(past_key_values, DynamicCache) andpast_key_values.get_seq_length() ==0): # [SnapKV]forlayerinself.model.layers:
layer.self_attn.kv_seq_len=0ifpast_key_valuesisnotNone:
ifinputs_embedsisnotNone: # Exception 1input_ids=input_ids[:, -cache_position.shape[0] :]
elifinput_ids.shape[1] !=cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)input_ids=input_ids[:, cache_position]
ifcache_position[0] !=0:
pixel_values=Nonepixel_values_videos=None# if `inputs_embeds` are passed, we only want to use them in the 1st generation stepifinputs_embedsisnotNoneandcache_position[0] ==0:
model_inputs= {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
model_inputs= {"input_ids": input_ids, "inputs_embeds": None}
ifisinstance(past_key_values, StaticCache) andattention_mask.ndim==2:
ifmodel_inputs["inputs_embeds"] isnotNone:
batch_size, sequence_length, _=inputs_embeds.shapedevice=inputs_embeds.deviceelse:
batch_size, sequence_length=input_ids.shapedevice=input_ids.deviceattention_mask=self.model._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_cache_shape(),
dtype=self.lm_head.weight.dtype,
device=device,
cache_position=cache_position,
batch_size=batch_size,
config=self.config,
past_key_values=past_key_values,
)
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"pixel_values": pixel_values,
"pixel_values_videos": pixel_values_videos,
"image_grid_thw": image_grid_thw,
"video_grid_thw": video_grid_thw,
"cache_position": cache_position,
}
)
returnmodel_inputs# patchdefpatch_qwen2_vl():
transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLFlashAttention2.forward=qwen2_vl_flash_attn2_forwardtransformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLForConditionalGeneration.prepare_inputs_for_generation=qwen2_vl_prepare_inputs_for_generationpatch_qwen2_vl()
defmain():
qwen2_vl_path="path/to/your checkpoints"model=Qwen2VLForConditionalGeneration.from_pretrained(qwen2_vl_path, device_map="auto", torch_dtype="auto", attn_implementation="flash_attention_2",).eval()
processor=AutoProcessor.from_pretrained(qwen2_vl_path)
tokenizer=AutoTokenizer.from_pretrained(qwen2_vl_path)
prompt="Hello, who are you?"input_ids=tokenizer([prompt], return_tensors="pt")
input_ids=input_ids.to("cuda")
output=model.generate(**input_ids)
text=tokenizer.batch_decode(output)
print(text)
if__name__=="__main__":
main()
the output text is always "['Hello, who are you? I am a language model created by Alibaba Cloud. I am called Qwen. I am a large']", whatever the size of window and max_capacity_prompt are.
The text was updated successfully, but these errors were encountered:
When I want to use the SnapKV on the Qwen2-VL to compress the visual token, the key, value is compressed successfully. I print the input shape of _flash_attn_forward_func, but the result is same as origin which is not compressed. Even I set the window_size to 1 and max_capacity_prompt=2, the result has no change.
here is my code:
the output text is always "['Hello, who are you? I am a language model created by Alibaba Cloud. I am called Qwen. I am a large']", whatever the size of window and max_capacity_prompt are.
The text was updated successfully, but these errors were encountered: