Skip to content

Instantly share code, notes, and snippets.

@dno89
Created March 27, 2022 00:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dno89/1cc0a8a4f6a6a1667e0887caac7f4199 to your computer and use it in GitHub Desktop.
Save dno89/1cc0a8a4f6a6a1667e0887caac7f4199 to your computer and use it in GitHub Desktop.
Sparse matrix multiplication
#include <iostream>
#include <map>
#include <utility>
#include <cassert>
class SparseMatrix {
public:
SparseMatrix(size_t rows, size_t cols) :
rows_(rows), cols_(cols)
{}
double getValue(size_t row, size_t col) const {
assert(row < rows_ && col < cols_);
if(auto iter = entries_.find({row, col}); iter != entries_.end()) {
return iter->second;
}
return 0.0;
}
void setValue(size_t row, size_t col, double value) {
assert(row < rows_ && col < cols_);
entries_[{row, col}] = value;
}
friend std::ostream& operator<<(std::ostream& out, const SparseMatrix& m) {
out << m.rows_ << "x" << m.cols_ << " matrix (" << m.entries_.size() << " entries): ";
for(const auto& entry : m.entries_) {
out << "{" << entry.first.first << ", " << entry.first.second << ": " << entry.second << "},";
}
return out;
}
size_t rows() const {return rows_; }
size_t cols() const { return cols_; }
SparseMatrix operator*(const SparseMatrix& b) {
// D = A*B
assert(cols() == b.rows());
SparseMatrix d(rows(), b.cols());
// Multiply.
for(const auto& [indices, a_ik] : entries_) {
size_t i = indices.first; size_t k = indices.second;
for(size_t j = 0; j < b.cols(); ++j) {
double b_kj = b.getValue(k, j);
if(b_kj != 0.0) {
// Update d_ij.
d.setValue(i, j, d.getValue(i,j) + a_ik*b_kj);
}
}
}
return d;
}
private:
size_t rows_ = 0;
size_t cols_ = 0;
using Indices = std::pair<size_t, size_t>;
std::map<Indices, double> entries_;
};
int main(int,char**) {
SparseMatrix m(2,2);
// std::cout << m << std::endl;
m.setValue(0,0, 1.0);
m.setValue(1,1, 1.0);
// std::cout << m << std::endl;
SparseMatrix b1(2,2);
b1.setValue(0,1, 2.0);
std::cout << "b1 " << m*b1 << std::endl;
SparseMatrix b2(2,3);
b2.setValue(0,1, 2.0);
b2.setValue(1,2, 3.0);
std::cout << "b2 " << m*b2 << std::endl;
SparseMatrix A(2,3);
A.setValue(0,0, 1.0);
A.setValue(1,1, 2.0);
A.setValue(1,2, 1.0);
SparseMatrix B(3,2);
B.setValue(0,0, 1.0);
B.setValue(0,1, 1.0);
B.setValue(2,0, 1.0);
std::cout << "D " << A*B << std::endl;
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment