Skip to content

Instantly share code, notes, and snippets.

@dfalbel
Created July 29, 2021 13:57
Show Gist options
  • Save dfalbel/c053c852bc95d671acf26e4cce9a789a to your computer and use it in GitHub Desktop.
Save dfalbel/c053c852bc95d671acf26e4cce9a789a to your computer and use it in GitHub Desktop.
Benchmark torch parallel dataloaders
library(torch)
dat <- dataset(
"mydataset",
initialize = function(time, size, len = 100 * 32) {
self$time <- time
self$len <- len
self$size <- size
},
.getitem = function(i) {
Sys.sleep(self$time)
torch_empty(self$size)
},
.length = function() {
self$len
}
)
exhaust_dl <- function(iter) {
while(!is.null(b <- dataloader_next(iter))) {
x <- b
}
}
a <- function(num_workers, time_per_batch, tensor_numel) {
ds <- dat(time_per_batch/32, tensor_numel)
dl <- dataloader(ds, batch_size = 32, num_workers = num_workers)
iter <- dataloader_make_iter(dl)
time <- system.time({
exhaust_dl(iter)
})
out <- tibble::tibble(
num_workers = num_workers,
time_per_batch = time_per_batch,
tensor_numel = tensor_numel,
time = time[["elapsed"]]
)
print(out)
out
}
results <- list(
num_workers = c(0, 2, 4, 8),
time_per_batch = c(0, 0.01, 0.05, 0.1),
tensor_numel = c(100, 1e3, 1e5, 1e6)
) %>%
purrr::cross() %>%
purrr::transpose() %>%
purrr::pmap_dfr(a)
saveRDS(results, "benchmarks.rds")
library(ggplot2)
ggplot(results, aes(x = num_workers, y = time, color = as.factor(time_per_batch))) +
geom_point() +
geom_line() +
facet_wrap(~tensor_numel, scales = "free") +
coord_cartesian(ylim = c(0, NA))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment