From b81e1c8e2e8aa6c19605835d9f9c1f7ab569874a Mon Sep 17 00:00:00 2001 From: Alex Zwanenburg Date: Tue, 26 Nov 2024 16:57:36 +0100 Subject: [PATCH] WIP on feature selection-less training. --- R/Familiar.R | 12 ++++--- R/HyperparameterOptimisation.R | 2 +- R/HyperparameterOptimisationUtilities.R | 5 +-- R/TaskFeatureInfo.R | 26 ++++++++++------ R/TaskLearn.R | 15 ++++++++- R/TaskVimp.R | 2 +- R/VimpTable.R | 4 +-- tests/testthat/test-task_based_workflow.R | 38 +++++++++++++++++++++++ 8 files changed, 82 insertions(+), 22 deletions(-) diff --git a/R/Familiar.R b/R/Familiar.R index cdab5fb2..64a18f52 100644 --- a/R/Familiar.R +++ b/R/Familiar.R @@ -376,6 +376,7 @@ summon_familiar <- function( if (.stop_after == "training") { tasks <- .generate_trainer_tasks( experiment_data = experiment_data, + optimisation_determine_vimp = settings$hpo$hpo_determine_vimp, vimp_methods = settings$vimp$vimp_methods, learners = settings$mb$learners, file_paths = file_paths @@ -442,10 +443,13 @@ summon_familiar <- function( # Check if the process should be stopped at this point. if (.stop_after %in% c("vimp")) { - feature_info <- get_feature_info_from_backend( - data_id = waiver(), - run_id = waiver() - ) + feature_info <- NULL + if (!is_empty(tasks$feature_info)) { + feature_info <- get_feature_info_from_backend( + data_id = waiver(), + run_id = waiver() + ) + } vimp_hyperparameters <- NULL if (!is_empty(tasks$hyperparameters_vimp)) { diff --git a/R/HyperparameterOptimisation.R b/R/HyperparameterOptimisation.R index 3d5ccc21..5aaf4e5c 100644 --- a/R/HyperparameterOptimisation.R +++ b/R/HyperparameterOptimisation.R @@ -1410,7 +1410,7 @@ setMethod( ) # Attach variable importance tables. - object@vimp_table <- decluster_vimp_table(vimp_table_list) + object@vimp_table <- vimp_table_list return(object) } diff --git a/R/HyperparameterOptimisationUtilities.R b/R/HyperparameterOptimisationUtilities.R index 120b5187..1cc56bb6 100644 --- a/R/HyperparameterOptimisationUtilities.R +++ b/R/HyperparameterOptimisationUtilities.R @@ -445,10 +445,7 @@ feature_info_list = feature_info, data = data ) - - # Form clusters. - vimp_table <- recluster_vimp_table(vimp_table) - + return(vimp_table) } diff --git a/R/TaskFeatureInfo.R b/R/TaskFeatureInfo.R index a373d008..95a49256 100644 --- a/R/TaskFeatureInfo.R +++ b/R/TaskFeatureInfo.R @@ -489,18 +489,26 @@ setMethod( data_id <- run_id <- list_name <- complete <- NULL # Create or load generic feature info. - if (!.file_exists(tasks$generic_feature_info[[1L]])) { - generic_feature_info <- .perform_task( - object = tasks$generic_feature_info[[1L]], - data = NULL, - experiment_data = experiment_data, - ... - ) + if (!is_empty(tasks$generic_feature_info)) { + if (!.file_exists(tasks$generic_feature_info[[1L]])) { + generic_feature_info <- .perform_task( + object = tasks$generic_feature_info[[1L]], + data = NULL, + experiment_data = experiment_data, + ... + ) + + } else { + generic_feature_info <- readRDS(tasks$generic_feature_info[[1L]]@file) + } } else { - generic_feature_info <- readRDS(tasks$generic_feature_info[[1L]]@file) + generic_feature_info <- NULL } + # Check that any feature info tasks are required. + if (is_empty(tasks$feature_info)) return(invisible(FALSE)) + # Determine which feature info objects need to be obtained. finished_tasks <- sapply(tasks$feature_info, .file_exists) unfinished_tasks <- tasks$feature_info[!finished_tasks] @@ -510,7 +518,7 @@ setMethod( if (length(unfinished_tasks) > 0L) { ..run_preprocessing( tasks = unfinished_tasks, - generic_feature_info <- generic_feature_info, + generic_feature_info = generic_feature_info, experiment_data = experiment_data, ... ) diff --git a/R/TaskLearn.R b/R/TaskLearn.R index 5f63dad9..de4ade59 100644 --- a/R/TaskLearn.R +++ b/R/TaskLearn.R @@ -416,13 +416,14 @@ setMethod( .generate_trainer_tasks <- function( experiment_data, + optimisation_determine_vimp, vimp_methods, learners, file_paths, skip_existing = FALSE ) { # Suppress NOTES due to non-standard evaluation in data.table - train <- main_data_id <- can_pre_process <- NULL + train <- main_data_id <- can_pre_process <- vimp <- NULL # Find the data_id related to model training. data_id <- experiment_data@experiment_setup[train == TRUE, ]$main_data_id[1L] @@ -474,6 +475,17 @@ setMethod( # learner hyperparameter tasks ----------------------------------------------- + # Check how variable importance data should be handled. + if (is_empty(experiment_data@experiment_setup[vimp == TRUE, ])) { + use_vimp <- "return_hpo_vimp" + + } else if (optimisation_determine_vimp) { + use_vimp <- "use_hpo_vimp" + + } else { + use_vimp <- "use_main_vimp" + } + # Set up variable importance hyperparameter task. train_run_table <- .get_run_table_from_experiment_setup( data_id = data_id, @@ -495,6 +507,7 @@ setMethod( "familiarTaskLearnerHyperparameters", data_id = learner_hyperparameter_data_id, run_id = run_id, + use_vimp = use_vimp, vimp_method = vimp_method, learner = learner, run_table = run_tables, diff --git a/R/TaskVimp.R b/R/TaskVimp.R index e5b5ced5..57fefa4d 100644 --- a/R/TaskVimp.R +++ b/R/TaskVimp.R @@ -407,7 +407,7 @@ setMethod( # Find the data_id related to computing variable importance. data_id <- experiment_data@experiment_setup[vimp == TRUE, ]$main_data_id[1L] - if (is_empty(data_id)) return(NULL) + if (is.na(data_id)) return(NULL) # Initialise empty list. task_list <- list() diff --git a/R/VimpTable.R b/R/VimpTable.R index 48c3a164..dbb7995a 100644 --- a/R/VimpTable.R +++ b/R/VimpTable.R @@ -372,7 +372,7 @@ setMethod( function(x, ...) { # If the list is empty, return NULL instead. if (is_empty(x)) return(NULL) - + browser() # Dispatch to method for single variable importance tables. return(lapply( x, @@ -388,7 +388,7 @@ setMethod( "decluster_vimp_table", signature(x = "vimpTable"), function(x, show_weights = FALSE, show_cluster_name = FALSE, ...) { - browser() + # Check if the table has already been declustered. if (.as_vimp_table_state(x@state) >= "declustered") return(x) diff --git a/tests/testthat/test-task_based_workflow.R b/tests/testthat/test-task_based_workflow.R index 61d4cdd1..7b546679 100644 --- a/tests/testthat/test-task_based_workflow.R +++ b/tests/testthat/test-task_based_workflow.R @@ -66,3 +66,41 @@ model <- familiar::train_familiar( verbose = TRUE, parallel = FALSE ) + + +# Check without explicit variable importance computation ----------------------- +# Create variable importance +experiment_vimp <- familiar::precompute_vimp( + data = data, + experimental_design = "bs(mb,3)", + vimp_method = "mim", + outcome_type = "binomial", + outcome_column = "outcome", + batch_id_column = "batch_id", + sample_id_column = "sample_id", + series_id_column = "series_id", + class_levels = c("red", "green"), + verbose = TRUE, + parallel = FALSE +) + +testthat::test_that("variable importance data is absent", { + testthat::expect_null(experiment_vimp@feature_info) + testthat::expect_null(experiment_vimp@vimp_table_list) +}) + +# Train model +model <- familiar::train_familiar( + data = data, + experiment_data = experiment_vimp, + vimp_method = "mim", + learner = "glm_logistic", + outcome_type = "binomial", + outcome_column = "outcome", + batch_id_column = "batch_id", + sample_id_column = "sample_id", + series_id_column = "series_id", + class_levels = c("red", "green"), + verbose = TRUE, + parallel = FALSE +)