Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stable Diffusion 1.4 #70

Closed
wants to merge 55 commits into from
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
a352aaa
Add initial docker setup
bgoelTT Jan 20, 2025
900c08b
Add sd server as background process in flask server
bgoelTT Jan 20, 2025
302f293
Update README and testing
bgoelTT Jan 20, 2025
136841a
Add ready check mechanism, not finished
bgoelTT Jan 20, 2025
572cdb3
Finish readycheck warmup thread
bgoelTT Jan 21, 2025
cd869eb
Add health check endpoint
bgoelTT Jan 21, 2025
f7774d7
Add common API token utility and use
bgoelTT Jan 22, 2025
a92e81c
Return JSON object
bgoelTT Jan 23, 2025
4582863
Use tt-metal release v0.55.0
bgoelTT Feb 2, 2025
f5650c5
Use logging instead of print
bgoelTT Feb 2, 2025
ada9877
Update header year to 2025
bgoelTT Feb 2, 2025
a75d0e6
Refactor repeated functionality
bgoelTT Feb 2, 2025
835fc91
Add mutual exclusion to ready variable
bgoelTT Feb 2, 2025
016af35
Remove locust tests and add healthcheck test
bgoelTT Feb 2, 2025
63d973c
Test more API endpoints and fix bug in locking mechanism
bgoelTT Feb 2, 2025
b7f5251
Use specific fix for sd1.4 web demo
bgoelTT Feb 4, 2025
5439ab2
Use newest fix
bgoelTT Feb 4, 2025
1b39319
Use fix to enable repeated prompts
bgoelTT Feb 4, 2025
4f97159
Add initial docker setup
bgoelTT Jan 20, 2025
eb5e4ff
Add sd server as background process in flask server
bgoelTT Jan 20, 2025
94b5995
Update README and testing
bgoelTT Jan 20, 2025
b90ba77
Add ready check mechanism, not finished
bgoelTT Jan 20, 2025
4e41606
Finish readycheck warmup thread
bgoelTT Jan 21, 2025
0c75a5d
Add health check endpoint
bgoelTT Jan 21, 2025
2b15cef
Add common API token utility and use
bgoelTT Jan 22, 2025
aeaaea9
Return JSON object
bgoelTT Jan 23, 2025
06e1842
Use tt-metal release v0.55.0
bgoelTT Feb 2, 2025
353420a
Use logging instead of print
bgoelTT Feb 2, 2025
c823b01
Update header year to 2025
bgoelTT Feb 2, 2025
5639d91
Refactor repeated functionality
bgoelTT Feb 2, 2025
ef3fc03
Add mutual exclusion to ready variable
bgoelTT Feb 2, 2025
8b5ff77
Remove locust tests and add healthcheck test
bgoelTT Feb 2, 2025
1ba7d3f
Test more API endpoints and fix bug in locking mechanism
bgoelTT Feb 2, 2025
b9fa027
Use specific fix for sd1.4 web demo
bgoelTT Feb 4, 2025
327841f
Use newest fix
bgoelTT Feb 4, 2025
2e8038b
Use fix to enable repeated prompts
bgoelTT Feb 4, 2025
8c1cb42
Merge branch 'ben/sd1.4' of github.com:tenstorrent/tt-inference-serve…
bgoelTT Feb 4, 2025
16bc4e9
Merge branch 'ben/sd1.4' of github.com:tenstorrent/tt-inference-serve…
bgoelTT Feb 5, 2025
12f1419
Initial commit of async worker server
bgoelTT Feb 5, 2025
22a7bf4
Integrate TtStableDiffusion3Pipeline -- debugging segfault
bgoelTT Feb 6, 2025
2e4528f
Add placeholder HF_TOKEN
bgoelTT Feb 6, 2025
57ecc53
Finish integration -- add API key requirements
bgoelTT Feb 7, 2025
5814102
Merge branch 'dev' of github.com:tenstorrent/tt-inference-server into…
bgoelTT Feb 8, 2025
a8af17f
Add model to setup.sh
bgoelTT Feb 9, 2025
f5ad09d
Configure docker compose to use persisent volume
bgoelTT Feb 9, 2025
0b8f2f5
Merge branch 'dev' of github.com:tenstorrent/tt-inference-server into…
bgoelTT Feb 9, 2025
37081ee
Update README.md
bgoelTT Feb 9, 2025
7429218
Add permission entrypoint script to dockerfile
bgoelTT Feb 9, 2025
53dba8c
Update healthcheck route to return JSON response
bgoelTT Feb 10, 2025
ea7fb0b
Use random seed
bgoelTT Feb 10, 2025
ef21d84
Add API key requirement to main inference endpoints
bgoelTT Feb 18, 2025
bdfadf3
Update README instructions for building and running
bgoelTT Feb 20, 2025
fd6103e
Update testing requirements all passing now
bgoelTT Feb 20, 2025
5f2e850
Rename .env.default to .env.build
bgoelTT Feb 23, 2025
31ba1b7
Merge SD3.5 upstream
bgoelTT Feb 23, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions tt-metal-stable-diffusion-1.4/.env.default
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this .env file needed?

I'd advocate for using the setup.sh script to generate the .env file for SD1.4. Also, it can use the format that integrates with tt-studio.

The environment variables defined in it are Docker build environment variables, are you sourcing the .env file for Docker build or does docker compose --env-file on first run pass the env vars through to the correct ARGS in the Dockerfile? We have been using the .env for setup runtime and dependencies, and keeping the Docker Build variables in documentation. I'm not against this, but it's different than other model implementations and not necessary if the user is running Docker build.

Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
TT_METAL_DOCKERFILE_VERSION=v0.53.0-rc34
TT_METAL_COMMIT_SHA_OR_TAG=4da4a5e79a13ece7ff5096c30cef79cb0c504f0e
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reference to this commit in tt-metal anywhere? Curious how you selected this one.

Copy link
Contributor Author

@bgoelTT bgoelTT Feb 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copy+pasted this from the YOLOv4 server because that required a specific commit as the metal YOLOv4 improvements got reverted as couldn't pass CI.

Should I just use the latest release? That would be release v0.55.0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to use release v0.55.0 in 4582863

TT_METAL_COMMIT_DOCKER_TAG=4da4a5e79a13 # 12-character version of TT_METAL_COMMIT_SHA_OR_TAG
IMAGE_VERSION=v0.0.1
# These are secrets and must be stored securely for production environments
JWT_SECRET=testing
51 changes: 51 additions & 0 deletions tt-metal-stable-diffusion-1.4/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# TT Metalium Stable Diffusion 1.4 Inference API

This implementation supports Stable Diffusion 1.4 execution on Worhmole n150 (n300 currently broken).


## Table of Contents
- [Run server](#run-server)
- [JWT_TOKEN Authorization](#jwt_token-authorization)
- [Development](#development)
- [Tests](#tests)


## Run server
To run the SD1.4 inference server, run the following command from the project root at `tt-inference-server`:
```bash
cd tt-inference-server
docker compose --env-file tt-metal-stable-diffusion-1.4/.env.default -f tt-metal-stable-diffusion-1.4/docker-compose.yaml up --build
```

This will start the default Docker container with the entrypoint command set to run the gunicorn server. The next section describes how to override the container's default command with an interractive shell via `bash`.


### JWT_TOKEN Authorization

To authenticate requests use the header `Authorization`. The JWT token can be computed using the script `jwt_util.py`. This is an example:
```bash
cd tt-inference-server/tt-metal-yolov4/server
export JWT_SECRET=<your-secure-secret>
export AUTHORIZATION="Bearer $(python scripts/jwt_util.py --secret ${JWT_SECRET?ERROR env var JWT_SECRET must be set} encode '{"team_id": "tenstorrent", "token_id":"debug-test"}')"
```


## Development
Inside the container you can then start the server with:
```bash
docker compose --env-file tt-metal-stable-diffusion-1.4/.env.default -f tt-metal-stable-diffusion-1.4/docker-compose.yaml run --rm --build inference_server /bin/bash
```

Inside the container, run `cd ~/app/server` to navigate to the server implementation.


## Tests
Tests can be found in `tests/`. The tests have their own dependencies found in `requirements-test.txt`.

First, ensure the server is running (see [how to run the server](#run-server)). Then in a different shell with the base dev `venv` activated:
```bash
cd tt-metal-stable-diffusion-1.4
pip install -r requirements-test.txt
cd tests/
locust --config locust_config.conf
```
26 changes: 26 additions & 0 deletions tt-metal-stable-diffusion-1.4/docker-compose.yaml
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you use this for development often? If you are finding it more useful than having docker run commands in documentation we should consider this for other model implementations.

Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
services:
inference_server:
image: ghcr.io/tenstorrent/tt-inference-server/tt-metal-stable-diffusion-1.4-src-base:${IMAGE_VERSION}-tt-metal-${TT_METAL_COMMIT_DOCKER_TAG}
build:
context: ../
dockerfile: tt-metal-stable-diffusion-1.4/stable-diffusion-1.4.src.Dockerfile
args:
TT_METAL_DOCKERFILE_VERSION: ${TT_METAL_DOCKERFILE_VERSION}
TT_METAL_COMMIT_SHA_OR_TAG: ${TT_METAL_COMMIT_SHA_OR_TAG}
container_name: sd_inference_server
ports:
- "7000:7000"
devices:
- "/dev/tenstorrent:/dev/tenstorrent"
volumes:
- "/dev/hugepages-1G/:/dev/hugepages-1G:rw"
shm_size: "32G"
cap_add:
- ALL
stdin_open: true
tty: true
# this is redundant as docker compose automatically uses the .env file as its in the same directory
# but this explicitly demonstrates its usage
env_file:
- .env.default
restart: no
3 changes: 3 additions & 0 deletions tt-metal-stable-diffusion-1.4/requirements-test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
pillow==10.3.0
locust==2.25.0
pytest==7.2.2
5 changes: 5 additions & 0 deletions tt-metal-stable-diffusion-1.4/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# inference server requirements
flask==3.0.2
gunicorn==21.2.0
requests==2.31.0
pyjwt==2.7.0
217 changes: 217 additions & 0 deletions tt-metal-stable-diffusion-1.4/server/flaskserver.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving it to you to decide, on whether the code should be moved but I don't love having the inference server backend implementation living in tt-metal repo.

The interaction with that server via a combination of HTTP API and file system needs some improvement to be threadsafe. I'd recommend a threadsafe lock on a global dict for example. Django has some facilities for this as well.

Current solution needs example client script, not clear how to call the image API.
does client need to call:

  1. /submit
  2. /update_status
  3. loop over get_latest_time or image_exists
  4. /get_image

That said, I'd recommend message passing using an inter-process queue, instead of using the filesystem and addtion REST API. Either python multiprocessing or a more robust solution using e.g. zmq (https://github.com/zeromq/pyzmq) which is used in vLLM backend.

I believe you have already overhauled this work for SD3.5, but adding these comments for tracking.

Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

from flask import (
abort,
Flask,
request,
jsonify,
send_from_directory,
)
import json
import os
import atexit
import time
import threading
from http import HTTPStatus
from utils.authentication import api_key_required

import subprocess
import signal
import sys

# script to run in background
script = "pytest models/demos/wormhole/stable_diffusion/demo/web_demo/sdserver.py"

# Start script using subprocess
process1 = subprocess.Popen(script, shell=True)


# Function to terminate both processes and kill port 5000
def signal_handler(sig, frame):
print("Terminating processes...")
process1.terminate()
sys.exit(0)


signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)

app = Flask(__name__)

# var to indicate ready state
ready = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Global variable ready : this flag is accessed and modified by multiple threads (warmup and request handlers). This may lead to race conditions. Consider using thread-safe mechanisms like threading.Lock or Event to manage state changes.

Copy link
Contributor Author

@bgoelTT bgoelTT Feb 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Enforced mutual exclusion of ready flag in 835fc91 then fixed a bug in that implementeation in 63d973c


# internal json prompt file
json_file_path = (
"models/demos/wormhole/stable_diffusion/demo/web_demo/input_prompts.json"
)


@app.route("/")
def hello_world():
return "Hello, World!"


def submit_prompt(prompt_file, prompt):
if not os.path.isfile(prompt_file):
with open(prompt_file, "w") as f:
json.dump({"prompts": []}, f)

with open(prompt_file, "r") as f:
prompts_data = json.load(f)

prompts_data["prompts"].append({"prompt": prompt, "status": "not generated"})

with open(prompt_file, "w") as f:
json.dump(prompts_data, f, indent=4)


def warmup():
sample_prompt = "Unicorn on a banana"
# submit sample prompt to perform tracing and server warmup
submit_prompt(json_file_path, sample_prompt)
global ready
while not ready:
with open(json_file_path, "r") as f:
prompts_data = json.load(f)
# sample prompt should be first prompt
sample_prompt_data = prompts_data["prompts"][0]
if sample_prompt_data["prompt"] == sample_prompt:
# TODO: remove this and replace with status check == "done"
# to flip ready flag
if sample_prompt_data["status"] == "done":
ready = True
print(sample_prompt_data["status"])
time.sleep(3)


# start warmup routine in background while server starts
warmup_thread = threading.Thread(target=warmup, name="warmup")
warmup_thread.start()


@app.route("/health")
def health_check():
if not ready:
abort(HTTPStatus.SERVICE_UNAVAILABLE, description="Server is not ready yet")
return jsonify({"message": "OK\n"}), 200


@app.route("/submit", methods=["POST"])
@api_key_required
def submit():
global ready
if not ready:
abort(HTTPStatus.SERVICE_UNAVAILABLE, description="Server is not ready yet")
data = request.get_json()
prompt = data.get("prompt")
print(prompt)

submit_prompt(json_file_path, prompt)

return jsonify({"message": "Prompt received and added to queue."})


@app.route("/update_status", methods=["POST"])
def update_status():
data = request.get_json()
prompt = data.get("prompt")

with open(json_file_path, "r") as f:
prompts_data = json.load(f)

for p in prompts_data["prompts"]:
if p["prompt"] == prompt:
p["status"] = "generated"
break

with open(json_file_path, "w") as f:
json.dump(prompts_data, f, indent=4)

return jsonify({"message": "Prompt status updated to generated."})


@app.route("/get_image", methods=["GET"])
def get_image():
image_name = "interactive_512x512_ttnn.png"
directory = os.getcwd() # Get the current working directory
return send_from_directory(directory, image_name)


@app.route("/image_exists", methods=["GET"])
def image_exists():
image_path = "interactive_512x512_ttnn.png"
if os.path.isfile(image_path):
return jsonify({"exists": True}), 200
else:
return jsonify({"exists": False}), 200


@app.route("/clean_up", methods=["POST"])
def clean_up():
with open(json_file_path, "r") as f:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Repeated code blocks : the code repeatedly reads and writes json_file_path. This can be refactored into utility functions for cleaner logic:

def read_json_file(file_path):
    if not os.path.isfile(file_path):
        return {"prompts": []}
    with open(file_path, "r") as f:
        return json.load(f)

def write_json_file(file_path, data):
    with open(file_path, "w") as f:
        json.dump(data, f, indent=4)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored in a75d0e6

prompts_data = json.load(f)

prompts_data["prompts"] = [
p for p in prompts_data["prompts"] if p["status"] != "done"
]

with open(json_file_path, "w") as f:
json.dump(prompts_data, f, indent=4)

return jsonify({"message": "Cleaned up done prompts."})


@app.route("/get_latest_time", methods=["GET"])
def get_latest_time():
if not os.path.isfile(json_file_path):
return jsonify({"message": "No prompts found"}), 404

with open(json_file_path, "r") as f:
prompts_data = json.load(f)

# Filter prompts that have a total_acc time available
completed_prompts = [p for p in prompts_data["prompts"] if "total_acc" in p]

if not completed_prompts:
return jsonify({"message": "No completed prompts with time available"}), 404

# Get the latest prompt with total_acc
latest_prompt = completed_prompts[-1] # Assuming prompts are in chronological order

return (
jsonify(
{
"prompt": latest_prompt["prompt"],
"total_acc": latest_prompt["total_acc"],
"batch_size": latest_prompt["batch_size"],
"steps": latest_prompt["steps"],
}
),
200,
)


def cleanup():
if os.path.isfile(
"models/demos/wormhole/stable_diffusion/demo/web_demo/input_prompts.json"
):
os.remove(
"models/demos/wormhole/stable_diffusion/demo/web_demo/input_prompts.json"
)
print("Deleted json")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use logging instead of print statements for better production diagnostics.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Used logging instead of print in f5650c5


if os.path.isfile("interactive_512x512_ttnn.png"):
os.remove("interactive_512x512_ttnn.png")
print("Deleted image")

signal_handler(None, None)


atexit.register(cleanup)


def create_server():
return app
15 changes: 15 additions & 0 deletions tt-metal-stable-diffusion-1.4/server/gunicorn.conf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
#
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC


workers = 1
# use 0.0.0.0 for externally accessible
bind = f"0.0.0.0:{7000}"
reload = False
worker_class = "gthread"
threads = 16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

running with 16 threads could have issues with thread safety given current server inter communication.

timeout = 0

# server factory
wsgi_app = "server.flaskserver:create_server()"
Loading