Skip to content

Instantly share code, notes, and snippets.

@dirkluijk
Last active December 3, 2023 06:03
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dirkluijk/1004c178922646dc297c862608c39c48 to your computer and use it in GitHub Desktop.
Save dirkluijk/1004c178922646dc297c862608c39c48 to your computer and use it in GitHub Desktop.
Specification API for Spring JDBC
package com.foo.util.specification.jdbc
import org.springframework.context.ApplicationEventPublisher
import org.springframework.data.jdbc.core.JdbcAggregateTemplate
import org.springframework.data.jdbc.core.convert.DataAccessStrategy
import org.springframework.data.jdbc.core.convert.JdbcConverter
import org.springframework.data.jdbc.repository.support.JdbcRepositoryFactory
import org.springframework.data.mapping.callback.EntityCallbacks
import org.springframework.data.relational.core.dialect.Dialect
import org.springframework.data.relational.core.dialect.RenderContextFactory
import org.springframework.data.relational.core.mapping.RelationalMappingContext
import org.springframework.data.relational.core.sql.render.SqlRenderer
import org.springframework.data.repository.core.RepositoryInformation
import org.springframework.data.repository.core.RepositoryMetadata
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate
import org.springframework.stereotype.Component
/**
* Class needed to register {@link CustomSimpleJdbcRepository}.
*
* This makes sure that every repository is allowed to extend the SpecificationRepository
* interface and automatically uses the implementation of CustomSimpleJdbcRepository.
*/
@Component
class CustomJdbcRepositoryFactory(
private val accessStrategy: DataAccessStrategy,
private val context: RelationalMappingContext,
private val converter: JdbcConverter,
private val dialect: Dialect,
private val publisher: ApplicationEventPublisher,
operations: NamedParameterJdbcOperations,
private val jdbcTemplate: NamedParameterJdbcTemplate,
private val jdbcConverter: JdbcConverter
): JdbcRepositoryFactory(accessStrategy, context, converter, dialect, publisher, operations) {
private var entityCallbacks: EntityCallbacks = EntityCallbacks.create()
override fun getRepositoryBaseClass(repositoryMetadata: RepositoryMetadata?): Class<*>? {
return CustomSimpleJdbcRepository::class.java
}
override fun getTargetRepository(repositoryInformation: RepositoryInformation): Any? {
val template = JdbcAggregateTemplate(publisher, context, converter, accessStrategy)
val renderer = SqlRenderer.create(RenderContextFactory(dialect).createRenderContext())
val repository: CustomSimpleJdbcRepository<*, Any> = CustomSimpleJdbcRepository(
jdbcTemplate,
accessStrategy,
template,
context.getRequiredPersistentEntity(repositoryInformation.domainType),
jdbcConverter,
renderer
)
template.setEntityCallbacks(entityCallbacks)
return repository
}
override fun setEntityCallbacks(entityCallbacks: EntityCallbacks) {
this.entityCallbacks = entityCallbacks
}
}
package com.foo.util.specification.jdbc
import org.springframework.beans.factory.BeanFactory
import org.springframework.beans.factory.annotation.Autowired
import org.springframework.context.ApplicationEventPublisher
import org.springframework.context.ApplicationEventPublisherAware
import org.springframework.data.jdbc.core.convert.DataAccessStrategy
import org.springframework.data.jdbc.core.convert.JdbcConverter
import org.springframework.data.jdbc.repository.QueryMappingConfiguration
import org.springframework.data.mapping.callback.EntityCallbacks
import org.springframework.data.relational.core.dialect.Dialect
import org.springframework.data.relational.core.mapping.RelationalMappingContext
import org.springframework.data.repository.Repository
import org.springframework.data.repository.core.support.RepositoryFactorySupport
import org.springframework.data.repository.core.support.TransactionalRepositoryFactoryBeanSupport
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate
/**
* Class needed to register {@link CustomJdbcRepositoryFactory}.
*
* Inspired by {@link org.springframework.data.jdbc.repository.support.JdbcRepositoryFactoryBean}
* I really don't like this class, but it's the only way to integrate into Spring JDBC.
*/
class CustomJdbcRepositoryFactoryBean<T : Repository<S, ID>, S, ID : java.io.Serializable>(
repositoryInterface: Class<T>
): TransactionalRepositoryFactoryBeanSupport<T, S, ID>(repositoryInterface), ApplicationEventPublisherAware {
@Autowired
private lateinit var publisher: ApplicationEventPublisher
@Autowired(required = false)
private var beanFactory: BeanFactory? = null
@Autowired
private lateinit var mappingContext: RelationalMappingContext
@Autowired
private lateinit var converter: JdbcConverter
@Autowired
private lateinit var dataAccessStrategy: DataAccessStrategy
@Autowired(required = false)
private var queryMappingConfiguration = QueryMappingConfiguration.EMPTY
@Autowired
private lateinit var operations: NamedParameterJdbcOperations
@Autowired(required = false)
private var entityCallbacks: EntityCallbacks? = null
@Autowired
private lateinit var dialect: Dialect
@Autowired
private lateinit var jdbcTemplate: NamedParameterJdbcTemplate
@Autowired
private lateinit var jdbcConverter: JdbcConverter
override fun setApplicationEventPublisher(publisher: ApplicationEventPublisher) {
super.setApplicationEventPublisher(publisher)
this.publisher = publisher
}
override fun doCreateRepositoryFactory(): RepositoryFactorySupport {
val jdbcRepositoryFactory = CustomJdbcRepositoryFactory(
dataAccessStrategy,
mappingContext,
converter,
dialect,
publisher,
operations,
jdbcTemplate,
jdbcConverter
)
jdbcRepositoryFactory.setQueryMappingConfiguration(queryMappingConfiguration)
jdbcRepositoryFactory.setEntityCallbacks(entityCallbacks!!)
return jdbcRepositoryFactory
}
override fun setBeanFactory(beanFactory: BeanFactory) {
super.setBeanFactory(beanFactory)
this.beanFactory = beanFactory
this.entityCallbacks = EntityCallbacks.create(beanFactory)
}
}
package com.foo.util.specification.jdbc
import com.foo.entities.TransactionEntity
import com.foo.util.specification.Specification
import com.foo.util.specification.SpecificationRepository
import org.springframework.data.domain.Page
import org.springframework.data.domain.Pageable
import org.springframework.data.domain.Sort
import org.springframework.data.jdbc.core.JdbcAggregateOperations
import org.springframework.data.jdbc.core.convert.DataAccessStrategy
import org.springframework.data.jdbc.core.convert.EntityRowMapper
import org.springframework.data.jdbc.core.convert.JdbcConverter
import org.springframework.data.jdbc.repository.support.SimpleJdbcRepository
import org.springframework.data.relational.core.mapping.RelationalPersistentEntity
import org.springframework.data.relational.core.sql.render.SqlRenderer
import org.springframework.data.repository.support.PageableExecutionUtils
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate
import org.springframework.transaction.annotation.Transactional
/**
* Custom JdbcRepository implementation which supports the specification search.
*
* Extension of Springs {@link SimpleJdbcRepository}, implements {@link SpecificationRepository}.
*/
@Transactional(readOnly = true)
class CustomSimpleJdbcRepository<T, ID>(
private val template: NamedParameterJdbcTemplate,
private val accessStrategy: DataAccessStrategy,
entityOperations: JdbcAggregateOperations,
entity: RelationalPersistentEntity<T>,
jdbcConverter: JdbcConverter,
renderer: SqlRenderer
): SimpleJdbcRepository<T, ID>(entityOperations, entity), SpecificationRepository<T> {
private val queryBuilder = SpecificationQueryBuilder(entity, renderer)
private val rowMapper = EntityRowMapper(entity, jdbcConverter)
override fun findAll(specification: Specification): Iterable<T> {
return findAll(specification, Pageable.unpaged(), Sort.unsorted())
}
override fun findAll(specification: Specification, sort: Sort): Iterable<T> {
return findAll(specification, Pageable.unpaged(), sort)
}
override fun findAll(specification: Specification, pageable: Pageable): Page<T> {
val items = findAll(specification, pageable, pageable.sort)
val totalCount = accessStrategy.count(TransactionEntity::class.java)
return PageableExecutionUtils.getPage(items, pageable) { totalCount }
}
private fun findAll(specification: Specification, pageable: Pageable, sort: Sort): List<T> {
val query = queryBuilder.query(specification, pageable, sort)
return template.query(query) { resultSet, i -> rowMapper.mapRow(resultSet, i)}
}
}
package com.foo.repositories
import com.foo.entities.FooEntity
import com.foo.util.specification.Specification
import com.foo.util.specification.Specification.Companion.columnIsEqualToIfNotNull
import com.foo.util.specification.SpecificationRepository
import com.foo.util.validation.CombinedNotNull
import org.springframework.data.jdbc.repository.query.Query
import org.springframework.data.relational.core.sql.Condition
import org.springframework.data.relational.core.sql.Table
import org.springframework.data.repository.PagingAndSortingRepository
interface FooRepository : PagingAndSortingRepository<FooEntity, Long>, SpecificationRepository<FooEntity> {
fun findAllByName(name: String): List<FooEntity>
}
data class FooSearchCriteria(
val id: String? = null,
val name: String? = null,
val groupId: Long? = null
) : Specification {
override fun toCondition(table: Table): Condition {
return Specification
.where(columnIsEqualToIfNotNull("id", id))
.and(columnIsEqualToIfNotNull("name", name))
.and(columnIsEqualToIfNotNull("group_id", groupId))
.toCondition(table)
}
}
package com.foo.util.specification
import org.springframework.data.relational.core.sql.*
import org.springframework.data.relational.core.sql.SQL.literalOf
// Examples of specifications
object Specifications {
/**
* Specification to filter on a specific column, based on a property.
*/
data class ColumnEquals<T, V>(private val property: KProperty1<T, V>, private val value: V) : Specification<T> {
override fun toCondition(table: Table, entity: RelationalPersistentEntity<T>): Condition {
val persistentProperty = entity.getRequiredPersistentProperty(property.name)
val column = table.column(persistentProperty.columnName)
return when (value) {
is Number -> column.isEqualTo(literalOf(value))
is Boolean -> column.isEqualTo(literalOf(value))
is String -> column.isEqualTo(literalOf(value))
else -> column.isEqualTo(literalOf(value.toString()))
}
}
}
/**
* Specification to filter a specific value column on a list of String, based on a property.
*/
data class ColumnIn<T, V>(private val property: KProperty1<T, V>, private val values: List<String>) : Specification<T> {
override fun toCondition(table: Table, entity: RelationalPersistentEntity<T>): Condition {
val persistentProperty = entity.getRequiredPersistentProperty(property.name)
val column = table.column(persistentProperty.columnName)
val valueLiterals = values.map { literalOf(it) }.toTypedArray()
return column.`in`(*valueLiterals)
}
}
/**
* Specification which is always true
*/
class Always<T> : Specification<T> {
override fun toCondition(table: Table, entity: RelationalPersistentEntity<T>): Condition {
return ALWAYS
}
}
/**
* Specification which is always false
*/
class Never<T> : Specification<T> {
override fun toCondition(table: Table, entity: RelationalPersistentEntity<T>): Condition {
return NEVER
}
}
val ALWAYS: Condition = Conditions.isEqual(literalOf(true), literalOf(true))
val NEVER: Condition = Conditions.isEqual(literalOf(true), literalOf(false))
}
// Kotlin helpers
infix fun <T, V> KProperty1<T, V>.ifGivenIsEqualTo(value: V?): Specification<T> {
return if (value == null) Always() else ColumnEquals(this, value)
}
infix fun <T> KProperty1<T, String>.ifGivenIsIn(values: List<String>?): Specification<T> {
return if (values == null) Always() else ColumnIn(this, values)
}
package com.foo.util.specification.jdbc
import com.foo.util.specification.Specification
import org.springframework.data.domain.Pageable
import org.springframework.data.domain.Sort
import org.springframework.data.relational.core.mapping.RelationalPersistentEntity
import org.springframework.data.relational.core.sql.*
import org.springframework.data.relational.core.sql.Expressions.asterisk
import org.springframework.data.relational.core.sql.SelectBuilder.*
import org.springframework.data.relational.core.sql.render.SqlRenderer
/**
* This class generates actual SQL queries out of specifications.
*/
class SpecificationQueryBuilder<T>(
private val entity: RelationalPersistentEntity<T>,
private val renderer: SqlRenderer
) {
fun query(specification: Specification, pageable: Pageable, sort: Sort): String {
val table = Table.create(entity.tableName)
val query = table
.selectAll()
.applyLimitAndOffset(pageable)
.applyCriteria(specification, table)
.applyOrderBy(sort, table)
.build()
return renderer.render(query)
}
private fun Table.selectAll(): SelectLimitOffset {
return Select.builder().select(asterisk()).from(this)
}
private fun SelectLimitOffset.applyLimitAndOffset(pageable: Pageable): SelectWhere {
return if (pageable.isPaged) {
this.limit(pageable.pageSize.toLong()).offset(pageable.offset)
} else {
this
} as SelectWhere
}
private fun SelectWhere.applyCriteria(specification: Specification, table: Table): SelectOrdered {
return this.where(specification.toCondition(table))
}
private fun SelectOrdered.applyOrderBy(sort: Sort, table: Table): SelectOrdered {
return if (sort.isSorted) {
this.orderBy(sort.toListOfFields(table))
} else {
this
}
}
private fun Sort.toListOfFields(table: Table): List<OrderByField> {
return this.map {
val columnName: SqlIdentifier = entity
.getRequiredPersistentProperty(it.property)
.columnName
val orderBy = OrderByField
.from(table.column(columnName))
.withNullHandling(it.nullHandling)
if (it.isAscending) orderBy.asc() else orderBy.desc()
}.toList()
}
}
package com.foo.util.specification
import org.springframework.data.domain.Page
import org.springframework.data.domain.Pageable
import org.springframework.data.domain.Sort
/**
* Repository interface which accepts specifications to search for.
* Also supports paging and sorting.
*/
interface SpecificationRepository<T> {
fun findAll(specification: Specification): Iterable<T>
fun findAll(specification: Specification, sort: Sort): Iterable<T>
fun findAll(specification: Specification, pageable: Pageable): Page<T>
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment