forked from X-PLUG/mPLUG-Owl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference.py
153 lines (126 loc) · 5.27 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""
CUDA_VISIBLE_DEVICES=? python inference.py \
--input_csv ../MiniGPT-4/input_csv/visit_bench_single_image.csv \
--output_dir ../MiniGPT-4/output_csv/visit_bench_single_image
python inference.py --input_csv ../MiniGPT-4/input_csv/visit_bench.csv --output_dir ../MiniGPT-4/output_csv/visit_bench
"""
# Load via Huggingface Style
import os
import json
import urllib.request
from urllib.parse import urlparse
import csv
import argparse
from PIL import Image
import requests
from io import BytesIO
from tqdm import tqdm
import torch
from mplug_owl.modeling_mplug_owl import MplugOwlForConditionalGeneration
from mplug_owl.processing_mplug_owl import MplugOwlImageProcessor, MplugOwlProcessor
from mplug_owl.tokenization_mplug_owl import MplugOwlTokenizer
parser = argparse.ArgumentParser(description="Demo")
parser.add_argument('--input_csv', type=str, default='../MiniGPT-4/input_csv/visit_instructions_700.csv')
parser.add_argument('--output_dir', type=str, default='../MiniGPT-4/output_csv/')
parser.add_argument('--model_name', type=str, default='mPLUG-Owl')
parser.add_argument('--verbose', action='store_true', default=False)
args = parser.parse_args()
def read_csv_file(file_path):
data = []
with open(file_path, 'r') as file:
csv_reader = csv.DictReader(file)
for row in csv_reader:
data.append(row)
return csv_reader.fieldnames, data
def check_url(url):
if ' ' in url:
url = url.replace(' ', "%20")
if '+' in url:
url = url.replace('+', "%2B")
return url
def download_image(url, file_path):
if args.verbose:
print(url)
print(file_path)
try:
urllib.request.urlretrieve(url, file_path)
if args.verbose:
print("Image downloaded successfully!")
except urllib.error.URLError as e:
print("Error occurred while downloading the image:", e)
if __name__ == '__main__':
# check output directory
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
args.output_csv = os.path.join(args.output_dir, f'{args.model_name.lower()}.csv')
# Load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
pretrained_ckpt = 'MAGAer13/mplug-owl-llama-7b'
model = MplugOwlForConditionalGeneration.from_pretrained(
pretrained_ckpt,
torch_dtype=torch.bfloat16,
).to(device)
image_processor = MplugOwlImageProcessor.from_pretrained(pretrained_ckpt)
tokenizer = MplugOwlTokenizer.from_pretrained(pretrained_ckpt)
processor = MplugOwlProcessor(image_processor, tokenizer)
generate_kwargs = {
'do_sample': True,
'top_k': 5,
'max_length': 512
}
# Read CSV file
fieldname_list, input_data_list = read_csv_file(args.input_csv)
output_data_list = []
prediction_fieldname = f'{args.model_name} prediction'
fieldname_list.append(prediction_fieldname)
for row in tqdm(input_data_list, total=len(input_data_list), desc='predict'):
if args.verbose:
print(row)
if 'Input.image_url' in row.keys():
image_url_list = [row['Input.image_url']]
elif 'image' in row.keys():
image_url_list = [row['image']]
else:
image_url_list = list(eval(row['images'].replace(', NaN', '')))
# if row['is_multiple_images'] == 'False':
# continue
# prepare instruction prompt
sep = '\n'
prompts = [
f'''The following is a conversation between a curious human and AI assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
{sep.join(['Human: <image>'] * len(image_url_list))}
Human: {row["instruction"]}
AI: ''']
# download image image
try:
image_inputs = []
for img_url in image_url_list:
response = requests.get(check_url(img_url))
image_inputs.append(Image.open(BytesIO(response.content)).convert("RGB"))
except Exception as e:
print(f'Error occurred while downloading the image: {e}')
row[prediction_fieldname] = f'[Error]: {e}'
output_data_list.append(row)
continue
inputs = processor(text=prompts, images=image_inputs, return_tensors='pt')
inputs = {k: v.bfloat16() if v.dtype == torch.float else v for k, v in inputs.items()}
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
res = model.generate(**inputs, **generate_kwargs)
llm_prediction = tokenizer.decode(res.tolist()[0], skip_special_tokens=True)
if args.verbose:
print(f'Question:\n\t{row["instruction"]}')
print(f'Image URL:\t{image_url_list}')
print(f'Answer:\n\t{llm_prediction}')
print('-'*30 + '\n')
row[prediction_fieldname] = llm_prediction
output_data_list.append(row)
with open('tmp.json', 'w') as f:
json.dump(output_data_list, f, indent=2)
# Write to output csv file
output_file = args.output_csv
with open(output_file, 'w', newline='') as file:
csv_writer = csv.DictWriter(file, fieldnames=fieldname_list)
csv_writer.writeheader()
csv_writer.writerows(output_data_list)