Last active
March 12, 2016 01:51
-
-
Save nathan-russell/29da20070554744800bd to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <Rcpp.h> | |
namespace sugar { | |
namespace cbind_impl { | |
// CRTP base class | |
template <int RTYPE, typename E> | |
class BindableExpression { | |
public: | |
typedef typename Rcpp::traits::storage_type<RTYPE>::type stored_type; | |
stored_type operator[](R_xlen_t i) const { | |
return static_cast<const E&>(*this)[i]; | |
} | |
stored_type operator()(R_xlen_t i, R_xlen_t j) const { | |
return static_cast<const E&>(*this)(i, j); | |
} | |
R_xlen_t size() const { | |
return static_cast<const E&>(*this).size(); | |
} | |
R_xlen_t nrow() const { | |
return static_cast<const E&>(*this).nrow(); | |
} | |
R_xlen_t ncol() const { | |
return static_cast<const E&>(*this).ncol(); | |
} | |
operator E&() { return static_cast<E&>(*this); } | |
operator const E&() const { return static_cast<const E&>(*this); } | |
}; | |
// Matrix, Vector | |
template <int RTYPE, typename T> | |
class ContainerBindable | |
: public BindableExpression<RTYPE, ContainerBindable<RTYPE, T> > { | |
public: | |
typedef typename Rcpp::traits::storage_type<RTYPE>::type stored_type; | |
private: | |
T vec; | |
R_xlen_t len, nr, nc; | |
public: | |
ContainerBindable(const Rcpp::Matrix<RTYPE>& vec_) | |
: vec(vec_), len(vec.ncol() * vec.nrow()), | |
nr(vec.nrow()), nc(vec.ncol()) | |
{} | |
ContainerBindable(const Rcpp::Vector<RTYPE>& vec_) | |
: vec(vec_), len(vec.size()), | |
nr(vec.size()), nc(1) | |
{} | |
template <typename S> | |
ContainerBindable(const BindableExpression<RTYPE, S>& e) | |
: vec(e.size()), len(e.size()), | |
nr(e.nrow()), nc(e.ncol()) | |
{ | |
for (R_xlen_t i = 0; i < len; i++) { | |
vec[i] = e[i]; | |
} | |
} | |
inline R_xlen_t size() const { return len; } | |
inline R_xlen_t nrow() const { return nr; } | |
inline R_xlen_t ncol() const { return nc; } | |
inline stored_type operator[](R_xlen_t i) const { | |
return vec[i]; | |
} | |
inline stored_type operator()(R_xlen_t i, R_xlen_t j) const { | |
return vec[i + nr * j]; | |
} | |
}; | |
template <int RTYPE> | |
struct scalar_type { | |
typedef typename Rcpp::traits::storage_type<RTYPE>::type type; | |
typedef Rcpp::Vector<RTYPE> vector_type; | |
}; | |
// scalars | |
template <typename T> | |
class ScalarBindable | |
: public BindableExpression< | |
Rcpp::traits::r_sexptype_traits<T>::rtype, ScalarBindable<T> > { | |
public: | |
typedef T stored_type; | |
enum { RTYPE = Rcpp::traits::r_sexptype_traits<T>::rtype }; | |
private: | |
T t; | |
public: | |
ScalarBindable(const T& t_) : t(t_) {} | |
inline R_xlen_t size() const { return 1; } | |
inline R_xlen_t nrow() const { return 1; } | |
inline R_xlen_t ncol() const { return 1; } | |
inline stored_type operator[](R_xlen_t i) const { | |
return t; | |
} | |
inline stored_type operator()(R_xlen_t i, R_xlen_t j) const { | |
return t; | |
} | |
}; | |
// binding logic; non-scalar operands | |
template <int RTYPE, typename E1, typename E2> | |
class JoinOp | |
: public BindableExpression<RTYPE, JoinOp<RTYPE, E1, E2> >, | |
public Rcpp::MatrixBase<RTYPE, true, JoinOp<RTYPE, E1, E2> > { | |
public: | |
typedef typename Rcpp::traits::storage_type<RTYPE>::type stored_type; | |
private: | |
const E1& e1; | |
const E2& e2; | |
public: | |
JoinOp(const BindableExpression<RTYPE, E1>& e1_, | |
const BindableExpression<RTYPE, E2>& e2_) | |
: e1(e1_), e2(e2_) | |
{ | |
if (e1.nrow() != e2.nrow()) { | |
std::string msg = | |
"Error in cbind: " | |
"Matrix and Vector operands " | |
"must have equal " | |
"number of rows (length)."; | |
Rcpp::stop(msg); | |
} | |
} | |
inline R_xlen_t size() const { return e1.size() + e2.size(); } | |
inline R_xlen_t nrow() const { return e1.nrow(); } | |
inline R_xlen_t ncol() const { | |
return e1.ncol() + e2.ncol(); | |
} | |
inline stored_type operator[](R_xlen_t i) const { | |
return (i < e1.size()) ? e1[i] : e2[i - e1.size()]; | |
} | |
inline stored_type operator()(R_xlen_t i, R_xlen_t j) const { | |
R_xlen_t index = i + nrow() * j; | |
return (*this)[index]; | |
} | |
}; | |
// binding logic; rhs scalar | |
template <int RTYPE, typename E1> | |
class JoinOp<RTYPE, E1, ScalarBindable<typename scalar_type<RTYPE>::type> > | |
: public BindableExpression<RTYPE, | |
JoinOp<RTYPE, E1, | |
ScalarBindable<typename scalar_type<RTYPE>::type> > >, | |
public Rcpp::MatrixBase<RTYPE, true, | |
JoinOp<RTYPE, E1, | |
ScalarBindable<typename scalar_type<RTYPE>::type> > > { | |
public: | |
typedef typename Rcpp::traits::storage_type<RTYPE>::type stored_type; | |
typedef ScalarBindable<typename scalar_type<RTYPE>::type> E2; | |
private: | |
const E1& e1; | |
const E2& e2; | |
public: | |
JoinOp(const BindableExpression<RTYPE, E1>& e1_, | |
const BindableExpression<RTYPE, E2>& e2_) | |
: e1(e1_), e2(e2_) | |
{} | |
inline R_xlen_t size() const { return e1.size() + e1.nrow(); } | |
inline R_xlen_t nrow() const { return e1.nrow(); } | |
inline R_xlen_t ncol() const { return e1.ncol() + 1; } | |
inline stored_type operator[](R_xlen_t i) const { | |
return (i < e1.size()) ? e1[i] : e2[i]; | |
} | |
inline stored_type operator()(R_xlen_t i, R_xlen_t j) const { | |
R_xlen_t index = i + nrow() * j; | |
return (*this)[index]; | |
} | |
}; | |
// binding logic; lhs scalar | |
template <int RTYPE, typename E2> | |
class JoinOp<RTYPE, ScalarBindable<typename scalar_type<RTYPE>::type>, E2> | |
: public BindableExpression<RTYPE, | |
JoinOp<RTYPE, | |
ScalarBindable<typename scalar_type<RTYPE>::type>, E2> >, | |
public Rcpp::MatrixBase<RTYPE, true, | |
JoinOp<RTYPE, | |
ScalarBindable<typename scalar_type<RTYPE>::type>, E2> > { | |
public: | |
typedef typename Rcpp::traits::storage_type<RTYPE>::type stored_type; | |
typedef ScalarBindable<typename scalar_type<RTYPE>::type> E1; | |
private: | |
const E1& e1; | |
const E2& e2; | |
public: | |
JoinOp(const BindableExpression<RTYPE, E1>& e1_, | |
const BindableExpression<RTYPE, E2>& e2_) | |
: e1(e1_), e2(e2_) | |
{} | |
inline R_xlen_t size() const { return e2.size() + e2.nrow(); } | |
inline R_xlen_t nrow() const { return e2.nrow(); } | |
inline R_xlen_t ncol() const { return e2.ncol() + 1; } | |
inline stored_type operator[](R_xlen_t i) const { | |
return (i < e2.nrow()) ? e1[i] : e2[i - e2.nrow()]; | |
} | |
inline stored_type operator()(R_xlen_t i, R_xlen_t j) const { | |
R_xlen_t index = i + nrow() * j; | |
return (*this)[index]; | |
} | |
}; | |
// binding logic; both scalar | |
template <int RTYPE> | |
class JoinOp<RTYPE, ScalarBindable<typename scalar_type<RTYPE>::type>, | |
ScalarBindable<typename scalar_type<RTYPE>::type> > | |
: public BindableExpression<RTYPE, | |
JoinOp<RTYPE, | |
ScalarBindable<typename scalar_type<RTYPE>::type>, | |
ScalarBindable<typename scalar_type<RTYPE>::type> > >, | |
public Rcpp::MatrixBase<RTYPE, true, | |
JoinOp<RTYPE, | |
ScalarBindable<typename scalar_type<RTYPE>::type>, | |
ScalarBindable<typename scalar_type<RTYPE>::type> > > { | |
public: | |
typedef typename Rcpp::traits::storage_type<RTYPE>::type stored_type; | |
typedef ScalarBindable<typename scalar_type<RTYPE>::type> E1; | |
typedef ScalarBindable<typename scalar_type<RTYPE>::type> E2; | |
private: | |
const E1& e1; | |
const E2& e2; | |
public: | |
JoinOp(const BindableExpression<RTYPE, E1>& e1_, | |
const BindableExpression<RTYPE, E2>& e2_) | |
: e1(e1_), e2(e2_) | |
{} | |
inline R_xlen_t size() const { return e2.size() + e2.nrow(); } | |
inline R_xlen_t nrow() const { return e2.nrow(); } | |
inline R_xlen_t ncol() const { return e2.ncol() + 1; } | |
inline stored_type operator[](R_xlen_t i) const { | |
return (i < e2.nrow()) ? e1[i] : e2[i]; | |
} | |
inline stored_type operator()(R_xlen_t i, R_xlen_t j) const { | |
R_xlen_t index = i + nrow() * j; | |
return (*this)[index]; | |
} | |
}; | |
template <int RTYPE> | |
inline ContainerBindable<RTYPE, Rcpp::Matrix<RTYPE> > | |
MakeContainerBindable(const Rcpp::Matrix<RTYPE>& x) { | |
return ContainerBindable<RTYPE, Rcpp::Matrix<RTYPE> >(x); | |
} | |
template <int RTYPE> | |
inline ContainerBindable<RTYPE, Rcpp::Vector<RTYPE> > | |
MakeContainerBindable(const Rcpp::Vector<RTYPE>& x) { | |
return ContainerBindable<RTYPE, Rcpp::Vector<RTYPE> >(x); | |
} | |
template <typename T> | |
inline ScalarBindable<T> | |
MakeScalarBindable(const T& t) { | |
return ScalarBindable<T>(t); | |
} | |
// for expressions of arbitrary length | |
template <int RTYPE, typename E1, typename E2> | |
inline JoinOp<RTYPE, E1, E2> operator,( | |
const BindableExpression<RTYPE, E1>& e1, | |
const BindableExpression<RTYPE, E2>& e2) | |
{ | |
return JoinOp<RTYPE, E1, E2>(e1, e2); | |
} | |
// helpers | |
namespace detail { | |
// distinguish Matrix/Vector from scalar | |
template <typename T> | |
class has_stored_type { | |
private: | |
typedef char yes; | |
typedef struct { | |
char array[2]; | |
} no; | |
template <typename C> | |
static yes test(typename C::stored_type*); | |
template <typename C> | |
static no test(...); | |
public: | |
static const bool value = sizeof(test<T>(0)) == sizeof(yes); | |
}; | |
// functor to dispatch appropriate wrapping behavior | |
template <typename T, bool is_container = has_stored_type<T>::value> | |
struct MakeBindableCall {}; | |
template <typename T> | |
struct MakeBindableCall<T, true> { | |
typedef typename T::stored_type stored_type; | |
enum { RTYPE = Rcpp::traits::r_sexptype_traits<stored_type>::rtype }; | |
ContainerBindable<RTYPE, T> operator()(const T& t) const { | |
return MakeContainerBindable(t); | |
} | |
}; | |
template <typename T> | |
struct MakeBindableCall<T, false> { | |
enum { RTYPE = Rcpp::traits::r_sexptype_traits<T>::rtype }; | |
ScalarBindable<T> operator()(const T& t) const { | |
return MakeScalarBindable(t); | |
} | |
}; | |
template <typename T> | |
inline typename Rcpp::traits::enable_if< | |
has_stored_type<T>::value, | |
MakeBindableCall<T, true> | |
>::type | |
MakeBindable(const T& t) { | |
return MakeBindableCall<T, true>(); | |
} | |
template <typename T> | |
inline typename Rcpp::traits::enable_if< | |
!has_stored_type<T>::value, | |
MakeBindableCall<T, false> | |
>::type | |
MakeBindable(const T& t) { | |
return MakeBindableCall<T, false>(); | |
} | |
// determine cbind return type from first template | |
// parameter, agnostic of Matrix/Vector/scalar | |
template <typename T, bool is_container = has_stored_type<T>::value> | |
struct matrix_return {}; | |
template <typename T> | |
struct matrix_return<T, true> { | |
typedef typename T::stored_type stored_type; | |
enum { RTYPE = Rcpp::traits::r_sexptype_traits<stored_type>::rtype }; | |
typedef Rcpp::Matrix<RTYPE> type; | |
}; | |
template <typename T> | |
struct matrix_return<T, false> { | |
enum { RTYPE = Rcpp::traits::r_sexptype_traits<T>::rtype }; | |
typedef Rcpp::Matrix<RTYPE> type; | |
}; | |
} // detail | |
template <typename T, bool B = detail::has_stored_type<T>::value> | |
struct matrix_return | |
: public detail::matrix_return<T, B> {}; | |
template <typename T> | |
struct matrix_return<T, false> | |
: public detail::matrix_return<T, false> {}; | |
} // cbind_impl | |
#define MakeBindable(x) (cbind_impl::detail::MakeBindable(x)(x)) | |
template <typename T1, typename T2> | |
inline typename cbind_impl::matrix_return<T1>::type | |
cbind(const T1& t1, const T2& t2) { | |
return (MakeBindable(t1), MakeBindable(t2)); | |
} | |
template <typename T1, typename T2, typename T3> | |
inline typename cbind_impl::matrix_return<T1>::type | |
cbind(const T1& t1, const T2& t2, const T3& t3) { | |
return (MakeBindable(t1), MakeBindable(t2), | |
MakeBindable(t3)); | |
} | |
template <typename T1, typename T2, typename T3, | |
typename T4 | |
> inline typename cbind_impl::matrix_return<T1>::type | |
cbind(const T1& t1, const T2& t2, const T3& t3, const T4& t4) { | |
return (MakeBindable(t1), MakeBindable(t2), | |
MakeBindable(t3), MakeBindable(t4)); | |
} | |
template <typename T1, typename T2, typename T3, | |
typename T4, typename T5 | |
> inline typename cbind_impl::matrix_return<T1>::type | |
cbind(const T1& t1, const T2& t2, const T3& t3, | |
const T4& t4, const T5& t5) { | |
return (MakeBindable(t1), MakeBindable(t2), | |
MakeBindable(t3), MakeBindable(t4), | |
MakeBindable(t5)); | |
} | |
template<typename T1, typename T2, typename T3, | |
typename T4, typename T5, typename T6 | |
> inline typename cbind_impl::matrix_return<T1>::type | |
cbind(const T1& t1, const T2& t2, const T3& t3, | |
const T4& t4, const T5& t5, const T6& t6) { | |
return (MakeBindable(t1), MakeBindable(t2), | |
MakeBindable(t3), MakeBindable(t4), | |
MakeBindable(t5), MakeBindable(t6)); | |
} | |
// ... | |
template<typename T1, typename T2, typename T3, | |
typename T4, typename T5, typename T6, | |
typename T7, typename T8, typename T9, | |
typename T10 | |
> inline typename cbind_impl::matrix_return<T1>::type | |
cbind(const T1& t1, const T2& t2, const T3& t3, | |
const T4& t4, const T5& t5, const T6& t6, | |
const T7& t7, const T8& t8, const T9& t9, const T10& t10) { | |
return (MakeBindable(t1), MakeBindable(t2), | |
MakeBindable(t3), MakeBindable(t4), | |
MakeBindable(t5), MakeBindable(t6), | |
MakeBindable(t7), MakeBindable(t8), | |
MakeBindable(t9), MakeBindable(t10)); | |
} | |
// etc... | |
} // sugar | |
using sugar::cbind; | |
typedef Rcpp::ComplexMatrix cx_mat; | |
typedef Rcpp::ComplexVector cx_vec; | |
typedef Rcpp::NumericMatrix nmat; | |
typedef Rcpp::NumericVector nvec; | |
// [[Rcpp::export]] | |
cx_mat cbind6(cx_mat m1, cx_vec v1, Rcomplex d1, | |
cx_mat m2, cx_vec v2, Rcomplex d2) { | |
return cbind(m1, v1, d1, m2, v2, d2); | |
} | |
// [[Rcpp::export]] | |
nmat cbind10(nmat m1, nvec v1, double d1, nmat m2, nvec v2, | |
double d2, nmat m3, nvec v3, double d3, nmat m4) { | |
return cbind(m1, v1, d1, m2, v2, d2, m3, v3, d3, m4); | |
} | |
// [[Rcpp::export]] | |
nmat cbind_t(nvec x, nvec y, nvec z) { | |
return Rcpp::transpose(cbind(x, y, z)); | |
} | |
// [[Rcpp::export]] | |
nmat sbind2(double d, double e) { | |
return cbind(d, e); | |
} | |
/*** R | |
M <- matrix(1:6, nrow = 3); V <- rep(12, 3); D <- 18 | |
cbind6(M + 1i, V + 2i, D + 3i, M + 4i, V + 5i, D + 6i) | |
cbind10(M, V, D, 2 * M, 2 * V, 2 * D, 3 * M, 3 * V, 3 * D, 4 * M) | |
cbind_t(rep(1, 3), rep(3, 3), rep(5, 3)) | |
sbind2(3.5, pi) | |
*/ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment