-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stable Diffusion 1.4 #70
Changes from 8 commits
a352aaa
900c08b
302f293
136841a
572cdb3
cd869eb
f7774d7
a92e81c
4582863
f5650c5
ada9877
a75d0e6
835fc91
016af35
63d973c
b7f5251
5439ab2
1b39319
4f97159
eb5e4ff
94b5995
b90ba77
4e41606
0c75a5d
2b15cef
aeaaea9
06e1842
353420a
c823b01
5639d91
ef3fc03
8b5ff77
1ba7d3f
b9fa027
327841f
2e8038b
8c1cb42
16bc4e9
12f1419
22a7bf4
2e4528f
57ecc53
5814102
a8af17f
f5ad09d
0b8f2f5
37081ee
7429218
53dba8c
ea7fb0b
ef21d84
bdfadf3
fd6103e
5f2e850
31ba1b7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
TT_METAL_DOCKERFILE_VERSION=v0.53.0-rc34 | ||
TT_METAL_COMMIT_SHA_OR_TAG=4da4a5e79a13ece7ff5096c30cef79cb0c504f0e | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a reference to this commit in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I copy+pasted this from the YOLOv4 server because that required a specific commit as the metal YOLOv4 improvements got reverted as couldn't pass CI. Should I just use the latest release? That would be release v0.55.0? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated to use release v0.55.0 in 4582863 |
||
TT_METAL_COMMIT_DOCKER_TAG=4da4a5e79a13 # 12-character version of TT_METAL_COMMIT_SHA_OR_TAG | ||
IMAGE_VERSION=v0.0.1 | ||
# These are secrets and must be stored securely for production environments | ||
JWT_SECRET=testing |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# TT Metalium Stable Diffusion 1.4 Inference API | ||
|
||
This implementation supports Stable Diffusion 1.4 execution on Worhmole n150 (n300 currently broken). | ||
|
||
|
||
## Table of Contents | ||
- [Run server](#run-server) | ||
- [JWT_TOKEN Authorization](#jwt_token-authorization) | ||
- [Development](#development) | ||
- [Tests](#tests) | ||
|
||
|
||
## Run server | ||
To run the SD1.4 inference server, run the following command from the project root at `tt-inference-server`: | ||
```bash | ||
cd tt-inference-server | ||
docker compose --env-file tt-metal-stable-diffusion-1.4/.env.default -f tt-metal-stable-diffusion-1.4/docker-compose.yaml up --build | ||
``` | ||
|
||
This will start the default Docker container with the entrypoint command set to run the gunicorn server. The next section describes how to override the container's default command with an interractive shell via `bash`. | ||
|
||
|
||
### JWT_TOKEN Authorization | ||
|
||
To authenticate requests use the header `Authorization`. The JWT token can be computed using the script `jwt_util.py`. This is an example: | ||
```bash | ||
cd tt-inference-server/tt-metal-yolov4/server | ||
export JWT_SECRET=<your-secure-secret> | ||
export AUTHORIZATION="Bearer $(python scripts/jwt_util.py --secret ${JWT_SECRET?ERROR env var JWT_SECRET must be set} encode '{"team_id": "tenstorrent", "token_id":"debug-test"}')" | ||
``` | ||
|
||
|
||
## Development | ||
Inside the container you can then start the server with: | ||
```bash | ||
docker compose --env-file tt-metal-stable-diffusion-1.4/.env.default -f tt-metal-stable-diffusion-1.4/docker-compose.yaml run --rm --build inference_server /bin/bash | ||
``` | ||
|
||
Inside the container, run `cd ~/app/server` to navigate to the server implementation. | ||
|
||
|
||
## Tests | ||
Tests can be found in `tests/`. The tests have their own dependencies found in `requirements-test.txt`. | ||
|
||
First, ensure the server is running (see [how to run the server](#run-server)). Then in a different shell with the base dev `venv` activated: | ||
```bash | ||
cd tt-metal-stable-diffusion-1.4 | ||
pip install -r requirements-test.txt | ||
cd tests/ | ||
locust --config locust_config.conf | ||
``` |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you use this for development often? If you are finding it more useful than having |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
services: | ||
inference_server: | ||
image: ghcr.io/tenstorrent/tt-inference-server/tt-metal-stable-diffusion-1.4-src-base:${IMAGE_VERSION}-tt-metal-${TT_METAL_COMMIT_DOCKER_TAG} | ||
build: | ||
context: ../ | ||
dockerfile: tt-metal-stable-diffusion-1.4/stable-diffusion-1.4.src.Dockerfile | ||
args: | ||
TT_METAL_DOCKERFILE_VERSION: ${TT_METAL_DOCKERFILE_VERSION} | ||
TT_METAL_COMMIT_SHA_OR_TAG: ${TT_METAL_COMMIT_SHA_OR_TAG} | ||
container_name: sd_inference_server | ||
ports: | ||
- "7000:7000" | ||
devices: | ||
- "/dev/tenstorrent:/dev/tenstorrent" | ||
volumes: | ||
- "/dev/hugepages-1G/:/dev/hugepages-1G:rw" | ||
shm_size: "32G" | ||
cap_add: | ||
- ALL | ||
stdin_open: true | ||
tty: true | ||
# this is redundant as docker compose automatically uses the .env file as its in the same directory | ||
# but this explicitly demonstrates its usage | ||
env_file: | ||
- .env.default | ||
restart: no |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
pillow==10.3.0 | ||
locust==2.25.0 | ||
pytest==7.2.2 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# inference server requirements | ||
flask==3.0.2 | ||
gunicorn==21.2.0 | ||
requests==2.31.0 | ||
pyjwt==2.7.0 |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Leaving it to you to decide, on whether the code should be moved but I don't love having the inference server backend implementation living in tt-metal repo. The interaction with that server via a combination of HTTP API and file system needs some improvement to be threadsafe. I'd recommend a threadsafe lock on a global dict for example. Django has some facilities for this as well. Current solution needs example client script, not clear how to call the image API.
That said, I'd recommend message passing using an inter-process queue, instead of using the filesystem and addtion REST API. Either python multiprocessing or a more robust solution using e.g. zmq (https://github.com/zeromq/pyzmq) which is used in vLLM backend. I believe you have already overhauled this work for SD3.5, but adding these comments for tracking. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,217 @@ | ||
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from flask import ( | ||
abort, | ||
Flask, | ||
request, | ||
jsonify, | ||
send_from_directory, | ||
) | ||
import json | ||
import os | ||
import atexit | ||
import time | ||
import threading | ||
from http import HTTPStatus | ||
from utils.authentication import api_key_required | ||
|
||
import subprocess | ||
import signal | ||
import sys | ||
|
||
# script to run in background | ||
script = "pytest models/demos/wormhole/stable_diffusion/demo/web_demo/sdserver.py" | ||
|
||
# Start script using subprocess | ||
process1 = subprocess.Popen(script, shell=True) | ||
|
||
|
||
# Function to terminate both processes and kill port 5000 | ||
def signal_handler(sig, frame): | ||
print("Terminating processes...") | ||
process1.terminate() | ||
sys.exit(0) | ||
|
||
|
||
signal.signal(signal.SIGINT, signal_handler) | ||
signal.signal(signal.SIGTERM, signal_handler) | ||
|
||
app = Flask(__name__) | ||
|
||
# var to indicate ready state | ||
ready = False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Global variable There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
# internal json prompt file | ||
json_file_path = ( | ||
"models/demos/wormhole/stable_diffusion/demo/web_demo/input_prompts.json" | ||
) | ||
|
||
|
||
@app.route("/") | ||
def hello_world(): | ||
return "Hello, World!" | ||
|
||
|
||
def submit_prompt(prompt_file, prompt): | ||
if not os.path.isfile(prompt_file): | ||
with open(prompt_file, "w") as f: | ||
json.dump({"prompts": []}, f) | ||
|
||
with open(prompt_file, "r") as f: | ||
prompts_data = json.load(f) | ||
|
||
prompts_data["prompts"].append({"prompt": prompt, "status": "not generated"}) | ||
|
||
with open(prompt_file, "w") as f: | ||
json.dump(prompts_data, f, indent=4) | ||
|
||
|
||
def warmup(): | ||
sample_prompt = "Unicorn on a banana" | ||
# submit sample prompt to perform tracing and server warmup | ||
submit_prompt(json_file_path, sample_prompt) | ||
global ready | ||
while not ready: | ||
with open(json_file_path, "r") as f: | ||
prompts_data = json.load(f) | ||
# sample prompt should be first prompt | ||
sample_prompt_data = prompts_data["prompts"][0] | ||
if sample_prompt_data["prompt"] == sample_prompt: | ||
# TODO: remove this and replace with status check == "done" | ||
# to flip ready flag | ||
if sample_prompt_data["status"] == "done": | ||
ready = True | ||
print(sample_prompt_data["status"]) | ||
time.sleep(3) | ||
|
||
|
||
# start warmup routine in background while server starts | ||
warmup_thread = threading.Thread(target=warmup, name="warmup") | ||
warmup_thread.start() | ||
|
||
|
||
@app.route("/health") | ||
def health_check(): | ||
if not ready: | ||
abort(HTTPStatus.SERVICE_UNAVAILABLE, description="Server is not ready yet") | ||
return jsonify({"message": "OK\n"}), 200 | ||
|
||
|
||
@app.route("/submit", methods=["POST"]) | ||
@api_key_required | ||
def submit(): | ||
global ready | ||
if not ready: | ||
abort(HTTPStatus.SERVICE_UNAVAILABLE, description="Server is not ready yet") | ||
data = request.get_json() | ||
prompt = data.get("prompt") | ||
print(prompt) | ||
|
||
submit_prompt(json_file_path, prompt) | ||
|
||
return jsonify({"message": "Prompt received and added to queue."}) | ||
|
||
|
||
@app.route("/update_status", methods=["POST"]) | ||
def update_status(): | ||
data = request.get_json() | ||
prompt = data.get("prompt") | ||
|
||
with open(json_file_path, "r") as f: | ||
prompts_data = json.load(f) | ||
|
||
for p in prompts_data["prompts"]: | ||
if p["prompt"] == prompt: | ||
p["status"] = "generated" | ||
break | ||
|
||
with open(json_file_path, "w") as f: | ||
json.dump(prompts_data, f, indent=4) | ||
|
||
return jsonify({"message": "Prompt status updated to generated."}) | ||
|
||
|
||
@app.route("/get_image", methods=["GET"]) | ||
def get_image(): | ||
image_name = "interactive_512x512_ttnn.png" | ||
directory = os.getcwd() # Get the current working directory | ||
return send_from_directory(directory, image_name) | ||
|
||
|
||
@app.route("/image_exists", methods=["GET"]) | ||
def image_exists(): | ||
image_path = "interactive_512x512_ttnn.png" | ||
if os.path.isfile(image_path): | ||
return jsonify({"exists": True}), 200 | ||
else: | ||
return jsonify({"exists": False}), 200 | ||
|
||
|
||
@app.route("/clean_up", methods=["POST"]) | ||
def clean_up(): | ||
with open(json_file_path, "r") as f: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Repeated code blocks : the code repeatedly reads and writes def read_json_file(file_path):
if not os.path.isfile(file_path):
return {"prompts": []}
with open(file_path, "r") as f:
return json.load(f)
def write_json_file(file_path, data):
with open(file_path, "w") as f:
json.dump(data, f, indent=4) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Refactored in a75d0e6 |
||
prompts_data = json.load(f) | ||
|
||
prompts_data["prompts"] = [ | ||
p for p in prompts_data["prompts"] if p["status"] != "done" | ||
] | ||
|
||
with open(json_file_path, "w") as f: | ||
json.dump(prompts_data, f, indent=4) | ||
|
||
return jsonify({"message": "Cleaned up done prompts."}) | ||
|
||
|
||
@app.route("/get_latest_time", methods=["GET"]) | ||
def get_latest_time(): | ||
if not os.path.isfile(json_file_path): | ||
return jsonify({"message": "No prompts found"}), 404 | ||
|
||
with open(json_file_path, "r") as f: | ||
prompts_data = json.load(f) | ||
|
||
# Filter prompts that have a total_acc time available | ||
completed_prompts = [p for p in prompts_data["prompts"] if "total_acc" in p] | ||
|
||
if not completed_prompts: | ||
return jsonify({"message": "No completed prompts with time available"}), 404 | ||
|
||
# Get the latest prompt with total_acc | ||
latest_prompt = completed_prompts[-1] # Assuming prompts are in chronological order | ||
|
||
return ( | ||
jsonify( | ||
{ | ||
"prompt": latest_prompt["prompt"], | ||
"total_acc": latest_prompt["total_acc"], | ||
"batch_size": latest_prompt["batch_size"], | ||
"steps": latest_prompt["steps"], | ||
} | ||
), | ||
200, | ||
) | ||
|
||
|
||
def cleanup(): | ||
if os.path.isfile( | ||
"models/demos/wormhole/stable_diffusion/demo/web_demo/input_prompts.json" | ||
): | ||
os.remove( | ||
"models/demos/wormhole/stable_diffusion/demo/web_demo/input_prompts.json" | ||
) | ||
print("Deleted json") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Used |
||
|
||
if os.path.isfile("interactive_512x512_ttnn.png"): | ||
os.remove("interactive_512x512_ttnn.png") | ||
print("Deleted image") | ||
|
||
signal_handler(None, None) | ||
|
||
|
||
atexit.register(cleanup) | ||
|
||
|
||
def create_server(): | ||
return app |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC | ||
|
||
|
||
workers = 1 | ||
# use 0.0.0.0 for externally accessible | ||
bind = f"0.0.0.0:{7000}" | ||
reload = False | ||
worker_class = "gthread" | ||
threads = 16 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. running with 16 threads could have issues with thread safety given current server inter communication. |
||
timeout = 0 | ||
|
||
# server factory | ||
wsgi_app = "server.flaskserver:create_server()" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this .env file needed?
I'd advocate for using the setup.sh script to generate the .env file for SD1.4. Also, it can use the format that integrates with tt-studio.
The environment variables defined in it are Docker build environment variables, are you sourcing the .env file for
Docker build
or doesdocker compose --env-file
on first run pass the env vars through to the correct ARGS in the Dockerfile? We have been using the .env for setup runtime and dependencies, and keeping the Docker Build variables in documentation. I'm not against this, but it's different than other model implementations and not necessary if the user is running Docker build.