Skip to content

Instantly share code, notes, and snippets.

@tsuna
Created March 30, 2012 00:15
Show Gist options
  • Star 7 You must be signed in to star a gist
  • Fork 7 You must be signed in to fork a gist
  • Save tsuna/2245176 to your computer and use it in GitHub Desktop.
Save tsuna/2245176 to your computer and use it in GitHub Desktop.
MySQL JDBC connection pool for Scala + Finagle
// Copyright (C) 2012 Benoit Sigoure
// Copyright (C) 2012 StumbleUpon, Inc.
// This library is free software: you can redistribute it and/or modify it
// under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 2.1 of the License, or (at your
// option) any later version. This program is distributed in the hope that it
// will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty
// of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser
// General Public License for more details. You should have received a copy
// of the GNU Lesser General Public License along with this program. If not,
// see <http://www.gnu.org/licenses/>.
package com.stumbleupon.backends
import java.sql.Connection
import java.sql.DriverManager
import java.sql.PreparedStatement
import java.sql.ResultSet
import java.sql.SQLDataException
import java.sql.SQLFeatureNotSupportedException
import java.sql.SQLIntegrityConstraintViolationException
import java.sql.SQLNonTransientException
import java.sql.SQLRecoverableException
import java.sql.SQLSyntaxErrorException
import java.util.concurrent.ArrayBlockingQueue
import java.util.concurrent.Executors
import java.util.concurrent.ThreadFactory
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.TimeUnit.MILLISECONDS
import scala.collection.mutable.ArrayBuffer
import org.slf4j.LoggerFactory
import com.mysql.jdbc.log.Log
import com.twitter.conversions.time._
import com.twitter.util.Duration
import com.twitter.util.Future
import com.twitter.util.FuturePool
import com.stumbleupon.common.Counter // This is like an AtomicLong except based on jsr166e.LongAdder
/**
* Configuration for a connection Pool.
* @param servers A list of "ip:port".
* @param user MySQL username to connect as.
* @param pass MySQL password for that user.
* @param schema Name of the DB schema to use.
*/
final case class PoolConfig(servers: Seq[String], user: String, pass: String, schema: String) {
/** Like `equals' but ignores the order in `servers' in case they were shuffled. */
def equivalentTo(other: PoolConfig): Boolean =
user == other.user && pass == other.pass && schema == other.schema &&
servers.sorted == other.servers.sorted
}
/**
* Wrapper for JDBC connections.
* We have to wrap every connection just so we can remember which server this
* connection is connected to, so we can reconnect when something bad happens.
* Because, yes, believe it or not, there's no way to reliably extract this
* information from a JDBC connection object.
* @param server A "host:port" string.
*/
final case class MySQLConnection(server: String, connection: Connection) {
def prepareStatement(sql: String) = connection.prepareStatement(sql)
def close = connection.close
}
/**
* A connection pool for asynchronous operations.
* For each connection, there is a dedicated thread, because MySQL doesn't
* have an asynchronous RPC protocol, and because JDBC doesn't have an
* asynchronous API.
* @param cfg Configuration for this connection pool.
* @param options A "query string" passed as-is in the JDBC URL.
* @param readonly Whether or not to set the connection read-only mode.
* @param appName Name of the current app (e.g. "honeybadger").
*/
final class ConnectionPool(cfg: PoolConfig,
options: String,
val readonly: Boolean,
appName: String) {
import ConnectionPool._
ensureDriverLoaded
@volatile private[this] var conf = cfg
private[this] val pool = makePool(cfg.servers.length)
@volatile private[this] var connections: ArrayBlockingQueue[MySQLConnection] = _
createConnections() // Populates `connections'.
/** Returns the current configuration of this pool. */
def config = conf
/**
* Attempts to apply the new configuration given to this pool.
* Changes are applied atomically without disruptive ongoing traffic.
* If successful, this closes all the connections and replaces them all with
* new connections.
* If there's an exception thrown, changes are rolled back first and both
* the configuration and the connection pool will remain unchanged.
* <strong>WARNING:</strong> this function is blocking, and might take a
* while (maybe several seconds) to return.
* @param newcfg The new configuration to apply to this pool. The
* configuration is assumed to be sane.
* @throws SQLException if something bad happens (e.g. being unable to open
* a connection to any one of the hosts for whatever reason).
*/
def updateConfig(newcfg: PoolConfig) {
// Almost everything we do is thread-safe but in order to guarantee that
// we can correctly rollback the changes in case of an exception, and in
// order to ensure that we only attempt to apply one change at a time,
// it's much safer and easier to make this entire method synchronized.
synchronized {
val prevconns = connections // volatile-read
val prevcfg = conf // volatile-read
try {
conf = newcfg
createConnections() // volatile-write on connections
// Success! Now dispose of the previous connections, to not leak them.
try {
closeAllConnections(prevcfg, prevconns)
} catch {
case e: Exception =>
log.warn("Uncaught exception while closing an old connection after"
+ " reloading a new configuration", e)
}
} catch {
case e: Exception =>
// Roll-back.
connections = prevconns // volatile-write
conf = prevcfg // volatile-write
throw e
}
}
}
/** Creates and populates all the connections for this pool .*/
private def createConnections() {
val newconns = new ArrayBlockingQueue[MySQLConnection](conf.servers.length)
conf.servers foreach { server => // server is already "ip:port".
newconns.add(newConnection(server))
}
connections = newconns // commit: volatile-write
}
/** How many queries did we send to MySQL. */
private[this] val queries = new Counter
/** How many exceptions we got from JDBC. */
private[this] val exceptions = new Counter
/** Returns the number of queries sent to MySQL. */
def queriesSent: Long = queries.get
/** Returns the number of exception caught while MySQL stuff. */
def exceptionsCaught: Long = exceptions.get
/** Closes all connections and releases all threads. */
def shutdown() {
pool.executor.shutdown()
closeAllConnections(conf, connections)
}
/**
* Executes a SELECT statement on the database.
* @param f Function called on each row returned by the database. This
* function is called with the connection locked, so if this function takes
* time, it will prevent the connection from beind reused for another query.
* @param sql The SQL statement, e.g. "SELECT foo FROM t WHERE id = ?"
* @param params The parameters to substitute in the `?' placeholders.
* These parameters don't need to be escaped as prepared statements are used
* and they already prevent SQL injections.
* @return A future sequence of things returned by `f'.
* @throws SQLException (async) if something bad happens (sorry I don't know more).
*/
def select[T](f: ResultSet => T, sql: String, params: Seq[Any]): Future[Seq[T]] = {
pool(execute(f, "/*" + appName + "*/ " + sql, params))
}
// TODO(tsuna): Provide code for insert, update etc, not just select.
private def execute[T](f: ResultSet => T, sql: String, params: Seq[Any]): Seq[T] = {
queries.increment
val connpool = this.connections // volatile-read
var connection = connpool.poll
if (connection == null) { // Should never happen.
// We have as many threads as connections so this can only happen if a
// thread is leaking a connection, which would be really bad.
val e = new IllegalStateException("WTF? Couldn't get a connection from the pool.")
exceptions.increment
log.error(e.getMessage)
throw e
}
try {
val statement = connection.prepareStatement(sql)
try {
bindParameters(statement, params)
if (log.isDebugEnabled)
log.debug(connection.server + ": " + sql
+ " " + params.mkString("(", ", ", ")"))
val rs = statement.executeQuery
try {
val results = new ArrayBuffer[T]
while (rs.next) {
results += f(rs)
}
results
} finally {
rs.close
}
} finally {
statement.close
}
} catch {
case e: SQLSyntaxErrorException =>
logAndRethrow(connection, "Syntax error in SQL query",
sql, params, e)
case e: SQLIntegrityConstraintViolationException =>
logAndRethrow(connection, "Integrity constraint violated by SQL query",
sql, params, e)
case e: SQLFeatureNotSupportedException =>
logAndRethrow(connection, "Feature not supported in SQL query",
sql, params, e)
case e: SQLDataException =>
logAndRethrow(connection, "Data exception caused by SQL query",
sql, params, e)
case e @ (_: SQLRecoverableException | _: SQLNonTransientException) =>
// The remaining kinds of SQLNonTransientException are typically
// connection-level problems, so let's close this connection and get a
// new one.
// For a SQLRecoverableException the JDK javadoc manual says that "the
// recovery operation must include closing the current connection and
// getting a new connection".
connection.close // If we double-close it's OK, it's a no-op.
// Create a new connection, the `finally' block below will put it back
// in the pool.
connection = newConnection(connection.server)
// TODO(tsuna): If we wanted we could retry once here.
logAndRethrow(connection, "Error on connection when trying to execute",
sql, params, e)
case e: Throwable =>
// TODO(tsuna): Should we close the connection here? I'm not sure.
logAndRethrow(connection, "Uncaught exception", sql, params, e)
} finally {
// Always return the connection to the pool.
connpool.put(connection)
}
}
/** Logs an exception and rethrows it. */
private def logAndRethrow(connection: MySQLConnection, msg: String,
sql: String, params: Seq[Any], e: Throwable) = {
// This function must never throw a new exception of its own.
exceptions.increment
val cause = new StringBuilder
var exception = e
// Get names & messages of all exceptions in the chain.
while (exception != null) {
cause.append(", caused by ")
.append(e.getClass.getName)
.append(": ")
.append(e.getMessage)
exception = exception.getCause // previous exception causing this one.
}
log.error(connection.server + ": " + msg + ": " + sql
+ " with params " + params.mkString("(", ", ", ")")
+ cause)
throw e
}
private def bindParameters(statement: PreparedStatement,
params: TraversableOnce[Any]) {
bindParameters(statement, 1, params)
}
private def bindParameters(statement: PreparedStatement,
startIndex: Int,
params: TraversableOnce[Any]): Int = {
var index = startIndex
for (param <- params) {
param match {
case i: Int => statement.setInt(index, i)
case l: Long => statement.setLong(index, l)
case s: String => statement.setString(index, s)
case l: TraversableOnce[_] =>
index = bindParameters(statement, index, l) - 1
case p: Product =>
index = bindParameters(statement, index, p.productIterator.toList) - 1
case b: Array[Byte] => statement.setBytes(index, b)
case b: Boolean => statement.setBoolean(index, b)
case s: Short => statement.setShort(index, s)
case f: Float => statement.setFloat(index, f)
case d: Double => statement.setDouble(index, d)
case _ =>
throw new IllegalArgumentException("Unsupported data type "
+ param.asInstanceOf[AnyRef].getClass.getName + ": " + param)
}
index += 1
}
index
}
/**
* Returns a new MySQL connection.
* @param server A "host:port" string.
*/
private def newConnection(server: String): MySQLConnection = {
val connection =
DriverManager.getConnection("jdbc:mysql://" + server + "/" + conf.schema + jdbcOptions,
conf.user, conf.pass)
connection.setReadOnly(readonly)
MySQLConnection(server, connection)
}
override def toString = "ConnectionPool(" + conf + ")"
}
object ConnectionPool {
private val log = LoggerFactory.getLogger(getClass)
private def ensureDriverLoaded =
// Load the MySQL JDBC driver. Yeah this looks like it has no side
// effect but it's required as it causes the driver to register itself
// with the JDBC DriverManager. Awesome design, right?
if (classOf[com.mysql.jdbc.Driver] == null)
throw new AssertionError("MySQL JDBC connector missing.")
/** Default options we use to connect to MySQL */
val jdbcOptions: String = "?" + {
val options = Map(
"connectTimeout" -> 4.seconds,
"socketTimeout" -> 2.seconds,
"useServerPrepStmts" -> true,
"cachePrepStmts" -> true,
"cacheResultSetMetadata" -> true,
"cacheServerConfiguration" -> true,
"logger" -> classOf[MySQLLogger]
)
options.toList.map { case (option, value) =>
option + "=" + (value match {
case c: Class[_] => c.getName
case d: Duration => d.inMilliseconds
case _ => value
})
} mkString "&"
}
def readOnly(config: PoolConfig, appName: String): ConnectionPool =
new ConnectionPool(config, jdbcOptions, true, appName)
def readWrite(config: PoolConfig, appName: String): ConnectionPool =
new ConnectionPool(config, jdbcOptions, false, appName)
/** Creates a thread-pool to use the given connections. */
private def makePool(size: Int) = {
val factory = new ThreadFactory {
val id = new AtomicInteger(0)
def newThread(r: Runnable) =
new Thread(r, "MySQL-" + id.incrementAndGet)
}
FuturePool(Executors.newFixedThreadPool(size, factory))
}
/**
* Closes all the connections from the given pool with the given config.
* <strong>WARNING:</strong> this function is blocking, and might take a
* while (maybe several seconds) to clear up the pool.
*/
private def closeAllConnections(conf: PoolConfig,
connections: ArrayBlockingQueue[MySQLConnection]) {
for (i <- 1 to conf.servers.length) {
// We're not serving a query to an end-user, and our goal is to
// close the connection but we don't want to wait forever in case
// the connection is somehow badly stuck. So allow quite a bit of
// time to grab a connection.
val connection = connections.poll(500, MILLISECONDS)
if (connection == null) {
log.error("Timeout while trying to get connection #" + i + " / "
+ conf.servers.length + ", connection will be leaked.")
} else {
connection.close
}
}
}
}
/** Class for MySQL's JDBC logging (otherwise it goes to stderr by default). */
private final class MySQLLogger(name: String) extends Log {
val log = LoggerFactory.getLogger(name)
def isDebugEnabled: Boolean = log.isDebugEnabled
def isErrorEnabled: Boolean = log.isErrorEnabled
def isFatalEnabled: Boolean = log.isErrorEnabled
def isInfoEnabled: Boolean = log.isInfoEnabled
def isTraceEnabled: Boolean = log.isTraceEnabled
def isWarnEnabled: Boolean = log.isWarnEnabled
private def cast(msg: Any): String =
msg match {
case m: String => m
case _ =>
throw new ClassCastException("argument isn't a String but a "
+ msg.asInstanceOf[AnyRef].getClass.getName + ": " + msg)
}
def logDebug(msg: Any) {
log.debug(cast(msg))
}
def logDebug(msg: Any, e: Throwable) {
log.debug(cast(msg), e)
}
def logError(msg: Any) {
log.error(cast(msg))
}
def logError(msg: Any, e: Throwable) {
log.error(cast(msg), e)
}
def logFatal(msg: Any) {
log.error("** FATAL ** " + cast(msg)) // Keep going anyway.
}
def logFatal(msg: Any, e: Throwable) {
log.error("** FATAL ** " + cast(msg), e) // Keep going anyway.
}
def logInfo(msg: Any) {
log.info(cast(msg))
}
def logInfo(msg: Any, e: Throwable) {
log.info(cast(msg), e)
}
def logTrace(msg: Any) {
log.trace(cast(msg))
}
def logTrace(msg: Any, e: Throwable) {
log.trace(cast(msg), e)
}
def logWarn(msg: Any) {
log.warn(cast(msg))
}
def logWarn(msg: Any, e: Throwable) {
log.warn(cast(msg), e)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment