From d0b2db55614121b44cd9c2ef6338573ead7a2997 Mon Sep 17 00:00:00 2001
From: Alex Zwanenburg <alexander.zwanenburg@nct-dresden.de>
Date: Fri, 3 Jan 2025 12:12:18 +0100
Subject: [PATCH] Naive models now yield the correct number of predicted values

---
 NEWS.md            |  3 +++
 R/LearnerS4Naive.R |  9 +++++++++
 R/UtilitiesS4.R    | 47 +++++++++++++++++++++++++++++++++++-----------
 3 files changed, 48 insertions(+), 11 deletions(-)

diff --git a/NEWS.md b/NEWS.md
index 359eb367..9aff5672 100644
--- a/NEWS.md
+++ b/NEWS.md
@@ -97,6 +97,9 @@
   
 - Some vignettes referred to `experiment_design` where `experimental_design` was
   intended.
+  
+- Fixed an error where naive models would yield an incorrect number of predicted
+  values when samples appear multiple times (e.g. in bootstraps).
 
 # Version 1.5.0 (Whole Whale)
 
diff --git a/R/LearnerS4Naive.R b/R/LearnerS4Naive.R
index bfe7e7a1..e14fdeb0 100644
--- a/R/LearnerS4Naive.R
+++ b/R/LearnerS4Naive.R
@@ -91,8 +91,11 @@ setMethod(
     ...
   ) {
 
+    # Check n_samples, and then ensure that n_samples is equal to the number
+    # of rows in the dataset.
     n_samples <- get_n_samples(data)
     if (n_samples == 0L) return(callNextMethod())
+    n_samples <- nrow(data@data)
     
     if (type == "default") {
       # default ----------------------------------------------------------------
@@ -184,8 +187,11 @@ setMethod(
     ...
   ) {
     
+    # Check n_samples, and then ensure that n_samples is equal to the number
+    # of rows in the dataset.
     n_samples <- get_n_samples(data)
     if (n_samples == 0L) return(callNextMethod())
+    n_samples <- nrow(data@data)
 
     if (object@outcome_type == "survival" && type == "default") {
       # For survival outcomes based on survival times, predict the average
@@ -268,8 +274,11 @@ setMethod(
     ...
   ) {
     
+    # Check n_samples, and then ensure that n_samples is equal to the number
+    # of rows in the dataset.
     n_samples <- get_n_samples(data)
     if (n_samples == 0L) return(callNextMethod())
+    n_samples <- nrow(data@data)
 
     if (object@outcome_type %in% c("survival") &&  type == "default") {
       # For survival outcomes based on survival times, predict the average
diff --git a/R/UtilitiesS4.R b/R/UtilitiesS4.R
index c7041a28..52d31764 100644
--- a/R/UtilitiesS4.R
+++ b/R/UtilitiesS4.R
@@ -463,20 +463,40 @@ setMethod(
 
 
 # get_n_samples ----------------------------------------------------------------
-setMethod("get_n_samples", signature(x = "data.table"), function(x, id_depth = "sample") {
-  return(.get_n_samples(x = x, id_depth = id_depth))
-})
+setMethod(
+  "get_n_samples",
+  signature(x = "data.table"),
+  function(x, id_depth = "sample", count_unique = TRUE) {
+    return(.get_n_samples(
+      x = x,
+      id_depth = id_depth, 
+      count_unique = count_unique
+    ))
+  }
+)
 
-setMethod("get_n_samples", signature(x = "dataObject"), function(x, id_depth = "sample") {
-  return(.get_n_samples(x = x@data, id_depth = id_depth))
-})
+setMethod(
+  "get_n_samples",
+  signature(x = "dataObject"),
+  function(x, id_depth = "sample", count_unique = TRUE) {
+    return(.get_n_samples(
+      x = x@data,
+      id_depth = id_depth,
+      count_unique = count_unique
+    ))
+  }
+)
 
-setMethod("get_n_samples", signature(x = "NULL"), function(x, id_depth = "sample") {
-  return(0L)
-})
+setMethod(
+  "get_n_samples",
+  signature(x = "NULL"),
+  function(x, id_depth = "sample", ...) {
+    return(0L)
+  }
+)
 
 
-.get_n_samples <- function(x, id_depth) {
+.get_n_samples <- function(x, id_depth, count_unique) {
   # Check if x is empty.
   if (is_empty(x)) {
     return(0L)
@@ -487,7 +507,12 @@ setMethod("get_n_samples", signature(x = "NULL"), function(x, id_depth = "sample
 
   # Return the number of rows with unique values for the combination of
   # identifier columns.
-  return(nrow(unique(x[, mget(id_columns)])))
+  if (count_unique) {
+    return(nrow(unique(x[, mget(id_columns)])))
+    
+  } else {
+    return(nrow(x[, mget(id_columns)]))
+  }
 }