Skip to content

Commit

Permalink
[Fix] Fix Emu3 Inference
Browse files Browse the repository at this point in the history
  • Loading branch information
kennymckormick committed Jan 24, 2025
1 parent c8bb6d1 commit f67440d
Showing 1 changed file with 46 additions and 15 deletions.
61 changes: 46 additions & 15 deletions vlmeval/vlm/emu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .base import BaseModel
from ..smp import *
from huggingface_hub import snapshot_download
from PIL import Image, ImageOps


def get_local_root(repo_id):
Expand All @@ -19,6 +20,42 @@ def get_local_root(repo_id):
return cache_path


def pad_image_to_aspect_ratio(img, max_aspect_ratio=5):
"""
Pad an image to ensure its aspect ratio (width/height or height/width) is less than the given value.
Parameters:
img (PIL.Image): The input PIL Image object.
max_aspect_ratio (float): The maximum allowed aspect ratio.
Returns:
PIL.Image: The padded image.
"""
width, height = img.size
# Calculate the required minimum dimensions to satisfy the aspect ratio constraint
if width > height * max_aspect_ratio:
# Width is too large, pad height
new_height = int(width / max_aspect_ratio + 1)
new_width = width
elif height > width * max_aspect_ratio:
# Height is too large, pad width
new_width = int(height / max_aspect_ratio + 1)
new_height = height
else:
# Aspect ratio is already less than or equal to max_aspect_ratio
return img

# Calculate the padding amounts
pad_width = (new_width - width)
pad_height = (new_height - height)

# Pad the image symmetrically
padding = (pad_width // 2, pad_height // 2, pad_width - pad_width // 2, pad_height - pad_height // 2)
padded_img = ImageOps.expand(img, padding, fill=(0, 0, 0)) # Fill color is black (0, 0, 0) by default

return padded_img


class Emu(BaseModel):

INSTALL_REQ = False
Expand Down Expand Up @@ -52,7 +89,6 @@ def __init__(self,
model = AutoModelForCausalLM.from_pretrained(
model_path, # "BAAI/Emu2-Chat"
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True)

device_map = infer_auto_device_map(
Expand Down Expand Up @@ -103,7 +139,7 @@ def generate_inner(self, message, dataset=None):


class Emu3_chat(BaseModel):
INSTALL_REQ = True
INSTALL_REQ = False
INTERLEAVE = False

def __init__(self, model_path='BAAI/Emu3-Chat', tokenizer_path='BAAI/Emu3-VisionTokenizer', **kwargs):
Expand Down Expand Up @@ -132,19 +168,15 @@ def __init__(self, model_path='BAAI/Emu3-Chat', tokenizer_path='BAAI/Emu3-Vision
tokenizer_path, device_map='cuda', trust_remote_code=True).eval()
self.processor = Emu3Processor(self.image_processor, self.image_tokenizer, self.tokenizer)
self.kwargs = kwargs
self.cuda = cuda

def generate_inner(self, message, dataset=None):
query, images = '', []
for item in message:
if item['type'] == 'image':
images.append(Image.open(item['value']).convert('RGB'))
elif item['type'] == 'text':
query += item['value']
prompt, image = self.message_to_promptimg(message)
image = Image.open(image).convert('RGB')
image = pad_image_to_aspect_ratio(image, 5)

inputs = self.processor(
text=[query],
image=images,
text=[prompt],
image=[image],
mode='U',
return_tensors="pt",
padding="longest",
Expand All @@ -159,9 +191,9 @@ def generate_inner(self, message, dataset=None):
)
# generate
outputs = self.model.generate(
inputs.input_ids.to(self.cuda),
inputs.input_ids.to('cuda'),
GENERATION_CONFIG,
attention_mask=inputs.attention_mask.to(self.cuda),
attention_mask=inputs.attention_mask.to('cuda'),
)

outputs = outputs[:, inputs.input_ids.shape[-1]:]
Expand All @@ -170,7 +202,7 @@ def generate_inner(self, message, dataset=None):


class Emu3_gen(BaseModel):
INSTALL_REQ = True
INSTALL_REQ = False
INTERLEAVE = False

def __init__(self,
Expand Down Expand Up @@ -207,7 +239,6 @@ def __init__(self,
trust_remote_code=True).eval()
self.processor = Emu3Processor(self.image_processor, self.image_tokenizer, self.tokenizer)
self.kwargs = kwargs
self.cuda = cuda
self.output_path = output_path

def generate_inner(self, message, dataset=None):
Expand Down

0 comments on commit f67440d

Please sign in to comment.