Skip to content

Commit

Permalink
WIP on feature selection-less training.
Browse files Browse the repository at this point in the history
  • Loading branch information
alexzwanenburg committed Nov 26, 2024
1 parent c0a2a25 commit b81e1c8
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 22 deletions.
12 changes: 8 additions & 4 deletions R/Familiar.R
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ summon_familiar <- function(
if (.stop_after == "training") {
tasks <- .generate_trainer_tasks(
experiment_data = experiment_data,
optimisation_determine_vimp = settings$hpo$hpo_determine_vimp,
vimp_methods = settings$vimp$vimp_methods,
learners = settings$mb$learners,
file_paths = file_paths
Expand Down Expand Up @@ -442,10 +443,13 @@ summon_familiar <- function(

# Check if the process should be stopped at this point.
if (.stop_after %in% c("vimp")) {
feature_info <- get_feature_info_from_backend(
data_id = waiver(),
run_id = waiver()
)
feature_info <- NULL
if (!is_empty(tasks$feature_info)) {
feature_info <- get_feature_info_from_backend(
data_id = waiver(),
run_id = waiver()
)
}

vimp_hyperparameters <- NULL
if (!is_empty(tasks$hyperparameters_vimp)) {
Expand Down
2 changes: 1 addition & 1 deletion R/HyperparameterOptimisation.R
Original file line number Diff line number Diff line change
Expand Up @@ -1410,7 +1410,7 @@ setMethod(
)

# Attach variable importance tables.
object@vimp_table <- decluster_vimp_table(vimp_table_list)
object@vimp_table <- vimp_table_list

return(object)
}
Expand Down
5 changes: 1 addition & 4 deletions R/HyperparameterOptimisationUtilities.R
Original file line number Diff line number Diff line change
Expand Up @@ -445,10 +445,7 @@
feature_info_list = feature_info,
data = data
)

# Form clusters.
vimp_table <- recluster_vimp_table(vimp_table)


return(vimp_table)
}

Expand Down
26 changes: 17 additions & 9 deletions R/TaskFeatureInfo.R
Original file line number Diff line number Diff line change
Expand Up @@ -489,18 +489,26 @@ setMethod(
data_id <- run_id <- list_name <- complete <- NULL

# Create or load generic feature info.
if (!.file_exists(tasks$generic_feature_info[[1L]])) {
generic_feature_info <- .perform_task(
object = tasks$generic_feature_info[[1L]],
data = NULL,
experiment_data = experiment_data,
...
)
if (!is_empty(tasks$generic_feature_info)) {
if (!.file_exists(tasks$generic_feature_info[[1L]])) {
generic_feature_info <- .perform_task(
object = tasks$generic_feature_info[[1L]],
data = NULL,
experiment_data = experiment_data,
...
)

} else {
generic_feature_info <- readRDS(tasks$generic_feature_info[[1L]]@file)
}

} else {
generic_feature_info <- readRDS(tasks$generic_feature_info[[1L]]@file)
generic_feature_info <- NULL
}

# Check that any feature info tasks are required.
if (is_empty(tasks$feature_info)) return(invisible(FALSE))

# Determine which feature info objects need to be obtained.
finished_tasks <- sapply(tasks$feature_info, .file_exists)
unfinished_tasks <- tasks$feature_info[!finished_tasks]
Expand All @@ -510,7 +518,7 @@ setMethod(
if (length(unfinished_tasks) > 0L) {
..run_preprocessing(
tasks = unfinished_tasks,
generic_feature_info <- generic_feature_info,
generic_feature_info = generic_feature_info,
experiment_data = experiment_data,
...
)
Expand Down
15 changes: 14 additions & 1 deletion R/TaskLearn.R
Original file line number Diff line number Diff line change
Expand Up @@ -416,13 +416,14 @@ setMethod(

.generate_trainer_tasks <- function(
experiment_data,
optimisation_determine_vimp,
vimp_methods,
learners,
file_paths,
skip_existing = FALSE
) {
# Suppress NOTES due to non-standard evaluation in data.table
train <- main_data_id <- can_pre_process <- NULL
train <- main_data_id <- can_pre_process <- vimp <- NULL

# Find the data_id related to model training.
data_id <- experiment_data@experiment_setup[train == TRUE, ]$main_data_id[1L]
Expand Down Expand Up @@ -474,6 +475,17 @@ setMethod(

# learner hyperparameter tasks -----------------------------------------------

# Check how variable importance data should be handled.
if (is_empty(experiment_data@experiment_setup[vimp == TRUE, ])) {
use_vimp <- "return_hpo_vimp"

} else if (optimisation_determine_vimp) {
use_vimp <- "use_hpo_vimp"

} else {
use_vimp <- "use_main_vimp"
}

# Set up variable importance hyperparameter task.
train_run_table <- .get_run_table_from_experiment_setup(
data_id = data_id,
Expand All @@ -495,6 +507,7 @@ setMethod(
"familiarTaskLearnerHyperparameters",
data_id = learner_hyperparameter_data_id,
run_id = run_id,
use_vimp = use_vimp,
vimp_method = vimp_method,
learner = learner,
run_table = run_tables,
Expand Down
2 changes: 1 addition & 1 deletion R/TaskVimp.R
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ setMethod(

# Find the data_id related to computing variable importance.
data_id <- experiment_data@experiment_setup[vimp == TRUE, ]$main_data_id[1L]
if (is_empty(data_id)) return(NULL)
if (is.na(data_id)) return(NULL)

# Initialise empty list.
task_list <- list()
Expand Down
4 changes: 2 additions & 2 deletions R/VimpTable.R
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ setMethod(
function(x, ...) {
# If the list is empty, return NULL instead.
if (is_empty(x)) return(NULL)

browser()
# Dispatch to method for single variable importance tables.
return(lapply(
x,
Expand All @@ -388,7 +388,7 @@ setMethod(
"decluster_vimp_table",
signature(x = "vimpTable"),
function(x, show_weights = FALSE, show_cluster_name = FALSE, ...) {
browser()

# Check if the table has already been declustered.
if (.as_vimp_table_state(x@state) >= "declustered") return(x)

Expand Down
38 changes: 38 additions & 0 deletions tests/testthat/test-task_based_workflow.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,41 @@ model <- familiar::train_familiar(
verbose = TRUE,
parallel = FALSE
)


# Check without explicit variable importance computation -----------------------
# Create variable importance
experiment_vimp <- familiar::precompute_vimp(
data = data,
experimental_design = "bs(mb,3)",
vimp_method = "mim",
outcome_type = "binomial",
outcome_column = "outcome",
batch_id_column = "batch_id",
sample_id_column = "sample_id",
series_id_column = "series_id",
class_levels = c("red", "green"),
verbose = TRUE,
parallel = FALSE
)

testthat::test_that("variable importance data is absent", {
testthat::expect_null(experiment_vimp@feature_info)
testthat::expect_null(experiment_vimp@vimp_table_list)
})

# Train model
model <- familiar::train_familiar(
data = data,
experiment_data = experiment_vimp,
vimp_method = "mim",
learner = "glm_logistic",
outcome_type = "binomial",
outcome_column = "outcome",
batch_id_column = "batch_id",
sample_id_column = "sample_id",
series_id_column = "series_id",
class_levels = c("red", "green"),
verbose = TRUE,
parallel = FALSE
)

0 comments on commit b81e1c8

Please sign in to comment.