Skip to content

Commit

Permalink
enable change experiment name auto-create
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasFaria committed Nov 28, 2024
1 parent 9a04a41 commit 64a92a7
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from tests.test_main import run_test
from utils.data import get_df_naf, get_sirene_3_data, get_sirene_4_data, get_test_data, get_Y
from utils.mappings import mappings
from utils.mlflow_tracking_queries import create_or_restore_experiment

parser = argparse.ArgumentParser(
description="FastAPE 🚀 : Model for coding a company's main activity"
Expand Down Expand Up @@ -311,6 +312,7 @@ def main(
embedding_dims = [value for key, value in locals().items() if key.startswith("embedding_dim")]

mlflow.set_tracking_uri(remote_server_uri)
create_or_restore_experiment(experiment_name)
mlflow.set_experiment(experiment_name)

with mlflow.start_run(run_name=run_name):
Expand Down
30 changes: 30 additions & 0 deletions src/utils/mlflow_tracking_queries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from mlflow.exceptions import RestException
from mlflow.tracking import MlflowClient


def create_or_restore_experiment(experiment_name):
client = MlflowClient()

try:
# Check if the experiment exists (either active or deleted)
experiment = client.get_experiment_by_name(experiment_name)

if experiment:
if experiment.lifecycle_stage == "deleted":
# Restore the experiment if it's deleted
print(
f"Restoring deleted experiment: '{experiment_name}' (ID: {experiment.experiment_id})"
)
client.restore_experiment(experiment.experiment_id)
else:
print(
f"Experiment '{experiment_name}' already exists and is active (ID: {experiment.experiment_id})."
)
else:
# Create the experiment if it doesn't exist
print(f"Creating a new experiment: '{experiment_name}'")
experiment_id = client.create_experiment(experiment_name)
print(f"Created experiment '{experiment_name}' with ID: {experiment_id}")

except RestException as e:
print(f"An error occurred while handling the experiment '{experiment_name}': {e}")

0 comments on commit 64a92a7

Please sign in to comment.