Skip to content

Commit

Permalink
feat: mlperf example
Browse files Browse the repository at this point in the history
  • Loading branch information
JuanPedroGHM committed Dec 29, 2023
1 parent e2f2388 commit fc11e0d
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 0 deletions.
File renamed without changes.
63 changes: 63 additions & 0 deletions examples/mlflow/requirenments.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
alembic==1.13.1
blinker==1.7.0
certifi==2023.11.17
charset-normalizer==3.3.2
click==8.1.7
cloudpickle==3.0.0
contourpy==1.1.1
cycler==0.12.1
databricks-cli==0.18.0
docker==6.1.3
entrypoints==0.4
Flask==3.0.0
fonttools==4.47.0
gitdb==4.0.11
GitPython==3.1.40
greenlet==3.0.3
gunicorn==21.2.0
h5py==3.10.0
idna==3.6
importlib-metadata==7.0.1
importlib-resources==6.1.1
itsdangerous==2.1.2
Jinja2==3.1.2
joblib==1.3.2
kiwisolver==1.4.5
Mako==1.3.0
Markdown==3.5.1
MarkupSafe==2.1.3
matplotlib==3.7.4
mlflow==2.9.2
mpi4py==3.1.5
numpy==1.24.4
nvidia-ml-py==12.535.133
oauthlib==3.2.2
packaging==23.2
pandas==2.0.3
-e git+ssh://[email protected]/Helmholtz-AI-Energy/perun.git@e2f23885dc8207838961a4a036583fdc5bd2e2bc#egg=perun
Pillow==10.1.0
protobuf==4.25.1
psutil==5.9.7
py-cpuinfo==9.0.0
pyarrow==14.0.2
PyJWT==2.8.0
pyparsing==3.1.1
python-dateutil==2.8.2
pytz==2023.3.post1
PyYAML==6.0.1
querystring-parser==1.2.4
requests==2.31.0
scikit-learn==1.3.2
scipy==1.10.1
six==1.16.0
smmap==5.0.1
SQLAlchemy==2.0.24
sqlparse==0.4.4
tabulate==0.9.0
threadpoolctl==3.2.0
typing_extensions==4.9.0
tzdata==2023.3
urllib3==2.1.0
websocket-client==1.7.0
Werkzeug==3.0.1
zipp==3.17.0
96 changes: 96 additions & 0 deletions examples/mlflow/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import logging
import sys
import warnings
from urllib.parse import urlparse

import mlflow
import mlflow.sklearn
import numpy as np
import pandas as pd
from sklearn.linear_model import ElasticNet
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.model_selection import train_test_split

from perun import monitor, register_callback

logging.basicConfig(level=logging.WARN)
logger = logging.getLogger(__name__)


@monitor()
def eval_metrics(actual, pred):
rmse = np.sqrt(mean_squared_error(actual, pred))
mae = mean_absolute_error(actual, pred)
r2 = r2_score(actual, pred)
return rmse, mae, r2


@monitor()
def train(data):
train, test = train_test_split(data)

# The predicted column is "quality" which is a scalar from [3, 9]
train_x = train.drop(["quality"], axis=1)
test_x = test.drop(["quality"], axis=1)
train_y = train[["quality"]]
test_y = test[["quality"]]

alpha = float(sys.argv[1]) if len(sys.argv) > 1 else 0.5
l1_ratio = float(sys.argv[2]) if len(sys.argv) > 2 else 0.5

with mlflow.start_run() as active_run:

@register_callback
def perun2mlflow(node):
mlflow.start_run(active_run.info.run_id)
for metricType, metric in node.metrics.items():
name = f"{metric.type.value}"
mlflow.log_metric(name, metric.value)

lr = ElasticNet(alpha=alpha, l1_ratio=l1_ratio, random_state=42)
lr.fit(train_x, train_y)

predicted_qualities = lr.predict(test_x)

(rmse, mae, r2) = eval_metrics(test_y, predicted_qualities)

print("Elasticnet model (alpha=%f, l1_ratio=%f):" % (alpha, l1_ratio))
print(" RMSE: %s" % rmse)
print(" MAE: %s" % mae)
print(" R2: %s" % r2)

mlflow.log_param("alpha", alpha)
mlflow.log_param("l1_ratio", l1_ratio)
mlflow.log_metric("rmse", rmse)
mlflow.log_metric("r2", r2)
mlflow.log_metric("mae", mae)

tracking_url_type_store = urlparse(mlflow.get_tracking_uri()).scheme

# Model registry does not work with file store
if tracking_url_type_store != "file":
# Register the model
# There are other ways to use the Model Registry, which depends on the use case,
# please refer to the doc for more information:
# https://mlflow.org/docs/latest/model-registry.html#api-workflow
mlflow.sklearn.log_model(
lr, "model", registered_model_name="ElasticnetWineModel"
)
else:
mlflow.sklearn.log_model(lr, "model")


if __name__ == "__main__":
warnings.filterwarnings("ignore")
np.random.seed(40)

# Read the wine-quality csv file from the URL
csv_url = "http://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-red.csv"
try:
data = pd.read_csv(csv_url, sep=";")
except Exception as e:
logger.exception(
"Unable to download training & test CSV, check your internet connection. Error: %s",
e,
)
train(data)

0 comments on commit fc11e0d

Please sign in to comment.