diff --git a/.github/workflows/cicd.yml b/.github/workflows/cicd.yml index 6720c644..94bfae50 100644 --- a/.github/workflows/cicd.yml +++ b/.github/workflows/cicd.yml @@ -6,7 +6,8 @@ on: [pull_request, push] jobs: # first job to test the application using pytest build: - runs-on: ubuntu-latest + runs-on: ubuntu-latest # choose the OS for running the action + # define the individual sequential steps to be run steps: - name: Checkout the repository uses: actions/checkout@v2 @@ -23,11 +24,13 @@ jobs: # second job to zip the codebase and upload it as an artifact when build succeeds upload_zip: - runs-on: ubuntu-latest + runs-on: ubuntu-latest # choose the OS for running the action needs: build # only run this action for pushes if: ${{ github.event_name == 'push' }} + + # define the individual sequential steps to be run steps: - name: Checkout the repository uses: actions/checkout@v2 diff --git a/main.py b/main.py index 99cf35de..48d45d79 100644 --- a/main.py +++ b/main.py @@ -7,7 +7,8 @@ # defining the main app app = FastAPI(title="Iris Predictor", docs_url="/") -# calling the load_model during startup +# calling the load_model during startup. +# this will train the model and keep it loaded for prediction. app.add_event_handler("startup", load_model) # class which is expected in the payload @@ -32,16 +33,23 @@ class FeedbackIn(BaseModel): # Route definitions @app.get("/ping") +# Healthcheck route to ensure that the API is up and running def ping(): return {"ping": "pong"} @app.post("/predict_flower", response_model=QueryOut, status_code=200) +# Route to do the prediction using the ML model defined. +# Payload: QueryIn containing the parameters +# Response: QueryOut containing the flower_class predicted (200) def predict_flower(query_data: QueryIn): output = {"flower_class": predict(query_data)} return output @app.post("/feedback_loop", status_code=200) +# Route to further train the model based on user input in form of feedback loop +# Payload: FeedbackIn containing the parameters and correct flower class +# Response: Dict with detail confirming success (200) def feedback_loop(data: List[FeedbackIn]): retrain(data) return {"detail": "Feedback loop successful"} @@ -49,4 +57,5 @@ def feedback_loop(data: List[FeedbackIn]): # Main function to start the app when main.py is called if __name__ == "__main__": + # Uvicorn is used to run the server and listen for incoming API requests on 0.0.0.0:8888 uvicorn.run("main:app", host="0.0.0.0", port=8888, reload=True) diff --git a/ml_utils.py b/ml_utils.py index 24703d04..bdd4dc83 100644 --- a/ml_utils.py +++ b/ml_utils.py @@ -3,18 +3,23 @@ from sklearn.naive_bayes import GaussianNB from sklearn.metrics import accuracy_score +# define a Gaussain NB classifier clf = GaussianNB() +# define the class encodings and reverse encodings classes = {0: "Iris Setosa", 1: "Iris Versicolour", 2: "Iris Virginica"} r_classes = {y: x for x, y in classes.items()} # function to train and load the model during startup def load_model(): + # load the dataset from the official sklearn datasets X, y = datasets.load_iris(return_X_y=True) + # do the test-train split and train the model X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2) clf.fit(X_train, y_train) + # calculate the print the accuracy score acc = accuracy_score(y_test, clf.predict(X_test)) print(f"Model trained with accuracy: {round(acc, 3)}") @@ -26,7 +31,11 @@ def predict(query_data): print(f"Model prediction: {classes[prediction]}") return classes[prediction] +# function to retrain the model as part of the feedback loop def retrain(data): + # pull out the relevant X and y from the FeedbackIn object X = [list(d.dict().values())[:-1] for d in data] y = [r_classes[d.flower_class] for d in data] + + # fit the classifier again based on the new data obtained clf.fit(X, y) diff --git a/test_app.py b/test_app.py index 0541cce9..b65fc902 100644 --- a/test_app.py +++ b/test_app.py @@ -5,12 +5,14 @@ def test_ping(): with TestClient(app) as client: response = client.get("/ping") + # asserting the correct response is received assert response.status_code == 200 assert response.json() == {"ping": "pong"} # test to check if Iris Virginica is classified correctly def test_pred_virginica(): + # defining a sample payload for the testcase payload = { "sepal_length": 3, "sepal_width": 5, @@ -19,5 +21,6 @@ def test_pred_virginica(): } with TestClient(app) as client: response = client.post("/predict_flower", json=payload) + # asserting the correct response is received assert response.status_code == 200 assert response.json() == {"flower_class": "Iris Virginica"}