Created
March 27, 2022 00:42
-
-
Save dno89/1cc0a8a4f6a6a1667e0887caac7f4199 to your computer and use it in GitHub Desktop.
Sparse matrix multiplication
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <map> | |
#include <utility> | |
#include <cassert> | |
class SparseMatrix { | |
public: | |
SparseMatrix(size_t rows, size_t cols) : | |
rows_(rows), cols_(cols) | |
{} | |
double getValue(size_t row, size_t col) const { | |
assert(row < rows_ && col < cols_); | |
if(auto iter = entries_.find({row, col}); iter != entries_.end()) { | |
return iter->second; | |
} | |
return 0.0; | |
} | |
void setValue(size_t row, size_t col, double value) { | |
assert(row < rows_ && col < cols_); | |
entries_[{row, col}] = value; | |
} | |
friend std::ostream& operator<<(std::ostream& out, const SparseMatrix& m) { | |
out << m.rows_ << "x" << m.cols_ << " matrix (" << m.entries_.size() << " entries): "; | |
for(const auto& entry : m.entries_) { | |
out << "{" << entry.first.first << ", " << entry.first.second << ": " << entry.second << "},"; | |
} | |
return out; | |
} | |
size_t rows() const {return rows_; } | |
size_t cols() const { return cols_; } | |
SparseMatrix operator*(const SparseMatrix& b) { | |
// D = A*B | |
assert(cols() == b.rows()); | |
SparseMatrix d(rows(), b.cols()); | |
// Multiply. | |
for(const auto& [indices, a_ik] : entries_) { | |
size_t i = indices.first; size_t k = indices.second; | |
for(size_t j = 0; j < b.cols(); ++j) { | |
double b_kj = b.getValue(k, j); | |
if(b_kj != 0.0) { | |
// Update d_ij. | |
d.setValue(i, j, d.getValue(i,j) + a_ik*b_kj); | |
} | |
} | |
} | |
return d; | |
} | |
private: | |
size_t rows_ = 0; | |
size_t cols_ = 0; | |
using Indices = std::pair<size_t, size_t>; | |
std::map<Indices, double> entries_; | |
}; | |
int main(int,char**) { | |
SparseMatrix m(2,2); | |
// std::cout << m << std::endl; | |
m.setValue(0,0, 1.0); | |
m.setValue(1,1, 1.0); | |
// std::cout << m << std::endl; | |
SparseMatrix b1(2,2); | |
b1.setValue(0,1, 2.0); | |
std::cout << "b1 " << m*b1 << std::endl; | |
SparseMatrix b2(2,3); | |
b2.setValue(0,1, 2.0); | |
b2.setValue(1,2, 3.0); | |
std::cout << "b2 " << m*b2 << std::endl; | |
SparseMatrix A(2,3); | |
A.setValue(0,0, 1.0); | |
A.setValue(1,1, 2.0); | |
A.setValue(1,2, 1.0); | |
SparseMatrix B(3,2); | |
B.setValue(0,0, 1.0); | |
B.setValue(0,1, 1.0); | |
B.setValue(2,0, 1.0); | |
std::cout << "D " << A*B << std::endl; | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment