Skip to content

Instantly share code, notes, and snippets.

@nathan-russell
Last active March 12, 2016 01:51
Show Gist options
  • Save nathan-russell/29da20070554744800bd to your computer and use it in GitHub Desktop.
Save nathan-russell/29da20070554744800bd to your computer and use it in GitHub Desktop.
#include <Rcpp.h>
namespace sugar {
namespace cbind_impl {
// CRTP base class
template <int RTYPE, typename E>
class BindableExpression {
public:
typedef typename Rcpp::traits::storage_type<RTYPE>::type stored_type;
stored_type operator[](R_xlen_t i) const {
return static_cast<const E&>(*this)[i];
}
stored_type operator()(R_xlen_t i, R_xlen_t j) const {
return static_cast<const E&>(*this)(i, j);
}
R_xlen_t size() const {
return static_cast<const E&>(*this).size();
}
R_xlen_t nrow() const {
return static_cast<const E&>(*this).nrow();
}
R_xlen_t ncol() const {
return static_cast<const E&>(*this).ncol();
}
operator E&() { return static_cast<E&>(*this); }
operator const E&() const { return static_cast<const E&>(*this); }
};
// Matrix, Vector
template <int RTYPE, typename T>
class ContainerBindable
: public BindableExpression<RTYPE, ContainerBindable<RTYPE, T> > {
public:
typedef typename Rcpp::traits::storage_type<RTYPE>::type stored_type;
private:
T vec;
R_xlen_t len, nr, nc;
public:
ContainerBindable(const Rcpp::Matrix<RTYPE>& vec_)
: vec(vec_), len(vec.ncol() * vec.nrow()),
nr(vec.nrow()), nc(vec.ncol())
{}
ContainerBindable(const Rcpp::Vector<RTYPE>& vec_)
: vec(vec_), len(vec.size()),
nr(vec.size()), nc(1)
{}
template <typename S>
ContainerBindable(const BindableExpression<RTYPE, S>& e)
: vec(e.size()), len(e.size()),
nr(e.nrow()), nc(e.ncol())
{
for (R_xlen_t i = 0; i < len; i++) {
vec[i] = e[i];
}
}
inline R_xlen_t size() const { return len; }
inline R_xlen_t nrow() const { return nr; }
inline R_xlen_t ncol() const { return nc; }
inline stored_type operator[](R_xlen_t i) const {
return vec[i];
}
inline stored_type operator()(R_xlen_t i, R_xlen_t j) const {
return vec[i + nr * j];
}
};
template <int RTYPE>
struct scalar_type {
typedef typename Rcpp::traits::storage_type<RTYPE>::type type;
typedef Rcpp::Vector<RTYPE> vector_type;
};
// scalars
template <typename T>
class ScalarBindable
: public BindableExpression<
Rcpp::traits::r_sexptype_traits<T>::rtype, ScalarBindable<T> > {
public:
typedef T stored_type;
enum { RTYPE = Rcpp::traits::r_sexptype_traits<T>::rtype };
private:
T t;
public:
ScalarBindable(const T& t_) : t(t_) {}
inline R_xlen_t size() const { return 1; }
inline R_xlen_t nrow() const { return 1; }
inline R_xlen_t ncol() const { return 1; }
inline stored_type operator[](R_xlen_t i) const {
return t;
}
inline stored_type operator()(R_xlen_t i, R_xlen_t j) const {
return t;
}
};
// binding logic; non-scalar operands
template <int RTYPE, typename E1, typename E2>
class JoinOp
: public BindableExpression<RTYPE, JoinOp<RTYPE, E1, E2> >,
public Rcpp::MatrixBase<RTYPE, true, JoinOp<RTYPE, E1, E2> > {
public:
typedef typename Rcpp::traits::storage_type<RTYPE>::type stored_type;
private:
const E1& e1;
const E2& e2;
public:
JoinOp(const BindableExpression<RTYPE, E1>& e1_,
const BindableExpression<RTYPE, E2>& e2_)
: e1(e1_), e2(e2_)
{
if (e1.nrow() != e2.nrow()) {
std::string msg =
"Error in cbind: "
"Matrix and Vector operands "
"must have equal "
"number of rows (length).";
Rcpp::stop(msg);
}
}
inline R_xlen_t size() const { return e1.size() + e2.size(); }
inline R_xlen_t nrow() const { return e1.nrow(); }
inline R_xlen_t ncol() const {
return e1.ncol() + e2.ncol();
}
inline stored_type operator[](R_xlen_t i) const {
return (i < e1.size()) ? e1[i] : e2[i - e1.size()];
}
inline stored_type operator()(R_xlen_t i, R_xlen_t j) const {
R_xlen_t index = i + nrow() * j;
return (*this)[index];
}
};
// binding logic; rhs scalar
template <int RTYPE, typename E1>
class JoinOp<RTYPE, E1, ScalarBindable<typename scalar_type<RTYPE>::type> >
: public BindableExpression<RTYPE,
JoinOp<RTYPE, E1,
ScalarBindable<typename scalar_type<RTYPE>::type> > >,
public Rcpp::MatrixBase<RTYPE, true,
JoinOp<RTYPE, E1,
ScalarBindable<typename scalar_type<RTYPE>::type> > > {
public:
typedef typename Rcpp::traits::storage_type<RTYPE>::type stored_type;
typedef ScalarBindable<typename scalar_type<RTYPE>::type> E2;
private:
const E1& e1;
const E2& e2;
public:
JoinOp(const BindableExpression<RTYPE, E1>& e1_,
const BindableExpression<RTYPE, E2>& e2_)
: e1(e1_), e2(e2_)
{}
inline R_xlen_t size() const { return e1.size() + e1.nrow(); }
inline R_xlen_t nrow() const { return e1.nrow(); }
inline R_xlen_t ncol() const { return e1.ncol() + 1; }
inline stored_type operator[](R_xlen_t i) const {
return (i < e1.size()) ? e1[i] : e2[i];
}
inline stored_type operator()(R_xlen_t i, R_xlen_t j) const {
R_xlen_t index = i + nrow() * j;
return (*this)[index];
}
};
// binding logic; lhs scalar
template <int RTYPE, typename E2>
class JoinOp<RTYPE, ScalarBindable<typename scalar_type<RTYPE>::type>, E2>
: public BindableExpression<RTYPE,
JoinOp<RTYPE,
ScalarBindable<typename scalar_type<RTYPE>::type>, E2> >,
public Rcpp::MatrixBase<RTYPE, true,
JoinOp<RTYPE,
ScalarBindable<typename scalar_type<RTYPE>::type>, E2> > {
public:
typedef typename Rcpp::traits::storage_type<RTYPE>::type stored_type;
typedef ScalarBindable<typename scalar_type<RTYPE>::type> E1;
private:
const E1& e1;
const E2& e2;
public:
JoinOp(const BindableExpression<RTYPE, E1>& e1_,
const BindableExpression<RTYPE, E2>& e2_)
: e1(e1_), e2(e2_)
{}
inline R_xlen_t size() const { return e2.size() + e2.nrow(); }
inline R_xlen_t nrow() const { return e2.nrow(); }
inline R_xlen_t ncol() const { return e2.ncol() + 1; }
inline stored_type operator[](R_xlen_t i) const {
return (i < e2.nrow()) ? e1[i] : e2[i - e2.nrow()];
}
inline stored_type operator()(R_xlen_t i, R_xlen_t j) const {
R_xlen_t index = i + nrow() * j;
return (*this)[index];
}
};
// binding logic; both scalar
template <int RTYPE>
class JoinOp<RTYPE, ScalarBindable<typename scalar_type<RTYPE>::type>,
ScalarBindable<typename scalar_type<RTYPE>::type> >
: public BindableExpression<RTYPE,
JoinOp<RTYPE,
ScalarBindable<typename scalar_type<RTYPE>::type>,
ScalarBindable<typename scalar_type<RTYPE>::type> > >,
public Rcpp::MatrixBase<RTYPE, true,
JoinOp<RTYPE,
ScalarBindable<typename scalar_type<RTYPE>::type>,
ScalarBindable<typename scalar_type<RTYPE>::type> > > {
public:
typedef typename Rcpp::traits::storage_type<RTYPE>::type stored_type;
typedef ScalarBindable<typename scalar_type<RTYPE>::type> E1;
typedef ScalarBindable<typename scalar_type<RTYPE>::type> E2;
private:
const E1& e1;
const E2& e2;
public:
JoinOp(const BindableExpression<RTYPE, E1>& e1_,
const BindableExpression<RTYPE, E2>& e2_)
: e1(e1_), e2(e2_)
{}
inline R_xlen_t size() const { return e2.size() + e2.nrow(); }
inline R_xlen_t nrow() const { return e2.nrow(); }
inline R_xlen_t ncol() const { return e2.ncol() + 1; }
inline stored_type operator[](R_xlen_t i) const {
return (i < e2.nrow()) ? e1[i] : e2[i];
}
inline stored_type operator()(R_xlen_t i, R_xlen_t j) const {
R_xlen_t index = i + nrow() * j;
return (*this)[index];
}
};
template <int RTYPE>
inline ContainerBindable<RTYPE, Rcpp::Matrix<RTYPE> >
MakeContainerBindable(const Rcpp::Matrix<RTYPE>& x) {
return ContainerBindable<RTYPE, Rcpp::Matrix<RTYPE> >(x);
}
template <int RTYPE>
inline ContainerBindable<RTYPE, Rcpp::Vector<RTYPE> >
MakeContainerBindable(const Rcpp::Vector<RTYPE>& x) {
return ContainerBindable<RTYPE, Rcpp::Vector<RTYPE> >(x);
}
template <typename T>
inline ScalarBindable<T>
MakeScalarBindable(const T& t) {
return ScalarBindable<T>(t);
}
// for expressions of arbitrary length
template <int RTYPE, typename E1, typename E2>
inline JoinOp<RTYPE, E1, E2> operator,(
const BindableExpression<RTYPE, E1>& e1,
const BindableExpression<RTYPE, E2>& e2)
{
return JoinOp<RTYPE, E1, E2>(e1, e2);
}
// helpers
namespace detail {
// distinguish Matrix/Vector from scalar
template <typename T>
class has_stored_type {
private:
typedef char yes;
typedef struct {
char array[2];
} no;
template <typename C>
static yes test(typename C::stored_type*);
template <typename C>
static no test(...);
public:
static const bool value = sizeof(test<T>(0)) == sizeof(yes);
};
// functor to dispatch appropriate wrapping behavior
template <typename T, bool is_container = has_stored_type<T>::value>
struct MakeBindableCall {};
template <typename T>
struct MakeBindableCall<T, true> {
typedef typename T::stored_type stored_type;
enum { RTYPE = Rcpp::traits::r_sexptype_traits<stored_type>::rtype };
ContainerBindable<RTYPE, T> operator()(const T& t) const {
return MakeContainerBindable(t);
}
};
template <typename T>
struct MakeBindableCall<T, false> {
enum { RTYPE = Rcpp::traits::r_sexptype_traits<T>::rtype };
ScalarBindable<T> operator()(const T& t) const {
return MakeScalarBindable(t);
}
};
template <typename T>
inline typename Rcpp::traits::enable_if<
has_stored_type<T>::value,
MakeBindableCall<T, true>
>::type
MakeBindable(const T& t) {
return MakeBindableCall<T, true>();
}
template <typename T>
inline typename Rcpp::traits::enable_if<
!has_stored_type<T>::value,
MakeBindableCall<T, false>
>::type
MakeBindable(const T& t) {
return MakeBindableCall<T, false>();
}
// determine cbind return type from first template
// parameter, agnostic of Matrix/Vector/scalar
template <typename T, bool is_container = has_stored_type<T>::value>
struct matrix_return {};
template <typename T>
struct matrix_return<T, true> {
typedef typename T::stored_type stored_type;
enum { RTYPE = Rcpp::traits::r_sexptype_traits<stored_type>::rtype };
typedef Rcpp::Matrix<RTYPE> type;
};
template <typename T>
struct matrix_return<T, false> {
enum { RTYPE = Rcpp::traits::r_sexptype_traits<T>::rtype };
typedef Rcpp::Matrix<RTYPE> type;
};
} // detail
template <typename T, bool B = detail::has_stored_type<T>::value>
struct matrix_return
: public detail::matrix_return<T, B> {};
template <typename T>
struct matrix_return<T, false>
: public detail::matrix_return<T, false> {};
} // cbind_impl
#define MakeBindable(x) (cbind_impl::detail::MakeBindable(x)(x))
template <typename T1, typename T2>
inline typename cbind_impl::matrix_return<T1>::type
cbind(const T1& t1, const T2& t2) {
return (MakeBindable(t1), MakeBindable(t2));
}
template <typename T1, typename T2, typename T3>
inline typename cbind_impl::matrix_return<T1>::type
cbind(const T1& t1, const T2& t2, const T3& t3) {
return (MakeBindable(t1), MakeBindable(t2),
MakeBindable(t3));
}
template <typename T1, typename T2, typename T3,
typename T4
> inline typename cbind_impl::matrix_return<T1>::type
cbind(const T1& t1, const T2& t2, const T3& t3, const T4& t4) {
return (MakeBindable(t1), MakeBindable(t2),
MakeBindable(t3), MakeBindable(t4));
}
template <typename T1, typename T2, typename T3,
typename T4, typename T5
> inline typename cbind_impl::matrix_return<T1>::type
cbind(const T1& t1, const T2& t2, const T3& t3,
const T4& t4, const T5& t5) {
return (MakeBindable(t1), MakeBindable(t2),
MakeBindable(t3), MakeBindable(t4),
MakeBindable(t5));
}
template<typename T1, typename T2, typename T3,
typename T4, typename T5, typename T6
> inline typename cbind_impl::matrix_return<T1>::type
cbind(const T1& t1, const T2& t2, const T3& t3,
const T4& t4, const T5& t5, const T6& t6) {
return (MakeBindable(t1), MakeBindable(t2),
MakeBindable(t3), MakeBindable(t4),
MakeBindable(t5), MakeBindable(t6));
}
// ...
template<typename T1, typename T2, typename T3,
typename T4, typename T5, typename T6,
typename T7, typename T8, typename T9,
typename T10
> inline typename cbind_impl::matrix_return<T1>::type
cbind(const T1& t1, const T2& t2, const T3& t3,
const T4& t4, const T5& t5, const T6& t6,
const T7& t7, const T8& t8, const T9& t9, const T10& t10) {
return (MakeBindable(t1), MakeBindable(t2),
MakeBindable(t3), MakeBindable(t4),
MakeBindable(t5), MakeBindable(t6),
MakeBindable(t7), MakeBindable(t8),
MakeBindable(t9), MakeBindable(t10));
}
// etc...
} // sugar
using sugar::cbind;
typedef Rcpp::ComplexMatrix cx_mat;
typedef Rcpp::ComplexVector cx_vec;
typedef Rcpp::NumericMatrix nmat;
typedef Rcpp::NumericVector nvec;
// [[Rcpp::export]]
cx_mat cbind6(cx_mat m1, cx_vec v1, Rcomplex d1,
cx_mat m2, cx_vec v2, Rcomplex d2) {
return cbind(m1, v1, d1, m2, v2, d2);
}
// [[Rcpp::export]]
nmat cbind10(nmat m1, nvec v1, double d1, nmat m2, nvec v2,
double d2, nmat m3, nvec v3, double d3, nmat m4) {
return cbind(m1, v1, d1, m2, v2, d2, m3, v3, d3, m4);
}
// [[Rcpp::export]]
nmat cbind_t(nvec x, nvec y, nvec z) {
return Rcpp::transpose(cbind(x, y, z));
}
// [[Rcpp::export]]
nmat sbind2(double d, double e) {
return cbind(d, e);
}
/*** R
M <- matrix(1:6, nrow = 3); V <- rep(12, 3); D <- 18
cbind6(M + 1i, V + 2i, D + 3i, M + 4i, V + 5i, D + 6i)
cbind10(M, V, D, 2 * M, 2 * V, 2 * D, 3 * M, 3 * V, 3 * D, 4 * M)
cbind_t(rep(1, 3), rep(3, 3), rep(5, 3))
sbind2(3.5, pi)
*/
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment