From e61f81e19c3b0b6bc739b987ec171db1704b39d6 Mon Sep 17 00:00:00 2001 From: Alex Zwanenburg Date: Tue, 5 Nov 2024 17:45:23 +0100 Subject: [PATCH] WIP on VIMP task and VIMP hyperparameter task. --- R/Familiar.R | 9 +- R/FamiliarS4Classes.R | 16 +-- R/TaskFeatureInfo.R | 20 ++++ R/TaskMain.R | 75 +++++++++++- R/TaskVimp.R | 175 ++++++++++++++-------------- R/TaskVimpHyperparameters.R | 220 ++++++++++++++++++++++++++++++++++++ 6 files changed, 409 insertions(+), 106 deletions(-) create mode 100644 R/TaskVimpHyperparameters.R diff --git a/R/Familiar.R b/R/Familiar.R index 7448f8b8..efdbeab3 100644 --- a/R/Familiar.R +++ b/R/Familiar.R @@ -392,8 +392,7 @@ summon_familiar <- function( # Select and sort unique tasks. tasks <- .sort_tasks(tasks) - browser() - + # Pre-processing ------------------------------------------------------------- # Start pre-processing .run_preprocessing( @@ -409,9 +408,9 @@ summon_familiar <- function( # Check if the process should be stopped at this point. if (.stop_after %in% c("preprocessing")) { return(create_experiment_data( - project_id = project_info$project_id, - experiment_setup = experiment_setup, - iteration_list = project_info$iter_list, + project_id = experiment_data@project_id, + experiment_setup = experiment_data@experiment_setup, + iteration_list = experiment_data@iteration_list, feature_info = get_feature_info_from_backend( data_id = waiver(), run_id = waiver() diff --git a/R/FamiliarS4Classes.R b/R/FamiliarS4Classes.R index 64b1c817..4770ea89 100644 --- a/R/FamiliarS4Classes.R +++ b/R/FamiliarS4Classes.R @@ -1572,13 +1572,14 @@ setClass( setClass( "familiarTask", slots = list( - task_name = "character", - task_id = "integer", - n_tasks = "integer", - data_id = "integer", - run_id = "integer", - file = "character", - project_id = "ANY" + "task_name" = "character", + "task_id" = "integer", + "n_tasks" = "integer", + "data_id" = "integer", + "run_id" = "integer", + "run_table" = "ANY", + "file" = "character", + "project_id" = "ANY" ), prototype = methods::prototype( task_name = NA_character_, @@ -1586,6 +1587,7 @@ setClass( n_tasks = 1L, data_id = NA_integer_, run_id = NA_integer_, + run_table = NULL, file = NA_character_, project_id = NULL ) diff --git a/R/TaskFeatureInfo.R b/R/TaskFeatureInfo.R index 75c85fab..df4529a1 100644 --- a/R/TaskFeatureInfo.R +++ b/R/TaskFeatureInfo.R @@ -116,6 +116,17 @@ setMethod( +# .get_feature_info_list (generic feature info task) --------------------------- +setMethod( + ".get_feature_info_list", + signature(object = "familiarTaskGenericFeatureInfo"), + function(object, feature_info_list, ...) { + ..error_reached_unreachable_code(".get_feature_info_list does not exist for this task") + } +) + + + # familiarTaskFeatureInfo ------------------------------------------------------ setClass( "familiarTaskFeatureInfo", @@ -320,6 +331,15 @@ setMethod( +# .get_feature_info_list (feature info task) --------------------------- +setMethod( + ".get_feature_info_list", + signature(object = "familiarTaskFeatureInfo"), + function(object, feature_info_list, ...) { + ..error_reached_unreachable_code(".get_feature_info_list does not exist for this task") + } +) + .generate_data_preprocessing_tasks <- function( diff --git a/R/TaskMain.R b/R/TaskMain.R index 6c4cf8a8..159fd752 100644 --- a/R/TaskMain.R +++ b/R/TaskMain.R @@ -15,6 +15,74 @@ setMethod( ) +# .get_feature_info_list (general task) ---------------------------------------- +setMethod( + ".get_feature_info_list", + signature(object = "familiarTask"), + function(object, feature_info_list, ...) { + # Suppress NOTES due to non-standard evaluation in data.table + can_pre_process <- NULL + + # Attempt to get the feature info list from the backend. + if (is.null(feature_info_list) && !is.null(object@run_table)) { + # Find the last entry that is available for pre-processing + pre_processing_run <- tail(object@run_table[can_pre_process == TRUE, ], n = 1L) + + feature_info_list <- tryCatch( + get_feature_info_from_backend( + data_id = pre_processing_run$data_id[1L], + run_id = pre_processing_run$run_id[1L] + ), + error = NULL + ) + } + + # If no feature list is present on the backend, check other options. + if (is.null(feature_info_list) && is.na(object@feature_info_file)) { + # Check that a feature info list is provided, otherwise create an ad-hoc + # list as an template. + + # Set up task, and explicitly don't write to file. + generic_feature_info_task <- methods::new( + "familiarTaskFeatureInfo", + project_id = object@project_id, + file = NA_character_ + ) + + # Execute the task. + feature_info_list <- .perform_task( + object = generic_feature_info_task, + ... + ) + + } else if (is.null(feature_info_list)) { + # Assume that the feature info file attribute contains the path to the + # file containing feature info. + if (!file.exists(object@feature_info_file)) { + ..error(paste0("feature info file does not exist at location: ", object@feature_info_file)) + } + feature_info_list <- readRDS(object@feature_info_file) + feature_info_list <- update_object(feature_info_list) + + } else if (is.character(feature_info_list)) { + # If the feature info list is a string, interpret this as a path to the + # file containing the feature info. + if (!file.exists(feature_info_list)) { + ..error(paste0("feature info file does not exist at location: ", feature_info_list)) + } + feature_info_list <- readRDS(feature_info_list) + feature_info_list <- update_object(feature_info_list) + } + + if (!rlang::is_bare_list(feature_info_list)) { + ..error("no feature info objects were found.") + } + + return(feature_info_list) + } +) + + .generate_trainer_tasks <- function( file_paths, @@ -41,8 +109,7 @@ setMethod( task_list <- c( task_list, .generate_vimp_tasks( - file_paths = file_paths, - project_id = project_id + experiment_data = experiment_data ) ) @@ -50,8 +117,8 @@ setMethod( task_list <- c( task_list, .generate_learner_data_preprocessing_tasks( - file_paths = file_paths, - project_id = project_id + experiment_data = experiment_data, + file_paths = file_paths ) ) diff --git a/R/TaskVimp.R b/R/TaskVimp.R index cd72b749..fc188fbd 100644 --- a/R/TaskVimp.R +++ b/R/TaskVimp.R @@ -11,14 +11,12 @@ setClass( slots = list( "vimp_method" = "character", "hyperparameter_file" = "character", - "feature_info_file" = "character", - "run_table" = "ANY" + "feature_info_file" = "character" ), prototype = methods::prototype( vimp_method = NA_character_, hyperparameter_file = NA_character_, feature_info_file = NA_character_, - run_table = NULL, task_name = "compute_variable_importance" ) ) @@ -35,6 +33,7 @@ setMethod( # Generate file name of variable importance table object@file <- get_object_file_name( object_type = "vimpTable", + vimp_method = object@vimp_method, project_id = object@project_id, dir_path = file_paths$vimp_dir ) @@ -224,76 +223,7 @@ setMethod( -# .get_feature_info_list (vimp task) ------------------------------------------- -setMethod( - ".get_feature_info_list", - signature(object = "familiarTaskVimp"), - function(object, feature_info_list, ...) { - # Suppress NOTES due to non-standard evaluation in data.table - can_pre_process <- NULL - - # Attempt to get the feature info list from the backend. - if (is.null(feature_info_list) && !is.null(object@run_table)) { - # Find the last entry that is available for pre-processing - pre_processing_run <- tail(object@run_table[can_pre_process == TRUE, ], n = 1L) - - feature_info_list <- tryCatch( - get_feature_info_from_backend( - data_id = pre_processing_run$data_id[1L], - run_id = pre_processing_run$run_id[1L] - ), - error = NULL - ) - } - - # If no feature list is present on the backend, check other options. - if (is.null(feature_info_list) && is.na(object@feature_info_file)) { - # Check that a feature info list is provided, otherwise create an ad-hoc - # list as an template. - - # Set up task, and explicitly don't write to file. - generic_feature_info_task <- methods::new( - "familiarTaskFeatureInfo", - project_id = object@project_id, - file = NA_character_ - ) - - # Execute the task. - feature_info_list <- .perform_task( - object = generic_feature_info_task, - ... - ) - - } else if (is.null(feature_info_list)) { - # Assume that the feature info file attribute contains the path to the - # file containing feature info. - if (!file.exists(object@feature_info_file)) { - ..error(paste0("feature info file does not exist at location: ", object@feature_info_file)) - } - feature_info_list <- readRDS(object@feature_info_file) - feature_info_list <- update_object(feature_info_list) - - } else if (is.character(feature_info_list)) { - # If the feature info list is a string, interpret this as a path to the - # file containing the feature info. - if (!file.exists(feature_info_list)) { - ..error(paste0("feature info file does not exist at location: ", feature_info_list)) - } - feature_info_list <- readRDS(feature_info_list) - feature_info_list <- update_object(feature_info_list) - } - - if (!rlang::is_bare_list(feature_info_list)) { - ..error("no feature info objects were found.") - } - - return(feature_info_list) - } -) - - - -# .get_hyperparameters --------------------------------------------------------- +# .get_hyperparameters (vimp task) --------------------------------------------- setMethod( ".get_hyperparameters", signature(object = "familiarTaskVimp"), @@ -305,7 +235,7 @@ setMethod( ) { # Suppress NOTES due to non-standard evaluation in data.table can_pre_process <- NULL - + if (is.null(hyperparameters) && !is.null(object@run_table)) { # This routine loads hyperparameters from disk, and is used when an # experiment is run using summon_familiar. @@ -386,27 +316,56 @@ setMethod( .generate_vimp_tasks <- function( + experiment_data, + vimp_methods, file_paths, - project_id + skip_existing = FALSE ) { + # Suppress NOTES due to non-standard evaluation in data.table + vimp <- can_pre_process <- NULL + + # TODO: Check if vimp should be computed separately or is computed during + # hyperparameter optimisation. + + # Find the data_id related to computing variable importance. + data_id <- experiment_data@experiment_setup[vimp == TRUE, ]$main_data_id[1L] + if (is_empty(data_id)) return(NULL) + # Initialise empty list. task_list <- list() + ii <- 1L - # Check if vimp should be computed separately or is computed during - # hyperparameter optimisation. + # vimp tasks ----------------------------------------------------------------- + + # Get run ids. + run_ids <- names(experiment_data@iteration_list[[as.character(data_id)]]$run) + run_ids <- as.integer(run_ids) - for (data_id in data_ids) { + + # Set up variable importance computation task. + for (vimp_method in vimp_methods) { for (run_id in run_ids) { - for (vimp_method in vimp_methods) { - - # Check if the variable importance method requires any computation. - # For example, signature_only, none and random do not require - # computation. - - # Set up variable importance computation task. - - # Set up variable importance hyperparameter task. - + + # Create task to generate run-specific feature info. + vimp_task <- methods::new( + "familiarTaskVimp", + data_id = data_id, + run_id = run_id, + vimp_method = vimp_method, + run_table = data.table::copy(experiment_data@iteration_list[[as.character(data_id)]]$run[[as.character(run_id)]]$run_table), + project_id = experiment_data@project_id + ) + + # Add file names. + vimp_task <- .set_file_name( + object = vimp_task, + file_paths = file_paths + ) + + # Add to list, if the file does not exist on disk. + if (!skip_existing || !.file_exists(vimp_task)) { + task_list[[ii]] <- vimp_task + ii <- ii + 1L } } } @@ -414,12 +373,48 @@ setMethod( # Check if any vimp-related tasks are required. if (length(task_list) == 0L) return(NULL) + # vimp hyperparameter tasks -------------------------------------------------- + + # Set up variable importance hyperparameter task. + run_table <- experiment_data@iteration_list[[as.character(data_id)]]$run[[1L]]$run_table + vimp_hyperparameter_data_id <- tail(run_table[can_pre_process == TRUE, ], n = 1L)$data_id[1L] + + # Get run ids. + run_ids <- names(experiment_data@iteration_list[[as.character(vimp_hyperparameter_data_id)]]$run) + run_ids <- as.integer(run_ids) + + for (vimp_method in vimp_methods) { + for (run_id in run_ids) { + # Create task to generate run-specific feature info. + vimp_hyperparameter_task <- methods::new( + "familiarTaskVimpHyperparameters", + data_id = vimp_hyperparameter_data_id, + run_id = run_id, + vimp_method = vimp_method, + run_table = data.table::copy(experiment_data@iteration_list[[as.character(vimp_hyperparameter_data_id)]]$run[[as.character(run_id)]]$run_table), + project_id = experiment_data@project_id + ) + + # Add file names. + vimp_hyperparameter_task <- .set_file_name( + object = vimp_hyperparameter_task, + file_paths = file_paths + ) + + # Add to list, if the file does not exist on disk. + if (!skip_existing || !.file_exists(vimp_hyperparameter_task)) { + task_list[[ii]] <- vimp_hyperparameter_task + ii <- ii + 1L + } + } + } + # Add tasks related to data processing for vimp methods. task_list <- c( task_list, .generate_vimp_data_preprocessing_tasks( - file_paths = file_paths, - project_id = project_id + experiment_data = experiment_data, + file_paths = file_paths ) ) diff --git a/R/TaskVimpHyperparameters.R b/R/TaskVimpHyperparameters.R new file mode 100644 index 00000000..c0efa506 --- /dev/null +++ b/R/TaskVimpHyperparameters.R @@ -0,0 +1,220 @@ +#' @include FamiliarS4Generics.R +#' @include FamiliarS4Classes.R +NULL + + + +# familiarTaskVimpHyperparameters ---------------------------------------------- +setClass( + "familiarTaskVimpHyperparameters", + contains = "familiarTask", + slots = list( + "vimp_method" = "character", + "feature_info_file" = "character" + ), + prototype = methods::prototype( + vimp_method = NA_character_, + feature_info_file = NA_character_, + task_name = "set_variable_importance_hyperparameters" + ) +) + + + +# .set_file_name (vimp hyperparameters task) ----------------------------------- +setMethod( + ".set_file_name", + signature(object = "familiarTaskVimpHyperparameters"), + function(object, file_paths = NULL) { + if (is.null(file_paths)) return(object) + + # Generate file name of variable importance table + object@file <- get_object_file_name( + object_type = "hyperparametersVimp", + vimp_method = object@vimp_method, + project_id = object@project_id, + dir_path = file_paths$vimp_dir + ) + + return(object) + } +) + + + +# .get_task_descriptor (vimp hyperparameters task) ----------------------------- +setMethod( + ".get_task_descriptor", + signature(object = "familiarTaskVimpHyperparameters"), + function(object, ...) { + return(paste0(object@task_name, "_", object@data_id, "_", object@run_id, "_", object@vimp_method)) + } +) + + + +# .perform_task (vimp hyperparameters task , NULL) ----------------------------- +setMethod( + ".perform_task", + signature( + object = "familiarTaskVimpHyperparameters", + data = "NULL" + ), + function( + object, + data, + experiment_data = NULL, + outcome_info = NULL, + ... + ) { + # This method is called when "data" is expected to be available somewhere in + # the backend. + + if (is.null(experiment_data)) { + ..error_reached_unreachable_code("experiment_data is required for retrieving data from the backend.") + } + if (is.null(outcome_info)) { + ..error_reached_unreachable_code("outcome_info is required.") + } + + # Find the run list. + run_list <- .get_run_list( + iteration_list = experiment_data@iteration_list, + data_id = object@data_id, + run_id = object@run_id + ) + + # Select unique samples. + sample_identifiers <- .get_sample_identifiers( + run = run_list, + train_or_validate = "train" + ) + sample_identifiers <- unique(sample_identifiers) + + # Create a dataObject. + data <- methods::new( + "dataObject", + data = get_data_from_backend(sample_identifiers = sample_identifiers), + preprocessing_level = "none", + outcome_type = outcome_info@outcome_type, + outcome_info = outcome_info + ) + + # Pass to method that dispatches with dataObject for further processing. + return(.perform_task( + object = object, + data = data, + ... + )) + } +) + + +# .perform_task (vimp task, dataObject) ---------------------------------------- +setMethod( + ".perform_task", + signature( + object = "familiarTaskVimp", + data = "dataObject" + ), + function( + object, + data, + settings = NULL, + feature_info_list = NULL, + hyperparameters = NULL, + message_indent = 0L, + verbose = FALSE, + cl = NULL, + ... + ) { + + logger_message( + paste0( + "\nVariable importance: starting variable importance computation using the \"", + object@vimp_method, "\" method for run ", + object@task_id, " of ", + object@n_tasks, "." + ), + indent = message_indent, + verbose = verbose + ) + + # Check that outcome_info is present on data + if (!is(data@outcome_info, "outcomeInfo")) { + ..error_reached_unreachable_code( + "outcome_info attribute of data (dataObject) does not contain an outcomeInfo object" + ) + } + + # Check and retrieve feature info list. + feature_info_list <- .get_feature_info_list( + object = object, + feature_info_list = feature_info_list, + data = data, + settings = settings, + message_indent = message_indent, + verbose = verbose, + cl = cl, + ... + ) + + # TODO: preprocess data + + # Check and retrieve hyperparameters. + hyperparameters <- .get_hyperparameters( + object = object, + hyperparameters = hyperparameters, + data = data, + settings = settings, + message_indent = message_indent, + verbose = verbose, + cl = cl, + ... + ) + + # Create the variable importance method object or familiar model object to + # compute variable importance with. + vimp_object <- methods::new( + "familiarVimpMethod", + outcome_type = data@outcome_type, + hyperparameters = hyperparameters, + vimp_method = object@vimp_method, + outcome_info = data@outcome_info, + run_table = object@run_table + ) + + # Promote to the correct subclass. + vimp_object <- promote_vimp_method(object = vimp_object) + + # Set multivariate methods. + if (is(vimp_object, "familiarModel")) is_multivariate <- TRUE + if (is(vimp_object, "familiarVimpMethod")) is_multivariate <- vimp_object@multivariate + + # Find required features. Exclude the signature features at this point, as + # these will have been dropped from the variable importance table. + required_features <- get_required_features( + x = data, + feature_info_list = feature_info_list, + exclude_signature = !is_multivariate + ) + + # Limit to required features. + vimp_object@required_features <- required_features + vimp_object@feature_info <- feature_info_list[required_features] + + # Compute variable importance. + vimp_table <- .vimp( + object = vimp_object, + data = data + ) + + if (!is.na(object@file)) { + saveRDS(vimp_table, file = object@file) + } else { + return(vimp_table) + } + + return(invisible(TRUE)) + } +)