From 136841a7905ebb57a47bd0bc977e2e2a3b2b8e00 Mon Sep 17 00:00:00 2001 From: bgoelTT Date: Mon, 20 Jan 2025 11:20:49 -0500 Subject: [PATCH] Add ready check mechanism, not finished --- .../server/flaskserver.py | 62 +++++++++++-------- 1 file changed, 36 insertions(+), 26 deletions(-) diff --git a/tt-metal-stable-diffusion-1.4/server/flaskserver.py b/tt-metal-stable-diffusion-1.4/server/flaskserver.py index a7a756c..ecf1f91 100644 --- a/tt-metal-stable-diffusion-1.4/server/flaskserver.py +++ b/tt-metal-stable-diffusion-1.4/server/flaskserver.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 from flask import ( + abort, Flask, request, jsonify, @@ -11,6 +12,8 @@ import json import os import atexit +import time +from http import HTTPStatus import subprocess import signal @@ -52,34 +55,44 @@ def signal_handler(sig, frame): app = Flask(__name__) +# var to indicate ready state +ready = False + +# internal json prompt file +json_file_path = ( + "models/demos/wormhole/stable_diffusion/demo/web_demo/input_prompts.json" +) + @app.route("/") def hello_world(): return "Hello, World!" -@app.route("/submit", methods=["POST"]) -def submit(): - data = request.get_json() - prompt = data.get("prompt") - print(prompt) - - json_file_path = ( - "models/demos/wormhole/stable_diffusion/demo/web_demo/input_prompts.json" - ) - - if not os.path.isfile(json_file_path): - with open(json_file_path, "w") as f: +def submit_prompt(prompt_file, prompt): + if not os.path.isfile(prompt_file): + with open(prompt_file, "w") as f: json.dump({"prompts": []}, f) - with open(json_file_path, "r") as f: + with open(prompt_file, "r") as f: prompts_data = json.load(f) prompts_data["prompts"].append({"prompt": prompt, "status": "not generated"}) - with open(json_file_path, "w") as f: + with open(prompt_file, "w") as f: json.dump(prompts_data, f, indent=4) + +@app.route("/submit", methods=["POST"]) +def submit(): + if not ready: + abort(HTTPStatus.SERVICE_UNAVAILABLE, description="Server is not ready yet") + data = request.get_json() + prompt = data.get("prompt") + print(prompt) + + submit_prompt(json_file_path, prompt) + return jsonify({"message": "Prompt received and added to queue."}) @@ -88,10 +101,6 @@ def update_status(): data = request.get_json() prompt = data.get("prompt") - json_file_path = ( - "models/demos/wormhole/stable_diffusion/demo/web_demo/input_prompts.json" - ) - with open(json_file_path, "r") as f: prompts_data = json.load(f) @@ -124,10 +133,6 @@ def image_exists(): @app.route("/clean_up", methods=["POST"]) def clean_up(): - json_file_path = ( - "models/demos/wormhole/stable_diffusion/demo/web_demo/input_prompts.json" - ) - with open(json_file_path, "r") as f: prompts_data = json.load(f) @@ -143,10 +148,6 @@ def clean_up(): @app.route("/get_latest_time", methods=["GET"]) def get_latest_time(): - json_file_path = ( - "models/demos/wormhole/stable_diffusion/demo/web_demo/input_prompts.json" - ) - if not os.path.isfile(json_file_path): return jsonify({"message": "No prompts found"}), 404 @@ -199,4 +200,13 @@ def cleanup(): def create_server(): + sample_prompt = "Unicorn on a banana" + submit_prompt(json_file_path, sample_prompt) + while not ready: + with open(json_file_path, "r") as f: + prompts_data = json.load(f) + for p in prompts_data["prompts"]: + if p["prompt"] == sample_prompt: + print(p["status"]) + time.sleep(2) return app