Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] cor_sort() can deal with non-square matrices #334

Merged
merged 26 commits into from
Dec 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Type: Package
Package: correlation
Title: Methods for Correlation Analysis
Version: 0.8.6
Version: 0.8.6.1
Authors@R:
c(person(given = "Dominique",
family = "Makowski",
Expand Down Expand Up @@ -57,8 +57,8 @@ Imports:
bayestestR (>= 0.15.0),
datasets,
datawizard (>= 0.13.0),
insight (>= 0.20.5),
parameters (>= 0.22.2),
insight (>= 1.0.0),
parameters (>= 0.24.0),
stats
Suggests:
BayesFactor,
Expand Down
90 changes: 80 additions & 10 deletions R/cor_sort.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@

#' @export
cor_sort.easycorrelation <- function(x, distance = "correlation", hclust_method = "complete", ...) {
col_order <- .cor_sort_order(as.matrix(x), distance = distance, hclust_method = hclust_method, ...)
x$Parameter1 <- factor(x$Parameter1, levels = col_order)
x$Parameter2 <- factor(x$Parameter2, levels = col_order)
m <- cor_sort(as.matrix(x), distance = distance, hclust_method = hclust_method, ...)
x$Parameter1 <- factor(x$Parameter1, levels = rownames(m))
x$Parameter2 <- factor(x$Parameter2, levels = colnames(m))
reordered <- x[order(x$Parameter1, x$Parameter2), ]

# Restore class and attributes
Expand All @@ -38,6 +38,8 @@
)

# Make sure Parameter columns are character
# Was added to fix a test, but makes the function not work
# (See https://github.com/easystats/correlation/issues/259)
# reordered$Parameter1 <- as.character(reordered$Parameter1)
# reordered$Parameter2 <- as.character(reordered$Parameter2)

Expand All @@ -55,18 +57,32 @@
m <- x
row.names(m) <- x$Parameter
m <- as.matrix(m[names(m)[names(m) != "Parameter"]])
col_order <- .cor_sort_order(m, distance = distance, hclust_method = hclust_method, ...)

# If non-redundant matrix, fail (## TODO: fix that)
if (anyNA(m)) {
insight::format_error("Non-redundant matrices are not supported yet. Try again by setting summary(..., redundant = TRUE)")

Check warning on line 63 in R/cor_sort.R

View workflow job for this annotation

GitHub Actions / lint / lint

file=R/cor_sort.R,line=63,col=121,[line_length_linter] Lines should not be more than 120 characters. This line is 126 characters.

Check warning on line 63 in R/cor_sort.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/cor_sort.R,line=63,col=121,[line_length_linter] Lines should not be more than 120 characters. This line is 126 characters.
}

# Get sorted matrix
m <- cor_sort(m, distance = distance, hclust_method = hclust_method, ...)

# Reorder
x$Parameter <- factor(x$Parameter, levels = col_order)
reordered <- x[order(x$Parameter), c("Parameter", col_order)]
x$Parameter <- factor(x$Parameter, levels = row.names(m))
reordered <- x[order(x$Parameter), c("Parameter", colnames(m))]

# Restore class and attributes
attributes(reordered) <- utils::modifyList(
attributes(x)[!names(attributes(x)) %in% c("names", "row.names")],
attributes(reordered)
)

# Reorder attributes (p-values) etc.
for (id in c("p", "CI", "CI_low", "CI_high", "BF", "Method", "n_Obs", "df_error", "t")) {
if (id %in% names(attributes(reordered))) {
attributes(reordered)[[id]] <- attributes(reordered)[[id]][order(x$Parameter), names(reordered)]
}
}

# make sure Parameter columns are character
reordered$Parameter <- as.character(reordered$Parameter)

Expand All @@ -76,8 +92,13 @@

#' @export
cor_sort.matrix <- function(x, distance = "correlation", hclust_method = "complete", ...) {
col_order <- .cor_sort_order(x, distance = distance, hclust_method = hclust_method, ...)
reordered <- x[col_order, col_order]
if (isSquare(x) && all(colnames(x) %in% rownames(x))) {
i <- .cor_sort_square(x, distance = distance, hclust_method = hclust_method, ...)
} else {
i <- .cor_sort_nonsquare(x, distance = "euclidean", ...)
}

reordered <- x[i$row_order, i$col_order]

# Restore class and attributes
attributes(reordered) <- utils::modifyList(
Expand All @@ -91,7 +112,7 @@
# Utils -------------------------------------------------------------------


.cor_sort_order <- function(m, distance = "correlation", hclust_method = "complete", ...) {
.cor_sort_square <- function(m, distance = "correlation", hclust_method = "complete", ...) {
if (distance == "correlation") {
d <- stats::as.dist((1 - m) / 2) # r = -1 -> d = 1; r = 1 -> d = 0
} else if (distance == "raw") {
Expand All @@ -101,5 +122,54 @@
}

hc <- stats::hclust(d, method = hclust_method)
row.names(m)[hc$order]
row_order <- row.names(m)[hc$order]
list(row_order = row_order, col_order = row_order)
}


.cor_sort_nonsquare <- function(m, distance = "euclidean", ...) {
# Step 1: Perform clustering on rows and columns independently
row_dist <- stats::dist(m, method = distance) # Distance between rows
col_dist <- stats::dist(t(m), method = distance) # Distance between columns

row_hclust <- stats::hclust(row_dist, method = "average")
col_hclust <- stats::hclust(col_dist, method = "average")

# Obtain clustering orders
row_order <- row_hclust$order
col_order <- col_hclust$order

# Reorder matrix based on clustering
clustered_matrix <- m[row_order, col_order]

# Step 2: Refine alignment to emphasize strong correlations along the diagonal
n_rows <- nrow(clustered_matrix)
n_cols <- ncol(clustered_matrix)

used_rows <- logical(n_rows)
refined_row_order <- integer(0)

for (col in seq_len(n_cols)) {
max_value <- -Inf
best_row <- NA

for (row in seq_len(n_rows)[!used_rows]) {
if (abs(clustered_matrix[row, col]) > max_value) {
max_value <- abs(clustered_matrix[row, col])
best_row <- row
}
}

if (!is.na(best_row)) {
refined_row_order <- c(refined_row_order, best_row)
used_rows[best_row] <- TRUE
}
}

# Append any unused rows at the end
refined_row_order <- c(refined_row_order, which(!used_rows))

# Apply
m <- clustered_matrix[refined_row_order, ]
list(row_order = rownames(m), col_order = colnames(m))
}
82 changes: 39 additions & 43 deletions R/correlation.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#'
#' Performs a correlation analysis.
#' You can easily visualize the result using [`plot()`][visualisation_recipe.easycormatrix()]
#' (see examples [**here**](https://easystats.github.io/correlation/reference/visualisation_recipe.easycormatrix.html#ref-examples)).

Check warning on line 5 in R/correlation.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/correlation.R,line=5,col=121,[line_length_linter] Lines should not be more than 120 characters. This line is 133 characters.
#'
#' @param data A data frame.
#' @param data2 An optional data frame. If specified, all pair-wise correlations
Expand Down Expand Up @@ -180,10 +180,9 @@
#' `stats` package are supported.
#' }
#'
#' @examplesIf requireNamespace("poorman", quietly = TRUE) && requireNamespace("psych", quietly = TRUE)
#'
#' @examplesIf all(insight::check_if_installed(c("psych", "datawizard"), quietly = TRUE)) && getRversion() >= "4.1.0"
#' library(correlation)
#' library(poorman)
#' data(iris)
#'
#' results <- correlation(iris)
#'
Expand All @@ -192,22 +191,20 @@
#' summary(results, redundant = TRUE)
#'
#' # pipe-friendly usage with grouped dataframes from {dplyr} package
#' iris %>%
#' iris |>
#' correlation(select = "Petal.Width", select2 = "Sepal.Length")
#'
#' # Grouped dataframe
#' # grouped correlations
#' iris %>%
#' group_by(Species) %>%
#' iris |>
#' datawizard::data_group(Species) |>
#' correlation()
#'
#' # selecting specific variables for correlation
#' mtcars %>%
#' group_by(am) %>%
#' correlation(
#' select = c("cyl", "wt"),
#' select2 = c("hp")
#' )
#' data(mtcars)
#' mtcars |>
#' datawizard::data_group(am) |>
#' correlation(select = c("cyl", "wt"), select2 = "hp")
#'
#' # supplying custom variable names
#' correlation(anscombe, select = c("x1", "x2"), rename = c("var1", "var2"))
Expand Down Expand Up @@ -388,7 +385,7 @@
attr(out, "additional_arguments") <- list(...)

if (inherits(data, "grouped_df")) {
class(out) <- unique(c("easycorrelation", "see_easycorrelation", "grouped_easycorrelation", "parameters_model", class(out)))

Check warning on line 388 in R/correlation.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/correlation.R,line=388,col=121,[line_length_linter] Lines should not be more than 120 characters. This line is 128 characters.
} else {
class(out) <- unique(c("easycorrelation", "see_easycorrelation", "parameters_model", class(out)))
}
Expand Down Expand Up @@ -425,9 +422,37 @@
ungrouped_x <- as.data.frame(data)
xlist <- split(ungrouped_x, ungrouped_x[groups], sep = " - ")

# If data 2 is provided
if (!is.null(data2)) {
# If data 2 is not provided
if (is.null(data2)) {
modelframe <- data.frame()
out <- data.frame()
for (i in names(xlist)) {
xlist[[i]][groups] <- NULL
rez <- .correlation(
xlist[[i]],
data2,
method = method,
p_adjust = p_adjust,
ci = ci,
bayesian = bayesian,
bayesian_prior = bayesian_prior,
bayesian_ci_method = bayesian_ci_method,
bayesian_test = bayesian_test,
redundant = redundant,
include_factors = include_factors,
partial = partial,
partial_bayesian = partial_bayesian,
multilevel = multilevel,
ranktransform = ranktransform,
winsorize = winsorize
)
modelframe_current <- rez$data
rez$params$Group <- modelframe_current$Group <- i
out <- rbind(out, rez$params)
modelframe <- rbind(modelframe, modelframe_current)
}
} else {
if (inherits(data2, "grouped_df")) {

Check warning on line 455 in R/correlation.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/correlation.R,line=455,col=5,[unnecessary_nesting_linter] Simplify this condition by using 'else if' instead of 'else { if.
groups2 <- setdiff(colnames(attributes(data2)$groups), ".rows")
if (!all.equal(groups, groups2)) {
insight::format_error("'data2' should have the same grouping characteristics as data.")
Expand Down Expand Up @@ -463,35 +488,6 @@
modelframe <- rbind(modelframe, modelframe_current)
}
}
# else
} else {
modelframe <- data.frame()
out <- data.frame()
for (i in names(xlist)) {
xlist[[i]][groups] <- NULL
rez <- .correlation(
xlist[[i]],
data2,
method = method,
p_adjust = p_adjust,
ci = ci,
bayesian = bayesian,
bayesian_prior = bayesian_prior,
bayesian_ci_method = bayesian_ci_method,
bayesian_test = bayesian_test,
redundant = redundant,
include_factors = include_factors,
partial = partial,
partial_bayesian = partial_bayesian,
multilevel = multilevel,
ranktransform = ranktransform,
winsorize = winsorize
)
modelframe_current <- rez$data
rez$params$Group <- modelframe_current$Group <- i
out <- rbind(out, rez$params)
modelframe <- rbind(modelframe, modelframe_current)
}
}

# Group as first column
Expand Down Expand Up @@ -526,7 +522,7 @@

if (ncol(data) <= 2L && any(sapply(data, is.factor)) && !include_factors) {
if (isTRUE(verbose)) {
insight::format_warning("It seems like there is not enough continuous variables in your data. Maybe you want to include the factors? We're setting `include_factors=TRUE` for you.")

Check warning on line 525 in R/correlation.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/correlation.R,line=525,col=121,[line_length_linter] Lines should not be more than 120 characters. This line is 186 characters.
}
include_factors <- TRUE
}
Expand Down
3 changes: 2 additions & 1 deletion R/display.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
#' @name display.easycormatrix
#'
#' @description Export tables (i.e. data frame) into different output formats.
#' `print_md()` is a alias for `display(format = "markdown")`.
#' `print_md()` is a alias for `display(format = "markdown")`. Note that
#' you can use `format()` to get the formatted table as a dataframe.
#'
#' @param object,x An object returned by
#' [`correlation()`][correlation] or its summary.
Expand Down
17 changes: 12 additions & 5 deletions R/methods_format.R
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,17 @@ format.easycormatrix <- function(x,
# final new line
footer <- paste0(footer, "\n")

# for html/markdown, create list
# for html/markdown, modify footer format
if (!is.null(format) && format != "text") {
# no line break if not text format
footer <- unlist(strsplit(footer, "\n", fixed = TRUE))
footer <- as.list(footer[nzchar(footer, keepNA = TRUE)])
# remove empty elements
footer <- footer[nzchar(footer, keepNA = TRUE)]
# create list or separate by ";"
footer <- switch(format,
html = paste(footer, collapse = "; "),
as.list(footer)
)
}

footer
Expand All @@ -168,7 +175,9 @@ format.easycormatrix <- function(x,

#' @keywords internal
.format_easycorrelation_caption <- function(x, format = NULL) {
if (!is.null(attributes(x)$method)) {
if (is.null(attributes(x)$method)) {
caption <- NULL
} else {
if (isTRUE(attributes(x)$smoothed)) {
prefix <- "Smoothed Correlation Matrix ("
} else {
Expand All @@ -179,8 +188,6 @@ format.easycormatrix <- function(x,
} else {
caption <- paste0(prefix, unique(attributes(x)$method), "-method)")
}
} else {
caption <- NULL
}

caption
Expand Down
10 changes: 5 additions & 5 deletions R/methods_print.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

#' @export
print.easycorrelation <- function(x, ...) {
cat(insight::export_table(format(x, ...), format = "text"))
cat(insight::export_table(format(x, ...), ...))
invisible(x)
}

Expand All @@ -13,9 +13,9 @@ print.easycormatrix <- function(x, ...) {
# If real matrix, print as matrix
if (colnames(formatted)[1] == "Variables") {
formatted$Variables <- NULL
print(as.matrix(formatted))
print(as.matrix(formatted), ...)
} else {
cat(insight::export_table(format(x, ...), format = "text"))
cat(insight::export_table(format(x, ...), ...))
}
invisible(x)
}
Expand All @@ -31,7 +31,7 @@ print.easymatrixlist <- function(x, cols = "auto", ...) {

for (i in cols) {
cat(" ", i, " ", "\n", rep("-", nchar(i) + 2), "\n", sep = "")
print(x[[i]])
print(x[[i]], ...)
cat("\n")
}
}
Expand All @@ -40,7 +40,7 @@ print.easymatrixlist <- function(x, cols = "auto", ...) {
print.grouped_easymatrixlist <- function(x, cols = "auto", ...) {
for (i in names(x)) {
cat(rep("=", nchar(i) + 2), "\n ", i, " ", "\n", rep("=", nchar(i) + 2), "\n\n", sep = "")
print(x[[i]])
print(x[[i]], ...)
cat("\n")
}
}
Expand Down
1 change: 1 addition & 0 deletions correlation.Rproj
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
Version: 1.0
ProjectId: a2737226-16da-4377-8659-b462bf604f1e

RestoreWorkspace: No
SaveWorkspace: No
Expand Down
12 changes: 6 additions & 6 deletions man/correlation-package.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading