Skip to content

Instantly share code, notes, and snippets.

@juliano
Created April 29, 2016 13:30
Show Gist options
  • Save juliano/86c269af2041e30699466b611fef1823 to your computer and use it in GitHub Desktop.
Save juliano/86c269af2041e30699466b611fef1823 to your computer and use it in GitHub Desktop.
package io.getquill.sources.jdbc
import java.sql.{ Connection, PreparedStatement, ResultSet }
import com.typesafe.scalalogging.Logger
import io.getquill.JdbcSourceConfig
import io.getquill.naming.NamingStrategy
import io.getquill.sources.BindedStatementBuilder
import io.getquill.sources.sql.SqlSource
import io.getquill.sources.sql.idiom.SqlIdiom
import org.slf4j.LoggerFactory
import scala.annotation.tailrec
import scala.util.{ DynamicVariable, Try }
class JdbcSource[D <: SqlIdiom, N <: NamingStrategy](config: JdbcSourceConfig[D, N])
extends SqlSource[D, N, ResultSet, BindedStatementBuilder[PreparedStatement]]
with JdbcEncoders
with JdbcDecoders {
protected val logger: Logger =
Logger(LoggerFactory.getLogger(classOf[JdbcSource[_, _]]))
type QueryResult[T] = List[T]
type ActionResult[T] = Long
type BatchedActionResult[T] = List[Long]
class ActionApply[T](f: List[T] => List[Long]) extends Function1[List[T], List[Long]] {
def apply(params: List[T]) = f(params)
def apply(param: T) = f(List(param)).head
}
private val dataSource = config.dataSource
override def close = dataSource.close
private val currentConnection = new DynamicVariable[Option[Connection]](None)
protected def withConnection[T](f: Connection => T) =
currentConnection.value.map(f).getOrElse {
val conn = dataSource.getConnection
try f(conn)
finally conn.close
}
def probe(sql: String) =
withConnection { conn =>
Try {
conn.createStatement.execute(sql)
}
}
def transaction[T](f: JdbcSource[D, N] => T) =
withConnection { conn =>
currentConnection.withValue(Some(conn)) {
new TransactionalJdbcSource(config, conn, f)
}
}
def execute(sql: String, bind: BindedStatementBuilder[PreparedStatement] => BindedStatementBuilder[PreparedStatement], generated: Option[String] = None): Long = {
logger.info(sql)
val (expanded, setValues) = bind(new BindedStatementBuilder[PreparedStatement]).build(sql)
logger.info(expanded)
withConnection { conn =>
generated match {
case None =>
val ps = setValues(conn.prepareStatement(expanded))
ps.executeUpdate.toLong
case Some(column) =>
val ps = setValues(conn.prepareStatement(expanded, Array(column)))
val rs = ps.executeUpdate
extractResult(ps.getGeneratedKeys, _.getLong(1)).head
}
}
}
def executeBatch[T](sql: String, bindParams: T => BindedStatementBuilder[PreparedStatement] => BindedStatementBuilder[PreparedStatement],
generated: Option[String] = None): ActionApply[T] = {
val func = { (values: List[T]) =>
val groups = values.map(bindParams(_)(new BindedStatementBuilder[PreparedStatement]).build(sql)).groupBy(_._1)
(for ((sql, setValues) <- groups.toList) yield {
logger.info(sql)
withConnection { conn =>
val ps = generated.fold(conn.prepareStatement(sql))(c => conn.prepareStatement(sql, Array(c)))
for ((_, set) <- setValues) {
set(ps)
ps.addBatch
}
val updateCount = ps.executeBatch.toList.map(_.toLong)
generated.fold(updateCount)(_ => extractResult(ps.getGeneratedKeys, _.getLong(1)))
}
}).flatten
}
new ActionApply(func)
}
def query[T](sql: String, bind: BindedStatementBuilder[PreparedStatement] => BindedStatementBuilder[PreparedStatement], extractor: ResultSet => T): List[T] = {
val (expanded, setValues) = bind(new BindedStatementBuilder[PreparedStatement]).build(sql)
logger.info(expanded)
withConnection { conn =>
val ps = setValues(conn.prepareStatement(expanded))
val rs = ps.executeQuery
extractResult(rs, extractor)
}
}
@tailrec
private def extractResult[T](rs: ResultSet, extractor: ResultSet => T, acc: List[T] = List()): List[T] =
if (rs.next)
extractResult(rs, extractor, acc :+ extractor(rs))
else
acc
}
package io.getquill.sources.jdbc
import java.sql.Connection
import io.getquill.JdbcSourceConfig
import io.getquill.naming.NamingStrategy
import io.getquill.sources.sql.idiom.SqlIdiom
import scala.util.control.NonFatal
class TransactionalJdbcSource[D <: SqlIdiom, N <: NamingStrategy, T](config: JdbcSourceConfig[D, N], conn: Connection, f: JdbcSource[D, N] => T)
extends JdbcSource[D, N](config) {
conn.setAutoCommit(false)
try {
val res = f(this)
conn.commit
res
} catch {
case NonFatal(e) =>
conn.rollback
throw e
} finally {
conn.setAutoCommit(true)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment