Skip to content

Instantly share code, notes, and snippets.

@feynmanliang
Last active September 3, 2015 23:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save feynmanliang/ca674a37574fb625011b to your computer and use it in GitHub Desktop.
Save feynmanliang/ca674a37574fb625011b to your computer and use it in GitHub Desktop.
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 6c02004..83d47c7 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -577,7 +577,9 @@ public String toString() {
StringBuilder build = new StringBuilder("[");
for (int i = 0; i < sizeInBytes; i += 8) {
build.append(java.lang.Long.toHexString(Platform.getLong(baseObject, baseOffset + i)));
- build.append(',');
+ if (i <= sizeInBytes-1) {
+ build.append(',');
+ }
}
build.append(']');
return build.toString();
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
index bf03c61..463de71 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
@@ -17,14 +17,21 @@
package org.apache.spark.sql
+import scala.collection.mutable.ArrayBuffer
import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag
+import org.apache.spark.SparkEnv
+import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
+import org.apache.spark.sql.catalyst.expressions.{SpecificMutableRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.execution.datasources.LogicalRelation
+import org.apache.spark.sql.sources.{BaseRelation, TableScan}
import org.apache.spark.sql.types._
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow
-import org.apache.spark.sql.types.StructField
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.unsafe.Platform
+import org.apache.spark.unsafe.memory.{MemoryBlock, TaskMemoryManager}
import org.apache.spark.unsafe.types.UTF8String
/**
@@ -110,4 +117,89 @@ private[sql] abstract class SQLImplicits {
DataFrameHolder(
_sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
}
+
+ /**
+ * ::Experimental::
+ *
+ * Pimp my library decorator for tungsten caching of DataFrames.
+ * @since 1.5.1
+ */
+ @Experimental
+ implicit class TungstenCache(df: DataFrame) {
+ /**
+ * Packs the rows of [[df]] into contiguous blocks of memory.
+ */
+ def tungstenCache(): (RDD[_], DataFrame) = {
+ val BLOCK_SIZE = 4000000 // 4 MB blocks
+ val schema = df.schema
+
+ val convert = CatalystTypeConverters.createToCatalystConverter(schema)
+ val internalRows = df.rdd.map(convert(_).asInstanceOf[InternalRow])
+ val cachedRDD = internalRows.mapPartitions { rowIterator =>
+ val bufferedRowIterator = rowIterator.buffered
+
+ val convertToUnsafe = UnsafeProjection.create(schema)
+ val taskMemoryManager = new TaskMemoryManager(SparkEnv.get.executorMemoryManager)
+ new Iterator[MemoryBlock] {
+
+ // This assumes that size of row < BLOCK_SIZE
+ def next(): MemoryBlock = {
+ val block = taskMemoryManager.allocateUnchecked(BLOCK_SIZE)
+ var currOffset = 0
+
+ while (bufferedRowIterator.hasNext && currOffset < BLOCK_SIZE) {
+ val currRow = convertToUnsafe.apply(bufferedRowIterator.head)
+ val recordSize = 4 + currRow.getSizeInBytes
+
+ if (currOffset + recordSize < BLOCK_SIZE) {
+ // Pack into memory with layout [rowSize (4) | row (rowSize)]
+ Platform.putInt(
+ block.getBaseObject, block.getBaseOffset + currOffset, currRow.getSizeInBytes)
+ currRow.writeToMemory(
+ block.getBaseObject, block.getBaseOffset + currOffset + 4)
+ bufferedRowIterator.next()
+ }
+ currOffset += recordSize // Increment regardless to break loop when full
+ }
+ block
+ }
+
+ def hasNext: Boolean = bufferedRowIterator.hasNext
+ }
+ }.persist(StorageLevel.MEMORY_ONLY)
+
+ val baseRelation: BaseRelation = new BaseRelation with TableScan {
+ override val sqlContext = _sqlContext
+ override val schema = df.schema
+ override val needConversion = false
+
+ override def buildScan(): RDD[Row] = {
+ val numFields = this.schema.length
+
+ cachedRDD.flatMap { block =>
+ val rows = new ArrayBuffer[InternalRow]()
+ var currOffset = 0
+ var moreData = true
+ while (currOffset < block.size() && moreData) {
+ val rowSize = Platform.getInt(block.getBaseObject, block.getBaseOffset + currOffset)
+ currOffset += 4
+ if (rowSize > 0) {
+ val unsafeRow = new UnsafeRow()
+ unsafeRow.pointTo(
+ block.getBaseObject, block.getBaseOffset + currOffset, numFields, rowSize)
+ rows.append(unsafeRow)
+ currOffset += rowSize
+ } else {
+ moreData = false
+ }
+ }
+ rows
+ }.asInstanceOf[RDD[Row]]
+ }
+
+ override def toString: String = getClass.getSimpleName + s"[${df.toString}]"
+ }
+ (cachedRDD, DataFrame(_sqlContext, LogicalRelation(baseRelation)))
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index af7590c..9036df2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -26,7 +26,7 @@ import org.apache.spark.Accumulators
import org.apache.spark.sql.columnar._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.storage.{StorageLevel, RDDBlockId}
+import org.apache.spark.storage.{RDDBlockId, StorageLevel}
private case class BigData(s: String)
@@ -75,17 +75,17 @@ class CachedTableSuite extends QueryTest with SharedSQLContext {
}
test("unpersist an uncached table will not raise exception") {
- assert(None == ctx.cacheManager.lookupCachedData(testData))
+ assert(None === ctx.cacheManager.lookupCachedData(testData))
testData.unpersist(blocking = true)
- assert(None == ctx.cacheManager.lookupCachedData(testData))
+ assert(None === ctx.cacheManager.lookupCachedData(testData))
testData.unpersist(blocking = false)
- assert(None == ctx.cacheManager.lookupCachedData(testData))
+ assert(None === ctx.cacheManager.lookupCachedData(testData))
testData.persist()
- assert(None != ctx.cacheManager.lookupCachedData(testData))
+ assert(None !== ctx.cacheManager.lookupCachedData(testData))
testData.unpersist(blocking = true)
- assert(None == ctx.cacheManager.lookupCachedData(testData))
+ assert(None === ctx.cacheManager.lookupCachedData(testData))
testData.unpersist(blocking = false)
- assert(None == ctx.cacheManager.lookupCachedData(testData))
+ assert(None === ctx.cacheManager.lookupCachedData(testData))
}
test("cache table as select") {
@@ -333,7 +333,13 @@ class CachedTableSuite extends QueryTest with SharedSQLContext {
val accsSize = Accumulators.originals.size
ctx.uncacheTable("t1")
ctx.uncacheTable("t2")
- assert((accsSize - 2) == Accumulators.originals.size)
+ assert((accsSize - 2) === Accumulators.originals.size)
}
}
+
+ test("tungsten cache table and read") {
+ val data = testData
+ val (cachedRDD, tungstenCachedDF) = data.tungstenCache()
+ assert(tungstenCachedDF.collect() === testData.collect())
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
index 2476b10..4f52535 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
@@ -21,7 +21,7 @@ import java.io.ByteArrayOutputStream
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection}
+import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.memory.MemoryAllocator
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
index 97b2c93..cc78fc8 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
@@ -175,6 +175,16 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError {
}
/**
+ * Allocates a contiguous block of memory, without checking for leaks provided by
+ * {@code allocatedNonPageMemory}
+ */
+ public MemoryBlock allocateUnchecked(long size) throws OutOfMemoryError {
+ assert(size > 0) : "Size must be positive, but got " + size;
+ final MemoryBlock memory = executorMemoryManager.allocate(size);
+ return memory;
+ }
+
+ /**
* Free memory allocated by {@link TaskMemoryManager#allocate(long)}.
*/
public void free(MemoryBlock memory) {
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment