Created
September 29, 2018 03:28
-
-
Save tkokof/b984f7ce11249badd98b552c291fdf63 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// desc simple implementation of matrix | |
// maintainer hugoyu | |
#ifndef __matrix_h__ | |
#define __matrix_h__ | |
#include <cassert> | |
#include <cstring> // memset | |
#include <memory> | |
#include "common.h" | |
// NOTE should provide allocator here ? | |
template<typename T, typename Allocator = std::allocator<T>> | |
class matrix | |
{ | |
public: | |
constexpr matrix(uint32 row, uint32 col) | |
{ | |
assert(is_valid_size(row, col)); | |
m_row = row; | |
m_col = col; | |
m_element_buffer = allocate(m_row, m_col); | |
memset(m_element_buffer, T(), m_row * m_col * sizeof(T)); | |
} | |
constexpr matrix(const matrix& other) | |
{ | |
m_row = other.m_row; | |
m_col = other.m_col; | |
m_element_buffer = allocate(m_row, m_col); | |
memcpy(m_element_buffer, other.m_element_buffer, m_row * m_col * sizeof(T)); | |
} | |
constexpr matrix(matrix&& other) | |
{ | |
m_row = other.m_row; | |
m_col = other.m_col; | |
m_element_buffer = other.m_element_buffer; | |
other.m_element_buffer = nullptr; | |
} | |
constexpr matrix& operator =(const matrix& other) | |
{ | |
if (this != &other) | |
{ | |
assert(is_valid_size(row, col)); | |
deallocate(m_element_buffer, m_row, m_col); | |
m_row = row; | |
m_col = col; | |
m_element_buffer = allocate(m_row, m_col); | |
memcpy(m_element_buffer, other.m_element_buffer, m_row * m_col * sizeof(T)); | |
} | |
} | |
constexpr matrix& operator =(matrix&& other) | |
{ | |
if (this != &other) | |
{ | |
deallocate(m_element_buffer, m_row, m_col); | |
m_row = other.m_row; | |
m_col = other.m_col; | |
m_element_buffer = other.m_element_buffer; | |
other.m_element_buffer = nullptr; | |
} | |
} | |
~matrix() | |
{ | |
deallocate(m_element_buffer, m_row, m_col); | |
} | |
constexpr uint32 row() const | |
{ | |
return m_row; | |
} | |
constexpr uint32 col() const | |
{ | |
return m_col; | |
} | |
constexpr const T& operator ()(uint32 row, uint32 col) const | |
{ | |
assert(is_valid_index(row, col)); | |
return m_element_buffer[row * m_col + col]; | |
} | |
constexpr T& operator ()(uint32 row, uint32 col) | |
{ | |
assert(is_valid_index(row, col)); | |
return m_element_buffer[row * m_col + col]; | |
} | |
matrix<T> operator *(const T& right) const | |
{ | |
matrix<T> temp(m_row, m_col); | |
for (uint32 row = 0; row < m_row; ++row) | |
{ | |
for (uint32 col = 0; col < m_col; ++col) | |
{ | |
temp(row, col) = (*this)(row, col) * right; | |
} | |
} | |
return temp; | |
} | |
matrix<T>& operator *=(const T& right) | |
{ | |
for (uint32 row = 0; row < m_row; ++row) | |
{ | |
for (uint32 col = 0; col < m_col; ++col) | |
{ | |
(*this)(row, col) *= right; | |
} | |
} | |
return *this; | |
} | |
matrix<T> operator +(const matrix<T>& right) const | |
{ | |
assert(row() == right.row() && col() == right.col()); | |
matrix<T> temp(m_row, m_col); | |
for (uint32 row = 0; row < m_row; ++row) | |
{ | |
for (uint32 col = 0; col < m_col; ++col) | |
{ | |
temp(row, col) = (*this)(row, col) + right(row, col); | |
} | |
} | |
return temp; | |
} | |
matrix<T>& operator +=(const matrix<T>& right) | |
{ | |
assert(row() == right.row() && col() == right.col()); | |
for (uint32 row = 0; row < m_row; ++row) | |
{ | |
for (uint32 col = 0; col < m_col; ++col) | |
{ | |
(*this)(row, col) += right(row, col); | |
} | |
} | |
return *this; | |
} | |
matrix<T> operator -(const matrix<T>& right) const | |
{ | |
assert(row() == right.row() && col() == right.col()); | |
matrix<T> temp(m_row, m_col); | |
for (uint32 row = 0; row < m_row; ++row) | |
{ | |
for (uint32 col = 0; col < m_col; ++col) | |
{ | |
temp(row, col) = (*this)(row, col) - right(row, col); | |
} | |
} | |
return temp; | |
} | |
matrix<T> operator -=(const matrix<T>& right) | |
{ | |
assert(row() == right.m_row && col() == right.m_col); | |
for (uint32 row = 0; row < m_row; ++row) | |
{ | |
for (uint32 col = 0; col < m_col; ++col) | |
{ | |
(*this)(row, col) -= right(row, col); | |
} | |
} | |
return *this; | |
} | |
matrix<T> operator *(const matrix<T>& right) const | |
{ | |
assert(col() == right.m_row); | |
matrix<T> temp(m_row, m_col); | |
auto row = m_row; | |
auto col = right.m_col; | |
auto inn = m_col; | |
for (uint32 i = 0; i < row; ++i) | |
{ | |
for (uint32 j = 0; j < col; ++j) | |
{ | |
auto val = T(); | |
for (uint32 k = 0; k < inn; ++k) | |
{ | |
val += (*this)(i, k) * right(k, j); | |
} | |
temp(i, j) = val; | |
} | |
} | |
return temp; | |
} | |
constexpr matrix<T>& operator *=(const matrix<T>& right) | |
{ | |
assert(col() == right.row()); | |
// NOTE if right is *this, we can optimize space ? | |
auto temp = (*this) * right; | |
m_row = temp.m_row; | |
m_col = temp.m_col; | |
// simple swap, or we can just use move constructor | |
m_element_buffer = temp.m_element_buffer; | |
temp.m_element_buffer = nullptr; | |
return *this; | |
} | |
std::string to_string() | |
{ | |
std::string string_buffer; | |
string_buffer.reserve(m_row * m_col * 8); | |
for (uint32 i = 0; i < m_row; ++i) | |
{ | |
for (uint32 j = 0; j < m_col; ++j) | |
{ | |
//string_buffer.append(this->operator()(i, j)); | |
string_buffer.append(std::to_string((*this)(i, j))); | |
string_buffer.append(", "); | |
} | |
string_buffer.append("\n"); | |
} | |
return string_buffer; | |
} | |
private: | |
constexpr T* allocate(uint32 row, uint32 col) | |
{ | |
return m_allocator.allocate(row * col); | |
} | |
constexpr void deallocate(T*& ptr, uint32 row, uint32 col) | |
{ | |
if (ptr != nullptr) | |
{ | |
m_allocator.deallocate(ptr, row * col); | |
ptr = nullptr; | |
} | |
} | |
constexpr bool is_valid_size(uint32 row, uint32 col) const noexcept | |
{ | |
return row > 0 && col > 0; | |
} | |
constexpr bool is_valid_index(uint32 row, uint32 col) const noexcept | |
{ | |
return row < m_row && col < m_col; | |
} | |
private: | |
uint32 m_row{ 0 }; | |
uint32 m_col{ 0 }; | |
// use flat array | |
T* m_element_buffer{ nullptr }; | |
Allocator m_allocator; | |
}; | |
template<typename T> | |
constexpr matrix<T> operator *(const T& left, const matrix<T>& right) | |
{ | |
return right * left; | |
} | |
#endif |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment