Created
December 2, 2016 08:13
-
-
Save htfy96/81914dc5fe8063317ed0f1d8119a01e3 to your computer and use it in GitHub Desktop.
Fast matrix multiplication with threads and AVX instructions
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <cstddef> | |
#include <cstring> | |
#include <cstdlib> | |
#include <thread> | |
#include <functional> | |
#include <queue> | |
#include <chrono> | |
#include <mutex> | |
#include <sstream> | |
#include <stdexcept> | |
#include <algorithm> | |
#include <cstdint> | |
#include <memory> | |
#include <condition_variable> | |
#include <x86intrin.h> | |
using namespace std; | |
template<typename T, size_t MAX_LEN = 200> | |
class Queue | |
{ | |
queue<T> q; | |
mutex q_mutex; | |
condition_variable cv_not_empty; | |
condition_variable cv_not_full; | |
public: | |
void put(const T& v) | |
{ | |
unique_lock<mutex> lk(q_mutex); | |
if (q.size() < MAX_LEN) | |
{ | |
q.push(v); | |
lk.unlock(); | |
cv_not_empty.notify_one(); | |
} else | |
{ | |
cv_not_full.wait(lk, [&]() { return q.size() < MAX_LEN; }); | |
q.push(v); // thread-safe, since unique_lock will release the lock when exception thrown | |
lk.unlock(); | |
cv_not_empty.notify_one(); | |
} | |
} | |
void fetch(T& v) // for exception safety, we cannot return T | |
{ | |
unique_lock<mutex> lk(q_mutex); | |
if (!q.empty()) | |
{ | |
v = q.front(); | |
q.pop(); | |
cv_not_full.notify_one(); | |
} else | |
{ | |
cv_not_empty.wait(lk, [&]() { return !q.empty(); }); | |
v = q.front(); | |
q.pop(); | |
lk.unlock(); | |
cv_not_empty.notify_one(); | |
cv_not_full.notify_one(); | |
} | |
} | |
}; | |
template<std::size_t W, std::size_t H> | |
class Matrix; | |
template<std::size_t W, std::size_t H> | |
class Matrix | |
{ | |
public: | |
unique_ptr<int32_t[]> buf; | |
int32_t *start; | |
static const std::size_t WIDTH = W, HEIGHT = H; | |
static const std::size_t ACTUAL_HEIGHT = HEIGHT % 8 == 0 ? HEIGHT: (HEIGHT / 8 + 1) *8; | |
static const size_t HEIGHT_PACK = ACTUAL_HEIGHT / 8; | |
static const std::size_t BUF_SIZE = ACTUAL_HEIGHT * WIDTH * sizeof(int32_t) + 48; | |
Matrix(): buf(new int[BUF_SIZE]), start(buf.get()) | |
{ | |
size_t size = BUF_SIZE; | |
if (!align(32, sizeof(int32_t), (void *&)start, size)) | |
throw std::runtime_error("size too small"); | |
} | |
int32_t* operator[](size_t w) | |
{ | |
return start + sizeof(int[ACTUAL_HEIGHT]) * w; | |
} | |
const int32_t* operator[](size_t w) const | |
{ | |
return start + sizeof(int[ACTUAL_HEIGHT]) * w; | |
} | |
void rand(int max = 1024) | |
{ | |
for (int i=0; i<W; ++i) | |
for (int j=0; j<H; ++j) | |
(*this)[i][j] = ::rand() % max; | |
} | |
void zero() | |
{ | |
memset(buf.get(), 0, BUF_SIZE); | |
} | |
}; | |
template<std::size_t W, std::size_t H> | |
ostream& operator<< (ostream &o, const Matrix<W, H> &m) | |
{ | |
for (int i=0; i<W; ++i) | |
{ | |
for (int j=0; j<H; ++j) | |
o << m[i][j] << '\t'; | |
if (i + 1 != W) o << endl; | |
} | |
return o; | |
} | |
template<std::size_t W, std::size_t K, std::size_t N> | |
void mult_worker(Queue<int> &q, const Matrix<W, K>& a, const Matrix<K, N>& b, Matrix<W, N> &result) | |
{ | |
int pos; | |
for (;;) | |
{ | |
q.fetch(pos); | |
if (pos< 0) | |
return; | |
for (int i=0; i<K; ++i) | |
{ | |
const __m256i factor = _mm256_broadcastd_epi32((__m128i) (__v4si) {a[pos][i], a[pos][i], a[pos][i], a[pos][i]}); | |
for (int j=0; j<N; j+=8) | |
{ | |
__m256i res = _mm256_mullo_epi32(factor, (__m256i&)b[i][j]); | |
(__m256i &)(result[pos][j]) += res; | |
//result[pos][j] += factor * b[i][j]; | |
} | |
} | |
} | |
} | |
template<std::size_t W, std::size_t K, std::size_t N> | |
Matrix<W, N> mult(const Matrix<W, K>& a, const Matrix<K, N>& b) | |
{ | |
Matrix<W, N> result; | |
result.zero(); | |
Queue<int> q; | |
constexpr int worker_num = 8; | |
vector<thread> threads; | |
threads.reserve(worker_num); | |
for (int i=0; i<worker_num; ++i) | |
threads.push_back(thread(mult_worker<W, K, N>, ref(q), cref(a), cref(b), ref(result))); | |
for (int i=0; i<W; ++i) | |
q.put(i); | |
for (int i=0; i<worker_num; ++i) | |
q.put(-1); | |
cout << "Put!" << endl; | |
for_each(begin(threads), end(threads), [](thread &t) { | |
cout << "joining" << endl; | |
t.join(); | |
}); | |
return result; | |
} | |
template<std::size_t W, std::size_t K, std::size_t N> | |
Matrix<W, N> mult_raw(const Matrix<W, K>& a, const Matrix<K, N>& b) | |
{ | |
Matrix<W, N> result; | |
result.zero(); | |
for (int j=0; j<K; ++j) | |
for (int i=0; i<W; ++i) | |
for (int k=0; k<N; ++k) | |
result[i][k] += a[i][j] * b[j][k]; | |
return result; | |
} | |
int main() | |
{ | |
static constexpr bool PARALLEL_ENABLED = true; | |
static constexpr bool BENCHMARK_ENABLED = true; | |
if (BENCHMARK_ENABLED) | |
{ | |
Matrix<2000, 2000> m1, m2; | |
cout << "Benchmarking parallel ..." << endl; | |
m1.rand(); | |
m2.rand(); | |
auto start_time = chrono::high_resolution_clock::now(); | |
auto result1 = mult(m1, m2); | |
auto end_time = chrono::high_resolution_clock::now(); | |
auto time1_ms = chrono::duration_cast<chrono::milliseconds>(end_time-start_time).count(); | |
cout << " time: " << time1_ms << "ms" << endl; | |
cout << "Benchmarking common ..." << endl; | |
start_time = chrono::high_resolution_clock::now(); | |
auto result2 = mult_raw(m1, m2); | |
end_time = chrono::high_resolution_clock::now(); | |
auto time2_ms = chrono::duration_cast<chrono::milliseconds>(end_time-start_time).count(); | |
cout << " time: " << time2_ms << "ms" << endl; | |
cout << "Acceleration ratio: " << time2_ms * 1.0 / time1_ms << endl; | |
cout << "Comparing result1 & result2..." << endl; | |
bool ok = true; | |
for (int i=0; i<decltype(result1)::WIDTH; ++i) | |
for (int j=0; j<decltype(result1)::HEIGHT;++j) | |
if (result1[i][j] != result2[i][j]) | |
{ | |
cout << "failed at " << i << "," << j << endl; | |
ok = false; | |
break; | |
} | |
if (ok) | |
cout << "Correct!" << endl; | |
else | |
cout << "Failed!" << endl; | |
} | |
Matrix<2000, 2000> m1; m1.rand(); | |
Matrix<2000, 2000> m2; m2.rand(); | |
//cout << m1 << endl; | |
//cout << m2 << endl; | |
if (PARALLEL_ENABLED) | |
{ | |
volatile auto result = mult(m1, m2); | |
} else | |
{ | |
volatile auto result = mult_raw(m1, m2); | |
} | |
//cout << result << endl; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment