Skip to content

Instantly share code, notes, and snippets.

@lionel-
Last active August 29, 2015 14:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lionel-/0b41c6b9d3554725807a to your computer and use it in GitHub Desktop.
Save lionel-/0b41c6b9d3554725807a to your computer and use it in GitHub Desktop.
Making lowliner better understand data frames
set_groups <- function(.d, .cols = NULL) {
stopifnot(is.data.frame(.d))
if (is.null(.cols)) {
return(group_by_(.d, .dots = list()))
}
if (is.numeric(.cols)) {
.cols <- names(.d)[.cols]
}
.cols %>% map_call(dplyr::group_by_, .data = .d)
}
unset_groups <- function(.d) {
set_groups(.d, NULL)
}
by_group <- function(.d, .f, ...) {
if (is.formula(.f)) {
.f <- lowliner:::as_function(.f)
}
if (!inherits(.d, "grouped_df")) {
return(.f(.d, ...))
}
indices <- attr(.d, "indices") %>% map(partial(`+`, 1))
classes <- attr(.d, "vars") %>% vapply(as.character, character(1))
unpiled <- .d[-match(classes, names(.d))] %>%
lapply(function(col) {
lapply(indices, . %>% col[.])
})
# Subset groups and apply function
out <- lapply(seq_along(indices), function(i) {
rows <- lapply(unpiled, function(col) col[[i]])
.f(dplyr::as_data_frame(rows), ...)
})
# If data frame, record number of rows in each group before
# merging. If any other kind, return a list-column.
out <-
if (every(out, is.data.frame)) {
lengths <- out %>% vapply(nrow, numeric(1))
dplyr::bind_rows(out)
} else {
lengths <- rep(1, length(out))
list(out = out) %>% dplyr::as_data_frame()
}
# Recycle labels to the output size in each group. dplyr's subset
# method is used because it always return a data frame
labels <- attr(.d, "labels")
n_groups <- nrow(labels)
seq <- Map(rep, seq_len(n_groups), lengths) %>% unlist()
labels <- dplyr::tbl_df(labels)[seq, ]
dplyr::bind_cols(labels, out) %>% dplyr::tbl_df()
}
#include <Rcpp.h>
#include <dplyr.h>
// [[Rcpp::depends(dplyr)]]
using namespace Rcpp;
int is_object(SEXP obj);
List apply_slices(List data, Function fun) {
ListOf<IntegerVector> indices(data.attr("indices"));
int n_slices = indices.size();
std::vector<int> slice_sizes(n_slices);
for (int i = 0; i < n_slices; ++i) {
slice_sizes[i] = indices[i].size();
}
CharacterVector classes = Rcpp::CharacterVector::create(
"tbl_df", "tbl", "data.frame"
);
dplyr::DataFrameVisitors visitors(data);
// Apply fun on each slice
List out(n_slices);
for (int i = 0; i < n_slices; ++i) {
out[i] = fun(visitors.subset(indices[i], classes));
}
return out;
}
// [[Rcpp::export]]
List by_slice_impl(const List data, Function fun) {
List out = apply_slices(data, fun);
int all_objects = 1;
for (int i = 0; i != out.size(); ++i) {
all_objects *= is_object(out[i]) * !Rf_inherits(out[i], "data.frame");
}
// Make a list-column only if all outputs are non-data frame
// objects. In all other cases, we let dplyr::bind_rows() check that
// the outputs are compatible.
if (all_objects) {
for (int i = 0; i != out.size(); ++i) {
out[i] = List::create(_[".out"] = List::create(out[i]));
List out_slice(out[i]);
out_slice.attr("row.names") = IntegerVector::create(IntegerVector::get_na(), -1);
out_slice.attr("class") = CharacterVector::create("tbl_df", "data.frame");
}
}
return out;
}
// [[Rcpp::export]]
List subset_slices(const List data) {
ListOf<IntegerVector> indices(data.attr("indices"));
int n_slices = indices.size();
std::vector<int> slice_sizes(n_slices);
for (int i = 0; i < n_slices; ++i) {
slice_sizes[i] = indices[i].size();
}
CharacterVector classes = Rcpp::CharacterVector::create(
"tbl_df", "tbl", "data.frame"
);
dplyr::DataFrameVisitors visitors(data);
List out(n_slices);
for (int i = 0; i < n_slices; ++i) {
out[i] = visitors.subset(indices[i], classes);
}
return out;
}
/*** R
# Calls ..f from Rcpp
by_slice2 <- function(.x, ..f) {
by_slice_impl(.x, ..f) %>% dplyr::bind_rows()
}
# Calls ..f from R
by_slice3 <- function(.x, ..f) {
out <- subset_slices(.x) %>% lapply(..f)
if (every(out, is.object)) {
out <- lapply(out, function(x) list(.out = list(x)) %>% dplyr::as_data_frame())
}
dplyr::bind_rows(out)
}
data <- rerun(1000, mtcars) %>% dplyr::bind_rows() %>% group_by(cyl, vs)
# R version wildly faster than Rcpp one, a bit slower than Rcpp two
fnu <- partial(lm, disp ~ gear)
microbenchmark(
R = data %>% by_slice(fnu),
dplyr = data %>% do(as_data_frame(list(.out = list(fnu(.))))),
Cpp1 = data %>% by_slice2(fnu),
Cpp2 = data %>% by_slice3(fnu)
)
# Now first Rcpp version is best, almost on par with dplyr
fnu2 <- partial(map, .f = mean)
microbenchmark(
R = data %>% by_slice(fnu2),
dplyr = data %>% summarise_each(funs(mean)),
Cpp1 = data %>% by_slice2(fnu2),
Cpp2 = data %>% by_slice3(fnu2)
)
*/
map_rows <- function(.d, .f, ..., .trace = TRUE) {
out <- map_n(.d, .f, ...) %>%
lapply(coerce_rows) %>%
dplyr::bind_rows()
if (.trace) {
dplyr::bind_cols(.d, out) %>% dplyr::tbl_df()
} else {
out
}
}
coerce_rows <- function(x) {
if (is_bare_atomic(x)) {
x %>%
as.list() %>%
setNames(seq_along(x)) %>%
dplyr::as_data_frame()
} else if (is.data.frame(x)) {
x
} else {
dplyr::data_frame(out = list(x))
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment