Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Specification API for Spring JDBC
package com.foo.util.specification.jdbc
import org.springframework.context.ApplicationEventPublisher
import org.springframework.data.jdbc.core.JdbcAggregateTemplate
import org.springframework.data.jdbc.core.convert.DataAccessStrategy
import org.springframework.data.jdbc.core.convert.JdbcConverter
import org.springframework.data.jdbc.repository.support.JdbcRepositoryFactory
import org.springframework.data.mapping.callback.EntityCallbacks
import org.springframework.data.relational.core.dialect.Dialect
import org.springframework.data.relational.core.dialect.RenderContextFactory
import org.springframework.data.relational.core.mapping.RelationalMappingContext
import org.springframework.data.relational.core.sql.render.SqlRenderer
import org.springframework.data.repository.core.RepositoryInformation
import org.springframework.data.repository.core.RepositoryMetadata
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate
import org.springframework.stereotype.Component
/**
* Class needed to register {@link CustomSimpleJdbcRepository}.
*
* This makes sure that every repository is allowed to extend the SpecificationRepository
* interface and automatically uses the implementation of CustomSimpleJdbcRepository.
*/
@Component
class CustomJdbcRepositoryFactory(
private val accessStrategy: DataAccessStrategy,
private val context: RelationalMappingContext,
private val converter: JdbcConverter,
private val dialect: Dialect,
private val publisher: ApplicationEventPublisher,
operations: NamedParameterJdbcOperations,
private val jdbcTemplate: NamedParameterJdbcTemplate,
private val jdbcConverter: JdbcConverter
): JdbcRepositoryFactory(accessStrategy, context, converter, dialect, publisher, operations) {
private var entityCallbacks: EntityCallbacks = EntityCallbacks.create()
override fun getRepositoryBaseClass(repositoryMetadata: RepositoryMetadata?): Class<*>? {
return CustomSimpleJdbcRepository::class.java
}
override fun getTargetRepository(repositoryInformation: RepositoryInformation): Any? {
val template = JdbcAggregateTemplate(publisher, context, converter, accessStrategy)
val renderer = SqlRenderer.create(RenderContextFactory(dialect).createRenderContext())
val repository: CustomSimpleJdbcRepository<*, Any> = CustomSimpleJdbcRepository(
jdbcTemplate,
accessStrategy,
template,
context.getRequiredPersistentEntity(repositoryInformation.domainType),
jdbcConverter,
renderer
)
template.setEntityCallbacks(entityCallbacks)
return repository
}
override fun setEntityCallbacks(entityCallbacks: EntityCallbacks) {
this.entityCallbacks = entityCallbacks
}
}
package com.foo.util.specification.jdbc
import org.springframework.beans.factory.BeanFactory
import org.springframework.beans.factory.annotation.Autowired
import org.springframework.context.ApplicationEventPublisher
import org.springframework.context.ApplicationEventPublisherAware
import org.springframework.data.jdbc.core.convert.DataAccessStrategy
import org.springframework.data.jdbc.core.convert.JdbcConverter
import org.springframework.data.jdbc.repository.QueryMappingConfiguration
import org.springframework.data.mapping.callback.EntityCallbacks
import org.springframework.data.relational.core.dialect.Dialect
import org.springframework.data.relational.core.mapping.RelationalMappingContext
import org.springframework.data.repository.Repository
import org.springframework.data.repository.core.support.RepositoryFactorySupport
import org.springframework.data.repository.core.support.TransactionalRepositoryFactoryBeanSupport
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate
/**
* Class needed to register {@link CustomJdbcRepositoryFactory}.
*
* Inspired by {@link org.springframework.data.jdbc.repository.support.JdbcRepositoryFactoryBean}
* I really don't like this class, but it's the only way to integrate into Spring JDBC.
*/
class CustomJdbcRepositoryFactoryBean<T : Repository<S, ID>, S, ID : java.io.Serializable>(
repositoryInterface: Class<T>
): TransactionalRepositoryFactoryBeanSupport<T, S, ID>(repositoryInterface), ApplicationEventPublisherAware {
@Autowired
private lateinit var publisher: ApplicationEventPublisher
@Autowired(required = false)
private var beanFactory: BeanFactory? = null
@Autowired
private lateinit var mappingContext: RelationalMappingContext
@Autowired
private lateinit var converter: JdbcConverter
@Autowired
private lateinit var dataAccessStrategy: DataAccessStrategy
@Autowired(required = false)
private var queryMappingConfiguration = QueryMappingConfiguration.EMPTY
@Autowired
private lateinit var operations: NamedParameterJdbcOperations
@Autowired(required = false)
private var entityCallbacks: EntityCallbacks? = null
@Autowired
private lateinit var dialect: Dialect
@Autowired
private lateinit var jdbcTemplate: NamedParameterJdbcTemplate
@Autowired
private lateinit var jdbcConverter: JdbcConverter
override fun setApplicationEventPublisher(publisher: ApplicationEventPublisher) {
super.setApplicationEventPublisher(publisher)
this.publisher = publisher
}
override fun doCreateRepositoryFactory(): RepositoryFactorySupport {
val jdbcRepositoryFactory = CustomJdbcRepositoryFactory(
dataAccessStrategy,
mappingContext,
converter,
dialect,
publisher,
operations,
jdbcTemplate,
jdbcConverter
)
jdbcRepositoryFactory.setQueryMappingConfiguration(queryMappingConfiguration)
jdbcRepositoryFactory.setEntityCallbacks(entityCallbacks!!)
return jdbcRepositoryFactory
}
override fun setBeanFactory(beanFactory: BeanFactory) {
super.setBeanFactory(beanFactory)
this.beanFactory = beanFactory
this.entityCallbacks = EntityCallbacks.create(beanFactory)
}
}
package com.foo.util.specification.jdbc
import com.foo.entities.TransactionEntity
import com.foo.util.specification.Specification
import com.foo.util.specification.SpecificationRepository
import org.springframework.data.domain.Page
import org.springframework.data.domain.Pageable
import org.springframework.data.domain.Sort
import org.springframework.data.jdbc.core.JdbcAggregateOperations
import org.springframework.data.jdbc.core.convert.DataAccessStrategy
import org.springframework.data.jdbc.core.convert.EntityRowMapper
import org.springframework.data.jdbc.core.convert.JdbcConverter
import org.springframework.data.jdbc.repository.support.SimpleJdbcRepository
import org.springframework.data.relational.core.mapping.RelationalPersistentEntity
import org.springframework.data.relational.core.sql.render.SqlRenderer
import org.springframework.data.repository.support.PageableExecutionUtils
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate
import org.springframework.transaction.annotation.Transactional
/**
* Custom JdbcRepository implementation which supports the specification search.
*
* Extension of Springs {@link SimpleJdbcRepository}, implements {@link SpecificationRepository}.
*/
@Transactional(readOnly = true)
class CustomSimpleJdbcRepository<T, ID>(
private val template: NamedParameterJdbcTemplate,
private val accessStrategy: DataAccessStrategy,
entityOperations: JdbcAggregateOperations,
entity: RelationalPersistentEntity<T>,
jdbcConverter: JdbcConverter,
renderer: SqlRenderer
): SimpleJdbcRepository<T, ID>(entityOperations, entity), SpecificationRepository<T> {
private val queryBuilder = SpecificationQueryBuilder(entity, renderer)
private val rowMapper = EntityRowMapper(entity, jdbcConverter)
override fun findAll(specification: Specification): Iterable<T> {
return findAll(specification, Pageable.unpaged(), Sort.unsorted())
}
override fun findAll(specification: Specification, sort: Sort): Iterable<T> {
return findAll(specification, Pageable.unpaged(), sort)
}
override fun findAll(specification: Specification, pageable: Pageable): Page<T> {
val items = findAll(specification, pageable, pageable.sort)
val totalCount = accessStrategy.count(TransactionEntity::class.java)
return PageableExecutionUtils.getPage(items, pageable) { totalCount }
}
private fun findAll(specification: Specification, pageable: Pageable, sort: Sort): List<T> {
val query = queryBuilder.query(specification, pageable, sort)
return template.query(query) { resultSet, i -> rowMapper.mapRow(resultSet, i)}
}
}
package com.foo.repositories
import com.foo.entities.FooEntity
import com.foo.util.specification.Specification
import com.foo.util.specification.Specification.Companion.columnIsEqualToIfNotNull
import com.foo.util.specification.SpecificationRepository
import com.foo.util.validation.CombinedNotNull
import org.springframework.data.jdbc.repository.query.Query
import org.springframework.data.relational.core.sql.Condition
import org.springframework.data.relational.core.sql.Table
import org.springframework.data.repository.PagingAndSortingRepository
interface FooRepository : PagingAndSortingRepository<FooEntity, Long>, SpecificationRepository<FooEntity> {
fun findAllByName(name: String): List<FooEntity>
}
data class FooSearchCriteria(
val id: String? = null,
val name: String? = null,
val groupId: Long? = null
) : Specification {
override fun toCondition(table: Table): Condition {
return Specification
.where(columnIsEqualToIfNotNull("id", id))
.and(columnIsEqualToIfNotNull("name", name))
.and(columnIsEqualToIfNotNull("group_id", groupId))
.toCondition(table)
}
}
package com.foo.util.specification
import org.springframework.data.relational.core.sql.*
import org.springframework.data.relational.core.sql.SQL.literalOf
/**
* DSL to specify filter criteria for a query. Inspired by JPA.
*/
interface Specification {
fun toCondition(table: Table): Condition
fun or(other: Specification): Specification {
return compose(this, other) { left, right -> left.or(right) }
}
fun and(other: Specification): Specification {
return compose(this, other) { left, right -> left.and(right) }
}
companion object {
fun where(specification: Specification): Specification {
return specification
}
fun and(vararg specifications: Specification): Specification {
return specifications.reduce { left, right -> left.and(right) }
}
fun or(vararg specifications: Specification): Specification {
return specifications.reduce { left, right -> left.or(right) }
}
fun columnIsEqualToIfNotNull(columnName: String, value: String?): Specification {
return StringColumnIsEqualToIfNotNull(columnName, value)
}
fun columnIsEqualToIfNotNull(columnName: String, value: Number?): Specification {
return IntColumnIsEqualToIfNotNull(columnName, value)
}
fun columnIsEqualToIfNotNull(columnName: String, value: Boolean?): Specification {
return BooleanColumnIsEqualToIfNotNull(columnName, value)
}
private fun compose(left: Specification, right: Specification, combiner: (left: Condition, right: Condition) -> Condition): Specification {
return object : Specification {
override fun toCondition(table: Table): Condition {
return combiner(right.toCondition(table), left.toCondition(table))
}
}
}
}
}
// Examples of specifications
/**
* Specification to filter on a specific string column if the passed value is not null.
*/
data class StringColumnIsEqualToIfNotNull(val columnName: String, val value: String?): Specification {
override fun toCondition(table: Table): Condition {
if (value == null) return None.toCondition(table)
return table.column(columnName).isEqualTo(literalOf(value))
}
}
/**
* Specification to filter on a specific boolean column if the passed value is not null.
*/
data class BooleanColumnIsEqualToIfNotNull(val columnName: String, val value: Boolean?): Specification {
override fun toCondition(table: Table): Condition {
if (value == null) return None.toCondition(table)
return table.column(columnName).isEqualTo(literalOf(value))
}
}
/**
* Specification to filter on a specific numeric column if the passed value is not null.
*/
data class IntColumnIsEqualToIfNotNull(val columnName: String, val value: Number?): Specification {
override fun toCondition(table: Table): Condition {
if (value == null) return None.toCondition(table)
return table.column(columnName).isEqualTo(literalOf(value))
}
}
/**
* Specification which is always true
*/
object None: Specification {
override fun toCondition(table: Table): Condition {
// SQL: 1 = 1
return Conditions.isEqual(literalOf(1), literalOf(1))
}
}
package com.foo.util.specification.jdbc
import com.foo.util.specification.Specification
import org.springframework.data.domain.Pageable
import org.springframework.data.domain.Sort
import org.springframework.data.relational.core.mapping.RelationalPersistentEntity
import org.springframework.data.relational.core.sql.*
import org.springframework.data.relational.core.sql.Expressions.asterisk
import org.springframework.data.relational.core.sql.SelectBuilder.*
import org.springframework.data.relational.core.sql.render.SqlRenderer
/**
* This class generates actual SQL queries out of specifications.
*/
class SpecificationQueryBuilder<T>(
private val entity: RelationalPersistentEntity<T>,
private val renderer: SqlRenderer
) {
fun query(specification: Specification, pageable: Pageable, sort: Sort): String {
val table = Table.create(entity.tableName)
val query = table
.selectAll()
.applyLimitAndOffset(pageable)
.applyCriteria(specification, table)
.applyOrderBy(sort, table)
.build()
return renderer.render(query)
}
private fun Table.selectAll(): SelectLimitOffset {
return Select.builder().select(asterisk()).from(this)
}
private fun SelectLimitOffset.applyLimitAndOffset(pageable: Pageable): SelectWhere {
return if (pageable.isPaged) {
this.limit(pageable.pageSize.toLong()).offset(pageable.offset)
} else {
this
} as SelectWhere
}
private fun SelectWhere.applyCriteria(specification: Specification, table: Table): SelectOrdered {
return this.where(specification.toCondition(table))
}
private fun SelectOrdered.applyOrderBy(sort: Sort, table: Table): SelectOrdered {
return if (sort.isSorted) {
this.orderBy(sort.toListOfFields(table))
} else {
this
}
}
private fun Sort.toListOfFields(table: Table): List<OrderByField> {
return this.map {
val columnName: SqlIdentifier = entity
.getRequiredPersistentProperty(it.property)
.columnName
val orderBy = OrderByField
.from(table.column(columnName))
.withNullHandling(it.nullHandling)
if (it.isAscending) orderBy.asc() else orderBy.desc()
}.toList()
}
}
package com.foo.util.specification
import org.springframework.data.domain.Page
import org.springframework.data.domain.Pageable
import org.springframework.data.domain.Sort
/**
* Repository interface which accepts specifications to search for.
* Also supports paging and sorting.
*/
interface SpecificationRepository<T> {
fun findAll(specification: Specification): Iterable<T>
fun findAll(specification: Specification, sort: Sort): Iterable<T>
fun findAll(specification: Specification, pageable: Pageable): Page<T>
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.