Skip to content

Instantly share code, notes, and snippets.

@jilen
Last active September 27, 2018 11:31
Show Gist options
  • Save jilen/e225e87bb92113f8b54d3e3a0988dca0 to your computer and use it in GitHub Desktop.
Save jilen/e225e87bb92113f8b54d3e3a0988dca0 to your computer and use it in GitHub Desktop.
package io.getquill.context.async
import com.github.mauricio.async.db.Connection
import com.github.mauricio.async.db.{ QueryResult => DBQueryResult }
import com.github.mauricio.async.db.RowData
import com.github.mauricio.async.db.pool.PartitionedConnectionPool
import scala.concurrent.Await
import scala.concurrent.ExecutionContext
import scala.concurrent.Future
import scala.concurrent.duration.Duration
import scala.util.Try
import io.getquill.context.sql.SqlContext
import io.getquill.context.sql.idiom.SqlIdiom
import io.getquill.NamingStrategy
import io.getquill.util.ContextLogger
import io.getquill.monad.IOMonad
import io.getquill.context.Context
abstract class AsyncIOContext[D <: SqlIdiom, N <: NamingStrategy, C <: Connection](val idiom: D, val naming: N, pool: PartitionedConnectionPool[C])
extends Context[D, N]
with SqlContext[D, N] with IOMonad with IOEncoders with IODecoders {
private val logger = ContextLogger(classOf[AsyncContext[_, _, _]])
override type PrepareRow = List[Any]
override type ResultRow = RowData
override type Result[T] = C => Future[T]
override type RunQueryResult[T] = List[T]
override type RunQuerySingleResult[T] = T
override type RunActionResult = Long
override type RunActionReturningResult[T] = T
override type RunBatchActionResult = List[Long]
override type RunBatchActionReturningResult[T] = List[T]
def performIO[T](io: IO[T, _], transactional: Boolean = false)(implicit ec: ExecutionContext): Future[T] = {
def performInternal[A](i: IO[A, _]): Result[A] = i match {
case FromTry(v) => { c: C =>
Future.fromTry(v)
}
case Run(f) => f()
case Sequence(in, cbfIOToResult, cbfResultToValue) =>
{ c: C =>
val b = cbfIOToResult()
val fs = in.map(ii => performInternal(ii)).foreach { f =>
b += f
}
val init = cbfResultToValue()
val fut = b.result.foldLeft(Future.successful(init)) { (r, f) =>
for {
vb <- r
fr <- f(c)
} yield vb += fr
}
fut.map(_.result)
}
case TransformWith(a, fA) =>
{ c: C =>
performInternal(a)(c)
.map(scala.util.Success(_))
.recover { case ex => scala.util.Failure(ex) }
.flatMap(v => performInternal(fA(v))(c))
}
case Transactional(io) =>
performInternal(io)
}
if (transactional) {
pool.inTransaction(c => performInternal(io)(c.asInstanceOf[C]))
} else {
withConnection(c => performInternal(io)(c.asInstanceOf[C]))
}
}
override def close = {
Await.result(pool.close, Duration.Inf)
()
}
protected def withConnection[T](f: Connection => Future[T])(implicit ec: ExecutionContext) =
ec match {
case TransactionalExecutionContext(ec, conn) => f(conn)
case other => f(pool)
}
protected def extractActionResult[O](returningColumn: String, extractor: Extractor[O])(result: DBQueryResult): O
protected def expandAction(sql: String, returningColumn: String) = sql
def probe(sql: String) =
Try {
Await.result(pool.sendQuery(sql), Duration.Inf)
}
def transaction[T](f: TransactionalExecutionContext => Future[T])(implicit ec: ExecutionContext) =
pool.inTransaction { c =>
f(TransactionalExecutionContext(ec, c))
}
def executeQuery[T](sql: String, prepare: Prepare = identityPrepare, extractor: Extractor[T] = identityExtractor)(implicit ec: ExecutionContext): Result[List[T]] = { c: C =>
val (params, values) = prepare(Nil)
logger.logQuery(sql, params)
c.sendPreparedStatement(sql, values).map {
_.rows match {
case Some(rows) => rows.map(extractor).toList
case None => Nil
}
}
}
def executeQuerySingle[T](sql: String, prepare: Prepare = identityPrepare, extractor: Extractor[T] = identityExtractor)(implicit ec: ExecutionContext): Result[T] = { c: C =>
executeQuery(sql, prepare, extractor)(ec)(c).map(handleSingleResult)
}
def executeAction[T](sql: String, prepare: Prepare = identityPrepare)(implicit ec: ExecutionContext): Result[Long] = { c =>
val (params, values) = prepare(Nil)
logger.logQuery(sql, params)
c.sendPreparedStatement(sql, values).map(_.rowsAffected)
}
def executeActionReturning[T](sql: String, prepare: Prepare = identityPrepare, extractor: Extractor[T], returningColumn: String)(implicit ec: ExecutionContext): Result[T] = { c =>
val expanded = expandAction(sql, returningColumn)
val (params, values) = prepare(Nil)
logger.logQuery(sql, params)
c.sendPreparedStatement(expanded, values).map(extractActionResult(returningColumn, extractor))
}
def executeBatchAction(groups: List[BatchGroup])(implicit ec: ExecutionContext): Result[List[Long]] = { c: C =>
Future.sequence {
groups.map {
case BatchGroup(sql, prepare) =>
prepare.foldLeft(Future.successful(List.newBuilder[Long])) {
case (acc, prepare) =>
acc.flatMap { list =>
executeAction(sql, prepare)(ec)(c).map(list += _)
}
}.map(_.result())
}
}.map(_.flatten.toList)
}
def executeBatchActionReturning[T](groups: List[BatchGroupReturning], extractor: Extractor[T])(implicit ec: ExecutionContext): Result[List[T]] =
{ c: C =>
Future.sequence {
groups.map {
case BatchGroupReturning(sql, column, prepare) =>
prepare.foldLeft(Future.successful(List.newBuilder[T])) {
case (acc, prepare) =>
acc.flatMap { list =>
executeActionReturning(sql, prepare, extractor, column)(ec)(c).map(list += _)
}
}.map(_.result())
}
}.map(_.flatten.toList)
}
}
package io.getquill.context.async
import java.time._
import java.util.Date
import io.getquill.util.Messages.fail
import scala.reflect.{ ClassTag, classTag }
import org.joda.time.{ DateTimeZone => JodaDateTimeZone, DateTime => JodaDateTime, LocalDate => JodaLocalDate, LocalDateTime => JodaLocalDateTime }
trait IOEncoders {
this: AsyncIOContext[_, _, _] =>
type Encoder[T] = AsyncEncoder[T]
type EncoderSqlType = SqlTypes.SqlTypes
case class AsyncEncoder[T](sqlType: DecoderSqlType)(implicit encoder: BaseEncoder[T])
extends BaseEncoder[T] {
override def apply(index: Index, value: T, row: PrepareRow) =
encoder.apply(index, value, row)
}
def encoder[T](sqlType: DecoderSqlType): Encoder[T] =
encoder(identity[T], sqlType)
def encoder[T](f: T => Any, sqlType: DecoderSqlType): Encoder[T] =
AsyncEncoder[T](sqlType)(new BaseEncoder[T] {
def apply(index: Index, value: T, row: PrepareRow) =
row :+ f(value)
})
implicit def mappedEncoder[I, O](implicit mapped: MappedEncoding[I, O], e: Encoder[O]): Encoder[I] =
AsyncEncoder(e.sqlType)(new BaseEncoder[I] {
def apply(index: Index, value: I, row: PrepareRow) =
e(index, mapped.f(value), row)
})
implicit def optionEncoder[T](implicit d: Encoder[T]): Encoder[Option[T]] =
AsyncEncoder(d.sqlType)(new BaseEncoder[Option[T]] {
def apply(index: Index, value: Option[T], row: PrepareRow) = {
value match {
case None => nullEncoder(index, null, row)
case Some(v) => d(index, v, row)
}
}
})
private[this] val nullEncoder: Encoder[Null] = encoder[Null](SqlTypes.NULL)
implicit val stringEncoder: Encoder[String] = encoder[String](SqlTypes.VARCHAR)
implicit val bigDecimalEncoder: Encoder[BigDecimal] = encoder[BigDecimal](SqlTypes.REAL)
implicit val booleanEncoder: Encoder[Boolean] = encoder[Boolean](SqlTypes.BOOLEAN)
implicit val byteEncoder: Encoder[Byte] = encoder[Byte](SqlTypes.TINYINT)
implicit val shortEncoder: Encoder[Short] = encoder[Short](SqlTypes.SMALLINT)
implicit val intEncoder: Encoder[Int] = encoder[Int](SqlTypes.INTEGER)
implicit val longEncoder: Encoder[Long] = encoder[Long](SqlTypes.BIGINT)
implicit val floatEncoder: Encoder[Float] = encoder[Float](SqlTypes.FLOAT)
implicit val doubleEncoder: Encoder[Double] = encoder[Double](SqlTypes.DOUBLE)
implicit val byteArrayEncoder: Encoder[Array[Byte]] = encoder[Array[Byte]](SqlTypes.VARBINARY)
implicit val jodaDateTimeEncoder: Encoder[JodaDateTime] = encoder[JodaDateTime](SqlTypes.TIMESTAMP)
implicit val jodaLocalDateEncoder: Encoder[JodaLocalDate] = encoder[JodaLocalDate](SqlTypes.DATE)
implicit val jodaLocalDateTimeEncoder: Encoder[JodaLocalDateTime] = encoder[JodaLocalDateTime](SqlTypes.TIMESTAMP)
implicit val dateEncoder: Encoder[Date] = encoder[Date]((d: Date) => new JodaLocalDateTime(d), SqlTypes.TIMESTAMP)
implicit val encodeZonedDateTime: MappedEncoding[ZonedDateTime, JodaDateTime] =
MappedEncoding(zdt => new JodaDateTime(zdt.toInstant.toEpochMilli, JodaDateTimeZone.forID(zdt.getZone.getId)))
implicit val encodeLocalDate: MappedEncoding[LocalDate, JodaLocalDate] =
MappedEncoding(ld => new JodaLocalDate(ld.getYear, ld.getMonthValue, ld.getDayOfMonth))
implicit val encodeLocalDateTime: MappedEncoding[LocalDateTime, JodaLocalDateTime] =
MappedEncoding(ldt => new JodaLocalDateTime(ldt.atZone(ZoneId.systemDefault()).toInstant.toEpochMilli))
implicit val localDateEncoder: Encoder[LocalDate] = mappedEncoder(encodeLocalDate, jodaLocalDateEncoder)
}
trait IODecoders {
this: AsyncIOContext[_, _, _] =>
type Decoder[T] = AsyncDecoder[T]
type DecoderSqlType = SqlTypes.SqlTypes
case class AsyncDecoder[T](sqlType: DecoderSqlType)(implicit decoder: BaseDecoder[T])
extends BaseDecoder[T] {
override def apply(index: Index, row: ResultRow) =
decoder(index, row)
}
def decoder[T: ClassTag](
f: PartialFunction[Any, T] = PartialFunction.empty,
sqlType: DecoderSqlType
): Decoder[T] =
AsyncDecoder[T](sqlType)(new BaseDecoder[T] {
def apply(index: Index, row: ResultRow) = {
row(index) match {
case value: T => value
case value if f.isDefinedAt(value) => f(value)
case value =>
fail(
s"Value '$value' at index $index can't be decoded to '${classTag[T].runtimeClass}'"
)
}
}
})
implicit def mappedDecoder[I, O](implicit mapped: MappedEncoding[I, O], decoder: Decoder[I]): Decoder[O] =
AsyncDecoder(decoder.sqlType)(new BaseDecoder[O] {
def apply(index: Index, row: ResultRow): O =
mapped.f(decoder.apply(index, row))
})
trait NumericDecoder[T] extends BaseDecoder[T] {
def apply(index: Index, row: ResultRow) =
row(index) match {
case v: Byte => decode(v)
case v: Short => decode(v)
case v: Int => decode(v)
case v: Long => decode(v)
case v: Float => decode(v)
case v: Double => decode(v)
case v: BigDecimal => decode(v)
case other =>
fail(s"Value $other is not numeric")
}
def decode[U](v: U)(implicit n: Numeric[U]): T
}
implicit def optionDecoder[T](implicit d: Decoder[T]): Decoder[Option[T]] =
AsyncDecoder(d.sqlType)(new BaseDecoder[Option[T]] {
def apply(index: Index, row: ResultRow) = {
row(index) match {
case null => None
case value => Some(d(index, row))
}
}
})
implicit val stringDecoder: Decoder[String] = decoder[String](PartialFunction.empty, SqlTypes.VARCHAR)
implicit val bigDecimalDecoder: Decoder[BigDecimal] =
AsyncDecoder(SqlTypes.REAL)(new NumericDecoder[BigDecimal] {
def decode[U](v: U)(implicit n: Numeric[U]) =
BigDecimal(n.toDouble(v))
})
implicit val booleanDecoder: Decoder[Boolean] =
decoder[Boolean]({
case v: Byte => v == (1: Byte)
case v: Short => v == (1: Short)
case v: Int => v == 1
case v: Long => v == 1L
}, SqlTypes.BOOLEAN)
implicit val byteDecoder: Decoder[Byte] =
decoder[Byte]({
case v: Short => v.toByte
}, SqlTypes.TINYINT)
implicit val shortDecoder: Decoder[Short] =
decoder[Short]({
case v: Byte => v.toShort
}, SqlTypes.SMALLINT)
implicit val intDecoder: Decoder[Int] =
AsyncDecoder(SqlTypes.INTEGER)(new NumericDecoder[Int] {
def decode[U](v: U)(implicit n: Numeric[U]) =
n.toInt(v)
})
implicit val longDecoder: Decoder[Long] =
AsyncDecoder(SqlTypes.BIGINT)(new NumericDecoder[Long] {
def decode[U](v: U)(implicit n: Numeric[U]) =
n.toLong(v)
})
implicit val floatDecoder: Decoder[Float] =
AsyncDecoder(SqlTypes.FLOAT)(new NumericDecoder[Float] {
def decode[U](v: U)(implicit n: Numeric[U]) =
n.toFloat(v)
})
implicit val doubleDecoder: Decoder[Double] =
AsyncDecoder(SqlTypes.DOUBLE)(new NumericDecoder[Double] {
def decode[U](v: U)(implicit n: Numeric[U]) =
n.toDouble(v)
})
implicit val byteArrayDecoder: Decoder[Array[Byte]] = decoder[Array[Byte]](PartialFunction.empty, SqlTypes.TINYINT)
implicit val jodaDateTimeDecoder: Decoder[JodaDateTime] = decoder[JodaDateTime]({
case dateTime: JodaDateTime => dateTime
case localDateTime: JodaLocalDateTime => localDateTime.toDateTime
}, SqlTypes.TIMESTAMP)
implicit val jodaLocalDateDecoder: Decoder[JodaLocalDate] = decoder[JodaLocalDate]({
case localDate: JodaLocalDate => localDate
}, SqlTypes.DATE)
implicit val jodaLocalDateTimeDecoder: Decoder[JodaLocalDateTime] = decoder[JodaLocalDateTime]({
case localDateTime: JodaLocalDateTime => localDateTime
}, SqlTypes.TIMESTAMP)
implicit val dateDecoder: Decoder[Date] = decoder[Date]({
case localDateTime: JodaLocalDateTime => localDateTime.toDate
case localDate: JodaLocalDate => localDate.toDate
}, SqlTypes.TIMESTAMP)
implicit val decodeZonedDateTime: MappedEncoding[JodaDateTime, ZonedDateTime] =
MappedEncoding(jdt => ZonedDateTime.ofInstant(Instant.ofEpochMilli(jdt.getMillis), ZoneId.of(jdt.getZone.getID)))
implicit val decodeLocalDate: MappedEncoding[JodaLocalDate, LocalDate] =
MappedEncoding(jld => LocalDate.of(jld.getYear, jld.getMonthOfYear, jld.getDayOfMonth))
implicit val decodeLocalDateTime: MappedEncoding[JodaLocalDateTime, LocalDateTime] =
MappedEncoding(jldt => LocalDateTime.ofInstant(jldt.toDate.toInstant, ZoneId.systemDefault()))
implicit val localDateDecoder: Decoder[LocalDate] = mappedDecoder(decodeLocalDate, jodaLocalDateDecoder)
}
@jilen
Copy link
Author

jilen commented Sep 27, 2018

Note non-transaction op are also performed sequentially... , this could be improved

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment