Skip to content

Instantly share code, notes, and snippets.

@alexanderscott
Created October 14, 2015 09:43
Show Gist options
  • Save alexanderscott/7c9b87bdb11f4b6a4ba1 to your computer and use it in GitHub Desktop.
Save alexanderscott/7c9b87bdb11f4b6a4ba1 to your computer and use it in GitHub Desktop.
MySql connection pooling in Scala
// Taken largely from https://gist.github.com/tsuna/2245176
import java.sql._
import java.util.concurrent.ArrayBlockingQueue
import java.util.concurrent.Executors
import java.util.concurrent.ThreadFactory
import java.util.concurrent.atomic.{AtomicLong, AtomicInteger}
import java.util.concurrent.TimeUnit.MILLISECONDS
import scala.collection.mutable.ArrayBuffer
import org.slf4j.LoggerFactory
import com.mysql.jdbc.log.Log
import com.twitter.conversions.time._
import com.twitter.util.Duration
import com.twitter.util.Future
import com.twitter.util.FuturePool
import scala.collection.{JavaConversions, JavaConverters}
import java.util.concurrent.atomic.LongAdder
/**
* Configuration for a connection Pool.
* @param servers A list of "ip:port".
* @param user MySQL username to connect as.
* @param pass MySQL password for that user.
* @param schema Name of the DB schema to use.
*/
final case class PoolConfig(servers: Seq[String], user: String, pass: String, schema: String) {
/** Like `equals' but ignores the order in `servers' in case they were shuffled. */
def equivalentTo(other: PoolConfig): Boolean =
user == other.user && pass == other.pass && schema == other.schema &&
servers.sorted == other.servers.sorted
}
/**
* Wrapper for JDBC connections.
* We have to wrap every connection just so we can remember which server this
* connection is connected to, so we can reconnect when something bad happens.
* Because, yes, believe it or not, there's no way to reliably extract this
* information from a JDBC connection object.
* @param server A "host:port" string.
*/
final case class MySQLConnection(server: String, connection: Connection) {
def prepareStatement(sql: String) = connection.prepareStatement(sql)
def close() = connection.close()
}
/**
* A connection pool for asynchronous operations.
* For each connection, there is a dedicated thread, because MySQL doesn't
* have an asynchronous RPC protocol, and because JDBC doesn't have an
* asynchronous API.
* @param cfg Configuration for this connection pool.
* @param options A "query string" passed as-is in the JDBC URL.
* @param readonly Whether or not to set the connection read-only mode.
* @param appName Name of the current app (e.g. "honeybadger").
*/
final class ConnectionPool(cfg: PoolConfig,
options: String,
val readonly: Boolean,
appName: String) {
import ConnectionPool._
ensureDriverLoaded
@volatile private[this] var conf = cfg
private[this] val pool = makePool(cfg.servers.length)
@volatile private[this] var connections: ArrayBlockingQueue[MySQLConnection] = _
createConnections() // Populates `connections'.
/** Returns the current configuration of this pool. */
def config = conf
/**
* Attempts to apply the new configuration given to this pool.
* Changes are applied atomically without disruptive ongoing traffic.
* If successful, this closes all the connections and replaces them all with
* new connections.
* If there's an exception thrown, changes are rolled back first and both
* the configuration and the connection pool will remain unchanged.
* <strong>WARNING:</strong> this function is blocking, and might take a
* while (maybe several seconds) to return.
* @param newcfg The new configuration to apply to this pool. The
* configuration is assumed to be sane.
* @throws SQLException if something bad happens (e.g. being unable to open
* a connection to any one of the hosts for whatever reason).
*/
def updateConfig(newcfg: PoolConfig) {
// Almost everything we do is thread-safe but in order to guarantee that
// we can correctly rollback the changes in case of an exception, and in
// order to ensure that we only attempt to apply one change at a time,
// it's much safer and easier to make this entire method synchronized.
synchronized {
val prevconns = connections // volatile-read
val prevcfg = conf // volatile-read
try {
conf = newcfg
createConnections() // volatile-write on connections
// Success! Now dispose of the previous connections, to not leak them.
try {
closeAllConnections(prevcfg, prevconns)
} catch {
case e: Exception =>
log.warn("Uncaught exception while closing an old connection after"
+ " reloading a new configuration", e)
}
} catch {
case e: Exception =>
// Roll-back.
connections = prevconns // volatile-write
conf = prevcfg // volatile-write
throw e
}
}
}
/** Creates and populates all the connections for this pool .*/
private def createConnections() {
val newconns = new ArrayBlockingQueue[MySQLConnection](conf.servers.length)
conf.servers foreach { server => // server is already "ip:port".
newconns.add(newConnection(server))
}
connections = newconns // commit: volatile-write
}
/** How many queries did we send to MySQL. */
private[this] val queries = new LongAdder()
/** How many exceptions we got from JDBC. */
private[this] val exceptions = new LongAdder()
/** Returns the number of queries sent to MySQL. */
def queriesSent: Long = queries.longValue()
/** Returns the number of exception caught while MySQL stuff. */
def exceptionsCaught: Long = exceptions.longValue()
/** Closes all connections and releases all threads. */
def shutdown() {
pool.executor.shutdown()
closeAllConnections(conf, connections)
}
/**
* Executes a SELECT statement on the database.
* @param f Function called on each row returned by the database. This
* function is called with the connection locked, so if this function takes
* time, it will prevent the connection from beind reused for another query.
* @param sql The SQL statement, e.g. "SELECT foo FROM t WHERE id = ?"
* @param params The parameters to substitute in the `?' placeholders.
* These parameters don't need to be escaped as prepared statements are used
* and they already prevent SQL injections.
* @return A future sequence of things returned by `f'.
* @throws SQLException (async) if something bad happens (sorry I don't know more).
*/
def select[T](f: ResultSet => T, sql: String, params: Seq[Any]): Future[Seq[T]] = {
pool(execute(f, "/*" + appName + "*/ " + sql, params))
}
def insert[T](f: ResultSet => T, sql: String, params: Seq[Any]): Future[Seq[T]] = {
pool(execute(f, sql, params))
}
def update[T](f: ResultSet => T, sql: String, params: Seq[Any]): Future[Seq[T]] = {
pool(execute(f, sql, params))
}
// TODO(tsuna): Provide code for insert, update etc, not just select.
def execute[T](f: ResultSet => T, sql: String, params: Seq[Any]): Seq[T] = {
queries.increment()
val connpool = this.connections // volatile-read
var connection = connpool.poll
if (connection == null) { // Should never happen.
// We have as many threads as connections so this can only happen if a
// thread is leaking a connection, which would be really bad.
val e = new IllegalStateException("WTF? Couldn't get a connection from the pool.")
exceptions.increment
log.error(e.getMessage)
throw e
}
try {
val statement = connection.prepareStatement(sql)
try {
bindParameters(statement, params)
if (log.isDebugEnabled)
log.debug(connection.server + ": " + sql
+ " " + params.mkString("(", ", ", ")"))
val rs = statement.executeQuery
try {
val results = new ArrayBuffer[T]
while (rs.next) {
results += f(rs)
}
results
} finally {
rs.close()
}
} finally {
statement.close()
}
} catch {
case e: SQLSyntaxErrorException =>
logAndRethrow(connection, "Syntax error in SQL query",
sql, params, e)
case e: SQLIntegrityConstraintViolationException =>
logAndRethrow(connection, "Integrity constraint violated by SQL query",
sql, params, e)
case e: SQLFeatureNotSupportedException =>
logAndRethrow(connection, "Feature not supported in SQL query",
sql, params, e)
case e: SQLDataException =>
logAndRethrow(connection, "Data exception caused by SQL query",
sql, params, e)
case e @ (_: SQLRecoverableException | _: SQLNonTransientException) =>
// The remaining kinds of SQLNonTransientException are typically
// connection-level problems, so let's close this connection and get a
// new one.
// For a SQLRecoverableException the JDK javadoc manual says that "the
// recovery operation must include closing the current connection and
// getting a new connection".
connection.close() // If we double-close it's OK, it's a no-op.
// Create a new connection, the `finally' block below will put it back
// in the pool.
connection = newConnection(connection.server)
// TODO(tsuna): If we wanted we could retry once here.
logAndRethrow(connection, "Error on connection when trying to execute",
sql, params, e)
case e: Throwable =>
// TODO(tsuna): Should we close the connection here? I'm not sure.
logAndRethrow(connection, "Uncaught exception", sql, params, e)
} finally {
// Always return the connection to the pool.
connpool.put(connection)
}
}
/** Logs an exception and rethrows it. */
private def logAndRethrow(connection: MySQLConnection, msg: String,
sql: String, params: Seq[Any], e: Throwable) = {
// This function must never throw a new exception of its own.
exceptions.increment()
val cause = new StringBuilder
var exception = e
// Get names & messages of all exceptions in the chain.
while (exception != null) {
cause.append(", caused by ")
.append(e.getClass.getName)
.append(": ")
.append(e.getMessage)
exception = exception.getCause // previous exception causing this one.
}
log.error(connection.server + ": " + msg + ": " + sql
+ " with params " + params.mkString("(", ", ", ")")
+ cause)
throw e
}
private def bindParameters(statement: PreparedStatement,
params: TraversableOnce[Any]) {
bindParameters(statement, 1, params)
}
private def bindParameters(statement: PreparedStatement,
startIndex: Int,
params: TraversableOnce[Any]): Int = {
var index = startIndex
for (param <- params) {
param match {
case i: Int => statement.setInt(index, i)
case l: Long => statement.setLong(index, l)
case s: String => statement.setString(index, s)
case l: TraversableOnce[_] =>
index = bindParameters(statement, index, l) - 1
case p: Product =>
index = bindParameters(statement, index, p.productIterator.toList) - 1
//case ab: Array[Byte] => statement.setBytes(index, ab)
case b: Boolean => statement.setBoolean(index, b)
case s: Short => statement.setShort(index, s)
case f: Float => statement.setFloat(index, f)
case d: Double => statement.setDouble(index, d)
case _ =>
throw new IllegalArgumentException("Unsupported data type "
+ param.asInstanceOf[AnyRef].getClass.getName + ": " + param)
}
index += 1
}
index
}
/**
* Returns a new MySQL connection.
* @param server A "host:port" string.
*/
private def newConnection(server: String): MySQLConnection = {
val connection =
DriverManager.getConnection("jdbc:mysql://" + server + "/" + conf.schema + jdbcOptions,
conf.user, conf.pass)
connection.setReadOnly(readonly)
MySQLConnection(server, connection)
}
override def toString = "ConnectionPool(" + conf + ")"
}
object ConnectionPool {
private val log = LoggerFactory.getLogger(getClass)
private def ensureDriverLoaded =
// Load the MySQL JDBC driver. Yeah this looks like it has no side
// effect but it's required as it causes the driver to register itself
// with the JDBC DriverManager. Awesome design, right?
if (classOf[com.mysql.jdbc.Driver] == null)
throw new AssertionError("MySQL JDBC connector missing.")
/** Default options we use to connect to MySQL */
val jdbcOptions: String = "?" + {
val options = Map(
"connectTimeout" -> 4.seconds,
"socketTimeout" -> 2.seconds,
"useServerPrepStmts" -> true,
"cachePrepStmts" -> true,
"cacheResultSetMetadata" -> true,
"cacheServerConfiguration" -> true,
"logger" -> classOf[MySQLLogger]
)
options.toList.map { case (option, value) =>
option + "=" + (value match {
case c: Class[_] => c.getName
case d: Duration => d.inMilliseconds
case _ => value
})
} mkString "&"
}
def readOnly(config: PoolConfig, appName: String): ConnectionPool =
new ConnectionPool(config, jdbcOptions, true, appName)
def readWrite(config: PoolConfig, appName: String): ConnectionPool =
new ConnectionPool(config, jdbcOptions, false, appName)
/** Creates a thread-pool to use the given connections. */
private def makePool(size: Int) = {
val factory = new ThreadFactory {
val id = new AtomicInteger(0)
def newThread(r: Runnable) =
new Thread(r, "MySQL-" + id.incrementAndGet)
}
FuturePool(Executors.newFixedThreadPool(size, factory))
}
/**
* Closes all the connections from the given pool with the given config.
* <strong>WARNING:</strong> this function is blocking, and might take a
* while (maybe several seconds) to clear up the pool.
*/
private def closeAllConnections(conf: PoolConfig,
connections: ArrayBlockingQueue[MySQLConnection]) {
for (i <- 1 to conf.servers.length) {
// We're not serving a query to an end-user, and our goal is to
// close the connection but we don't want to wait forever in case
// the connection is somehow badly stuck. So allow quite a bit of
// time to grab a connection.
val connection = connections.poll(500, MILLISECONDS)
if (connection == null) {
log.error("Timeout while trying to get connection #" + i + " / "
+ conf.servers.length + ", connection will be leaked.")
} else {
connection.close()
}
}
}
}
/** Class for MySQL's JDBC logging (otherwise it goes to stderr by default). */
private final class MySQLLogger(name: String) extends Log {
val log = LoggerFactory.getLogger(name)
def isDebugEnabled: Boolean = log.isDebugEnabled
def isErrorEnabled: Boolean = log.isErrorEnabled
def isFatalEnabled: Boolean = log.isErrorEnabled
def isInfoEnabled: Boolean = log.isInfoEnabled
def isTraceEnabled: Boolean = log.isTraceEnabled
def isWarnEnabled: Boolean = log.isWarnEnabled
private def cast(msg: Any): String =
msg match {
case m: String => m
case _ =>
throw new ClassCastException("argument isn't a String but a "
+ msg.asInstanceOf[AnyRef].getClass.getName + ": " + msg)
}
def logDebug(msg: Any) {
log.debug(cast(msg))
}
def logDebug(msg: Any, e: Throwable) {
log.debug(cast(msg), e)
}
def logError(msg: Any) {
log.error(cast(msg))
}
def logError(msg: Any, e: Throwable) {
log.error(cast(msg), e)
}
def logFatal(msg: Any) {
log.error("** FATAL ** " + cast(msg)) // Keep going anyway.
}
def logFatal(msg: Any, e: Throwable) {
log.error("** FATAL ** " + cast(msg), e) // Keep going anyway.
}
def logInfo(msg: Any) {
log.info(cast(msg))
}
def logInfo(msg: Any, e: Throwable) {
log.info(cast(msg), e)
}
def logTrace(msg: Any) {
log.trace(cast(msg))
}
def logTrace(msg: Any, e: Throwable) {
log.trace(cast(msg), e)
}
def logWarn(msg: Any) {
log.warn(cast(msg))
}
def logWarn(msg: Any, e: Throwable) {
log.warn(cast(msg), e)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment