Skip to content

Commit

Permalink
[add] elaborate comments
Browse files Browse the repository at this point in the history
  • Loading branch information
zjesko committed Jul 1, 2021
1 parent 1bc156c commit c074254
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 3 deletions.
7 changes: 5 additions & 2 deletions .github/workflows/cicd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ on: [pull_request, push]
jobs:
# first job to test the application using pytest
build:
runs-on: ubuntu-latest
runs-on: ubuntu-latest # choose the OS for running the action
# define the individual sequential steps to be run
steps:
- name: Checkout the repository
uses: actions/checkout@v2
Expand All @@ -23,11 +24,13 @@ jobs:
# second job to zip the codebase and upload it as an artifact when build succeeds
upload_zip:
runs-on: ubuntu-latest
runs-on: ubuntu-latest # choose the OS for running the action
needs: build

# only run this action for pushes
if: ${{ github.event_name == 'push' }}

# define the individual sequential steps to be run
steps:
- name: Checkout the repository
uses: actions/checkout@v2
Expand Down
11 changes: 10 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
# defining the main app
app = FastAPI(title="Iris Predictor", docs_url="/")

# calling the load_model during startup
# calling the load_model during startup.
# this will train the model and keep it loaded for prediction.
app.add_event_handler("startup", load_model)

# class which is expected in the payload
Expand All @@ -32,21 +33,29 @@ class FeedbackIn(BaseModel):

# Route definitions
@app.get("/ping")
# Healthcheck route to ensure that the API is up and running
def ping():
return {"ping": "pong"}


@app.post("/predict_flower", response_model=QueryOut, status_code=200)
# Route to do the prediction using the ML model defined.
# Payload: QueryIn containing the parameters
# Response: QueryOut containing the flower_class predicted (200)
def predict_flower(query_data: QueryIn):
output = {"flower_class": predict(query_data)}
return output

@app.post("/feedback_loop", status_code=200)
# Route to further train the model based on user input in form of feedback loop
# Payload: FeedbackIn containing the parameters and correct flower class
# Response: Dict with detail confirming success (200)
def feedback_loop(data: List[FeedbackIn]):
retrain(data)
return {"detail": "Feedback loop successful"}


# Main function to start the app when main.py is called
if __name__ == "__main__":
# Uvicorn is used to run the server and listen for incoming API requests on 0.0.0.0:8888
uvicorn.run("main:app", host="0.0.0.0", port=8888, reload=True)
9 changes: 9 additions & 0 deletions ml_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,23 @@
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import accuracy_score

# define a Gaussain NB classifier
clf = GaussianNB()

# define the class encodings and reverse encodings
classes = {0: "Iris Setosa", 1: "Iris Versicolour", 2: "Iris Virginica"}
r_classes = {y: x for x, y in classes.items()}

# function to train and load the model during startup
def load_model():
# load the dataset from the official sklearn datasets
X, y = datasets.load_iris(return_X_y=True)

# do the test-train split and train the model
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
clf.fit(X_train, y_train)

# calculate the print the accuracy score
acc = accuracy_score(y_test, clf.predict(X_test))
print(f"Model trained with accuracy: {round(acc, 3)}")

Expand All @@ -26,7 +31,11 @@ def predict(query_data):
print(f"Model prediction: {classes[prediction]}")
return classes[prediction]

# function to retrain the model as part of the feedback loop
def retrain(data):
# pull out the relevant X and y from the FeedbackIn object
X = [list(d.dict().values())[:-1] for d in data]
y = [r_classes[d.flower_class] for d in data]

# fit the classifier again based on the new data obtained
clf.fit(X, y)
3 changes: 3 additions & 0 deletions test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
def test_ping():
with TestClient(app) as client:
response = client.get("/ping")
# asserting the correct response is received
assert response.status_code == 200
assert response.json() == {"ping": "pong"}


# test to check if Iris Virginica is classified correctly
def test_pred_virginica():
# defining a sample payload for the testcase
payload = {
"sepal_length": 3,
"sepal_width": 5,
Expand All @@ -19,5 +21,6 @@ def test_pred_virginica():
}
with TestClient(app) as client:
response = client.post("/predict_flower", json=payload)
# asserting the correct response is received
assert response.status_code == 200
assert response.json() == {"flower_class": "Iris Virginica"}

0 comments on commit c074254

Please sign in to comment.