Skip to content

Instantly share code, notes, and snippets.

@yahorbarkouski
Last active July 10, 2024 01:26
Show Gist options
  • Save yahorbarkouski/bcfbf2bf1b10ff7757f2c629eab33a46 to your computer and use it in GitHub Desktop.
Save yahorbarkouski/bcfbf2bf1b10ff7757f2c629eab33a46 to your computer and use it in GitHub Desktop.
Binding pgvector::vector type to List<Double> using jOOQ
import org.jooq.*
import org.jooq.impl.DSL
import java.sql.SQLFeatureNotSupportedException
import java.sql.Types
@Suppress("UNCHECKED_CAST")
class PGVectorBinding : Binding<Any, List<Double>> {
override fun converter(): Converter<Any, List<Double>> {
return object : Converter<Any, List<Double>> {
override fun from(databaseObject: Any?): List<Double> {
return databaseObject?.let { v ->
v.toString().removeSurrounding("[", "]").split(",").map { it.toDouble() }
} ?: emptyList()
}
override fun to(userObject: List<Double>): Any {
return userObject.toString()
}
override fun fromType(): Class<Any> = Any::class.java
override fun toType(): Class<List<Double>> = List::class.java as Class<List<Double>>
}
}
override fun sql(ctx: BindingSQLContext<List<Double>>) {
ctx.render().visit(DSL.`val`(ctx.convert(converter()).value())).sql("::vector")
}
override fun register(ctx: BindingRegisterContext<List<Double>>) {
ctx.statement().registerOutParameter(ctx.index(), Types.ARRAY)
}
override fun get(ctx: BindingGetResultSetContext<List<Double>>) {
val resultSet = ctx.resultSet()
val vectorAsString = resultSet.getString(ctx.index())
ctx.value(converter().from(vectorAsString))
}
override fun set(ctx: BindingSetStatementContext<List<Double>>) {
val value = ctx.value()
ctx.statement().setString(ctx.index(), value?.let { converter().to(it) as String } ?: "[]")
}
override fun get(ctx: BindingGetStatementContext<List<Double>>) {
val statement = ctx.statement()
val vectorAsString = statement.getString(ctx.index())
ctx.value(converter().from(vectorAsString))
}
// the below methods aren't needed in Postgres:
override fun get(ctx: BindingGetSQLInputContext<List<Double>>?) {
throw SQLFeatureNotSupportedException()
}
override fun set(ctx: BindingSetSQLOutputContext<List<Double>>?) {
throw SQLFeatureNotSupportedException()
}
}
@MarcusDunn
Copy link

In the case that

  • you're using r2dbc
  • you're willing to use driver-specific types

I got much better memory usage from the following

import io.r2dbc.postgresql.codec.Vector
import org.jooq.BindingGetResultSetContext
import org.jooq.BindingSetStatementContext
import org.jooq.Converter
import org.jooq.impl.AbstractBinding

@JvmInline
internal value class PgVector(val values: FloatArray)

internal class PgVectorBinding : AbstractBinding<Any, PgVector>() {
    override fun converter(): Converter<Any, PgVector> = Converter

    override fun get(ctx: BindingGetResultSetContext<PgVector>) {
        val resultSet = ctx.resultSet()
        val vector = resultSet.getObject(ctx.index()) as Vector
        ctx.value(PgVector(vector.vector))
    }

    override fun set(ctx: BindingSetStatementContext<PgVector>) {
        val value = ctx.value()?.let { converter().to(it) }
        ctx.statement().setObject(ctx.index(), value)
    }
}

private object Converter : Converter<Any, PgVector> {
    private fun readResolve(): Any = Converter
    override fun from(databaseObject: Any): PgVector = PgVector((databaseObject as Vector).vector)
    override fun to(userObject: PgVector): Any = Vector.of(*userObject.values)
    override fun fromType(): Class<Any> = Any::class.java
    override fun toType(): Class<PgVector> = PgVector::class.java
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment