template<typename T> | |
class Matrix { | |
static_assert(std::is_arithmetic<T>::value, "T must be numeric"); | |
public: | |
~Matrix() = default; | |
Matrix(size_t rows, size_t columns, T *m) | |
: nbRows(rows), nbColumns(columns), matrix(std::make_unique<T[]>(rows*columns)) | |
{ | |
const size_t size = nbRows*nbColumns; | |
std::copy(m, m + size, matrix.get()); | |
AssertData(*this); | |
} | |
Matrix(size_t rows, size_t columns) | |
: nbRows(rows), nbColumns(columns), matrix(std::make_unique<T[]>(rows*columns)) | |
{ | |
const size_t size = nbRows*nbColumns; | |
std::fill(matrix.get(), matrix.get() + size, 0); | |
AssertData(*this); | |
} | |
Matrix(const Matrix<T> &m) : nbRows(m.nbRows), nbColumns(m.nbColumns) { | |
const int size = nbRows * nbColumns; | |
matrix = std::make_unique<T[]>(size); | |
std::copy(m.matrix.get(), m.matrix.get() + size, matrix.get()); | |
} | |
Matrix(Matrix<T> &&m) : nbRows(std::move(m.nbRows)), nbColumns(std::move(m.nbColumns)) { | |
matrix.swap(m.matrix); | |
m.nbRows = 0; | |
m.nbColumns = 0; | |
m.matrix.release(); | |
} | |
Matrix<T> & operator=(const Matrix<T> &m){ | |
Matrix tmp(m); | |
nbRows = tmp.nbRows; | |
nbColumns = tmp.nbColumns; | |
matrix.reset(tmp.matrix.get()); | |
return *this; | |
} | |
Matrix<T> & operator=(Matrix<T> &&m){ | |
Matrix tmp(std::move(m)); | |
std::swap(tmp.nbRows, nbRows); | |
std::swap(tmp.nbColumns, nbColumns); | |
matrix.swap(tmp.matrix); | |
return *this; | |
} | |
const T & operator()(size_t row, size_t column) const { | |
return matrix[row*nbColumns + column]; | |
} | |
T & operator()(size_t row, size_t column) { | |
return matrix[row*nbColumns + column]; | |
} | |
[[nodiscard]] size_t rows() const { | |
return nbRows; | |
} | |
[[nodiscard]] size_t columns() const { | |
return nbColumns; | |
} | |
template<typename U> | |
friend Matrix<U> operator*(const Matrix<U> &lhs, const Matrix<U> & rhs); | |
private: | |
static void AssertData(const Matrix<T> &m) { | |
if(m.nbRows == 0 || m.nbColumns == 0) { | |
throw std::domain_error("Invalid defined matrix."); | |
} | |
if(m.nbRows != m.nbColumns) { | |
throw std::domain_error("Matrix is not square."); | |
} | |
} | |
size_t nbRows{0}; | |
size_t nbColumns{0}; | |
std::unique_ptr<T[]> matrix; | |
}; | |
template<typename U> | |
Matrix<U> operator*(const Matrix<U> &lhs, const Matrix<U> & rhs) { | |
Matrix<U>::AssertData(lhs); | |
Matrix<U>::AssertData(rhs); | |
if(lhs.rows() != rhs.rows()) { | |
throw std::domain_error("Matrices have unequal size."); | |
} | |
const size_t lhsRows = lhs.rows(); | |
const size_t rhsColumns = rhs.columns(); | |
const size_t lhsColumns = lhs.columns(); | |
Matrix<U> C(lhsRows, rhsColumns); | |
for (size_t i = 0; i < lhsRows; ++i) { | |
for (size_t k = 0; k < rhsColumns; ++k) { | |
for (size_t j = 0; j < lhsColumns; ++j) { | |
C(i, k) += lhs(i, j) * rhs(j, k); | |
} | |
} | |
} | |
return C; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment