Skip to content

Instantly share code, notes, and snippets.

@swook
Created March 30, 2015 16:26
Show Gist options
  • Save swook/fd50bb7b8284c7411571 to your computer and use it in GitHub Desktop.
Save swook/fd50bb7b8284c7411571 to your computer and use it in GitHub Desktop.
hpc-gemm.cpp
#include <iostream>
#include <cassert>
#include <cmath>
#include <algorithm>
#include <boost/chrono.hpp>
#include <openblas/cblas.h>
#include <nmmintrin.h>
#include <omp.h>
#include "../lib/ArgumentParser.h"
#include "../lib/matrix.hpp"
typedef double value_t;
typedef hpcse::matrix<value_t, hpcse::column_major> cmatrix_t;
typedef hpcse::matrix<value_t, hpcse::row_major> rmatrix_t;
typedef boost::chrono::high_resolution_clock myclock_t;
/**
* Matrix dimensions
* a(M*N), b(N*K), c(M*K)
*/
unsigned int M, N, K;
// Assign memory to matrices
rmatrix_t a; // Should access row-by-row
cmatrix_t b; // Should access col-by-col
rmatrix_t c; // Should access row-by-row
rmatrix_t c_ref;
// This function does a naive matrix multiplication c = a*b
// We assume column_major ordering
void naive() {
for (size_t j = 0; j < M; j++)
for (size_t k = 0; k < K; k++) {
double sum = 0.;
for (size_t i = 0; i < N; i++)
sum += a(j,i) * b(i,k);
c(j,k) = sum;
}
}
void naive_simd() {
double* _a = a.data();
double* _b = b.data();
double* _c = c.data();
double* __restrict__ __a;
double* __restrict__ __b;
double sum;
for (size_t j = 0; j < M; j++)
for (size_t k = 0; k < K; k++) {
sum = 0.;
__a = &_a[0] + j*N;
__b = &_b[0] + k*K;
for (size_t i = 0; i < N; i++)
sum += __a[i] * __b[i];
_c[j*K+k] = sum;
}
}
void naive_simd_omp() {
double* _a = a.data();
double* _b = b.data();
double* _c = c.data();
double* __restrict__ __a;
double* __restrict__ __b;
#pragma omp parallel for schedule(static)
for (size_t j = 0; j < M; j++)
for (size_t k = 0; k < K; k++) {
double sum = 0.;
__a = &_a[0] + j*N;
__b = &_b[0] + k*K;
for (size_t i = 0; i < N; i++)
sum += __a[i] * __b[i];
_c[j*K+k] = sum;
}
}
void naive_omp() {
#pragma omp parallel for schedule(static)
for (size_t j = 0; j < M; j++)
for (size_t k = 0; k < K; k++) {
double sum = 0.;
c(j,k) = 0.;
for (size_t i = 0; i < N; i++)
sum += a(j,i) * b(i,k);
c(j,k) = sum;
}
}
// Blocked version of naive gemm
// Calculate result cell values in c in blocks
size_t blksize;
inline void blocked_kernel(double* __restrict__ a, double* __restrict__ b, double* c) {
double* __restrict__ _a;
double* __restrict__ _b;
for (size_t j = 0; j < blksize; j++)
for (size_t k = 0; k < blksize; k++) {
_a = a + j*N;
_b = b + k*K;
double sum = 0.;
for (size_t i = 0; i < blksize; i++)
sum += _a[i] * _b[i];
c[j*K+k] += sum;
}
}
void blocked_simd() {
assert(M % blksize == 0);
assert(N % blksize == 0);
assert(K % blksize == 0);
size_t bM = M / blksize,
bN = N / blksize,
bK = K / blksize,
_j, _i, _k;
double* _a = a.data();
double* _b = b.data();
double* _c = c.data();
for (size_t bj = 0; bj < bM; bj++)
for (size_t bi = 0; bi < bN; bi++)
for (size_t bk = 0; bk < bK; bk++) {
_j = bj*blksize;
_i = bi*blksize;
_k = bk*blksize;
blocked_kernel(
_a + _j*N + _i,
_b + _k*N + _i,
_c + _j*K + _k
);
}
}
void blocked_simd_omp() {
assert(M % blksize == 0);
assert(N % blksize == 0);
assert(K % blksize == 0);
size_t bM = M / blksize,
bN = N / blksize,
bK = K / blksize,
_j, _i, _k;
double* _a = a.data();
double* _b = b.data();
double* _c = c.data();
#pragma omp parallel for schedule(static) private(_j,_i,_k)
for (size_t bj = 0; bj < bM; bj++)
for (size_t bi = 0; bi < bN; bi++)
for (size_t bk = 0; bk < bK; bk++) {
_j = bj*blksize;
_i = bi*blksize;
_k = bk*blksize;
blocked_kernel(
_a + _j*N + _i,
_b + _k*N + _i,
_c + _j*K + _k
);
}
}
const double CblasZero = 0.,
CblasOne = 1.;
void blas_gemvs() {
// OpenBLAS bug: cannot have more than 44 threads
size_t maxT = omp_get_max_threads() > 44 ? 44 : omp_get_max_threads();
#pragma omp parallel for num_threads(maxT)
for (size_t i = 0; i < N; i++)
cblas_dgemv(CblasRowMajor, CblasNoTrans, M, N, CblasOne,
a.data(), M, b.data()+i*N, CblasOne, CblasZero,
c.data()+i, N);
}
void blas_ddots() {
// OpenBLAS bug: cannot have more than 44 threads
size_t maxT = omp_get_max_threads() > 44 ? 44 : omp_get_max_threads();
#pragma omp parallel for num_threads(maxT)
for (size_t j = 0; j < M; j++)
for (size_t k = 0; k < K; k++)
c(j,k) = cblas_ddot(N, a.data()+j*N, CblasOne, b.data()+k*N,
CblasOne);
}
/**
* Utility methods to use when running different methods
*/
bool equal_double(const double& a, const double& b) {
return a == b || std::abs((a-b)/a) < 1e-10;
}
class timer {
private:
boost::chrono::time_point<myclock_t> start;
std::string name;
public:
timer(std::string name) : name(name) {
// Initialise (reset) matrix c to allow first-touch policy
c = rmatrix_t(M, K);
start = myclock_t::now(); // Cache time
}
~timer() {
// Calculate time taken
boost::chrono::time_point<myclock_t> end = myclock_t::now();
double elapsed = boost::chrono::duration<double>(end-start).count();
std::cout << name << ":\t" << elapsed << std::endl;
// Validate result in matrix c
if (!std::equal(c.data(), c.data() + M*K, c_ref.data(), equal_double))
throw std::runtime_error("Incorrect GEMM result attained. Check code.");
}
};
int main(int argc, char* argv[]) {
ArgumentParser parser(argc, argv);
M = N = K = parser("-N").asInt(1600);
blksize = parser("-blksize").asInt(4);
const size_t NT = parser("-NT").asInt(1);
omp_set_dynamic(0);
omp_set_num_threads(NT);
openblas_set_num_threads(NT);
// Assign matrix dimensions
a = rmatrix_t(M, N);
b = cmatrix_t(N, K);
c = rmatrix_t(M, K);
c_ref = rmatrix_t(M, K);
// Fill matrices a and b with some values
std::generate_n(a.data(), M*N, std::rand);
std::generate_n(b.data(), N*K, std::rand);
/**
* Calculate reference solution
*
* void cblas_dgemm(const enum CBLAS_ORDER Order,
* const enum CBLAS_TRANSPOSE TransA,
* const enum CBLAS_TRANSPOSE TransB, const int M,
* const int N, const int K, const double alpha,
* const double *A, const int lda, const double *B,
* const int ldb, const double beta, double *C,
* const int ldc);
*/
cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, M, N, K, CblasOne,
a.data(), M, b.data(), K, CblasZero, c_ref.data(), M);
#ifdef SEQ
{
timer t = timer("naive (seq)");
naive();
}{
timer t = timer("naive (SIMD)");
naive_simd();
}{
timer t = timer("blocked (SIMD)");
blocked_simd();
}
#else
{
timer t = timer("GEMM (blas)");
cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, M, N, K, CblasOne,
a.data(), M, b.data(), K, CblasZero, c.data(), M);
}{
timer t = timer("GEMVs (blas)");
blas_gemvs();
}{
timer t = timer("DDOTs (blas)");
blas_ddots();
}{
timer t = timer("naive (OMP)");
naive_omp();
}{
timer t = timer("naive (SIMD+OMP)");
naive_simd_omp();
}{
timer t = timer("blocked (SIMD+OMP)");
blocked_simd_omp();
}
#endif
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment