Created
June 21, 2015 00:08
-
-
Save sd2k/6e94e9dc590502473746 to your computer and use it in GitHub Desktop.
Methods to integrate SparkR & dplyr, and in doing so allow non-standard evaluation (e.g. select(df, age) instead of select(df, df$age))
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#' dplyr compatibility methods. | |
#' | |
#' These allow Spark DataFrames to cooperate with standard dplyr verbs in a | |
#' more-or-less ordinary way. This also means we can chain commands together | |
#' using pipes (`%>%`) when selecting, filtering, adding to or summarising a | |
#' DataFrame. | |
#' | |
#' Currently requires dplyr to be loaded after SparkR, since SparkR clobbers | |
#' many of the dplyr verbs. | |
#' @importFrom lazyeval all_dots | |
#' @export | |
select_.DataFrame <- function(.data, ..., .dots) { | |
tablename <- substitute(.data) | |
all_dots <- all_dots(.dots, ...) | |
dots <- all_dots(.dots, ...) | |
args <- sapply(all_dots, function(x) x$expr) | |
if (is.list(args)) { | |
args <- as.list(as.character(args)) | |
} | |
args <- sapply(args, function(x) { | |
gsub(paste0(tablename, "$"), | |
"", | |
x, | |
fixed = TRUE) | |
}) | |
args <- c(.data, args) | |
do.call(SparkR::select, args) | |
} | |
#' @export | |
collect.DataFrame <- function(.data) { | |
SparkR::collect(.data) | |
} | |
#' @importFrom lazyeval all_dots | |
#' @export | |
filter_.DataFrame <- function(.data, ..., .dots) { | |
tablename <- substitute(.data) | |
dots <- all_dots(.dots, ...) | |
condition <- .dots[[1]]$expr | |
x <- deparse(condition) | |
for (col in columns(.data)) { | |
x <- gsub( | |
paste0("(?<!", tablename, "\\$)", col), | |
paste0(tablename, "$", col), | |
x, | |
perl = TRUE | |
) | |
} | |
condition <- eval(parse(text=x)) | |
SparkR::filter(.data, condition) | |
} | |
#' @importFrom lazyeval all_dots | |
#' @export | |
summarise_.DataFrame <- function(.data, ..., .dots) { | |
tablename <- substitute(.data) | |
all_dots <- all_dots(.dots, ...) | |
dots <- all_dots(.dots, ...) | |
dots <- c(.data, sapply(dots, function(x) x$expr)) | |
args <- sapply( | |
dots, | |
function(x) if (!class(x) == "DataFrame"){ | |
x <- deparse(x) | |
for (col in columns(.data)) { | |
x <- gsub( | |
paste0("(?<!", tablename, "\\$)", col), | |
paste0(tablename, "$", col), | |
x, | |
perl = TRUE | |
) | |
} | |
eval(parse(text=x)) | |
} else { | |
x | |
} | |
) | |
do.call(SparkR::summarize, args) | |
} | |
#' @importFrom lazyeval all_dots | |
#' @export | |
summarise_.GroupedData <- function(.data, ..., .dots) { | |
tablename <- substitute(.data) | |
all_dots <- all_dots(.dots, ...) | |
dots <- all_dots(.dots, ...) | |
dots <- c(.data, sapply(dots, function(x) x$expr)) | |
args <- sapply( | |
dots, | |
function(x) { | |
if (!class(x) == "GroupedData"){ | |
out <- x | |
names(out) = x | |
} else { | |
x | |
} | |
} | |
) | |
do.call(SparkR::summarize, args) | |
} | |
#' @importFrom lazyeval all_dots | |
#' @export | |
group_by_.DataFrame <- function(.data, ..., .dots, add = FALSE) { | |
tablename <- substitute(.data) | |
all_dots <- all_dots(.dots, ...) | |
dots <- all_dots(.dots, ...) | |
args <- sapply(all_dots, function(x) x$expr) | |
if (is.list(args)) { | |
args <- as.list(as.character(args)) | |
} | |
args <- sapply(args, function(x) { | |
gsub(paste0(tablename, "$"), | |
"", | |
x, | |
fixed = TRUE) | |
}) | |
args <- c(.data, args) | |
do.call(SparkR::group_by, args) | |
} | |
#' @importFrom lazyeval all_dots | |
#' @export | |
mutate_.DataFrame <- function(.data, ..., .dots) { | |
tablename <- substitute(.data) | |
dots <- all_dots(.dots, ...) | |
dots <- c(.data, sapply(dots, function(x) x$expr)) | |
args <- sapply( | |
dots, | |
function(x) if (!class(x) == "DataFrame"){ | |
x <- deparse(x) | |
for (col in columns(.data)) { | |
x <- gsub( | |
paste0("(?<!", tablename, "\\$)", col), | |
paste0(tablename, "$", col), | |
x, | |
perl = TRUE | |
) | |
} | |
eval(parse(text=x)) | |
} else { | |
x | |
} | |
) | |
do.call(SparkR::mutate, args) | |
} | |
## Examples | |
library(SparkR, lib.loc=file.path(Sys.getenv("SPARK_HOME"), "R", "lib")) | |
sc <- sparkR.init(master="local") | |
sqlContext <- sparkRSQL.init(sc) | |
library(dplyr) | |
df <- jsonFile(sqlContext, file.path(Sys.getenv("SPARK_HOME"), "examples/src/main/resources/people.json")) | |
select(df, name) %>% collect | |
select_(df, .dots = c("age", "name")) %>% collect | |
df %>% select(age) %>% collect | |
mutate(df, age2 = age * 2) %>% collect | |
df %>% mutate(age2 = age * 2, age3 = age * 3) %>% collect | |
filter(df, age > 10 & name == "Justin") %>% collect | |
df %>% filter(name != "Andy") %>% collect | |
df2 <- df %>% | |
mutate(age2 = age * 2) %>% | |
filter(age > 10) %>% | |
select(name, age2) %>% | |
group_by(name) %>% | |
summarise(age2 = "sum") | |
# test %>% collect # takes quite a while |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment