Skip to content

Instantly share code, notes, and snippets.

@tkokof
Created September 29, 2018 03:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tkokof/ebaa026e126faa7e4ecc69293b5096f3 to your computer and use it in GitHub Desktop.
Save tkokof/ebaa026e126faa7e4ecc69293b5096f3 to your computer and use it in GitHub Desktop.
// desc simple implementation of sparse matrix
// maintainer hugoyu
#ifndef __sparse_matrix_h__
#define __sparse_matrix_h__
#include <cassert>
#include <string>
#include "common.h"
template<typename T, typename Container>
class sparse_matrix
{
public:
constexpr sparse_matrix(uint32 row, uint32 col)
{
assert(is_valid_size(row, col));
m_row = row;
m_col = col;
}
constexpr sparse_matrix(const sparse_matrix& other)
{
m_row = other.m_row;
m_col = other.m_col;
m_element_buffer = other.m_element_buffer;
}
constexpr sparse_matrix(sparse_matrix&& other)
{
m_row = other.m_row;
m_col = other.m_col;
m_element_buffer.swap(other.m_element_buffer);
}
constexpr sparse_matrix& operator =(const sparse_matrix& other)
{
if (this != &other)
{
assert(is_valid_size(row, col));
m_row = row;
m_col = col;
m_element_buffer = other.m_element_buffer;
}
}
constexpr sparse_matrix& operator =(sparse_matrix&& other)
{
if (this != &other)
{
m_row = other.m_row;
m_col = other.m_col;
m_element_buffer.swap(other.m_element_buffer);
}
}
constexpr uint32 row() const
{
return m_row;
}
constexpr uint32 col() const
{
return m_col;
}
constexpr const T& operator ()(uint32 row, uint32 col) const
{
assert(is_valid_index(row, col));
auto iter = m_element_buffer.find(gen_element_key(row, col));
if (iter != m_element_buffer.end())
{
return iter->second;
}
return T();
}
constexpr T& operator ()(uint32 row, uint32 col)
{
assert(is_valid_index(row, col));
return m_element_buffer[gen_element_key(row, col)];
}
sparse_matrix<T, Container> operator *(const T& right) const
{
sparse_matrix<T> temp(m_row, m_col);
for (auto& element : m_element_buffer)
{
/*
uint32 row = 0;
uint32 col = 0;
extract_element_key(iter.first, row, col);
temp(row, col) = iter.second * right;
*/
temp.m_element_buffer[element.first] = element.second * right;
}
return temp;
}
sparse_matrix<T, Container>& operator *=(const T& right)
{
for (auto& element : m_element_buffer)
{
element.second *= right;
}
return *this;
}
sparse_matrix<T, Container> operator +(const sparse_matrix<T, Container>& right) const
{
assert(row() == right.row() && col() == right.col());
sparse_matrix<T> temp(m_row, m_col);
/*
for (auto& element : m_element_buffer)
{
//uint32 row = 0;
//uint32 col = 0;
//extract_element_key(element.first, row, col);
//temp(row, col) = element.second;
temp.m_element_buffer[element.first] = element.second;
}
*/
temp.m_element_buffer = m_element_buffer;
for (auto& element : right.m_element_buffer)
{
/*
uint32 row = 0;
uint32 col = 0;
extract_element_key(iter.first, row, col);
temp(row, col) = (*this)(row, col) + iter.second;
*/
/*
auto val = T();
auto left_iter = m_element_buffer.find(element.first);
if (left_iter != m_element_buffer.end())
{
val = left_iter->second;
}
temp.m_element_buffer[element.first] = val + element.second;
*/
temp.m_element_buffer[element.first] += element.second;
}
return temp;
}
sparse_matrix<T, Container>& operator +=(const sparse_matrix<T, Container>& right)
{
assert(row() == right.row() && col() == right.col());
for (auto& element : right.m_element_buffer)
{
/*
uint32 row = 0;
uint32 col = 0;
extract_element_key(iter.first, row, col);
(*this)(row, col) += iter.second;
*/
/*
auto val = T();
auto left_iter = m_element_buffer.find(element.first);
if (left_iter != m_element_buffer.end())
{
val = left_iter->second;
}
m_element_buffer[element.first] = val + element.second;
*/
m_element_buffer[element.first] += element.second;
}
return *this;
}
sparse_matrix<T, Container> operator -(const sparse_matrix<T, Container>& right) const
{
assert(row() == right.row() && col() == right.col());
sparse_matrix<T> temp(m_row, m_col);
/*
for (auto& element : m_element_buffer)
{
uint32 row = 0;
uint32 col = 0;
extract_element_key(element.first, row, col);
temp(row, col) = element.second;
}
*/
temp.m_element_buffer = m_element_buffer;
for (auto& element : right.m_element_buffer)
{
/*
uint32 row = 0;
uint32 col = 0;
extract_element_key(iter.first, row, col);
temp(row, col) = (*this)(row, col) - iter.second;
*/
/*
auto val = T();
auto left_iter = m_element_buffer.find(element.first);
if (left_iter != m_element_buffer.end())
{
val = left_iter->second;
}
temp.m_element_buffer[element.first] = val - element.second;
*/
temp.m_element_buffer[element.first] -= element.second;
}
return temp;
}
sparse_matrix<T, Container> operator -=(const sparse_matrix<T, Container>& right)
{
assert(row() == right.m_row && col() == right.m_col);
for (auto& element : right.m_element_buffer)
{
/*
uint32 row = 0;
uint32 col = 0;
extract_element_key(iter.first, row, col);
(*this)(row, col) -= iter.second;
*/
/*
auto val = T();
auto left_iter = m_element_buffer.find(element.first);
if (left_iter != m_element_buffer.end())
{
val = left_iter->second;
}
m_element_buffer[element.first] = val - element.second;
*/
m_element_buffer[element.first] -= element.second;
}
return *this;
}
sparse_matrix<T, Container> operator *(const sparse_matrix<T, Container>& right) const
{
assert(col() == right.m_row);
sparse_matrix<T> temp(m_row, m_col);
auto row = m_row;
auto col = right.m_col;
auto inn = m_col;
for (uint32 i = 0; i < row; ++i)
{
for (uint32 j = 0; j < col; ++j)
{
auto val = T();
for (uint32 k = 0; k < inn; ++k)
{
val += (*this)(i, k) * right(k, j);
}
temp(i, j) = val;
}
}
return temp;
}
constexpr sparse_matrix<T, Container>& operator *=(const sparse_matrix<T, Container>& right)
{
assert(col() == right.row());
// NOTE if right is *this, we can optimize space ?
auto temp = (*this) * right;
m_row = temp.m_row;
m_col = temp.m_col;
// simple swap, or we can just use move constructor
m_element_buffer.swap(temp.m_element_buffer);
return *this;
}
std::string to_string()
{
std::string string_buffer;
string_buffer.reserve(m_row * m_col * 8);
for (uint32 i = 0; i < m_row; ++i)
{
for (uint32 j = 0; j < m_col; ++j)
{
//string_buffer.append(this->operator()(i, j));
string_buffer.append(std::to_string((*this)(i, j)));
string_buffer.append(", ");
}
string_buffer.append("\n");
}
return string_buffer;
}
private:
constexpr bool is_valid_size(uint32 row, uint32 col) const noexcept
{
return row > 0 && col > 0;
}
constexpr bool is_valid_index(uint32 row, uint32 col) const noexcept
{
return row < m_row && col < m_col;
}
constexpr static uint64 gen_element_key(uint32 row, uint32 col) noexcept
{
return ((uint64)row << 32) | ((uint64)col);
}
constexpr static void extract_element_key(uint64 key, uint32& row, uint32& col) noexcept
{
row = (uint32)((key >> 32) & 0xFFFFFFFF);
col = (uint32)(key & 0xFFFFFFFF);
}
private:
uint32 m_row{ 0 };
uint32 m_col{ 0 };
Container m_element_buffer;
};
template<typename T, typename Container>
constexpr sparse_matrix<T, Container> operator *(const T& left, const sparse_matrix<T, Container>& right)
{
return right * left;
}
#endif
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment