Created
February 23, 2024 00:12
-
-
Save Nillerr/cf159a878d66e6aa5c03af71da111512 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
abstract class BeforeSuiteExtension : BeforeAllCallback, ExtensionContext.Store.CloseableResource { | |
companion object { | |
@JvmStatic | |
private val lock = ReentrantLock() | |
@JvmStatic | |
private var isStarted: Boolean = false | |
} | |
override fun beforeAll(context: ExtensionContext) { | |
if (isStarted) { | |
return | |
} | |
lock.withLock { | |
if (isStarted) { | |
return | |
} | |
beforeSuite(context) | |
isStarted = true | |
context.root.getStore(GLOBAL).put(this::class, this) | |
} | |
} | |
override fun close() { | |
afterSuite() | |
} | |
abstract fun beforeSuite(context: ExtensionContext) | |
abstract fun afterSuite() | |
} | |
inline fun <reified T : Annotation> ExtensionContext.findAnnotationDeep(): T? { | |
return findAnnotationDeep(T::class) | |
} | |
fun <T : Annotation> ExtensionContext.findAnnotationDeep(type: KClass<T>): T? { | |
val function = checkNotNull(requiredTestMethod.kotlinFunction) | |
val funAnnotation = function.findAnnotationDeep(type) | |
if (funAnnotation != null) { | |
return funAnnotation | |
} | |
val cls = requiredTestClass.kotlin | |
val clsAnnotation = cls.findAnnotationDeep(type) | |
return clsAnnotation | |
} | |
fun <T : Annotation> KClass<*>.findAnnotationDeep(type: KClass<T>): T? { | |
val elements = Stack<KClass<*>>() | |
elements.push(this) | |
while (elements.isNotEmpty()) { | |
val element = elements.pop() | |
for (annotation in element.annotations) { | |
val cast = type.safeCast(annotation) | |
if (cast != null) { | |
return cast | |
} | |
elements.push(annotation.annotationClass) | |
} | |
} | |
return null | |
} | |
fun <T : Annotation> KFunction<*>.findAnnotationDeep(type: KClass<T>): T? { | |
for (annotation in annotations) { | |
val cast = type.safeCast(annotation) | |
if (cast != null) { | |
return cast | |
} | |
} | |
return null | |
} | |
class AutoCloseableResource(val resource: AutoCloseable) : ExtensionContext.Store.CloseableResource { | |
override fun close() { | |
resource.close() | |
} | |
} | |
data class CompositeKey(val keys: List<Any>) { | |
constructor(vararg keys: Any) : this(keys.toList()) | |
} | |
@Inherited | |
@ExtendWith(JdbcDatabaseContainerTestExtension::class) | |
annotation class JdbcDatabaseContainerTest( | |
val type: String, | |
val dockerImageName: String = "", | |
val queryString: String = "", | |
@Language("SQL") val initSQL: String = "", | |
) | |
fun <T, V> KProperty1<T, V>.asMutableProperty(): KMutableProperty1<T, V>? = this as? KMutableProperty1<T, V> | |
@Suppress("UNCHECKED_CAST") | |
inline fun <reified V : Any> KMutableProperty1<Any, *>.asReturnType(): KMutableProperty1<Any, V>? { | |
val type = typeOf<V>() | |
if (returnType.isSupertypeOf(type)) { | |
return this as KMutableProperty1<Any, V> | |
} | |
return null | |
} | |
@Suppress("UNCHECKED_CAST") | |
fun <T : Any, V : Any> KMutableProperty1<T, *>.asReturnType(type: KClass<V>): KMutableProperty1<T, V>? { | |
if (returnType.isSupertypeOf(type.starProjectedType)) { | |
return this as KMutableProperty1<T, V> | |
} | |
return null | |
} | |
val <T : Any> KClass<T>.mutableMemberProperties: List<KMutableProperty1<T, *>> | |
get() = memberProperties.mapNotNull { it.asMutableProperty() } | |
inline fun <reified T : Any> classOf(): KClass<T> = T::class | |
fun <T : Any, V : Any> KClass<T>.mutableMemberPropertiesOf(type: KClass<V>): List<KMutableProperty1<T, V>> { | |
return mutableMemberProperties.mapNotNull { it.asReturnType(type) } | |
} | |
/** | |
* Creates a [JdbcDatabaseContainer] creates a [Connection] for each test. | |
*/ | |
class JdbcDatabaseContainerTestExtension : BeforeEachCallback { | |
companion object { | |
@JvmStatic | |
private val stringType = typeOf<String>() | |
} | |
override fun beforeEach(context: ExtensionContext) { | |
val annotation = context.findAnnotationDeep<JdbcDatabaseContainerTest>() | |
if (annotation == null) { | |
return | |
} | |
val thread = Thread.currentThread() | |
val namespace = ExtensionContext.Namespace.create(thread) | |
val store = context.root.getStore(namespace) | |
val containerKey = JdbcDatabaseContainer::class | |
val container = store.get(containerKey, JdbcDatabaseContainer::class.java) | |
?: store.createContainer(containerKey, annotation) | |
val connectionKey = Connection::class | |
val connection = store.get(connectionKey, Connection::class.java) | |
?: store.createConnection(container, connectionKey, annotation) | |
val instance = context.requiredTestInstance | |
assignConnection(instance, connection) | |
} | |
@Suppress("UNCHECKED_CAST") | |
private fun assignConnection(instance: Any, connection: Connection) { | |
val type = instance::class as KClass<in Any> | |
val properties = type.mutableMemberPropertiesOf(classOf<Connection>()) | |
for (property in properties) { | |
property.set(instance, connection) | |
} | |
} | |
private fun ExtensionContext.Store.createContainer( | |
key: KClass<JdbcDatabaseContainer<*>>, | |
annotation: JdbcDatabaseContainerTest, | |
): JdbcDatabaseContainer<*> { | |
val thread = Thread.currentThread() | |
val classLoader = thread.contextClassLoader | |
val clazz = classLoader.loadClass(annotation.type) | |
val containerType = clazz.kotlin | |
val dockerImageName = annotation.dockerImageName | |
println("[JdbcDatabaseContainerTestExtension] Creating") | |
val container = if (annotation.dockerImageName.isEmpty()) { | |
val constructor = containerType.constructors.firstOrNull { it.parameters.isEmpty() } | |
checkNotNull(constructor) { | |
"The type `$containerType` has no empty empty constructor and no `dockerImageName` was specified. " + | |
"Specify a `dockerImageName` if this type has a `String` constructor." | |
} | |
constructor.call(dockerImageName) as JdbcDatabaseContainer<*> | |
} else { | |
val constructor = containerType.constructors | |
.firstOrNull { it.parameters.size == 1 && stringType.isSupertypeOf(it.parameters[0].type) } | |
checkNotNull(constructor) { | |
"The type `$containerType` has no `String` constructor and a `dockerImageName` was specified. " + | |
"Either add a `String` constructor to the type or remove the `dockerImageName` from the annotation" | |
} | |
constructor.call(dockerImageName) as JdbcDatabaseContainer<*> | |
} | |
put(key, container) | |
println("[JdbcDatabaseContainerTestExtension] Starting") | |
container.start() | |
println("[JdbcDatabaseContainerTestExtension] Started") | |
val resource = AutoCloseableResource(container) | |
put(CompositeKey(key, resource), resource) | |
return container | |
} | |
private fun ExtensionContext.Store.createConnection( | |
container: JdbcDatabaseContainer<*>, | |
key: KClass<Connection>, | |
annotation: JdbcDatabaseContainerTest, | |
): Connection { | |
val queryString = annotation.queryString | |
val connection = container.createConnection(queryString) | |
put(key, connection) | |
val resource = AutoCloseableResource(connection) | |
put(CompositeKey(key, resource), resource) | |
connection.initialize(annotation) | |
return connection | |
} | |
@Suppress("SqlSourceToSinkFlow") | |
private fun Connection.initialize(annotation: JdbcDatabaseContainerTest) { | |
println("[JdbcDatabaseContainerTestExtension] Initializing") | |
val initSQL = annotation.initSQL | |
if (initSQL.isNotEmpty()) { | |
createStatement().execute(initSQL) | |
} | |
println("[JdbcDatabaseContainerTestExtension] Initialized") | |
} | |
} | |
@Inherited | |
@ExtendWith(LiquibaseTestExtension::class) | |
annotation class LiquibaseTest( | |
val changeLogFile: String, | |
val contexts: String = "", | |
val rollback: Boolean = true, | |
) | |
/** | |
* Finds a [Connection] runs migrations on it once. | |
*/ | |
class LiquibaseTestExtension : BeforeEachCallback { | |
override fun beforeEach(context: ExtensionContext) { | |
val annotation = context.findAnnotationDeep<LiquibaseTest>() | |
if (annotation == null) { | |
return | |
} | |
val thread = Thread.currentThread() | |
val namespace = ExtensionContext.Namespace.create(thread) | |
val store = context.root.getStore(namespace) | |
val key = LiquibaseTestExtension::class.java | |
val isInitialized = store.get(key, Boolean::class.java) | |
if (isInitialized == true) { | |
return | |
} | |
val connection = store.getConnection() | |
connection.initialize(annotation) | |
store.put(key, true) | |
} | |
private fun Connection.initialize(annotation: LiquibaseTest) { | |
val changeLogFile = annotation.changeLogFile | |
println("[LiquibaseTestExtension] Initializing") | |
val liquibaseConnection = JdbcConnection(this) | |
val database = DatabaseFactory.getInstance() | |
.findCorrectDatabaseImplementation(liquibaseConnection) | |
val resourceAccessor = ClassLoaderResourceAccessor() | |
val liquibase = Liquibase(changeLogFile, resourceAccessor, database) | |
if (annotation.rollback) { | |
liquibase.updateTestingRollback(annotation.contexts) | |
} else { | |
liquibase.update(annotation.contexts) | |
} | |
println("[LiquibaseTestExtension] Initialized") | |
} | |
private fun ExtensionContext.Store.getConnection(): Connection { | |
val key = Connection::class | |
val connection = get(key, Connection::class.java) | |
checkNotNull(connection) { "A connection could not be found" } | |
return connection | |
} | |
} | |
@Inherited | |
@ExtendWith(JOOQTestExtension::class) | |
annotation class JOOQTest | |
/** | |
* Finds a [Connection] and creates a [DSLContext] for each test. | |
*/ | |
class JOOQTestExtension : BeforeEachCallback { | |
override fun beforeEach(context: ExtensionContext) { | |
val annotation = context.findAnnotationDeep<JOOQTest>() | |
if (annotation == null) { | |
return | |
} | |
val thread = Thread.currentThread() | |
val namespace = ExtensionContext.Namespace.create(thread) | |
val store = context.root.getStore(namespace) | |
val connection = store.getConnection() | |
val dsl = DSL.using(connection) | |
val instance = context.requiredTestInstance | |
assignDSLContext(instance, dsl) | |
} | |
private fun ExtensionContext.Store.getConnection(): Connection { | |
val key = Connection::class | |
val connection = get(key, Connection::class.java) | |
checkNotNull(connection) { "A connection could not be found" } | |
return connection | |
} | |
@Suppress("UNCHECKED_CAST") | |
private fun assignDSLContext(instance: Any, dsl: DSLContext) { | |
val type = instance::class as KClass<in Any> | |
val properties = type.mutableMemberPropertiesOf(classOf<DSLContext>()) | |
for (property in properties) { | |
property.set(instance, dsl) | |
} | |
} | |
} | |
@Inherited | |
@ExtendWith(TransactionalServiceTestExtension::class) | |
annotation class TransactionalServiceTest | |
/** | |
* Finds a [Connection] and creates a [TransactionalServiceContext] for each test. | |
*/ | |
class TransactionalServiceTestExtension : BeforeEachCallback { | |
override fun beforeEach(context: ExtensionContext) { | |
val annotation = context.findAnnotationDeep<TransactionalServiceTest>() | |
if (annotation == null) { | |
return | |
} | |
val thread = Thread.currentThread() | |
val namespace = ExtensionContext.Namespace.create(thread) | |
val store = context.root.getStore(namespace) | |
val connection = store.getConnection() | |
val dsl = DSL.using(connection) | |
val serviceContext = TransactionalServiceContext(dsl) | |
val instance = context.requiredTestInstance | |
assignTransactionalServiceContext(instance, serviceContext) | |
} | |
private fun ExtensionContext.Store.getConnection(): Connection { | |
val key = Connection::class | |
val connection = get(key, Connection::class.java) | |
checkNotNull(connection) { "A connection could not be found" } | |
return connection | |
} | |
@Suppress("UNCHECKED_CAST") | |
private fun assignTransactionalServiceContext(instance: Any, serviceContext: TransactionalServiceContext) { | |
val type = instance::class as KClass<in Any> | |
val properties = type.mutableMemberPropertiesOf(classOf<TransactionalServiceContext>()) | |
for (property in properties) { | |
property.set(instance, serviceContext) | |
} | |
} | |
} | |
@JdbcDatabaseContainerTest( | |
type = "org.testcontainers.containers.PostgreSQLContainer", | |
dockerImageName = "postgres:13-alpine", | |
queryString = "?currentSchema=account", | |
initSQL = """ | |
CREATE SCHEMA account; | |
CREATE SCHEMA card; | |
""" | |
) | |
@LiquibaseTest(changeLogFile = "db/changelog/changelog.xml") | |
@JOOQTest | |
@TransactionalServiceTest | |
annotation class PostgresTest | |
class ThreadLocalProperty<V>(initialValue: () -> V) : ReadWriteProperty<Any, V> { | |
private val threadLocalValue: ThreadLocal<V> = ThreadLocal.withInitial { | |
initialValue() | |
} | |
override fun getValue(thisRef: Any, property: KProperty<*>): V { | |
return threadLocalValue.get() | |
} | |
override fun setValue(thisRef: Any, property: KProperty<*>, value: V) { | |
threadLocalValue.set(value) | |
} | |
} | |
fun <V> threadLocal(initialValue: () -> V): ThreadLocalProperty<V> { | |
return ThreadLocalProperty(initialValue) | |
} | |
operator fun <T> ThreadLocal<T>.getValue(thisRef: Any?, property: KProperty<*>) = get() | |
operator fun <T> ThreadLocal<T>.setValue(thisRef: Any?, property: KProperty<*>, value: T) = set(value) | |
class PostgresTestExtension : BeforeEachCallback { | |
companion object { | |
@JvmStatic | |
private val connectionType = typeOf<Connection>() | |
@JvmStatic | |
private val dslContextType = typeOf<DSLContext>() | |
@JvmStatic | |
private val transactionalServiceContextType = typeOf<TransactionalServiceContext>() | |
} | |
override fun beforeEach(context: ExtensionContext) { | |
val type = context.requiredTestClass.kotlin | |
val annotation = context.findAnnotationDeep<PostgresTest>() | |
if (annotation == null) { | |
return | |
} | |
val container = context.getContainer() | |
val connection = container.createConnection("") | |
context.getStore(GLOBAL).put(connection, AutoCloseableResource(connection)) | |
val instance = context.requiredTestInstance | |
val properties = type.memberProperties | |
val mutableProperties = properties.filterIsInstance<KMutableProperty1<Any, Any>>() | |
mutableProperties.forEach { property -> | |
when { | |
property.returnType.isSupertypeOf(connectionType) -> | |
property.set(instance, connection) | |
property.returnType.isSupertypeOf(dslContextType) -> | |
property.set(instance, DSL.using(connection)) | |
property.returnType.isSupertypeOf(transactionalServiceContextType) -> | |
property.set(instance, TransactionalServiceContext(DSL.using(connection))) | |
} | |
} | |
} | |
private fun ExtensionContext.getContainer(): PostgreSQLContainer<*> { | |
val store = root.getStore(GLOBAL) | |
val containerType = PostgreSQLContainer::class | |
val dockerImageName = "postgres:13-alpine" | |
val key = "${this::class}#${containerType}#${dockerImageName}#${Thread.currentThread().id}" | |
val instance1 = store.get(key, AutoCloseableResource::class.java) | |
if (instance1 != null) { | |
println("[PostgresContainer] Reusing") | |
return containerType.cast(instance1.resource) | |
} | |
println("[PostgresContainer] Starting") | |
val postgres = PostgreSQLContainer(dockerImageName) | |
postgres.start() | |
println("[PostgresContainer] Started") | |
val resource = AutoCloseableResource(postgres) | |
store.put(key, resource) | |
println("[PostgresContainer] Initializing") | |
postgres.createConnection("").use { connection -> | |
connection.createStatement().execute("CREATE SCHEMA account") | |
connection.createStatement().execute("CREATE SCHEMA card") | |
connection.schema = "account" | |
val liquibaseConnection = JdbcConnection(connection) | |
val database = DatabaseFactory.getInstance() | |
.findCorrectDatabaseImplementation(liquibaseConnection) | |
val resourceAccessor = ClassLoaderResourceAccessor() | |
val liquibase = Liquibase("db/changelog/changelog.xml", resourceAccessor, database) | |
liquibase.updateTestingRollback(null) | |
} | |
println("[PostgresContainer] Initialized") | |
return postgres | |
// lock.withLock { | |
// val instance2 = store.get(key, AutoCloseableResource::class.java) | |
// if (instance2 != null) { | |
// return containerType.cast(instance2) | |
// } | |
// } | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment