Skip to content

Instantly share code, notes, and snippets.

@cfeduke
Last active August 29, 2015 14:14
Show Gist options
  • Save cfeduke/3bca88ed793ddf20ea6d to your computer and use it in GitHub Desktop.
Save cfeduke/3bca88ed793ddf20ea6d to your computer and use it in GitHub Desktop.
JDBC RDD for Spark when data is pre-sharded across databases.
package org.apache.spark.rdd
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import java.sql.{Connection, ResultSet}
import org.apache.spark.util.NextIterator
import org.apache.spark.{Logging, Partition, SparkContext, TaskContext}
import scala.reflect.ClassTag
private[spark] class ShardedJdbcPartition(idx: Int, val shard: String) extends Partition {
override def index = idx
}
/**
* An RDD that executes an SQL query on a sharded JDBC connection and reads results.
*
* @param getConnection a function that maps a shard name into an open connection
* The RDD takes care of closing the connection.
* @param shards the shard names
* @param sql the text of the query.
* @param mapRow a function from a ResultSet to a single row of the desired result type(s).
* This should only call getInt, getString, etc; the RDD takes care of calling next.
*/
class ShardedJdbcRDD[T: ClassTag](
sc: SparkContext,
getConnection: (String) => Connection,
shards: Seq[String],
sql: String,
mapRow: (ResultSet) => T)
extends RDD[T](sc, Nil) with Logging {
override def getPartitions: Array[Partition] = {
shards.zipWithIndex.map { case (shard, index) =>
new ShardedJdbcPartition(index, shard)
}.toArray
}
override def compute(thePart: Partition, context: TaskContext) = new NextIterator[T] {
context.addTaskCompletionListener{ context => closeIfNeeded() }
val part = thePart.asInstanceOf[ShardedJdbcPartition]
val conn = getConnection(part.shard)
val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
// setFetchSize(Integer.MIN_VALUE) is a mysql driver specific way to force streaming results,
// rather than pulling entire resultset into memory.
// see http://dev.mysql.com/doc/refman/5.0/en/connector-j-reference-implementation-notes.html
if (conn.getMetaData.getURL.matches("jdbc:mysql:.*")) {
stmt.setFetchSize(Integer.MIN_VALUE)
logInfo("statement fetch size set to: " + stmt.getFetchSize + " to force MySQL streaming ")
}
val rs = stmt.executeQuery()
override def getNext(): T = {
if (rs.next()) {
mapRow(rs)
} else {
finished = true
null.asInstanceOf[T]
}
}
override def close() {
try {
if (null != rs && ! rs.isClosed) {
rs.close()
}
} catch {
case e: Exception => logWarning("Exception closing resultset", e)
}
try {
if (null != stmt && ! stmt.isClosed) {
stmt.close()
}
} catch {
case e: Exception => logWarning("Exception closing statement", e)
}
try {
if (null != conn && ! conn.isClosed) {
conn.close()
}
logInfo("closed connection")
} catch {
case e: Exception => logWarning("Exception closing connection", e)
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment