Last active
September 26, 2022 18:53
-
-
Save luhenry/e45b719462044b0889b481e38cb3b911 to your computer and use it in GitHub Desktop.
DotProduct
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package dev.ludovic.presentation; | |
import java.util.Random; | |
public class BenchmarkSupport { | |
private static final Random rand = new Random(0); | |
static double[] randomDoubleArray(int n) { | |
double[] res = new double[n]; | |
for (int i = 0; i < n; i++) { | |
res[i] = rand.nextDouble(); | |
} | |
return res; | |
} | |
static int loopBound(int length, int stride) { | |
return length - (length % stride); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package dev.ludovic.presentation; | |
import jdk.incubator.vector.DoubleVector; | |
import jdk.incubator.vector.VectorOperators; | |
import jdk.incubator.vector.VectorSpecies; | |
import static dev.ludovic.presentation.BenchmarkSupport.*; | |
public class DotProduct { | |
public static double scalar(double[] x, double[] y, int n) { | |
double sum = 0.0; | |
for (int i = 0; i < n; i++) { | |
sum += x[i] * y[i]; | |
} | |
return sum; | |
} | |
public static double scalarUnroll4(double[] x, double[] y, int n) { | |
double sum0 = 0.0; | |
double sum1 = 0.0; | |
double sum2 = 0.0; | |
double sum3 = 0.0; | |
int i = 0; | |
for (; i < loopBound(n, 4); i += 4) { | |
sum0 += x[i+0] * y[i+0]; | |
sum1 += x[i+1] * y[i+1]; | |
sum2 += x[i+2] * y[i+2]; | |
sum3 += x[i+3] * y[i+3]; | |
} | |
double sum = sum0 + sum1 + sum2 + sum3; | |
for (; i < n; i++) { | |
sum += x[i] * y[i]; | |
} | |
return sum; | |
} | |
public static double scalarUnroll32(double[] x, double[] y, int n) { | |
double sum0 = 0.0; | |
double sum1 = 0.0; | |
double sum2 = 0.0; | |
double sum3 = 0.0; | |
double sum4 = 0.0; | |
double sum5 = 0.0; | |
double sum6 = 0.0; | |
double sum7 = 0.0; | |
double sum8 = 0.0; | |
double sum9 = 0.0; | |
double sum10 = 0.0; | |
double sum11 = 0.0; | |
double sum12 = 0.0; | |
double sum13 = 0.0; | |
double sum14 = 0.0; | |
double sum15 = 0.0; | |
double sum16 = 0.0; | |
double sum17 = 0.0; | |
double sum18 = 0.0; | |
double sum19 = 0.0; | |
double sum20 = 0.0; | |
double sum21 = 0.0; | |
double sum22 = 0.0; | |
double sum23 = 0.0; | |
double sum24 = 0.0; | |
double sum25 = 0.0; | |
double sum26 = 0.0; | |
double sum27 = 0.0; | |
double sum28 = 0.0; | |
double sum29 = 0.0; | |
double sum30 = 0.0; | |
double sum31 = 0.0; | |
int i = 0; | |
for (; i < loopBound(n, 32); i += 32) { | |
sum0 += x[i+0] * y[i+0]; | |
sum1 += x[i+1] * y[i+1]; | |
sum2 += x[i+2] * y[i+2]; | |
sum3 += x[i+3] * y[i+3]; | |
sum4 += x[i+4] * y[i+4]; | |
sum5 += x[i+5] * y[i+5]; | |
sum6 += x[i+6] * y[i+6]; | |
sum7 += x[i+7] * y[i+7]; | |
sum8 += x[i+8] * y[i+8]; | |
sum9 += x[i+9] * y[i+9]; | |
sum10 += x[i+10] * y[i+10]; | |
sum11 += x[i+11] * y[i+11]; | |
sum12 += x[i+12] * y[i+12]; | |
sum13 += x[i+13] * y[i+13]; | |
sum14 += x[i+14] * y[i+14]; | |
sum15 += x[i+15] * y[i+15]; | |
sum16 += x[i+16] * y[i+16]; | |
sum17 += x[i+17] * y[i+17]; | |
sum18 += x[i+18] * y[i+18]; | |
sum19 += x[i+19] * y[i+19]; | |
sum20 += x[i+20] * y[i+20]; | |
sum21 += x[i+21] * y[i+21]; | |
sum22 += x[i+22] * y[i+22]; | |
sum23 += x[i+23] * y[i+23]; | |
sum24 += x[i+24] * y[i+24]; | |
sum25 += x[i+25] * y[i+25]; | |
sum26 += x[i+26] * y[i+26]; | |
sum27 += x[i+27] * y[i+27]; | |
sum28 += x[i+28] * y[i+28]; | |
sum29 += x[i+29] * y[i+29]; | |
sum30 += x[i+30] * y[i+30]; | |
sum31 += x[i+31] * y[i+31]; | |
} | |
double sum = sum0 + sum1 + sum2 + sum3 + sum4 + sum5 + sum6 + sum7 | |
+ sum8 + sum9 + sum10 + sum11 + sum12 + sum13 + sum14 + sum15 | |
+ sum16 + sum17 + sum18 + sum19 + sum20 + sum21 + sum22 + sum23 | |
+ sum24 + sum25 + sum26 + sum27 + sum28 + sum29 + sum30 + sum31; | |
for (; i < n; i++) { | |
sum += x[i] * y[i]; | |
} | |
return sum; | |
} | |
final static VectorSpecies<Double> DMAX = DoubleVector.SPECIES_MAX; | |
public static double vector(double[] x, double[] y, int n) { | |
DoubleVector vsum = DoubleVector.zero(DMAX); | |
int i = 0; | |
for (; i < DMAX.loopBound(n); i += DMAX.length()) { | |
DoubleVector vx = DoubleVector.fromArray(DMAX, x, i); | |
DoubleVector vy = DoubleVector.fromArray(DMAX, y, i); | |
vsum = vx.fma(vy, vsum); | |
} | |
double sum = vsum.reduceLanes(VectorOperators.ADD); | |
for (; i < n; i++) { | |
sum += x[i] * y[i]; | |
} | |
return sum; | |
} | |
public static double vectorUnroll4(double[] x, double[] y, int n) { | |
DoubleVector vsum0 = DoubleVector.zero(DMAX); | |
DoubleVector vsum1 = DoubleVector.zero(DMAX); | |
DoubleVector vsum2 = DoubleVector.zero(DMAX); | |
DoubleVector vsum3 = DoubleVector.zero(DMAX); | |
int i = 0; | |
for (; i < loopBound(n, DMAX.length() * 4); i += DMAX.length() * 4) { | |
DoubleVector vx0 = DoubleVector.fromArray(DMAX, x, i + DMAX.length() * 0); | |
DoubleVector vy0 = DoubleVector.fromArray(DMAX, y, i + DMAX.length() * 0); | |
DoubleVector vx1 = DoubleVector.fromArray(DMAX, x, i + DMAX.length() * 1); | |
DoubleVector vy1 = DoubleVector.fromArray(DMAX, y, i + DMAX.length() * 1); | |
DoubleVector vx2 = DoubleVector.fromArray(DMAX, x, i + DMAX.length() * 2); | |
DoubleVector vy2 = DoubleVector.fromArray(DMAX, y, i + DMAX.length() * 2); | |
DoubleVector vx3 = DoubleVector.fromArray(DMAX, x, i + DMAX.length() * 3); | |
DoubleVector vy3 = DoubleVector.fromArray(DMAX, y, i + DMAX.length() * 3); | |
vsum0 = vx0.fma(vy0, vsum0); | |
vsum1 = vx1.fma(vy1, vsum1); | |
vsum2 = vx2.fma(vy2, vsum2); | |
vsum3 = vx3.fma(vy3, vsum3); | |
} | |
double sum = vsum0.reduceLanes(VectorOperators.ADD) | |
+ vsum1.reduceLanes(VectorOperators.ADD) | |
+ vsum2.reduceLanes(VectorOperators.ADD) | |
+ vsum3.reduceLanes(VectorOperators.ADD); | |
for (; i < n; i++) { | |
sum += x[i] * y[i]; | |
} | |
return sum; | |
} | |
public static double vectorUnroll8(double[] x, double[] y, int n) { | |
DoubleVector vsum0 = DoubleVector.zero(DMAX); | |
DoubleVector vsum1 = DoubleVector.zero(DMAX); | |
DoubleVector vsum2 = DoubleVector.zero(DMAX); | |
DoubleVector vsum3 = DoubleVector.zero(DMAX); | |
DoubleVector vsum4 = DoubleVector.zero(DMAX); | |
DoubleVector vsum5 = DoubleVector.zero(DMAX); | |
DoubleVector vsum6 = DoubleVector.zero(DMAX); | |
DoubleVector vsum7 = DoubleVector.zero(DMAX); | |
int i = 0; | |
for (; i < loopBound(n, DMAX.length() * 8); i += DMAX.length() * 8) { | |
DoubleVector vx0 = DoubleVector.fromArray(DMAX, x, i + DMAX.length() * 0); | |
DoubleVector vy0 = DoubleVector.fromArray(DMAX, y, i + DMAX.length() * 0); | |
DoubleVector vx1 = DoubleVector.fromArray(DMAX, x, i + DMAX.length() * 1); | |
DoubleVector vy1 = DoubleVector.fromArray(DMAX, y, i + DMAX.length() * 1); | |
DoubleVector vx2 = DoubleVector.fromArray(DMAX, x, i + DMAX.length() * 2); | |
DoubleVector vy2 = DoubleVector.fromArray(DMAX, y, i + DMAX.length() * 2); | |
DoubleVector vx3 = DoubleVector.fromArray(DMAX, x, i + DMAX.length() * 3); | |
DoubleVector vy3 = DoubleVector.fromArray(DMAX, y, i + DMAX.length() * 3); | |
DoubleVector vx4 = DoubleVector.fromArray(DMAX, x, i + DMAX.length() * 4); | |
DoubleVector vy4 = DoubleVector.fromArray(DMAX, y, i + DMAX.length() * 4); | |
DoubleVector vx5 = DoubleVector.fromArray(DMAX, x, i + DMAX.length() * 5); | |
DoubleVector vy5 = DoubleVector.fromArray(DMAX, y, i + DMAX.length() * 5); | |
DoubleVector vx6 = DoubleVector.fromArray(DMAX, x, i + DMAX.length() * 6); | |
DoubleVector vy6 = DoubleVector.fromArray(DMAX, y, i + DMAX.length() * 6); | |
DoubleVector vx7 = DoubleVector.fromArray(DMAX, x, i + DMAX.length() * 7); | |
DoubleVector vy7 = DoubleVector.fromArray(DMAX, y, i + DMAX.length() * 7); | |
vsum0 = vx0.fma(vy0, vsum0); | |
vsum1 = vx1.fma(vy1, vsum1); | |
vsum2 = vx2.fma(vy2, vsum2); | |
vsum3 = vx3.fma(vy3, vsum3); | |
vsum4 = vx4.fma(vy4, vsum4); | |
vsum5 = vx5.fma(vy5, vsum5); | |
vsum6 = vx6.fma(vy6, vsum6); | |
vsum7 = vx7.fma(vy7, vsum7); | |
} | |
double sum = vsum0.reduceLanes(VectorOperators.ADD) | |
+ vsum1.reduceLanes(VectorOperators.ADD) | |
+ vsum2.reduceLanes(VectorOperators.ADD) | |
+ vsum3.reduceLanes(VectorOperators.ADD) | |
+ vsum4.reduceLanes(VectorOperators.ADD) | |
+ vsum5.reduceLanes(VectorOperators.ADD) | |
+ vsum6.reduceLanes(VectorOperators.ADD) | |
+ vsum7.reduceLanes(VectorOperators.ADD); | |
for (; i < n; i++) { | |
sum += x[i] * y[i]; | |
} | |
return sum; | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package dev.ludovic.presentation; | |
import java.util.concurrent.TimeUnit; | |
import org.openjdk.jmh.annotations.*; | |
import static dev.ludovic.presentation.BenchmarkSupport.*; | |
@BenchmarkMode(Mode.Throughput) | |
@OutputTimeUnit(TimeUnit.MILLISECONDS) | |
@State(Scope.Thread) | |
@Fork(value = 3, jvmArgs = "--add-modules=jdk.incubator.vector") | |
@Warmup(iterations = 2, time = 5, timeUnit = TimeUnit.SECONDS) | |
@Measurement(iterations = 1, time = 5, timeUnit = TimeUnit.SECONDS) | |
public class DotProductBenchmark { | |
@Param({"100000"}) | |
public int n; | |
public double[] x; | |
public double[] y; | |
@Setup | |
public void setup() { | |
x = randomDoubleArray(n); | |
y = randomDoubleArray(n); | |
} | |
@Benchmark | |
public double scalar() { | |
return DotProduct.scalar(x, y, n); | |
} | |
@Benchmark | |
public double scalarUnroll4() { | |
return DotProduct.scalarUnroll4(x, y, n); | |
} | |
@Benchmark | |
public double scalarUnroll32() { | |
return DotProduct.scalarUnroll32(x, y, n); | |
} | |
@Benchmark | |
public double vector() { | |
return DotProduct.vector(x, y, n); | |
} | |
@Benchmark | |
public double vectorUnroll4() { | |
return DotProduct.vectorUnroll4(x, y, n); | |
} | |
@Benchmark | |
public double vectorUnroll8() { | |
return DotProduct.vectorUnroll8(x, y, n); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment