Skip to content

Instantly share code, notes, and snippets.

@swankjesse
Created May 26, 2020 22:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save swankjesse/92c93842f9c3705ca976031e7d0e664a to your computer and use it in GitHub Desktop.
Save swankjesse/92c93842f9c3705ca976031e7d0e664a to your computer and use it in GitHub Desktop.

JUnit Test Sharding

This is a JUnit extension that shards the tests for parallel execution.

It selects tests by hashing the test class and method name.

Configuration

Tests are sharded if these environment variables are present:

KOCHIKU_WORKER_CHUNK=1
KOCHIKU_TOTAL_WORKERS=3

The set of tests is divided into a range with KOCHIKU_TOTAL_WORKERS parts, and runs only the tests in the part KOCHIKU_WORKER_CHUNK. The KOCHIKU_WORKER_CHUNK environment variable is a 1-based index, so if there are 3 workers the chunks should be 1, 2, and 3.

To use this in Gradle, enable autodetection in the test {} section of your build.gradle:

test {
  systemProperty 'junit.jupiter.extensions.autodetection.enabled', 'true' // For test sharding.
  ...
}

You'll also need to add this module to your test's runtime classpath:

dependencies {
  testRuntimeOnly dep.testSharding
}
# This file goes in META-INF/services/org.junit.jupiter.api.extension.Extension
# It must be on the classpath for JUnit to discover it
com.squareup.polyrepo.testsharding.TestShardingExtension
package com.squareup.polyrepo.testsharding
import okio.Buffer
import okio.ByteString.Companion.encodeUtf8
import org.junit.jupiter.api.extension.ConditionEvaluationResult
import org.junit.jupiter.api.extension.ExecutionCondition
import org.junit.jupiter.api.extension.Extension
import org.junit.jupiter.api.extension.ExtensionContext
/**
* Implement test sharding as a JUnit extension. This class is registered using
* [java.util.ServiceLoader] and so must be on the test suite's class path to run.
*
* TODO: Use ranges rather than hashing for better distribution.
* TODO: Get JUnit to include a feature like this by default.
*/
class TestShardingExtension(
private val totalWorkers: Int,
private val workerChunk: Int
) : Extension, ExecutionCondition {
constructor() : this(
totalWorkers = System.getenv("KOCHIKU_TOTAL_WORKERS")?.toInt() ?: 1,
workerChunk = System.getenv("KOCHIKU_WORKER_CHUNK")?.toInt() ?: 1
)
init {
require(totalWorkers > 0) { "unexpected total workers: $totalWorkers" }
require(workerChunk in 1..totalWorkers) { "unexpected worker $workerChunk" }
}
override fun evaluateExecutionCondition(context: ExtensionContext): ConditionEvaluationResult {
if (context.testInstance.isEmpty) {
return ConditionEvaluationResult.enabled("filtering child nodes only")
}
if (totalWorkers == 1) {
return ConditionEvaluationResult.enabled("testing all chunks")
}
val testClassName = context.testClass.orElse(null)?.name ?: ""
val testMethodName = context.testMethod.orElse(null)?.name ?: ""
val chunk = hashToRange("$testClassName.$testMethodName", totalWorkers) + 1
return if (chunk == workerChunk) {
ConditionEvaluationResult.enabled("$chunk == $workerChunk")
} else {
ConditionEvaluationResult.disabled("$chunk != $workerChunk")
}
}
/** Returns a value in `[0..max)` that is uniformly distributed for different strings. */
internal fun hashToRange(string: String, max: Int): Int {
val hash = string.encodeUtf8().sha256()
val hashInt = Buffer().write(hash).readInt()
val hashUnsignedInt = hashInt.toLong() and 0xffffffffL
return (hashUnsignedInt % max).toInt()
}
}
@swankjesse
Copy link
Author

Some related work on JUnit itself:
junit-team/junit5#1041
junit-team/junit5#1055

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment