Last active
February 3, 2021 15:35
-
-
Save lakshya-sky/944c8a92186c0ec70f03fb04768c5f22 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
template <typename T, | |
size_t N, | |
template <typename U> class PtrTraits = DefaultPtrTraits, | |
typename index_t = int64_t> | |
class TensorAccessorBase { | |
public: | |
typedef typename PtrTraits<T>::PtrType PtrType; | |
TensorAccessorBase(PtrType data_, const index_t* strides_) | |
: data_(data_), strides_(strides_) {} | |
protected: | |
PtrType data_; | |
const index_t* strides_; | |
}; | |
// This can be implemented as MultiDimensionalTensorAccessor<const N> because it has N > 1. | |
template <typename T, | |
size_t N, | |
template <typename U> class PtrTraits = DefaultPtrTraits, | |
typename index_t = int64_t> | |
class TensorAccessor : public TensorAccessorBase<T, N, PtrTraits, index_t> { | |
public: | |
typedef typename PtrTraits<T>::PtrType PtrType; | |
TensorAccessor(PtrType data_, const index_t* strides_) | |
: TensorAccessorBase<T, N, PtrTraits, index_t>(data_, strides_) {} | |
// Indexing this returns a TensorAccessor with N-1 as a generic argument. | |
// In rust, depending on N-1 it may return MultiDimensionalTensorAccessor<const N-1> | |
// or SingleDimensionalTensorAccessor if N-1 is 1. | |
TensorAccessor<T, N - 1, PtrTraits, index_t> operator[](index_t i) { | |
return TensorAccessor<T, N - 1, PtrTraits, index_t>( | |
this->data_ + this->strides_[0] * i, this->strides_ + 1); | |
} | |
}; | |
// This can be implemented as SingleDimensionalTensorAccessor because it has N = 1. | |
template <typename T, template <typename U> class PtrTraits, typename index_t> | |
class TensorAccessor<T, 1, PtrTraits, index_t> | |
: public TensorAccessorBase<T, 1, PtrTraits, index_t> { | |
public: | |
typedef typename PtrTraits<T>::PtrType PtrType; | |
TensorAccessor(PtrType data_, const index_t* strides_) | |
: TensorAccessorBase<T, 1, PtrTraits, index_t>(data_, strides_) {} | |
// Here it only returns T. | |
T& operator[](index_t i) { return this->data_[this->strides_[0] * i]; } | |
}; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment