Last active
March 5, 2017 23:14
-
-
Save harperjiang/5ae6d9efb0da22b4c23e0356ecf34c7a to your computer and use it in GitHub Desktop.
ND4j concat takes considerable long time
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
public class Slow { | |
public static void main(String[] args) { | |
int hiddenDim = 200; | |
int numChar = 100; | |
int length = 500; | |
int batchSize = 50; | |
int[] pshape = new int[]{numChar, hiddenDim}; | |
INDArray c2v = xavier(pshape); | |
INDArray h0 = Nd4j.zeros(batchSize, hiddenDim); | |
INDArray c0 = Nd4j.zeros(batchSize, hiddenDim); | |
long start = System.currentTimeMillis(); | |
INDArray fwdmap = Nd4j.zeros(batchSize, numChar); | |
for (int i = 0; i < length; i++) { | |
INDArray embed = fwdmap.mmul(c2v); // 1 | |
INDArray concat = Nd4j.concat(1, embed, h0); // 2 | |
} | |
System.out.println(System.currentTimeMillis() - start); | |
} | |
static INDArray xavier(int[] shape) { | |
int n = 1; | |
for (int i = 0; i < shape.length - 1; i++) | |
n *= shape[i]; | |
double sd = Math.sqrt(3d / n); | |
return new UniformDistribution(-sd, sd).sample(shape); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment