From d0b2db55614121b44cd9c2ef6338573ead7a2997 Mon Sep 17 00:00:00 2001 From: Alex Zwanenburg <alexander.zwanenburg@nct-dresden.de> Date: Fri, 3 Jan 2025 12:12:18 +0100 Subject: [PATCH] Naive models now yield the correct number of predicted values --- NEWS.md | 3 +++ R/LearnerS4Naive.R | 9 +++++++++ R/UtilitiesS4.R | 47 +++++++++++++++++++++++++++++++++++----------- 3 files changed, 48 insertions(+), 11 deletions(-) diff --git a/NEWS.md b/NEWS.md index 359eb367..9aff5672 100644 --- a/NEWS.md +++ b/NEWS.md @@ -97,6 +97,9 @@ - Some vignettes referred to `experiment_design` where `experimental_design` was intended. + +- Fixed an error where naive models would yield an incorrect number of predicted + values when samples appear multiple times (e.g. in bootstraps). # Version 1.5.0 (Whole Whale) diff --git a/R/LearnerS4Naive.R b/R/LearnerS4Naive.R index bfe7e7a1..e14fdeb0 100644 --- a/R/LearnerS4Naive.R +++ b/R/LearnerS4Naive.R @@ -91,8 +91,11 @@ setMethod( ... ) { + # Check n_samples, and then ensure that n_samples is equal to the number + # of rows in the dataset. n_samples <- get_n_samples(data) if (n_samples == 0L) return(callNextMethod()) + n_samples <- nrow(data@data) if (type == "default") { # default ---------------------------------------------------------------- @@ -184,8 +187,11 @@ setMethod( ... ) { + # Check n_samples, and then ensure that n_samples is equal to the number + # of rows in the dataset. n_samples <- get_n_samples(data) if (n_samples == 0L) return(callNextMethod()) + n_samples <- nrow(data@data) if (object@outcome_type == "survival" && type == "default") { # For survival outcomes based on survival times, predict the average @@ -268,8 +274,11 @@ setMethod( ... ) { + # Check n_samples, and then ensure that n_samples is equal to the number + # of rows in the dataset. n_samples <- get_n_samples(data) if (n_samples == 0L) return(callNextMethod()) + n_samples <- nrow(data@data) if (object@outcome_type %in% c("survival") && type == "default") { # For survival outcomes based on survival times, predict the average diff --git a/R/UtilitiesS4.R b/R/UtilitiesS4.R index c7041a28..52d31764 100644 --- a/R/UtilitiesS4.R +++ b/R/UtilitiesS4.R @@ -463,20 +463,40 @@ setMethod( # get_n_samples ---------------------------------------------------------------- -setMethod("get_n_samples", signature(x = "data.table"), function(x, id_depth = "sample") { - return(.get_n_samples(x = x, id_depth = id_depth)) -}) +setMethod( + "get_n_samples", + signature(x = "data.table"), + function(x, id_depth = "sample", count_unique = TRUE) { + return(.get_n_samples( + x = x, + id_depth = id_depth, + count_unique = count_unique + )) + } +) -setMethod("get_n_samples", signature(x = "dataObject"), function(x, id_depth = "sample") { - return(.get_n_samples(x = x@data, id_depth = id_depth)) -}) +setMethod( + "get_n_samples", + signature(x = "dataObject"), + function(x, id_depth = "sample", count_unique = TRUE) { + return(.get_n_samples( + x = x@data, + id_depth = id_depth, + count_unique = count_unique + )) + } +) -setMethod("get_n_samples", signature(x = "NULL"), function(x, id_depth = "sample") { - return(0L) -}) +setMethod( + "get_n_samples", + signature(x = "NULL"), + function(x, id_depth = "sample", ...) { + return(0L) + } +) -.get_n_samples <- function(x, id_depth) { +.get_n_samples <- function(x, id_depth, count_unique) { # Check if x is empty. if (is_empty(x)) { return(0L) @@ -487,7 +507,12 @@ setMethod("get_n_samples", signature(x = "NULL"), function(x, id_depth = "sample # Return the number of rows with unique values for the combination of # identifier columns. - return(nrow(unique(x[, mget(id_columns)]))) + if (count_unique) { + return(nrow(unique(x[, mget(id_columns)]))) + + } else { + return(nrow(x[, mget(id_columns)])) + } }