Skip to content

Commit

Permalink
aggregate
Browse files Browse the repository at this point in the history
  • Loading branch information
lskatz committed Jan 26, 2024
1 parent 6ad2ece commit 0250eaf
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 0 deletions.
43 changes: 43 additions & 0 deletions R/aggregate.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#' Aggregate predictions
#'
#' This function aggregates predictions from multiple
#' calls of `prediction()`.
#'
#' @param predictions (vector of predict.rfsrc() objects) The prediction objects from `prediction()`.
#'
#' @return something
#'
#' @export
#'
#' @examples
#' \dontrun{
#' # Example usage:
#' prediction <- list()
#' predictions[[1]] <- prediction(model = "bs23.rds", ...)
#' predictions[[2]] <- prediction(model = "bs24.rds", ...)
#' result <- aggregate(predictions)
#' }
#'
#' @importFrom logger log_info
#' @importFrom utils read.csv write.table
#' @importFrom magrittr `%>%`
#' @importFrom randomForestSRC predict.rfsrc

aggregate <- function(predictions) {

# Get all the confidence scores
scores <- lapply(predictions, function(pred) pred$predicted)

# Get averages of different sources
all_scores <- do.call(rbind, scores)

# Calculate the averages and standard deviations for each category
category_summary <- apply(all_scores, 2, function(x) c(mean = mean(x), sd = sd(x)))

# TODO return averaged VIMP list
# TODO make the return variable an object holding all this stuff
# TODO document this better

return(category_summary)
}

28 changes: 28 additions & 0 deletions man/aggregate.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

38 changes: 38 additions & 0 deletions tests/testthat/test-20_aggregation.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
suppressPackageStartupMessages(library("data.table"))
suppressPackageStartupMessages(library("dplyr"))
suppressPackageStartupMessages(library("logger"))
suppressPackageStartupMessages(library("gt"))
suppressPackageStartupMessages(library("randomForestSRC"))
suppressPackageStartupMessages(library("tidyverse"))


test_that("Aggregating LMO0003 with example_query", {

ncores <- 1

# rfsrc prediction objects
predictions <- list()
for(i in seq(1,3)){
model <- paste0("test-results/bs",(i+22),".rds")
predictions[[i]] <- prediction(model_filename = model,
query = "example_query.csv", ncores = ncores)

}

bootstrapped_prediction <- aggregate(predictions = predictions)

expected <- matrix(
c(
0.1501267932, 0.1578980673, 0.3111404334, 0.1697799469, 0.211054759,
0.0006860643, 0.0008925544, 0.0002023788, 0.0009216065, 0.000183757
),
nrow = 2, byrow = TRUE,
dimnames = list(c("mean", "sd"), c("dairy", "fruit", "meat", "seafood", "vegetable"))
)

expect_equal(bootstrapped_prediction["mean","dairy"],
expected["mean","dairy"],
tolerance = 0.1)
expect_equal(bootstrapped_prediction, expected, tolerance = 0.1)

})

0 comments on commit 0250eaf

Please sign in to comment.