Created
April 29, 2023 03:09
-
-
Save ray1422/d203eb4e38b8a71a93d9500019428703 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <bits/stdc++.h> | |
template <typename T, unsigned... S> | |
class Tensor; | |
// define Tensor with variable shape | |
template <typename T, unsigned... S> | |
class Tensor { | |
public: | |
// output stream operator for Shape | |
template <typename Y, unsigned... Z> | |
friend std::ostream& operator<<(std::ostream& os, const Tensor<Y, Z...>::Shape& shape); | |
typedef const std::array<unsigned, sizeof...(S)> Shape; | |
constexpr static Shape shape = {S...}; | |
// init from initializer list | |
Tensor(std::initializer_list<T> l) : data(new T[(S * ...)]) { | |
std::copy(l.begin(), l.end(), data); | |
} | |
Tensor() : data(new T[(S * ...)]) { std::cout << "Tensor created" << std::endl; } | |
// constructor from smart pointer | |
Tensor(const T* old_data) : data(new T[(S * ...)]) { | |
std::copy(old_data, old_data + (S * ...), data); | |
} | |
friend void swap(Tensor& first, Tensor& second) noexcept { | |
using std::swap; | |
std::cout << "Tensor swapped" << std::endl; | |
swap(first.data, second.data); | |
} | |
// copy constructor | |
Tensor(const Tensor& t) : data(new T[(S * ...)]) { | |
std::copy(t.data, t.data + (S * ...), data); | |
std::cout << "Tensor copied" << std::endl; | |
} | |
// move constructor | |
Tensor(Tensor&& t) : data(t.data) { t.data = nullptr; } | |
// copy assignment with copy-and-swap idiom | |
Tensor& operator=(Tensor t) { | |
std::cout << "Tensor copy or move-assigned" << std::endl; | |
using std::swap; | |
swap(*this, t); | |
return *this; | |
} | |
~Tensor() { delete[] data; } | |
// subscript operator with variadic template | |
T& operator[](std::initializer_list<unsigned> idxes) { | |
unsigned idx = 0; | |
unsigned i = 1; | |
for (auto it = idxes.begin(); it != idxes.end(); ++it) { | |
if (*it >= shape[i - 1]) { | |
throw std::out_of_range("Index out of range"); | |
} | |
unsigned u = *it; | |
for (unsigned j = i; j < sizeof...(S); ++j) { | |
u *= shape[j]; | |
} | |
idx += u; | |
i++; | |
} | |
return data[idx]; | |
} | |
friend Tensor operator+(Tensor lhs, const Tensor& rhs) { return std::move(lhs.add(rhs)); } | |
friend Tensor operator-(Tensor lhs, const Tensor& rhs) { return std::move(lhs.subtract(rhs)); } | |
friend Tensor operator-(Tensor lhs) { return std::move(lhs.negate()); } | |
friend Tensor operator-(Tensor lhs, const T& rhs) { return std::move(lhs.subtract(rhs)); } | |
friend Tensor operator-(const T& rhs, Tensor lhs) { return std::move(lhs.be_subtracted(rhs)); } | |
friend Tensor operator*(Tensor lhs, const Tensor& rhs) { | |
return std::move(std::move(lhs).multiply(rhs)); | |
} | |
friend Tensor operator*(Tensor lhs, const T& rhs) { return std::move(lhs.multiply(rhs)); } | |
friend Tensor operator*(const T& rhs, Tensor lhs) { return std::move(lhs.multiply(rhs)); } | |
friend Tensor operator/(Tensor lhs, const Tensor& rhs) { | |
return std::move(std::move(lhs).divide(rhs)); | |
} | |
friend Tensor operator/(Tensor lhs, const T& rhs) { return std::move(lhs.divide(rhs)); } | |
friend Tensor operator/(const T& rhs, Tensor lhs) { return std::move(lhs.divide(rhs)); } | |
Tensor& add(const Tensor& t) { | |
// add t into this then return this | |
std::transform(data, data + (S * ...), t.data, data, std::plus<T>()); | |
return *this; | |
} | |
Tensor& add(const T& t) { | |
// add t into this then return this | |
std::transform(data, data + (S * ...), data, [t](const T& x) { return x + t; }); | |
return *this; | |
} | |
// negative operator | |
auto& negate() { | |
std::transform(data, data + (S * ...), data, std::negate<T>()); | |
return *this; | |
} | |
auto& subtract(const Tensor& t) { | |
// subtract t from this then return this | |
std::transform(data, data + (S * ...), t.data, data, std::minus<T>()); | |
return *this; | |
} | |
// subtract by scalar. e.g. t1 - 1 | |
auto& subtract(const T& t) { | |
// subtract t from this then return this | |
std::transform(data, data + (S * ...), data, [t](const T& x) { return x - t; }); | |
return *this; | |
} | |
// be subtracted by scalar. e.g. 1 - t1 | |
auto& be_subtracted(const T& t) { | |
// subtract t from this then return this | |
std::transform(data, data + (S * ...), data, [t](const T& x) { return t - x; }); | |
return *this; | |
} | |
Tensor& multiply(const Tensor& t) { | |
// multiply t into this then return this | |
std::transform(data, data + (S * ...), t.data, data, std::multiplies<T>()); | |
return *this; | |
} | |
// multiply by scalar. e.g. t1 * 2 | |
Tensor& multiply(const T& t) { | |
// multiply t into this then return this | |
std::transform(data, data + (S * ...), data, [t](const T& x) { return x * t; }); | |
return *this; | |
} | |
// divide by Tensor element-wise | |
Tensor& divide(const Tensor& t) { | |
// divide t into this then return this | |
std::transform(data, data + (S * ...), t.data, data, std::divides<T>()); | |
return *this; | |
} | |
// divide by scalar. e.g. t1 / 2 | |
Tensor& divide(const T& t) { | |
// divide t into this then return this | |
std::transform(data, data + (S * ...), data, [t](const T& x) { return x / t; }); | |
return *this; | |
} | |
// be divided by scalar. e.g. 2 / t1 | |
Tensor& be_divided(const T& t) { | |
// divide t into this then return this | |
std::transform(data, data + (S * ...), data, [t](const T& x) { return t / x; }); | |
return *this; | |
} | |
template <unsigned... NewS> | |
Tensor<T, NewS...> reshape() { | |
static_assert((S * ...) == (NewS * ...), "reshape size mismatch"); | |
return Tensor<T, NewS...>(data); | |
} | |
std::ostream& print_shape(std::ostream& os) const { | |
os << "("; | |
for (unsigned i = 0; i < sizeof...(S); ++i) { | |
os << shape[i]; | |
if (i != sizeof...(S) - 1) { | |
os << ", "; | |
} | |
} | |
os << ")"; | |
return os; | |
} | |
private: | |
T* data; | |
}; | |
template <unsigned... S> | |
using TensorI = Tensor<int, S...>; | |
template <unsigned... S> | |
using TensorF64 = Tensor<double, S...>; | |
// output stream operator << for Tensor::Shape | |
int main() { | |
using namespace std; | |
Tensor<int, 3, 2, 2> t1 = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; | |
auto t2 = t1; // copy constructed | |
auto t3 = t2; | |
std::cout << "should one copy each block" << std::endl; | |
// auto t4 = t1; | |
auto t4 = t1 * (t2 * t3); | |
std::cout << "----------" << std::endl; | |
auto t5 = t1 - t2 + t3 + std::move(t4); | |
std::cout << "----------" << std::endl; | |
auto t6 = t1 - t3; | |
auto t10 = t1.reshape<2, 2, 3>(); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment