Skip to content

Instantly share code, notes, and snippets.

@ray1422
Created April 29, 2023 03:09
Show Gist options
  • Save ray1422/d203eb4e38b8a71a93d9500019428703 to your computer and use it in GitHub Desktop.
Save ray1422/d203eb4e38b8a71a93d9500019428703 to your computer and use it in GitHub Desktop.
#include <bits/stdc++.h>
template <typename T, unsigned... S>
class Tensor;
// define Tensor with variable shape
template <typename T, unsigned... S>
class Tensor {
public:
// output stream operator for Shape
template <typename Y, unsigned... Z>
friend std::ostream& operator<<(std::ostream& os, const Tensor<Y, Z...>::Shape& shape);
typedef const std::array<unsigned, sizeof...(S)> Shape;
constexpr static Shape shape = {S...};
// init from initializer list
Tensor(std::initializer_list<T> l) : data(new T[(S * ...)]) {
std::copy(l.begin(), l.end(), data);
}
Tensor() : data(new T[(S * ...)]) { std::cout << "Tensor created" << std::endl; }
// constructor from smart pointer
Tensor(const T* old_data) : data(new T[(S * ...)]) {
std::copy(old_data, old_data + (S * ...), data);
}
friend void swap(Tensor& first, Tensor& second) noexcept {
using std::swap;
std::cout << "Tensor swapped" << std::endl;
swap(first.data, second.data);
}
// copy constructor
Tensor(const Tensor& t) : data(new T[(S * ...)]) {
std::copy(t.data, t.data + (S * ...), data);
std::cout << "Tensor copied" << std::endl;
}
// move constructor
Tensor(Tensor&& t) : data(t.data) { t.data = nullptr; }
// copy assignment with copy-and-swap idiom
Tensor& operator=(Tensor t) {
std::cout << "Tensor copy or move-assigned" << std::endl;
using std::swap;
swap(*this, t);
return *this;
}
~Tensor() { delete[] data; }
// subscript operator with variadic template
T& operator[](std::initializer_list<unsigned> idxes) {
unsigned idx = 0;
unsigned i = 1;
for (auto it = idxes.begin(); it != idxes.end(); ++it) {
if (*it >= shape[i - 1]) {
throw std::out_of_range("Index out of range");
}
unsigned u = *it;
for (unsigned j = i; j < sizeof...(S); ++j) {
u *= shape[j];
}
idx += u;
i++;
}
return data[idx];
}
friend Tensor operator+(Tensor lhs, const Tensor& rhs) { return std::move(lhs.add(rhs)); }
friend Tensor operator-(Tensor lhs, const Tensor& rhs) { return std::move(lhs.subtract(rhs)); }
friend Tensor operator-(Tensor lhs) { return std::move(lhs.negate()); }
friend Tensor operator-(Tensor lhs, const T& rhs) { return std::move(lhs.subtract(rhs)); }
friend Tensor operator-(const T& rhs, Tensor lhs) { return std::move(lhs.be_subtracted(rhs)); }
friend Tensor operator*(Tensor lhs, const Tensor& rhs) {
return std::move(std::move(lhs).multiply(rhs));
}
friend Tensor operator*(Tensor lhs, const T& rhs) { return std::move(lhs.multiply(rhs)); }
friend Tensor operator*(const T& rhs, Tensor lhs) { return std::move(lhs.multiply(rhs)); }
friend Tensor operator/(Tensor lhs, const Tensor& rhs) {
return std::move(std::move(lhs).divide(rhs));
}
friend Tensor operator/(Tensor lhs, const T& rhs) { return std::move(lhs.divide(rhs)); }
friend Tensor operator/(const T& rhs, Tensor lhs) { return std::move(lhs.divide(rhs)); }
Tensor& add(const Tensor& t) {
// add t into this then return this
std::transform(data, data + (S * ...), t.data, data, std::plus<T>());
return *this;
}
Tensor& add(const T& t) {
// add t into this then return this
std::transform(data, data + (S * ...), data, [t](const T& x) { return x + t; });
return *this;
}
// negative operator
auto& negate() {
std::transform(data, data + (S * ...), data, std::negate<T>());
return *this;
}
auto& subtract(const Tensor& t) {
// subtract t from this then return this
std::transform(data, data + (S * ...), t.data, data, std::minus<T>());
return *this;
}
// subtract by scalar. e.g. t1 - 1
auto& subtract(const T& t) {
// subtract t from this then return this
std::transform(data, data + (S * ...), data, [t](const T& x) { return x - t; });
return *this;
}
// be subtracted by scalar. e.g. 1 - t1
auto& be_subtracted(const T& t) {
// subtract t from this then return this
std::transform(data, data + (S * ...), data, [t](const T& x) { return t - x; });
return *this;
}
Tensor& multiply(const Tensor& t) {
// multiply t into this then return this
std::transform(data, data + (S * ...), t.data, data, std::multiplies<T>());
return *this;
}
// multiply by scalar. e.g. t1 * 2
Tensor& multiply(const T& t) {
// multiply t into this then return this
std::transform(data, data + (S * ...), data, [t](const T& x) { return x * t; });
return *this;
}
// divide by Tensor element-wise
Tensor& divide(const Tensor& t) {
// divide t into this then return this
std::transform(data, data + (S * ...), t.data, data, std::divides<T>());
return *this;
}
// divide by scalar. e.g. t1 / 2
Tensor& divide(const T& t) {
// divide t into this then return this
std::transform(data, data + (S * ...), data, [t](const T& x) { return x / t; });
return *this;
}
// be divided by scalar. e.g. 2 / t1
Tensor& be_divided(const T& t) {
// divide t into this then return this
std::transform(data, data + (S * ...), data, [t](const T& x) { return t / x; });
return *this;
}
template <unsigned... NewS>
Tensor<T, NewS...> reshape() {
static_assert((S * ...) == (NewS * ...), "reshape size mismatch");
return Tensor<T, NewS...>(data);
}
std::ostream& print_shape(std::ostream& os) const {
os << "(";
for (unsigned i = 0; i < sizeof...(S); ++i) {
os << shape[i];
if (i != sizeof...(S) - 1) {
os << ", ";
}
}
os << ")";
return os;
}
private:
T* data;
};
template <unsigned... S>
using TensorI = Tensor<int, S...>;
template <unsigned... S>
using TensorF64 = Tensor<double, S...>;
// output stream operator << for Tensor::Shape
int main() {
using namespace std;
Tensor<int, 3, 2, 2> t1 = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
auto t2 = t1; // copy constructed
auto t3 = t2;
std::cout << "should one copy each block" << std::endl;
// auto t4 = t1;
auto t4 = t1 * (t2 * t3);
std::cout << "----------" << std::endl;
auto t5 = t1 - t2 + t3 + std::move(t4);
std::cout << "----------" << std::endl;
auto t6 = t1 - t3;
auto t10 = t1.reshape<2, 2, 3>();
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment