-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy paththreed_api.py
91 lines (79 loc) · 3.37 KB
/
threed_api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import sys
import time
import torch
import requests
from io import BytesIO
from pydantic import BaseModel
import base64
import warnings
warnings.filterwarnings("ignore")
import modal
from modal import web_endpoint
from ..common import stub
cache_path = "/vol/cache"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def download_models():
from shap_e.models.download import load_model
xm = load_model('transmitter', device=device)
image_model = load_model('image300M', device=device)
text_model = load_model('text300M', device=device)
image = (
modal.Image.debian_slim(python_version="3.10")
.apt_install("git")
.pip_install(
"pyyaml",
"ipywidgets",
"git+https://github.com/openai/shap-e.git",
"trimesh",
"matplotlib"
)
.run_function(download_models)
)
class Item(BaseModel):
prompt: str = None
@stub.cls(gpu="A100", image=image, timeout=180)
class ThreeD:
def __enter__(self):
from shap_e.models.download import load_model
self.xm = xm = load_model('transmitter', device=device)
self.image_model = image_model = load_model('image300M', device=device)
self.text_model = text_model = load_model('text300M', device=device)
@web_endpoint(method="POST")
def api(self, item: Item):
from PIL import Image
from .open_source.shap_e.app import generate_3D
try:
start = time.time()
init_image = False
if item.prompt.startswith("data:image"):
item.prompt = item.prompt.replace("data:image/png;base64,", "")
item.prompt = item.prompt.replace("data:image/jpeg;base64,", "")
item.prompt = item.prompt.replace("data:image/jpg;base64,", "")
item.prompt = item.prompt.replace("data:image/gif;base64,", "")
item.prompt = item.prompt.replace("data:image/bmp;base64,", "")
item.prompt = item.prompt.replace("data:image/tiff;base64,", "")
item.prompt = item.prompt.replace("data:image/webp;base64,", "")
item.prompt = item.prompt.replace("data:image/avif;base64,", "")
item.prompt = item.prompt.replace("data:image/heif;base64,", "")
item.prompt = item.prompt.replace("data:image/heic;base64,", "")
item.prompt = item.prompt.replace("data:image/jxl;base64,", "")
item.prompt = Image.open(BytesIO(base64.b64decode(item.image))).convert('RGB')
glb_path = generate_3D(item.prompt, self.image_model, self.xm)
init_image = True
else:
data = {"prompt": item.prompt}
# Change this endpoint to match your own
response = requests.post("https://mirageml--stock-image-api.modal.run", json=data)
response = response.json()
# Save image to file
img_data = base64.b64decode(response["png"])
item.prompt = Image.open(BytesIO(img_data)).convert('RGB')
glb_path = generate_3D(item.prompt, self.text_model, self.xm)
print("Time Taken:", time.time() - start)
return {
"glb": str(base64.b64encode(open(glb_path, 'rb').read()).decode('utf-8')),
"init_image": init_image
}
except Exception as e:
print(e)
return ""