Skip to content

Commit

Permalink
Model Training: Add additional metric logging on test set
Browse files Browse the repository at this point in the history
  • Loading branch information
manuelfuchs committed Nov 16, 2022
1 parent 13834d1 commit 0da486d
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 9 deletions.
1 change: 1 addition & 0 deletions src/1_conda_env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dependencies:
- mlflow==1.30.0
- numpy==1.19.5
- pandas==1.1.5
- Pillow==9.3.0
- psutil==5.8.0
- scikit-learn==1.1.2
- scipy==1.7.1
Expand Down
23 changes: 14 additions & 9 deletions src/2_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,18 @@
from dotenv import load_dotenv
import mlflow
import mlflow.sklearn
from sklearn.metrics import (
accuracy_score, f1_score, precision_score, recall_score, classification_report,
confusion_matrix, ConfusionMatrixDisplay
)
from sklearn.model_selection import GridSearchCV
from sklearn.neural_network import MLPClassifier
from sklearn.preprocessing import StandardScaler

from utils import start_action, end_action
from utils import start_action, end_action, matplotlib_figure_to_pillow_image

load_dotenv('./.env')

mlflow.sklearn.autolog()

from sklearn.metrics import (
accuracy_score, f1_score, precision_score, recall_score, classification_report, confusion_matrix
)


def main():
subscription_id = os.getenv('SUBSCRIPTION_ID')
Expand Down Expand Up @@ -89,9 +87,16 @@ def analyze_model(digit_classifier, x_test, y_test):
y_pred = digit_classifier.predict(x_test)
print(classification_report(y_test, y_pred))

mlflow.log_metric('test_accuracy', accuracy_score(y_test, y_pred, normalize=True))
mlflow.log_metric('test_f1_score', f1_score(y_test, y_pred, average="macro"))
# print(precision_score(y_test, y_pred, average="macro"))
# print(recall_score(y_test, y_pred, average="macro"))
mlflow.log_metric('test_precision', precision_score(y_test, y_pred, average="macro"))
mlflow.log_metric('test_recall', recall_score(y_test, y_pred, average="macro"))

confusion_matrix_display = ConfusionMatrixDisplay(
confusion_matrix=confusion_matrix(y_test, y_pred, labels=digit_classifier.classes_),
display_labels=digit_classifier.classes_
)
mlflow.log_image(matplotlib_figure_to_pillow_image(confusion_matrix_display.figure_), 'test_confusion_matrix.png')

end_action(action_text)

Expand Down
6 changes: 6 additions & 0 deletions src/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import json
from matplotlib import figure
from PIL import Image as PILImage
import subprocess
import time
from typing import Union
Expand Down Expand Up @@ -53,3 +55,7 @@ def request_user_consent(question: str) -> bool:
print(question)
response = input('Do you want to continue? [Y/n] ')
return len(response) == 0 or response.lower() == 'y'


def matplotlib_figure_to_pillow_image(figure: figure) -> PILImage:
return PILImage.frombytes('RGB', figure.canvas.get_width_height(), figure.canvas.tostring_rgb())

0 comments on commit 0da486d

Please sign in to comment.