Skip to content

Instantly share code, notes, and snippets.

@ezyang
Created April 9, 2018 17:40
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ezyang/0fed850e57bb87780baa181c743260c7 to your computer and use it in GitHub Desktop.
Save ezyang/0fed850e57bb87780baa181c743260c7 to your computer and use it in GitHub Desktop.
#include <cstdint>
// Placeholder for asserts; ignore them for now
#define AT_ASSERT(cond, ...)
// ArrayRef (comes from LLVM, ATen uses it, we think it's pretty good)
//===--- ArrayRef.h - Array Reference Wrapper -------------------*- C++ -*-===//
//
// The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
// ATen: modified from llvm::ArrayRef.
// removed llvm-specific functionality
// removed some implicit const -> non-const conversions that rely on
// complicated std::enable_if meta-programming
// removed a bunch of slice variants for simplicity...
#include <array>
#include <iterator>
#include <vector>
/// ArrayRef - Represent a constant reference to an array (0 or more elements
/// consecutively in memory), i.e. a start pointer and a length. It allows
/// various APIs to take consecutive elements easily and conveniently.
///
/// This class does not own the underlying data, it is expected to be used in
/// situations where the data resides in some other buffer, whose lifetime
/// extends past that of the ArrayRef. For this reason, it is not in general
/// safe to store an ArrayRef.
///
/// This is intended to be trivially copyable, so it should be passed by
/// value.
template<typename T>
class ArrayRef {
public:
typedef const T *iterator;
typedef const T *const_iterator;
typedef size_t size_type;
typedef std::reverse_iterator<iterator> reverse_iterator;
private:
/// The start of the array, in an external buffer.
const T *Data;
/// The number of elements.
size_type Length;
public:
/// @name Constructors
/// @{
/// Construct an empty ArrayRef.
/*implicit*/ ArrayRef() : Data(nullptr), Length(0) {}
/// Construct an ArrayRef from a single element.
/*implicit*/ ArrayRef(const T &OneElt)
: Data(&OneElt), Length(1) {}
/// Construct an ArrayRef from a pointer and length.
/*implicit*/ ArrayRef(const T *data, size_t length)
: Data(data), Length(length) {}
/// Construct an ArrayRef from a range.
ArrayRef(const T *begin, const T *end)
: Data(begin), Length(end - begin) {}
/// Construct an ArrayRef from a std::vector.
template<typename A>
/*implicit*/ ArrayRef(const std::vector<T, A> &Vec)
: Data(Vec.data()), Length(Vec.size()) {}
/// Construct an ArrayRef from a std::array
template <size_t N>
/*implicit*/ constexpr ArrayRef(const std::array<T, N> &Arr)
: Data(Arr.data()), Length(N) {}
/// Construct an ArrayRef from a C array.
template <size_t N>
/*implicit*/ constexpr ArrayRef(const T (&Arr)[N]) : Data(Arr), Length(N) {}
/// Construct an ArrayRef from a std::initializer_list.
/*implicit*/ ArrayRef(const std::initializer_list<T> &Vec)
: Data(Vec.begin() == Vec.end() ? (T*)nullptr : Vec.begin()),
Length(Vec.size()) {}
/// @}
/// @name Simple Operations
/// @{
iterator begin() const { return Data; }
iterator end() const { return Data + Length; }
reverse_iterator rbegin() const { return reverse_iterator(end()); }
reverse_iterator rend() const { return reverse_iterator(begin()); }
/// empty - Check if the array is empty.
bool empty() const { return Length == 0; }
const T *data() const { return Data; }
/// size - Get the array size.
size_t size() const { return Length; }
/// front - Get the first element.
const T &front() const {
AT_ASSERT(!empty(), "Empty list!");
return Data[0];
}
/// back - Get the last element.
const T &back() const {
AT_ASSERT(!empty(), "Empty list!");
return Data[Length-1];
}
/// equals - Check for element-wise equality.
bool equals(ArrayRef RHS) const {
if (Length != RHS.Length)
return false;
return std::equal(begin(), end(), RHS.begin());
}
/// slice(n, m) - Chop off the first N elements of the array, and keep M
/// elements in the array.
ArrayRef<T> slice(size_t N, size_t M) const {
AT_ASSERT(N+M <= size(), "Invalid specifier");
return ArrayRef<T>(data()+N, M);
}
/// slice(n) - Chop off the first N elements of the array.
ArrayRef<T> slice(size_t N) const { return slice(N, size() - N); }
/// @}
/// @name Operator Overloads
/// @{
const T &operator[](size_t Index) const {
return Data[Index];
}
/// Vector compatibility
const T &at(size_t Index) const {
AT_ASSERT(Index < Length, "Invalid index!");
return Data[Index];
}
/// Disallow accidental assignment from a temporary.
///
/// The declaration here is extra complicated so that "arrayRef = {}"
/// continues to select the move assignment operator.
template <typename U>
typename std::enable_if<std::is_same<U, T>::value, ArrayRef<T>>::type &
operator=(U &&Temporary) = delete;
/// Disallow accidental assignment from a temporary.
///
/// The declaration here is extra complicated so that "arrayRef = {}"
/// continues to select the move assignment operator.
template <typename U>
typename std::enable_if<std::is_same<U, T>::value, ArrayRef<T>>::type &
operator=(std::initializer_list<U>) = delete;
/// @}
/// @name Expensive Operations
/// @{
std::vector<T> vec() const {
return std::vector<T>(Data, Data+Length);
}
/// @}
/// @name Conversion operators
/// @{
operator std::vector<T>() const {
return std::vector<T>(Data, Data+Length);
}
/// @}
};
//===--- TypeId.h -----------------------------------------------*- C++ -*-===//
// A compact identifier which stores all of the information necessary to
// carry out a dispatch on a type. This is NOT NECESSARILY in one-to-one
// correspondence with the type hierarchy of TensorImpl, because we may decide
// that we want to refine dispatch on a runtime property of a tensor which is
// NOT reflected by the class hierarchy.
//
// Other note: there is no NATIVE notion of a subtyping relationship between
// these TypeIds. We are planning to design one but we haven't decided on
// its specifics yet.
//
// ezyang: CC smessmer; I know you wanted to have this line up exactly
// with the concrete TensorImpl subclasses, but I don't want to commit
// to that at the moment
//
// TODO: Does this also contain per Tensor properties, like contiguity?
class TypeId final {
int64_t id_;
TypeId(int64_t id) : id_(id) {}
public:
static constexpr TypeId Undefined = {0};
static constexpr TypeId CPUTensor = {1};
static constexpr TypeId StridedCPUTensor = {2};
};
static_assert(std::is_pod<TypeId>());
//===--- TensorImpl.h -----------------------------------------------*- C++ -*-===//
// TODO: Fill in an actual SmallVector implementation here. Both Folly and LLVM's
// implementation are a bit annoying to make standalone. Maybe this can be made
// simpler by assuming T is POD.
template <typename T>
using SmallVector = std::vector<T>;
// For now: try using empty tensors for type (I think we'll probably add a Type
// object)
// NB: Use of virtual functions means that this is NOT a plain old data class.
// This means that we don't get inlineable C API functions which access the representation
// directly
class TensorImpl {
// Used for dispatch on the object
TypeId type_id_;
// We have an interesting problem here, which regards our short term plan for
// integrating PyTorch and Caffe2 without having to rewrite all of Torch/Caffe2's
// operators. Recall that both Torch and Caffe2 have their own, existing tensor
// types, which record sizes by themselves.
SmallVector<int64_t> size_;
public:
explicit TensorImpl(TypeId type_id) : type_id_(type_id) {};
// Inline? Virtual? See the admonition above.
virtual ArrayRef<int64_t> size() const {
return size_;
}
virtual ArrayRef<int64_t> stride() const {
throw std::runtime_error("TensorImpl::stride()");
}
virtual int64_t dim() const {
return static_cast<int64_t>(size().size());
}
virtual void* data_ptr() const {
throw std::runtime_error("TensorImpl::data_ptr()");
}
virtual void retain() = 0;
virtual void release() = 0;
virtual ~Tensor() = 0;
};
// See design notes on Tensor.h, where this is hardcoded a few times.
class UndefinedTensorImpl : public TensorImpl {
UndefinedTensorImpl() : TensorImpl(TypeId::Undefined) {};
static UndefinedTensorImpl singleton_;
public:
virtual ArrayRef<int64_t> size() const {
throw std::runtime_error("UndefinedTensorImpl::sizes()");
}
virtual int64_t dim() const {
throw std::runtime_error("UndefinedTensorImpl::sizes()");
}
virtual void retain() {}
virtual void release() {}
static inline UndefinedTensorImpl* singleton() {
return &singleton_;
}
};
class CPUTensorImpl : public TensorImpl {
void* data_ptr_;
public:
CPUTensorImpl() : TensorImpl(TypeId::CPUTensor) {};
};
class StridedCPUTensorImpl : public CPUTensorImpl {
SmallVector<int64_t> stride_;
StridedCPUTensorImpl() : TensorImpl(TypeId::StridedCPUTensor) {};
};
// NB: From ATen I dropped the following methods:
// - toString()
//===--- Tensor.h -----------------------------------------------*- C++ -*-===//
// Design notes:
// - Manual retain/release instead of shared_ptr. Reasons:
// - PRIMARY: It's possible to work with the underlying retained object using
// a C API, which is basically impossible to do with shared_ptr because
// it doesn't expose a manual retain()/release() API
// - SECONDARY: A true intrusive reference count has some nice properties
// which you don't get from use of std::make_shared (to put the refcount
// metadata next to the regular dynamic allocation) and
// std::enabled_shared_from_this (which generally needs to store a weak pointer
// to the control block).
// - UndefinedTensorImpl instead of null pointer. Reasons:
// - We originally had a null pointer in ATen, but this meant that when we
// incorrectly attempted to use such a null pointer, we would segfault and
// crash, which is very unfriendly for our Python users. Using an UndefinedTensorImpl
// as our default constructor is much better for us.
// - Fixed the mismatch between PyTorch and C++ methods
// - sizes() is now size()
class Tensor final {
TensorImpl * pImpl;
public:
// This is a relatively unsafe constructor which you should avoid using if you
// don't need it. The retain parameter specifies whether or not this constructor
// takes ownership of the passed Impl or not (when retain = true, the caller retains
// their reference.)
Tensor(TensorImpl* self, bool retain)
: pImpl(self) {
if (pImpl == nullptr) {
throw std::runtime_error("Tensor with nullptr not supported");
}
if(retain && pImpl != UndefinedTensorImpl::singleton())
pImpl->retain();
}
// Normal constructors
Tensor(): Tensor(UndefinedTensorImpl::singleton(), false) {}
Tensor(const Tensor & rhs)
: pImpl(rhs.pImpl) {
if (pImpl != UndefinedTensorImpl::singleton())
pImpl->retain();
}
Tensor(Tensor && rhs) noexcept
: pImpl(rhs.pImpl) {
rhs.pImpl = UndefinedTensorImpl::singleton();
}
// Destructor
~Tensor() {
if (pImpl != UndefinedTensorImpl::singleton())
pImpl->release();
}
// Copy assignment
Tensor & operator=(Tensor && rhs) & noexcept {
rhs.swap(*this);
return *this;
}
Tensor & operator=(Tensor const & rhs) & {
//TensorBase ctor retains original rhs.pImpl
//then rhs.pImpl is swapped with this->pImpl
//finally TensorBase dtor releases rhs.pImpl, which was originally this->pImpl
Tensor(rhs).swap(*this);
return *this;
}
// Direct PIMPL manipulation methods. Use of these is discouraged, but may be needed
// in binding code.
void reset() {
Tensor().swap(*this);
}
void reset(TensorImpl * rhs) {
Tensor(rhs, true).swap(*this);
}
void reset(TensorImpl * rhs, bool retain) {
Tensor(rhs, retain).swap(*this );
}
void swap(Tensor & rhs) noexcept {
TensorImpl * tmp = pImpl;
pImpl = rhs.pImpl;
rhs.pImpl = tmp;
}
TensorImpl * get() const {
return pImpl;
}
TensorImpl * detach() {
TensorImpl * ret = pImpl;
pImpl = UndefinedTensorImpl::singleton();
return ret;
}
// We do a lot of null-pointer checks in our code, good to have this be cheap.
inline bool defined() const {
return pImpl != UndefinedTensorImpl::singleton();
}
// These methods are SO important, they are currently implemented via virtual dispatch
// via our implementation classes. Most non-core methods should be implemented by
// the generic dispatch mechanism.
int64_t dim() const {
return pImpl->dim();
}
inline int64_t ndimension() const {
return dim();
}
ArrayRef<int64_t> size() const {
return pImpl->size();
}
ArrayRef<int64_t> stride() const {
return pImpl->stride();
}
template<typename T>
T * data() const {
return static_cast<T*>(pImpl->data_ptr());
}
// TODO: work out the type() situation
// TODO: work out the storage() situation
// The "well known" Tensor functions will call into the dispatch mechanism (yet to be
// implemented)
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment