Skip to content

Commit

Permalink
Add FLAML experiment runner.
Browse files Browse the repository at this point in the history
  • Loading branch information
AxiomAlive committed Feb 23, 2025
1 parent 6083e8d commit 7140702
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 1 deletion.
10 changes: 9 additions & 1 deletion experiment.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,15 @@ run_experiment() {
automl="ag"
fi

source env/bin/activate
if [[ "$*" == *"flaml"* ]]; then
automl="flaml"
fi

if [[ "$automl" == "imba" ]]; then
source env/bin/activate
else
source devenv/bin/activate
fi

"$VIRTUAL_ENV"/bin/python -m experiment.main --automl="$automl" --out="$out" --preset="$preset"
}
Expand Down
29 changes: 29 additions & 0 deletions experiment/flaml_automl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import Union

import numpy as np
import pandas as pd
from sklearn.exceptions import NotFittedError
from sklearn.metrics import f1_score

from experiment.runner import AutoMLRunner
from flaml import AutoML


class FLAMLExperimentRunner(AutoMLRunner):
def fit(self, X_train: Union[np.ndarray, pd.DataFrame], y_train: Union[np.ndarray, pd.Series], target_label: str,
dataset_name: str):
flaml = AutoML()
flaml.fit(X_train, y_train, task='classification', time_budget=-1, metric='f1')

best_loss = flaml.best_loss
best_model = flaml.best_estimator
self._log_val_loss_alongside_model_class({best_model: best_loss})

self._fitted_model = flaml

def predict(self, X_test: Union[np.ndarray, pd.DataFrame]) -> np.ndarray:
if self._fitted_model is None:
raise NotFittedError()

predictions = self._fitted_model.predict(X_test)
return predictions
4 changes: 4 additions & 0 deletions experiment/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ def main():
from experiment.imba import ImbaExperimentRunner

automl_runner = ImbaExperimentRunner()
elif automl == 'flaml':
from experiment.flaml_automl import FLAMLExperimentRunner

automl_runner = FLAMLExperimentRunner()
else:
raise ValueError("Invalid --automl option. Options available: ['imba', 'ag'].")

Expand Down
2 changes: 2 additions & 0 deletions experiment/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ def run(self, n_evals: Optional[int] = None):

iterator_of_class_belongings = iter(sorted(class_belongings))
*_, positive_class_label = iterator_of_class_belongings
logger.info(f"Pos class label: {positive_class_label}")

number_of_positives = class_belongings.get(positive_class_label, None)

if number_of_positives is None:
Expand Down

0 comments on commit 7140702

Please sign in to comment.