Skip to content

Commit

Permalink
make fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
DominiqueMakowski committed Feb 11, 2024
1 parent 9e2bb40 commit 717c4f3
Show file tree
Hide file tree
Showing 9 changed files with 202 additions and 117 deletions.
3 changes: 1 addition & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,11 @@ VignetteBuilder:
knitr
Encoding: UTF-8
Language: en-US
RoxygenNote: 7.2.3.9000
RoxygenNote: 7.3.1
Config/testthat/edition: 3
Config/testthat/parallel: true
Roxygen: list(markdown = TRUE)
Config/Needs/website:
rstudio/bslib,
r-lib/pkgdown,
easystats/easystatstemplate
Remotes: easystats/insight
40 changes: 13 additions & 27 deletions R/estimate_means.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,47 +53,29 @@ estimate_means <- function(model,
ci = 0.95,
backend = "emmeans",
...) {
# Compute means

if (backend == "emmeans") {
# Emmeans ------------------------------------------------------------------
estimated <- get_emmeans(model, at, fixed, transform = transform, ...)
means <- .format_emmeans_means(estimated, model, ci, transform, ...)

# Summarize and clean
if (insight::model_info(model)$is_bayesian) {
means <- parameters::parameters(estimated, ci = ci, ...)
means <- .clean_names_bayesian(means, model, transform, type = "mean")
means <- cbind(estimated@grid, means)
means$`.wgt.` <- NULL # Drop the weight column
} else {
means <- as.data.frame(stats::confint(estimated, level = ci))
means$df <- NULL
means <- .clean_names_frequentist(means)
}
# Remove the "1 - overall" column that can appear in cases like at = NULL
means <- means[names(means) != "1"]

info <- attributes(estimated)
} else {
means <- .get_marginalmeans(model, at, fixed, transform = transform, ...)

info <- attributes(means)
# Marginalmeans ------------------------------------------------------------
estimated <- .get_marginalmeans(model, at, ci=ci, ...)

Check warning on line 64 in R/estimate_means.R

View workflow job for this annotation

GitHub Actions / lint / lint

file=R/estimate_means.R,line=64,col=50,[infix_spaces_linter] Put spaces around all infix operators.
means <- .format_marginaleffects_means(estimated, model, ...)
}

# Restore factor levels
means <- datawizard::data_restoretype(means, insight::get_data(model))


# Table formatting

attr(means, "table_title") <- c("Estimated Marginal Means", "blue")
attr(means, "table_footer") <- .estimate_means_footer(means, info$at, type = "means")
attr(means, "table_footer") <- .estimate_means_footer(means, type = "means")

# Add attributes
attr(means, "model") <- model
attr(means, "response") <- insight::find_response(model)
attr(means, "ci") <- ci
attr(means, "transform") <- transform
attr(means, "at") <- info$at
attr(means, "fixed") <- info$fixed

attr(means, "coef_name") <- intersect(c("Mean", "Probability"), names(means))


Expand All @@ -114,6 +96,8 @@ estimate_means <- function(model,
# Levels
if (!is.null(at) && length(at) > 0) {
table_footer <- paste0(table_footer, " estimated at ", toString(at))
} else {
table_footer <- paste0(table_footer, " estimated at ", attr(x, "at"))
}

# P-value adjustment footer
Expand All @@ -125,6 +109,8 @@ estimate_means <- function(model,
}
}

if (table_footer == "") table_footer <- NULL
if (all(table_footer == "")) table_footer <- NULL

Check warning on line 112 in R/estimate_means.R

View workflow job for this annotation

GitHub Actions / lint / lint

file=R/estimate_means.R,line=112,col=11,[nzchar_linter] Use !nzchar(x) instead of x == "". Note that unlike nzchar(), EQ coerces to character, so you'll have to use as.character() if x is a factor. Whenever missing data is possible, please take care to use nzchar(., keepNA = TRUE); nzchar(NA) is TRUE by default.
c(table_footer, "blue")
}

Check warning on line 115 in R/estimate_means.R

View workflow job for this annotation

GitHub Actions / lint / lint

file=R/estimate_means.R,line=115,col=1,[trailing_blank_lines_linter] Remove trailing blank lines.

80 changes: 60 additions & 20 deletions R/get_emmeans.R
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,69 @@ get_emmeans <- function(model,
model_emmeans <- get_emmeans



# =========================================================================
# HELPERS ----------------------------------------------------------------
# =========================================================================
# This function is the actual equivalent of .get_marginalmeans(); both being used
# in estimate_means

#' @keywords internal
.format_emmeans_means <- function(estimated, model, ci = 0.95, transform = "response", ...) {
# Summarize and clean
if (insight::model_info(model)$is_bayesian) {
means <- parameters::parameters(estimated, ci = ci, ...)
means <- .clean_names_bayesian(means, model, transform, type = "mean")
means <- cbind(estimated@grid, means)
means$`.wgt.` <- NULL # Drop the weight column
} else {
means <- as.data.frame(stats::confint(estimated, level = ci))
means$df <- NULL
means <- .clean_names_frequentist(means)
}
# Remove the "1 - overall" column that can appear in cases like at = NULL
means <- means[names(means) != "1"]

# Restore factor levels
means <- datawizard::data_restoretype(means, insight::get_data(model))


info <- attributes(estimated)

attr(means, "at") <- info$at
attr(means, "fixed") <- info$fixed
means
}



# =========================================================================
# HELPERS (guess arguments) -----------------------------------------------
# =========================================================================

#' @keywords internal
.guess_emmeans_arguments <- function(model,
at = NULL,
fixed = NULL,
...) {
# Gather info
predictors <- insight::find_predictors(model, effects = "fixed", flatten = TRUE, ...)
data <- insight::get_data(model)

# Guess arguments
if (!is.null(at) && length(at) == 1 && at == "auto") {
at <- predictors[!sapply(data[predictors], is.numeric)]
if (!length(at) || all(is.na(at))) {
stop("Model contains no categorical factor. Please specify 'at'.", call. = FALSE)
}
message("We selected `at = c(", toString(paste0('"', at, '"')), ")`.")
}

args <- list(at = at, fixed = fixed)
.format_emmeans_arguments(model, args, data, ...)
}


#' @keywords internal
.format_emmeans_arguments <- function(model, args, data, ...) {
# Create the data_matrix
Expand Down Expand Up @@ -174,7 +233,7 @@ model_emmeans <- get_emmeans
for (var_at in names(args$emmeans_at)) {
term <- terms[grepl(var_at, terms, fixed = TRUE)]
if (any(grepl(paste0("as.factor(", var_at, ")"), term, fixed = TRUE)) ||
any(grepl(paste0("as.character(", var_at, ")"), term, fixed = TRUE))) {
any(grepl(paste0("as.character(", var_at, ")"), term, fixed = TRUE))) {
args$retransform[[var_at]] <- args$emmeans_at[[var_at]]
args$emmeans_at[[var_at]] <- as.numeric(as.character(args$emmeans_at[[var_at]]))
}
Expand All @@ -187,24 +246,5 @@ model_emmeans <- get_emmeans



#' @keywords internal
.guess_emmeans_arguments <- function(model,
at = NULL,
fixed = NULL,
...) {
# Gather info
predictors <- insight::find_predictors(model, effects = "fixed", flatten = TRUE, ...)
data <- insight::get_data(model)

# Guess arguments
if (!is.null(at) && length(at) == 1 && at == "auto") {
at <- predictors[!sapply(data[predictors], is.numeric)]
if (!length(at) || all(is.na(at))) {
stop("Model contains no categorical factor. Please specify 'at'.", call. = FALSE)
}
message("We selected `at = c(", toString(paste0('"', at, '"')), ")`.")
}

args <- list(at = at, fixed = fixed)
.format_emmeans_arguments(model, args, data, ...)
}
94 changes: 66 additions & 28 deletions R/get_marginalmeans.R
Original file line number Diff line number Diff line change
@@ -1,42 +1,80 @@
#' @keywords internal
.get_marginalmeans <- function(model,
at = "auto",
fixed = NULL,
transform = "response",
ci = 0.95,
marginal = FALSE,
...) {
# check if available
insight::check_if_installed("marginaleffects")

# Guess arguments
args <- .guess_emmeans_arguments(model, at, fixed, ...)

# Run emmeans
means <- marginaleffects::marginalmeans(model, variables = args$at, conf_level = ci)

# TODO: this should be replaced by parameters::parameters(means)
# Format names
names(means)[names(means) %in% "conf.low"] <- "CI_low"
names(means)[names(means) %in% "conf.high"] <- "CI_high"
names(means)[names(means) %in% "std.error"] <- "SE"
names(means)[names(means) %in% "marginalmean"] <- "Mean"
names(means)[names(means) %in% "p.value"] <- "p"
names(means)[names(means) %in% "statistic"] <- ifelse(insight::find_statistic(model) == "t-statistic", "t", "statistic")

# Format terms
term <- unique(means$term) # Get name of variable
if (length(term) > 1L) {
insight::format_error("marignalmeans backend can currently only deal with one 'at' variable.")
}
names(means)[names(means) %in% c("value")] <- term # Replace 'value' col by var name
means$term <- NULL
args <- .guess_arguments_means(model, at, ...)

# Drop stats
means$p <- NULL
means$t <- NULL
# Get corresponding datagrid (and deal with particular ats)
datagrid <- insight::get_datagrid(model, at = args$at, ...)
# Drop random effects
datagrid <- datagrid[insight::find_predictors(model, effects="fixed", flatten = TRUE)]
at_specs <- attributes(datagrid)$at_specs

# Store attributes
attr(means, "at") <- args$at

if (marginal == FALSE) {
if(insight::is_mixed_model(model)) {
means <- marginaleffects::predictions(model,
newdata=datagrid,
by=at_specs$varname,
conf_level = ci,
re.form=NA)
} else {
means <- marginaleffects::predictions(model,
newdata=datagrid,
by=at_specs$varname,
conf_level = ci)
}
} else {
means <- marginaleffects::predictions(model,
newdata=insight::get_data(model),
by=at_specs$varname,
conf_level = ci)
}
attr(means, "at") <- args$at
means
}


# Format ------------------------------------------------------------------


#' @keywords internal
.format_marginaleffects_means <- function(means, model, ...) {
# Format
params <- parameters::parameters(means) |>
datawizard::data_relocate(c("Predicted", "SE", "CI_low", "CI_high"), after=-1) |>
datawizard::data_rename("Predicted", "Mean") |>
datawizard::data_remove(c("p", "Statistic", "s.value", "S", "CI")) |>
datawizard::data_restoretype(insight::get_data(model))

# Store info
attr(params, "at") <- attr(means, "at")
params
}

# Guess -------------------------------------------------------------------

#' @keywords internal
.guess_arguments_means <- function(model, at = NULL, ...) {
# Gather info and data from model
predictors <- insight::find_predictors(model, flatten = TRUE, ...)
data <- insight::get_data(model)

# Guess arguments ('at' and 'fixed')
if (!is.null(at) && length(at) == 1 && at == "auto") {
# Find categorical predictors
at <- predictors[!sapply(data[predictors], is.numeric)]
if (!length(at) || all(is.na(at))) {
stop("Model contains no categorical factor. Please specify 'at'.", call. = FALSE)
}
message("We selected `at = c(", toString(paste0('"', at, '"')), ")`.")
}

list(at=at)
}
4 changes: 2 additions & 2 deletions R/visualisation_recipe.estimate_means.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
#' x <- estimate_means(model, at = c("new_factor", "wt"))
#' plot(visualisation_recipe(x))
#'
#' x <- estimate_means(model, at = c("new_factor", "cyl", "wt"))
#' plot(visualisation_recipe(x))
#' # x <- estimate_means(model, at = c("new_factor", "cyl", "wt"))
#' # plot(visualisation_recipe(x)) # TODO: broken
#'
#' #' # GLMs ---------------------
#' data <- data.frame(vs = mtcars$vs, cyl = as.factor(mtcars$cyl))
Expand Down
1 change: 0 additions & 1 deletion man/modelbased-package.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/visualisation_recipe.estimate_predicted.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 717c4f3

Please sign in to comment.