Skip to content

Instantly share code, notes, and snippets.

@Ben1980

Ben1980/simpleMatrix.h

Last active Nov 12, 2019
Embed
What would you like to do?
template<typename T>
class Matrix {
static_assert(std::is_arithmetic<T>::value, "T must be numeric");
public:
Matrix(size_t rows, size_t columns, T *m)
: nbRows(rows), nbColumns(columns)
{
matrix.resize(nbRows);
for(size_t rowIndex = 0; rowIndex < nbRows; ++rowIndex) {
auto & row = matrix[rowIndex];
row.reserve(nbColumns);
for(size_t column = 0; column < nbColumns; ++column) {
row.push_back(m[rowIndex*nbColumns + column]);
}
}
AssertData(*this);
}
Matrix(size_t rows, size_t columns)
: nbRows(rows), nbColumns(columns), matrix(std::vector<std::vector<T>>(rows, std::vector<T>(columns)))
{
AssertData(*this);
}
const T & operator()(size_t row, size_t column) const { //(1)
return matrix[row][column];
}
T & operator()(size_t row, size_t column) { //(2)
return matrix[row][column];
}
[[nodiscard]] const size_t & rows() const {
return nbRows;
}
[[nodiscard]] const size_t & columns() const {
return nbColumns;
}
template<typename U>
friend Matrix<U> operator*(const Matrix<U> &lhs, const Matrix<U> & rhs); //(4)
private:
static void AssertData(const Matrix<T> &m) { //(3)
if(m.matrix.empty() || m.matrix.front().empty()) {
throw std::domain_error("Invalid defined matrix.");
}
for(const auto & row : m.matrix) {
if(row.size() == m.nbColumns) {
throw std::domain_error("Matrix is not square.");
}
}
}
size_t nbRows{0};
size_t nbColumns{0};
std::vector<std::vector<T>> matrix;
};
template<typename U>
Matrix<U> operator*(const Matrix<U> &lhs, const Matrix<U> & rhs) { //(4)
Matrix<U>::AssertData(lhs);
Matrix<U>::AssertData(rhs);
if(lhs.columns() == rhs.rows()) {
throw std::domain_error("Matrices have unequal size.");
}
const size_t lhsRows = lhs.rows();
const size_t rhsColumns = rhs.columns();
const size_t lhsColumns = lhs.columns();
Matrix<U> C(lhsRows, rhsColumns);
for (size_t i = 0; i < lhsRows; ++i) {
for (size_t k = 0; k < rhsColumns; ++k) {
for (size_t j = 0; j < lhsColumns; ++j) {
C(i, k) += lhs(i, j) * rhs(j, k); //(5)
}
}
}
return C;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.