Skip to content

Commit

Permalink
Add ready check mechanism, not finished
Browse files Browse the repository at this point in the history
  • Loading branch information
bgoelTT committed Jan 20, 2025
1 parent 302f293 commit 136841a
Showing 1 changed file with 36 additions and 26 deletions.
62 changes: 36 additions & 26 deletions tt-metal-stable-diffusion-1.4/server/flaskserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

from flask import (
abort,
Flask,
request,
jsonify,
Expand All @@ -11,6 +12,8 @@
import json
import os
import atexit
import time
from http import HTTPStatus

import subprocess
import signal
Expand Down Expand Up @@ -52,34 +55,44 @@ def signal_handler(sig, frame):

app = Flask(__name__)

# var to indicate ready state
ready = False

# internal json prompt file
json_file_path = (
"models/demos/wormhole/stable_diffusion/demo/web_demo/input_prompts.json"
)


@app.route("/")
def hello_world():
return "Hello, World!"


@app.route("/submit", methods=["POST"])
def submit():
data = request.get_json()
prompt = data.get("prompt")
print(prompt)

json_file_path = (
"models/demos/wormhole/stable_diffusion/demo/web_demo/input_prompts.json"
)

if not os.path.isfile(json_file_path):
with open(json_file_path, "w") as f:
def submit_prompt(prompt_file, prompt):
if not os.path.isfile(prompt_file):
with open(prompt_file, "w") as f:
json.dump({"prompts": []}, f)

with open(json_file_path, "r") as f:
with open(prompt_file, "r") as f:
prompts_data = json.load(f)

prompts_data["prompts"].append({"prompt": prompt, "status": "not generated"})

with open(json_file_path, "w") as f:
with open(prompt_file, "w") as f:
json.dump(prompts_data, f, indent=4)


@app.route("/submit", methods=["POST"])
def submit():
if not ready:
abort(HTTPStatus.SERVICE_UNAVAILABLE, description="Server is not ready yet")
data = request.get_json()
prompt = data.get("prompt")
print(prompt)

submit_prompt(json_file_path, prompt)

return jsonify({"message": "Prompt received and added to queue."})


Expand All @@ -88,10 +101,6 @@ def update_status():
data = request.get_json()
prompt = data.get("prompt")

json_file_path = (
"models/demos/wormhole/stable_diffusion/demo/web_demo/input_prompts.json"
)

with open(json_file_path, "r") as f:
prompts_data = json.load(f)

Expand Down Expand Up @@ -124,10 +133,6 @@ def image_exists():

@app.route("/clean_up", methods=["POST"])
def clean_up():
json_file_path = (
"models/demos/wormhole/stable_diffusion/demo/web_demo/input_prompts.json"
)

with open(json_file_path, "r") as f:
prompts_data = json.load(f)

Expand All @@ -143,10 +148,6 @@ def clean_up():

@app.route("/get_latest_time", methods=["GET"])
def get_latest_time():
json_file_path = (
"models/demos/wormhole/stable_diffusion/demo/web_demo/input_prompts.json"
)

if not os.path.isfile(json_file_path):
return jsonify({"message": "No prompts found"}), 404

Expand Down Expand Up @@ -199,4 +200,13 @@ def cleanup():


def create_server():
sample_prompt = "Unicorn on a banana"
submit_prompt(json_file_path, sample_prompt)
while not ready:
with open(json_file_path, "r") as f:
prompts_data = json.load(f)
for p in prompts_data["prompts"]:
if p["prompt"] == sample_prompt:
print(p["status"])
time.sleep(2)
return app

0 comments on commit 136841a

Please sign in to comment.