Created
December 7, 2015 15:13
-
-
Save nathan-russell/068b8a459833609c52c2 to your computer and use it in GitHub Desktop.
Rcpp cbind
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <Rcpp.h> | |
// ColumnBindable | |
// | |
// Provides a common, matrix-like interface | |
// for use in Cbind class. | |
template<int RTYPE, typename T> | |
class ColumnBindable { | |
public: | |
typedef typename Rcpp::traits::storage_type<RTYPE>::type stored_type; | |
typedef T VEC_TYPE; | |
private: | |
const VEC_TYPE& vec; | |
R_xlen_t len, nrows, ncols; | |
public: | |
ColumnBindable(const VEC_TYPE& vec_) | |
: vec(vec_), len(vec.size()), | |
nrows(vec.nrow()), ncols(vec.ncol()) | |
{} | |
inline R_xlen_t size() const { return len; } | |
inline R_xlen_t nrow() const { return nrows; } | |
inline R_xlen_t ncol() const { return ncols; } | |
inline stored_type operator[](R_xlen_t i) const { | |
return vec[i]; | |
} | |
inline stored_type operator()(R_xlen_t i, R_xlen_t j) const { | |
return vec[i + nrows * j]; | |
} | |
}; | |
// Specialization for Vector | |
template <int RTYPE> | |
class ColumnBindable< RTYPE, Rcpp::Vector<RTYPE> > { | |
public: | |
typedef typename Rcpp::traits::storage_type<RTYPE>::type stored_type; | |
typedef Rcpp::Vector<RTYPE> VEC_TYPE; | |
private: | |
const VEC_TYPE& vec; | |
R_xlen_t len, nrows, ncols; | |
public: | |
ColumnBindable(const VEC_TYPE& vec_) | |
: vec(vec_), len(vec.size()), | |
nrows(vec.size()), ncols(1) | |
{} | |
inline R_xlen_t size() const { return len; } | |
inline R_xlen_t nrow() const { return nrows; } | |
inline R_xlen_t ncol() const { return ncols; } | |
inline stored_type operator[](R_xlen_t i) const { | |
return vec[i]; | |
} | |
inline stored_type operator()(R_xlen_t i, R_xlen_t j) const { | |
return vec[i]; | |
} | |
}; | |
// Specialization for Matrix | |
template <int RTYPE> | |
class ColumnBindable< RTYPE, Rcpp::Matrix<RTYPE> > { | |
public: | |
typedef typename Rcpp::traits::storage_type<RTYPE>::type stored_type; | |
typedef Rcpp::Matrix<RTYPE> VEC_TYPE; | |
private: | |
const VEC_TYPE& vec; | |
R_xlen_t len, nrows, ncols; | |
public: | |
ColumnBindable(const VEC_TYPE& vec_) | |
: vec(vec_), len(vec_.rows() * vec_.cols()), | |
nrows(vec.nrow()), ncols(vec.ncol()) | |
{} | |
inline R_xlen_t size() const { return len; } | |
inline R_xlen_t nrow() const { return nrows; } | |
inline R_xlen_t ncol() const { return ncols; } | |
inline stored_type operator[](R_xlen_t i) const { | |
return vec[i]; | |
} | |
inline stored_type operator()(R_xlen_t i, R_xlen_t j) const { | |
return vec[i + nrows * j]; | |
} | |
}; | |
// Cbind implementation | |
template <int RTYPE, typename LHS_T, typename RHS_T> | |
class Cbind | |
: public Rcpp::MatrixBase< RTYPE, true, Cbind<RTYPE, LHS_T, RHS_T> | |
> { | |
public: | |
typedef typename Rcpp::traits::storage_type<RTYPE>::type stored_type; | |
typedef ColumnBindable<RTYPE, LHS_T> LHS_TYPE; | |
typedef ColumnBindable<RTYPE, RHS_T> RHS_TYPE; | |
private: | |
const LHS_TYPE& lhs; | |
const RHS_TYPE& rhs; | |
public: | |
Cbind(const LHS_T& lhs_, const RHS_T& rhs_) | |
: lhs(lhs_), rhs(rhs_) | |
{ | |
if (lhs.nrow() != rhs.nrow()) { | |
std::string msg = | |
"Unable to construct Cbind expression. " | |
"Objects must have same number of rows.\n"; | |
Rcpp::stop(msg); | |
} | |
} | |
inline R_xlen_t size() const { return lhs.size() + rhs.size(); } | |
inline R_xlen_t nrow() const { return lhs.nrow(); } | |
inline R_xlen_t ncol() const { return lhs.ncol() + rhs.ncol(); } | |
inline stored_type operator[](R_xlen_t i) const { | |
return (i < lhs.size()) ? lhs[i] : rhs[i - lhs.size()]; | |
} | |
inline stored_type operator()(R_xlen_t i, R_xlen_t j) const { | |
R_xlen_t index = i + nrow() * j; | |
return this->operator[](index); | |
} | |
}; | |
// Specialization for LHS array, RHS scalar | |
template <int RTYPE, typename LHS_T> | |
class Cbind<RTYPE, LHS_T, typename Rcpp::traits::storage_type<RTYPE>::type> | |
: public Rcpp::MatrixBase< | |
RTYPE, true, | |
Cbind<RTYPE, LHS_T, typename Rcpp::traits::storage_type<RTYPE>::type> | |
> { | |
public: | |
typedef typename Rcpp::traits::storage_type<RTYPE>::type stored_type; | |
typedef ColumnBindable<RTYPE, LHS_T> LHS_TYPE; | |
typedef stored_type RHS_TYPE; | |
private: | |
const LHS_TYPE& lhs; | |
RHS_TYPE rhs; | |
public: | |
Cbind(const LHS_T& lhs_, const stored_type rhs_) | |
: lhs(lhs_), rhs(rhs_) {} | |
inline R_xlen_t size() const { return lhs.size() + lhs.nrow(); } | |
inline R_xlen_t nrow() const { return lhs.nrow(); } | |
inline R_xlen_t ncol() const { return lhs.ncol() + 1; } | |
inline stored_type operator[](R_xlen_t i) const { | |
return (i < lhs.size()) ? lhs[i] : rhs; | |
} | |
inline stored_type operator()(R_xlen_t i, R_xlen_t j) const { | |
R_xlen_t index = i + nrow() * j; | |
return this->operator[](index); | |
} | |
}; | |
// Specialization for LHS scalar, RHS array | |
template <int RTYPE, typename RHS_T> | |
class Cbind<RTYPE, typename Rcpp::traits::storage_type<RTYPE>::type, RHS_T> | |
: public Rcpp::MatrixBase< | |
RTYPE, true, | |
Cbind<RTYPE, typename Rcpp::traits::storage_type<RTYPE>::type, RHS_T> | |
> { | |
public: | |
typedef typename Rcpp::traits::storage_type<RTYPE>::type stored_type; | |
typedef stored_type LHS_TYPE; | |
typedef ColumnBindable<RTYPE, RHS_T> RHS_TYPE; | |
private: | |
LHS_TYPE lhs; | |
const RHS_TYPE& rhs; | |
public: | |
Cbind(const stored_type lhs_, const RHS_T& rhs_) | |
: lhs(lhs_), rhs(rhs_) {} | |
inline R_xlen_t size() const { return rhs.size() + rhs.nrow(); } | |
inline R_xlen_t nrow() const { return rhs.nrow(); } | |
inline R_xlen_t ncol() const { return rhs.ncol() + 1; } | |
inline stored_type operator[](R_xlen_t i) const { | |
return (i < rhs.nrow()) ? lhs : rhs[i - rhs.nrow()]; | |
} | |
inline stored_type operator()(R_xlen_t i, R_xlen_t j) const { | |
R_xlen_t index = i + nrow() * j; | |
return this->operator[](index); | |
} | |
}; | |
// Vector, Vector | |
template <int RTYPE> | |
inline Cbind<RTYPE, Rcpp::Vector<RTYPE>, Rcpp::Vector<RTYPE> > | |
cbind(const Rcpp::Vector<RTYPE>& lhs, const Rcpp::Vector<RTYPE>& rhs) { | |
return Cbind< RTYPE, Rcpp::Vector<RTYPE>, Rcpp::Vector<RTYPE> >(lhs, rhs); | |
} | |
// Matrix, Matrix | |
template <int RTYPE> | |
inline Cbind<RTYPE, Rcpp::Matrix<RTYPE>, Rcpp::Matrix<RTYPE> > | |
cbind(const Rcpp::Matrix<RTYPE>& lhs, const Rcpp::Matrix<RTYPE>& rhs) { | |
return Cbind< RTYPE, Rcpp::Matrix<RTYPE>, Rcpp::Matrix<RTYPE> >(lhs, rhs); | |
} | |
// Vector, Matrix | |
template <int RTYPE> | |
inline Cbind<RTYPE, Rcpp::Vector<RTYPE>, Rcpp::Matrix<RTYPE> > | |
cbind(const Rcpp::Vector<RTYPE>& lhs, const Rcpp::Matrix<RTYPE>& rhs) { | |
return Cbind< RTYPE, Rcpp::Vector<RTYPE>, Rcpp::Matrix<RTYPE> >(lhs, rhs); | |
} | |
// Matrix, Vector | |
template <int RTYPE> | |
inline Cbind<RTYPE, Rcpp::Matrix<RTYPE>, Rcpp::Vector<RTYPE> > | |
cbind(const Rcpp::Matrix<RTYPE>& lhs, const Rcpp::Vector<RTYPE>& rhs) { | |
return Cbind< RTYPE, Rcpp::Matrix<RTYPE>, Rcpp::Vector<RTYPE> >(lhs, rhs); | |
} | |
// Vector, Scalar | |
template <int RTYPE> | |
inline Cbind<RTYPE, Rcpp::Vector<RTYPE>, typename Rcpp::traits::storage_type<RTYPE>::type> | |
cbind(const Rcpp::Vector<RTYPE>& lhs, const typename Rcpp::traits::storage_type<RTYPE>::type rhs) { | |
typedef typename Rcpp::traits::storage_type<RTYPE>::type scalar_t; | |
return Cbind<RTYPE, Rcpp::Vector<RTYPE>, scalar_t>(lhs, rhs); | |
} | |
// Matrix, Scalar | |
template <int RTYPE> | |
inline Cbind<RTYPE, Rcpp::Matrix<RTYPE>, typename Rcpp::traits::storage_type<RTYPE>::type> | |
cbind(const Rcpp::Matrix<RTYPE>& lhs, const typename Rcpp::traits::storage_type<RTYPE>::type rhs) { | |
typedef typename Rcpp::traits::storage_type<RTYPE>::type scalar_t; | |
return Cbind<RTYPE, Rcpp::Matrix<RTYPE>, scalar_t>(lhs, rhs); | |
} | |
// Scalar, Vector | |
template <int RTYPE> | |
inline Cbind<RTYPE, typename Rcpp::traits::storage_type<RTYPE>::type, Rcpp::Vector<RTYPE> > | |
cbind(const typename Rcpp::traits::storage_type<RTYPE>::type lhs, const Rcpp::Vector<RTYPE>& rhs) { | |
typedef typename Rcpp::traits::storage_type<RTYPE>::type scalar_t; | |
return Cbind< RTYPE, scalar_t, Rcpp::Vector<RTYPE> >(lhs, rhs); | |
} | |
// Scalar, Matrix | |
template <int RTYPE> | |
inline Cbind<RTYPE, typename Rcpp::traits::storage_type<RTYPE>::type, Rcpp::Matrix<RTYPE> > | |
cbind(const typename Rcpp::traits::storage_type<RTYPE>::type lhs, const Rcpp::Matrix<RTYPE>& rhs) { | |
typedef typename Rcpp::traits::storage_type<RTYPE>::type scalar_t; | |
return Cbind< RTYPE, scalar_t, Rcpp::Matrix<RTYPE> >(lhs, rhs); | |
} | |
// Enable nested calls; e.g. cbind(cbind(cbind(A, B), C), D) | |
template <int RTYPE, typename LHS_LT, typename LHS_RT, typename RHS_T> | |
inline Cbind<RTYPE, Cbind<RTYPE, LHS_LT, LHS_RT>, RHS_T> | |
cbind(const Cbind<RTYPE, LHS_LT, LHS_RT>& lhs, const RHS_T& rhs) { | |
return Cbind<RTYPE, Cbind<RTYPE, LHS_LT, LHS_RT>, RHS_T>(lhs, rhs); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment