Last active
July 25, 2023 21:59
-
-
Save goerch/a14a582fb31182d962dbfc191baec097 to your computer and use it in GitHub Desktop.
sgemm with clang vectorization
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <cstdlib> | |
#include <ctime> | |
#include <iostream> | |
#include <queue> | |
#include <vector> | |
#include <thread> | |
#include <future> | |
#define restrict __restrict | |
// #define aligned __declspec(align(32)) | |
#define aligned __attribute__((aligned(32))) | |
// #define assume_aligned(ptr, sz) | |
#define assume_aligned(ptr, sz) ((void)__builtin_assume_aligned(ptr, sz)) | |
// #define always_inline __forceinline | |
#define always_inline __attribute__((always_inline)) | |
template<class T> | |
void benchmark(const std::function<T()>& f, float ops, int trials = 10) { | |
auto min_time = std::numeric_limits<float>::max(); | |
auto sum_time = 0.f; | |
auto max_time = 0.f; | |
T result(0); | |
for (auto trial = 0; trial < trials; ++trial) { | |
auto t_start = std::chrono::high_resolution_clock::now(); | |
result += f(); | |
auto t_end = std::chrono::high_resolution_clock::now(); | |
auto time = std::chrono::duration<float, std::milli>(t_end - t_start).count(); | |
min_time = std::min(min_time, time); | |
sum_time += time; | |
max_time = std::max(max_time, time); | |
} | |
auto avg_time = sum_time / trials; | |
result /= trials; | |
std::cout | |
<< "Wall clock time passed: min = " << min_time << " ms, avg = " << avg_time << " ms, max = " << max_time << "\n" | |
// << ops * 1000.f / time / 1024.f / 1024.f / 1024.f << " GFlops, " | |
<< "min = " << ops * 1000.f / max_time / 1000.f / 1000.f / 1000.f << " GFlops\n" | |
<< "avg = " << ops * 1000.f / avg_time / 1000.f / 1000.f / 1000.f << " GFlops\n" | |
<< "max = " << ops * 1000.f / min_time / 1000.f / 1000.f / 1000.f << " GFlops\n" | |
<< "Result: " << result << "\n"; | |
} | |
class pool_t { | |
public: | |
size_t size; | |
std::queue<std::function<void()>> tasks; | |
std::vector<std::thread> workers; | |
std::queue<std::future<void>> futures; | |
std::mutex mutex; | |
std::condition_variable condition_variable; | |
bool stop; | |
always_inline pool_t(std::size_t size) : | |
size(size), tasks(), workers(), futures(), | |
mutex(), condition_variable(), stop(false) { | |
for (auto index = 0; index < size - 1; ++index) | |
workers.emplace_back([this, index] { run(index); }); | |
} | |
pool_t(const pool_t& pool) = delete; | |
pool_t& operator=(const pool_t& pool) = delete; | |
pool_t(pool_t&& pool) noexcept = delete; | |
pool_t& operator=(pool_t&& pool) = delete; | |
always_inline ~pool_t() { | |
{ | |
std::unique_lock<std::mutex> lock(mutex); | |
stop = true; | |
} | |
condition_variable.notify_all(); | |
for (auto& worker : workers) | |
worker.join(); | |
} | |
template<class F, class... Args> | |
void enqueue(F&& f, Args&&... args) { | |
auto task = std::make_shared<std::packaged_task<void()>>( | |
std::bind(std::forward<F>(f), std::forward<Args>(args)...)); | |
futures.emplace(task->get_future()); | |
{ | |
std::unique_lock<std::mutex> lock(mutex); | |
if (stop) | |
throw std::runtime_error("pool is stopping"); | |
tasks.emplace([task]() { (*task)(); }); | |
} | |
condition_variable.notify_one(); | |
} | |
void run(int index) { | |
for (;;) { | |
std::function<void()> task; | |
{ | |
std::unique_lock<std::mutex> lock(mutex); | |
if (index < size - 1) { | |
condition_variable.wait(lock, | |
[this] { return stop || !tasks.empty(); }); | |
if (stop && tasks.empty()) | |
return; | |
} | |
else if (tasks.empty()) | |
return; | |
task = std::move(tasks.front()); | |
tasks.pop(); | |
} | |
task(); | |
} | |
} | |
void dequeue() { | |
run(size - 1); | |
while (!futures.empty()) | |
{ | |
futures.front().get(); | |
futures.pop(); | |
} | |
} | |
}; | |
void sgemm(const int m, const int n, const int k, | |
const float* restrict A, const int ldA, | |
const float* restrict B, const int ldB, | |
float* restrict C, const int ldC) { | |
assume_aligned(A, 32); | |
assume_aligned(B, 32); | |
assume_aligned(C, 32); | |
for (int im = 0; im < m; ++im) { | |
for (int ik = 0; ik < k; ++ik) { | |
for (int in = 0; in < n; ++in) { | |
C[im * ldC + in] += A[im * ldA + ik] * B[ik * ldB + in]; | |
} | |
} | |
} | |
} | |
void sgemm(const int bm, const int em, const int bn, const int en, const int bk, int ek, | |
const float* restrict A, int ldA, | |
const float* restrict B, int ldB, | |
float* restrict C, int ldC) { | |
assume_aligned(A, 32); | |
assume_aligned(B, 32); | |
assume_aligned(C, 32); | |
for (int im = bm; im < em; ++im) { | |
for (int ik = bk; ik < ek; ++ik) { | |
for (int in = bn; in < en; ++in) { | |
C[im * ldC + in] += A[im * ldA + ik] * B[ik * ldB + in]; | |
} | |
} | |
} | |
} | |
void tiled_sgemm(const int m, const int n, const int k, | |
const float* restrict A, const int ldA, | |
const float* restrict B, const int ldB, | |
float* restrict C, int ldC, | |
const int tsk) { | |
assume_aligned(A, 32); | |
assume_aligned(B, 32); | |
assume_aligned(C, 32); | |
for (int btk = 0; btk < k; btk += tsk) { | |
int etk = std::min<int>(btk + tsk, k); | |
sgemm(0, m, 0, n, btk, etk, A, ldA, B, ldB, C, ldC); | |
} | |
} | |
void tiled_sgemm_parallel(const int m, const int n, const int k, | |
const float* restrict A, const int ldA, | |
const float* restrict B, const int ldB, | |
float* restrict C, int ldC, | |
const int tsm, const int tsn, const int tsk, pool_t& pool) { | |
assume_aligned(A, 32); | |
assume_aligned(B, 32); | |
assume_aligned(C, 32); | |
for (int btm = 0; btm < m; btm += tsm) { | |
int etm = std::min<int>(btm + tsm, m); | |
for (int btn = 0; btn < n; btn += tsn) { | |
int etn = std::min<int>(btn + tsn, n); | |
pool.enqueue([btm, etm, btn, etn, k, &A, ldA, &B, ldB, &C, ldC, tsk] { | |
for (int btk = 0; btk < k; btk += tsk) { | |
int etk = std::min<int>(btk + tsk, k); | |
sgemm(btm, etm, btn, etn, btk, etk, A, ldA, B, ldB, C, ldC); | |
} | |
}); | |
} | |
} | |
pool.dequeue(); | |
} | |
// #define SIMPLE | |
// #define TILED | |
#define TILED_PARALLEL | |
#ifdef SIMPLE | |
#define M 512 | |
#define N 512 | |
#define K 512 | |
#endif | |
#ifdef TILED | |
#define M 1024 | |
#define N 1024 | |
#define K 1024 | |
#endif | |
#ifdef TILED_PARALLEL | |
#define M 2048 | |
#define N 2048 | |
#define K 2048 | |
#endif | |
aligned float A[M * K]; | |
aligned float B[K * N]; | |
aligned float C[M * N]; | |
int main() { | |
auto trials = 100; | |
for (int im = 0; im < M; ++im) | |
for (int ik = 0; ik < K; ++ik) | |
A[im * K + ik] = 1.f; | |
for (int ik = 0; ik < K; ++ik) | |
for (int in = 0; in < N; ++in) | |
B[ik * N + in] = 1.f; | |
#ifdef SIMPLE | |
std::cout << "simple" << std::endl; | |
benchmark<float>([&]() { | |
for (int im = 0; im < K; ++im) | |
for (int in = 0; in < N; ++in) | |
C[im * N + in] = 0.f; | |
sgemm(M, N, K, A, K, B, N, C, N); | |
return C[M * N - 1]; | |
}, 2.f * M * N * K, trials); | |
#endif | |
#ifdef TILED | |
std::cout << "tiled" << std::endl; | |
benchmark<float>([&]() { | |
for (int im = 0; im < K; ++im) | |
for (int in = 0; in < N; ++in) | |
C[im * N + in] = 0.f; | |
tiled_sgemm(M, N, K, A, K, B, N, C, N, 16); | |
return C[M * N - 1]; | |
}, 2.f * M * N * K, trials); | |
#endif | |
#ifdef TILED_PARALLEL | |
std::cout << "tiled parallel" << std::endl; | |
pool_t pool(12); | |
benchmark<float>([&]() { | |
for (int im = 0; im < K; ++im) | |
for (int in = 0; in < N; ++in) | |
C[im * N + in] = 0.f; | |
tiled_sgemm_parallel(M, N, K, A, K, B, N, C, N, 16, N, 16, pool); | |
return C[M * N - 1]; | |
}, 2.f * M * N * K, trials); | |
#endif | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Compiled with -march=native and -ffast-math. I believe Clang vectorizes a lot. Reaches up to around 170 GFlops on my Core i7. MKL gets around 300 - 400 GFlops.