From 3cfa159e8cab69695b2d1b085294e19c16741db1 Mon Sep 17 00:00:00 2001 From: morrisnein Date: Thu, 14 Dec 2023 18:29:47 +0000 Subject: [PATCH] fix after rebase --- experiments/fedot_warm_start/run.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/experiments/fedot_warm_start/run.py b/experiments/fedot_warm_start/run.py index 591e6f00..06202e55 100644 --- a/experiments/fedot_warm_start/run.py +++ b/experiments/fedot_warm_start/run.py @@ -33,10 +33,19 @@ from meta_automl.data_preparation.datasets_train_test_split import openml_datasets_train_test_split from meta_automl.data_preparation.file_system import get_cache_dir -CONFIG_PATH = Path(__file__).parent.joinpath('config_light.yaml') +CONFIGS_DIR = Path(__file__).parent -with open(CONFIG_PATH, 'r') as config_file: - config = yaml.load(config_file, yaml.Loader) +with open(CONFIGS_DIR / 'configs_list.yaml', 'r') as config_file: + configs_list = yaml.load(config_file, yaml.Loader) + +config = {} +for conf_name in configs_list: + with open(CONFIGS_DIR / conf_name, 'r') as config_file: + conf = yaml.load(config_file, yaml.Loader) + intersection = set(config).intersection(set(conf)) + if intersection: + raise ValueError(f'Parameter values given twice: {conf_name}, {intersection}.') + config.update(conf) # Load constants SEED = config['seed'] @@ -94,8 +103,8 @@ def get_current_formatted_date() -> Tuple[datetime, str, str]: def get_save_dir(time_now_for_path) -> Path: save_dir = get_cache_dir(). \ joinpath('experiments').joinpath('fedot_warm_start').joinpath(f'run_{time_now_for_path}') - if 'debug' in CONFIG_PATH.name: - save_dir = save_dir.with_name('debug_' + save_dir.name) + if SAVE_DIR_PREFIX: + save_dir = save_dir.with_name(SAVE_DIR_PREFIX + save_dir.name) if save_dir.exists(): shutil.rmtree(save_dir) save_dir.mkdir(parents=True) @@ -320,6 +329,7 @@ def main(): dataset_ids = get_dataset_ids() dataset_ids_train, dataset_ids_test = split_datasets(dataset_ids, N_DATASETS, UPDATE_TRAIN_TEST_DATASETS_SPLIT) + dataset_ids = dataset_ids_train + dataset_ids_test algorithm = KNNSimilarityModelAdvice( N_BEST_DATASET_MODELS_TO_MEMORIZE, @@ -350,6 +360,7 @@ def main(): random_state=DATA_SPLIT_SEED) train_data, test_data = dataset_data[idx_train], dataset_data[idx_test] dataset_splits[dataset_id] = dict(train=train_data, test=test_data) + knowledge_base = {dataset_id: [] for dataset_id in dataset_ids_train} fedot_evaluations_cache = CacheDict(get_cache_dir() / 'fedot_runs.pkl') description = 'FEDOT, all datasets'