Skip to content

Instantly share code, notes, and snippets.

@fracaron
Forked from ps-feng/BatchSelect.kt
Created September 12, 2019 21:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save fracaron/6334ef10dac92b0ecbf97235b2e164d8 to your computer and use it in GitHub Desktop.
Save fracaron/6334ef10dac92b0ecbf97235b2e164d8 to your computer and use it in GitHub Desktop.
Batch select in Exposed, inspired by Rails' find_in_batches
package org.jetbrains.exposed.sql
import org.jetbrains.exposed.dao.EntityID
import org.jetbrains.exposed.sql.FieldSet
import org.jetbrains.exposed.sql.Op
import org.jetbrains.exposed.sql.ResultRow
import org.jetbrains.exposed.sql.SortOrder
import org.jetbrains.exposed.sql.SqlExpressionBuilder
import org.jetbrains.exposed.sql.and
import org.jetbrains.exposed.sql.isAutoInc
import org.jetbrains.exposed.sql.select
fun FieldSet.selectBatched(
batchSize: Int = 1000,
where: SqlExpressionBuilder.() -> Op<Boolean>
): Iterable<Iterable<ResultRow>> {
return selectBatched(batchSize, SqlExpressionBuilder.where())
}
fun FieldSet.selectAllBatched(
batchSize: Int = 1000
): Iterable<Iterable<ResultRow>> {
return selectBatched(batchSize, Op.TRUE)
}
private fun FieldSet.selectBatched(
batchSize: Int = 1000,
whereOp: Op<Boolean>
): Iterable<Iterable<ResultRow>> {
require(batchSize > 0) { "Batch size should be greater than 0" }
val autoIncColumn = try {
source.columns.first { it.columnType.isAutoInc }
} catch (_: NoSuchElementException) {
throw UnsupportedOperationException("Batched select only works on tables with an autoincrementing column")
}
return object : Iterable<Iterable<ResultRow>> {
override fun iterator(): Iterator<Iterable<ResultRow>> {
return iterator {
var lastOffset = 0L
while (true) {
val query =
select { whereOp and (autoIncColumn greater lastOffset) }
.limit(batchSize)
.orderBy(autoIncColumn, SortOrder.ASC)
// query.iterator() executes the query
val results = query.iterator().asSequence().toList()
if (results.isEmpty()) break
yield(results)
lastOffset = toLong(results.last()[autoIncColumn]!!)
}
}
}
private fun toLong(autoIncVal: Any): Long = when (autoIncVal) {
is EntityID<*> -> autoIncVal.value as Long
else -> autoIncVal as Long
}
}
}
package org.jetbrains.exposed.sql
import org.jetbrains.exposed.config.KMySqlContainer
import io.kotlintest.properties.Gen
import io.kotlintest.properties.forAll
import io.kotlintest.shouldBe
import io.kotlintest.shouldThrow
import java.security.SecureRandom
import java.util.*
import org.jetbrains.exposed.dao.LongIdTable
import org.jetbrains.exposed.sql.Database
import org.jetbrains.exposed.sql.ResultRow
import org.jetbrains.exposed.sql.SchemaUtils
import org.jetbrains.exposed.sql.Table
import org.jetbrains.exposed.sql.batchInsert
import org.jetbrains.exposed.sql.transactions.transaction
import org.junit.jupiter.api.AfterEach
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
internal class BatchSelectTest {
object Cities : LongIdTable() {
val name = varchar("name", length = 50)
val type = integer("type")
}
object TableWithNoAutoIncCol : Table() {
val name = varchar("name", length = 50).primaryKey()
}
data class City(
val name: String,
val type: CityType
)
enum class CityType {
SMALL,
BIG
}
@BeforeEach
fun setUp() {
transaction {
SchemaUtils.create(Cities)
}
}
@AfterEach
fun tearDown() {
transaction {
SchemaUtils.drop(Cities)
}
}
@Test
fun `should respect 'where' expression and the provided batch size`() {
transaction {
val smallCities = generateCities(count = 50, cityType = CityType.SMALL)
val bigCities = generateCities(count = 50, cityType = CityType.BIG)
insert(smallCities)
insert(bigCities)
val batches = Cities.selectBatched(batchSize = 25) {
Cities.type eq CityType.SMALL.ordinal
}
.toList()
.map { it.toCityList() }
batches shouldBe listOf(
smallCities.take(25),
smallCities.takeLast(25)
)
}
}
@Test
fun `when batch size is greater than the amount of available items, should return 1 batch`() {
transaction {
val cities = generateCities(count = 25, cityType = CityType.SMALL)
insert(cities)
val batches = Cities.selectBatched(batchSize = 100) {
Cities.type eq CityType.SMALL.ordinal
}
.toList()
.map { it.toCityList() }
batches shouldBe listOf(cities)
}
}
@Test
fun `when selecting all by batches, should return all available items`() {
transaction {
val smallCities = generateCities(count = 30, cityType = CityType.SMALL)
val bigCities = generateCities(count = 30, cityType = CityType.BIG)
insert(smallCities)
insert(bigCities)
val batches = Cities.selectAllBatched(batchSize = 30)
.toList()
.map { it.toCityList() }
batches shouldBe listOf(
smallCities,
bigCities
)
}
}
@Test
fun `when there are no items, should return an empty iterable`() {
transaction {
val batches = Cities.selectAllBatched()
.toList()
batches shouldBe emptyList()
}
}
@Test
fun `when there are no items of the given condition, should return an empty iterable`() {
transaction {
val cities = generateCities(count = 25, cityType = CityType.SMALL)
insert(cities)
val batches = Cities.selectBatched(batchSize = 100) {
Cities.type eq CityType.BIG.ordinal
}.toList()
batches shouldBe emptyList()
}
}
@Test
fun `when the table doesn't have an autoinc column, should throw an exception`() {
shouldThrow<UnsupportedOperationException> {
TableWithNoAutoIncCol.selectAllBatched()
}
}
@Test
fun `when batch size is 0 or less, should throw an exception`() {
val zeroOrNegativeIntGenerator = object : Gen<Int> {
override fun constants(): Iterable<Int> = listOf(0)
override fun random(): Sequence<Int> = generateSequence {
secureRandom.nextInt(Integer.MAX_VALUE) - Integer.MAX_VALUE
}
}
forAll(zeroOrNegativeIntGenerator) { size ->
runCatching { Cities.selectAllBatched(batchSize = size) }
.exceptionOrNull() is IllegalArgumentException
}
}
private fun generateCities(count: Int = 50, cityType: CityType = CityType.SMALL): List<City> {
return List(count) {
City(
name = UUID.randomUUID().toString(),
type = cityType
)
}
}
private fun insert(cities: List<City>) {
Cities.batchInsert(cities) { (name, type) ->
this[Cities.name] = name
this[Cities.type] = type.ordinal
}
}
private fun ResultRow.toCity() = City(
name = this[Cities.name],
type = CityType.values()[this[Cities.type]]
)
private fun Iterable<ResultRow>.toCityList(): List<City> = map { it.toCity() }
companion object {
private val mySqlContainer = KMySqlContainer()
private val secureRandom = SecureRandom()
init {
mySqlContainer.start()
Database.connect(
url = mySqlContainer.jdbcUrl,
driver = "com.mysql.cj.jdbc.Driver",
user = mySqlContainer.username,
password = mySqlContainer.password
)
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment