diff --git a/LAVIS/lavis/models/blip2_models/modeling_llama.py b/LAVIS/lavis/models/blip2_models/modeling_llama.py index 1b714b9..19306ba 100644 --- a/LAVIS/lavis/models/blip2_models/modeling_llama.py +++ b/LAVIS/lavis/models/blip2_models/modeling_llama.py @@ -418,6 +418,8 @@ class LlamaFlashAttention2(LlamaAttention): untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ + def __init__(self, config: LlamaConfig): + super().__init__(config) def forward( self,