Skip to content

Instantly share code, notes, and snippets.

@nathan-russell
Created December 7, 2015 15:13
Show Gist options
  • Save nathan-russell/068b8a459833609c52c2 to your computer and use it in GitHub Desktop.
Save nathan-russell/068b8a459833609c52c2 to your computer and use it in GitHub Desktop.
Rcpp cbind
#include <Rcpp.h>
// ColumnBindable
//
// Provides a common, matrix-like interface
// for use in Cbind class.
template<int RTYPE, typename T>
class ColumnBindable {
public:
typedef typename Rcpp::traits::storage_type<RTYPE>::type stored_type;
typedef T VEC_TYPE;
private:
const VEC_TYPE& vec;
R_xlen_t len, nrows, ncols;
public:
ColumnBindable(const VEC_TYPE& vec_)
: vec(vec_), len(vec.size()),
nrows(vec.nrow()), ncols(vec.ncol())
{}
inline R_xlen_t size() const { return len; }
inline R_xlen_t nrow() const { return nrows; }
inline R_xlen_t ncol() const { return ncols; }
inline stored_type operator[](R_xlen_t i) const {
return vec[i];
}
inline stored_type operator()(R_xlen_t i, R_xlen_t j) const {
return vec[i + nrows * j];
}
};
// Specialization for Vector
template <int RTYPE>
class ColumnBindable< RTYPE, Rcpp::Vector<RTYPE> > {
public:
typedef typename Rcpp::traits::storage_type<RTYPE>::type stored_type;
typedef Rcpp::Vector<RTYPE> VEC_TYPE;
private:
const VEC_TYPE& vec;
R_xlen_t len, nrows, ncols;
public:
ColumnBindable(const VEC_TYPE& vec_)
: vec(vec_), len(vec.size()),
nrows(vec.size()), ncols(1)
{}
inline R_xlen_t size() const { return len; }
inline R_xlen_t nrow() const { return nrows; }
inline R_xlen_t ncol() const { return ncols; }
inline stored_type operator[](R_xlen_t i) const {
return vec[i];
}
inline stored_type operator()(R_xlen_t i, R_xlen_t j) const {
return vec[i];
}
};
// Specialization for Matrix
template <int RTYPE>
class ColumnBindable< RTYPE, Rcpp::Matrix<RTYPE> > {
public:
typedef typename Rcpp::traits::storage_type<RTYPE>::type stored_type;
typedef Rcpp::Matrix<RTYPE> VEC_TYPE;
private:
const VEC_TYPE& vec;
R_xlen_t len, nrows, ncols;
public:
ColumnBindable(const VEC_TYPE& vec_)
: vec(vec_), len(vec_.rows() * vec_.cols()),
nrows(vec.nrow()), ncols(vec.ncol())
{}
inline R_xlen_t size() const { return len; }
inline R_xlen_t nrow() const { return nrows; }
inline R_xlen_t ncol() const { return ncols; }
inline stored_type operator[](R_xlen_t i) const {
return vec[i];
}
inline stored_type operator()(R_xlen_t i, R_xlen_t j) const {
return vec[i + nrows * j];
}
};
// Cbind implementation
template <int RTYPE, typename LHS_T, typename RHS_T>
class Cbind
: public Rcpp::MatrixBase< RTYPE, true, Cbind<RTYPE, LHS_T, RHS_T>
> {
public:
typedef typename Rcpp::traits::storage_type<RTYPE>::type stored_type;
typedef ColumnBindable<RTYPE, LHS_T> LHS_TYPE;
typedef ColumnBindable<RTYPE, RHS_T> RHS_TYPE;
private:
const LHS_TYPE& lhs;
const RHS_TYPE& rhs;
public:
Cbind(const LHS_T& lhs_, const RHS_T& rhs_)
: lhs(lhs_), rhs(rhs_)
{
if (lhs.nrow() != rhs.nrow()) {
std::string msg =
"Unable to construct Cbind expression. "
"Objects must have same number of rows.\n";
Rcpp::stop(msg);
}
}
inline R_xlen_t size() const { return lhs.size() + rhs.size(); }
inline R_xlen_t nrow() const { return lhs.nrow(); }
inline R_xlen_t ncol() const { return lhs.ncol() + rhs.ncol(); }
inline stored_type operator[](R_xlen_t i) const {
return (i < lhs.size()) ? lhs[i] : rhs[i - lhs.size()];
}
inline stored_type operator()(R_xlen_t i, R_xlen_t j) const {
R_xlen_t index = i + nrow() * j;
return this->operator[](index);
}
};
// Specialization for LHS array, RHS scalar
template <int RTYPE, typename LHS_T>
class Cbind<RTYPE, LHS_T, typename Rcpp::traits::storage_type<RTYPE>::type>
: public Rcpp::MatrixBase<
RTYPE, true,
Cbind<RTYPE, LHS_T, typename Rcpp::traits::storage_type<RTYPE>::type>
> {
public:
typedef typename Rcpp::traits::storage_type<RTYPE>::type stored_type;
typedef ColumnBindable<RTYPE, LHS_T> LHS_TYPE;
typedef stored_type RHS_TYPE;
private:
const LHS_TYPE& lhs;
RHS_TYPE rhs;
public:
Cbind(const LHS_T& lhs_, const stored_type rhs_)
: lhs(lhs_), rhs(rhs_) {}
inline R_xlen_t size() const { return lhs.size() + lhs.nrow(); }
inline R_xlen_t nrow() const { return lhs.nrow(); }
inline R_xlen_t ncol() const { return lhs.ncol() + 1; }
inline stored_type operator[](R_xlen_t i) const {
return (i < lhs.size()) ? lhs[i] : rhs;
}
inline stored_type operator()(R_xlen_t i, R_xlen_t j) const {
R_xlen_t index = i + nrow() * j;
return this->operator[](index);
}
};
// Specialization for LHS scalar, RHS array
template <int RTYPE, typename RHS_T>
class Cbind<RTYPE, typename Rcpp::traits::storage_type<RTYPE>::type, RHS_T>
: public Rcpp::MatrixBase<
RTYPE, true,
Cbind<RTYPE, typename Rcpp::traits::storage_type<RTYPE>::type, RHS_T>
> {
public:
typedef typename Rcpp::traits::storage_type<RTYPE>::type stored_type;
typedef stored_type LHS_TYPE;
typedef ColumnBindable<RTYPE, RHS_T> RHS_TYPE;
private:
LHS_TYPE lhs;
const RHS_TYPE& rhs;
public:
Cbind(const stored_type lhs_, const RHS_T& rhs_)
: lhs(lhs_), rhs(rhs_) {}
inline R_xlen_t size() const { return rhs.size() + rhs.nrow(); }
inline R_xlen_t nrow() const { return rhs.nrow(); }
inline R_xlen_t ncol() const { return rhs.ncol() + 1; }
inline stored_type operator[](R_xlen_t i) const {
return (i < rhs.nrow()) ? lhs : rhs[i - rhs.nrow()];
}
inline stored_type operator()(R_xlen_t i, R_xlen_t j) const {
R_xlen_t index = i + nrow() * j;
return this->operator[](index);
}
};
// Vector, Vector
template <int RTYPE>
inline Cbind<RTYPE, Rcpp::Vector<RTYPE>, Rcpp::Vector<RTYPE> >
cbind(const Rcpp::Vector<RTYPE>& lhs, const Rcpp::Vector<RTYPE>& rhs) {
return Cbind< RTYPE, Rcpp::Vector<RTYPE>, Rcpp::Vector<RTYPE> >(lhs, rhs);
}
// Matrix, Matrix
template <int RTYPE>
inline Cbind<RTYPE, Rcpp::Matrix<RTYPE>, Rcpp::Matrix<RTYPE> >
cbind(const Rcpp::Matrix<RTYPE>& lhs, const Rcpp::Matrix<RTYPE>& rhs) {
return Cbind< RTYPE, Rcpp::Matrix<RTYPE>, Rcpp::Matrix<RTYPE> >(lhs, rhs);
}
// Vector, Matrix
template <int RTYPE>
inline Cbind<RTYPE, Rcpp::Vector<RTYPE>, Rcpp::Matrix<RTYPE> >
cbind(const Rcpp::Vector<RTYPE>& lhs, const Rcpp::Matrix<RTYPE>& rhs) {
return Cbind< RTYPE, Rcpp::Vector<RTYPE>, Rcpp::Matrix<RTYPE> >(lhs, rhs);
}
// Matrix, Vector
template <int RTYPE>
inline Cbind<RTYPE, Rcpp::Matrix<RTYPE>, Rcpp::Vector<RTYPE> >
cbind(const Rcpp::Matrix<RTYPE>& lhs, const Rcpp::Vector<RTYPE>& rhs) {
return Cbind< RTYPE, Rcpp::Matrix<RTYPE>, Rcpp::Vector<RTYPE> >(lhs, rhs);
}
// Vector, Scalar
template <int RTYPE>
inline Cbind<RTYPE, Rcpp::Vector<RTYPE>, typename Rcpp::traits::storage_type<RTYPE>::type>
cbind(const Rcpp::Vector<RTYPE>& lhs, const typename Rcpp::traits::storage_type<RTYPE>::type rhs) {
typedef typename Rcpp::traits::storage_type<RTYPE>::type scalar_t;
return Cbind<RTYPE, Rcpp::Vector<RTYPE>, scalar_t>(lhs, rhs);
}
// Matrix, Scalar
template <int RTYPE>
inline Cbind<RTYPE, Rcpp::Matrix<RTYPE>, typename Rcpp::traits::storage_type<RTYPE>::type>
cbind(const Rcpp::Matrix<RTYPE>& lhs, const typename Rcpp::traits::storage_type<RTYPE>::type rhs) {
typedef typename Rcpp::traits::storage_type<RTYPE>::type scalar_t;
return Cbind<RTYPE, Rcpp::Matrix<RTYPE>, scalar_t>(lhs, rhs);
}
// Scalar, Vector
template <int RTYPE>
inline Cbind<RTYPE, typename Rcpp::traits::storage_type<RTYPE>::type, Rcpp::Vector<RTYPE> >
cbind(const typename Rcpp::traits::storage_type<RTYPE>::type lhs, const Rcpp::Vector<RTYPE>& rhs) {
typedef typename Rcpp::traits::storage_type<RTYPE>::type scalar_t;
return Cbind< RTYPE, scalar_t, Rcpp::Vector<RTYPE> >(lhs, rhs);
}
// Scalar, Matrix
template <int RTYPE>
inline Cbind<RTYPE, typename Rcpp::traits::storage_type<RTYPE>::type, Rcpp::Matrix<RTYPE> >
cbind(const typename Rcpp::traits::storage_type<RTYPE>::type lhs, const Rcpp::Matrix<RTYPE>& rhs) {
typedef typename Rcpp::traits::storage_type<RTYPE>::type scalar_t;
return Cbind< RTYPE, scalar_t, Rcpp::Matrix<RTYPE> >(lhs, rhs);
}
// Enable nested calls; e.g. cbind(cbind(cbind(A, B), C), D)
template <int RTYPE, typename LHS_LT, typename LHS_RT, typename RHS_T>
inline Cbind<RTYPE, Cbind<RTYPE, LHS_LT, LHS_RT>, RHS_T>
cbind(const Cbind<RTYPE, LHS_LT, LHS_RT>& lhs, const RHS_T& rhs) {
return Cbind<RTYPE, Cbind<RTYPE, LHS_LT, LHS_RT>, RHS_T>(lhs, rhs);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment