Last active
September 27, 2018 11:31
-
-
Save jilen/e225e87bb92113f8b54d3e3a0988dca0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package io.getquill.context.async | |
import com.github.mauricio.async.db.Connection | |
import com.github.mauricio.async.db.{ QueryResult => DBQueryResult } | |
import com.github.mauricio.async.db.RowData | |
import com.github.mauricio.async.db.pool.PartitionedConnectionPool | |
import scala.concurrent.Await | |
import scala.concurrent.ExecutionContext | |
import scala.concurrent.Future | |
import scala.concurrent.duration.Duration | |
import scala.util.Try | |
import io.getquill.context.sql.SqlContext | |
import io.getquill.context.sql.idiom.SqlIdiom | |
import io.getquill.NamingStrategy | |
import io.getquill.util.ContextLogger | |
import io.getquill.monad.IOMonad | |
import io.getquill.context.Context | |
abstract class AsyncIOContext[D <: SqlIdiom, N <: NamingStrategy, C <: Connection](val idiom: D, val naming: N, pool: PartitionedConnectionPool[C]) | |
extends Context[D, N] | |
with SqlContext[D, N] with IOMonad with IOEncoders with IODecoders { | |
private val logger = ContextLogger(classOf[AsyncContext[_, _, _]]) | |
override type PrepareRow = List[Any] | |
override type ResultRow = RowData | |
override type Result[T] = C => Future[T] | |
override type RunQueryResult[T] = List[T] | |
override type RunQuerySingleResult[T] = T | |
override type RunActionResult = Long | |
override type RunActionReturningResult[T] = T | |
override type RunBatchActionResult = List[Long] | |
override type RunBatchActionReturningResult[T] = List[T] | |
def performIO[T](io: IO[T, _], transactional: Boolean = false)(implicit ec: ExecutionContext): Future[T] = { | |
def performInternal[A](i: IO[A, _]): Result[A] = i match { | |
case FromTry(v) => { c: C => | |
Future.fromTry(v) | |
} | |
case Run(f) => f() | |
case Sequence(in, cbfIOToResult, cbfResultToValue) => | |
{ c: C => | |
val b = cbfIOToResult() | |
val fs = in.map(ii => performInternal(ii)).foreach { f => | |
b += f | |
} | |
val init = cbfResultToValue() | |
val fut = b.result.foldLeft(Future.successful(init)) { (r, f) => | |
for { | |
vb <- r | |
fr <- f(c) | |
} yield vb += fr | |
} | |
fut.map(_.result) | |
} | |
case TransformWith(a, fA) => | |
{ c: C => | |
performInternal(a)(c) | |
.map(scala.util.Success(_)) | |
.recover { case ex => scala.util.Failure(ex) } | |
.flatMap(v => performInternal(fA(v))(c)) | |
} | |
case Transactional(io) => | |
performInternal(io) | |
} | |
if (transactional) { | |
pool.inTransaction(c => performInternal(io)(c.asInstanceOf[C])) | |
} else { | |
withConnection(c => performInternal(io)(c.asInstanceOf[C])) | |
} | |
} | |
override def close = { | |
Await.result(pool.close, Duration.Inf) | |
() | |
} | |
protected def withConnection[T](f: Connection => Future[T])(implicit ec: ExecutionContext) = | |
ec match { | |
case TransactionalExecutionContext(ec, conn) => f(conn) | |
case other => f(pool) | |
} | |
protected def extractActionResult[O](returningColumn: String, extractor: Extractor[O])(result: DBQueryResult): O | |
protected def expandAction(sql: String, returningColumn: String) = sql | |
def probe(sql: String) = | |
Try { | |
Await.result(pool.sendQuery(sql), Duration.Inf) | |
} | |
def transaction[T](f: TransactionalExecutionContext => Future[T])(implicit ec: ExecutionContext) = | |
pool.inTransaction { c => | |
f(TransactionalExecutionContext(ec, c)) | |
} | |
def executeQuery[T](sql: String, prepare: Prepare = identityPrepare, extractor: Extractor[T] = identityExtractor)(implicit ec: ExecutionContext): Result[List[T]] = { c: C => | |
val (params, values) = prepare(Nil) | |
logger.logQuery(sql, params) | |
c.sendPreparedStatement(sql, values).map { | |
_.rows match { | |
case Some(rows) => rows.map(extractor).toList | |
case None => Nil | |
} | |
} | |
} | |
def executeQuerySingle[T](sql: String, prepare: Prepare = identityPrepare, extractor: Extractor[T] = identityExtractor)(implicit ec: ExecutionContext): Result[T] = { c: C => | |
executeQuery(sql, prepare, extractor)(ec)(c).map(handleSingleResult) | |
} | |
def executeAction[T](sql: String, prepare: Prepare = identityPrepare)(implicit ec: ExecutionContext): Result[Long] = { c => | |
val (params, values) = prepare(Nil) | |
logger.logQuery(sql, params) | |
c.sendPreparedStatement(sql, values).map(_.rowsAffected) | |
} | |
def executeActionReturning[T](sql: String, prepare: Prepare = identityPrepare, extractor: Extractor[T], returningColumn: String)(implicit ec: ExecutionContext): Result[T] = { c => | |
val expanded = expandAction(sql, returningColumn) | |
val (params, values) = prepare(Nil) | |
logger.logQuery(sql, params) | |
c.sendPreparedStatement(expanded, values).map(extractActionResult(returningColumn, extractor)) | |
} | |
def executeBatchAction(groups: List[BatchGroup])(implicit ec: ExecutionContext): Result[List[Long]] = { c: C => | |
Future.sequence { | |
groups.map { | |
case BatchGroup(sql, prepare) => | |
prepare.foldLeft(Future.successful(List.newBuilder[Long])) { | |
case (acc, prepare) => | |
acc.flatMap { list => | |
executeAction(sql, prepare)(ec)(c).map(list += _) | |
} | |
}.map(_.result()) | |
} | |
}.map(_.flatten.toList) | |
} | |
def executeBatchActionReturning[T](groups: List[BatchGroupReturning], extractor: Extractor[T])(implicit ec: ExecutionContext): Result[List[T]] = | |
{ c: C => | |
Future.sequence { | |
groups.map { | |
case BatchGroupReturning(sql, column, prepare) => | |
prepare.foldLeft(Future.successful(List.newBuilder[T])) { | |
case (acc, prepare) => | |
acc.flatMap { list => | |
executeActionReturning(sql, prepare, extractor, column)(ec)(c).map(list += _) | |
} | |
}.map(_.result()) | |
} | |
}.map(_.flatten.toList) | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package io.getquill.context.async | |
import java.time._ | |
import java.util.Date | |
import io.getquill.util.Messages.fail | |
import scala.reflect.{ ClassTag, classTag } | |
import org.joda.time.{ DateTimeZone => JodaDateTimeZone, DateTime => JodaDateTime, LocalDate => JodaLocalDate, LocalDateTime => JodaLocalDateTime } | |
trait IOEncoders { | |
this: AsyncIOContext[_, _, _] => | |
type Encoder[T] = AsyncEncoder[T] | |
type EncoderSqlType = SqlTypes.SqlTypes | |
case class AsyncEncoder[T](sqlType: DecoderSqlType)(implicit encoder: BaseEncoder[T]) | |
extends BaseEncoder[T] { | |
override def apply(index: Index, value: T, row: PrepareRow) = | |
encoder.apply(index, value, row) | |
} | |
def encoder[T](sqlType: DecoderSqlType): Encoder[T] = | |
encoder(identity[T], sqlType) | |
def encoder[T](f: T => Any, sqlType: DecoderSqlType): Encoder[T] = | |
AsyncEncoder[T](sqlType)(new BaseEncoder[T] { | |
def apply(index: Index, value: T, row: PrepareRow) = | |
row :+ f(value) | |
}) | |
implicit def mappedEncoder[I, O](implicit mapped: MappedEncoding[I, O], e: Encoder[O]): Encoder[I] = | |
AsyncEncoder(e.sqlType)(new BaseEncoder[I] { | |
def apply(index: Index, value: I, row: PrepareRow) = | |
e(index, mapped.f(value), row) | |
}) | |
implicit def optionEncoder[T](implicit d: Encoder[T]): Encoder[Option[T]] = | |
AsyncEncoder(d.sqlType)(new BaseEncoder[Option[T]] { | |
def apply(index: Index, value: Option[T], row: PrepareRow) = { | |
value match { | |
case None => nullEncoder(index, null, row) | |
case Some(v) => d(index, v, row) | |
} | |
} | |
}) | |
private[this] val nullEncoder: Encoder[Null] = encoder[Null](SqlTypes.NULL) | |
implicit val stringEncoder: Encoder[String] = encoder[String](SqlTypes.VARCHAR) | |
implicit val bigDecimalEncoder: Encoder[BigDecimal] = encoder[BigDecimal](SqlTypes.REAL) | |
implicit val booleanEncoder: Encoder[Boolean] = encoder[Boolean](SqlTypes.BOOLEAN) | |
implicit val byteEncoder: Encoder[Byte] = encoder[Byte](SqlTypes.TINYINT) | |
implicit val shortEncoder: Encoder[Short] = encoder[Short](SqlTypes.SMALLINT) | |
implicit val intEncoder: Encoder[Int] = encoder[Int](SqlTypes.INTEGER) | |
implicit val longEncoder: Encoder[Long] = encoder[Long](SqlTypes.BIGINT) | |
implicit val floatEncoder: Encoder[Float] = encoder[Float](SqlTypes.FLOAT) | |
implicit val doubleEncoder: Encoder[Double] = encoder[Double](SqlTypes.DOUBLE) | |
implicit val byteArrayEncoder: Encoder[Array[Byte]] = encoder[Array[Byte]](SqlTypes.VARBINARY) | |
implicit val jodaDateTimeEncoder: Encoder[JodaDateTime] = encoder[JodaDateTime](SqlTypes.TIMESTAMP) | |
implicit val jodaLocalDateEncoder: Encoder[JodaLocalDate] = encoder[JodaLocalDate](SqlTypes.DATE) | |
implicit val jodaLocalDateTimeEncoder: Encoder[JodaLocalDateTime] = encoder[JodaLocalDateTime](SqlTypes.TIMESTAMP) | |
implicit val dateEncoder: Encoder[Date] = encoder[Date]((d: Date) => new JodaLocalDateTime(d), SqlTypes.TIMESTAMP) | |
implicit val encodeZonedDateTime: MappedEncoding[ZonedDateTime, JodaDateTime] = | |
MappedEncoding(zdt => new JodaDateTime(zdt.toInstant.toEpochMilli, JodaDateTimeZone.forID(zdt.getZone.getId))) | |
implicit val encodeLocalDate: MappedEncoding[LocalDate, JodaLocalDate] = | |
MappedEncoding(ld => new JodaLocalDate(ld.getYear, ld.getMonthValue, ld.getDayOfMonth)) | |
implicit val encodeLocalDateTime: MappedEncoding[LocalDateTime, JodaLocalDateTime] = | |
MappedEncoding(ldt => new JodaLocalDateTime(ldt.atZone(ZoneId.systemDefault()).toInstant.toEpochMilli)) | |
implicit val localDateEncoder: Encoder[LocalDate] = mappedEncoder(encodeLocalDate, jodaLocalDateEncoder) | |
} | |
trait IODecoders { | |
this: AsyncIOContext[_, _, _] => | |
type Decoder[T] = AsyncDecoder[T] | |
type DecoderSqlType = SqlTypes.SqlTypes | |
case class AsyncDecoder[T](sqlType: DecoderSqlType)(implicit decoder: BaseDecoder[T]) | |
extends BaseDecoder[T] { | |
override def apply(index: Index, row: ResultRow) = | |
decoder(index, row) | |
} | |
def decoder[T: ClassTag]( | |
f: PartialFunction[Any, T] = PartialFunction.empty, | |
sqlType: DecoderSqlType | |
): Decoder[T] = | |
AsyncDecoder[T](sqlType)(new BaseDecoder[T] { | |
def apply(index: Index, row: ResultRow) = { | |
row(index) match { | |
case value: T => value | |
case value if f.isDefinedAt(value) => f(value) | |
case value => | |
fail( | |
s"Value '$value' at index $index can't be decoded to '${classTag[T].runtimeClass}'" | |
) | |
} | |
} | |
}) | |
implicit def mappedDecoder[I, O](implicit mapped: MappedEncoding[I, O], decoder: Decoder[I]): Decoder[O] = | |
AsyncDecoder(decoder.sqlType)(new BaseDecoder[O] { | |
def apply(index: Index, row: ResultRow): O = | |
mapped.f(decoder.apply(index, row)) | |
}) | |
trait NumericDecoder[T] extends BaseDecoder[T] { | |
def apply(index: Index, row: ResultRow) = | |
row(index) match { | |
case v: Byte => decode(v) | |
case v: Short => decode(v) | |
case v: Int => decode(v) | |
case v: Long => decode(v) | |
case v: Float => decode(v) | |
case v: Double => decode(v) | |
case v: BigDecimal => decode(v) | |
case other => | |
fail(s"Value $other is not numeric") | |
} | |
def decode[U](v: U)(implicit n: Numeric[U]): T | |
} | |
implicit def optionDecoder[T](implicit d: Decoder[T]): Decoder[Option[T]] = | |
AsyncDecoder(d.sqlType)(new BaseDecoder[Option[T]] { | |
def apply(index: Index, row: ResultRow) = { | |
row(index) match { | |
case null => None | |
case value => Some(d(index, row)) | |
} | |
} | |
}) | |
implicit val stringDecoder: Decoder[String] = decoder[String](PartialFunction.empty, SqlTypes.VARCHAR) | |
implicit val bigDecimalDecoder: Decoder[BigDecimal] = | |
AsyncDecoder(SqlTypes.REAL)(new NumericDecoder[BigDecimal] { | |
def decode[U](v: U)(implicit n: Numeric[U]) = | |
BigDecimal(n.toDouble(v)) | |
}) | |
implicit val booleanDecoder: Decoder[Boolean] = | |
decoder[Boolean]({ | |
case v: Byte => v == (1: Byte) | |
case v: Short => v == (1: Short) | |
case v: Int => v == 1 | |
case v: Long => v == 1L | |
}, SqlTypes.BOOLEAN) | |
implicit val byteDecoder: Decoder[Byte] = | |
decoder[Byte]({ | |
case v: Short => v.toByte | |
}, SqlTypes.TINYINT) | |
implicit val shortDecoder: Decoder[Short] = | |
decoder[Short]({ | |
case v: Byte => v.toShort | |
}, SqlTypes.SMALLINT) | |
implicit val intDecoder: Decoder[Int] = | |
AsyncDecoder(SqlTypes.INTEGER)(new NumericDecoder[Int] { | |
def decode[U](v: U)(implicit n: Numeric[U]) = | |
n.toInt(v) | |
}) | |
implicit val longDecoder: Decoder[Long] = | |
AsyncDecoder(SqlTypes.BIGINT)(new NumericDecoder[Long] { | |
def decode[U](v: U)(implicit n: Numeric[U]) = | |
n.toLong(v) | |
}) | |
implicit val floatDecoder: Decoder[Float] = | |
AsyncDecoder(SqlTypes.FLOAT)(new NumericDecoder[Float] { | |
def decode[U](v: U)(implicit n: Numeric[U]) = | |
n.toFloat(v) | |
}) | |
implicit val doubleDecoder: Decoder[Double] = | |
AsyncDecoder(SqlTypes.DOUBLE)(new NumericDecoder[Double] { | |
def decode[U](v: U)(implicit n: Numeric[U]) = | |
n.toDouble(v) | |
}) | |
implicit val byteArrayDecoder: Decoder[Array[Byte]] = decoder[Array[Byte]](PartialFunction.empty, SqlTypes.TINYINT) | |
implicit val jodaDateTimeDecoder: Decoder[JodaDateTime] = decoder[JodaDateTime]({ | |
case dateTime: JodaDateTime => dateTime | |
case localDateTime: JodaLocalDateTime => localDateTime.toDateTime | |
}, SqlTypes.TIMESTAMP) | |
implicit val jodaLocalDateDecoder: Decoder[JodaLocalDate] = decoder[JodaLocalDate]({ | |
case localDate: JodaLocalDate => localDate | |
}, SqlTypes.DATE) | |
implicit val jodaLocalDateTimeDecoder: Decoder[JodaLocalDateTime] = decoder[JodaLocalDateTime]({ | |
case localDateTime: JodaLocalDateTime => localDateTime | |
}, SqlTypes.TIMESTAMP) | |
implicit val dateDecoder: Decoder[Date] = decoder[Date]({ | |
case localDateTime: JodaLocalDateTime => localDateTime.toDate | |
case localDate: JodaLocalDate => localDate.toDate | |
}, SqlTypes.TIMESTAMP) | |
implicit val decodeZonedDateTime: MappedEncoding[JodaDateTime, ZonedDateTime] = | |
MappedEncoding(jdt => ZonedDateTime.ofInstant(Instant.ofEpochMilli(jdt.getMillis), ZoneId.of(jdt.getZone.getID))) | |
implicit val decodeLocalDate: MappedEncoding[JodaLocalDate, LocalDate] = | |
MappedEncoding(jld => LocalDate.of(jld.getYear, jld.getMonthOfYear, jld.getDayOfMonth)) | |
implicit val decodeLocalDateTime: MappedEncoding[JodaLocalDateTime, LocalDateTime] = | |
MappedEncoding(jldt => LocalDateTime.ofInstant(jldt.toDate.toInstant, ZoneId.systemDefault())) | |
implicit val localDateDecoder: Decoder[LocalDate] = mappedDecoder(decodeLocalDate, jodaLocalDateDecoder) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Note non-transaction op are also performed sequentially... , this could be improved