Created
October 25, 2015 23:34
-
-
Save petewarden/eca9f31f0815cf638d13 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
void ReferenceEightBitIntGemm(bool transpose_a, bool transpose_b, | |
bool transpose_c, int m, int n, int k, | |
const uint8_t* a, int32_t a_offset, int lda, | |
const uint8_t* b, int32_t b_offset, int ldb, | |
uint8_t* c, int32_t c_offset, int32_t c_mult_int, | |
int32_t c_shift, int ldc) { | |
assert((c_shift >= 0) && (c_shift <= 32)); | |
assert(a != nullptr); | |
assert(b != nullptr); | |
assert(c != nullptr); | |
int a_i_stride; | |
int a_l_stride; | |
if (transpose_a) { | |
a_i_stride = lda; | |
a_l_stride = 1; | |
} else { | |
a_i_stride = 1; | |
a_l_stride = lda; | |
} | |
int b_j_stride; | |
int b_l_stride; | |
if (transpose_b) { | |
b_j_stride = 1; | |
b_l_stride = ldb; | |
} else { | |
b_j_stride = ldb; | |
b_l_stride = 1; | |
} | |
int c_i_stride; | |
int c_j_stride; | |
if (transpose_c) { | |
c_i_stride = ldc; | |
c_j_stride = 1; | |
} else { | |
c_i_stride = 1; | |
c_j_stride = ldc; | |
} | |
int i, j, l; | |
const std::int32_t kRoundingTerm = (c_shift < 1) ? 0 : (1 << (c_shift - 1)); | |
for (j = 0; j < n; j++) { | |
for (i = 0; i < m; i++) { | |
int32_t total = 0; | |
for (l = 0; l < k; l++) { | |
const int a_index = i * a_i_stride + l * a_l_stride; | |
const uint8_t a_as_byte = a[a_index]; | |
const int32_t a_as_int = static_cast<int32_t>(a_as_byte) + a_offset; | |
const int b_index = j * b_j_stride + l * b_l_stride; | |
const uint8_t b_as_byte = b[b_index]; | |
const int32_t b_as_int = static_cast<int32_t>(b_as_byte) + b_offset; | |
const int32_t mult_as_int = a_as_int * b_as_int; | |
total += mult_as_int; | |
} | |
int32_t output = | |
(((total + c_offset) * c_mult_int) + kRoundingTerm) >> c_shift; | |
if (output > 255) { | |
output = 255; | |
} | |
if (output < 0) { | |
output = 0; | |
} | |
const int c_index = i * c_i_stride + j * c_j_stride; | |
c[c_index] = static_cast<uint8_t>(output); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment