Skip to content

Instantly share code, notes, and snippets.

@lakshya-sky
Last active February 3, 2021 15:35
Show Gist options
  • Save lakshya-sky/944c8a92186c0ec70f03fb04768c5f22 to your computer and use it in GitHub Desktop.
Save lakshya-sky/944c8a92186c0ec70f03fb04768c5f22 to your computer and use it in GitHub Desktop.
template <typename T,
size_t N,
template <typename U> class PtrTraits = DefaultPtrTraits,
typename index_t = int64_t>
class TensorAccessorBase {
public:
typedef typename PtrTraits<T>::PtrType PtrType;
TensorAccessorBase(PtrType data_, const index_t* strides_)
: data_(data_), strides_(strides_) {}
protected:
PtrType data_;
const index_t* strides_;
};
// This can be implemented as MultiDimensionalTensorAccessor<const N> because it has N > 1.
template <typename T,
size_t N,
template <typename U> class PtrTraits = DefaultPtrTraits,
typename index_t = int64_t>
class TensorAccessor : public TensorAccessorBase<T, N, PtrTraits, index_t> {
public:
typedef typename PtrTraits<T>::PtrType PtrType;
TensorAccessor(PtrType data_, const index_t* strides_)
: TensorAccessorBase<T, N, PtrTraits, index_t>(data_, strides_) {}
// Indexing this returns a TensorAccessor with N-1 as a generic argument.
// In rust, depending on N-1 it may return MultiDimensionalTensorAccessor<const N-1>
// or SingleDimensionalTensorAccessor if N-1 is 1.
TensorAccessor<T, N - 1, PtrTraits, index_t> operator[](index_t i) {
return TensorAccessor<T, N - 1, PtrTraits, index_t>(
this->data_ + this->strides_[0] * i, this->strides_ + 1);
}
};
// This can be implemented as SingleDimensionalTensorAccessor because it has N = 1.
template <typename T, template <typename U> class PtrTraits, typename index_t>
class TensorAccessor<T, 1, PtrTraits, index_t>
: public TensorAccessorBase<T, 1, PtrTraits, index_t> {
public:
typedef typename PtrTraits<T>::PtrType PtrType;
TensorAccessor(PtrType data_, const index_t* strides_)
: TensorAccessorBase<T, 1, PtrTraits, index_t>(data_, strides_) {}
// Here it only returns T.
T& operator[](index_t i) { return this->data_[this->strides_[0] * i]; }
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment