Skip to content

Instantly share code, notes, and snippets.

@mitallast
Last active October 26, 2018 07:29
Show Gist options
  • Save mitallast/e6761af9180d70d56e9933e18c0bfedd to your computer and use it in GitHub Desktop.
Save mitallast/e6761af9180d70d56e9933e18c0bfedd to your computer and use it in GitHub Desktop.
package doobieext
import java.nio.charset.CharsetEncoder
import java.nio.charset.StandardCharsets.UTF_8
import java.nio.{ByteBuffer, CharBuffer}
import java.time.LocalDate
import java.util.UUID
import akka.Done
import akka.stream.Materializer
import akka.stream.scaladsl.Source
import cats.Foldable
import cats.implicits._
import doobie._
import doobie.implicits._
import doobie.postgres.implicits._
import doobie.postgres.Text
import org.apache.logging.log4j.scala.Logging
import org.postgresql.copy.{CopyIn, CopyManager}
import org.postgresql.core.BaseConnection
import scala.util.{Failure, Left, Right, Success}
final class PostgresCopyOperations(fragment: Fragment) extends Logging {
private val maxChars = 131072
private val limitChars = maxChars - 1000
def copyStream[T, F[_]: Foldable](fa: F[T])(implicit text: Text[T]): ConnectionIO[Long] =
if (fa.isEmpty) FC.pure(0L)
else {
for {
conn <- FC.unwrap(classOf[BaseConnection])
rows <- FC.delay {
val start = System.currentTimeMillis()
val copyManager = new CopyManager(conn)
val cp = copyManager.copyIn(fragment.query.sql)
val encoder = UTF_8.newEncoder()
val sb = new StringBuilder(maxChars)
val bb = ByteBuffer.allocate(maxChars + maxChars >> 1)
fa.foldLeft(()) {
case (_, a: T) =>
text.unsafeEncode(a, sb)
sb.append("\n")
if (sb.length >= limitChars) write(sb, encoder, bb, cp)
}
if (sb.nonEmpty) write(sb, encoder, bb, cp)
val rows = cp.endCopy()
val end = System.currentTimeMillis()
logger.info(s"copy $rows rows at ${end - start}ms")
rows
}
} yield rows
}
def copySource[T](fa: Source[T, _])(implicit text: Text[T], mat: Materializer): ConnectionIO[Long] =
for {
conn <- FC.unwrap(classOf[BaseConnection])
start <- FC.delay(System.currentTimeMillis())
cp <- FC.delay(new CopyManager(conn).copyIn(fragment.query.sql))
encoder = UTF_8.newEncoder()
sb = new StringBuilder(maxChars)
bb = ByteBuffer.allocate(maxChars + maxChars >> 1)
_ <- FC.async[Done] { cb =>
fa.runForeach { t: T =>
text.unsafeEncode(t, sb)
sb.append("\n")
if (sb.length >= limitChars) write(sb, encoder, bb, cp)
}
.onComplete(
r =>
cb(r match {
case Success(a) => Right(a)
case Failure(e) => Left(e)
})
)(BlockingContext.blocking)
}
_ <- FC.delay(if (sb.nonEmpty) write(sb, encoder, bb, cp))
rows <- FC.delay(cp.endCopy())
_ <- FC.delay(logger.info(s"copy $rows rows at ${System.currentTimeMillis() - start}ms"))
} yield rows
private def write(sb: StringBuilder, encoder: CharsetEncoder, bb: ByteBuffer, cp: CopyIn): Unit = {
val cb = CharBuffer.wrap(sb)
encoder.reset()
do {
bb.clear()
logger.trace(s"encode: ${cb.remaining()} chars")
val cr = encoder.encode(cb, bb, true)
if (cr.isError) {
cr.throwException()
}
bb.flip()
if (bb.hasRemaining) {
logger.trace(s"write ${bb.remaining()} bytes")
cp.writeToCopy(bb.array(), 0, bb.remaining())
}
} while (cb.hasRemaining)
sb.clear()
}
}
object PostgresCopyOperations {
implicit def toCopyOperations(fragment: Fragment): PostgresCopyOperations = new PostgresCopyOperations(fragment)
final implicit val UUIDText: Text[UUID] = Text.instance((n, sb) => sb.append(n.toString))
final implicit val LocalDateText: Text[LocalDate] = Text.instance((n, sb) => sb.append(n.toString))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment