Skip to content

Instantly share code, notes, and snippets.

@tkokof
Created September 29, 2018 03:28
Show Gist options
  • Save tkokof/b984f7ce11249badd98b552c291fdf63 to your computer and use it in GitHub Desktop.
Save tkokof/b984f7ce11249badd98b552c291fdf63 to your computer and use it in GitHub Desktop.
// desc simple implementation of matrix
// maintainer hugoyu
#ifndef __matrix_h__
#define __matrix_h__
#include <cassert>
#include <cstring> // memset
#include <memory>
#include "common.h"
// NOTE should provide allocator here ?
template<typename T, typename Allocator = std::allocator<T>>
class matrix
{
public:
constexpr matrix(uint32 row, uint32 col)
{
assert(is_valid_size(row, col));
m_row = row;
m_col = col;
m_element_buffer = allocate(m_row, m_col);
memset(m_element_buffer, T(), m_row * m_col * sizeof(T));
}
constexpr matrix(const matrix& other)
{
m_row = other.m_row;
m_col = other.m_col;
m_element_buffer = allocate(m_row, m_col);
memcpy(m_element_buffer, other.m_element_buffer, m_row * m_col * sizeof(T));
}
constexpr matrix(matrix&& other)
{
m_row = other.m_row;
m_col = other.m_col;
m_element_buffer = other.m_element_buffer;
other.m_element_buffer = nullptr;
}
constexpr matrix& operator =(const matrix& other)
{
if (this != &other)
{
assert(is_valid_size(row, col));
deallocate(m_element_buffer, m_row, m_col);
m_row = row;
m_col = col;
m_element_buffer = allocate(m_row, m_col);
memcpy(m_element_buffer, other.m_element_buffer, m_row * m_col * sizeof(T));
}
}
constexpr matrix& operator =(matrix&& other)
{
if (this != &other)
{
deallocate(m_element_buffer, m_row, m_col);
m_row = other.m_row;
m_col = other.m_col;
m_element_buffer = other.m_element_buffer;
other.m_element_buffer = nullptr;
}
}
~matrix()
{
deallocate(m_element_buffer, m_row, m_col);
}
constexpr uint32 row() const
{
return m_row;
}
constexpr uint32 col() const
{
return m_col;
}
constexpr const T& operator ()(uint32 row, uint32 col) const
{
assert(is_valid_index(row, col));
return m_element_buffer[row * m_col + col];
}
constexpr T& operator ()(uint32 row, uint32 col)
{
assert(is_valid_index(row, col));
return m_element_buffer[row * m_col + col];
}
matrix<T> operator *(const T& right) const
{
matrix<T> temp(m_row, m_col);
for (uint32 row = 0; row < m_row; ++row)
{
for (uint32 col = 0; col < m_col; ++col)
{
temp(row, col) = (*this)(row, col) * right;
}
}
return temp;
}
matrix<T>& operator *=(const T& right)
{
for (uint32 row = 0; row < m_row; ++row)
{
for (uint32 col = 0; col < m_col; ++col)
{
(*this)(row, col) *= right;
}
}
return *this;
}
matrix<T> operator +(const matrix<T>& right) const
{
assert(row() == right.row() && col() == right.col());
matrix<T> temp(m_row, m_col);
for (uint32 row = 0; row < m_row; ++row)
{
for (uint32 col = 0; col < m_col; ++col)
{
temp(row, col) = (*this)(row, col) + right(row, col);
}
}
return temp;
}
matrix<T>& operator +=(const matrix<T>& right)
{
assert(row() == right.row() && col() == right.col());
for (uint32 row = 0; row < m_row; ++row)
{
for (uint32 col = 0; col < m_col; ++col)
{
(*this)(row, col) += right(row, col);
}
}
return *this;
}
matrix<T> operator -(const matrix<T>& right) const
{
assert(row() == right.row() && col() == right.col());
matrix<T> temp(m_row, m_col);
for (uint32 row = 0; row < m_row; ++row)
{
for (uint32 col = 0; col < m_col; ++col)
{
temp(row, col) = (*this)(row, col) - right(row, col);
}
}
return temp;
}
matrix<T> operator -=(const matrix<T>& right)
{
assert(row() == right.m_row && col() == right.m_col);
for (uint32 row = 0; row < m_row; ++row)
{
for (uint32 col = 0; col < m_col; ++col)
{
(*this)(row, col) -= right(row, col);
}
}
return *this;
}
matrix<T> operator *(const matrix<T>& right) const
{
assert(col() == right.m_row);
matrix<T> temp(m_row, m_col);
auto row = m_row;
auto col = right.m_col;
auto inn = m_col;
for (uint32 i = 0; i < row; ++i)
{
for (uint32 j = 0; j < col; ++j)
{
auto val = T();
for (uint32 k = 0; k < inn; ++k)
{
val += (*this)(i, k) * right(k, j);
}
temp(i, j) = val;
}
}
return temp;
}
constexpr matrix<T>& operator *=(const matrix<T>& right)
{
assert(col() == right.row());
// NOTE if right is *this, we can optimize space ?
auto temp = (*this) * right;
m_row = temp.m_row;
m_col = temp.m_col;
// simple swap, or we can just use move constructor
m_element_buffer = temp.m_element_buffer;
temp.m_element_buffer = nullptr;
return *this;
}
std::string to_string()
{
std::string string_buffer;
string_buffer.reserve(m_row * m_col * 8);
for (uint32 i = 0; i < m_row; ++i)
{
for (uint32 j = 0; j < m_col; ++j)
{
//string_buffer.append(this->operator()(i, j));
string_buffer.append(std::to_string((*this)(i, j)));
string_buffer.append(", ");
}
string_buffer.append("\n");
}
return string_buffer;
}
private:
constexpr T* allocate(uint32 row, uint32 col)
{
return m_allocator.allocate(row * col);
}
constexpr void deallocate(T*& ptr, uint32 row, uint32 col)
{
if (ptr != nullptr)
{
m_allocator.deallocate(ptr, row * col);
ptr = nullptr;
}
}
constexpr bool is_valid_size(uint32 row, uint32 col) const noexcept
{
return row > 0 && col > 0;
}
constexpr bool is_valid_index(uint32 row, uint32 col) const noexcept
{
return row < m_row && col < m_col;
}
private:
uint32 m_row{ 0 };
uint32 m_col{ 0 };
// use flat array
T* m_element_buffer{ nullptr };
Allocator m_allocator;
};
template<typename T>
constexpr matrix<T> operator *(const T& left, const matrix<T>& right)
{
return right * left;
}
#endif
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment