Last active
October 26, 2018 07:29
-
-
Save mitallast/e6761af9180d70d56e9933e18c0bfedd to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package doobieext | |
import java.nio.charset.CharsetEncoder | |
import java.nio.charset.StandardCharsets.UTF_8 | |
import java.nio.{ByteBuffer, CharBuffer} | |
import java.time.LocalDate | |
import java.util.UUID | |
import akka.Done | |
import akka.stream.Materializer | |
import akka.stream.scaladsl.Source | |
import cats.Foldable | |
import cats.implicits._ | |
import doobie._ | |
import doobie.implicits._ | |
import doobie.postgres.implicits._ | |
import doobie.postgres.Text | |
import org.apache.logging.log4j.scala.Logging | |
import org.postgresql.copy.{CopyIn, CopyManager} | |
import org.postgresql.core.BaseConnection | |
import scala.util.{Failure, Left, Right, Success} | |
final class PostgresCopyOperations(fragment: Fragment) extends Logging { | |
private val maxChars = 131072 | |
private val limitChars = maxChars - 1000 | |
def copyStream[T, F[_]: Foldable](fa: F[T])(implicit text: Text[T]): ConnectionIO[Long] = | |
if (fa.isEmpty) FC.pure(0L) | |
else { | |
for { | |
conn <- FC.unwrap(classOf[BaseConnection]) | |
rows <- FC.delay { | |
val start = System.currentTimeMillis() | |
val copyManager = new CopyManager(conn) | |
val cp = copyManager.copyIn(fragment.query.sql) | |
val encoder = UTF_8.newEncoder() | |
val sb = new StringBuilder(maxChars) | |
val bb = ByteBuffer.allocate(maxChars + maxChars >> 1) | |
fa.foldLeft(()) { | |
case (_, a: T) => | |
text.unsafeEncode(a, sb) | |
sb.append("\n") | |
if (sb.length >= limitChars) write(sb, encoder, bb, cp) | |
} | |
if (sb.nonEmpty) write(sb, encoder, bb, cp) | |
val rows = cp.endCopy() | |
val end = System.currentTimeMillis() | |
logger.info(s"copy $rows rows at ${end - start}ms") | |
rows | |
} | |
} yield rows | |
} | |
def copySource[T](fa: Source[T, _])(implicit text: Text[T], mat: Materializer): ConnectionIO[Long] = | |
for { | |
conn <- FC.unwrap(classOf[BaseConnection]) | |
start <- FC.delay(System.currentTimeMillis()) | |
cp <- FC.delay(new CopyManager(conn).copyIn(fragment.query.sql)) | |
encoder = UTF_8.newEncoder() | |
sb = new StringBuilder(maxChars) | |
bb = ByteBuffer.allocate(maxChars + maxChars >> 1) | |
_ <- FC.async[Done] { cb => | |
fa.runForeach { t: T => | |
text.unsafeEncode(t, sb) | |
sb.append("\n") | |
if (sb.length >= limitChars) write(sb, encoder, bb, cp) | |
} | |
.onComplete( | |
r => | |
cb(r match { | |
case Success(a) => Right(a) | |
case Failure(e) => Left(e) | |
}) | |
)(BlockingContext.blocking) | |
} | |
_ <- FC.delay(if (sb.nonEmpty) write(sb, encoder, bb, cp)) | |
rows <- FC.delay(cp.endCopy()) | |
_ <- FC.delay(logger.info(s"copy $rows rows at ${System.currentTimeMillis() - start}ms")) | |
} yield rows | |
private def write(sb: StringBuilder, encoder: CharsetEncoder, bb: ByteBuffer, cp: CopyIn): Unit = { | |
val cb = CharBuffer.wrap(sb) | |
encoder.reset() | |
do { | |
bb.clear() | |
logger.trace(s"encode: ${cb.remaining()} chars") | |
val cr = encoder.encode(cb, bb, true) | |
if (cr.isError) { | |
cr.throwException() | |
} | |
bb.flip() | |
if (bb.hasRemaining) { | |
logger.trace(s"write ${bb.remaining()} bytes") | |
cp.writeToCopy(bb.array(), 0, bb.remaining()) | |
} | |
} while (cb.hasRemaining) | |
sb.clear() | |
} | |
} | |
object PostgresCopyOperations { | |
implicit def toCopyOperations(fragment: Fragment): PostgresCopyOperations = new PostgresCopyOperations(fragment) | |
final implicit val UUIDText: Text[UUID] = Text.instance((n, sb) => sb.append(n.toString)) | |
final implicit val LocalDateText: Text[LocalDate] = Text.instance((n, sb) => sb.append(n.toString)) | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment