Skip to content

Commit

Permalink
WIP on enabling experiments without an explicit feature selection step.
Browse files Browse the repository at this point in the history
  • Loading branch information
alexzwanenburg committed Nov 22, 2024
1 parent 97baae7 commit c0a2a25
Show file tree
Hide file tree
Showing 9 changed files with 265 additions and 157 deletions.
107 changes: 55 additions & 52 deletions R/ExperimentSetup.R
Original file line number Diff line number Diff line change
Expand Up @@ -543,60 +543,63 @@ extract_experimental_setup <- function(
)

} else {
# Feature selection first
main_message <- "Setup report: Feature selection on"

# Iteratively append message
dt_sub <- section_table[vimp == TRUE, ]
curr_ref_data_id <- dt_sub$main_data_id[1L]

while (curr_ref_data_id > 0L) {
if (any(section_table$vimp)) {
# Feature selection first
main_message <- "Setup report: Feature selection on"

dt_sub <- section_table[main_data_id == curr_ref_data_id, ]
# Iteratively append message
dt_sub <- section_table[vimp == TRUE, ]
curr_ref_data_id <- dt_sub$main_data_id[1L]

if (dt_sub$perturb_method[1L] == "main") {
main_message <- c(
main_message,
"the training data."
)
while (curr_ref_data_id > 0L) {

} else if (dt_sub$perturb_method[1L] %in% c("limited_bootstrap", "full_bootstrap")) {
main_message <- c(
main_message,
paste0(dt_sub$perturb_n_rep[1L], " bootstraps of")
)
dt_sub <- section_table[main_data_id == curr_ref_data_id, ]

} else if (dt_sub$perturb_method[1L] == "cross_val") {
main_message <- c(
main_message,
paste0(
dt_sub$perturb_n_rep[1L], " repetitions of ",
dt_sub$perturb_n_folds, "-fold cross validation of"
if (dt_sub$perturb_method[1L] == "main") {
main_message <- c(
main_message,
"the training data."
)
)

} else if (dt_sub$perturb_method[1L] == "loocv") {
main_message <- c(
main_message,
"folds of leave-one-out-cross-validation of"
)

} else if (dt_sub$perturb_method[1L] %in% c("limited_bootstrap", "full_bootstrap")) {
main_message <- c(
main_message,
paste0(dt_sub$perturb_n_rep[1L], " bootstraps of")
)

} else if (dt_sub$perturb_method[1L] == "cross_val") {
main_message <- c(
main_message,
paste0(
dt_sub$perturb_n_rep[1L], " repetitions of ",
dt_sub$perturb_n_folds, "-fold cross validation of"
)
)

} else if (dt_sub$perturb_method[1L] == "loocv") {
main_message <- c(
main_message,
"folds of leave-one-out-cross-validation of"
)

} else if (dt_sub$perturb_method[1L] == "imbalance_partition") {
main_message <- c(
main_message,
"class-balanced partitions of"
)
}

} else if (dt_sub$perturb_method[1L] == "imbalance_partition") {
main_message <- c(
main_message,
"class-balanced partitions of"
)
curr_ref_data_id <- dt_sub$ref_data_id[1L]
}

curr_ref_data_id <- dt_sub$ref_data_id[1L]
logger_message(
paste0(main_message, collapse = " "),
indent = message_indent,
verbose = verbose
)
}

logger_message(
paste0(main_message, collapse = " "),
indent = message_indent,
verbose = verbose
)

# Model building second
main_message <- "Setup report: Model building on"

Expand Down Expand Up @@ -666,15 +669,15 @@ extract_experimental_setup <- function(
)
}

if (sum(section_table$vimp) == 0L) {
..error(
paste0(
"The fs component for variable importance computation must appear in the ",
"experimental design. It was not found."
),
error_class = "input_argument_error"
)
}
# if (sum(section_table$vimp) == 0L) {
# ..error(
# paste0(
# "The fs component for variable importance computation must appear in the ",
# "experimental design. It was not found."
# ),
# error_class = "input_argument_error"
# )
# }

if (sum(section_table$train) > 1L) {
..error(
Expand Down
5 changes: 5 additions & 0 deletions R/FamiliarS4Classes.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
#' detect out-of-distribution samples.
#' @slot learner Learning algorithm used to create the model.
#' @slot vimp_method Method used to determine variable importance for the model.
#' @slot vimp_table Variable importance table or list of variable importance
#' tables for the model.
#' @slot required_features The set of features required for complete
#' reproduction, i.e. with imputation.
#' @slot model_features The set of features that is used to train the model,
Expand Down Expand Up @@ -80,6 +82,8 @@ setClass("familiarModel",
learner = "character",
# Name of variable importance method
vimp_method = "character",
# Variable importance table
vimp_table = "ANY",
# Required features for complete reconstruction, including imputation.
required_features = "ANY",
# Features that are required for the model.
Expand Down Expand Up @@ -123,6 +127,7 @@ setClass("familiarModel",
novelty_detector = NULL,
learner = NA_character_,
vimp_method = NA_character_,
vimp_table = NA_character_,
required_features = NULL,
model_features = NULL,
novelty_features = NULL,
Expand Down
21 changes: 13 additions & 8 deletions R/HyperparameterOptimisation.R
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,8 @@ setMethod(
function(
object,
data,
vimp_aggregation_method,
vimp_rank_threshold,
cl = NULL,
experiment_info = NULL,
user_list = NULL,
Expand All @@ -463,7 +465,6 @@ setMethod(
grid_initialisation_method = "fixed_subsample",
exploration_method = "successive_halving",
n_random_sets = 100L,
determine_vimp = TRUE,
measure_time = TRUE,
hyperparameter_learner = "gaussian_process",
n_max_bootstraps = 20L,
Expand Down Expand Up @@ -728,11 +729,12 @@ setMethod(
}

## Create or obtain variable importance ------------------------------------
rank_table_list <- .compute_hyperparameter_variable_importance(
vimp_table_list <- .compute_hyperparameter_variable_importance(
cl = cl,
determine_vimp = determine_vimp,
object = object,
data = data,
vimp_aggregation_method = vimp_aggregation_method,
vimp_rank_threshold = vimp_rank_threshold,
bootstraps = bootstraps$train_list,
metric = metric,
measure_time = measure_time,
Expand All @@ -759,7 +761,7 @@ setMethod(
# Set signature size.
user_list$sign_size <- .set_signature_size(
object = object,
rank_table_list = rank_table_list,
rank_table_list = get_vimp_table(vimp_table_list),
suggested_range = user_list$sign_size
)

Expand Down Expand Up @@ -902,7 +904,7 @@ setMethod(
run_table = run_table,
bootstraps = bootstraps,
data = data,
rank_table_list = rank_table_list,
rank_table_list = get_vimp_table(vimp_table_list),
parameter_table = parameter_table,
metric_objects = metric_object_list,
iteration_id = 0L,
Expand Down Expand Up @@ -955,7 +957,7 @@ setMethod(
run_table = run_table,
bootstraps = bootstraps,
data = data,
rank_table_list = rank_table_list,
rank_table_list = get_vimp_table(vimp_table_list),
parameter_table = parameter_table,
metric_objects = metric_object_list,
iteration_id = 0L,
Expand Down Expand Up @@ -999,7 +1001,7 @@ setMethod(
run_table = run_table,
bootstraps = bootstraps,
data = data,
rank_table_list = rank_table_list,
rank_table_list = get_vimp_table(vimp_table_list),
parameter_table = parameter_table,
iteration_id = 0L,
metric_objects = metric_object_list,
Expand Down Expand Up @@ -1161,7 +1163,7 @@ setMethod(
run_table = run_table,
bootstraps = bootstraps,
data = data,
rank_table_list = rank_table_list,
rank_table_list = get_vimp_table(vimp_table_list),
parameter_table = parameter_table,
metric_objects = metric_object_list,
iteration_id = optimisation_step + 1L,
Expand Down Expand Up @@ -1407,6 +1409,9 @@ setMethod(
"n_features" = get_n_features(data)
)

# Attach variable importance tables.
object@vimp_table <- decluster_vimp_table(vimp_table_list)

return(object)
}
)
Expand Down
114 changes: 73 additions & 41 deletions R/HyperparameterOptimisationUtilities.R
Original file line number Diff line number Diff line change
Expand Up @@ -350,9 +350,10 @@

.compute_hyperparameter_variable_importance <- function(
cl = NULL,
determine_vimp = TRUE,
object,
data,
vimp_aggregation_method,
vimp_rank_threshold,
bootstraps,
verbose,
message_indent,
Expand All @@ -366,38 +367,61 @@
object@vimp_method %in% .get_available_no_features_vimp_methods()
) {
return(NULL)

} else if (is(object@vimp_table, "vimpTable") || rlang::is_bare_list(object@vimp_table)) {
# Existing vimp_tables.
vimp_table <- object@vimp_table

vimp_table <- update_vimp_table_to_reference(
x = vimp_table,
reference_cluster_table = .create_clustering_table(
feature_info_list = object@feature_info
)
)

# Form clusters.
vimp_table <- recluster_vimp_table(vimp_table)

# Aggregate to single table.
vimp_table <- aggregate_vimp_table(
vimp_table,
aggregation_method = vimp_aggregation_method,
rank_threshold = vimp_rank_threshold
)

} else {

logger_message(
paste0(
"Computing variable importance for ",
length(bootstraps), " bootstraps."
),
indent = message_indent + 1L,
verbose = verbose
)

# Spawn task to obtain variable importance tables.
vimp_task <- methods::new(
"familiarTaskVimp",
project_id = object@project_id,
vimp_method = object@vimp_method,
file = NA_character_
)

vimp_table <- fam_lapply(
cl = cl,
assign = "all",
X = bootstraps,
FUN = ..compute_hyperparameter_variable_importance,
vimp_task = vimp_task,
data = data,
feature_info = object@feature_info,
progress_bar = verbose,
chopchop = TRUE
)
}

logger_message(
paste0(
"Computing variable importance for ",
length(bootstraps), " bootstraps."
),
indent = message_indent + 1L,
verbose = verbose
)

# Spawn task to obtain variable importance tables.
vimp_task <- methods::new(
"familiarTaskVimp",
project_id = object@project_id,
vimp_method = object@vimp_method,
file = NA_character_
)

vimp_list <- fam_lapply(
cl = cl,
assign = "all",
X = bootstraps,
FUN = ..compute_hyperparameter_variable_importance,
vimp_task = vimp_task,
data = data,
feature_info = object@feature_info,
progress_bar = verbose,
chopchop = TRUE
)

return(vimp_list)
return(vimp_table)
}


Expand Down Expand Up @@ -425,9 +449,6 @@
# Form clusters.
vimp_table <- recluster_vimp_table(vimp_table)

# Compute variable importance.
vimp_table <- get_vimp_table(vimp_table)

return(vimp_table)
}

Expand Down Expand Up @@ -475,13 +496,24 @@
parameter_table = parameter_table
)

# Generate variable importance sets (if any)
rank_table_list <- lapply(
run_table$run_id,
function(ii, rank_table_list) (rank_table_list[[ii]]),
rank_table_list = rank_table_list
)

# Replicate single variable importance features.
if (data.table::is.data.table(rank_table_list)) {
browser()
rank_table_list <- lapply(
run_table$run_id,
function(ii, x) (data.table::copy(x)),
x = rank_table_list
)

} else {
# Generate variable importance sets (if any)
rank_table_list <- lapply(
run_table$run_id,
function(ii, rank_table_list) (rank_table_list[[ii]]),
rank_table_list = rank_table_list
)
}

if (is.null(time_optimisation_model)) {
# Create a scoring table, with accompanying information.
score_results <- fam_mapply_lb(
Expand Down
Loading

0 comments on commit c0a2a25

Please sign in to comment.