Skip to content

Commit

Permalink
get_predicted() supports dpar (#988)
Browse files Browse the repository at this point in the history
* `get_predicted()` supports dpar

* lintr

* fix
  • Loading branch information
strengejacke authored Jan 5, 2025
1 parent e94ec60 commit 0937c16
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 10 deletions.
7 changes: 3 additions & 4 deletions R/find_formula.R
Original file line number Diff line number Diff line change
Expand Up @@ -1688,8 +1688,10 @@ find_formula.model_fit <- function(x, verbose = TRUE, ...) {
f_mu <- f$pforms$mu
f_nu <- f$pforms$nu
f_shape <- f$pforms$shape
f_alpha <- f$pforms$alpha
f_beta <- f$pforms$beta
f_phi <- f$pforms$phi
f_xi <- f$pforms$xi
f_hu <- f$pforms$hu
f_ndt <- f$pforms$ndt
f_zoi <- f$pforms$zoi
Expand All @@ -1710,10 +1712,7 @@ find_formula.model_fit <- function(x, verbose = TRUE, ...) {
# by the above exceptions.

# auxiliary names
auxiliary_names <- c(
"sigma", "mu", "nu", "shape", "beta", "phi", "hu", "ndt", "zoi", "coi",
"kappa", "bias", "bs", "zi"
)
auxiliary_names <- .brms_aux_elements()

# check if any further pforms exist
if (all(names(f$pforms) %in% auxiliary_names)) {
Expand Down
4 changes: 4 additions & 0 deletions R/get_predicted.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@
#' type (for instance, to a factor).
#' * Other strings are passed directly to the `type` argument of the `predict()`
#' method supplied by the modelling package.
#' * Specifically for models of class `brmsfit` (package *brms*), the `predict`
#' argument can be any valid option for the `dpar` argument, to predict
#' distributional parameters (such as `"sigma"`, `"beta"`, `"kappa"`, `"phi"`
#' and so on, see `?brms::brmsfamily`).
#' * When `predict = NULL`, alternative arguments such as `type` will be captured
#' by the `...` ellipsis and passed directly to the `predict()` method supplied
#' by the modelling package. Note that this might result in conflicts with
Expand Down
12 changes: 10 additions & 2 deletions R/get_predicted_args.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

# check whether user possibly used the "type" instead of "predict" argument
dots <- list(...)
dpar <- NULL

# one of "type" or "predict" must be provided...
if (is.null(dots$type) && is.null(predict)) {
Expand Down Expand Up @@ -94,7 +95,7 @@
# retrieve model object's predict-method prediction-types (if any)
type_methods <- suppressWarnings(eval(formals(predict_method)$type))
# and together, these prediction-types are supported...
supported <- c(easystats_methods, type_methods)
supported <- c(easystats_methods, type_methods, .brms_aux_elements())

# check aliases - ignore "expected" when this is a valid type-argument (e.g. coxph)
if (predict %in% c("expected", "response") && !"expected" %in% supported) {
Expand All @@ -104,6 +105,12 @@
predict <- "prediction"
}

# brms-exceptions: predict distributional parameters
if (predict %in% .brms_aux_elements()) {
dpar <- predict
predict <- "expectation"
}

# Warn if get_predicted() is not called with an easystats- or
# model-supported predicted type
if (isTRUE(verbose) && !is.null(predict) && !predict %in% supported) {
Expand Down Expand Up @@ -337,6 +344,7 @@
scale = scale_arg,
transform = my_transform,
info = info,
allow_new_levels = allow_new_levels
allow_new_levels = allow_new_levels,
distributional_parameter = dpar
)
}
1 change: 1 addition & 0 deletions R/get_predicted_bayesian.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ get_predicted.stanreg <- function(x,
fun_args <- list(x,
newdata = my_args$data,
re.form = my_args$re.form,
dpar = my_args$distributional_parameter,
draws = iterations
)

Expand Down
12 changes: 9 additions & 3 deletions R/helper_functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -229,21 +229,27 @@
"extra", "scale", "marginal", "alpha", "beta", "survival", "infrequent_purchase",
"auxiliary", "mix", "shiftprop", "phi", "ndt", "hu", "xi", "coi", "zoi",
"aux", "dist", "selection", "outcome", "time_dummies", "sigma_random",
"beta_random", "car", "nominal", "bidrange"
"beta_random", "car", "nominal", "bidrange", "mu", "kappa", "bias"
)
}

.aux_elements <- function() {
c(
"sigma", "alpha", "beta", "dispersion", "precision", "nu", "tau", "shape",
"phi", "(phi)", "ndt", "hu", "xi", "coi", "zoi", "mix", "shiftprop", "auxiliary",
"aux", "dist",

"aux", "dist", "mu", "kappa", "bias",
# random parameters
"dispersion_random", "sigma_random", "beta_random"
)
}

.brms_aux_elements <- function() {
c(
"sigma", "mu", "nu", "shape", "beta", "phi", "hu", "ndt", "zoi", "coi",
"kappa", "bias", "bs", "zi", "alpha", "xi"
)
}

.get_elements <- function(effects, component, model = NULL) {
# all elements of a model
elements <- .all_elements()
Expand Down
4 changes: 4 additions & 0 deletions man/get_predicted.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 22 additions & 1 deletion tests/testthat/test-get_predicted.R
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,26 @@ test_that("get_predicted - rstanarm", {
})


test_that("get_predicted - brms, auxiliary", {
skip_on_cran()
skip_if_not_installed("brms")
skip_if_not_installed("httr2")

m <- insight::download_model("brms_sigma_2")
dg <- get_datagrid(m, reference = "grid", include_random = TRUE)
out <- get_predicted(m, data = dg, predict = "sigma")
expect_equal(
as.numeric(out),
c(
1.02337, 0.82524, 0.58538, 0.74573, 0.66292, 1.0336, 0.94714,
0.74541, 0.71533, 0.7032, 0.63151, 0.65244, 0.58731, 0.45177,
0.75789
),
tolerance = 1e-4
)
})


# FA / PCA ----------------------------------------------------------------
# =========================================================================

Expand Down Expand Up @@ -560,7 +580,7 @@ test_that("bugfix: used to fail with matrix variables", {
foo <- function() {
mtcars2 <- mtcars
mtcars2$wt <- scale(mtcars2$wt)
return(lm(mpg ~ wt + cyl + gear + disp, data = mtcars2))
lm(mpg ~ wt + cyl + gear + disp, data = mtcars2)
}
pred <- get_predicted(foo())
expect_s3_class(pred, c("get_predicted", "numeric"))
Expand All @@ -582,6 +602,7 @@ test_that("bugfix: used to fail with matrix variables", {
expect_equal(pred, pred2, ignore_attr = TRUE)
})


test_that("brms: `type` in ellipsis used to produce the wrong intervals", {
skip_on_cran()
skip_if_not_installed("brms")
Expand Down

0 comments on commit 0937c16

Please sign in to comment.