Skip to content

Instantly share code, notes, and snippets.

@luhenry
Last active September 26, 2022 18:53
Show Gist options
  • Save luhenry/e45b719462044b0889b481e38cb3b911 to your computer and use it in GitHub Desktop.
Save luhenry/e45b719462044b0889b481e38cb3b911 to your computer and use it in GitHub Desktop.
DotProduct
package dev.ludovic.presentation;
import java.util.Random;
public class BenchmarkSupport {
private static final Random rand = new Random(0);
static double[] randomDoubleArray(int n) {
double[] res = new double[n];
for (int i = 0; i < n; i++) {
res[i] = rand.nextDouble();
}
return res;
}
static int loopBound(int length, int stride) {
return length - (length % stride);
}
}
package dev.ludovic.presentation;
import jdk.incubator.vector.DoubleVector;
import jdk.incubator.vector.VectorOperators;
import jdk.incubator.vector.VectorSpecies;
import static dev.ludovic.presentation.BenchmarkSupport.*;
public class DotProduct {
public static double scalar(double[] x, double[] y, int n) {
double sum = 0.0;
for (int i = 0; i < n; i++) {
sum += x[i] * y[i];
}
return sum;
}
public static double scalarUnroll4(double[] x, double[] y, int n) {
double sum0 = 0.0;
double sum1 = 0.0;
double sum2 = 0.0;
double sum3 = 0.0;
int i = 0;
for (; i < loopBound(n, 4); i += 4) {
sum0 += x[i+0] * y[i+0];
sum1 += x[i+1] * y[i+1];
sum2 += x[i+2] * y[i+2];
sum3 += x[i+3] * y[i+3];
}
double sum = sum0 + sum1 + sum2 + sum3;
for (; i < n; i++) {
sum += x[i] * y[i];
}
return sum;
}
public static double scalarUnroll32(double[] x, double[] y, int n) {
double sum0 = 0.0;
double sum1 = 0.0;
double sum2 = 0.0;
double sum3 = 0.0;
double sum4 = 0.0;
double sum5 = 0.0;
double sum6 = 0.0;
double sum7 = 0.0;
double sum8 = 0.0;
double sum9 = 0.0;
double sum10 = 0.0;
double sum11 = 0.0;
double sum12 = 0.0;
double sum13 = 0.0;
double sum14 = 0.0;
double sum15 = 0.0;
double sum16 = 0.0;
double sum17 = 0.0;
double sum18 = 0.0;
double sum19 = 0.0;
double sum20 = 0.0;
double sum21 = 0.0;
double sum22 = 0.0;
double sum23 = 0.0;
double sum24 = 0.0;
double sum25 = 0.0;
double sum26 = 0.0;
double sum27 = 0.0;
double sum28 = 0.0;
double sum29 = 0.0;
double sum30 = 0.0;
double sum31 = 0.0;
int i = 0;
for (; i < loopBound(n, 32); i += 32) {
sum0 += x[i+0] * y[i+0];
sum1 += x[i+1] * y[i+1];
sum2 += x[i+2] * y[i+2];
sum3 += x[i+3] * y[i+3];
sum4 += x[i+4] * y[i+4];
sum5 += x[i+5] * y[i+5];
sum6 += x[i+6] * y[i+6];
sum7 += x[i+7] * y[i+7];
sum8 += x[i+8] * y[i+8];
sum9 += x[i+9] * y[i+9];
sum10 += x[i+10] * y[i+10];
sum11 += x[i+11] * y[i+11];
sum12 += x[i+12] * y[i+12];
sum13 += x[i+13] * y[i+13];
sum14 += x[i+14] * y[i+14];
sum15 += x[i+15] * y[i+15];
sum16 += x[i+16] * y[i+16];
sum17 += x[i+17] * y[i+17];
sum18 += x[i+18] * y[i+18];
sum19 += x[i+19] * y[i+19];
sum20 += x[i+20] * y[i+20];
sum21 += x[i+21] * y[i+21];
sum22 += x[i+22] * y[i+22];
sum23 += x[i+23] * y[i+23];
sum24 += x[i+24] * y[i+24];
sum25 += x[i+25] * y[i+25];
sum26 += x[i+26] * y[i+26];
sum27 += x[i+27] * y[i+27];
sum28 += x[i+28] * y[i+28];
sum29 += x[i+29] * y[i+29];
sum30 += x[i+30] * y[i+30];
sum31 += x[i+31] * y[i+31];
}
double sum = sum0 + sum1 + sum2 + sum3 + sum4 + sum5 + sum6 + sum7
+ sum8 + sum9 + sum10 + sum11 + sum12 + sum13 + sum14 + sum15
+ sum16 + sum17 + sum18 + sum19 + sum20 + sum21 + sum22 + sum23
+ sum24 + sum25 + sum26 + sum27 + sum28 + sum29 + sum30 + sum31;
for (; i < n; i++) {
sum += x[i] * y[i];
}
return sum;
}
final static VectorSpecies<Double> DMAX = DoubleVector.SPECIES_MAX;
public static double vector(double[] x, double[] y, int n) {
DoubleVector vsum = DoubleVector.zero(DMAX);
int i = 0;
for (; i < DMAX.loopBound(n); i += DMAX.length()) {
DoubleVector vx = DoubleVector.fromArray(DMAX, x, i);
DoubleVector vy = DoubleVector.fromArray(DMAX, y, i);
vsum = vx.fma(vy, vsum);
}
double sum = vsum.reduceLanes(VectorOperators.ADD);
for (; i < n; i++) {
sum += x[i] * y[i];
}
return sum;
}
public static double vectorUnroll4(double[] x, double[] y, int n) {
DoubleVector vsum0 = DoubleVector.zero(DMAX);
DoubleVector vsum1 = DoubleVector.zero(DMAX);
DoubleVector vsum2 = DoubleVector.zero(DMAX);
DoubleVector vsum3 = DoubleVector.zero(DMAX);
int i = 0;
for (; i < loopBound(n, DMAX.length() * 4); i += DMAX.length() * 4) {
DoubleVector vx0 = DoubleVector.fromArray(DMAX, x, i + DMAX.length() * 0);
DoubleVector vy0 = DoubleVector.fromArray(DMAX, y, i + DMAX.length() * 0);
DoubleVector vx1 = DoubleVector.fromArray(DMAX, x, i + DMAX.length() * 1);
DoubleVector vy1 = DoubleVector.fromArray(DMAX, y, i + DMAX.length() * 1);
DoubleVector vx2 = DoubleVector.fromArray(DMAX, x, i + DMAX.length() * 2);
DoubleVector vy2 = DoubleVector.fromArray(DMAX, y, i + DMAX.length() * 2);
DoubleVector vx3 = DoubleVector.fromArray(DMAX, x, i + DMAX.length() * 3);
DoubleVector vy3 = DoubleVector.fromArray(DMAX, y, i + DMAX.length() * 3);
vsum0 = vx0.fma(vy0, vsum0);
vsum1 = vx1.fma(vy1, vsum1);
vsum2 = vx2.fma(vy2, vsum2);
vsum3 = vx3.fma(vy3, vsum3);
}
double sum = vsum0.reduceLanes(VectorOperators.ADD)
+ vsum1.reduceLanes(VectorOperators.ADD)
+ vsum2.reduceLanes(VectorOperators.ADD)
+ vsum3.reduceLanes(VectorOperators.ADD);
for (; i < n; i++) {
sum += x[i] * y[i];
}
return sum;
}
public static double vectorUnroll8(double[] x, double[] y, int n) {
DoubleVector vsum0 = DoubleVector.zero(DMAX);
DoubleVector vsum1 = DoubleVector.zero(DMAX);
DoubleVector vsum2 = DoubleVector.zero(DMAX);
DoubleVector vsum3 = DoubleVector.zero(DMAX);
DoubleVector vsum4 = DoubleVector.zero(DMAX);
DoubleVector vsum5 = DoubleVector.zero(DMAX);
DoubleVector vsum6 = DoubleVector.zero(DMAX);
DoubleVector vsum7 = DoubleVector.zero(DMAX);
int i = 0;
for (; i < loopBound(n, DMAX.length() * 8); i += DMAX.length() * 8) {
DoubleVector vx0 = DoubleVector.fromArray(DMAX, x, i + DMAX.length() * 0);
DoubleVector vy0 = DoubleVector.fromArray(DMAX, y, i + DMAX.length() * 0);
DoubleVector vx1 = DoubleVector.fromArray(DMAX, x, i + DMAX.length() * 1);
DoubleVector vy1 = DoubleVector.fromArray(DMAX, y, i + DMAX.length() * 1);
DoubleVector vx2 = DoubleVector.fromArray(DMAX, x, i + DMAX.length() * 2);
DoubleVector vy2 = DoubleVector.fromArray(DMAX, y, i + DMAX.length() * 2);
DoubleVector vx3 = DoubleVector.fromArray(DMAX, x, i + DMAX.length() * 3);
DoubleVector vy3 = DoubleVector.fromArray(DMAX, y, i + DMAX.length() * 3);
DoubleVector vx4 = DoubleVector.fromArray(DMAX, x, i + DMAX.length() * 4);
DoubleVector vy4 = DoubleVector.fromArray(DMAX, y, i + DMAX.length() * 4);
DoubleVector vx5 = DoubleVector.fromArray(DMAX, x, i + DMAX.length() * 5);
DoubleVector vy5 = DoubleVector.fromArray(DMAX, y, i + DMAX.length() * 5);
DoubleVector vx6 = DoubleVector.fromArray(DMAX, x, i + DMAX.length() * 6);
DoubleVector vy6 = DoubleVector.fromArray(DMAX, y, i + DMAX.length() * 6);
DoubleVector vx7 = DoubleVector.fromArray(DMAX, x, i + DMAX.length() * 7);
DoubleVector vy7 = DoubleVector.fromArray(DMAX, y, i + DMAX.length() * 7);
vsum0 = vx0.fma(vy0, vsum0);
vsum1 = vx1.fma(vy1, vsum1);
vsum2 = vx2.fma(vy2, vsum2);
vsum3 = vx3.fma(vy3, vsum3);
vsum4 = vx4.fma(vy4, vsum4);
vsum5 = vx5.fma(vy5, vsum5);
vsum6 = vx6.fma(vy6, vsum6);
vsum7 = vx7.fma(vy7, vsum7);
}
double sum = vsum0.reduceLanes(VectorOperators.ADD)
+ vsum1.reduceLanes(VectorOperators.ADD)
+ vsum2.reduceLanes(VectorOperators.ADD)
+ vsum3.reduceLanes(VectorOperators.ADD)
+ vsum4.reduceLanes(VectorOperators.ADD)
+ vsum5.reduceLanes(VectorOperators.ADD)
+ vsum6.reduceLanes(VectorOperators.ADD)
+ vsum7.reduceLanes(VectorOperators.ADD);
for (; i < n; i++) {
sum += x[i] * y[i];
}
return sum;
}
}
package dev.ludovic.presentation;
import java.util.concurrent.TimeUnit;
import org.openjdk.jmh.annotations.*;
import static dev.ludovic.presentation.BenchmarkSupport.*;
@BenchmarkMode(Mode.Throughput)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@State(Scope.Thread)
@Fork(value = 3, jvmArgs = "--add-modules=jdk.incubator.vector")
@Warmup(iterations = 2, time = 5, timeUnit = TimeUnit.SECONDS)
@Measurement(iterations = 1, time = 5, timeUnit = TimeUnit.SECONDS)
public class DotProductBenchmark {
@Param({"100000"})
public int n;
public double[] x;
public double[] y;
@Setup
public void setup() {
x = randomDoubleArray(n);
y = randomDoubleArray(n);
}
@Benchmark
public double scalar() {
return DotProduct.scalar(x, y, n);
}
@Benchmark
public double scalarUnroll4() {
return DotProduct.scalarUnroll4(x, y, n);
}
@Benchmark
public double scalarUnroll32() {
return DotProduct.scalarUnroll32(x, y, n);
}
@Benchmark
public double vector() {
return DotProduct.vector(x, y, n);
}
@Benchmark
public double vectorUnroll4() {
return DotProduct.vectorUnroll4(x, y, n);
}
@Benchmark
public double vectorUnroll8() {
return DotProduct.vectorUnroll8(x, y, n);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment