Skip to content

Instantly share code, notes, and snippets.

@Nillerr
Created February 23, 2024 00:12
Show Gist options
  • Save Nillerr/cf159a878d66e6aa5c03af71da111512 to your computer and use it in GitHub Desktop.
Save Nillerr/cf159a878d66e6aa5c03af71da111512 to your computer and use it in GitHub Desktop.
abstract class BeforeSuiteExtension : BeforeAllCallback, ExtensionContext.Store.CloseableResource {
companion object {
@JvmStatic
private val lock = ReentrantLock()
@JvmStatic
private var isStarted: Boolean = false
}
override fun beforeAll(context: ExtensionContext) {
if (isStarted) {
return
}
lock.withLock {
if (isStarted) {
return
}
beforeSuite(context)
isStarted = true
context.root.getStore(GLOBAL).put(this::class, this)
}
}
override fun close() {
afterSuite()
}
abstract fun beforeSuite(context: ExtensionContext)
abstract fun afterSuite()
}
inline fun <reified T : Annotation> ExtensionContext.findAnnotationDeep(): T? {
return findAnnotationDeep(T::class)
}
fun <T : Annotation> ExtensionContext.findAnnotationDeep(type: KClass<T>): T? {
val function = checkNotNull(requiredTestMethod.kotlinFunction)
val funAnnotation = function.findAnnotationDeep(type)
if (funAnnotation != null) {
return funAnnotation
}
val cls = requiredTestClass.kotlin
val clsAnnotation = cls.findAnnotationDeep(type)
return clsAnnotation
}
fun <T : Annotation> KClass<*>.findAnnotationDeep(type: KClass<T>): T? {
val elements = Stack<KClass<*>>()
elements.push(this)
while (elements.isNotEmpty()) {
val element = elements.pop()
for (annotation in element.annotations) {
val cast = type.safeCast(annotation)
if (cast != null) {
return cast
}
elements.push(annotation.annotationClass)
}
}
return null
}
fun <T : Annotation> KFunction<*>.findAnnotationDeep(type: KClass<T>): T? {
for (annotation in annotations) {
val cast = type.safeCast(annotation)
if (cast != null) {
return cast
}
}
return null
}
class AutoCloseableResource(val resource: AutoCloseable) : ExtensionContext.Store.CloseableResource {
override fun close() {
resource.close()
}
}
data class CompositeKey(val keys: List<Any>) {
constructor(vararg keys: Any) : this(keys.toList())
}
@Inherited
@ExtendWith(JdbcDatabaseContainerTestExtension::class)
annotation class JdbcDatabaseContainerTest(
val type: String,
val dockerImageName: String = "",
val queryString: String = "",
@Language("SQL") val initSQL: String = "",
)
fun <T, V> KProperty1<T, V>.asMutableProperty(): KMutableProperty1<T, V>? = this as? KMutableProperty1<T, V>
@Suppress("UNCHECKED_CAST")
inline fun <reified V : Any> KMutableProperty1<Any, *>.asReturnType(): KMutableProperty1<Any, V>? {
val type = typeOf<V>()
if (returnType.isSupertypeOf(type)) {
return this as KMutableProperty1<Any, V>
}
return null
}
@Suppress("UNCHECKED_CAST")
fun <T : Any, V : Any> KMutableProperty1<T, *>.asReturnType(type: KClass<V>): KMutableProperty1<T, V>? {
if (returnType.isSupertypeOf(type.starProjectedType)) {
return this as KMutableProperty1<T, V>
}
return null
}
val <T : Any> KClass<T>.mutableMemberProperties: List<KMutableProperty1<T, *>>
get() = memberProperties.mapNotNull { it.asMutableProperty() }
inline fun <reified T : Any> classOf(): KClass<T> = T::class
fun <T : Any, V : Any> KClass<T>.mutableMemberPropertiesOf(type: KClass<V>): List<KMutableProperty1<T, V>> {
return mutableMemberProperties.mapNotNull { it.asReturnType(type) }
}
/**
* Creates a [JdbcDatabaseContainer] creates a [Connection] for each test.
*/
class JdbcDatabaseContainerTestExtension : BeforeEachCallback {
companion object {
@JvmStatic
private val stringType = typeOf<String>()
}
override fun beforeEach(context: ExtensionContext) {
val annotation = context.findAnnotationDeep<JdbcDatabaseContainerTest>()
if (annotation == null) {
return
}
val thread = Thread.currentThread()
val namespace = ExtensionContext.Namespace.create(thread)
val store = context.root.getStore(namespace)
val containerKey = JdbcDatabaseContainer::class
val container = store.get(containerKey, JdbcDatabaseContainer::class.java)
?: store.createContainer(containerKey, annotation)
val connectionKey = Connection::class
val connection = store.get(connectionKey, Connection::class.java)
?: store.createConnection(container, connectionKey, annotation)
val instance = context.requiredTestInstance
assignConnection(instance, connection)
}
@Suppress("UNCHECKED_CAST")
private fun assignConnection(instance: Any, connection: Connection) {
val type = instance::class as KClass<in Any>
val properties = type.mutableMemberPropertiesOf(classOf<Connection>())
for (property in properties) {
property.set(instance, connection)
}
}
private fun ExtensionContext.Store.createContainer(
key: KClass<JdbcDatabaseContainer<*>>,
annotation: JdbcDatabaseContainerTest,
): JdbcDatabaseContainer<*> {
val thread = Thread.currentThread()
val classLoader = thread.contextClassLoader
val clazz = classLoader.loadClass(annotation.type)
val containerType = clazz.kotlin
val dockerImageName = annotation.dockerImageName
println("[JdbcDatabaseContainerTestExtension] Creating")
val container = if (annotation.dockerImageName.isEmpty()) {
val constructor = containerType.constructors.firstOrNull { it.parameters.isEmpty() }
checkNotNull(constructor) {
"The type `$containerType` has no empty empty constructor and no `dockerImageName` was specified. " +
"Specify a `dockerImageName` if this type has a `String` constructor."
}
constructor.call(dockerImageName) as JdbcDatabaseContainer<*>
} else {
val constructor = containerType.constructors
.firstOrNull { it.parameters.size == 1 && stringType.isSupertypeOf(it.parameters[0].type) }
checkNotNull(constructor) {
"The type `$containerType` has no `String` constructor and a `dockerImageName` was specified. " +
"Either add a `String` constructor to the type or remove the `dockerImageName` from the annotation"
}
constructor.call(dockerImageName) as JdbcDatabaseContainer<*>
}
put(key, container)
println("[JdbcDatabaseContainerTestExtension] Starting")
container.start()
println("[JdbcDatabaseContainerTestExtension] Started")
val resource = AutoCloseableResource(container)
put(CompositeKey(key, resource), resource)
return container
}
private fun ExtensionContext.Store.createConnection(
container: JdbcDatabaseContainer<*>,
key: KClass<Connection>,
annotation: JdbcDatabaseContainerTest,
): Connection {
val queryString = annotation.queryString
val connection = container.createConnection(queryString)
put(key, connection)
val resource = AutoCloseableResource(connection)
put(CompositeKey(key, resource), resource)
connection.initialize(annotation)
return connection
}
@Suppress("SqlSourceToSinkFlow")
private fun Connection.initialize(annotation: JdbcDatabaseContainerTest) {
println("[JdbcDatabaseContainerTestExtension] Initializing")
val initSQL = annotation.initSQL
if (initSQL.isNotEmpty()) {
createStatement().execute(initSQL)
}
println("[JdbcDatabaseContainerTestExtension] Initialized")
}
}
@Inherited
@ExtendWith(LiquibaseTestExtension::class)
annotation class LiquibaseTest(
val changeLogFile: String,
val contexts: String = "",
val rollback: Boolean = true,
)
/**
* Finds a [Connection] runs migrations on it once.
*/
class LiquibaseTestExtension : BeforeEachCallback {
override fun beforeEach(context: ExtensionContext) {
val annotation = context.findAnnotationDeep<LiquibaseTest>()
if (annotation == null) {
return
}
val thread = Thread.currentThread()
val namespace = ExtensionContext.Namespace.create(thread)
val store = context.root.getStore(namespace)
val key = LiquibaseTestExtension::class.java
val isInitialized = store.get(key, Boolean::class.java)
if (isInitialized == true) {
return
}
val connection = store.getConnection()
connection.initialize(annotation)
store.put(key, true)
}
private fun Connection.initialize(annotation: LiquibaseTest) {
val changeLogFile = annotation.changeLogFile
println("[LiquibaseTestExtension] Initializing")
val liquibaseConnection = JdbcConnection(this)
val database = DatabaseFactory.getInstance()
.findCorrectDatabaseImplementation(liquibaseConnection)
val resourceAccessor = ClassLoaderResourceAccessor()
val liquibase = Liquibase(changeLogFile, resourceAccessor, database)
if (annotation.rollback) {
liquibase.updateTestingRollback(annotation.contexts)
} else {
liquibase.update(annotation.contexts)
}
println("[LiquibaseTestExtension] Initialized")
}
private fun ExtensionContext.Store.getConnection(): Connection {
val key = Connection::class
val connection = get(key, Connection::class.java)
checkNotNull(connection) { "A connection could not be found" }
return connection
}
}
@Inherited
@ExtendWith(JOOQTestExtension::class)
annotation class JOOQTest
/**
* Finds a [Connection] and creates a [DSLContext] for each test.
*/
class JOOQTestExtension : BeforeEachCallback {
override fun beforeEach(context: ExtensionContext) {
val annotation = context.findAnnotationDeep<JOOQTest>()
if (annotation == null) {
return
}
val thread = Thread.currentThread()
val namespace = ExtensionContext.Namespace.create(thread)
val store = context.root.getStore(namespace)
val connection = store.getConnection()
val dsl = DSL.using(connection)
val instance = context.requiredTestInstance
assignDSLContext(instance, dsl)
}
private fun ExtensionContext.Store.getConnection(): Connection {
val key = Connection::class
val connection = get(key, Connection::class.java)
checkNotNull(connection) { "A connection could not be found" }
return connection
}
@Suppress("UNCHECKED_CAST")
private fun assignDSLContext(instance: Any, dsl: DSLContext) {
val type = instance::class as KClass<in Any>
val properties = type.mutableMemberPropertiesOf(classOf<DSLContext>())
for (property in properties) {
property.set(instance, dsl)
}
}
}
@Inherited
@ExtendWith(TransactionalServiceTestExtension::class)
annotation class TransactionalServiceTest
/**
* Finds a [Connection] and creates a [TransactionalServiceContext] for each test.
*/
class TransactionalServiceTestExtension : BeforeEachCallback {
override fun beforeEach(context: ExtensionContext) {
val annotation = context.findAnnotationDeep<TransactionalServiceTest>()
if (annotation == null) {
return
}
val thread = Thread.currentThread()
val namespace = ExtensionContext.Namespace.create(thread)
val store = context.root.getStore(namespace)
val connection = store.getConnection()
val dsl = DSL.using(connection)
val serviceContext = TransactionalServiceContext(dsl)
val instance = context.requiredTestInstance
assignTransactionalServiceContext(instance, serviceContext)
}
private fun ExtensionContext.Store.getConnection(): Connection {
val key = Connection::class
val connection = get(key, Connection::class.java)
checkNotNull(connection) { "A connection could not be found" }
return connection
}
@Suppress("UNCHECKED_CAST")
private fun assignTransactionalServiceContext(instance: Any, serviceContext: TransactionalServiceContext) {
val type = instance::class as KClass<in Any>
val properties = type.mutableMemberPropertiesOf(classOf<TransactionalServiceContext>())
for (property in properties) {
property.set(instance, serviceContext)
}
}
}
@JdbcDatabaseContainerTest(
type = "org.testcontainers.containers.PostgreSQLContainer",
dockerImageName = "postgres:13-alpine",
queryString = "?currentSchema=account",
initSQL = """
CREATE SCHEMA account;
CREATE SCHEMA card;
"""
)
@LiquibaseTest(changeLogFile = "db/changelog/changelog.xml")
@JOOQTest
@TransactionalServiceTest
annotation class PostgresTest
class ThreadLocalProperty<V>(initialValue: () -> V) : ReadWriteProperty<Any, V> {
private val threadLocalValue: ThreadLocal<V> = ThreadLocal.withInitial {
initialValue()
}
override fun getValue(thisRef: Any, property: KProperty<*>): V {
return threadLocalValue.get()
}
override fun setValue(thisRef: Any, property: KProperty<*>, value: V) {
threadLocalValue.set(value)
}
}
fun <V> threadLocal(initialValue: () -> V): ThreadLocalProperty<V> {
return ThreadLocalProperty(initialValue)
}
operator fun <T> ThreadLocal<T>.getValue(thisRef: Any?, property: KProperty<*>) = get()
operator fun <T> ThreadLocal<T>.setValue(thisRef: Any?, property: KProperty<*>, value: T) = set(value)
class PostgresTestExtension : BeforeEachCallback {
companion object {
@JvmStatic
private val connectionType = typeOf<Connection>()
@JvmStatic
private val dslContextType = typeOf<DSLContext>()
@JvmStatic
private val transactionalServiceContextType = typeOf<TransactionalServiceContext>()
}
override fun beforeEach(context: ExtensionContext) {
val type = context.requiredTestClass.kotlin
val annotation = context.findAnnotationDeep<PostgresTest>()
if (annotation == null) {
return
}
val container = context.getContainer()
val connection = container.createConnection("")
context.getStore(GLOBAL).put(connection, AutoCloseableResource(connection))
val instance = context.requiredTestInstance
val properties = type.memberProperties
val mutableProperties = properties.filterIsInstance<KMutableProperty1<Any, Any>>()
mutableProperties.forEach { property ->
when {
property.returnType.isSupertypeOf(connectionType) ->
property.set(instance, connection)
property.returnType.isSupertypeOf(dslContextType) ->
property.set(instance, DSL.using(connection))
property.returnType.isSupertypeOf(transactionalServiceContextType) ->
property.set(instance, TransactionalServiceContext(DSL.using(connection)))
}
}
}
private fun ExtensionContext.getContainer(): PostgreSQLContainer<*> {
val store = root.getStore(GLOBAL)
val containerType = PostgreSQLContainer::class
val dockerImageName = "postgres:13-alpine"
val key = "${this::class}#${containerType}#${dockerImageName}#${Thread.currentThread().id}"
val instance1 = store.get(key, AutoCloseableResource::class.java)
if (instance1 != null) {
println("[PostgresContainer] Reusing")
return containerType.cast(instance1.resource)
}
println("[PostgresContainer] Starting")
val postgres = PostgreSQLContainer(dockerImageName)
postgres.start()
println("[PostgresContainer] Started")
val resource = AutoCloseableResource(postgres)
store.put(key, resource)
println("[PostgresContainer] Initializing")
postgres.createConnection("").use { connection ->
connection.createStatement().execute("CREATE SCHEMA account")
connection.createStatement().execute("CREATE SCHEMA card")
connection.schema = "account"
val liquibaseConnection = JdbcConnection(connection)
val database = DatabaseFactory.getInstance()
.findCorrectDatabaseImplementation(liquibaseConnection)
val resourceAccessor = ClassLoaderResourceAccessor()
val liquibase = Liquibase("db/changelog/changelog.xml", resourceAccessor, database)
liquibase.updateTestingRollback(null)
}
println("[PostgresContainer] Initialized")
return postgres
// lock.withLock {
// val instance2 = store.get(key, AutoCloseableResource::class.java)
// if (instance2 != null) {
// return containerType.cast(instance2)
// }
// }
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment