Skip to content

Instantly share code, notes, and snippets.

@wanji
Last active January 1, 2018 00:28
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wanji/25d6f6a02e12da324b60a734b63d8c94 to your computer and use it in GitHub Desktop.
Save wanji/25d6f6a02e12da324b60a734b63d8c94 to your computer and use it in GitHub Desktop.
Understanding GEMM

Userful links

Memo

void cblas_sgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
                 const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
                 const int K, const float alpha, const float *A,
                 const int lda, const float *B, const int ldb,
                 const float beta, float *C, const int ldc);

C = alpha * op(A) * op(B) + beta * C

  1. op(A) = A' if TransA is set, otherwise op(A) = A. op(B) is similar.
  2. op(A) is MxK, op(B) is KxN, C is MxN
  3. lda always be the number of cols of A (at least), no matter TransA is set or not. ldb is similar.

NOTES on caffe

Matrix are stored in row-major order in CPU but in col-major order in GPU. So caffe_cpu_gemm computes C=A*B while caffe_gpu_gemm computes C'=B'*A'.

/*************************************************************************
> File Name: blastoy.cpp
> Copyright (C) 2013 Wan Ji<wanji@live.com>
> Created Time: 2016年04月19日 星期二 16时55分52秒
> Descriptions:
************************************************************************/
#include <cblas.h>
#include <iostream>
#include <stdio.h>
#include <stdlib.h>
using namespace std;
template<typename T>
class Matrix {
public:
Matrix(int nr, int nc, int step=0) : nr_(nr), nc_(nc), step_(step) {
if (step == 0) {
step_ = nc * sizeof(T);
}
data_ = new char[nr_ * step_];
}
~Matrix() {
delete [] data_;
}
void print() {
cout << "--------------------------------" << endl;
for (int r=0; r<nr_; ++r) {
T * p = ptr(r);
for (int c=0; c<nc_; ++c) {
cout << p[c] << '\t';
}
cout << '\n';
}
cout << "--------------------------------" << endl;
}
T * ptr() {
return ptr(0);
}
T * ptr(int r) {
return (T *)(data_ + r * step_);
}
private:
int nr_, nc_, step_;
char * data_;
};
int main(int argc, char * argv[]) {
int M = 2;
int K = 3;
int N = 4;
Matrix<float> a(M, K);
Matrix<float> b(K, N);
Matrix<float> c(M, N);
Matrix<float> d(K, N);
Matrix<float> e(K, K);
for (int i=0; i<6; ++i) {
a.ptr()[i] = i;
b.ptr()[i] = i;
}
for (int i=6; i<12; ++i) {
b.ptr()[i] = i;
}
a.print();
b.print();
cout << "### C = A * B" << endl;
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans,
M, N, K,
1.0, a.ptr(), K, b.ptr(), N,
0.0, c.ptr(), N);
c.print();
cout << "### D = A' * C" << endl;
cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans,
K, N, M,
1.0, a.ptr(), K, c.ptr(), N,
0.0, d.ptr(), N);
d.print();
cout << "### E = D * B'" << endl;
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
K, K, N,
1.0, d.ptr(), N, b.ptr(), N,
0.0, e.ptr(), K);
e.print();
return 0;
}
#!/usr/bin/env python
# coding: utf-8
"""
File Name: blastoy.py
Author: Wan Ji
E-mail: wanji@live.com
Created on: 2016年04月19日 星期二 19时03分09秒
Description:
"""
import numpy as np
a = np.arange(6).reshape(2, 3)
b = np.arange(12).reshape(3, 4)
a.dot(b)
c = a.dot(b)
d = a.T.dot(c)
e = d.dot(b.T)
print a
print b
print c
print d
print e
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment