Skip to content

Commit

Permalink
Naive models now yield the correct number of predicted values
Browse files Browse the repository at this point in the history
  • Loading branch information
alexzwanenburg committed Jan 3, 2025
1 parent 2729584 commit d0b2db5
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 11 deletions.
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@

- Some vignettes referred to `experiment_design` where `experimental_design` was
intended.

- Fixed an error where naive models would yield an incorrect number of predicted
values when samples appear multiple times (e.g. in bootstraps).

# Version 1.5.0 (Whole Whale)

Expand Down
9 changes: 9 additions & 0 deletions R/LearnerS4Naive.R
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,11 @@ setMethod(
...
) {

# Check n_samples, and then ensure that n_samples is equal to the number
# of rows in the dataset.
n_samples <- get_n_samples(data)
if (n_samples == 0L) return(callNextMethod())
n_samples <- nrow(data@data)

if (type == "default") {
# default ----------------------------------------------------------------
Expand Down Expand Up @@ -184,8 +187,11 @@ setMethod(
...
) {

# Check n_samples, and then ensure that n_samples is equal to the number
# of rows in the dataset.
n_samples <- get_n_samples(data)
if (n_samples == 0L) return(callNextMethod())
n_samples <- nrow(data@data)

if (object@outcome_type == "survival" && type == "default") {
# For survival outcomes based on survival times, predict the average
Expand Down Expand Up @@ -268,8 +274,11 @@ setMethod(
...
) {

# Check n_samples, and then ensure that n_samples is equal to the number
# of rows in the dataset.
n_samples <- get_n_samples(data)
if (n_samples == 0L) return(callNextMethod())
n_samples <- nrow(data@data)

if (object@outcome_type %in% c("survival") && type == "default") {
# For survival outcomes based on survival times, predict the average
Expand Down
47 changes: 36 additions & 11 deletions R/UtilitiesS4.R
Original file line number Diff line number Diff line change
Expand Up @@ -463,20 +463,40 @@ setMethod(


# get_n_samples ----------------------------------------------------------------
setMethod("get_n_samples", signature(x = "data.table"), function(x, id_depth = "sample") {
return(.get_n_samples(x = x, id_depth = id_depth))
})
setMethod(
"get_n_samples",
signature(x = "data.table"),
function(x, id_depth = "sample", count_unique = TRUE) {
return(.get_n_samples(
x = x,
id_depth = id_depth,
count_unique = count_unique
))
}
)

setMethod("get_n_samples", signature(x = "dataObject"), function(x, id_depth = "sample") {
return(.get_n_samples(x = x@data, id_depth = id_depth))
})
setMethod(
"get_n_samples",
signature(x = "dataObject"),
function(x, id_depth = "sample", count_unique = TRUE) {
return(.get_n_samples(
x = x@data,
id_depth = id_depth,
count_unique = count_unique
))
}
)

setMethod("get_n_samples", signature(x = "NULL"), function(x, id_depth = "sample") {
return(0L)
})
setMethod(
"get_n_samples",
signature(x = "NULL"),
function(x, id_depth = "sample", ...) {
return(0L)
}
)


.get_n_samples <- function(x, id_depth) {
.get_n_samples <- function(x, id_depth, count_unique) {
# Check if x is empty.
if (is_empty(x)) {
return(0L)
Expand All @@ -487,7 +507,12 @@ setMethod("get_n_samples", signature(x = "NULL"), function(x, id_depth = "sample

# Return the number of rows with unique values for the combination of
# identifier columns.
return(nrow(unique(x[, mget(id_columns)])))
if (count_unique) {
return(nrow(unique(x[, mget(id_columns)])))

} else {
return(nrow(x[, mget(id_columns)]))
}
}


Expand Down

0 comments on commit d0b2db5

Please sign in to comment.