Skip to content

Commit

Permalink
No link-inv in pool_predictions, add tests (#418)
Browse files Browse the repository at this point in the history
* No link-inv in pool_predictions, add tests

* add com,ment

* remove duplicated test

* add test
  • Loading branch information
strengejacke authored Feb 23, 2025
1 parent b2a18b9 commit 56b26ba
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 17 deletions.
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ S3method(standardize,estimate_predicted)
S3method(standardize,estimate_slopes)
S3method(summary,estimate_slopes)
S3method(summary,reshape_grouplevel)
S3method(unstandardize,estimate_contrasts)
S3method(unstandardize,estimate_means)
S3method(unstandardize,estimate_predicted)
S3method(visualisation_recipe,estimate_grouplevel)
S3method(visualisation_recipe,estimate_means)
Expand Down
21 changes: 7 additions & 14 deletions R/pool.R
Original file line number Diff line number Diff line change
Expand Up @@ -141,15 +141,11 @@ pool_predictions <- function(x, transform = NULL, ...) {
ci <- attributes(x[[1]])$ci
model <- attributes(x[[1]])$model
dof <- x[[1]]$df
link_inv <- insight::link_inverse(model)
link_fun <- insight::link_function(model)

if (is.null(link_inv)) {
link_inv <- function(x) x
}
if (is.null(link_fun)) {
link_fun <- function(x) x
}
# we don't use the link-inverse because standard errors are calculated using
# the delta method, hence, these would be incorrect if we apply link-inverse
# transformation to calculate CIs.

if (is.null(dof)) {
dof <- Inf
}
Expand All @@ -166,7 +162,7 @@ pool_predictions <- function(x, transform = NULL, ...) {

for (i in 1:n_rows) {
# pooled estimate
pooled_pred <- unlist(lapply(original_x, function(j) link_fun(j[[estimate_name]][i])), use.names = FALSE)
pooled_pred <- unlist(lapply(original_x, function(j) j[[estimate_name]][i]), use.names = FALSE)
pooled_predictions[[estimate_name]][i] <- mean(pooled_pred, na.rm = TRUE)

# pooled standard error
Expand All @@ -187,8 +183,8 @@ pool_predictions <- function(x, transform = NULL, ...) {
# confidence intervals ----
alpha <- (1 + ci) / 2
fac <- stats::qt(alpha, df = pooled_df)
pooled_predictions$CI_low <- link_inv(pooled_predictions[[estimate_name]] - fac * pooled_predictions$SE)
pooled_predictions$CI_high <- link_inv(pooled_predictions[[estimate_name]] + fac * pooled_predictions$SE)
pooled_predictions$CI_low <- pooled_predictions[[estimate_name]] - fac * pooled_predictions$SE
pooled_predictions$CI_high <- pooled_predictions[[estimate_name]] + fac * pooled_predictions$SE

# udpate df ----
pooled_predictions$df <- pooled_df
Expand All @@ -200,9 +196,6 @@ pool_predictions <- function(x, transform = NULL, ...) {
pooled_predictions$CI_high <- transform_fun(pooled_predictions$CI_high)
}

# backtransform
pooled_predictions[[estimate_name]] <- link_inv(pooled_predictions[[estimate_name]])

pooled_predictions
}

Expand Down
36 changes: 33 additions & 3 deletions R/standardize_methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,18 @@ standardize.estimate_slopes <- standardize.estimate_contrasts
#' @method unstandardize estimate_predicted
#' @export
unstandardize.estimate_predicted <- function(x, include_response = TRUE, ...) {
model <- attributes(x)$model

# Get data of predictors
data <- insight::get_data(attributes(x)$model, verbose = FALSE, ...)
data <- insight::get_data(model, verbose = FALSE, ...)
data[[attributes(x)$response]] <- NULL # Remove resp from data

# Standardize predictors
x[names(data)] <- datawizard::unstandardize(as.data.frame(x)[names(data)], reference = data, ...)

# Standardize response
if (include_response == TRUE && insight::model_info(attributes(x)$model)$is_linear) {
resp <- insight::get_response(attributes(x)$model)
if (include_response == TRUE && insight::model_info(model)$is_linear) {
resp <- insight::get_response(model)
disp <- attributes(datawizard::standardize(resp, ...))$scale

for (col in c("Predicted", "Mean", "CI_low", "CI_high")) {
Expand All @@ -95,3 +97,31 @@ unstandardize.estimate_predicted <- function(x, include_response = TRUE, ...) {
}
x
}


#' @export
unstandardize.estimate_means <- unstandardize.estimate_predicted


#' @export
unstandardize.estimate_contrasts <- function(x, robust = FALSE, ...) {
model <- attributes(x)$model

if (insight::model_info(model)$is_linear) {
# Get dispersion scaling factor
if (robust) {
disp <- stats::mad(insight::get_response(model), na.rm = TRUE)
} else {
disp <- stats::sd(insight::get_response(model), na.rm = TRUE)
}

# Standardize relevant cols
for (col in c("Difference", "Ratio", "Coefficient", "SE", "MAD", "CI_low", "CI_high")) {
if (col %in% names(x)) {
x[col] <- x[[col]] * disp
}
}
}

x
}
9 changes: 9 additions & 0 deletions tests/testthat/test-mice.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@ test_that("pool_predictions", {
out <- pool_predictions(predictions)
expect_equal(out$Mean, c(29.84661, 25.20021, 23.14022), tolerance = 1e-3)
expect_equal(out$CI_low, c(2.10117, 3.44548, -5.79522), tolerance = 1e-3)

# transformed response
predictions <- lapply(1:5, function(i) {
m <- lm(log1p(bmi) ~ age + hyp + chl, data = mice::complete(imp, action = i))
estimate_means(m, "age")
})
out <- pool_predictions(predictions, transform = TRUE)
expect_equal(out$Mean, c(29.67473, 24.99382, 23.19148), tolerance = 1e-3)
expect_equal(out$CI_low, c(10.58962, 11.13011, 7.43196), tolerance = 1e-3)
})


Expand Down
17 changes: 17 additions & 0 deletions tests/testthat/test-standardize.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,20 @@ test_that("standardize() - estimate_predicted", {
expect_equal(as.vector(out1$Predicted), c(1.0906, -0.0577, -0.82805), tolerance = 1e-4)
expect_equal(as.vector(out2$Predicted), estim$Predicted, tolerance = 1e-4)
})


test_that("standardize() - estimate_contrasts()", {
data(mtcars)

dat <- mtcars
dat$gear <- as.factor(dat$gear)
dat$cyl <- as.factor(dat$cyl)

# Simple
model <- lm(mpg ~ cyl, data = dat)
estim <- estimate_contrasts(model, "cyl", backend = "marginaleffects")
out1 <- standardize(estim)
out2 <- unstandardize(out1)
expect_equal(as.vector(out1$Difference), c(-1.14831, -1.91866, -0.77035), tolerance = 1e-4)
expect_equal(as.vector(out2$Difference), estim$Difference, tolerance = 1e-4)
})

0 comments on commit 56b26ba

Please sign in to comment.