forked from Dmmc123/taim-gan
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
130 lines (118 loc) · 3.94 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import gradio
import numpy as np # this should come first to mitigate mlk-service bug
from src.models.utils import get_image_arr, load_model
from src.data import TAIMGANTokenizer
from torchvision import transforms
from src.config import config_dict
from pathlib import Path
from PIL import Image
import gradio as gr
import logging
import torch
from src.models.modules import (
VGGEncoder,
InceptionEncoder,
TextEncoder,
Generator
)
##########
# PARAMS #
##########
IMG_CHANS = 3 # RGB channels for image
IMG_HW = 256 # height and width of images
HIDDEN_DIM = 128 # hidden dimensions of lstm cell in one direction
C = 2 * HIDDEN_DIM # length of embeddings
Ng = config_dict["Ng"]
cond_dim = config_dict["condition_dim"]
z_dim = config_dict["noise_dim"]
###############
# LOAD MODELS #
###############
models = {
"COCO": {
"dir": "weights/coco"
},
"Bird": {
"dir": "weights/bird"
},
"UTKFace": {
"dir": "weights/utkface"
}
}
for model_name in models:
# create tokenizer
models[model_name]["tokenizer"] = TAIMGANTokenizer(captions_path=f"{models[model_name]['dir']}/captions.pickle")
vocab_size = len(models[model_name]["tokenizer"].word_to_ix)
# instantiate models
models[model_name]["generator"] = Generator(Ng=Ng, D=C, conditioning_dim=cond_dim, noise_dim=z_dim).eval()
models[model_name]["lstm"] = TextEncoder(vocab_size=vocab_size, emb_dim=C, hidden_dim=HIDDEN_DIM).eval()
models[model_name]["vgg"] = VGGEncoder().eval()
models[model_name]["inception"] = InceptionEncoder(D=C).eval()
# load models
load_model(
generator=models[model_name]["generator"],
discriminator=None,
image_encoder=models[model_name]["inception"],
text_encoder=models[model_name]["lstm"],
output_dir=Path(models[model_name]["dir"]),
device=torch.device("cpu")
)
def change_image_with_text(image: Image, text: str, model_name: str) -> Image:
"""
Create an image modified by text from the original image
and save it with _modified postfix
:param gr.Image image: Path to the image
:param str text: Desired caption
"""
global models
tokenizer = models[model_name]["tokenizer"]
G = models[model_name]["generator"]
lstm = models[model_name]["lstm"]
inception = models[model_name]["inception"]
vgg = models[model_name]["vgg"]
# generate some noise
noise = torch.rand(z_dim).unsqueeze(0)
# transform input text and get masks with embeddings
tokens = torch.tensor(tokenizer.encode(text)).unsqueeze(0)
mask = (tokens == tokenizer.pad_token_id)
word_embs, sent_embs = lstm(tokens)
# open the image and transform it to the tensor
image = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((IMG_HW, IMG_HW)),
transforms.Normalize(
mean=(0.5, 0.5, 0.5),
std=(0.5, 0.5, 0.5)
)
])(image).unsqueeze(0)
# obtain visual features of the image
vgg_features = vgg(image)
local_features, global_features = inception(image)
# generate new image from the old one
fake_image, _, _ = G(noise, sent_embs, word_embs, global_features,
local_features, vgg_features, mask)
# denormalize the image
fake_image = Image.fromarray(get_image_arr(fake_image)[0])
# return image in gradio format
return fake_image
##########
# GRADIO #
##########
gradio.close_all()
demo = gr.Interface(
fn=change_image_with_text,
inputs=[gr.Image(type="pil"), "text", gr.inputs.Dropdown(list(models.keys()))],
outputs=gr.Image(type="pil"),
examples=[
["src/data/stubs/bird.jpg", "black bird with blue wings", "Bird"],
["src/data/stubs/lady.jpg", "lady with blue eyes", "UTKFace"],
["src/data/stubs/bird.jpg", "white bird with black wings", "Bird"]
]
)
print("Please visit http://0.0.0.0:7861")
demo.launch(
server_name="0.0.0.0",
server_port=7861,
show_error=True,
debug=True
)