Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
strengejacke committed Dec 15, 2024
1 parent 5e279fc commit 164bc11
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 9 deletions.
30 changes: 26 additions & 4 deletions R/get_marginalmeans.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
fun_args <- list(
model,
by = at_specs$varname,
newdata = datagrid,
conf_level = ci,
type = type,
hypothesis = hypothesis
Expand All @@ -37,6 +38,7 @@

attr(means, "at") <- my_args$by
attr(means, "by") <- my_args$by
attr(means, "focal_terms") <- my_args$focal_terms
means
}

Expand All @@ -47,7 +49,7 @@
#' @keywords internal
.format_marginaleffects_means <- function(means, model, ...) {
model_data <- insight::get_data(model)
non_focal <- setdiff(colnames(model_data), attr(means, "by"))
non_focal <- setdiff(colnames(model_data), attr(means, "focal_terms"))
# Format
params <- parameters::parameters(means)
params <- datawizard::data_relocate(params, c("Predicted", "SE", "CI_low", "CI_high"), after = -1)
Expand All @@ -71,15 +73,35 @@
predictors <- insight::find_predictors(model, flatten = TRUE, ...)
model_data <- insight::get_data(model)

# Guess arguments ('by' and 'fixed')
# Guess arguments 'by'
if (identical(by, "auto")) {
# Find categorical predictors
by <- predictors[!vapply(model_data[predictors], is.numeric, logical(1))]
if (!length(by) || all(is.na(by))) {
insight::format_error("Model contains no categorical factor. Please specify 'by'.")
insight::format_error("Model contains no categorical predictor. Please specify `by`.")
}
insight::format_alert(paste0("We selected `by = c(", toString(paste0('"', by, '"')), ")`."))
}

list(by = by)
# in "focal_terms", we want the variable names.
focal_terms <- by

# This is needed when we have something like
# `by = "Species=c('versicolor', 'virginica')")`
# we need the variable names for selecting columns in the output
focals_to_fix <- vapply(by, function(i) grepl("=", i, fixed = TRUE), logical(1))
if (any(focals_to_fix)) {
for (i in seq_along(focal_terms)) {
if (focals_to_fix[i]) {
focal_terms[i] <- insight::trim_ws(unlist(strsplit(by[i], "=", fixed = TRUE), use.names = FALSE))[1]
}
}
}

# exceptions: by = "all"
if (all(focal_terms == "all")) {
focal_terms <- predictors
}

list(by = by, focal_terms = focal_terms)
}
14 changes: 9 additions & 5 deletions tests/testthat/test-estimate_means.R
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
test_that("estimate_means() - core", {
skip_if_not_installed("rstanarm")
skip_if_not_installed("emmeans")
skip_if_not_installed("marginaleffects")

# library(testthat)

dat <- mtcars
dat$gear <- as.factor(dat$gear)
dat$cyl <- as.factor(dat$cyl)
Expand All @@ -23,8 +20,11 @@ test_that("estimate_means() - core", {
estim1 <- suppressMessages(estimate_means(model))
expect_identical(dim(estim1), c(3L, 5L))
estim2 <- suppressMessages(estimate_means(model, backend = "marginaleffects"))
expect_identical(dim(estim2), c(3L, 6L))
expect_identical(dim(estim2), c(3L, 5L))
expect_lt(max(estim1$Mean - estim2$Mean), 1e-10)
expect_equal(estim1$Mean, estim2$Mean, tolerance = 1e-4)
expect_named(estim1, c("gear", "Mean", "SE", "CI_low", "CI_high"))
expect_named(estim2, c("gear", "Mean", "SE", "CI_low", "CI_high"))

# At specific levels
model <- lm(Sepal.Width ~ Species, data = iris)
Expand All @@ -33,6 +33,9 @@ test_that("estimate_means() - core", {
estim2 <- suppressMessages(estimate_means(model, by = "Species=c('versicolor', 'virginica')", backend = "marginaleffects"))
expect_identical(dim(estim2), c(2L, 5L))
expect_lt(max(estim1$Mean - estim2$Mean), 1e-10)
expect_equal(estim1$Mean, estim2$Mean, tolerance = 1e-4)
expect_named(estim1, c("Species", "Mean", "SE", "CI_low", "CI_high"))
expect_named(estim2, c("Species", "Mean", "SE", "CI_low", "CI_high"))

# Interactions between factors
dat <- iris
Expand Down Expand Up @@ -107,6 +110,8 @@ test_that("estimate_means() - core", {
expect_equal(dim(estim), c(5, 5))
estim <- suppressMessages(estimate_means(model, by = "Sepal.Width=c(2, 4)"))
expect_identical(dim(estim), c(2L, 5L))
estim1 <- suppressMessages(estimate_means(model, by = c("Species=c('versicolor', 'setosa')", "Sepal.Width=c(2, 4)")))
estim2 <- suppressMessages(estimate_means(model, by = c("Species=c('versicolor', 'setosa')", "Sepal.Width=c(2, 4)"), backend = "marginalmeans"))

# Two factors
dat <- iris
Expand Down Expand Up @@ -166,7 +171,6 @@ test_that("estimate_means() - core", {
})

test_that("estimate_means() - mixed models", {
skip_if_not_installed("rstanarm")
skip_if_not_installed("emmeans")
skip_if_not_installed("lme4")
skip_if_not_installed("glmmTMB")
Expand Down

0 comments on commit 164bc11

Please sign in to comment.