-
Notifications
You must be signed in to change notification settings - Fork 72
/
Copy pathmain.py
63 lines (51 loc) · 2.12 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel
from ml_utils import load_model, predict, retrain
from typing import List
from datetime import date, datetime
# defining the main app
app = FastAPI(title="Iris Predictor", docs_url="/")
# calling the load_model during startup.
# this will train the model and keep it loaded for prediction.
app.add_event_handler("startup", load_model)
# class which is expected in the payload
class QueryIn(BaseModel):
sepal_length: float
sepal_width: float
petal_length: float
petal_width: float
# class which is returned in the response
class QueryOut(BaseModel):
flower_class: str
timestamp: datetime
# class which is expected in the payload while re-training
class FeedbackIn(BaseModel):
sepal_length: float
sepal_width: float
petal_length: float
petal_width: float
flower_class: str
# Route definitions
@app.get("/ping")
# Healthcheck route to ensure that the API is up and running
def ping():
return {"ping": "pong", "timestamp": datetime.now().strftime("%b %d %Y %H:%M:%S")}
@app.post("/predict_flower", response_model=QueryOut, status_code=200)
# Route to do the prediction using the ML model defined.
# Payload: QueryIn containing the parameters
# Response: QueryOut containing the flower_class predicted (200)
def predict_flower(query_data: QueryIn):
output = {"flower_class": predict(query_data), "timestamp": datetime.now().strftime("%b %d %Y %H:%M:%S")}
return output
@app.post("/feedback_loop", status_code=200)
# Route to further train the model based on user input in form of feedback loop
# Payload: FeedbackIn containing the parameters and correct flower class
# Response: Dict with detail confirming success (200)
def feedback_loop(data: List[FeedbackIn]):
retrain(data)
return {"detail": "Feedback loop successful", "timestamp": datetime.now().strftime("%b %d %Y %H:%M:%S")}
# Main function to start the app when main.py is called
if __name__ == "__main__":
# Uvicorn is used to run the server and listen for incoming API requests on 0.0.0.0:8888
uvicorn.run("main:app", host="0.0.0.0", port=8888, reload=True)