Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Support LLaVANEXT #125

Merged
merged 14 commits into from
Mar 24, 2024
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

## 🆕 News

- **[2024-03-22]** We have supported [**LLaVA-NeXT**](https://llava-vl.github.io/blog/2024-01-30-llava-next/) 🔥🔥🔥
- **[2024-03-21]** We have supported [**DeepSeek-VL**](https://github.com/deepseek-ai/DeepSeek-VL/tree/main) 🔥🔥🔥
- **[2024-03-20]** We have supported users to use a `.env` file to manage all environment variables used in VLMEvalKit, see [**Quickstart**](\Quickstart.md) for more details
- **[2024-03-17]** We have added an API wrapper for [**Step-1V**](https://www.stepfun.com/#step1v) 🔥🔥🔥
Expand All @@ -29,7 +30,6 @@
- **[2024-02-24]** We have supported [**InternVL-Chat Series**](https://github.com/OpenGVLab/InternVL). The models achieve over 80% Top-1 accuracies on MMBench v1.0 [[**Blog**](https://github.com/OpenGVLab/InternVL/blob/main/BLOG.md)] 🔥🔥🔥
- **[2024-02-07]** We have supported two new models: [**MiniCPM-V**](https://huggingface.co/openbmb/MiniCPM-V) and [**OmniLMM-12B**](https://huggingface.co/openbmb/OmniLMM-12B). 🔥🔥🔥
- **[2024-01-30]** We have supported three new models: [**QwenVLMax**](https://huggingface.co/spaces/Qwen/Qwen-VL-Max), [**InternLM-XComposer2-7B**](https://huggingface.co/internlm/internlm-xcomposer2-vl-7b), [**MMAlaya**](https://huggingface.co/DataCanvas/MMAlaya) 🔥🔥🔥
- **[2024-01-30]** We have merged all performance numbers on our leaderboards into a single json file: [**OpenVLM.json**](http://opencompass.openxlab.space/utils/OpenVLM.json)


## 📊 Datasets, Models, and Evaluation Results
Expand Down Expand Up @@ -76,16 +76,19 @@
| [**LLaVA (XTuner)**](https://huggingface.co/xtuner/llava-internlm-7b)🚅 | [**CogVLM-17B-Chat**](https://huggingface.co/THUDM/cogvlm-chat-hf)🚅 | [**SharedCaptioner**](https://huggingface.co/spaces/Lin-Chen/Share-Captioner)🚅 | [**CogVLM-Grounding-Generalist**](https://huggingface.co/THUDM/cogvlm-grounding-generalist-hf)🚅 |
| [**Monkey**](https://github.com/Yuliang-Liu/Monkey)🚅 | [**EMU2 / EMU2-Chat**](https://github.com/baaivision/Emu)🚅🎞️ | [**Yi-VL-[6B/34B]**](https://huggingface.co/01-ai/Yi-VL-6B) | [**MMAlaya**](https://huggingface.co/DataCanvas/MMAlaya)🚅 |
| [**InternLM-XComposer2-7B**](https://huggingface.co/internlm/internlm-xcomposer2-vl-7b)🚅🎞️ | [**MiniCPM-V**](https://huggingface.co/openbmb/MiniCPM-V)🚅 | [**OmniLMM-12B**](https://huggingface.co/openbmb/OmniLMM-12B) | [**InternVL-Chat Series**](https://github.com/OpenGVLab/InternVL)🚅 |
| [**DeepSeek-VL**](https://github.com/deepseek-ai/DeepSeek-VL/tree/main)🎞️ | | | |
| [**DeepSeek-VL**](https://github.com/deepseek-ai/DeepSeek-VL/tree/main)🎞️ | [**LLaVA-NeXT**](https://llava-vl.github.io/blog/2024-01-30-llava-next/)🚅 | | |

🎞️: Support multiple images as inputs, via the `interleave_generate` interface.

🚅: Model can be used without any additional configuration / operation.

**Transformers Version Recommendation: ** Note that some VLMs may not be able to run under certain transformer versions, we recommend the following settings to evaluate each VLM:
**Transformers Version Recommendation: **

Note that some VLMs may not be able to run under certain transformer versions, we recommend the following settings to evaluate each VLM:

- **Please use** `transformers==4.33.0` **for**: `Qwen series`, `Monkey series`, `InternVL series`, `InternLM-XComposer Series`, `mPLUG-Owl2`, `OpenFlamingo v2`, `IDEFICS series`, `VisualGLM`, `MMAlaya`, `SharedCaptioner`, `MiniGPT-4 series`, `InstructBLIP series`, `PandaGPT`.
- **Please use** `transformers==4.37.0 ` **for**: `LLaVA series`, `ShareGPT4V series`, `TransCore-M`, `LLaVA (XTuner)`, `CogVLM Series`, `EMU2 Series`, `Yi-VL Series`, `MiniCPM-V`, `OmniLMM-12B`, `DeepSeek-VL series`.
- **Please use** `transformers==4.39.0 ` **for**: `LLaVA-Next series`.

```python
# Demo
Expand Down
6 changes: 5 additions & 1 deletion vlmeval/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@

qwen_series = {
'qwen_base': partial(QwenVL, model_path='Qwen/Qwen-VL'),
'qwen_chat': partial(QwenVLChat, model_path='Qwen/Qwen-VL-Chat'),
'qwen_chat': partial(QwenVL, model_path='Qwen/Qwen-VL-Chat'),
'monkey':partial(Monkey, model_path='echo840/Monkey'),
'monkey-chat':partial(MonkeyChat, model_path='echo840/Monkey-Chat')
}
Expand All @@ -67,6 +67,10 @@
'llava_v1_7b': partial(LLaVA, model_pth=LLAVA_V1_7B_MODEL_PTH),
'sharegpt4v_7b': partial(LLaVA, model_pth='Lin-Chen/ShareGPT4V-7B'),
'sharegpt4v_13b': partial(LLaVA, model_pth='Lin-Chen/ShareGPT4V-13B'),
'llava_next_vicuna_7b': partial(LLaVA_Next, model_pth='llava-hf/llava-v1.6-vicuna-7b-hf'),
'llava_next_vicuna_13b': partial(LLaVA_Next, model_pth='llava-hf/llava-v1.6-vicuna-13b-hf'),
'llava_next_mistral_7b': partial(LLaVA_Next, model_pth='llava-hf/llava-v1.6-mistral-7b-hf'),
'llava_next_yi_34b': partial(LLaVA_Next, model_pth='llava-hf/llava-v1.6-34b-hf'),
}

internvl_series = {
Expand Down
7 changes: 7 additions & 0 deletions vlmeval/smp/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,10 @@ def pip_install_robust(package):
subprocess.check_call([sys.executable, '-m', 'pip', 'install', package])
retry -= 1
return False


def version_cmp(v1, v2, op='eq'):
from packaging import version
import operator
op_func = getattr(operator, op)
return op_func(version.parse(v1), version.parse(v2))
2 changes: 1 addition & 1 deletion vlmeval/vlm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .emu import Emu
from .idefics import IDEFICS
from .instructblip import InstructBLIP
from .llava import LLaVA
from .llava import LLaVA, LLaVA_Next
from .llava_xtuner import LLaVA_XTuner
from .minicpm_v import MiniCPM_V
from .minigpt4 import MiniGPT4
Expand Down
96 changes: 96 additions & 0 deletions vlmeval/vlm/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,99 @@ def generate(self, image_path, prompt, dataset=None):

output = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
return output


class LLaVA_Next(CustomPrompt):

def __init__(self, model_pth='llava-hf/llava-v1.6-vicuna-7b-hf', **kwargs):
import transformers
assert version_cmp(transformers.__version__, '4.39.0', 'ge')
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
self.model_pth = model_pth
if '34b' in model_pth.lower():
self.processor = LlavaNextProcessor.from_pretrained(self.model_pth, use_fast=False)
else:
self.processor = LlavaNextProcessor.from_pretrained(self.model_pth)
model = LlavaNextForConditionalGeneration.from_pretrained(
self.model_pth, torch_dtype=torch.float16, low_cpu_mem_usage=True)
model = model.eval()
self.model = model.cuda()
kwargs_default = dict(do_sample=False, temperature=0, max_new_tokens=512, top_p=None, num_beams=1)
kwargs_default.update(kwargs)
self.kwargs = kwargs_default
warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ')

def apply_prompt_template(self, prompt):
model_pth = self.model_pth.lower()
if 'mistral' in model_pth:
s = f'[INST] <image>\n {prompt} [/INST]'
elif 'vicuna' in model_pth:
s = (
'A chat between a curious human and an artificial intelligence assistant. '
"The assistant gives helpful, detailed, and polite answers to the human's questions. "
f'USER: <image>\n{prompt} ASSISTANT:'
)
elif '34b' in model_pth:
s = (
f'<|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n<image>\n{prompt}<|im_end|>'
'<|im_start|>assistant\n'
)
else:
raise NotImplementedError(f'Prompt template for {model_pth} not implemented.')
return s

def use_custom_prompt(self, dataset):
assert dataset is not None
if DATASET_TYPE(dataset) == 'multi-choice':
return True
return False

def build_prompt(self, line, dataset=None):
assert self.use_custom_prompt(dataset)
assert dataset is None or isinstance(dataset, str)
tgt_path = self.dump_image(line, dataset)

question = line['question']
hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
if hint is not None:
question = hint + '\n' + question

options = {
cand: line[cand]
for cand in string.ascii_uppercase
if cand in line and not pd.isna(line[cand])
}
for key, item in options.items():
question += f'\n{key}. {item}'
prompt = question

if len(options):
prompt += (
'\n请直接回答选项字母。' if cn_string(prompt) else
"\nAnswer with the option's letter from the given choices directly."
)
else:
prompt += '\n请直接回答问题。' if cn_string(prompt) else '\nAnswer the question directly.'
return {'image': tgt_path, 'text': prompt}

def generate(self, image_path, prompt, dataset=None):
image = Image.open(image_path)
prompt_wtmpl = self.apply_prompt_template(prompt)
inputs = self.processor(prompt_wtmpl, image, return_tensors='pt').to('cuda')
output = self.model.generate(**inputs, **self.kwargs)
answer = self.processor.decode(output[0], skip_special_token=True)
if '<s>' in answer:
answer = answer.replace('<s>', '').strip()
if '[/INST]' in answer:
answer = answer.split('[/INST]')[1].strip()
elif 'ASSISTANT:' in answer:
answer = answer.split('ASSISTANT:')[1].strip()
elif 'assistant\n' in answer:
answer = answer.split('assistant\n')[1].strip()

if '</s>' in answer:
answer = answer.split('</s>')[0].strip()
if '<|im_end|>' in answer:
answer = answer.split('<|im_end|>')[0].strip()

return answer
75 changes: 54 additions & 21 deletions vlmeval/vlm/qwen_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from transformers import AutoModelForCausalLM, AutoTokenizer
import warnings
import os.path as osp
import copy as cp
from vlmeval.smp import isimg, listinstr
import re
from vlmeval.utils import DATASET_TYPE


class QwenVL:
Expand All @@ -13,33 +14,65 @@ class QwenVL:
def __init__(self, model_path='Qwen/Qwen-VL', **kwargs):
assert model_path is not None
self.model_path = model_path
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
tokenizer.padding_side = 'left'
tokenizer.pad_token_id = tokenizer.eod_id
self.tokenizer = tokenizer
self.model = AutoModelForCausalLM.from_pretrained(model_path, device_map='cuda', trust_remote_code=True).eval()
self.kwargs = kwargs
default_kwargs = dict(
do_sample=False,
num_beams=1,
max_new_tokens=512,
min_new_tokens=1,
num_return_sequences=1,
use_cache=True,
output_hidden_states=True,
pad_token_id=tokenizer.eod_id,
eos_token_id=tokenizer.eod_id)
default_kwargs.update(kwargs)
self.kwargs = default_kwargs
warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ')
torch.cuda.empty_cache()

def generate(self, image_path, prompt, dataset=None):
vl_pair = [{'image': image_path}, {'text': prompt}]
query = self.tokenizer.from_list_format(vl_pair)

inputs = self.tokenizer(query, return_tensors='pt')
inputs = inputs.to(self.model.device)
pred = self.model.generate(**inputs, **self.kwargs)
response = self.tokenizer.decode(pred.cpu()[0], skip_special_tokens=False)
response = response.split(prompt)[1].split('<|endoftext|>')[0]
return response
def adjust_kwargs(self, dataset):
kwargs = cp.deepcopy(self.kwargs)
if DATASET_TYPE(dataset) in ['multi-choice', 'Y/N']:
kwargs['max_new_tokens'] = 32
elif DATASET_TYPE(dataset) == 'Caption' and 'COCO' in dataset:
kwargs['max_new_tokens'] = 32
elif DATASET_TYPE(dataset) == 'VQA':
if listinstr(['OCRVQA', 'ChartQA', 'DocVQA'], dataset):
kwargs['max_new_tokens'] = 100
elif listinstr(['TextVQA'], dataset):
kwargs['max_new_tokens'] = 10
return kwargs

def interleave_generate(self, ti_list, dataset=None):
vl_list = [{'image': s} if isimg(s) else {'text': s} for s in ti_list]
query = self.tokenizer.from_list_format(vl_list)
if dataset is not None:
kwargs = self.adjust_kwargs(dataset)
else:
kwargs = self.kwargs
prompt = ''
for s in ti_list:
if isimg(s):
prompt += f'<img>{s}</img>'
else:
prompt += s
if dataset is not None and DATASET_TYPE(dataset) == 'VQA':
prompt += ' Answer:'
encoded = self.tokenizer([prompt], return_tensors='pt', padding='longest')
input_ids = encoded.input_ids.to('cuda')
attention_mask = encoded.attention_mask.to('cuda')

inputs = self.tokenizer(query, return_tensors='pt')
inputs = inputs.to(self.model.device)
pred = self.model.generate(**inputs, **self.kwargs)
response = self.tokenizer.decode(pred.cpu()[0], skip_special_tokens=False)
response = response.split(query)[1].split('<|endoftext|>')[0]
return response
pred = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
**kwargs)
answer = self.tokenizer.decode(pred[0][input_ids.size(1):].cpu(), skip_special_tokens=True).strip()
return answer

def generate(self, image_path, prompt, dataset=None):
return self.interleave_generate([image_path, prompt], dataset)


class QwenVLChat:
Expand Down
Loading