Skip to content

Instantly share code, notes, and snippets.

@sd2k
Created June 21, 2015 00:08
Show Gist options
  • Save sd2k/6e94e9dc590502473746 to your computer and use it in GitHub Desktop.
Save sd2k/6e94e9dc590502473746 to your computer and use it in GitHub Desktop.
Methods to integrate SparkR & dplyr, and in doing so allow non-standard evaluation (e.g. select(df, age) instead of select(df, df$age))
#' dplyr compatibility methods.
#'
#' These allow Spark DataFrames to cooperate with standard dplyr verbs in a
#' more-or-less ordinary way. This also means we can chain commands together
#' using pipes (`%>%`) when selecting, filtering, adding to or summarising a
#' DataFrame.
#'
#' Currently requires dplyr to be loaded after SparkR, since SparkR clobbers
#' many of the dplyr verbs.
#' @importFrom lazyeval all_dots
#' @export
select_.DataFrame <- function(.data, ..., .dots) {
tablename <- substitute(.data)
all_dots <- all_dots(.dots, ...)
dots <- all_dots(.dots, ...)
args <- sapply(all_dots, function(x) x$expr)
if (is.list(args)) {
args <- as.list(as.character(args))
}
args <- sapply(args, function(x) {
gsub(paste0(tablename, "$"),
"",
x,
fixed = TRUE)
})
args <- c(.data, args)
do.call(SparkR::select, args)
}
#' @export
collect.DataFrame <- function(.data) {
SparkR::collect(.data)
}
#' @importFrom lazyeval all_dots
#' @export
filter_.DataFrame <- function(.data, ..., .dots) {
tablename <- substitute(.data)
dots <- all_dots(.dots, ...)
condition <- .dots[[1]]$expr
x <- deparse(condition)
for (col in columns(.data)) {
x <- gsub(
paste0("(?<!", tablename, "\\$)", col),
paste0(tablename, "$", col),
x,
perl = TRUE
)
}
condition <- eval(parse(text=x))
SparkR::filter(.data, condition)
}
#' @importFrom lazyeval all_dots
#' @export
summarise_.DataFrame <- function(.data, ..., .dots) {
tablename <- substitute(.data)
all_dots <- all_dots(.dots, ...)
dots <- all_dots(.dots, ...)
dots <- c(.data, sapply(dots, function(x) x$expr))
args <- sapply(
dots,
function(x) if (!class(x) == "DataFrame"){
x <- deparse(x)
for (col in columns(.data)) {
x <- gsub(
paste0("(?<!", tablename, "\\$)", col),
paste0(tablename, "$", col),
x,
perl = TRUE
)
}
eval(parse(text=x))
} else {
x
}
)
do.call(SparkR::summarize, args)
}
#' @importFrom lazyeval all_dots
#' @export
summarise_.GroupedData <- function(.data, ..., .dots) {
tablename <- substitute(.data)
all_dots <- all_dots(.dots, ...)
dots <- all_dots(.dots, ...)
dots <- c(.data, sapply(dots, function(x) x$expr))
args <- sapply(
dots,
function(x) {
if (!class(x) == "GroupedData"){
out <- x
names(out) = x
} else {
x
}
}
)
do.call(SparkR::summarize, args)
}
#' @importFrom lazyeval all_dots
#' @export
group_by_.DataFrame <- function(.data, ..., .dots, add = FALSE) {
tablename <- substitute(.data)
all_dots <- all_dots(.dots, ...)
dots <- all_dots(.dots, ...)
args <- sapply(all_dots, function(x) x$expr)
if (is.list(args)) {
args <- as.list(as.character(args))
}
args <- sapply(args, function(x) {
gsub(paste0(tablename, "$"),
"",
x,
fixed = TRUE)
})
args <- c(.data, args)
do.call(SparkR::group_by, args)
}
#' @importFrom lazyeval all_dots
#' @export
mutate_.DataFrame <- function(.data, ..., .dots) {
tablename <- substitute(.data)
dots <- all_dots(.dots, ...)
dots <- c(.data, sapply(dots, function(x) x$expr))
args <- sapply(
dots,
function(x) if (!class(x) == "DataFrame"){
x <- deparse(x)
for (col in columns(.data)) {
x <- gsub(
paste0("(?<!", tablename, "\\$)", col),
paste0(tablename, "$", col),
x,
perl = TRUE
)
}
eval(parse(text=x))
} else {
x
}
)
do.call(SparkR::mutate, args)
}
## Examples
library(SparkR, lib.loc=file.path(Sys.getenv("SPARK_HOME"), "R", "lib"))
sc <- sparkR.init(master="local")
sqlContext <- sparkRSQL.init(sc)
library(dplyr)
df <- jsonFile(sqlContext, file.path(Sys.getenv("SPARK_HOME"), "examples/src/main/resources/people.json"))
select(df, name) %>% collect
select_(df, .dots = c("age", "name")) %>% collect
df %>% select(age) %>% collect
mutate(df, age2 = age * 2) %>% collect
df %>% mutate(age2 = age * 2, age3 = age * 3) %>% collect
filter(df, age > 10 & name == "Justin") %>% collect
df %>% filter(name != "Andy") %>% collect
df2 <- df %>%
mutate(age2 = age * 2) %>%
filter(age > 10) %>%
select(name, age2) %>%
group_by(name) %>%
summarise(age2 = "sum")
# test %>% collect # takes quite a while
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment