Skip to content

Instantly share code, notes, and snippets.

@EoinTravers
Last active June 1, 2021 13:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save EoinTravers/086a597e9bbab76be36f5dc65fc4b5be to your computer and use it in GitHub Desktop.
Save EoinTravers/086a597e9bbab76be36f5dc65fc4b5be to your computer and use it in GitHub Desktop.
Plot a (correlation/covariance) matrix in R
#' Plot matrix as heatmap
#' @description
#' `plot_covariance_matrix()` adds an appropriate `fill_label`
#' `plot_correlation_matrix()` also sets limits to ±1
#'
#' Make sure to import dplyr and ggplot (or tidyverse)!
#' @param mat Matrix to plot
#' @param labeller Function or list used to rename rows/cols
#' @param digits Digits to round values to (default 2)
#' @param limit Limit (±) of colour scale. Defaults to `abs(max(value))`
#' @param fill_label Label to use for colourscale
#' @param fill_gradient Optional custom fill gradient (default blue-white-red)
#'
#' @examples
#' iris %>%
#' select_if(is.numeric) %>%
#' cor() %>%
#' plot_correlation_matrix()
#'
#' @export
plot_matrix = function(mat, labeller=NULL,
digits=2, limit=NULL, fill_label=NULL,
tilt_x = TRUE,
fill_gradient=NULL) {
var_order1 = rownames(mat)
var_order2 = colnames(mat)
df = gather_matrix(mat) %>%
mutate(var1 = factor(var1, levels=var_order1),
var2 = factor(var2, levels=rev(var_order2)))
if(is.null(limit)){
limit = abs(max(df$value))
}
# Handle labels
if(is.function(labeller)){
levels(df$var1) = labeller(levels(df$var1))
levels(df$var2) = labeller(levels(df$var2))
} else if (is.list(labeller)){
levels(df$var1) = replace_if_found(levels(df$var1), labeller)
levels(df$var2) = replace_if_found(levels(df$var2), labeller)
}
g = df %>%
ggplot(aes(var1, var2, fill=value, label=round(value, digits))) +
geom_tile() +
geom_label(fill='white') +
coord_fixed() +
labs(x='', y='', fill=fill_label)
# Use default gradient (blue-white-red), or apply custom one.
if(is.null(fill_gradient)){
g = g + scale_fill_gradient2(low='red', mid='white', high='blue',
limits=c(-limit, limit))
} else {
g = g + fill_gradient
}
if(tilt_x) g = g + tilt_x_ticks()
g
}
#' @rdname plot_matrix
#' @export
plot_covariance_matrix = function(cov_mat, ...){
plot_matrix(cov_mat, fill_label='(Co)variance', ...)
}
#' @rdname plot_matrix
#' @export
plot_correlation_matrix = function(cor_mat, ...){
plot_matrix(cor_mat, limit=1, fill_label='Correlation', ...)
}
#' Gathers a matrix into a long data frame
#' @importFrom tibble rownames_to_column
#' @importFrom tidyr gather
gather_matrix = function(mat) {
df = mat %>%
data.frame()
names(df) = colnames(mat)
df %>%
tibble::rownames_to_column('var1') %>%
# Should upgrade this to use tidyr::pivot_longer
gather(var2, value, -var1)
}
#' Place x-axis ticks at an angle.
tilt_x_ticks = function(angle=45, vjust=1, hjust=1){
theme(axis.text.x = element_text(angle=angle, vjust=vjust, hjust=hjust))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment