diff --git a/main.py b/main.py index 7351584f..99cf35de 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,8 @@ import uvicorn from fastapi import FastAPI from pydantic import BaseModel -from ml_utils import load_model, predict +from ml_utils import load_model, predict, retrain +from typing import List # defining the main app app = FastAPI(title="Iris Predictor", docs_url="/") @@ -21,6 +22,13 @@ class QueryIn(BaseModel): class QueryOut(BaseModel): flower_class: str +# class which is expected in the payload while re-training +class FeedbackIn(BaseModel): + sepal_length: float + sepal_width: float + petal_length: float + petal_width: float + flower_class: str # Route definitions @app.get("/ping") @@ -33,6 +41,11 @@ def predict_flower(query_data: QueryIn): output = {"flower_class": predict(query_data)} return output +@app.post("/feedback_loop", status_code=200) +def feedback_loop(data: List[FeedbackIn]): + retrain(data) + return {"detail": "Feedback loop successful"} + # Main function to start the app when main.py is called if __name__ == "__main__": diff --git a/ml_utils.py b/ml_utils.py index 75a358b8..24703d04 100644 --- a/ml_utils.py +++ b/ml_utils.py @@ -6,6 +6,7 @@ clf = GaussianNB() classes = {0: "Iris Setosa", 1: "Iris Versicolour", 2: "Iris Virginica"} +r_classes = {y: x for x, y in classes.items()} # function to train and load the model during startup def load_model(): @@ -24,3 +25,8 @@ def predict(query_data): prediction = clf.predict([x])[0] print(f"Model prediction: {classes[prediction]}") return classes[prediction] + +def retrain(data): + X = [list(d.dict().values())[:-1] for d in data] + y = [r_classes[d.flower_class] for d in data] + clf.fit(X, y)