diff --git a/R/DataPreProcessing.R b/R/DataPreProcessing.R index 4e58b1c1..ba5676e6 100644 --- a/R/DataPreProcessing.R +++ b/R/DataPreProcessing.R @@ -293,19 +293,19 @@ determine_preprocessing_parameters <- function( -.get_feature_info_list <- function(run) { - - # Find pre-processing control element for the current run - pre_proc_id_list <- .get_preprocessing_iteration_identifiers(run = run) - - # Load feature info list from backend - feature_info_list <- get_feature_info_from_backend( - data_id = pre_proc_id_list$data, - run_id = pre_proc_id_list$run - ) - - return(feature_info_list) -} +# .get_feature_info_list <- function(run) { +# +# # Find pre-processing control element for the current run +# pre_proc_id_list <- .get_preprocessing_iteration_identifiers(run = run) +# +# # Load feature info list from backend +# feature_info_list <- get_feature_info_from_backend( +# data_id = pre_proc_id_list$data, +# run_id = pre_proc_id_list$run +# ) +# +# return(feature_info_list) +# } diff --git a/R/FamiliarS4Generics.R b/R/FamiliarS4Generics.R index 2829e35b..05fc3b5e 100644 --- a/R/FamiliarS4Generics.R +++ b/R/FamiliarS4Generics.R @@ -377,3 +377,5 @@ setGeneric(".file_exists", function(object, ...) standardGeneric(".file_exists") setGeneric(".perform_task", function(object, data, ...) standardGeneric(".perform_task")) setGeneric(".get_task_descriptor", function(object, ...) standardGeneric(".get_task_descriptor")) + +setGeneric(".get_feature_info_list", function(object, ...) standardGeneric(".get_feature_info_list")) diff --git a/R/TaskFeatureInfo.R b/R/TaskFeatureInfo.R index 46463b8e..7e59840d 100644 --- a/R/TaskFeatureInfo.R +++ b/R/TaskFeatureInfo.R @@ -23,12 +23,10 @@ setMethod( if (is.null(file_paths)) return(object) # Generate file name of pre-processing file - file_name <- paste0(object@project_id, "_generic_feature_info.RDS") - - # Add file path and normalise according to the OS - object@file <- normalizePath( - file.path(file_paths$process_data_dir, file_name), - mustWork = FALSE + object@file <- get_object_file_name( + object_type = "genericFeatureInfo", + project_id = object@project_id, + dir_path = file_paths$process_data_dir ) return(object) @@ -58,17 +56,14 @@ setMethod( function( object, data, - settings = NULL, + outcome_info = NULL, ... ) { # This method is called when "data" is expected to be available somewhere in # the backend. - - if (is.null(project_info)) { - ..error_reached_unreachable_code("project_info is required for retrieving data from the backend.") - } - if (is.null(settings)) { - ..error_reached_unreachable_code("settings is required for retrieving data from the backend.") + + if (is.null(outcome_info)) { + ..error_reached_unreachable_code("outcome_info is required.") } # Create a dataObject. @@ -76,7 +71,8 @@ setMethod( "dataObject", data = get_data_from_backend(), preprocessing_level = "none", - outcome_type = settings$data$outcome_type + outcome_type = outcome_info@outcome_type, + outcome_info = outcome_info ) # Pass to .perform_task for dataObject. @@ -141,15 +137,13 @@ setMethod( function(object, file_paths = NULL) { if (is.null(file_paths)) return(object) - # Generate file name of pre-processing file. - file_name <- paste0( - object@project_id, "_", object@data_id, "_", object@run_id, "_feature_info.RDS" - ) - - # Add file path and normalise according to the OS - object@file <- normalizePath( - file.path(file_paths$process_data_dir, file_name), - mustWork = FALSE + # Generate file name of pre-processing file + object@file <- get_object_file_name( + object_type = "featureInfo", + project_id = object@project_id, + data_id = object@data_id, + run_id = object@run_id, + dir_path = file_paths$process_data_dir ) return(object) @@ -179,8 +173,8 @@ setMethod( function( object, data, - settings = NULL, project_info = NULL, + outcome_info = NULL, ... ) { # This method is called when "data" is expected to be available somewhere in @@ -189,8 +183,8 @@ setMethod( if (is.null(project_info)) { ..error_reached_unreachable_code("project_info is required for retrieving data from the backend.") } - if (is.null(settings)) { - ..error_reached_unreachable_code("settings is required for retrieving data from the backend.") + if (is.null(outcome_info)) { + ..error_reached_unreachable_code("outcome_info is required.") } # Find the run list. @@ -212,14 +206,14 @@ setMethod( "dataObject", data = get_data_from_backend(sample_identifiers = sample_identifiers), preprocessing_level = "none", - outcome_type = settings$data$outcome_type + outcome_type = outcome_info@outcome_type, + outcome_info = outcome_info ) # Pass to method that dispatches with dataObject for further processing. return(.perform_task( object = object, data = data, - settings = settings, ... )) } diff --git a/R/TaskMain.R b/R/TaskMain.R index 5a9a7c32..3a92da08 100644 --- a/R/TaskMain.R +++ b/R/TaskMain.R @@ -15,6 +15,7 @@ setMethod( ) + .generate_trainer_tasks <- function( file_paths, project_id @@ -48,7 +49,7 @@ setMethod( # Add tasks related to data processing for learners. task_list <- c( task_list, - .generate_learner_tasks( + .generate_learner_data_preprocessing_tasks( file_paths = file_paths, project_id = project_id ) @@ -58,46 +59,7 @@ setMethod( } -.generate_vimp_tasks <- function( - file_paths, - project_id -) { - - task_list <- list() - - # Check if vimp should be computed separately or is computed during - # hyperparameter optimisation. - - for (data_id in data_ids) { - for (run_id in run_ids) { - for (vimp_method in vimp_methods) { - - # Check if the variable importance method requires any computation. - # For example, signature_only, none and random do not require - # computation. - - # Set up variable importance computation task. - - # Set up variable importance hyperparameter task. - - } - } - } - - # Check if any vimp-related tasks are required. - if (len(task_list) == 0L) return(NULL) - - # Add tasks related to data processing for vimp methods. - task_list <- c( - task_list, - .generate_vimp_data_preprocessing_tasks( - file_paths = file_paths, - project_id = project_id - ) - ) - - return(task_list) -} + diff --git a/R/TaskVimp.R b/R/TaskVimp.R new file mode 100644 index 00000000..546266a4 --- /dev/null +++ b/R/TaskVimp.R @@ -0,0 +1,309 @@ +# familiarTaskVimp ------------------------------------------------------------- +setClass( + "familiarTaskVimp", + contains = "familiarTask", + slots = list( + "vimp_method" = "character", + "hyperparameter_file" = "character", + "feature_info_file" = "character", + "run_table" = "ANY" + ), + prototype = methods::prototype( + vimp_method = NA_character_, + hyperparameter_file = NA_character_, + feature_info_file = NA_character_, + run_table = NULL, + task_name = "compute_variable_importance" + ) +) + + + +# .set_file_name (vimp task) --------------------------------------------------- +setMethod( + ".set_file_name", + signature(object = "familiarTaskVimp"), + function(object, file_paths = NULL) { + if (is.null(file_paths)) return(object) + +browser() + + return(object) + } +) + + + +# .get_task_descriptor (vimp task) --------------------------------------------- +setMethod( + ".get_task_descriptor", + signature(object = "familiarTaskVimp"), + function(object, ...) { + return(paste0(object@task_name, "_", object@data_id, "_", object@run_id, "_", object@vimp_method)) + } +) + + + +# .perform_task (vimp task , NULL) --------------------------------------------- +setMethod( + ".perform_task", + signature( + object = "familiarTaskVimp", + data = "NULL" + ), + function( + object, + data, + project_info = NULL, + outcome_info = NULL, + ... + ) { + # This method is called when "data" is expected to be available somewhere in + # the backend. + + if (is.null(project_info)) { + ..error_reached_unreachable_code("project_info is required for retrieving data from the backend.") + } + if (is.null(outcome_info)) { + ..error_reached_unreachable_code("outcome_info is required.") + } + + # Find the run list. + run_list <- .get_run_list( + iteration_list = project_info$iter_list, + data_id = object@data_id, + run_id = object@run_id + ) + + # Select unique samples. + sample_identifiers <- .get_sample_identifiers( + run = run_list, + train_or_validate = "train" + ) + sample_identifiers <- unique(sample_identifiers) + + # Create a dataObject. + data <- methods::new( + "dataObject", + data = get_data_from_backend(sample_identifiers = sample_identifiers), + preprocessing_level = "none", + outcome_type = outcome_info@outcome_type, + outcome_info = outcome_info + ) + + # Pass to method that dispatches with dataObject for further processing. + return(.perform_task( + object = object, + data = data, + ... + )) + } +) + + +# .perform_task (vimp task, dataObject) ---------------------------------------- +setMethod( + ".perform_task", + signature( + object = "familiarTaskVimp", + data = "dataObject" + ), + function( + object, + data, + settings = NULL, + feature_info_list = NULL, + hyperparameters = NULL, + message_indent = 0L, + verbose = FALSE, + cl = NULL, + ... + ) { + + logger_message( + paste0( + "\nVariable importance: starting variable importance computation using the \"", + object@vimp_method, "\" method for run ", + object@task_id, " of ", + object@n_tasks, "." + ), + indent = message_indent, + verbose = verbose + ) + + # Check that outcome_info is present on data + if (!is(data@outcome_info, "outcomeInfo")) { + ..error_reached_unreachable_code( + "outcome_info attribute of data (dataObject) does not contain an outcomeInfo object" + ) + } + + # Check and retrieve feature info list. + feature_info_list <- .get_feature_info_list( + object = object, + feature_info_list = feature_info_list, + data = data, + settings = settings, + message_indent = message_indent, + verbose = verbose, + cl = cl, + ... + ) + + if (is.null(hyperparameters) && is.na(object@hyperparameter_file)) { + # Check that a list of hyperparameters is provided, otherwise create an ad- + # hoc list of hyperparameters. + browser() + + } else if (is.null(hyperparameters)) { + # Assume that the hyperparameter file attribute contains the path to the + # file containing hyperparameters. + + } else if (is.character(hyperparameters)) { + # If hyperparameters is a string, interpret this as a path to file + # containing the feature info. + + } + + # Create the variable importance method object or familiar model object to + # compute variable importance with. + vimp_object <- methods::new( + "familiarVimpMethod", + outcome_type = data@outcome_type, + hyperparameters = hyperparameters, + vimp_method = object@vimp_method, + outcome_info = data@outcome_info, + run_table = object@run_table + ) + + # Promote to the correct subclass. + vimp_object <- promote_vimp_method(object = vimp_object) + + # Set multivariate methods. + if (is(vimp_object, "familiarModel")) is_multivariate <- TRUE + if (is(vimp_object, "familiarVimpMethod")) is_multivariate <- vimp_object@multivariate + + # Find required features. Exclude the signature features at this point, as + # these will have been dropped from the variable importance table. + required_features <- get_required_features( + x = data, + feature_info_list = feature_info_list, + exclude_signature = !is_multivariate + ) + + # Limit to required features. + vimp_object@required_features <- required_features + vimp_object@feature_info <- feature_info_list[required_features] + + # Compute variable importance. + vimp_object <- .vimp( + object = vimp_object, + data = data + ) + + if (!is.na(object@file)) { + saveRDS(vimp_object, file = object@file) + } else { + return(vimp_object) + } + + return(invisible(TRUE)) + } +) + + + +# .get_feature_info_list (vimp task) ------------------------------------------- +setMethod( + ".get_feature_info_list", + signature(object = "familiarTask"), + function(object, feature_info_list, ...) { + # Suppress NOTES due to non-standard evaluation in data.table + can_pre_process <- NULL + + # Attempt to get the feature info list from the backend. + if (is.null(feature_info_list) && !is.null(object@run_table)) { + # Find the last entry that is available for pre-processing + pre_processing_run <- tail(object@run_table[can_pre_process == TRUE, ], n = 1L) + + feature_info_list <- get_feature_info_from_backend( + data_id = pre_processing_run$data_id[1L], + run_id = pre_processing_run$run_id[1L] + ) + } + + # If no feature list is present on the backend, check other options. + if (is.null(feature_info_list) && is.na(object@feature_info_file)) { + # Check that a feature info list is provided, otherwise create an ad-hoc + # list as an template. + + # Set up task, and explicitly don't write to file. + generic_feature_info_task <- methods::new( + "familiarTaskFeatureInfo", + project_id = object@project_id, + file = NA_character_ + ) + + # Execute the task. + feature_info_list <- .perform_task( + object = generic_feature_info_task, + ... + ) + + } else if (is.null(feature_info_list)) { + # Assume that the feature info file attribute contains the path to the + # file containing feature info. + + } else if (is.character(feature_info_list)) { + # If the feature info list is a string, interpret this as a path to the + # file containing the feature info. + browser() + } + + return(feature_info_list) + } +) + + + +.generate_vimp_tasks <- function( + file_paths, + project_id +) { + + task_list <- list() + + # Check if vimp should be computed separately or is computed during + # hyperparameter optimisation. + + for (data_id in data_ids) { + for (run_id in run_ids) { + for (vimp_method in vimp_methods) { + + # Check if the variable importance method requires any computation. + # For example, signature_only, none and random do not require + # computation. + + # Set up variable importance computation task. + + # Set up variable importance hyperparameter task. + + } + } + } + + # Check if any vimp-related tasks are required. + if (length(task_list) == 0L) return(NULL) + + # Add tasks related to data processing for vimp methods. + task_list <- c( + task_list, + .generate_vimp_data_preprocessing_tasks( + file_paths = file_paths, + project_id = project_id + ) + ) + + return(task_list) +} diff --git a/R/Utilities.R b/R/Utilities.R index ffce3923..0c78ca6c 100644 --- a/R/Utilities.R +++ b/R/Utilities.R @@ -752,11 +752,11 @@ get_id_columns <- function(id_depth = "repetition", single_column = NULL) { get_object_file_name <- function( - learner, - vimp_method, project_id, data_id, run_id, + learner = NULL, + vimp_method = NULL, pool_data_id = NULL, pool_run_id = NULL, object_type, @@ -783,42 +783,83 @@ get_object_file_name <- function( if (object_type == "familiarModel") { # For familiarModel objects - output_str <- paste0(base_str, "_model") + if (is.null(learner) || is.null(vimp_method) || is.null(project_id)) { + ..error_reached_unreachable_code("missing arguments") + } + + output_str <- paste0( + project_id, "_", + learner, "_", + vimp_method, "_", + data_id, "_", + run_id, + "_model" + ) } else if (object_type == "familiarEnsemble") { # For familiarEnsemble objects - if (is.null(is_ensemble)) { - ..error("The \"is_ensemble\" parameter is not set to TRUE or FALSE.") + if (is.null(learner) || is.null(vimp_method) || is.null(project_id) || is.null(is_ensemble)) { + ..error_reached_unreachable_code("missing arguments") } output_str <- paste0( - base_str, "_", + project_id, "_", + learner, "_", + vimp_method, "_", + data_id, "_", + run_id, "_", ifelse(is_ensemble, "ensemble", "pool") ) } else if (object_type == "familiarData") { # For familiarData objects - if (is.null(is_ensemble)) { - ..error("The \"is_ensemble\" parameter is not set to TRUE or FALSE.") - } - - if (is.null(is_validation)) { - ..error("The \"is_validation\" parameter is not set to TRUE or FALSE.") - } - - if (is.null(pool_data_id) || is.null(pool_run_id)) { - ..error("pool_data_id and pool_run_id should be provided.") + if ( + is.null(learner) || is.null(vimp_method) || is.null(project_id) || + is.null(is_ensemble) || is.null(is_validation) || is.null(pool_data_id) || + is.null(pool_run_id) + ) { + ..error_reached_unreachable_code("missing arguments") } output_str <- paste0( - base_str, "_", + project_id, "_", + learner, "_", + vimp_method, "_", + data_id, "_", + run_id, "_", ifelse(is_ensemble, "ensemble", "pool"), "_", pool_data_id, "_", pool_run_id, "_", ifelse(is_validation, "validation", "development"), "_data" ) + + } else if (object_type == "featureInfo") { + # Complete feature info objects. + + if (is.null(project_id)) { + ..error_reached_unreachable_code("missing arguments") + } + + output_str <- paste0( + project_id, "_", + data_id, "_", + run_id, + "_feature_info.RDS" + ) + + } else if (object_type == "genericFeatureInfo") { + # Generic feature info objects. + + if (is.null(project_id)) { + ..error_reached_unreachable_code("missing arguments") + } + + output_str <- paste0(project_id, "_generic_feature_info") + + } else { + ..error_reached_unreachable_code(paste0("unknown object_type: ", object_type)) } # Add extension