Skip to content

Instantly share code, notes, and snippets.

@nathan-russell
Created December 8, 2015 01:13
Show Gist options
  • Save nathan-russell/8bf4e874418887aa1b6c to your computer and use it in GitHub Desktop.
Save nathan-russell/8bf4e874418887aa1b6c to your computer and use it in GitHub Desktop.
Test __VA_ARGS__ cbind macro
#include <Rcpp.h>
template<int RTYPE, typename T>
class ColumnBindable {
public:
typedef typename Rcpp::traits::storage_type<RTYPE>::type stored_type;
typedef T VEC_TYPE;
private:
const VEC_TYPE& vec;
R_xlen_t len, nrows, ncols;
public:
ColumnBindable(const VEC_TYPE& vec_)
: vec(vec_), len(vec.size()),
nrows(vec.nrow()), ncols(vec.ncol())
{}
inline R_xlen_t size() const { return len; }
inline R_xlen_t nrow() const { return nrows; }
inline R_xlen_t ncol() const { return ncols; }
inline stored_type operator[](R_xlen_t i) const {
return vec[i];
}
inline stored_type operator()(R_xlen_t i, R_xlen_t j) const {
return vec[i + nrows * j];
}
};
template <int RTYPE>
class ColumnBindable< RTYPE, Rcpp::Vector<RTYPE> > {
public:
typedef typename Rcpp::traits::storage_type<RTYPE>::type stored_type;
typedef Rcpp::Vector<RTYPE> VEC_TYPE;
private:
const VEC_TYPE& vec;
R_xlen_t len, nrows, ncols;
public:
ColumnBindable(const VEC_TYPE& vec_)
: vec(vec_), len(vec.size()),
nrows(vec.size()), ncols(1)
{}
inline R_xlen_t size() const { return len; }
inline R_xlen_t nrow() const { return nrows; }
inline R_xlen_t ncol() const { return ncols; }
inline stored_type operator[](R_xlen_t i) const {
return vec[i];
}
inline stored_type operator()(R_xlen_t i, R_xlen_t j) const {
return vec[i];
}
};
template <int RTYPE>
class ColumnBindable< RTYPE, Rcpp::Matrix<RTYPE> > {
public:
typedef typename Rcpp::traits::storage_type<RTYPE>::type stored_type;
typedef Rcpp::Matrix<RTYPE> VEC_TYPE;
private:
const VEC_TYPE& vec;
R_xlen_t len, nrows, ncols;
public:
ColumnBindable(const VEC_TYPE& vec_)
: vec(vec_), len(vec_.rows() * vec_.cols()),
nrows(vec.nrow()), ncols(vec.ncol())
{}
inline R_xlen_t size() const { return len; }
inline R_xlen_t nrow() const { return nrows; }
inline R_xlen_t ncol() const { return ncols; }
inline stored_type operator[](R_xlen_t i) const {
return vec[i];
}
inline stored_type operator()(R_xlen_t i, R_xlen_t j) const {
return vec[i + nrows * j];
}
};
template <int RTYPE, typename LHS_T, typename RHS_T>
class Cbind
: public Rcpp::MatrixBase< RTYPE, true, Cbind<RTYPE, LHS_T, RHS_T>
> {
public:
typedef typename Rcpp::traits::storage_type<RTYPE>::type stored_type;
typedef ColumnBindable<RTYPE, LHS_T> LHS_TYPE;
typedef ColumnBindable<RTYPE, RHS_T> RHS_TYPE;
private:
LHS_TYPE lhs;
RHS_TYPE rhs;
public:
Cbind(const LHS_T& lhs_, const RHS_T& rhs_)
: lhs(lhs_), rhs(rhs_)
{
if (lhs.nrow() != rhs.nrow()) {
std::string msg =
"Unable to construct Cbind expression. "
"Objects must have same number of rows.\n";
Rcpp::stop(msg);
}
}
inline R_xlen_t size() const { return lhs.size() + rhs.size(); }
inline R_xlen_t nrow() const { return lhs.nrow(); }
inline R_xlen_t ncol() const { return lhs.ncol() + rhs.ncol(); }
inline stored_type operator[](R_xlen_t i) const {
return (i < lhs.size()) ? lhs[i] : rhs[i - lhs.size()];
}
inline stored_type operator()(R_xlen_t i, R_xlen_t j) const {
R_xlen_t index = i + nrow() * j;
return this->operator[](index);
}
};
template <int RTYPE, typename LHS_T>
class Cbind<RTYPE, LHS_T, typename Rcpp::traits::storage_type<RTYPE>::type>
: public Rcpp::MatrixBase<
RTYPE, true,
Cbind<RTYPE, LHS_T, typename Rcpp::traits::storage_type<RTYPE>::type>
> {
public:
typedef typename Rcpp::traits::storage_type<RTYPE>::type stored_type;
typedef ColumnBindable<RTYPE, LHS_T> LHS_TYPE;
typedef stored_type RHS_TYPE;
private:
LHS_TYPE lhs;
RHS_TYPE rhs;
public:
Cbind(const LHS_T& lhs_, stored_type rhs_)
: lhs(lhs_), rhs(rhs_) {}
inline R_xlen_t size() const { return lhs.size() + lhs.nrow(); }
inline R_xlen_t nrow() const { return lhs.nrow(); }
inline R_xlen_t ncol() const { return lhs.ncol() + 1; }
inline stored_type operator[](R_xlen_t i) const {
return (i < lhs.size()) ? lhs[i] : rhs;
}
inline stored_type operator()(R_xlen_t i, R_xlen_t j) const {
R_xlen_t index = i + nrow() * j;
return this->operator[](index);
}
};
template <int RTYPE, typename RHS_T>
class Cbind<RTYPE, typename Rcpp::traits::storage_type<RTYPE>::type, RHS_T>
: public Rcpp::MatrixBase<
RTYPE, true,
Cbind<RTYPE, typename Rcpp::traits::storage_type<RTYPE>::type, RHS_T>
> {
public:
typedef typename Rcpp::traits::storage_type<RTYPE>::type stored_type;
typedef stored_type LHS_TYPE;
typedef ColumnBindable<RTYPE, RHS_T> RHS_TYPE;
private:
LHS_TYPE lhs;
RHS_TYPE rhs;
public:
Cbind(const stored_type lhs_, const RHS_T& rhs_)
: lhs(lhs_), rhs(rhs_) {}
inline R_xlen_t size() const { return rhs.size() + rhs.nrow(); }
inline R_xlen_t nrow() const { return rhs.nrow(); }
inline R_xlen_t ncol() const { return rhs.ncol() + 1; }
inline stored_type operator[](R_xlen_t i) const {
return (i < rhs.nrow()) ? lhs : rhs[i - rhs.nrow()];
}
inline stored_type operator()(R_xlen_t i, R_xlen_t j) const {
R_xlen_t index = i + nrow() * j;
return this->operator[](index);
}
};
template <int RTYPE>
inline Cbind<RTYPE, Rcpp::Vector<RTYPE>, Rcpp::Vector<RTYPE> >
cbind_impl(const Rcpp::Vector<RTYPE>& lhs, const Rcpp::Vector<RTYPE>& rhs) {
return Cbind< RTYPE, Rcpp::Vector<RTYPE>, Rcpp::Vector<RTYPE> >(lhs, rhs);
}
template <int RTYPE>
inline Cbind<RTYPE, Rcpp::Matrix<RTYPE>, Rcpp::Matrix<RTYPE> >
cbind_impl(const Rcpp::Matrix<RTYPE>& lhs, const Rcpp::Matrix<RTYPE>& rhs) {
return Cbind< RTYPE, Rcpp::Matrix<RTYPE>, Rcpp::Matrix<RTYPE> >(lhs, rhs);
}
template <int RTYPE>
inline Cbind<RTYPE, Rcpp::Vector<RTYPE>, Rcpp::Matrix<RTYPE> >
cbind_impl(const Rcpp::Vector<RTYPE>& lhs, const Rcpp::Matrix<RTYPE>& rhs) {
return Cbind< RTYPE, Rcpp::Vector<RTYPE>, Rcpp::Matrix<RTYPE> >(lhs, rhs);
}
template <int RTYPE>
inline Cbind<RTYPE, Rcpp::Matrix<RTYPE>, Rcpp::Vector<RTYPE> >
cbind_impl(const Rcpp::Matrix<RTYPE>& lhs, const Rcpp::Vector<RTYPE>& rhs) {
return Cbind< RTYPE, Rcpp::Matrix<RTYPE>, Rcpp::Vector<RTYPE> >(lhs, rhs);
}
template <int RTYPE>
inline Cbind<RTYPE, Rcpp::Vector<RTYPE>, typename Rcpp::traits::storage_type<RTYPE>::type>
cbind_impl(const Rcpp::Vector<RTYPE>& lhs, const typename Rcpp::traits::storage_type<RTYPE>::type rhs) {
typedef typename Rcpp::traits::storage_type<RTYPE>::type scalar_t;
return Cbind<RTYPE, Rcpp::Vector<RTYPE>, scalar_t>(lhs, rhs);
}
template <int RTYPE>
inline Cbind<RTYPE, Rcpp::Matrix<RTYPE>, typename Rcpp::traits::storage_type<RTYPE>::type>
cbind_impl(const Rcpp::Matrix<RTYPE>& lhs, const typename Rcpp::traits::storage_type<RTYPE>::type rhs) {
typedef typename Rcpp::traits::storage_type<RTYPE>::type scalar_t;
return Cbind<RTYPE, Rcpp::Matrix<RTYPE>, scalar_t>(lhs, rhs);
}
template <int RTYPE>
inline Cbind<RTYPE, typename Rcpp::traits::storage_type<RTYPE>::type, Rcpp::Vector<RTYPE> >
cbind_impl(const typename Rcpp::traits::storage_type<RTYPE>::type lhs, const Rcpp::Vector<RTYPE>& rhs) {
typedef typename Rcpp::traits::storage_type<RTYPE>::type scalar_t;
return Cbind< RTYPE, scalar_t, Rcpp::Vector<RTYPE> >(lhs, rhs);
}
template <int RTYPE>
inline Cbind<RTYPE, typename Rcpp::traits::storage_type<RTYPE>::type, Rcpp::Matrix<RTYPE> >
cbind_impl(const typename Rcpp::traits::storage_type<RTYPE>::type lhs, const Rcpp::Matrix<RTYPE>& rhs) {
typedef typename Rcpp::traits::storage_type<RTYPE>::type scalar_t;
return Cbind< RTYPE, scalar_t, Rcpp::Matrix<RTYPE> >(lhs, rhs);
}
template <int RTYPE, typename LHS_LT, typename LHS_RT, typename RHS_T>
inline Cbind<RTYPE, Cbind<RTYPE, LHS_LT, LHS_RT>, RHS_T>
cbind_impl(const Cbind<RTYPE, LHS_LT, LHS_RT>& lhs, const RHS_T& rhs) {
return Cbind<RTYPE, Cbind<RTYPE, LHS_LT, LHS_RT>, RHS_T>(lhs, rhs);
}
#include "cbind_macro_dispatch.h"
typedef Rcpp::NumericVector vec;
// cat(paste0("vec ", letters[1:20], collapse = ", ")),
// cat(paste0(letters[1:20], collapse = ", "))
// [[Rcpp::export]]
Rcpp::NumericMatrix bind_20(
vec a, vec b, vec c, vec d, vec e,
vec f, vec g, vec h, vec i, vec j,
vec k, vec l, vec m, vec n, vec o,
vec p, vec q, vec r, vec s, vec t)
{
return cbind(
a, b, c, d, e, f, g, h, i, j,
k, l, m, n, o, p, q, r, s, t);
}
/*** R
## cat(paste0(1:20, collapse = ", "))
cpp_res <- bind_20(
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20
)
r_res <- cbind(
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20
)
all.equal(cpp_res, r_res)
*/
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment