Skip to content

Commit

Permalink
Merge pull request #334 from easystats/cor_sort_improvements
Browse files Browse the repository at this point in the history
[Feature] cor_sort() can deal with non-square matrices
  • Loading branch information
strengejacke authored Dec 29, 2024
2 parents fced785 + b95e053 commit ac9ef1e
Show file tree
Hide file tree
Showing 17 changed files with 267 additions and 194 deletions.
6 changes: 3 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Type: Package
Package: correlation
Title: Methods for Correlation Analysis
Version: 0.8.6
Version: 0.8.6.1
Authors@R:
c(person(given = "Dominique",
family = "Makowski",
Expand Down Expand Up @@ -57,8 +57,8 @@ Imports:
bayestestR (>= 0.15.0),
datasets,
datawizard (>= 0.13.0),
insight (>= 0.20.5),
parameters (>= 0.22.2),
insight (>= 1.0.0),
parameters (>= 0.24.0),
stats
Suggests:
BayesFactor,
Expand Down
90 changes: 80 additions & 10 deletions R/cor_sort.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ cor_sort <- function(x, distance = "correlation", hclust_method = "complete", ..

#' @export
cor_sort.easycorrelation <- function(x, distance = "correlation", hclust_method = "complete", ...) {
col_order <- .cor_sort_order(as.matrix(x), distance = distance, hclust_method = hclust_method, ...)
x$Parameter1 <- factor(x$Parameter1, levels = col_order)
x$Parameter2 <- factor(x$Parameter2, levels = col_order)
m <- cor_sort(as.matrix(x), distance = distance, hclust_method = hclust_method, ...)
x$Parameter1 <- factor(x$Parameter1, levels = rownames(m))
x$Parameter2 <- factor(x$Parameter2, levels = colnames(m))
reordered <- x[order(x$Parameter1, x$Parameter2), ]

# Restore class and attributes
Expand All @@ -38,6 +38,8 @@ cor_sort.easycorrelation <- function(x, distance = "correlation", hclust_method
)

# Make sure Parameter columns are character
# Was added to fix a test, but makes the function not work
# (See https://github.com/easystats/correlation/issues/259)
# reordered$Parameter1 <- as.character(reordered$Parameter1)
# reordered$Parameter2 <- as.character(reordered$Parameter2)

Expand All @@ -55,18 +57,32 @@ cor_sort.easycormatrix <- function(x, distance = "correlation", hclust_method =
m <- x
row.names(m) <- x$Parameter
m <- as.matrix(m[names(m)[names(m) != "Parameter"]])
col_order <- .cor_sort_order(m, distance = distance, hclust_method = hclust_method, ...)

# If non-redundant matrix, fail (## TODO: fix that)
if (anyNA(m)) {
insight::format_error("Non-redundant matrices are not supported yet. Try again by setting summary(..., redundant = TRUE)")

Check warning on line 63 in R/cor_sort.R

View workflow job for this annotation

GitHub Actions / lint / lint

file=R/cor_sort.R,line=63,col=121,[line_length_linter] Lines should not be more than 120 characters. This line is 126 characters.
}

# Get sorted matrix
m <- cor_sort(m, distance = distance, hclust_method = hclust_method, ...)

# Reorder
x$Parameter <- factor(x$Parameter, levels = col_order)
reordered <- x[order(x$Parameter), c("Parameter", col_order)]
x$Parameter <- factor(x$Parameter, levels = row.names(m))
reordered <- x[order(x$Parameter), c("Parameter", colnames(m))]

# Restore class and attributes
attributes(reordered) <- utils::modifyList(
attributes(x)[!names(attributes(x)) %in% c("names", "row.names")],
attributes(reordered)
)

# Reorder attributes (p-values) etc.
for (id in c("p", "CI", "CI_low", "CI_high", "BF", "Method", "n_Obs", "df_error", "t")) {
if (id %in% names(attributes(reordered))) {
attributes(reordered)[[id]] <- attributes(reordered)[[id]][order(x$Parameter), names(reordered)]
}
}

# make sure Parameter columns are character
reordered$Parameter <- as.character(reordered$Parameter)

Expand All @@ -76,8 +92,13 @@ cor_sort.easycormatrix <- function(x, distance = "correlation", hclust_method =

#' @export
cor_sort.matrix <- function(x, distance = "correlation", hclust_method = "complete", ...) {
col_order <- .cor_sort_order(x, distance = distance, hclust_method = hclust_method, ...)
reordered <- x[col_order, col_order]
if (isSquare(x) && all(colnames(x) %in% rownames(x))) {
i <- .cor_sort_square(x, distance = distance, hclust_method = hclust_method, ...)
} else {
i <- .cor_sort_nonsquare(x, distance = "euclidean", ...)
}

reordered <- x[i$row_order, i$col_order]

# Restore class and attributes
attributes(reordered) <- utils::modifyList(
Expand All @@ -91,7 +112,7 @@ cor_sort.matrix <- function(x, distance = "correlation", hclust_method = "comple
# Utils -------------------------------------------------------------------


.cor_sort_order <- function(m, distance = "correlation", hclust_method = "complete", ...) {
.cor_sort_square <- function(m, distance = "correlation", hclust_method = "complete", ...) {
if (distance == "correlation") {
d <- stats::as.dist((1 - m) / 2) # r = -1 -> d = 1; r = 1 -> d = 0
} else if (distance == "raw") {
Expand All @@ -101,5 +122,54 @@ cor_sort.matrix <- function(x, distance = "correlation", hclust_method = "comple
}

hc <- stats::hclust(d, method = hclust_method)
row.names(m)[hc$order]
row_order <- row.names(m)[hc$order]
list(row_order = row_order, col_order = row_order)
}


.cor_sort_nonsquare <- function(m, distance = "euclidean", ...) {
# Step 1: Perform clustering on rows and columns independently
row_dist <- stats::dist(m, method = distance) # Distance between rows
col_dist <- stats::dist(t(m), method = distance) # Distance between columns

row_hclust <- stats::hclust(row_dist, method = "average")
col_hclust <- stats::hclust(col_dist, method = "average")

# Obtain clustering orders
row_order <- row_hclust$order
col_order <- col_hclust$order

# Reorder matrix based on clustering
clustered_matrix <- m[row_order, col_order]

# Step 2: Refine alignment to emphasize strong correlations along the diagonal
n_rows <- nrow(clustered_matrix)
n_cols <- ncol(clustered_matrix)

used_rows <- logical(n_rows)
refined_row_order <- integer(0)

for (col in seq_len(n_cols)) {
max_value <- -Inf
best_row <- NA

for (row in seq_len(n_rows)[!used_rows]) {
if (abs(clustered_matrix[row, col]) > max_value) {
max_value <- abs(clustered_matrix[row, col])
best_row <- row
}
}

if (!is.na(best_row)) {
refined_row_order <- c(refined_row_order, best_row)
used_rows[best_row] <- TRUE
}
}

# Append any unused rows at the end
refined_row_order <- c(refined_row_order, which(!used_rows))

# Apply
m <- clustered_matrix[refined_row_order, ]
list(row_order = rownames(m), col_order = colnames(m))
}
82 changes: 39 additions & 43 deletions R/correlation.R
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,9 @@
#' `stats` package are supported.
#' }
#'
#' @examplesIf requireNamespace("poorman", quietly = TRUE) && requireNamespace("psych", quietly = TRUE)
#'
#' @examplesIf all(insight::check_if_installed(c("psych", "datawizard"), quietly = TRUE)) && getRversion() >= "4.1.0"
#' library(correlation)
#' library(poorman)
#' data(iris)
#'
#' results <- correlation(iris)
#'
Expand All @@ -192,22 +191,20 @@
#' summary(results, redundant = TRUE)
#'
#' # pipe-friendly usage with grouped dataframes from {dplyr} package
#' iris %>%
#' iris |>
#' correlation(select = "Petal.Width", select2 = "Sepal.Length")
#'
#' # Grouped dataframe
#' # grouped correlations
#' iris %>%
#' group_by(Species) %>%
#' iris |>
#' datawizard::data_group(Species) |>
#' correlation()
#'
#' # selecting specific variables for correlation
#' mtcars %>%
#' group_by(am) %>%
#' correlation(
#' select = c("cyl", "wt"),
#' select2 = c("hp")
#' )
#' data(mtcars)
#' mtcars |>
#' datawizard::data_group(am) |>
#' correlation(select = c("cyl", "wt"), select2 = "hp")
#'
#' # supplying custom variable names
#' correlation(anscombe, select = c("x1", "x2"), rename = c("var1", "var2"))
Expand Down Expand Up @@ -425,8 +422,36 @@ correlation <- function(data,
ungrouped_x <- as.data.frame(data)
xlist <- split(ungrouped_x, ungrouped_x[groups], sep = " - ")

# If data 2 is provided
if (!is.null(data2)) {
# If data 2 is not provided
if (is.null(data2)) {
modelframe <- data.frame()
out <- data.frame()
for (i in names(xlist)) {
xlist[[i]][groups] <- NULL
rez <- .correlation(
xlist[[i]],
data2,
method = method,
p_adjust = p_adjust,
ci = ci,
bayesian = bayesian,
bayesian_prior = bayesian_prior,
bayesian_ci_method = bayesian_ci_method,
bayesian_test = bayesian_test,
redundant = redundant,
include_factors = include_factors,
partial = partial,
partial_bayesian = partial_bayesian,
multilevel = multilevel,
ranktransform = ranktransform,
winsorize = winsorize
)
modelframe_current <- rez$data
rez$params$Group <- modelframe_current$Group <- i
out <- rbind(out, rez$params)
modelframe <- rbind(modelframe, modelframe_current)
}
} else {
if (inherits(data2, "grouped_df")) {
groups2 <- setdiff(colnames(attributes(data2)$groups), ".rows")
if (!all.equal(groups, groups2)) {
Expand Down Expand Up @@ -463,35 +488,6 @@ correlation <- function(data,
modelframe <- rbind(modelframe, modelframe_current)
}
}
# else
} else {
modelframe <- data.frame()
out <- data.frame()
for (i in names(xlist)) {
xlist[[i]][groups] <- NULL
rez <- .correlation(
xlist[[i]],
data2,
method = method,
p_adjust = p_adjust,
ci = ci,
bayesian = bayesian,
bayesian_prior = bayesian_prior,
bayesian_ci_method = bayesian_ci_method,
bayesian_test = bayesian_test,
redundant = redundant,
include_factors = include_factors,
partial = partial,
partial_bayesian = partial_bayesian,
multilevel = multilevel,
ranktransform = ranktransform,
winsorize = winsorize
)
modelframe_current <- rez$data
rez$params$Group <- modelframe_current$Group <- i
out <- rbind(out, rez$params)
modelframe <- rbind(modelframe, modelframe_current)
}
}

# Group as first column
Expand Down
3 changes: 2 additions & 1 deletion R/display.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
#' @name display.easycormatrix
#'
#' @description Export tables (i.e. data frame) into different output formats.
#' `print_md()` is a alias for `display(format = "markdown")`.
#' `print_md()` is a alias for `display(format = "markdown")`. Note that
#' you can use `format()` to get the formatted table as a dataframe.
#'
#' @param object,x An object returned by
#' [`correlation()`][correlation] or its summary.
Expand Down
17 changes: 12 additions & 5 deletions R/methods_format.R
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,17 @@ format.easycormatrix <- function(x,
# final new line
footer <- paste0(footer, "\n")

# for html/markdown, create list
# for html/markdown, modify footer format
if (!is.null(format) && format != "text") {
# no line break if not text format
footer <- unlist(strsplit(footer, "\n", fixed = TRUE))
footer <- as.list(footer[nzchar(footer, keepNA = TRUE)])
# remove empty elements
footer <- footer[nzchar(footer, keepNA = TRUE)]
# create list or separate by ";"
footer <- switch(format,
html = paste(footer, collapse = "; "),
as.list(footer)
)
}

footer
Expand All @@ -168,7 +175,9 @@ format.easycormatrix <- function(x,

#' @keywords internal
.format_easycorrelation_caption <- function(x, format = NULL) {
if (!is.null(attributes(x)$method)) {
if (is.null(attributes(x)$method)) {
caption <- NULL
} else {
if (isTRUE(attributes(x)$smoothed)) {
prefix <- "Smoothed Correlation Matrix ("
} else {
Expand All @@ -179,8 +188,6 @@ format.easycormatrix <- function(x,
} else {
caption <- paste0(prefix, unique(attributes(x)$method), "-method)")
}
} else {
caption <- NULL
}

caption
Expand Down
10 changes: 5 additions & 5 deletions R/methods_print.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

#' @export
print.easycorrelation <- function(x, ...) {
cat(insight::export_table(format(x, ...), format = "text"))
cat(insight::export_table(format(x, ...), ...))
invisible(x)
}

Expand All @@ -13,9 +13,9 @@ print.easycormatrix <- function(x, ...) {
# If real matrix, print as matrix
if (colnames(formatted)[1] == "Variables") {
formatted$Variables <- NULL
print(as.matrix(formatted))
print(as.matrix(formatted), ...)
} else {
cat(insight::export_table(format(x, ...), format = "text"))
cat(insight::export_table(format(x, ...), ...))
}
invisible(x)
}
Expand All @@ -31,7 +31,7 @@ print.easymatrixlist <- function(x, cols = "auto", ...) {

for (i in cols) {
cat(" ", i, " ", "\n", rep("-", nchar(i) + 2), "\n", sep = "")
print(x[[i]])
print(x[[i]], ...)
cat("\n")
}
}
Expand All @@ -40,7 +40,7 @@ print.easymatrixlist <- function(x, cols = "auto", ...) {
print.grouped_easymatrixlist <- function(x, cols = "auto", ...) {
for (i in names(x)) {
cat(rep("=", nchar(i) + 2), "\n ", i, " ", "\n", rep("=", nchar(i) + 2), "\n\n", sep = "")
print(x[[i]])
print(x[[i]], ...)
cat("\n")
}
}
Expand Down
1 change: 1 addition & 0 deletions correlation.Rproj
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
Version: 1.0
ProjectId: a2737226-16da-4377-8659-b462bf604f1e

RestoreWorkspace: No
SaveWorkspace: No
Expand Down
12 changes: 6 additions & 6 deletions man/correlation-package.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit ac9ef1e

Please sign in to comment.