Skip to content

Commit

Permalink
[Model] Support HunYuan-Standard-Vision model (#721)
Browse files Browse the repository at this point in the history
* Update hunyuan-standard-vision

* Add HunYuan-Standard-Vision config

* Fix Lint

---------

Co-authored-by: berlinni <[email protected]>
Co-authored-by: kennymckormick <[email protected]>
  • Loading branch information
3 people authored Jan 14, 2025
1 parent 27dd622 commit 77b7d08
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 32 deletions.
136 changes: 104 additions & 32 deletions vlmeval/api/hunyuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,25 @@
import os
import sys
from vlmeval.api.base import BaseAPI
import math
from vlmeval.dataset import DATASET_TYPE
from vlmeval.dataset import img_root_map
from io import BytesIO
import pandas as pd
import requests
import json
import base64
import time


class HunyuanWrapper(BaseAPI):

is_api: bool = True
_apiVersion = '2023-09-01'
_apiVersion = '2024-12-31'
_service = 'hunyuan'

def __init__(self,
model: str = 'hunyuan-vision',
model: str = 'hunyuan-standard-vision',
retry: int = 5,
wait: int = 5,
secret_key: str = None,
Expand Down Expand Up @@ -53,15 +62,88 @@ def __init__(self,
super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)

cred = credential.Credential(self.secret_id, self.secret_key)
httpProfile = HttpProfile()
httpProfile = HttpProfile(reqTimeout=300)
httpProfile.endpoint = self.endpoint
clientProfile = ClientProfile()
clientProfile.httpProfile = httpProfile
self.client = hunyuan_client.HunyuanClient(cred, 'ap-beijing', clientProfile)
self.client = hunyuan_client.HunyuanClient(cred, '', clientProfile)
self.logger.info(
f'Using Endpoint: {self.endpoint}; API Secret ID: {self.secret_id}; API Secret Key: {self.secret_key}'
)

def dump_image(self, line, dataset):
"""Dump the image(s) of the input line to the corresponding dataset folder.
Args:
line (line of pd.DataFrame): The raw input line.
dataset (str): The name of the dataset.
Returns:
str | list[str]: The paths of the dumped images.
"""
ROOT = LMUDataRoot()
assert isinstance(dataset, str)

img_root = os.path.join(ROOT, 'images', img_root_map(dataset) if dataset in img_root_map(dataset) else dataset)
os.makedirs(img_root, exist_ok=True)
if 'image' in line:
if isinstance(line['image'], list):
tgt_path = []
assert 'image_path' in line
for img, im_name in zip(line['image'], line['image_path']):
path = osp.join(img_root, im_name)
if not read_ok(path):
decode_base64_to_image_file(img, path)
tgt_path.append(path)
else:
tgt_path = osp.join(img_root, f"{line['index']}.jpg")
if not read_ok(tgt_path):
decode_base64_to_image_file(line['image'], tgt_path)
tgt_path = [tgt_path]
else:
assert 'image_path' in line
tgt_path = toliststr(line['image_path'])

return tgt_path

def use_custom_prompt(self, dataset_name):
if DATASET_TYPE(dataset_name) == 'MCQ':
return True
else:
return False

def build_prompt(self, line, dataset=None):
assert self.use_custom_prompt(dataset)
assert dataset is None or isinstance(dataset, str)

tgt_path = self.dump_image(line, dataset)

question = line['question']
options = {
cand: line[cand]
for cand in string.ascii_uppercase
if cand in line and not pd.isna(line[cand])
}
options_prompt = 'Options:\n'
for key, item in options.items():
options_prompt += f'{key}. {item}\n'
hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
prompt = ''
if hint is not None:
prompt += f'Hint: {hint}\n'
prompt += f'Question: {question}\n'
if len(options):
prompt += options_prompt
prompt += 'Answer with the option letter from the given choices directly.'

msgs = []
if isinstance(tgt_path, list):
msgs.extend([dict(type='image', value=p) for p in tgt_path])
else:
msgs = [dict(type='image', value=tgt_path)]
msgs.append(dict(type='text', value=prompt))
return msgs

# inputs can be a lvl-2 nested list: [content1, content2, content3, ...]
# content can be a string or a list of image & text
def prepare_itlist(self, inputs):
Expand Down Expand Up @@ -109,36 +191,26 @@ def generate_inner(self, inputs, **kwargs) -> str:
Model=self.model,
Messages=input_msgs,
Temperature=temperature,
TopK=1,
**kwargs)

retry_counter = 0
while retry_counter < 3:
try:
req = models.ChatCompletionsRequest()
req.from_json_string(json.dumps(payload))
resp = self.client.ChatCompletions(req)
resp = json.loads(resp.to_json_string())
answer = resp['Choices'][0]['Message']['Content']
return 0, answer, resp
except TencentCloudSDKException as e:
self.logger.error(f'Got error code: {e.get_code()}')
if e.get_code() == 'ClientNetworkError':
return -1, self.fail_msg + e.get_code(), None
elif e.get_code() in ['InternalError', 'ServerNetworkError']:
if retry_counter == 3:
return -1, self.fail_msg + e.get_code(), None
retry_counter += 1
continue
elif e.get_code() in ['LimitExceeded']:
time.sleep(5)
if retry_counter == 3:
return -1, self.fail_msg + e.get_code(), None
retry_counter += 1
continue
else:
return -1, self.fail_msg + str(e), None

return -1, self.fail_msg, None
try:
req = models.ChatCompletionsRequest()
req.from_json_string(json.dumps(payload))
resp = self.client.ChatCompletions(req)
resp = json.loads(resp.to_json_string())
answer = resp['Choices'][0]['Message']['Content']
return 0, answer, resp
except TencentCloudSDKException as e:
self.logger.error(f'Got error code: {e.get_code()}')
if e.get_code() == 'ClientNetworkError':
return -1, self.fail_msg + e.get_code(), None
elif e.get_code() in ['InternalError', 'ServerNetworkError']:
return -1, self.fail_msg + e.get_code(), None
elif e.get_code() in ['LimitExceeded']:
return -1, self.fail_msg + e.get_code(), None
else:
return -1, self.fail_msg + str(e), None


class HunyuanVision(HunyuanWrapper):
Expand Down
1 change: 1 addition & 0 deletions vlmeval/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
# SenseChat-V
'SenseChat-Vision': partial(SenseChatVisionAPI, model='SenseChat-Vision', temperature=0, retry=10),
'HunYuan-Vision': partial(HunyuanVision, model='hunyuan-vision', temperature=0, retry=10),
'HunYuan-Standard-Vision': partial(HunyuanVision, model='hunyuan-standard-vision', temperature=0, retry=10),
'bailingMM': partial(bailingMMAPI, model='bailingMM-mini', temperature=0, retry=10),
# BlueLM-V
"BlueLM_V": partial(BlueLM_V_API, model='BlueLM-VL-v3.0', temperature=0, retry=10),
Expand Down

0 comments on commit 77b7d08

Please sign in to comment.