Skip to content

Commit

Permalink
SDK-392: Support for Mlflow Cluster Type (#306)
Browse files Browse the repository at this point in the history
  • Loading branch information
Anupam-pandey authored and chattarajoy committed Mar 18, 2020
1 parent 46ead9a commit 68657f8
Show file tree
Hide file tree
Showing 2 changed files with 204 additions and 154 deletions.
46 changes: 32 additions & 14 deletions qds_sdk/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ def __init__(self, flavour=None):
self.hadoop_settings = {}
self.presto_settings = {}
self.spark_settings = {}
self.airflow_settings ={}
self.airflow_settings = {}
self.engine_config = {}
self.mlflow_settings = {}

def set_engine_config(self,
custom_hadoop_config=None,
Expand All @@ -31,7 +32,8 @@ def set_engine_config(self,
airflow_version=None,
airflow_python_version=None,
is_ha=None,
enable_rubix=None):
enable_rubix=None,
mlflow_version=None):
'''
Args:
Expand Down Expand Up @@ -68,13 +70,16 @@ def set_engine_config(self,
is_ha: Enabling HA config for cluster
is_deeplearning : this is a deeplearning cluster config
enable_rubix: Enable rubix on the cluster
mlflow_version : this is the version of the mlflow cluster
'''

self.set_hadoop_settings(custom_hadoop_config, use_qubole_placement_policy, is_ha, fairscheduler_config_xml, default_pool, enable_rubix)
self.set_hadoop_settings(custom_hadoop_config, use_qubole_placement_policy, is_ha, fairscheduler_config_xml,
default_pool, enable_rubix)
self.set_presto_settings(presto_version, custom_presto_config)
self.set_spark_settings(spark_version, custom_spark_config)
self.set_airflow_settings(dbtap_id, fernet_key, overrides, airflow_version, airflow_python_version)
self.set_mlflow_settings(mlflow_version)

def set_fairscheduler_settings(self,
fairscheduler_config_xml=None,
Expand Down Expand Up @@ -121,11 +126,15 @@ def set_airflow_settings(self,
self.airflow_settings['version'] = airflow_version
self.airflow_settings['airflow_python_version'] = airflow_python_version

def set_mlflow_settings(self,
mlflow_version="1.5"):
self.mlflow_settings['version'] = mlflow_version

def set_engine_config_settings(self, arguments):
custom_hadoop_config = util._read_file(arguments.custom_hadoop_config_file)
fairscheduler_config_xml = util._read_file(arguments.fairscheduler_config_xml_file)
custom_presto_config = util._read_file(arguments.presto_custom_config_file)
is_deeplearning=False
is_deeplearning = False

self.set_engine_config(custom_hadoop_config=custom_hadoop_config,
use_qubole_placement_policy=arguments.use_qubole_placement_policy,
Expand All @@ -140,14 +149,16 @@ def set_engine_config_settings(self, arguments):
overrides=arguments.overrides,
airflow_version=arguments.airflow_version,
airflow_python_version=arguments.airflow_python_version,
enable_rubix=arguments.enable_rubix)
enable_rubix=arguments.enable_rubix,
mlflow_version=arguments.mlflow_version)

@staticmethod
def engine_parser(argparser):
engine_group = argparser.add_argument_group("engine settings")
engine_group.add_argument("--flavour",
dest="flavour",
choices=["hadoop", "hadoop2", "hs2", "hive", "presto", "spark", "sparkstreaming", "hbase", "airflow", "deeplearning"],
choices=["hadoop", "hadoop2", "hs2", "hive", "presto", "spark", "sparkstreaming",
"hbase", "airflow", "deeplearning", "mlflow"],
default=None,
help="Set engine flavour")

Expand All @@ -172,15 +183,15 @@ def engine_parser(argparser):
" for clusters with spot nodes", )
enable_rubix_group = hadoop_settings_group.add_mutually_exclusive_group()
enable_rubix_group.add_argument("--enable-rubix",
dest="enable_rubix",
action="store_true",
default=None,
help="Enable rubix for cluster", )
dest="enable_rubix",
action="store_true",
default=None,
help="Enable rubix for cluster", )
enable_rubix_group.add_argument("--no-enable-rubix",
dest="enable_rubix",
action="store_false",
default=None,
help="Do not enable rubix for cluster", )
dest="enable_rubix",
action="store_false",
default=None,
help="Do not enable rubix for cluster", )

fairscheduler_group = argparser.add_argument_group(
"fairscheduler configuration options")
Expand Down Expand Up @@ -236,3 +247,10 @@ def engine_parser(argparser):
default=None,
help="python environment version for airflow cluster", )

mlflow_settings_group = argparser.add_argument_group("mlflow settings")

mlflow_settings_group.add_argument("--mlflow-version",
dest="mlflow_version",
default=None,
help="mlflow version for mlflow cluster", )

Loading

0 comments on commit 68657f8

Please sign in to comment.