diff --git a/vlms/minigpt4/models/minigpt_base.py b/vlms/minigpt4/models/minigpt_base.py index 0a848ca..a6cb637 100644 --- a/vlms/minigpt4/models/minigpt_base.py +++ b/vlms/minigpt4/models/minigpt_base.py @@ -285,7 +285,7 @@ def forward(self, samples, reduction='mean', concept_signals: torch.Tensor = Non bos_embeds = self.embed_tokens(bos) bos_atts = cond_atts[:, :1] - # add bos token at the begining + # add bos token at the beginning inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1) attention_mask = torch.cat([bos_atts, attention_mask], dim=1)