Skip to content

Instantly share code, notes, and snippets.

@tarsa
Last active January 16, 2021 11:41
mandelbrot from benchmarks game meets Vector API from Project Panama
package pl.tarsa;
import jdk.incubator.vector.DoubleVector;
import jdk.incubator.vector.VectorMask;
import jdk.incubator.vector.VectorSpecies;
import org.openjdk.jmh.annotations.*;
import org.openjdk.jmh.runner.Runner;
import org.openjdk.jmh.runner.RunnerException;
import org.openjdk.jmh.runner.options.Options;
import org.openjdk.jmh.runner.options.OptionsBuilder;
import java.io.BufferedOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.stream.IntStream;
public class mandelbrot_simd_1 {
@State(Scope.Benchmark)
public static class JmhConstants {
public static final int JMH_DATA_WIDTH = 10_000;
public static final double FAC = 2.0 / JMH_DATA_WIDTH;
public double Ci;
@SuppressWarnings("unused")
@Setup
public void doSetup() {
Ci = 0.0;
}
public final double[] aCr =
IntStream.range(0, JMH_DATA_WIDTH)
.parallel()
.mapToDouble(x -> x * FAC - 1.5)
.toArray();
public final byte[] bitsReversalMapping =
computeBitsReversalMapping();
}
@State(Scope.Thread)
public static class JmhWorkMem {
private final int sideLen = JmhConstants.JMH_DATA_WIDTH;
final byte[] row = new byte[(sideLen + 7) / 8];
final long[] rowChunks = new long[sideLen / 64];
}
@SuppressWarnings("UnusedReturnValue")
@Threads(1)
@Benchmark
public byte[] benchScalarST(JmhConstants constants, JmhWorkMem state) {
computeScalar(constants.Ci, constants.aCr, state.row, 0, false);
return state.row;
}
@SuppressWarnings("UnusedReturnValue")
@Threads(Threads.MAX)
@Benchmark
public byte[] benchScalarMT(JmhConstants constants, JmhWorkMem state) {
computeScalar(constants.Ci, constants.aCr, state.row, 0, false);
return state.row;
}
@SuppressWarnings("UnusedReturnValue")
@Threads(1)
@Benchmark
public byte[] benchScalarPairsST(JmhConstants constants, JmhWorkMem state) {
computeScalarPairs(constants.Ci, constants.aCr, state.row, 0);
return state.row;
}
@SuppressWarnings("UnusedReturnValue")
@Threads(Threads.MAX)
@Benchmark
public byte[] benchScalarPairsMT(JmhConstants constants, JmhWorkMem state) {
computeScalarPairs(constants.Ci, constants.aCr, state.row, 0);
return state.row;
}
@SuppressWarnings("UnusedReturnValue")
@Benchmark
public byte[] benchScalarRemainderOnly(JmhConstants constants,
JmhWorkMem state) {
computeScalar(constants.Ci, constants.aCr, state.row, 0, true);
return state.row;
}
@SuppressWarnings("UnusedReturnValue")
@Threads(1)
@Benchmark
public long[] benchVectorST(JmhConstants constants, JmhWorkMem state) {
computeChunksVector(constants.Ci, constants.aCr, state.rowChunks);
return state.rowChunks;
}
@SuppressWarnings("UnusedReturnValue")
@Threads(Threads.MAX)
@Benchmark
public long[] benchVectorMT(JmhConstants constants, JmhWorkMem state) {
computeChunksVector(constants.Ci, constants.aCr, state.rowChunks);
return state.rowChunks;
}
@SuppressWarnings("UnusedReturnValue")
@Benchmark
public byte[] benchVectorWithTransfer(JmhConstants constants,
JmhWorkMem state) {
computeChunksVector(constants.Ci, constants.aCr, state.rowChunks);
transferRowFlags(state.rowChunks, constants.bitsReversalMapping,
state.row, 0);
return state.row;
}
@SuppressWarnings("UnusedReturnValue")
@Threads(1)
@Benchmark
public byte[] benchRowST(JmhConstants constants, JmhWorkMem state) {
computeRow(constants.Ci, constants.aCr, constants.bitsReversalMapping,
state.rowChunks, state.row, 0);
return state.row;
}
@SuppressWarnings("UnusedReturnValue")
@Threads(Threads.MAX)
@Benchmark
public byte[] benchRowMT(JmhConstants constants, JmhWorkMem state) {
computeRow(constants.Ci, constants.aCr, constants.bitsReversalMapping,
state.rowChunks, state.row, 0);
return state.row;
}
private static final VectorSpecies<Double> SPECIES =
DoubleVector.SPECIES_PREFERRED.length() <= 8 ?
DoubleVector.SPECIES_PREFERRED : DoubleVector.SPECIES_512;
private static final int LANES = SPECIES.length();
private static final int LANES_LOG = Integer.numberOfTrailingZeros(LANES);
public static void main(String[] args) throws IOException, RunnerException {
if ((LANES > 8) || (LANES != (1 << LANES_LOG))) {
var errorMsg = "LANES must be a power of two and at most 8. " +
"Change SPECIES in the source code.";
throw new RuntimeException(errorMsg);
}
if (args.length == 0) {
// java-microbenchmark-harness benchmark run
Options opt = new OptionsBuilder()
.include(mandelbrot_simd_1.class.getSimpleName())
.forks(1)
.build();
new Runner(opt).run();
} else {
// benchmarks game mandelbrot run
var sideLen = Integer.parseInt(args[0]);
try (var out = new BufferedOutputStream(makeOut2())) {
var headerStr = String.format("P4\n%d %d\n", sideLen, sideLen);
out.write(headerStr.getBytes());
out.write(computeRows(sideLen));
}
}
}
@SuppressWarnings("unused")
// the version that avoids mixing up output with JVM diagnostic messages
private static OutputStream makeOut1() throws IOException {
return Files.newOutputStream(Path.of("mandelbrot_simd_1.pbm"));
}
@SuppressWarnings("unused")
// the version that is compatible with benchmark requirements
private static OutputStream makeOut2() {
return System.out;
}
private static byte[] computeRows(int sideLen) {
var threadRowChunks =
ThreadLocal.withInitial(() -> new long[sideLen / 64]);
var rowOutputSize = (sideLen + 7) / 8;
var rowsMerged = new byte[sideLen * rowOutputSize];
var numCpus = Runtime.getRuntime().availableProcessors();
var fac = 2.0 / sideLen;
var aCr = IntStream.range(0, sideLen).parallel()
.mapToDouble(x -> x * fac - 1.5).toArray();
var bitsReversalMapping = computeBitsReversalMapping();
var computeEc = Executors.newWorkStealingPool(numCpus);
for (var i = 0; i < sideLen; i++) {
var y = i;
computeEc.submit(() -> {
var rowChunks = threadRowChunks.get();
var rowOffset = y * rowOutputSize;
var Ci = y * fac - 1.0;
//noinspection CommentedOutCode
try {
computeRow(Ci, aCr, bitsReversalMapping,
rowChunks, rowsMerged, rowOffset);
// computeScalar(Ci, aCr, rowsMerged, rowOffset, false);
// computeScalarPairs(Ci, aCr, rowsMerged, rowOffset);
} catch (Exception e) {
e.printStackTrace();
System.exit(-1);
}
});
}
computeEc.shutdown();
while (!computeEc.isTerminated()) {
try {
@SuppressWarnings("unused")
var ignored = computeEc.awaitTermination(1, TimeUnit.DAYS);
} catch (InterruptedException ignored) {
}
}
return rowsMerged;
}
private static byte[] computeBitsReversalMapping() {
var bitsReversalMapping = new byte[256];
for (var i = 0; i < 256; i++) {
bitsReversalMapping[i] = (byte) (Integer.reverse(i) >>> 24);
}
return bitsReversalMapping;
}
private static void computeRow(double Ci, double[] aCr,
byte[] bitsReversalMapping, long[] rowChunks,
byte[] rowsMerged, int rowOffset) {
computeChunksVector(Ci, aCr, rowChunks);
transferRowFlags(rowChunks, bitsReversalMapping, rowsMerged, rowOffset);
computeRemainderScalar(Ci, aCr, rowsMerged, rowOffset);
}
private static void computeChunksVector(double Ci, double[] aCr,
long[] rowChunks) {
var sideLen = aCr.length;
var vCi = DoubleVector.broadcast(SPECIES, Ci);
var vZeroes = DoubleVector.zero(SPECIES);
var vFours = DoubleVector.broadcast(SPECIES, 4.0);
var zeroMask = VectorMask.fromLong(SPECIES, 0);
// (1 << 6) = 64 = length of long in bits
for (var xBase = 0; xBase < (sideLen & -(1 << 6)); xBase += (1 << 6)) {
var cmpFlags = 0L;
for (var xInc = 0; xInc < (1 << 6); xInc += LANES * 2) {
var vZr1 = vZeroes;
var vZr2 = vZeroes;
var vZi1 = vZeroes;
var vZi2 = vZeroes;
var vCr1 = DoubleVector.fromArray(
SPECIES, aCr, xBase + xInc);
var vCr2 = DoubleVector.fromArray(
SPECIES, aCr, xBase + xInc + LANES);
var vZrN1 = vZeroes;
var vZrN2 = vZeroes;
var vZiN1 = vZeroes;
var vZiN2 = vZeroes;
var cmpMask1 = zeroMask;
var cmpMask2 = zeroMask;
var stop = false;
for (var outer = 0; !stop && outer < 10; outer++) {
for (var inner = 0; inner < 5; inner++) {
vZi1 = vZr1.add(vZr1).mul(vZi1).add(vCi);
vZi2 = vZr2.add(vZr2).mul(vZi2).add(vCi);
vZr1 = vZrN1.sub(vZiN1).add(vCr1);
vZr2 = vZrN2.sub(vZiN2).add(vCr2);
vZiN1 = vZi1.mul(vZi1);
vZiN2 = vZi2.mul(vZi2);
vZrN1 = vZr1.mul(vZr1);
vZrN2 = vZr2.mul(vZr2);
}
// I'm doing here: cmpMask = cmpMask.or(newValue);
// instead of just: cmpMask = newValue;
// because 4.lt(NaN) gives false
// NaN comes from Infinity - Infinity
// Infinity comes from numeric overflows
cmpMask1 = cmpMask1.or(vFours.lt(vZiN1.add(vZrN1)));
cmpMask2 = cmpMask2.or(vFours.lt(vZiN2.add(vZrN2)));
stop = cmpMask1.allTrue() & cmpMask2.allTrue();
}
cmpFlags |= cmpMask1.toLong() << xInc;
cmpFlags |= cmpMask2.toLong() << (xInc + LANES);
}
rowChunks[xBase >> 6] = cmpFlags;
}
}
private static void transferRowFlags(long[] rowChunks,
byte[] bitsReversalMapping,
byte[] rowsMerged, int rowOffset) {
for (var i = 0; i < rowChunks.length; i++) {
var group = ~rowChunks[i];
for (var j = 7; j >= 0; j--) {
rowsMerged[rowOffset + i * 8 + j] =
bitsReversalMapping[0xff & (byte) (group >>> (j * 8))];
}
}
}
private static void computeRemainderScalar(double Ci, double[] aCr,
byte[] rowsMerged,
int rowOffset) {
computeScalar(Ci, aCr, rowsMerged, rowOffset, true);
}
private static void computeScalar(double Ci, double[] aCr,
byte[] rowsMerged, int rowOffset,
boolean remainderOnly) {
var sideLen = aCr.length;
var startX = remainderOnly ? sideLen & -(1 << 6) : 0;
var bits = 0;
for (var x = startX; x < sideLen; x++) {
var Zr = 0.0;
var Zi = 0.0;
var Cr = aCr[x];
var i = 50;
var ZrN = 0.0;
var ZiN = 0.0;
do {
Zi = 2.0 * Zr * Zi + Ci;
Zr = ZrN - ZiN + Cr;
ZiN = Zi * Zi;
ZrN = Zr * Zr;
} while (ZiN + ZrN <= 4.0 && --i > 0);
bits <<= 1;
bits += i == 0 ? 1 : 0;
if (x % 8 == 7) {
rowsMerged[rowOffset + x / 8] = (byte) bits;
bits = 0;
}
}
if (sideLen % 8 != 0) {
rowsMerged[rowOffset + sideLen / 8] = (byte) bits;
}
}
// taken from Mandelbrot Java #2 solution
static void computeScalarPairs(double Ci, double[] aCr,
byte[] rowsMerged, int rowOffset) {
var sideLen = aCr.length / 8;
for (int xb = 0; xb < sideLen; xb++) {
int res = 0;
for (int i = 0; i < 8; i += 2) {
double Zr1 = aCr[xb * 8 + i];
double Zi1 = Ci;
double Zr2 = aCr[xb * 8 + i + 1];
double Zi2 = Ci;
int b = 0;
int j = 49;
do {
double nZr1 = Zr1 * Zr1 - Zi1 * Zi1 + aCr[xb * 8 + i];
double nZi1 = Zr1 * Zi1 + Zr1 * Zi1 + Ci;
Zr1 = nZr1;
Zi1 = nZi1;
double nZr2 = Zr2 * Zr2 - Zi2 * Zi2 + aCr[xb * 8 + i + 1];
double nZi2 = Zr2 * Zi2 + Zr2 * Zi2 + Ci;
Zr2 = nZr2;
Zi2 = nZi2;
if (Zr1 * Zr1 + Zi1 * Zi1 > 4) {
b |= 2;
//noinspection ConstantConditions
if (b == 3) break;
}
if (Zr2 * Zr2 + Zi2 * Zi2 > 4) {
b |= 1;
if (b == 3) break;
}
} while (--j > 0);
res = (res << 2) + b;
}
rowsMerged[rowOffset + xb] = (byte) (~res);
}
}
}
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>pl.tarsa</groupId>
<artifactId>benchmarks-game-java</artifactId>
<version>1.0-SNAPSHOT</version>
<packaging>jar</packaging>
<name>JMH benchmark sample: Java</name>
<dependencies>
<dependency>
<groupId>org.openjdk.jmh</groupId>
<artifactId>jmh-core</artifactId>
<version>${jmh.version}</version>
</dependency>
<dependency>
<groupId>org.openjdk.jmh</groupId>
<artifactId>jmh-generator-annprocess</artifactId>
<version>${jmh.version}</version>
<scope>provided</scope>
</dependency>
</dependencies>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<jmh.version>1.27</jmh.version>
<javac.target>11</javac.target>
<uberjar.name>benchmarks</uberjar.name>
</properties>
<build>
<plugins>
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>exec-maven-plugin</artifactId>
<version>3.0.0</version>
<configuration>
<executable>/home/piotrek/devel/jdk-16/bin/java</executable>
<arguments>
<argument>-classpath</argument>
<classpath/>
<argument>--add-modules</argument>
<argument>jdk.incubator.vector</argument>
<argument>-Djmh.blackhole.mode=COMPILER</argument>
<argument>-XX:-TieredCompilation</argument>
<argument>pl.tarsa.mandelbrot_simd_1</argument>
</arguments>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.8.0</version>
<configuration>
<compilerVersion>${javac.target}</compilerVersion>
<source>${javac.target}</source>
<target>${javac.target}</target>
<compilerArgs>
<arg>--add-modules=jdk.incubator.vector</arg>
</compilerArgs>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-shade-plugin</artifactId>
<version>3.2.1</version>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>shade</goal>
</goals>
<configuration>
<finalName>${uberjar.name}</finalName>
<transformers>
<transformer
implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
<mainClass>org.openjdk.jmh.Main</mainClass>
</transformer>
<transformer
implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/>
</transformers>
<filters>
<filter>
<!--
Shading signed JARs will fail without this.
http://stackoverflow.com/questions/999489/invalid-signature-file-when-attempting-to-run-a-jar
-->
<artifact>*:*</artifact>
<excludes>
<exclude>META-INF/*.SF</exclude>
<exclude>META-INF/*.DSA</exclude>
<exclude>META-INF/*.RSA</exclude>
</excludes>
</filter>
</filters>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
<pluginManagement>
<plugins>
<plugin>
<artifactId>maven-clean-plugin</artifactId>
<version>2.5</version>
</plugin>
<plugin>
<artifactId>maven-deploy-plugin</artifactId>
<version>2.8.1</version>
</plugin>
<plugin>
<artifactId>maven-install-plugin</artifactId>
<version>2.5.1</version>
</plugin>
<plugin>
<artifactId>maven-jar-plugin</artifactId>
<version>2.4</version>
</plugin>
<plugin>
<artifactId>maven-javadoc-plugin</artifactId>
<version>2.9.1</version>
</plugin>
<plugin>
<artifactId>maven-resources-plugin</artifactId>
<version>2.6</version>
</plugin>
<plugin>
<artifactId>maven-site-plugin</artifactId>
<version>3.3</version>
</plugin>
<plugin>
<artifactId>maven-source-plugin</artifactId>
<version>2.2.1</version>
</plugin>
<plugin>
<artifactId>maven-surefire-plugin</artifactId>
<version>2.17</version>
</plugin>
</plugins>
</pluginManagement>
</build>
</project>
Machine: Intel Core i5 4670 @ 3.80 GHz
Benchmarks game results (i.e. running from cmdline to produce the ~32 MB file):
mandelbrot Java #2 takes about 3s
mandelbrot from this gist takes usually about 1.2s - 1.3s, but sometimes even up to 1.5s.
Timing vary widely, probably due to long compilation times.
JMH results (a bit reordered):
# Run complete. Total time: 00:16:42
REMEMBER: The numbers below are just data. To gain reusable insights, you need to follow up on
why the numbers are the way they are. Use profilers (see -prof, -lprof), design factorial
experiments, perform baseline and negative tests that provide experimental control, make sure
the benchmarking environment is safe on JVM/OS/HW level, ask for reviews from the domain experts.
Do not assume the numbers tell you what you want them to tell.
Benchmark Mode Cnt Score Error Units
benchRowMT thrpt 5 17755,337 ± 279,118 ops/s
benchRowST thrpt 5 4535,280 ± 7,471 ops/s
benchScalarPairsMT thrpt 5 4583,354 ± 89,532 ops/s
benchScalarPairsST thrpt 5 1163,925 ± 0,469 ops/s
benchScalarMT thrpt 5 2666,210 ± 5,004 ops/s
benchScalarST thrpt 5 673,234 ± 0,167 ops/s
benchVectorMT thrpt 5 18020,397 ± 54,230 ops/s
benchVectorST thrpt 5 4567,873 ± 10,339 ops/s
benchVectorWithTransfer thrpt 5 4557,361 ± 9,450 ops/s
benchScalarRemainderOnly thrpt 5 7105989,167 ± 4691,311 ops/s
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment