Last active
June 1, 2021 13:08
-
-
Save EoinTravers/086a597e9bbab76be36f5dc65fc4b5be to your computer and use it in GitHub Desktop.
Plot a (correlation/covariance) matrix in R
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#' Plot matrix as heatmap | |
#' @description | |
#' `plot_covariance_matrix()` adds an appropriate `fill_label` | |
#' `plot_correlation_matrix()` also sets limits to ±1 | |
#' | |
#' Make sure to import dplyr and ggplot (or tidyverse)! | |
#' @param mat Matrix to plot | |
#' @param labeller Function or list used to rename rows/cols | |
#' @param digits Digits to round values to (default 2) | |
#' @param limit Limit (±) of colour scale. Defaults to `abs(max(value))` | |
#' @param fill_label Label to use for colourscale | |
#' @param fill_gradient Optional custom fill gradient (default blue-white-red) | |
#' | |
#' @examples | |
#' iris %>% | |
#' select_if(is.numeric) %>% | |
#' cor() %>% | |
#' plot_correlation_matrix() | |
#' | |
#' @export | |
plot_matrix = function(mat, labeller=NULL, | |
digits=2, limit=NULL, fill_label=NULL, | |
tilt_x = TRUE, | |
fill_gradient=NULL) { | |
var_order1 = rownames(mat) | |
var_order2 = colnames(mat) | |
df = gather_matrix(mat) %>% | |
mutate(var1 = factor(var1, levels=var_order1), | |
var2 = factor(var2, levels=rev(var_order2))) | |
if(is.null(limit)){ | |
limit = abs(max(df$value)) | |
} | |
# Handle labels | |
if(is.function(labeller)){ | |
levels(df$var1) = labeller(levels(df$var1)) | |
levels(df$var2) = labeller(levels(df$var2)) | |
} else if (is.list(labeller)){ | |
levels(df$var1) = replace_if_found(levels(df$var1), labeller) | |
levels(df$var2) = replace_if_found(levels(df$var2), labeller) | |
} | |
g = df %>% | |
ggplot(aes(var1, var2, fill=value, label=round(value, digits))) + | |
geom_tile() + | |
geom_label(fill='white') + | |
coord_fixed() + | |
labs(x='', y='', fill=fill_label) | |
# Use default gradient (blue-white-red), or apply custom one. | |
if(is.null(fill_gradient)){ | |
g = g + scale_fill_gradient2(low='red', mid='white', high='blue', | |
limits=c(-limit, limit)) | |
} else { | |
g = g + fill_gradient | |
} | |
if(tilt_x) g = g + tilt_x_ticks() | |
g | |
} | |
#' @rdname plot_matrix | |
#' @export | |
plot_covariance_matrix = function(cov_mat, ...){ | |
plot_matrix(cov_mat, fill_label='(Co)variance', ...) | |
} | |
#' @rdname plot_matrix | |
#' @export | |
plot_correlation_matrix = function(cor_mat, ...){ | |
plot_matrix(cor_mat, limit=1, fill_label='Correlation', ...) | |
} | |
#' Gathers a matrix into a long data frame | |
#' @importFrom tibble rownames_to_column | |
#' @importFrom tidyr gather | |
gather_matrix = function(mat) { | |
df = mat %>% | |
data.frame() | |
names(df) = colnames(mat) | |
df %>% | |
tibble::rownames_to_column('var1') %>% | |
# Should upgrade this to use tidyr::pivot_longer | |
gather(var2, value, -var1) | |
} | |
#' Place x-axis ticks at an angle. | |
tilt_x_ticks = function(angle=45, vjust=1, hjust=1){ | |
theme(axis.text.x = element_text(angle=angle, vjust=vjust, hjust=hjust)) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment