Skip to content

Commit

Permalink
forecasting data type and serve model support (#2290)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesbchao authored Feb 8, 2024
1 parent 7b35e37 commit 35ac794
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 3 deletions.
2 changes: 2 additions & 0 deletions assets/responsibleai/tabular/components/src/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ class PropertyKeyValues:
RAI_INSIGHTS_DATETIME_FEATURES_KEY = "datetime_features"
RAI_INSIGHTS_TIME_SERIES_ID_FEATURES_KEY = "time_series_id_features"

RAI_INSIGHTS_DATA_TYPE_KEY = "data_type"


class RAIToolType:
CAUSAL = "causal"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
)
data_type = "data_type"

FORECASTING = "forecasting"


_logger = logging.getLogger(__file__)
logging.basicConfig(level=logging.INFO)

Expand Down Expand Up @@ -117,7 +120,10 @@ def load_mlflow_model(
if use_model_dependency:
if not use_separate_conda_env:
try:
conda_file = mlflow.pyfunc.get_model_dependencies(model_uri, format="conda")
conda_file = mlflow.pyfunc.get_model_dependencies(
model_uri,
format="conda"
)
except Exception as e:
raise UserConfigError(
"Failed to get model dependency from given model {}, error:\n{}".format(
Expand All @@ -126,11 +132,28 @@ def load_mlflow_model(
)
try:
if use_separate_conda_env:
tmp_model_path = "./mlflow_model"
if (not model_path and model_id):
model_path = Model.get_model_path(model_name=model.name,
version=model.version)
shutil.copytree(model_path, tmp_model_path)
model_uri = tmp_model_path

_logger.info("MODEL URI: {}".format(
model_uri
))

for root, _, files in os.walk(model_uri):
for f in files:
full_path = os.path.join(root, f)
_logger.info("FILE: {}".format(
full_path
))

conda_install_command = ["mlflow", "models", "prepare-env",
"-m", model_uri,
"--env-manager", "conda"]
else:
# mlflow model input mount as read only. Conda need write access.
local_conda_dep = "./conda_dep.yaml"
shutil.copyfile(conda_file, local_conda_dep)
conda_prefix = str(Path(sys.executable).parents[1])
Expand Down Expand Up @@ -164,7 +187,7 @@ def load_mlflow_model(
return model

# Serve model from separate conda env using mlflow
mlflow_models_serve_logfile_name = "mlflow_models_serve.log"
mlflow_models_serve_logfile_name = "./logs/azureml/mlflow_models_serve.log"
try:
# run mlflow model server in background
with open(mlflow_models_serve_logfile_name, "w") as logfile:
Expand Down Expand Up @@ -231,6 +254,7 @@ def load_mlflow_model(
)
_logger.info("Successfully started mlflow model server.")
model = ServedModelWrapper(port=MLFLOW_MODEL_SERVER_PORT)
_logger.info("Successfully loaded model.")
return model
except Exception as e:
raise UserConfigError(
Expand Down Expand Up @@ -465,6 +489,15 @@ def add_properties_to_gather_run(
],
}

constructor_args = dashboard_info[
DashboardInfo.RAI_INSIGHTS_CONSTRUCTOR_ARGS_KEY
]
if "task_type" in constructor_args:
if constructor_args["task_type"] == FORECASTING:
run_properties[
PropertyKeyValues.RAI_INSIGHTS_DATA_TYPE_KEY
] = FORECASTING

_logger.info("Appending tool present information")
for k, v in tool_present_dict.items():
key = PropertyKeyValues.RAI_INSIGHTS_TOOL_KEY_FORMAT.format(k)
Expand Down

0 comments on commit 35ac794

Please sign in to comment.