Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test branch #32

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Prev Previous commit
Next Next commit
[add] comments
  • Loading branch information
zjesko committed Jun 30, 2021
commit 956b24485f910de2b219720b1850e741efb23f0c
7 changes: 6 additions & 1 deletion .github/workflows/cicd.yml
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
name: ci-cd

# run the action on pull_requests and pushes
on: [pull_request, push]

jobs:
# first job to test the application using pytest
build:
runs-on: ubuntu-latest
steps:
@@ -18,10 +20,13 @@ jobs:
- name: Run pytest
run: |
pytest


# second job to zip the codebase and upload it as an artifact when build succeeds
upload_zip:
runs-on: ubuntu-latest
needs: build

# only run this action for pushes
if: ${{ github.event_name == 'push' }}
steps:
- name: Checkout the repository
21 changes: 12 additions & 9 deletions main.py
Original file line number Diff line number Diff line change
@@ -3,34 +3,37 @@
from pydantic import BaseModel
from ml_utils import load_model, predict

app = FastAPI(
title="Iris Predictor",
docs_url="/"
)
# defining the main app
app = FastAPI(title="Iris Predictor", docs_url="/")

# calling the load_model during startup
app.add_event_handler("startup", load_model)

# class which is expected in the payload
class QueryIn(BaseModel):
sepal_length: float
sepal_width: float
petal_length: float
petal_width: float


# class which is returned in the response
class QueryOut(BaseModel):
flower_class: str


# Route definitions
@app.get("/ping")
def ping():
return {"ping": "pong"}


@app.post("/predict_flower", response_model=QueryOut, status_code=200)
def predict_flower(
query_data: QueryIn
):
output = {'flower_class': predict(query_data)}
def predict_flower(query_data: QueryIn):
output = {"flower_class": predict(query_data)}
return output


# Main function to start the app when main.py is called
if __name__ == "__main__":
uvicorn.run("main:app", host='0.0.0.0', port=8888, reload=True)
uvicorn.run("main:app", host="0.0.0.0", port=8888, reload=True)
31 changes: 13 additions & 18 deletions ml_utils.py
Original file line number Diff line number Diff line change
@@ -5,27 +5,22 @@

clf = GaussianNB()

classes = {
0: "Iris Setosa",
1: "Iris Versicolour",
2: "Iris Virginica"
}
classes = {0: "Iris Setosa", 1: "Iris Versicolour", 2: "Iris Virginica"}

# function to train and load the model during startup
def load_model():
X, y = datasets.load_iris(return_X_y=True)

X_train, X_test, y_train, y_test = train_test_split(X,y, test_size=0.2)
clf.fit(X_train, y_train)

acc = accuracy_score(y_test, clf.predict(X_test))
print(f"Model trained with accuracy: {round(acc, 3)}")

def predict(query_data):
x = list(query_data.dict().values())
prediction = clf.predict([x])[0]
print(f"Model prediction: {classes[prediction]}")
return classes[prediction]
X, y = datasets.load_iris(return_X_y=True)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
clf.fit(X_train, y_train)

acc = accuracy_score(y_test, clf.predict(X_test))
print(f"Model trained with accuracy: {round(acc, 3)}")


# function to predict the flower using the model
def predict(query_data):
x = list(query_data.dict().values())
prediction = clf.predict([x])[0]
print(f"Model prediction: {classes[prediction]}")
return classes[prediction]
18 changes: 10 additions & 8 deletions test_app.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
from fastapi.testclient import TestClient
from main import app


# test to check the correct functioning of the /ping route
def test_ping():
with TestClient(app) as client:
response = client.get("/ping")
assert response.status_code == 200
assert response.json() == {"ping":"pong"}
assert response.json() == {"ping": "pong"}


# test to check if Iris Virginica is classified correctly
def test_pred_virginica():
payload = {
"sepal_length": 3,
"sepal_width": 5,
"petal_length": 3.2,
"petal_width": 4.4
"sepal_length": 3,
"sepal_width": 5,
"petal_length": 3.2,
"petal_width": 4.4,
}
with TestClient(app) as client:
response = client.post('/predict_flower', json=payload)
response = client.post("/predict_flower", json=payload)
assert response.status_code == 200
assert response.json() == {'flower_class': "Iris Virginica"}
assert response.json() == {"flower_class": "Iris Virginica"}