Last active
May 4, 2017 02:37
-
-
Save gooooloo/38b25a9107394522a0fffeb7294f416b to your computer and use it in GitHub Desktop.
Codes to implement the approximation of exp(x),tanh(x),sigmoid of DeepVoice paper https://arxiv.org/pdf/1702.07825.pdf
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* Codes to implement the approximation of compute(x) of paper https://arxiv.org/pdf/1702.07825.pdf | |
*/ | |
import java.util.Random; | |
public class Test { | |
static float eTilde(float x) { | |
final float x2 = x * x; | |
final float x4 = x2 * x2; | |
return (float) (1 + Math.abs(x) + 0.5658 * x2 + 0.143 * x4); | |
} | |
interface I { float compute(float x); } | |
static class FastTanh implements I { | |
public float compute(float x) { | |
// check Appendices C.1. in paper https://arxiv.org/pdf/1702.07825.pdf | |
final float eTilde = eTilde(x); | |
final float eTilde1 = 1/eTilde; | |
final float epsilon = 0.000001f; | |
final int sign = (x > epsilon) ? 1 : (x > -epsilon) ? 0 : -1; | |
return sign * (eTilde - eTilde1) / (eTilde + eTilde1); | |
} | |
} | |
static class SlowTanh implements I { | |
public float compute(float x) { | |
// tanh(x) = 2 * sigmoid(2x) - 1 | |
// sigmoid(x) = exp(x) / (1 + exp(x)) | |
final float exp = (float) Math.exp(2*x); | |
final float sigmoid = exp / (1 + exp); | |
return 2 * sigmoid - 1; | |
} | |
} | |
static class FastExp implements I { | |
public float compute(float x) { | |
// check Appendices C.2. in paper https://arxiv.org/pdf/1702.07825.pdf | |
final float ln2 = 0.69314718056f; | |
x = x / ln2; | |
final float z = (float) (x - Math.floor(x)); | |
final float gz = (float) (-4.7259162 + 27.7280233 / (4.84252568 - z) - 1.49012907 * z); | |
final float i2z = (x + 126 + gz) * (1 << 23); | |
return Float.intBitsToFloat((int) i2z); | |
} | |
} | |
static class SlowExp implements I { | |
public float compute(float x) { | |
return (float) Math.exp(x); | |
} | |
} | |
static class FastSigmoid implements I { | |
public float compute(float x) { | |
final float eTilde = eTilde(x); | |
return (x >= 0) | |
? eTilde / (1 + eTilde) | |
: 1 / (1 + eTilde); | |
} | |
} | |
static class SlowSigmoid implements I { | |
public float compute(float x) { | |
final float exp = (float) Math.exp(x); | |
return exp / (1 + exp); | |
} | |
} | |
private static void test(I exp) { | |
Random r = new Random(2342358787l); | |
for (int i = 0; i < 10000000; i++) { | |
float x = r.nextFloat() * 10000; | |
float e = exp.compute(x); | |
} | |
} | |
private static float staterror(I good, I bad, double epsilon) { | |
int cnt = 0; | |
int total = 10000000; | |
Random r = new Random(2342358787l); // seed is ad-hoc selection | |
for (int i = 0; i < total; i++) { | |
float x = r.nextFloat() * 10000000; // 10000000 is ad-hoc selection | |
float egood = good.compute(x); | |
float ebad = bad.compute(x); | |
if (Math.abs((egood - ebad) / egood) > epsilon) { | |
cnt++; | |
} | |
} | |
return (float) (1.0 * cnt / total); | |
} | |
private static void testSuite(I fastBad, I slowGood, String text) { | |
System.out.println(text + " test start"); | |
System.out.println(); | |
long ts = System.currentTimeMillis(); | |
test(slowGood); | |
long slowduration = System.currentTimeMillis() - ts; | |
ts = System.currentTimeMillis(); | |
test(fastBad); | |
long fastduration = System.currentTimeMillis() - ts; | |
System.out.println("Fast way takes " + fastduration + " ms, slow way takes " + slowduration + " ms"); | |
float epsilon = 0.00001f; | |
System.out.println(""+ (100*staterror(slowGood, fastBad, epsilon))+ "% have relative error rate > " + epsilon); | |
epsilon = 0.00002f; | |
System.out.println(""+ (100*staterror(slowGood, fastBad, epsilon))+ "% have relative error rate > " + epsilon); | |
epsilon = 0.00003f; | |
System.out.println(""+ (100*staterror(slowGood, fastBad, epsilon))+ "% have relative error rate > " + epsilon); | |
epsilon = 0.00004f; | |
System.out.println(""+ (100*staterror(slowGood, fastBad, epsilon))+ "% have relative error rate > " + epsilon); | |
epsilon = 0.00005f; | |
System.out.println(""+ (100*staterror(slowGood, fastBad, epsilon))+ "% have relative error rate > " + epsilon); | |
epsilon = 0.0001f; | |
System.out.println(""+ (100*staterror(slowGood, fastBad, epsilon))+ "% have relative error rate > " + epsilon); | |
epsilon = 0.001f; | |
System.out.println(""+ (100*staterror(slowGood, fastBad, epsilon))+ "% have relative error rate > " + epsilon); | |
System.out.println(text + " test end."); | |
System.out.println(); | |
} | |
/** | |
* execute result on my laptop: | |
tanh test start | |
Fast way takes 138 ms, slow way takes 3377 ms | |
1.0999999E-4% have relative error rate > 1.0E-5 | |
1.0E-4% have relative error rate > 2.0E-5 | |
1.0E-4% have relative error rate > 3.0E-5 | |
7.0E-5% have relative error rate > 4.0E-5 | |
6.0000002E-5% have relative error rate > 5.0E-5 | |
6.0000002E-5% have relative error rate > 1.0E-4 | |
5.0E-5% have relative error rate > 0.001 | |
tanh test end. | |
sigmoid test start | |
Fast way takes 209 ms, slow way takes 3747 ms | |
3.2999998E-4% have relative error rate > 1.0E-5 | |
3.0E-4% have relative error rate > 2.0E-5 | |
2.8E-4% have relative error rate > 3.0E-5 | |
2.7000002E-4% have relative error rate > 4.0E-5 | |
2.4999998E-4% have relative error rate > 5.0E-5 | |
2.1000001E-4% have relative error rate > 1.0E-4 | |
8.9999994E-5% have relative error rate > 0.001 | |
sigmoid test end. | |
exp test start | |
Fast way takes 384 ms, slow way takes 3140 ms | |
8.3000003E-4% have relative error rate > 1.0E-5 | |
4.8000002E-4% have relative error rate > 2.0E-5 | |
8.0E-5% have relative error rate > 3.0E-5 | |
0.0% have relative error rate > 4.0E-5 | |
0.0% have relative error rate > 5.0E-5 | |
0.0% have relative error rate > 1.0E-4 | |
0.0% have relative error rate > 0.001 | |
exp test end. | |
*/ | |
public static void main(String[] args) { | |
testSuite(new FastTanh(), new SlowTanh(), "tanh"); | |
testSuite(new FastSigmoid(), new SlowSigmoid(), "sigmoid"); | |
testSuite(new FastExp(), new SlowExp(), "exp"); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment