From f378d7f3db56d82c1be95f2e18ef45525a02cd86 Mon Sep 17 00:00:00 2001 From: Alex Zwanenburg Date: Fri, 1 Nov 2024 16:49:19 +0100 Subject: [PATCH] WIP --- R/TaskVimp.R | 103 +++++++++++++++++++++++++++++++++++++++++++++----- R/Utilities.R | 30 +++++++-------- 2 files changed, 109 insertions(+), 24 deletions(-) diff --git a/R/TaskVimp.R b/R/TaskVimp.R index 546266a4..b28f81b0 100644 --- a/R/TaskVimp.R +++ b/R/TaskVimp.R @@ -26,7 +26,12 @@ setMethod( function(object, file_paths = NULL) { if (is.null(file_paths)) return(object) -browser() + # Generate file name of variable importance table + object@file <- get_object_file_name( + object_type = "vimpTable", + project_id = object@project_id, + dir_path = file_paths$vimp_dir + ) return(object) } @@ -151,6 +156,13 @@ setMethod( ... ) + # Check and retrieve hyperparameters. + hyperparameters <- .get_hyperparameters( + object = object, + hyperparameters = hyperparameters, + data = data + ) + if (is.null(hyperparameters) && is.na(object@hyperparameter_file)) { # Check that a list of hyperparameters is provided, otherwise create an ad- # hoc list of hyperparameters. @@ -197,15 +209,15 @@ setMethod( vimp_object@feature_info <- feature_info_list[required_features] # Compute variable importance. - vimp_object <- .vimp( + vimp_table <- .vimp( object = vimp_object, data = data ) if (!is.na(object@file)) { - saveRDS(vimp_object, file = object@file) + saveRDS(vimp_table, file = object@file) } else { - return(vimp_object) + return(vimp_table) } return(invisible(TRUE)) @@ -217,7 +229,7 @@ setMethod( # .get_feature_info_list (vimp task) ------------------------------------------- setMethod( ".get_feature_info_list", - signature(object = "familiarTask"), + signature(object = "familiarTaskVimp"), function(object, feature_info_list, ...) { # Suppress NOTES due to non-standard evaluation in data.table can_pre_process <- NULL @@ -227,9 +239,12 @@ setMethod( # Find the last entry that is available for pre-processing pre_processing_run <- tail(object@run_table[can_pre_process == TRUE, ], n = 1L) - feature_info_list <- get_feature_info_from_backend( - data_id = pre_processing_run$data_id[1L], - run_id = pre_processing_run$run_id[1L] + feature_info_list <- tryCatch( + get_feature_info_from_backend( + data_id = pre_processing_run$data_id[1L], + run_id = pre_processing_run$run_id[1L] + ), + error = NULL ) } @@ -254,11 +269,24 @@ setMethod( } else if (is.null(feature_info_list)) { # Assume that the feature info file attribute contains the path to the # file containing feature info. + if (!file.exists(object@feature_info_file)) { + ..error(paste0("feature info file does not exist at location: ", object@feature_info_file)) + } + feature_info_list <- readRDS(object@feature_info_file) + feature_info_list <- update_object(feature_info_list) } else if (is.character(feature_info_list)) { # If the feature info list is a string, interpret this as a path to the # file containing the feature info. - browser() + if (!file.exists(feature_info_list)) { + ..error(paste0("feature info file does not exist at location: ", feature_info_list)) + } + feature_info_list <- readRDS(feature_info_list) + feature_info_list <- update_object(feature_info_list) + } + + if (is.null(feature_info_list)) { + ..error("no feature info objects were found.") } return(feature_info_list) @@ -267,6 +295,63 @@ setMethod( +setMethod( + ".get_hyperparameters", + signature(object = "familiarTaskVimp"), + function(object, hyperparameters, ...) { + + if (is.null(hyperparameters) && is.na(object@hyperparameter_file)) { + + + + } else if (is.null(hyperparameters)) { + # Identify the right entry on hpo_list. + for (ii in rev(run$run_table$perturb_level)) { + run_id_list <- .get_iteration_identifiers( + run = run, + perturb_level = ii + ) + + # Check whether there are any matching data and run ids by determining the + # number of rows in the table after matching + match_hpo <- sapply( + hpo_list, + function(iter_hpo, run_id_list) { + # Determine if there are any rows in the run_table of the parameter list + # that match the data and run identifiers of the current level. + match_size <- nrow(iter_hpo@run_table[data_id == run_id_list$data & run_id == run_id_list$run]) + + # Return TRUE if any matching rows are found. + return(match_size > 0L) + }, + run_id_list = run_id_list + ) + + # If there is a match, we step out of the loop + if (any(match_hpo)) break + } + + # Extract the table of parameters + if (allow_random_selection && sum(match_hpo) > 1L) { + random_set <- sample(which(match_hpo), size = 1L) + + object <- hpo_list[[random_set]] + + } else { + object <- hpo_list[match_hpo][[1L]] + } + + if (as_list) { + return(object@hyperparameters) + } else { + return(object) + } + } + } +) + + + .generate_vimp_tasks <- function( file_paths, project_id diff --git a/R/Utilities.R b/R/Utilities.R index 0c78ca6c..4ace067e 100644 --- a/R/Utilities.R +++ b/R/Utilities.R @@ -765,20 +765,6 @@ get_object_file_name <- function( with_extension = TRUE, dir_path = NULL ) { - # Generate file name for an object - - if (!object_type %in% c("familiarModel", "familiarEnsemble", "familiarData")) { - ..error("The object type was not recognised.") - } - - # Generate the basic string - base_str <- paste0( - project_id, "_", - learner, "_", - vimp_method, "_", - data_id, "_", - run_id - ) if (object_type == "familiarModel") { # For familiarModel objects @@ -835,6 +821,20 @@ get_object_file_name <- function( ifelse(is_validation, "validation", "development"), "_data" ) + } else if (object_type == "vimpTable") { + + if (is.null(vimp_method) || is.null(project_id)) { + ..error_reached_unreachable_code("missing arguments") + } + + output_str <- paste0( + project_id, "_", + vimp_method, "_", + data_id, "_", + run_id, + "_vimp" + ) + } else if (object_type == "featureInfo") { # Complete feature info objects. @@ -846,7 +846,7 @@ get_object_file_name <- function( project_id, "_", data_id, "_", run_id, - "_feature_info.RDS" + "_feature_info" ) } else if (object_type == "genericFeatureInfo") {