-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e2f2388
commit fc11e0d
Showing
3 changed files
with
159 additions
and
0 deletions.
There are no files selected for viewing
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
alembic==1.13.1 | ||
blinker==1.7.0 | ||
certifi==2023.11.17 | ||
charset-normalizer==3.3.2 | ||
click==8.1.7 | ||
cloudpickle==3.0.0 | ||
contourpy==1.1.1 | ||
cycler==0.12.1 | ||
databricks-cli==0.18.0 | ||
docker==6.1.3 | ||
entrypoints==0.4 | ||
Flask==3.0.0 | ||
fonttools==4.47.0 | ||
gitdb==4.0.11 | ||
GitPython==3.1.40 | ||
greenlet==3.0.3 | ||
gunicorn==21.2.0 | ||
h5py==3.10.0 | ||
idna==3.6 | ||
importlib-metadata==7.0.1 | ||
importlib-resources==6.1.1 | ||
itsdangerous==2.1.2 | ||
Jinja2==3.1.2 | ||
joblib==1.3.2 | ||
kiwisolver==1.4.5 | ||
Mako==1.3.0 | ||
Markdown==3.5.1 | ||
MarkupSafe==2.1.3 | ||
matplotlib==3.7.4 | ||
mlflow==2.9.2 | ||
mpi4py==3.1.5 | ||
numpy==1.24.4 | ||
nvidia-ml-py==12.535.133 | ||
oauthlib==3.2.2 | ||
packaging==23.2 | ||
pandas==2.0.3 | ||
-e git+ssh://[email protected]/Helmholtz-AI-Energy/perun.git@e2f23885dc8207838961a4a036583fdc5bd2e2bc#egg=perun | ||
Pillow==10.1.0 | ||
protobuf==4.25.1 | ||
psutil==5.9.7 | ||
py-cpuinfo==9.0.0 | ||
pyarrow==14.0.2 | ||
PyJWT==2.8.0 | ||
pyparsing==3.1.1 | ||
python-dateutil==2.8.2 | ||
pytz==2023.3.post1 | ||
PyYAML==6.0.1 | ||
querystring-parser==1.2.4 | ||
requests==2.31.0 | ||
scikit-learn==1.3.2 | ||
scipy==1.10.1 | ||
six==1.16.0 | ||
smmap==5.0.1 | ||
SQLAlchemy==2.0.24 | ||
sqlparse==0.4.4 | ||
tabulate==0.9.0 | ||
threadpoolctl==3.2.0 | ||
typing_extensions==4.9.0 | ||
tzdata==2023.3 | ||
urllib3==2.1.0 | ||
websocket-client==1.7.0 | ||
Werkzeug==3.0.1 | ||
zipp==3.17.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
import logging | ||
import sys | ||
import warnings | ||
from urllib.parse import urlparse | ||
|
||
import mlflow | ||
import mlflow.sklearn | ||
import numpy as np | ||
import pandas as pd | ||
from sklearn.linear_model import ElasticNet | ||
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score | ||
from sklearn.model_selection import train_test_split | ||
|
||
from perun import monitor, register_callback | ||
|
||
logging.basicConfig(level=logging.WARN) | ||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@monitor() | ||
def eval_metrics(actual, pred): | ||
rmse = np.sqrt(mean_squared_error(actual, pred)) | ||
mae = mean_absolute_error(actual, pred) | ||
r2 = r2_score(actual, pred) | ||
return rmse, mae, r2 | ||
|
||
|
||
@monitor() | ||
def train(data): | ||
train, test = train_test_split(data) | ||
|
||
# The predicted column is "quality" which is a scalar from [3, 9] | ||
train_x = train.drop(["quality"], axis=1) | ||
test_x = test.drop(["quality"], axis=1) | ||
train_y = train[["quality"]] | ||
test_y = test[["quality"]] | ||
|
||
alpha = float(sys.argv[1]) if len(sys.argv) > 1 else 0.5 | ||
l1_ratio = float(sys.argv[2]) if len(sys.argv) > 2 else 0.5 | ||
|
||
with mlflow.start_run() as active_run: | ||
|
||
@register_callback | ||
def perun2mlflow(node): | ||
mlflow.start_run(active_run.info.run_id) | ||
for metricType, metric in node.metrics.items(): | ||
name = f"{metric.type.value}" | ||
mlflow.log_metric(name, metric.value) | ||
|
||
lr = ElasticNet(alpha=alpha, l1_ratio=l1_ratio, random_state=42) | ||
lr.fit(train_x, train_y) | ||
|
||
predicted_qualities = lr.predict(test_x) | ||
|
||
(rmse, mae, r2) = eval_metrics(test_y, predicted_qualities) | ||
|
||
print("Elasticnet model (alpha=%f, l1_ratio=%f):" % (alpha, l1_ratio)) | ||
print(" RMSE: %s" % rmse) | ||
print(" MAE: %s" % mae) | ||
print(" R2: %s" % r2) | ||
|
||
mlflow.log_param("alpha", alpha) | ||
mlflow.log_param("l1_ratio", l1_ratio) | ||
mlflow.log_metric("rmse", rmse) | ||
mlflow.log_metric("r2", r2) | ||
mlflow.log_metric("mae", mae) | ||
|
||
tracking_url_type_store = urlparse(mlflow.get_tracking_uri()).scheme | ||
|
||
# Model registry does not work with file store | ||
if tracking_url_type_store != "file": | ||
# Register the model | ||
# There are other ways to use the Model Registry, which depends on the use case, | ||
# please refer to the doc for more information: | ||
# https://mlflow.org/docs/latest/model-registry.html#api-workflow | ||
mlflow.sklearn.log_model( | ||
lr, "model", registered_model_name="ElasticnetWineModel" | ||
) | ||
else: | ||
mlflow.sklearn.log_model(lr, "model") | ||
|
||
|
||
if __name__ == "__main__": | ||
warnings.filterwarnings("ignore") | ||
np.random.seed(40) | ||
|
||
# Read the wine-quality csv file from the URL | ||
csv_url = "http://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-red.csv" | ||
try: | ||
data = pd.read_csv(csv_url, sep=";") | ||
except Exception as e: | ||
logger.exception( | ||
"Unable to download training & test CSV, check your internet connection. Error: %s", | ||
e, | ||
) | ||
train(data) |