Skip to content

Instantly share code, notes, and snippets.

@cynecx
Created May 30, 2016 23:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cynecx/e1b44a67e74bac257ed9e632c982fc7f to your computer and use it in GitHub Desktop.
Save cynecx/e1b44a67e74bac257ed9e632c982fc7f to your computer and use it in GitHub Desktop.
My crappy Matrix implementation.
#include <algorithm>
#include <array>
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <numeric>
#include <type_traits>
template <uint16_t M, uint16_t N = M, typename T = float>
struct Matrix {
using index_t = uint16_t;
static_assert(M >= 1 && N >= 1);
static constexpr size_t ElementsCount = M * N;
using matrix_ty = std::array<T, ElementsCount>;
using iterator = typename matrix_ty::iterator;
using const_iterator = typename matrix_ty::const_iterator;
template <typename Accessor, index_t Limit>
struct MatrixAccessorIterator {
private:
const Accessor m_Accessor;
index_t m_CurrIndex;
public:
constexpr MatrixAccessorIterator(const Accessor& accessor,
index_t currIndex)
: m_Accessor{accessor}, m_CurrIndex{currIndex} {
assert(currIndex <= Limit);
}
constexpr MatrixAccessorIterator(const MatrixAccessorIterator&) = default;
MatrixAccessorIterator& operator++() {
++m_CurrIndex;
assert(m_CurrIndex <= Limit);
return *this;
}
MatrixAccessorIterator operator++(int) {
MatrixAccessorIterator tmp(*this);
++m_CurrIndex;
assert(m_CurrIndex <= Limit);
return tmp;
}
bool operator==(const MatrixAccessorIterator& rhs) const {
return m_Accessor == rhs.m_Accessor && m_CurrIndex == rhs.m_CurrIndex;
}
bool operator!=(const MatrixAccessorIterator& rhs) const {
return !(*this == rhs);
}
T operator*() const {
return m_Accessor[m_CurrIndex];
}
T operator->() const {
return m_Accessor[m_CurrIndex];
}
};
template <typename MatTy>
struct MatrixRowAccessor {
using IteratorTy = MatrixAccessorIterator<MatrixRowAccessor, N>;
using ValTy = std::conditional_t<std::is_const<MatTy>::value,
std::add_const_t<T>, T>;
private:
MatTy& m_ArrRef;
const size_t m_RowIndex;
public:
constexpr MatrixRowAccessor(MatTy& arr, index_t row)
: m_ArrRef{arr}, m_RowIndex(row * N) {
assert(row < M);
}
constexpr MatrixRowAccessor(const MatrixRowAccessor&) = default;
constexpr size_t Size() const {
return N;
}
const T& operator[](index_t column) const {
assert(column < N);
return m_ArrRef[m_RowIndex + column];
}
ValTy& operator[](index_t column) {
assert(column < N);
return m_ArrRef[m_RowIndex + column];
}
bool operator==(const MatrixRowAccessor& rhs) const {
return &m_ArrRef == &rhs.m_ArrRef && m_RowIndex == rhs.m_RowIndex;
}
bool operator!=(const MatrixRowAccessor& rhs) const {
return !(*this == rhs);
}
IteratorTy begin() const {
return {*this, 0};
}
IteratorTy end() const {
return {*this, N};
}
};
template <typename MatTy>
struct MatrixColumnAccessor {
using IteratorTy = MatrixAccessorIterator<MatrixColumnAccessor, M>;
private:
MatTy& m_ArrRef;
const index_t m_Column;
public:
constexpr MatrixColumnAccessor(MatTy& arr, index_t column)
: m_ArrRef{arr}, m_Column{column} {
assert(column < N);
}
constexpr MatrixColumnAccessor(const MatrixColumnAccessor& rhs) = default;
constexpr size_t Size() const {
return M;
}
const T& operator[](index_t row) const {
assert(row < M);
return m_ArrRef[row * N + m_Column];
}
T& operator[](index_t row) {
assert(row < M);
return m_ArrRef[row * N + m_Column];
}
bool operator==(const MatrixColumnAccessor& rhs) const {
return &m_ArrRef == &rhs.m_ArrRef && m_Column == rhs.m_Column;
}
bool operator!=(const MatrixColumnAccessor& rhs) const {
return !(*this == rhs);
}
IteratorTy begin() const {
return {*this, 0};
}
IteratorTy end() const {
return {*this, M};
}
};
matrix_ty m_Values = { {0} };
constexpr Matrix() = default;
template <typename... ElemsTy>
constexpr Matrix(ElemsTy... values) : m_Values{{T{values}...}} {
static_assert(sizeof...(ElemsTy) == ElementsCount);
}
constexpr Matrix(const Matrix&) = default;
constexpr Matrix& operator=(const Matrix&) = default;
const MatrixRowAccessor<const matrix_ty> GetRowAccessor(
index_t row) const {
return {m_Values, row};
}
MatrixRowAccessor<matrix_ty> GetRowAccessor(index_t row) {
return {m_Values, row};
}
const MatrixColumnAccessor<const matrix_ty> GetColumnAccessor(
index_t column) const {
return {m_Values, column};
}
MatrixColumnAccessor<matrix_ty> GetColumnAccessor(
index_t column) {
return {m_Values, column};
}
const MatrixRowAccessor<const matrix_ty> operator[](
index_t row) const {
return GetRowAccessor(row);
}
MatrixRowAccessor<matrix_ty> operator[](index_t row) {
return GetRowAccessor(row);
}
iterator begin() {
return m_Values.begin();
}
const_iterator begin() const {
return m_Values.begin();
}
iterator end() {
return m_Values.end();
}
const_iterator end() const {
return m_Values.end();
}
Matrix& operator*=(const T& rhs) {
std::transform(begin(), end(), begin(), [&rhs] (T val) {
return val * rhs;
});
return *this;
}
Matrix operator*(T rhs) {
Matrix result;
std::transform(begin(), end(), result.begin(), [&rhs] (T val) {
return val * rhs;
});
return result;
}
Matrix& operator/=(const T& rhs) {
std::transform(begin(), end(), begin(), [&rhs] (T val) {
return val / rhs;
});
return *this;
}
Matrix operator/(T rhs) {
Matrix result;
std::transform(begin(), end(), result.begin(), [&rhs] (T val) {
return val / rhs;
});
return result;
}
Matrix& operator+=(const Matrix& rhs) {
std::transform(begin(), end(), rhs.begin(), begin(), [] (T val, T val2) {
return val + val2;
});
return *this;
}
Matrix operator+(const Matrix& rhs) {
Matrix result;
std::transform(begin(), end(), rhs.begin(), result.begin(),
[] (T val, T val2) {
return val + val2;
}
);
return result;
}
Matrix& operator-=(const Matrix& rhs) {
std::transform(begin(), end(), rhs.begin(), begin(), [] (T val, T val2) {
return val - val2;
});
return *this;
}
Matrix operator-(const Matrix& rhs) {
Matrix result;
std::transform(begin(), end(), rhs.begin(), result.begin(),
[] (T val, T val2) {
return val - val2;
}
);
return result;
}
template <uint16_t N2>
Matrix<M, N2, T> operator*(const Matrix<N, N2, T>& rhs) {
Matrix<M, N2, T> result;
for(index_t y = 0; y < M; y++) {
auto rowAccessor = GetRowAccessor(y);
auto outRowAccessor = result.GetRowAccessor(y);
for(index_t x = 0; x < N2; x++) {
outRowAccessor[x] = std::inner_product(rowAccessor.begin(),
rowAccessor.end(), rhs.GetColumnAccessor(x).begin(), T(0));
}
}
return result;
}
Matrix& operator*=(const Matrix& rhs) {
Matrix result;
for(index_t y = 0; y < M; y++) {
auto rowAccessor = GetRowAccessor(y);
auto outRowAccessor = result.GetRowAccessor(y);
for(index_t x = 0; x < N; x++) {
outRowAccessor[x] = std::inner_product(rowAccessor.begin(),
rowAccessor.end(), rhs.GetColumnAccessor(x).begin(), T(0));
}
}
m_Values = result.m_Values;
return *this;
}
};
int main() {
Matrix<3, 3, float> a{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f};
Matrix<3, 3, float> b{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f};
auto newMat = a * b;
std::cout << newMat[0][0] << std::endl;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment