-
-
Save wwwind/0b7413833cf8a7596bee084403224fa5 to your computer and use it in GitHub Desktop.
accuracy comparison between tf tanh/simoid and the proposed based on LUT
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "fixedpoint.h" | |
#include <stdio.h> | |
using namespace gemmlowp; | |
/* TensorFlow implementation of tanh */ | |
/* -8 to +8 input and -1 to +1 tanh result */ | |
int tf_tanh(int i) | |
{ | |
using F3 = gemmlowp::FixedPoint<std::int16_t, 3>; | |
F3 x = F3::FromRaw(i); | |
int z = gemmlowp::tanh(x).raw(); | |
return z; | |
} | |
/* TensorFlow implementation of sigmoid */ | |
/* -8 to +8 input and -1 to 1 sigmoid result - even though really 0-1 output range */ | |
int tf_sigmoid(int i) | |
{ | |
using F0 = gemmlowp::FixedPoint<std::int16_t, 0>; | |
using F3 = gemmlowp::FixedPoint<std::int16_t, 3>; | |
F3 x = F3::FromRaw(i); | |
int z = gemmlowp::logistic(x).raw(); | |
return z; | |
} | |
/* Reference double precision converted to fixed point as above */ | |
double ref_tanh(int i) | |
{ | |
double x = (double)i/(double)(1<<12); | |
double y = std::tanh(x); | |
return y; | |
} | |
double sigmoid(double x) | |
{ | |
return 1.0/(1.0 + std::exp(-x)); | |
} | |
double ref_sigmoid(int i) | |
{ | |
double y = sigmoid((double)i/(double)(1<<12)); | |
return y; | |
} | |
/* Table of sigmoid(i/24) at 0.16 precision */ | |
const uint16_t sigmoid_table_u16[256] = { | |
32768, 33451, 34133, 34813, 35493, 36169, 36843, 37513, | |
38180, 38841, 39498, 40149, 40794, 41432, 42064, 42688, | |
43304, 43912, 44511, 45102, 45683, 46255, 46817, 47369, | |
47911, 48443, 48964, 49475, 49975, 50464, 50942, 51409, | |
51865, 52311, 52745, 53169, 53581, 53983, 54374, 54755, | |
55125, 55485, 55834, 56174, 56503, 56823, 57133, 57433, | |
57724, 58007, 58280, 58544, 58800, 59048, 59288, 59519, | |
59743, 59959, 60168, 60370, 60565, 60753, 60935, 61110, | |
61279, 61441, 61599, 61750, 61896, 62036, 62172, 62302, | |
62428, 62549, 62666, 62778, 62886, 62990, 63090, 63186, | |
63279, 63368, 63454, 63536, 63615, 63691, 63765, 63835, | |
63903, 63968, 64030, 64090, 64148, 64204, 64257, 64308, | |
64357, 64405, 64450, 64494, 64536, 64576, 64614, 64652, | |
64687, 64721, 64754, 64786, 64816, 64845, 64873, 64900, | |
64926, 64950, 64974, 64997, 65019, 65039, 65060, 65079, | |
65097, 65115, 65132, 65149, 65164, 65179, 65194, 65208, | |
65221, 65234, 65246, 65258, 65269, 65280, 65291, 65301, | |
65310, 65319, 65328, 65337, 65345, 65352, 65360, 65367, | |
65374, 65381, 65387, 65393, 65399, 65404, 65410, 65415, | |
65420, 65425, 65429, 65433, 65438, 65442, 65445, 65449, | |
65453, 65456, 65459, 65462, 65465, 65468, 65471, 65474, | |
65476, 65479, 65481, 65483, 65485, 65488, 65489, 65491, | |
65493, 65495, 65497, 65498, 65500, 65501, 65503, 65504, | |
65505, 65507, 65508, 65509, 65510, 65511, 65512, 65513, | |
65514, 65515, 65516, 65517, 65517, 65518, 65519, 65520, | |
65520, 65521, 65522, 65522, 65523, 65523, 65524, 65524, | |
65525, 65525, 65526, 65526, 65526, 65527, 65527, 65528, | |
65528, 65528, 65529, 65529, 65529, 65529, 65530, 65530, | |
65530, 65530, 65531, 65531, 65531, 65531, 65531, 65532, | |
65532, 65532, 65532, 65532, 65532, 65533, 65533, 65533, | |
65533, 65533, 65533, 65533, 65533, 65534, 65534, 65534, | |
65534, 65534, 65534, 65534, 65534, 65534, 65534, 65535 | |
}; | |
uint16_t sigmoid_table(int i) | |
{ | |
int j = (i<0) ? -i : i; | |
double y = sigmoid((double)j/(double)(3<<3)); | |
/* Shift up table points to minimize absolute error */ | |
double ym = sigmoid(((double)j+0.5)/(double)(3<<3)); | |
double ye = sigmoid(((double)j+1.0)/(double)(3<<3)); | |
double yd = ym - ((y+ye)/2.0); | |
if (i>0) y = y + 0.5*yd; // add half mid point error to table point | |
uint16_t z = (int)floor(y*(double)(1<<16)+0.5); | |
uint16_t v = (j<255) ? z : 0xFFFF; // 256 x 16-bit entry table | |
return (i<0) ? (1<<16)-v : v; | |
} | |
// Computes sigmoid using the table - proposed solution. | |
int16_t test_sigmoid(int i) | |
{ | |
i = 3*i; // scale by 3/4 to expand range [-8,8]->[-10.7,10.7] | |
unsigned ui = (i>=0) ? i : -i; | |
unsigned uh = ui>>9; | |
unsigned ua = sigmoid_table_u16[uh]; | |
unsigned ub = sigmoid_table_u16[uh+1]; | |
unsigned ut = ui & 0x1FF; | |
unsigned ur = (ua<<9) + ut*(ub-ua); | |
ur = (i>=0) ? (ur+(1<<9))>>(10) : | |
((1<<(16+9))-ur+(1<<9)-1)>>(10); | |
return ur; | |
} | |
// Computes tanh using the table - proposed solution. | |
int16_t test_tanh(int i) | |
{ | |
i = 3*i; // scale by 3/4 to expand range [-8,8]->[-10.7,10.7] | |
unsigned ui = (i>=0) ? i : -i; | |
unsigned uh = ui>>8; | |
unsigned ur; | |
if (uh>=255) | |
{ | |
/* Saturate value to maximum */ | |
ur = 0xFFFF<<8; | |
} | |
else | |
{ | |
unsigned ua = sigmoid_table_u16[uh]; | |
unsigned ub = sigmoid_table_u16[uh+1]; | |
unsigned ut = ui & 0xFF; | |
ur = (ua<<8) + ut*(ub-ua); | |
} | |
ur = (i>=0) ? (ur - (1<<(14+9)) + (1<<(9-2)))>>(9-1) : | |
((1<<(14+9)) - (int)ur + (1<<(9-2))-1)>>(9-1); | |
return ur; | |
} | |
void test_function(int tanh_flag) | |
{ | |
int i; | |
float tf_err_neg = 0; | |
float tf_err_pos = 0; | |
double tf_err_sum = 0; | |
float test_err_neg = 0; | |
float test_err_pos = 0; | |
double test_err_sum = 0; | |
for (i=-32768; i<=32767; i++) | |
{ | |
float ref, tf, test, scale; | |
if (tanh_flag) | |
{ | |
scale = (float)(1<<15); | |
// Reference float value | |
ref = ref_tanh(i)*scale; | |
// Tensorflow value | |
tf = tf_tanh(i); | |
// Proposed implementation | |
test = test_tanh(i); | |
} | |
else | |
{ | |
scale = (float)(1<<15); | |
ref = ref_sigmoid(i)*scale; | |
tf = tf_sigmoid(i); | |
test = test_sigmoid(i); | |
} | |
// Error of the current tensorflow implementation | |
float tf_err = tf-ref; | |
// Error of the proposed implementation | |
float test_err = test-ref; | |
tf_err_sum += tf_err; | |
test_err_sum += test_err; | |
// Maximum/minimum of the error of tensorflow implementation | |
if (tf_err < tf_err_neg) tf_err_neg = tf_err; | |
if (tf_err > tf_err_pos) tf_err_pos = tf_err; | |
// Maximum/minimum of the error of the proposed implementation | |
if (test_err < test_err_neg) test_err_neg = test_err; | |
if (test_err > test_err_pos) test_err_pos = test_err; | |
if (test_err > 4 || test_err<-4) { | |
printf("i=%6d ref=%f tf=%f test=%f\n", i, ref, tf, test); | |
} | |
} | |
const char *name = tanh_flag ? "tanh " : "sigmoid"; | |
printf("tf %s implementation error %f to %f, error sum %f\n", name, tf_err_neg, tf_err_pos, (float)tf_err_sum); | |
printf("test %s implementation error %f to %f, error sum %f\n", name, test_err_neg, test_err_pos, (float)test_err_sum); | |
} | |
int main(void) | |
{ | |
test_function(0); | |
test_function(1); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
output is
Sigmoid:
tensorflow sigmoid implementation error -6.392090 to 5.392578, error sum -32767.042969
proposed sigmoid implementation error -1.015625 to 1.015625, error sum -0.043635
Tanh:
tensorflow tanh implementation error -11.791016 to 11.791016, error sum 0.992188
proposed tanh implementation error -1.476563 to 1.476563, error sum 0.992188