-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
336 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# ernie_vilg | ||
|
||
|模型名称|ernie_vilg| | ||
| :--- | :---: | | ||
|类别|图像-文图生成| | ||
|网络|ERNIE-ViLG| | ||
|数据集|-| | ||
|是否支持Fine-tuning|否| | ||
|模型大小|-| | ||
|最新更新日期|2022-08-02| | ||
|数据指标|-| | ||
|
||
## 一、模型基本信息 | ||
|
||
### 应用效果展示 | ||
|
||
- 输入文本 "宁静的小镇" 风格 "油画" | ||
|
||
- 输出图像 | ||
<p align="center"> | ||
<img src="https://user-images.githubusercontent.com/22424850/183041589-57debf50-80ec-496f-8bb5-42d9d38646dd.png" width = "80%" hspace='10'/> | ||
<br /> | ||
|
||
|
||
### 模型介绍 | ||
|
||
文心ERNIE-ViLG参数规模达到100亿,是目前为止全球最大规模中文跨模态生成模型,在文本生成图像、图像描述等跨模态生成任务上效果全球领先,在图文生成领域MS-COCO、COCO-CN、AIC-ICC等数据集上取得最好效果。你可以输入一段文本描述以及生成风格,模型就会根据输入的内容自动创作出符合要求的图像。 | ||
|
||
## 二、安装 | ||
|
||
- ### 1、环境依赖 | ||
|
||
- paddlepaddle >= 2.0.0 | ||
|
||
- paddlehub >= 2.2.0 | [如何安装PaddleHub](../../../../docs/docs_ch/get_start/installation.rst) | ||
|
||
- ### 2、安装 | ||
|
||
- ```shell | ||
$ hub install ernie_vilg | ||
``` | ||
- 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md) | ||
| [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md) | ||
|
||
|
||
## 三、模型API预测 | ||
|
||
- ### 1、命令行预测 | ||
|
||
- ```shell | ||
$ hub run ernie_vilg --text_prompts "宁静的小镇" --output_dir ernie_vilg_out | ||
``` | ||
|
||
- ### 2、预测代码示例 | ||
|
||
- ```python | ||
import paddlehub as hub | ||
module = hub.Module(name="ernie_vilg") | ||
text_prompts = ["宁静的小镇"] | ||
images = module.generate_image(text_prompts=text_prompts, output_dir='./ernie_vilg_out/') | ||
``` | ||
|
||
- ### 3、API | ||
|
||
- ```python | ||
def __init__(ak: Optional[str]=None, sk: Optional[str]=None) | ||
``` | ||
- 初始化模块,可自定义用于申请访问文心API的ak和sk。 | ||
|
||
- **参数** | ||
- ak:(Optional[str]): 用于申请文心api使用token的ak,可不填。 | ||
- sk:(Optional[str]): 用于申请文心api使用token的sk,可不填。 | ||
|
||
- ```python | ||
def generate_image( | ||
text_prompts:str, | ||
style: Optional[str] = "油画", | ||
topk: Optional[int] = 10, | ||
output_dir: Optional[str] = 'ernievilg_output') | ||
``` | ||
|
||
- 文图生成API,生成文本描述内容的图像。 | ||
|
||
- **参数** | ||
|
||
- text_prompts(str): 输入的语句,描述想要生成的图像的内容。 | ||
- style(Optional[str]): 生成图像的风格,当前支持'油画','水彩','粉笔画','卡通','儿童画','蜡笔画'。 | ||
- topk(Optional[int]): 保存前多少张图,最多保存10张。 | ||
- output_dir(Optional[str]): 保存输出图像的目录,默认为"ernievilg_output"。 | ||
|
||
|
||
- **返回** | ||
- images(List(PIL.Image)): 返回生成的所有图像列表,PIL的Image格式。 | ||
|
||
## 四、更新历史 | ||
|
||
* 1.0.0 | ||
|
||
初始发布 | ||
|
||
```shell | ||
$ hub install ernie_vilg == 1.0.0 | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,230 @@ | ||
import argparse | ||
import ast | ||
import os | ||
import re | ||
import sys | ||
import time | ||
from functools import partial | ||
from io import BytesIO | ||
from typing import List | ||
from typing import Optional | ||
|
||
import requests | ||
from PIL import Image | ||
from tqdm.auto import tqdm | ||
|
||
import paddlehub as hub | ||
from paddlehub.module.module import moduleinfo | ||
from paddlehub.module.module import runnable | ||
from paddlehub.module.module import serving | ||
|
||
|
||
@moduleinfo(name="ernie_vilg", | ||
version="1.0.0", | ||
type="image/text_to_image", | ||
summary="", | ||
author="baidu-nlp", | ||
author_email="[email protected]") | ||
class ErnieVilG: | ||
|
||
def __init__(self, ak=None, sk=None): | ||
""" | ||
:param ak: ak for applying token to request wenxin api. | ||
:param sk: sk for applying token to request wenxin api. | ||
""" | ||
if ak is None or sk is None: | ||
self.ak = 'G26BfAOLpGIRBN5XrOV2eyPA25CE01lE' | ||
self.sk = 'txLZOWIjEqXYMU3lSm05ViW4p9DWGOWs' | ||
else: | ||
self.ak = ak | ||
self.sk = sk | ||
self.token_host = 'https://wenxin.baidu.com/younger/portal/api/oauth/token' | ||
self.token = self._apply_token(self.ak, self.sk) | ||
|
||
def _apply_token(self, ak, sk): | ||
if ak is None or sk is None: | ||
ak = self.ak | ||
sk = self.sk | ||
response = requests.get(self.token_host, | ||
params={ | ||
'grant_type': 'client_credentials', | ||
'client_id': ak, | ||
'client_secret': sk | ||
}) | ||
if response: | ||
res = response.json() | ||
if res['code'] != 0: | ||
print('Request access token error.') | ||
raise RuntimeError("Request access token error.") | ||
else: | ||
print('Request access token error.') | ||
raise RuntimeError("Request access token error.") | ||
return res['data'] | ||
|
||
def generate_image(self, | ||
text_prompts, | ||
style: Optional[str] = "油画", | ||
topk: Optional[int] = 10, | ||
output_dir: Optional[str] = 'ernievilg_output'): | ||
""" | ||
Create image by text prompts using ErnieVilG model. | ||
:param text_prompts: Phrase, sentence, or string of words and phrases describing what the image should look like. | ||
:param style: Image stype, currently supported 油画、水彩、粉笔画、卡通、儿童画、蜡笔画 | ||
:param topk: Top k images to save. | ||
:output_dir: Output directory | ||
""" | ||
if not os.path.exists(output_dir): | ||
os.makedirs(output_dir, exist_ok=True) | ||
token = self.token | ||
create_url = 'https://wenxin.baidu.com/younger/portal/api/rest/1.0/ernievilg/v1/txt2img?from=paddlehub' | ||
get_url = 'https://wenxin.baidu.com/younger/portal/api/rest/1.0/ernievilg/v1/getImg?from=paddlehub' | ||
if isinstance(text_prompts, str): | ||
text_prompts = [text_prompts] | ||
taskids = [] | ||
for text_prompt in text_prompts: | ||
res = requests.post(create_url, | ||
headers={'Content-Type': 'application/x-www-form-urlencoded'}, | ||
data={ | ||
'access_token': token, | ||
"text": text_prompt, | ||
"style": style | ||
}) | ||
res = res.json() | ||
if res['code'] == 4001: | ||
print('请求参数错误') | ||
raise RuntimeError("请求参数错误") | ||
elif res['code'] == 4002: | ||
print('请求参数格式错误,请检查必传参数是否齐全,参数类型等') | ||
raise RuntimeError("请求参数格式错误,请检查必传参数是否齐全,参数类型等") | ||
elif res['code'] == 4003: | ||
print('请求参数中,图片风格不在可选范围内') | ||
raise RuntimeError("请求参数中,图片风格不在可选范围内") | ||
elif res['code'] == 4004: | ||
print('API服务内部错误,可能引起原因有请求超时、模型推理错误等') | ||
raise RuntimeError("API服务内部错误,可能引起原因有请求超时、模型推理错误等") | ||
elif res['code'] == 100 or res['code'] == 110 or res['code'] == 111: | ||
token = self._apply_token(self.ak, self.sk) | ||
res = requests.post(create_url, | ||
headers={'Content-Type': 'application/x-www-form-urlencoded'}, | ||
data={ | ||
'access_token': token, | ||
"text": text_prompt, | ||
"style": style | ||
}) | ||
res = res.json() | ||
if res['code'] != 0: | ||
print("Token失效重新请求后依然发生错误,请检查输入的参数") | ||
raise RuntimeError("Token失效重新请求后依然发生错误,请检查输入的参数") | ||
|
||
taskids.append(res['data']["taskId"]) | ||
|
||
start_time = time.time() | ||
process_bar = tqdm(total=100, unit='%') | ||
results = {} | ||
first_iter = True | ||
while True: | ||
if not taskids: | ||
break | ||
total_time = 0 | ||
has_done = [] | ||
for taskid in taskids: | ||
res = requests.post(get_url, | ||
headers={'Content-Type': 'application/x-www-form-urlencoded'}, | ||
data={ | ||
'access_token': token, | ||
'taskId': {taskid} | ||
}) | ||
res = res.json() | ||
if res['code'] == 4001: | ||
print('请求参数错误') | ||
raise RuntimeError("请求参数错误") | ||
elif res['code'] == 4002: | ||
print('请求参数格式错误,请检查必传参数是否齐全,参数类型等') | ||
raise RuntimeError("请求参数格式错误,请检查必传参数是否齐全,参数类型等") | ||
elif res['code'] == 4003: | ||
print('请求参数中,图片风格不在可选范围内') | ||
raise RuntimeError("请求参数中,图片风格不在可选范围内") | ||
elif res['code'] == 4004: | ||
print('API服务内部错误,可能引起原因有请求超时、模型推理错误等') | ||
raise RuntimeError("API服务内部错误,可能引起原因有请求超时、模型推理错误等") | ||
elif res['code'] == 100 or res['code'] == 110 or res['code'] == 111: | ||
token = self._apply_token(self.ak, self.sk) | ||
res = requests.post(get_url, | ||
headers={'Content-Type': 'application/x-www-form-urlencoded'}, | ||
data={ | ||
'access_token': token, | ||
'taskId': {taskid} | ||
}) | ||
res = res.json() | ||
if res['code'] != 0: | ||
print("Token失效重新请求后依然发生错误,请检查输入的参数") | ||
raise RuntimeError("Token失效重新请求后依然发生错误,请检查输入的参数") | ||
if res['data']['status'] == 1: | ||
has_done.append(res['data']['taskId']) | ||
results[res['data']['text']] = { | ||
'imgUrls': res['data']['imgUrls'], | ||
'waiting': res['data']['waiting'], | ||
'taskId': res['data']['taskId'] | ||
} | ||
total_time = int(re.match('[0-9]+', str(res['data']['waiting'])).group(0)) * 60 | ||
end_time = time.time() | ||
progress_rate = int(((end_time - start_time) / total_time * 100)) if total_time != 0 else 100 | ||
if progress_rate > process_bar.n: | ||
increase_rate = progress_rate - process_bar.n | ||
if progress_rate >= 100: | ||
increase_rate = 100 - process_bar.n | ||
else: | ||
increase_rate = 0 | ||
process_bar.update(increase_rate) | ||
time.sleep(5) | ||
for taskid in has_done: | ||
taskids.remove(taskid) | ||
print('Saving Images...') | ||
result_images = [] | ||
for text, data in results.items(): | ||
for idx, imgdata in enumerate(data['imgUrls']): | ||
image = Image.open(BytesIO(requests.get(imgdata['image']).content)) | ||
image.save(os.path.join(output_dir, '{}_{}.png'.format(text, idx))) | ||
result_images.append(image) | ||
if idx + 1 >= topk: | ||
break | ||
print('Done') | ||
return result_images | ||
|
||
@runnable | ||
def run_cmd(self, argvs): | ||
""" | ||
Run as a command. | ||
""" | ||
self.parser = argparse.ArgumentParser(description="Run the {} module.".format(self.name), | ||
prog='hub run {}'.format(self.name), | ||
usage='%(prog)s', | ||
add_help=True) | ||
self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required") | ||
self.add_module_input_arg() | ||
args = self.parser.parse_args(argvs) | ||
if args.ak is not None and args.sk is not None: | ||
self.ak = args.ak | ||
self.sk = args.sk | ||
self.token = self._apply_token(self.ak, self.sk) | ||
results = self.generate_image(text_prompts=args.text_prompts, | ||
style=args.style, | ||
topk=args.topk, | ||
output_dir=args.output_dir) | ||
return results | ||
|
||
def add_module_input_arg(self): | ||
""" | ||
Add the command input options. | ||
""" | ||
self.arg_input_group.add_argument('--text_prompts', type=str) | ||
self.arg_input_group.add_argument('--style', | ||
type=str, | ||
default='油画', | ||
choices=['油画', '水彩', '粉笔画', '卡通', '儿童画', '蜡笔画'], | ||
help="绘画风格") | ||
self.arg_input_group.add_argument('--topk', type=int, default=10, help="选取保存前多少张图,最多10张") | ||
self.arg_input_group.add_argument('--ak', type=str, default=None, help="申请文心api使用token的ak") | ||
self.arg_input_group.add_argument('--sk', type=str, default=None, help="申请文心api使用token的sk") | ||
self.arg_input_group.add_argument('--output_dir', type=str, default='ernievilg_output') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
requests | ||
tqdm |