Skip to content

Instantly share code, notes, and snippets.

@harperjiang
Last active March 5, 2017 23:14
Show Gist options
  • Save harperjiang/5ae6d9efb0da22b4c23e0356ecf34c7a to your computer and use it in GitHub Desktop.
Save harperjiang/5ae6d9efb0da22b4c23e0356ecf34c7a to your computer and use it in GitHub Desktop.
ND4j concat takes considerable long time
public class Slow {
public static void main(String[] args) {
int hiddenDim = 200;
int numChar = 100;
int length = 500;
int batchSize = 50;
int[] pshape = new int[]{numChar, hiddenDim};
INDArray c2v = xavier(pshape);
INDArray h0 = Nd4j.zeros(batchSize, hiddenDim);
INDArray c0 = Nd4j.zeros(batchSize, hiddenDim);
long start = System.currentTimeMillis();
INDArray fwdmap = Nd4j.zeros(batchSize, numChar);
for (int i = 0; i < length; i++) {
INDArray embed = fwdmap.mmul(c2v); // 1
INDArray concat = Nd4j.concat(1, embed, h0); // 2
}
System.out.println(System.currentTimeMillis() - start);
}
static INDArray xavier(int[] shape) {
int n = 1;
for (int i = 0; i < shape.length - 1; i++)
n *= shape[i];
double sd = Math.sqrt(3d / n);
return new UniformDistribution(-sd, sd).sample(shape);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment