Skip to content

Commit

Permalink
update qwenvl
Browse files Browse the repository at this point in the history
  • Loading branch information
kennymckormick committed Mar 23, 2024
1 parent f92be3a commit e16d1c0
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 22 deletions.
2 changes: 1 addition & 1 deletion vlmeval/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@

qwen_series = {
'qwen_base': partial(QwenVL, model_path='Qwen/Qwen-VL'),
'qwen_chat': partial(QwenVLChat, model_path='Qwen/Qwen-VL-Chat'),
'qwen_chat': partial(QwenVL, model_path='Qwen/Qwen-VL-Chat'),
'monkey':partial(Monkey, model_path='echo840/Monkey'),
'monkey-chat':partial(MonkeyChat, model_path='echo840/Monkey-Chat')
}
Expand Down
72 changes: 51 additions & 21 deletions vlmeval/vlm/qwen_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from transformers import AutoModelForCausalLM, AutoTokenizer
import warnings
import os.path as osp
import copy as cp
from vlmeval.smp import isimg, listinstr
import re
from vlmeval.utils import DATASET_TYPE


class QwenVL:
Expand All @@ -13,33 +14,62 @@ class QwenVL:
def __init__(self, model_path='Qwen/Qwen-VL', **kwargs):
assert model_path is not None
self.model_path = model_path
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
tokenizer.padding_side = 'left'
tokenizer.pad_token_id = tokenizer.eod_id
self.tokenizer = tokenizer
self.model = AutoModelForCausalLM.from_pretrained(model_path, device_map='cuda', trust_remote_code=True).eval()
self.kwargs = kwargs
default_kwargs = dict(
do_sample=False,
num_beams=1,
max_new_tokens=512,
min_new_tokens=1,
num_return_sequences=1,
use_cache=True,
output_hidden_states=True,
pad_token_id=tokenizer.eod_id,
eos_token_id=tokenizer.eod_id)
default_kwargs.update(self.kwargs)
self.kwargs = default_kwargs
warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ')
torch.cuda.empty_cache()

def generate(self, image_path, prompt, dataset=None):
vl_pair = [{'image': image_path}, {'text': prompt}]
query = self.tokenizer.from_list_format(vl_pair)

inputs = self.tokenizer(query, return_tensors='pt')
inputs = inputs.to(self.model.device)
pred = self.model.generate(**inputs, **self.kwargs)
response = self.tokenizer.decode(pred.cpu()[0], skip_special_tokens=False)
response = response.split(prompt)[1].split('<|endoftext|>')[0]
return response
def adjust_kwargs(self, dataset):
kwargs = cp.deepcopy(self.kwargs)
if DATASET_TYPE(dataset) in ['multi-choice', 'Y/N']:
kwargs['max_new_tokens'] = 32
elif DATASET_TYPE(dataset) == 'Caption' and 'COCO' in dataset:
kwargs['max_new_tokens'] = 32
elif DATASET_TYPE(dataset) == 'VQA':
if listinstr(['OCRVQA', 'ChartQA', 'DocVQA'], dataset):
kwargs['max_new_tokens'] = 100
elif listinstr(['TextVQA'], dataset):
kwargs['max_new_tokens'] = 10
return kwargs

def interleave_generate(self, ti_list, dataset=None):
vl_list = [{'image': s} if isimg(s) else {'text': s} for s in ti_list]
query = self.tokenizer.from_list_format(vl_list)
kwargs = self.adjust_kwargs(dataset)
prompt = ''
for s in ti_list:
if isimg(s):
prompt += f'<img>{s}</img>'
else:
prompt += s
if DATASET_TYPE(dataset) == 'VQA':
prompt += ' Answer:'
encoded = self.tokenizer([prompt], return_tensors='pt', padding='longest')
input_ids = encoded.input_ids.to('cuda')
attention_mask = encoded.attention_mask.to('cuda')

inputs = self.tokenizer(query, return_tensors='pt')
inputs = inputs.to(self.model.device)
pred = self.model.generate(**inputs, **self.kwargs)
response = self.tokenizer.decode(pred.cpu()[0], skip_special_tokens=False)
response = response.split(query)[1].split('<|endoftext|>')[0]
return response
pred = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
**kwargs)
answer = self.tokenizer.decode(pred[0][input_ids.size(1):].cpu(), skip_special_tokens=True).strip()
return answer

def generate(self, image_path, prompt, dataset=None):
return self.interleave_generate([image_path, prompt], dataset)


class QwenVLChat:
Expand Down

0 comments on commit e16d1c0

Please sign in to comment.