Skip to content

Instantly share code, notes, and snippets.

@prise6
Last active December 12, 2022 14:23
Show Gist options
  • Save prise6/5210fc2fd254be28d389c2517b94d3de to your computer and use it in GitHub Desktop.
Save prise6/5210fc2fd254be28d389c2517b94d3de to your computer and use it in GitHub Desktop.
Exploration of different ways to compute quick diff time (lag) in R by group
# ------------------------------------------------------------------------
#
# MICROBENCHMARK ABOUT LAG OPERATION WITH R (dplyr, data.table)
#
# ------------------------------------------------------------------------
#
# Notes:
#
# This script'goal is to show that when using "group_by"
# functions (dplyr or DT) "big computation" can be expensive.
# The use case is too compute diff with lag function
# dplyr can really be slow
#
# The following is juste an example, but with bigger datasets differences
# are serious
#
# I tried different way to approach it
#
# i tried one method with dplyr
# and three with data.table
# data.table ways are similar
#
library(dplyr)
library(data.table)
library(lubridate)
N = 10000
id_unique = c(LETTERS[1:26], c("A", LETTERS[1:26]))
data = data.frame(
id = rep(id_unique, each = N),
time = rep(seq(as.Date("2016-01-01"), as.Date("2016-01-01") + (N-1), by = "day"), length(id_unique)),
stringsAsFactors = F
)
dim(data)
# dplyr way ---------------------------------------------------------------
lag.dplyr = function() {
data.dplyr = data %>% arrange(id, time) %>% group_by(id) %>% mutate(time_lag = lag(time), time_diff = difftime(time, time_lag, units = "days")) %>% ungroup()
}
# data.table way ----------------------------------------------------------
# 1
# rewrite dplyr way
lag.dt1 = function() {
data.dt1 = data.table(data)
setorder(data.dt1, id, time)
data.dt1 = data.dt1[, `:=` (time_lag = c(as.Date(NA), time[-length(time)]), time_diff = difftime(time, c(as.Date(NA), time[-length(time)]), units = "days")), by = "id"]
}
# 2
# avoid "big computation" in group by
lag.dt2 = function() {
data.dt2 = data.table(data)
setorder(data.dt2, id, time)
data.dt2 = data.dt2[, time_lag := c(as.Date(NA), time[-length(time)]), by = "id"][, time_diff := difftime(time, time_lag, units = "days")]
}
# 3
# avoir "big computation" in group by bis
lag.dt3 = function() {
data.dt3 = data.table(data)
setorder(data.dt3, id, time)
data.dt3[, `:=` (time_lag = c(as.Date(NA), time[-length(time)]))][, tmp := 1:.N, by = "id"][tmp == 1, time_lag := NA][, `:=` (time_diff = difftime(time, time_lag, units = "days"), tmp = NULL)]
}
# 4
# avoid group by
lag.dt4 = function() {
data.dt4 = data.table(data)
setkey(data.dt4, id, time)
data.dt4[, `:=` (time_lag = c(as.Date(NA), time[-length(time)]))]
data.dt4[data.dt4[unique(id),, mult="first", which=T], time_lag := NA][, `:=` (time_diff = difftime(time, time_lag, units = "days"))]
}
# 5
# avoid group by bis
lag.dt5 = function() {
data.dt5 = data.table(data)
setorder(data.dt5, id, time)
data.dt5[, c("time_lag", "id_lag") := shift(list(time, id), 1, type = "lag")]
# data.dt5[, `:=` (time_lag = c(as.Date(NA), time[-length(time)]), id_lag = c(NA, id[-length(id)]))]
data.dt5[id == id_lag, `:=` (time_diff = difftime(time, time_lag, units = "days"))]
data.dt5[, id_lag := NULL]
}
# Tests -------------------------------------------------------------------
identical(data.frame(lag.dplyr()), data.frame(lag.dt1()))
identical(data.frame(lag.dplyr()), data.frame(lag.dt2()))
identical(data.frame(lag.dplyr()), data.frame(lag.dt3()))
identical(data.frame(lag.dplyr()), data.frame(lag.dt4()))
identical(data.frame(lag.dplyr())[, c("id", "time", "time_diff")], data.frame(lag.dt5())[c("id", "time", "time_diff")])
head(lag.dt1())
# Benchmark ---------------------------------------------------------------
microbenchmark::microbenchmark(
lag.dplyr(),
lag.dt1(),
lag.dt2(),
lag.dt3(),
lag.dt4(),
lag.dt5(),
times = 100
)
# run with
# N = 10000
# id_unique = c(LETTERS[1:26], c("A", LETTERS[1:26]))
# neval = 100
# Unit: milliseconds
# expr min lq mean median uq max
# lag.dplyr() 593.72174 597.94157 838.61884 609.22381 708.54437 3787.7640
# lag.dt1() 68.90300 71.68860 108.42908 75.25099 170.97737 392.1275
# lag.dt2() 59.22160 62.68490 93.21577 66.43231 138.76776 343.1273
# lag.dt3() 67.04714 69.61819 106.06478 74.16524 158.91851 357.6966
# lag.dt4() 65.23839 68.46359 96.95182 72.16294 94.77179 369.2292
# lag.dt5() 67.67634 73.89122 108.80821 77.52069 166.39375 384.9868
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment