Skip to content

Commit

Permalink
Fix bugs and update default parameter (#12)
Browse files Browse the repository at this point in the history
* support video qa for intern and qwen

* fix parameter error
  • Loading branch information
philokey authored Feb 10, 2025
1 parent 12b272f commit 0a30d73
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 21 deletions.
12 changes: 6 additions & 6 deletions flagevalmm/common/video_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ def load_image_or_video(
frame_list = [frame]
frames = np.array(frame_list)
elif image_or_video_path.endswith(".mp4"):
decord.bridge.set_bridge("native")
video_reader = decord.VideoReader(image_or_video_path, num_threads=1)
total_frame_num = len(video_reader)
sampled_frame_indices = np.linspace(
start=0, stop=total_frame_num - 1, num=max_num_frames, dtype=int
)
# Ensure the last frame is included
sampled_frame_indices[-1] = total_frame_num - 1
if total_frame_num <= max_num_frames:
sampled_frame_indices = np.arange(total_frame_num)
else:
sampled_frame_indices = np.linspace(
start=0, stop=total_frame_num - 1, num=max_num_frames, dtype=int
)
frames = video_reader.get_batch(sampled_frame_indices)
frames = frames.asnumpy().astype(np.uint8)
else:
Expand Down
7 changes: 6 additions & 1 deletion flagevalmm/models/base_model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,19 @@ def create_data_loader(
self,
dataset_cls: type[ServerDataset],
task_name: str,
task_type: str = "vqa",
collate_fn: Optional[Callable] = None,
batch_size: int = 1,
num_workers: int = 2,
):
if self.accelerator is not None:
with self.accelerator.main_process_first():
dataset = dataset_cls(
task_name, self.server_ip, self.server_port, self.timeout
task_name,
self.server_ip,
self.server_port,
self.timeout,
task_type=task_type,
)
data_loader = DataLoader(
dataset,
Expand Down
4 changes: 2 additions & 2 deletions flagevalmm/models/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ def __init__(
self,
model_name: str,
chat_name: Optional[str] = None,
max_tokens: int = 1024,
max_tokens: int = 4096,
temperature: float = 0.0,
max_image_size: Optional[int] = None,
min_short_side: Optional[int] = None,
max_long_side: Optional[int] = None,
max_num_frames: Optional[int] = 8,
max_num_frames: Optional[int] = 16,
use_cache: bool = False,
api_key: Optional[str] = None,
url: Optional[Union[str, httpx.URL]] = None,
Expand Down
2 changes: 1 addition & 1 deletion model_zoo/vlm/Phi_3.5_v/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def run_one_task(self, task_name: str, meta_info: Dict[str, Any]):
results = []
cnt = 0
data_loader = self.create_data_loader(
CustomDataset, task_name, collate_fn, batch_size=1
CustomDataset, task_name, collate_fn=collate_fn, batch_size=1
)

for question_id, question, images in data_loader:
Expand Down
2 changes: 1 addition & 1 deletion model_zoo/vlm/idefics3/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def run_one_task(self, task_name: str, meta_info: Dict[str, Any]):
cnt = 0

data_loader = self.create_data_loader(
CustomDataset, task_name, collate_fn, batch_size=1, num_workers=2
CustomDataset, task_name, collate_fn=collate_fn, batch_size=1, num_workers=2
)
for question_id, question, images in data_loader:
if cnt == 1:
Expand Down
1 change: 1 addition & 0 deletions model_zoo/vlm/intern_vl/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def run_one_task(self, task_name: str, meta_info: Dict[str, Any]):
data_loader = self.create_data_loader(
CustomDataset,
task_name,
task_type=meta_info["type"],
collate_fn=default_collate_fn,
batch_size=1,
num_workers=2,
Expand Down
37 changes: 27 additions & 10 deletions model_zoo/vlm/qwen_vl/model_adapter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import torch
from typing import Dict, Any
import time
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from transformers import (
Qwen2VLForConditionalGeneration,
Qwen2_5_VLForConditionalGeneration,
AutoTokenizer,
AutoProcessor,
)
from flagevalmm.server import ServerDataset
from flagevalmm.models.base_model_adapter import BaseModelAdapter
from flagevalmm.server.utils import parse_args, process_images_symbol
Expand All @@ -15,6 +20,7 @@ def __getitem__(self, index):
img_path = data["img_path"]
qs = data["question"]
qs, idx = process_images_symbol(qs)
qs = qs.strip()
idx = set(idx)
img_path_idx = []
for i in idx:
Expand All @@ -31,19 +37,31 @@ def model_init(self, task_info: Dict):
torch.set_grad_enabled(False)
with self.accelerator.main_process_first():
tokenizer = AutoTokenizer.from_pretrained(ckpt_path, trust_remote_code=True)
model = Qwen2VLForConditionalGeneration.from_pretrained(
ckpt_path,
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
if "Qwen2.5" in ckpt_path:
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
ckpt_path,
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
else:
model = Qwen2VLForConditionalGeneration.from_pretrained(
ckpt_path,
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)

model = self.accelerator.prepare_model(model, evaluation_mode=True)
self.tokenizer = tokenizer
if hasattr(model, "module"):
model = model.module
self.model = model
self.processor = AutoProcessor.from_pretrained(ckpt_path)
min_pixels = 256 * 28 * 28
max_pixels = 1280 * 28 * 28
self.processor = AutoProcessor.from_pretrained(
ckpt_path, min_pixels=min_pixels, max_pixels=max_pixels
)

def build_message(
self,
Expand Down Expand Up @@ -86,7 +104,6 @@ def run_one_task(self, task_name: str, meta_info: Dict[str, Any]):
img_path_flaten = [p[0] for p in img_path]
qs = qs[0]
messages = self.build_message(qs, image_paths=img_path_flaten)

text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
Expand All @@ -101,7 +118,7 @@ def run_one_task(self, task_name: str, meta_info: Dict[str, Any]):
inputs = inputs.to("cuda")

# Inference
generated_ids = self.model.generate(**inputs, max_new_tokens=1024)
generated_ids = self.model.generate(**inputs, max_new_tokens=4096)
generated_ids_trimmed = [
out_ids[len(in_ids) :]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
Expand Down

0 comments on commit 0a30d73

Please sign in to comment.