Skip to content

Instantly share code, notes, and snippets.

@pervognsen
Last active February 15, 2024 18:05
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pervognsen/ead20b921d138aa06e60bfe76c9ec720 to your computer and use it in GitHub Desktop.
Save pervognsen/ead20b921d138aa06e60bfe76c9ec720 to your computer and use it in GitHub Desktop.
struct AbstractMatrix {
int m; // number of rows
int n; // number of columns
// Pack block at ib, jb of size mb, nb into dest in row-major format.
virtual void pack_rowmajor(int ib, int jb, int mb, int nb, float *dest) const = 0;
// Unpack row-major matrix from src into block at ib, jb of size mb, nb.
virtual void unpack_rowmajor(int ib, int jb, int mb, int nb, const float *src) = 0;
// Pack block at ib, jb of size mb, nb into dest in column-major format.
virtual void pack_colmajor(int ib, int jb, int mb, int nb, float *dest) const = 0;
// Unpack column-major matrix from src into block at ib, jb of size mb, nb.
virtual void unpack_colmajor(int ib, int jb, int mb, int nb, const float *src) = 0;
};
struct StridedMatrix : AbstractMatrix {
float *base;
int di; // row stride
int dj; // column stride
virtual void pack_rowmajor(int ib, int jb, int mb, int nb, float *dest) const {
float *A = base + ib*di + jb*dj;
for (int i = 0; i < mb; i++) {
for (int j = 0; j < nb; j++) {
*dest++ = A[i*di + j*dj];
}
}
}
virtual void unpack_rowmajor(int ib, int jb, int mb, int nb, const float *src) {
float *A = base + ib*di + jb*dj;
for (int i = 0; i < mb; i++) {
for (int j = 0; j < nb; j++) {
A[i*di + j*dj] = *src++;
}
}
}
virtual void pack_colmajor(int ib, int jb, int mb, int nb, float *dest) const {
float *A = base + ib*di + jb*dj;
for (int j = 0; j < nb; j++) {
for (int i = 0; i < mb; i++) {
*dest++ = A[i*di + j*dj];
}
}
}
virtual void unpack_colmajor(int ib, int jb, int mb, int nb, const float *src) {
float *A = base + ib*di + jb*dj;
for (int j = 0; j < nb; j++) {
for (int i = 0; i < mb; i++) {
A[i*di + j*dj] = *src++;
}
}
}
};
void transpose_abstract(AbstractMatrix *A, const AbstractMatrix *B) {
int m = A->m;
int n = A->n;
assert(B->m == n);
assert(B->n == m);
enum { MB = 64, NB = 64 };
float buf[MB * NB];
for (int i = 0; i < m; i += MB) {
for (int j = 0; j < n; j += NB) {
// This copies the transpose of B's (j, i) block into A's (i, j) block.
B->pack_rowmajor(j, i, nb, mb, buf);
A->unpack_colmajor(i, j, mb, nb, buf);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment