-
-
Save SimLeek/808266d917af8d22e54dbe829eb5f876 to your computer and use it in GitHub Desktop.
Sparse Tensor implementation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
template<typename _Tp, std::size_t _Nm> | |
class ConstructableArray{ | |
public: | |
ConstructableArray(std::initializer_list<_Tp> list) | |
{ | |
std::copy(list.begin(), list.end(), pointer); | |
} | |
ConstructableArray(const ConstructableArray& other) | |
{ | |
std::copy(other.pointer, other.pointer+_Nm, pointer); | |
} | |
_Tp operator[](size_t i){ | |
return pointer[i]; | |
} | |
private: | |
_Tp pointer[_Nm]; | |
}; | |
template<typename _Tp, std::size_t _Nm> | |
std::ostream &operator<<(std::ostream &os, ConstructableArray<_Tp, _Nm> &m) { | |
os<<"["<<m[0]; | |
for (int i=1;i<_Nm;++i){ | |
os<<","<<m[i]; | |
} | |
return os << "]"; | |
} | |
template<typename _Tp, std::size_t _Nm> | |
std::string to_string(ConstructableArray<_Tp, _Nm> m){ | |
std::stringstream ss; | |
ss<<m; | |
return ss.str(); | |
} | |
template<typename _Tpi,typename _Tpr, std::size_t _Ndim> | |
class SparseTensor{ | |
public: | |
SparseTensor(){} | |
SparseTensor(const _Tpr& default_value):default_value(default_value){} | |
_Tpr& operator[](ConstructableArray<_Tpi,_Ndim> loc){ | |
//note: this can and should be improved to not used lists if time permits | |
//note: some modification necessary for higher precision floats | |
auto it=sparse_map.find(loc); | |
if(it==sparse_map.end()){ | |
sparse_map[loc]=default_value; | |
} | |
return sparse_map[loc]; | |
} | |
std::vector<_Tpr> get_array(ConstructableArray<_Tpi,_Ndim> max_loc){ | |
_Tpi num=1; | |
for(std::size_t i=0; i<_Ndim;++i){ | |
num*=max_loc[i]; | |
} | |
std::vector<_Tpr> dense_array; | |
dense_array.resize(num, default_value); | |
for(auto it: sparse_map){ | |
_Tpi pos = 0; | |
_Tpi multiplier = 1; | |
for(std::size_t i=0; i<_Ndim;++i){ | |
pos += multiplier*(it->first[i]); | |
multiplier*=max_loc[i]; | |
} | |
dense_array[pos] = it->second; | |
} | |
return dense_array; | |
} | |
_Tpr& operator[](std::initializer_list<_Tpi> list){ | |
ConstructableArray<std::size_t,_Ndim> loc(list); | |
return operator[](loc); | |
} | |
std::vector<_Tpr> get_array(std::initializer_list<_Tpi> list){ | |
ConstructableArray<std::size_t,_Ndim> loc(list); | |
return get_array(loc); | |
} | |
_Tpr default_value; | |
private: | |
std::unordered_map<ConstructableArray<_Tpi,_Ndim>, _Tpr&> sparse_map; | |
}; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
test: