Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
alexzwanenburg committed Nov 1, 2024
1 parent 8693001 commit f378d7f
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 24 deletions.
103 changes: 94 additions & 9 deletions R/TaskVimp.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@ setMethod(
function(object, file_paths = NULL) {
if (is.null(file_paths)) return(object)

browser()
# Generate file name of variable importance table
object@file <- get_object_file_name(
object_type = "vimpTable",
project_id = object@project_id,
dir_path = file_paths$vimp_dir
)

return(object)
}
Expand Down Expand Up @@ -151,6 +156,13 @@ setMethod(
...
)

# Check and retrieve hyperparameters.
hyperparameters <- .get_hyperparameters(
object = object,
hyperparameters = hyperparameters,
data = data
)

if (is.null(hyperparameters) && is.na(object@hyperparameter_file)) {
# Check that a list of hyperparameters is provided, otherwise create an ad-
# hoc list of hyperparameters.
Expand Down Expand Up @@ -197,15 +209,15 @@ setMethod(
vimp_object@feature_info <- feature_info_list[required_features]

# Compute variable importance.
vimp_object <- .vimp(
vimp_table <- .vimp(
object = vimp_object,
data = data
)

if (!is.na(object@file)) {
saveRDS(vimp_object, file = object@file)
saveRDS(vimp_table, file = object@file)
} else {
return(vimp_object)
return(vimp_table)
}

return(invisible(TRUE))
Expand All @@ -217,7 +229,7 @@ setMethod(
# .get_feature_info_list (vimp task) -------------------------------------------
setMethod(
".get_feature_info_list",
signature(object = "familiarTask"),
signature(object = "familiarTaskVimp"),
function(object, feature_info_list, ...) {
# Suppress NOTES due to non-standard evaluation in data.table
can_pre_process <- NULL
Expand All @@ -227,9 +239,12 @@ setMethod(
# Find the last entry that is available for pre-processing
pre_processing_run <- tail(object@run_table[can_pre_process == TRUE, ], n = 1L)

feature_info_list <- get_feature_info_from_backend(
data_id = pre_processing_run$data_id[1L],
run_id = pre_processing_run$run_id[1L]
feature_info_list <- tryCatch(
get_feature_info_from_backend(
data_id = pre_processing_run$data_id[1L],
run_id = pre_processing_run$run_id[1L]
),
error = NULL
)
}

Expand All @@ -254,11 +269,24 @@ setMethod(
} else if (is.null(feature_info_list)) {
# Assume that the feature info file attribute contains the path to the
# file containing feature info.
if (!file.exists(object@feature_info_file)) {
..error(paste0("feature info file does not exist at location: ", object@feature_info_file))
}
feature_info_list <- readRDS(object@feature_info_file)
feature_info_list <- update_object(feature_info_list)

} else if (is.character(feature_info_list)) {
# If the feature info list is a string, interpret this as a path to the
# file containing the feature info.
browser()
if (!file.exists(feature_info_list)) {
..error(paste0("feature info file does not exist at location: ", feature_info_list))
}
feature_info_list <- readRDS(feature_info_list)
feature_info_list <- update_object(feature_info_list)
}

if (is.null(feature_info_list)) {
..error("no feature info objects were found.")
}

return(feature_info_list)
Expand All @@ -267,6 +295,63 @@ setMethod(



setMethod(
".get_hyperparameters",
signature(object = "familiarTaskVimp"),
function(object, hyperparameters, ...) {

if (is.null(hyperparameters) && is.na(object@hyperparameter_file)) {



} else if (is.null(hyperparameters)) {
# Identify the right entry on hpo_list.
for (ii in rev(run$run_table$perturb_level)) {
run_id_list <- .get_iteration_identifiers(
run = run,
perturb_level = ii
)

# Check whether there are any matching data and run ids by determining the
# number of rows in the table after matching
match_hpo <- sapply(
hpo_list,
function(iter_hpo, run_id_list) {
# Determine if there are any rows in the run_table of the parameter list
# that match the data and run identifiers of the current level.
match_size <- nrow(iter_hpo@run_table[data_id == run_id_list$data & run_id == run_id_list$run])

# Return TRUE if any matching rows are found.
return(match_size > 0L)
},
run_id_list = run_id_list
)

# If there is a match, we step out of the loop
if (any(match_hpo)) break
}

# Extract the table of parameters
if (allow_random_selection && sum(match_hpo) > 1L) {
random_set <- sample(which(match_hpo), size = 1L)

object <- hpo_list[[random_set]]

} else {
object <- hpo_list[match_hpo][[1L]]
}

if (as_list) {
return(object@hyperparameters)
} else {
return(object)
}
}
}
)



.generate_vimp_tasks <- function(
file_paths,
project_id
Expand Down
30 changes: 15 additions & 15 deletions R/Utilities.R
Original file line number Diff line number Diff line change
Expand Up @@ -765,20 +765,6 @@ get_object_file_name <- function(
with_extension = TRUE,
dir_path = NULL
) {
# Generate file name for an object

if (!object_type %in% c("familiarModel", "familiarEnsemble", "familiarData")) {
..error("The object type was not recognised.")
}

# Generate the basic string
base_str <- paste0(
project_id, "_",
learner, "_",
vimp_method, "_",
data_id, "_",
run_id
)

if (object_type == "familiarModel") {
# For familiarModel objects
Expand Down Expand Up @@ -835,6 +821,20 @@ get_object_file_name <- function(
ifelse(is_validation, "validation", "development"), "_data"
)

} else if (object_type == "vimpTable") {

if (is.null(vimp_method) || is.null(project_id)) {
..error_reached_unreachable_code("missing arguments")
}

output_str <- paste0(
project_id, "_",
vimp_method, "_",
data_id, "_",
run_id,
"_vimp"
)

} else if (object_type == "featureInfo") {
# Complete feature info objects.

Expand All @@ -846,7 +846,7 @@ get_object_file_name <- function(
project_id, "_",
data_id, "_",
run_id,
"_feature_info.RDS"
"_feature_info"
)

} else if (object_type == "genericFeatureInfo") {
Expand Down

0 comments on commit f378d7f

Please sign in to comment.