Skip to content

Instantly share code, notes, and snippets.

@PhDP
Created March 8, 2020 00:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save PhDP/3228222357731751fc40e287f5c8dffa to your computer and use it in GitHub Desktop.
Save PhDP/3228222357731751fc40e287f5c8dffa to your computer and use it in GitHub Desktop.
Matrix multiplication with a matrix of type Integer -> Integer -> T.
// C++17
#include <iostream>
#include <array>
#include <initializer_list>
// Column-major matrix based on std::array.
template<size_t ROW, size_t COL, typename T>
class matrix {
public:
// Assuming the initializer lists are the right sizes. Hey, it's just a demo!
constexpr matrix(std::initializer_list<std::initializer_list<T>> const& init) {
auto row_id = 0;
for (auto const& row : init) {
auto col_id = 0;
for (auto const& value : row) {
m_arr[row_id * COL + col_id] = value;
++col_id;
}
++row_id;
}
}
constexpr auto size() const -> size_t {
return ROW * COL;
}
constexpr auto operator()(size_t row_id, size_t col_id) const -> T const& {
return m_arr.at(row_id * COL + col_id);
}
constexpr auto operator()(size_t row_id, size_t col_id) -> T& {
return m_arr.at(row_id * COL + col_id);
}
template<size_t C>
constexpr auto operator*(matrix<COL, C, T> const& other) const -> matrix<ROW, C, T> {
// Naive matrix multiplication:
auto ans = matrix<ROW, C, T>{};
for (auto row = 0; row < ROW; ++row) {
for (auto col = 0; col < C; ++col) {
for (auto k = 0; k < COL; ++k) {
ans(row, col) += this->operator()(row, k) * other(k, col);
}
}
}
return ans;
}
private:
std::array<T, (ROW * COL)> m_arr;
};
template<size_t ROW, size_t COL, typename T>
auto operator<<(std::ostream& os, matrix<ROW, COL, T> const& m) -> std::ostream& {
if (ROW == 0 || COL == 0) {
return os;
}
for (auto r = 0; r < ROW; ++r) {
auto c = 0;
os << '[' << m(r, c++);
for (; c < COL; ++c) {
os << ", " << m(r, c);
}
os << "]\n";
}
return os;
}
auto main() -> int {
auto const A = matrix<2, 3, double>({{1, 2, 3}, {4, 5, 6}});
auto const B = matrix<3, 2, double>({{7, 8}, {9, 10}, {11, 12}});
std::cout << A << "\n*\n\n";
std::cout << B << "\n=\n\n";
std::cout << (A * B) << '\n';
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment