Skip to content

Instantly share code, notes, and snippets.

@feynmanliang
Last active September 4, 2015 05:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save feynmanliang/18d6b6d55fce961f2f15 to your computer and use it in GitHub Desktop.
Save feynmanliang/18d6b6d55fce961f2f15 to your computer and use it in GitHub Desktop.
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 6c02004..83d47c7 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -577,7 +577,9 @@ public String toString() {
StringBuilder build = new StringBuilder("[");
for (int i = 0; i < sizeInBytes; i += 8) {
build.append(java.lang.Long.toHexString(Platform.getLong(baseObject, baseOffset + i)));
- build.append(',');
+ if (i <= sizeInBytes-1) {
+ build.append(',');
+ }
}
build.append(']');
return build.toString();
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
index bf03c61..6b4145c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
@@ -17,14 +17,25 @@
package org.apache.spark.sql
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
+
+import scala.collection.mutable.ArrayBuffer
import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag
+import org.apache.spark.SparkEnv
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.io.CompressionCodec
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
+import org.apache.spark.sql.catalyst.expressions.{SpecificMutableRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.execution.datasources.LogicalRelation
+import org.apache.spark.sql.sources.{BaseRelation, TableScan}
import org.apache.spark.sql.types._
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow
-import org.apache.spark.sql.types.StructField
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.unsafe.Platform
+import org.apache.spark.unsafe.array.ByteArrayMethods
+import org.apache.spark.unsafe.memory.{MemoryBlock, TaskMemoryManager}
import org.apache.spark.unsafe.types.UTF8String
/**
@@ -110,4 +121,165 @@ private[sql] abstract class SQLImplicits {
DataFrameHolder(
_sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
}
+
+ /**
+ * ::Experimental::
+ *
+ * Pimp my library decorator for tungsten caching of DataFrames.
+ * @since 1.5.1
+ */
+ @Experimental
+ implicit class TungstenCache(df: DataFrame) {
+ /**
+ * Packs the rows of [[df]] into contiguous blocks of memory.
+ * @param compressionType "" (default), "lz4", "lzf", or "snappy", see
+ * [[CompressionCodec.ALL_COMPRESSION_CODECS]]
+ * @param blockSize size of each MemoryBlock (default = 4 MB)
+ */
+ def tungstenCache(
+ compressionType: String = "", blockSize: Int = 4000000): (RDD[_], DataFrame) = {
+ val schema = df.schema
+
+ val convert = CatalystTypeConverters.createToCatalystConverter(schema)
+ val internalRows = df.rdd.map(convert(_).asInstanceOf[InternalRow])
+ val cachedRDD = internalRows.mapPartitions { rowIterator =>
+ val bufferedRowIterator = rowIterator.buffered
+ val convertToUnsafe = UnsafeProjection.create(schema)
+ val taskMemoryManager = new TaskMemoryManager(SparkEnv.get.executorMemoryManager)
+ val compressionCodec: Option[CompressionCodec] = if (compressionType.isEmpty) {
+ None
+ } else {
+ Some(CompressionCodec.createCodec(SparkEnv.get.conf, compressionType))
+ }
+
+ new Iterator[MemoryBlock] {
+ // NOTE: This assumes that size of every row < blockSize
+ def next(): MemoryBlock = {
+ // Packs rows into a `blockSize` bytes contiguous block of memory, starting a new block
+ // whenever the current fills up
+ // Each row is laid out in memory as [rowSize (4)|row (rowSize)]
+ val block = taskMemoryManager.allocateUnchecked(blockSize)
+
+ var currOffset = 0
+ while (bufferedRowIterator.hasNext && currOffset < blockSize) {
+ val currRow = convertToUnsafe.apply(bufferedRowIterator.head)
+ val recordSize = 4 + currRow.getSizeInBytes
+ if (currOffset + recordSize < blockSize) {
+ Platform.putInt(
+ block.getBaseObject, block.getBaseOffset + currOffset, currRow.getSizeInBytes)
+ currRow.writeToMemory(block.getBaseObject, block.getBaseOffset + currOffset + 4)
+ bufferedRowIterator.next()
+ }
+ currOffset += recordSize // Increment currOffset regardless to break loop when full
+ }
+
+ // Optionally compress block before writing
+ compressionCodec match {
+ case Some(codec) =>
+ // Compress the block using an on-heap byte array
+ val blockArray = new Array[Byte](blockSize)
+ Platform.copyMemory(
+ block.getBaseObject,
+ block.getBaseOffset,
+ blockArray,
+ Platform.BYTE_ARRAY_OFFSET,
+ blockSize)
+ val baos = new ByteArrayOutputStream(blockSize)
+ val compressedBaos = codec.compressedOutputStream(baos)
+ compressedBaos.write(blockArray)
+ compressedBaos.flush()
+ compressedBaos.close()
+ val compressedBlockArray = baos.toByteArray
+
+ // Allocate a new block with compressed byte array padded to word boundary
+ val totalRecordSize = compressedBlockArray.size + 4
+ val nearestWordBoundary =
+ ByteArrayMethods.roundNumberOfBytesToNearestWord(totalRecordSize)
+ val padding = nearestWordBoundary - totalRecordSize
+ val compressedBlock = taskMemoryManager.allocateUnchecked(totalRecordSize + padding)
+ Platform.putInt(
+ compressedBlock.getBaseObject,
+ compressedBlock.getBaseOffset,
+ padding)
+ Platform.copyMemory(
+ compressedBlockArray,
+ Platform.BYTE_ARRAY_OFFSET,
+ compressedBlock.getBaseObject,
+ compressedBlock.getBaseOffset + 4,
+ compressedBlockArray.size)
+ taskMemoryManager.freeUnchecked(block)
+ compressedBlock
+ case None => block
+ }
+ }
+
+ def hasNext: Boolean = bufferedRowIterator.hasNext
+ }
+ }.setName(compressionType + "_" + df.toString).persist(StorageLevel.MEMORY_ONLY)
+
+ val baseRelation: BaseRelation = new BaseRelation with TableScan {
+ override val sqlContext = _sqlContext
+ override val schema = df.schema
+ override val needConversion = false
+
+ override def buildScan(): RDD[Row] = {
+ val numFields = this.schema.length
+ val _compressionType: String = compressionType
+ val _blockSize = blockSize
+
+ cachedRDD.flatMap { rawBlock =>
+ // Optionally decompress block
+ val compressionCodec: Option[CompressionCodec] = if (_compressionType.isEmpty) {
+ None
+ } else {
+ Some(CompressionCodec.createCodec(SparkEnv.get.conf, _compressionType))
+ }
+ val block = compressionCodec match {
+ case Some(codec) =>
+ // Copy compressed block (excluding padding) to on-heap byte array
+ val padding = Platform.getInt(rawBlock.getBaseObject, rawBlock.getBaseOffset)
+ val compressedBlockArray = new Array[Byte](_blockSize)
+ Platform.copyMemory(
+ rawBlock.getBaseObject,
+ rawBlock.getBaseOffset + 4,
+ compressedBlockArray,
+ Platform.BYTE_ARRAY_OFFSET,
+ rawBlock.size() - padding)
+
+ // Decompress into MemoryBlock backed by on-heap byte array
+ val compressedBaos = new ByteArrayInputStream(compressedBlockArray)
+ val uncompressedBlockArray = new Array[Byte](_blockSize)
+ val cis = codec.compressedInputStream(compressedBaos)
+ cis.read(uncompressedBlockArray)
+ cis.close()
+ MemoryBlock.fromByteArray(uncompressedBlockArray)
+ case None => rawBlock
+ }
+
+ val rows = new ArrayBuffer[InternalRow]()
+ var currOffset = 0
+ var moreData = true
+ while (currOffset < block.size() && moreData) {
+ val rowSize = Platform.getInt(block.getBaseObject, block.getBaseOffset + currOffset)
+ currOffset += 4
+ // TODO: should probably have a null terminator rather than relying on zeroed out
+ if (rowSize > 0) {
+ val unsafeRow = new UnsafeRow()
+ unsafeRow.pointTo(
+ block.getBaseObject, block.getBaseOffset + currOffset, numFields, rowSize)
+ rows.append(unsafeRow)
+ currOffset += rowSize
+ } else {
+ moreData = false
+ }
+ }
+ rows
+ }.asInstanceOf[RDD[Row]]
+ }
+
+ override def toString: String = getClass.getSimpleName + s"[${df.toString}]"
+ }
+ (cachedRDD, DataFrame(_sqlContext, LogicalRelation(baseRelation)))
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index af7590c..2eade42 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -26,7 +26,7 @@ import org.apache.spark.Accumulators
import org.apache.spark.sql.columnar._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.storage.{StorageLevel, RDDBlockId}
+import org.apache.spark.storage.{RDDBlockId, StorageLevel}
private case class BigData(s: String)
@@ -75,17 +75,17 @@ class CachedTableSuite extends QueryTest with SharedSQLContext {
}
test("unpersist an uncached table will not raise exception") {
- assert(None == ctx.cacheManager.lookupCachedData(testData))
+ assert(None === ctx.cacheManager.lookupCachedData(testData))
testData.unpersist(blocking = true)
- assert(None == ctx.cacheManager.lookupCachedData(testData))
+ assert(None === ctx.cacheManager.lookupCachedData(testData))
testData.unpersist(blocking = false)
- assert(None == ctx.cacheManager.lookupCachedData(testData))
+ assert(None === ctx.cacheManager.lookupCachedData(testData))
testData.persist()
- assert(None != ctx.cacheManager.lookupCachedData(testData))
+ assert(None !== ctx.cacheManager.lookupCachedData(testData))
testData.unpersist(blocking = true)
- assert(None == ctx.cacheManager.lookupCachedData(testData))
+ assert(None === ctx.cacheManager.lookupCachedData(testData))
testData.unpersist(blocking = false)
- assert(None == ctx.cacheManager.lookupCachedData(testData))
+ assert(None === ctx.cacheManager.lookupCachedData(testData))
}
test("cache table as select") {
@@ -333,7 +333,33 @@ class CachedTableSuite extends QueryTest with SharedSQLContext {
val accsSize = Accumulators.originals.size
ctx.uncacheTable("t1")
ctx.uncacheTable("t2")
- assert((accsSize - 2) == Accumulators.originals.size)
+ assert((accsSize - 2) === Accumulators.originals.size)
}
}
+
+ test("tungsten cache uncompressed table and read") {
+ val data = testData
+ // Use a 0.4 KB block size to force multiple blocks
+ val (_, tungstenCachedDF) = data.tungstenCache("", 400)
+ assert(tungstenCachedDF.collect() === testData.collect())
+ }
+
+ test("tungsten cache lz4 compressed table and read") {
+ val data = testData
+ val (_, tungstenCachedDF) = data.tungstenCache("lz4", 400)
+ assert(tungstenCachedDF.collect() === testData.collect())
+ }
+
+ test("tungsten cache lzf compressed table and read") {
+ val data = testData
+ val (_, tungstenCachedDF) = data.tungstenCache("lzf", 400)
+ assert(tungstenCachedDF.collect() === testData.collect())
+ }
+
+ test("tungsten cache snappy compressed table and read") {
+ val data = testData
+ val (_, tungstenCachedDF) = data.tungstenCache("snappy", 400)
+ assert(tungstenCachedDF.collect() === testData.collect())
+ }
+
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
index 2476b10..4f52535 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
@@ -21,7 +21,7 @@ import java.io.ByteArrayOutputStream
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection}
+import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.memory.MemoryAllocator
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
index dd75820..3a51f0e 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
@@ -52,4 +52,11 @@ public long size() {
public static MemoryBlock fromLongArray(final long[] array) {
return new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, array.length * 8);
}
+
+ /**
+ * Creates a memory block pointing to the memory used by the byte array.
+ */
+ public static MemoryBlock fromByteArray(final byte[] array) {
+ return new MemoryBlock(array, Platform.BYTE_ARRAY_OFFSET, array.length);
+ }
}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
index 97b2c93..8824f98 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
@@ -175,6 +175,16 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError {
}
/**
+ * Allocates a contiguous block of memory, without checking for leaks provided by
+ * {@code allocatedNonPageMemory}
+ */
+ public MemoryBlock allocateUnchecked(long size) throws OutOfMemoryError {
+ assert(size > 0) : "Size must be positive, but got " + size;
+ final MemoryBlock memory = executorMemoryManager.allocate(size);
+ return memory;
+ }
+
+ /**
* Free memory allocated by {@link TaskMemoryManager#allocate(long)}.
*/
public void free(MemoryBlock memory) {
@@ -187,6 +197,15 @@ public void free(MemoryBlock memory) {
}
/**
+ * Frees a contiguous block of memory, without checking for leaks provided by
+ * {@code allocatedNonPageMemory}
+ */
+ public void freeUnchecked(MemoryBlock memory) {
+ assert (memory.pageNumber == -1) : "Should call freePage() for pages, not free()";
+ executorMemoryManager.free(memory);
+ }
+
+ /**
* Given a memory page and offset within that page, encode this address into a 64-bit long.
* This address will remain valid as long as the corresponding page has not been freed.
*
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment