Skip to content

Instantly share code, notes, and snippets.

@harperjiang
Created March 6, 2017 16:44
Show Gist options
  • Save harperjiang/847a1f7f02fd219553761ef06154c349 to your computer and use it in GitHub Desktop.
Save harperjiang/847a1f7f02fd219553761ef06154c349 to your computer and use it in GitHub Desktop.
public static void main(String[] args) {
fast();
slow();
}
static void fast() {
int hiddenDim = 200;
int numChar = 100;
int length = 500;
int batchSize = 50;
INDArray c2v = Nd4j.zeros(numChar, hiddenDim);
INDArray h0 = Nd4j.zeros(batchSize, hiddenDim);
INDArray c0 = Nd4j.zeros(batchSize, hiddenDim);
INDArray fwdmap = Nd4j.zeros(batchSize, numChar);
INDArray embed = fwdmap.mmul(c2v);
List<INDArray> embeds = new ArrayList<>();
List<INDArray> h0s = new ArrayList<>();
for (int x = 0; x < 1000; x++) {
embeds.add(Nd4j.createUninitialized(embed.shape()));
h0s.add(Nd4j.createUninitialized(h0.shape()));
}
long sum = 0;
for (int x = 0; x < embeds.size(); x++) {
long time1 = System.nanoTime();
INDArray concat = Nd4j.concat(1, embeds.get(x), h0s.get(x));
long time2 = System.nanoTime();
sum += time2 - time1;
}
System.out.println(sum / embeds.size());
}
static void slow() {
int hiddenDim = 200;
int numChar = 100;
int length = 500;
int batchSize = 50;
INDArray c2v = Nd4j.zeros(numChar, hiddenDim);
INDArray h0 = Nd4j.zeros(batchSize, hiddenDim);
INDArray c0 = Nd4j.zeros(batchSize, hiddenDim);
INDArray fwdmap = Nd4j.zeros(batchSize, numChar);
INDArray embed = fwdmap.mmul(c2v);
List<INDArray> embeds = new ArrayList<>();
List<INDArray> h0s = new ArrayList<>();
for (int x = 0; x < 1000; x++) {
embeds.add(Nd4j.createUninitialized(embed.shape()));
h0s.add(Nd4j.createUninitialized(h0.shape()));
}
long sum = 0;
for (int x = 0; x < embeds.size(); x++) {
embed = fwdmap.mmul(c2v);
long time1 = System.nanoTime();
INDArray concat = Nd4j.concat(1, embeds.get(x), h0s.get(x));
long time2 = System.nanoTime();
sum += time2 - time1;
}
System.out.println(sum / embeds.size());
}
static INDArray xavier(int[] shape) {
int n = 1;
for (int i = 0; i < shape.length - 1; i++)
n *= shape[i];
double sd = Math.sqrt(3d / n);
return new UniformDistribution(-sd, sd).sample(shape);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment