-
-
Save cynecx/e1b44a67e74bac257ed9e632c982fc7f to your computer and use it in GitHub Desktop.
My crappy Matrix implementation.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <algorithm> | |
#include <array> | |
#include <cassert> | |
#include <cstddef> | |
#include <cstdint> | |
#include <numeric> | |
#include <type_traits> | |
template <uint16_t M, uint16_t N = M, typename T = float> | |
struct Matrix { | |
using index_t = uint16_t; | |
static_assert(M >= 1 && N >= 1); | |
static constexpr size_t ElementsCount = M * N; | |
using matrix_ty = std::array<T, ElementsCount>; | |
using iterator = typename matrix_ty::iterator; | |
using const_iterator = typename matrix_ty::const_iterator; | |
template <typename Accessor, index_t Limit> | |
struct MatrixAccessorIterator { | |
private: | |
const Accessor m_Accessor; | |
index_t m_CurrIndex; | |
public: | |
constexpr MatrixAccessorIterator(const Accessor& accessor, | |
index_t currIndex) | |
: m_Accessor{accessor}, m_CurrIndex{currIndex} { | |
assert(currIndex <= Limit); | |
} | |
constexpr MatrixAccessorIterator(const MatrixAccessorIterator&) = default; | |
MatrixAccessorIterator& operator++() { | |
++m_CurrIndex; | |
assert(m_CurrIndex <= Limit); | |
return *this; | |
} | |
MatrixAccessorIterator operator++(int) { | |
MatrixAccessorIterator tmp(*this); | |
++m_CurrIndex; | |
assert(m_CurrIndex <= Limit); | |
return tmp; | |
} | |
bool operator==(const MatrixAccessorIterator& rhs) const { | |
return m_Accessor == rhs.m_Accessor && m_CurrIndex == rhs.m_CurrIndex; | |
} | |
bool operator!=(const MatrixAccessorIterator& rhs) const { | |
return !(*this == rhs); | |
} | |
T operator*() const { | |
return m_Accessor[m_CurrIndex]; | |
} | |
T operator->() const { | |
return m_Accessor[m_CurrIndex]; | |
} | |
}; | |
template <typename MatTy> | |
struct MatrixRowAccessor { | |
using IteratorTy = MatrixAccessorIterator<MatrixRowAccessor, N>; | |
using ValTy = std::conditional_t<std::is_const<MatTy>::value, | |
std::add_const_t<T>, T>; | |
private: | |
MatTy& m_ArrRef; | |
const size_t m_RowIndex; | |
public: | |
constexpr MatrixRowAccessor(MatTy& arr, index_t row) | |
: m_ArrRef{arr}, m_RowIndex(row * N) { | |
assert(row < M); | |
} | |
constexpr MatrixRowAccessor(const MatrixRowAccessor&) = default; | |
constexpr size_t Size() const { | |
return N; | |
} | |
const T& operator[](index_t column) const { | |
assert(column < N); | |
return m_ArrRef[m_RowIndex + column]; | |
} | |
ValTy& operator[](index_t column) { | |
assert(column < N); | |
return m_ArrRef[m_RowIndex + column]; | |
} | |
bool operator==(const MatrixRowAccessor& rhs) const { | |
return &m_ArrRef == &rhs.m_ArrRef && m_RowIndex == rhs.m_RowIndex; | |
} | |
bool operator!=(const MatrixRowAccessor& rhs) const { | |
return !(*this == rhs); | |
} | |
IteratorTy begin() const { | |
return {*this, 0}; | |
} | |
IteratorTy end() const { | |
return {*this, N}; | |
} | |
}; | |
template <typename MatTy> | |
struct MatrixColumnAccessor { | |
using IteratorTy = MatrixAccessorIterator<MatrixColumnAccessor, M>; | |
private: | |
MatTy& m_ArrRef; | |
const index_t m_Column; | |
public: | |
constexpr MatrixColumnAccessor(MatTy& arr, index_t column) | |
: m_ArrRef{arr}, m_Column{column} { | |
assert(column < N); | |
} | |
constexpr MatrixColumnAccessor(const MatrixColumnAccessor& rhs) = default; | |
constexpr size_t Size() const { | |
return M; | |
} | |
const T& operator[](index_t row) const { | |
assert(row < M); | |
return m_ArrRef[row * N + m_Column]; | |
} | |
T& operator[](index_t row) { | |
assert(row < M); | |
return m_ArrRef[row * N + m_Column]; | |
} | |
bool operator==(const MatrixColumnAccessor& rhs) const { | |
return &m_ArrRef == &rhs.m_ArrRef && m_Column == rhs.m_Column; | |
} | |
bool operator!=(const MatrixColumnAccessor& rhs) const { | |
return !(*this == rhs); | |
} | |
IteratorTy begin() const { | |
return {*this, 0}; | |
} | |
IteratorTy end() const { | |
return {*this, M}; | |
} | |
}; | |
matrix_ty m_Values = { {0} }; | |
constexpr Matrix() = default; | |
template <typename... ElemsTy> | |
constexpr Matrix(ElemsTy... values) : m_Values{{T{values}...}} { | |
static_assert(sizeof...(ElemsTy) == ElementsCount); | |
} | |
constexpr Matrix(const Matrix&) = default; | |
constexpr Matrix& operator=(const Matrix&) = default; | |
const MatrixRowAccessor<const matrix_ty> GetRowAccessor( | |
index_t row) const { | |
return {m_Values, row}; | |
} | |
MatrixRowAccessor<matrix_ty> GetRowAccessor(index_t row) { | |
return {m_Values, row}; | |
} | |
const MatrixColumnAccessor<const matrix_ty> GetColumnAccessor( | |
index_t column) const { | |
return {m_Values, column}; | |
} | |
MatrixColumnAccessor<matrix_ty> GetColumnAccessor( | |
index_t column) { | |
return {m_Values, column}; | |
} | |
const MatrixRowAccessor<const matrix_ty> operator[]( | |
index_t row) const { | |
return GetRowAccessor(row); | |
} | |
MatrixRowAccessor<matrix_ty> operator[](index_t row) { | |
return GetRowAccessor(row); | |
} | |
iterator begin() { | |
return m_Values.begin(); | |
} | |
const_iterator begin() const { | |
return m_Values.begin(); | |
} | |
iterator end() { | |
return m_Values.end(); | |
} | |
const_iterator end() const { | |
return m_Values.end(); | |
} | |
Matrix& operator*=(const T& rhs) { | |
std::transform(begin(), end(), begin(), [&rhs] (T val) { | |
return val * rhs; | |
}); | |
return *this; | |
} | |
Matrix operator*(T rhs) { | |
Matrix result; | |
std::transform(begin(), end(), result.begin(), [&rhs] (T val) { | |
return val * rhs; | |
}); | |
return result; | |
} | |
Matrix& operator/=(const T& rhs) { | |
std::transform(begin(), end(), begin(), [&rhs] (T val) { | |
return val / rhs; | |
}); | |
return *this; | |
} | |
Matrix operator/(T rhs) { | |
Matrix result; | |
std::transform(begin(), end(), result.begin(), [&rhs] (T val) { | |
return val / rhs; | |
}); | |
return result; | |
} | |
Matrix& operator+=(const Matrix& rhs) { | |
std::transform(begin(), end(), rhs.begin(), begin(), [] (T val, T val2) { | |
return val + val2; | |
}); | |
return *this; | |
} | |
Matrix operator+(const Matrix& rhs) { | |
Matrix result; | |
std::transform(begin(), end(), rhs.begin(), result.begin(), | |
[] (T val, T val2) { | |
return val + val2; | |
} | |
); | |
return result; | |
} | |
Matrix& operator-=(const Matrix& rhs) { | |
std::transform(begin(), end(), rhs.begin(), begin(), [] (T val, T val2) { | |
return val - val2; | |
}); | |
return *this; | |
} | |
Matrix operator-(const Matrix& rhs) { | |
Matrix result; | |
std::transform(begin(), end(), rhs.begin(), result.begin(), | |
[] (T val, T val2) { | |
return val - val2; | |
} | |
); | |
return result; | |
} | |
template <uint16_t N2> | |
Matrix<M, N2, T> operator*(const Matrix<N, N2, T>& rhs) { | |
Matrix<M, N2, T> result; | |
for(index_t y = 0; y < M; y++) { | |
auto rowAccessor = GetRowAccessor(y); | |
auto outRowAccessor = result.GetRowAccessor(y); | |
for(index_t x = 0; x < N2; x++) { | |
outRowAccessor[x] = std::inner_product(rowAccessor.begin(), | |
rowAccessor.end(), rhs.GetColumnAccessor(x).begin(), T(0)); | |
} | |
} | |
return result; | |
} | |
Matrix& operator*=(const Matrix& rhs) { | |
Matrix result; | |
for(index_t y = 0; y < M; y++) { | |
auto rowAccessor = GetRowAccessor(y); | |
auto outRowAccessor = result.GetRowAccessor(y); | |
for(index_t x = 0; x < N; x++) { | |
outRowAccessor[x] = std::inner_product(rowAccessor.begin(), | |
rowAccessor.end(), rhs.GetColumnAccessor(x).begin(), T(0)); | |
} | |
} | |
m_Values = result.m_Values; | |
return *this; | |
} | |
}; | |
int main() { | |
Matrix<3, 3, float> a{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f}; | |
Matrix<3, 3, float> b{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f}; | |
auto newMat = a * b; | |
std::cout << newMat[0][0] << std::endl; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment