Skip to content

Instantly share code, notes, and snippets.

@sadikovi
Last active March 24, 2017 01:08
Show Gist options
  • Save sadikovi/e5308e6cb6e681fe08abb58c21632131 to your computer and use it in GitHub Desktop.
Save sadikovi/e5308e6cb6e681fe08abb58c21632131 to your computer and use it in GitHub Desktop.
Fix for CSV read/write for empty DataFrame, or with some empty partitions, will store metadata for a directory (csvfix1); or will write headers for each empty file (csvfix2)
package org.apache.spark.sql
import scala.language.implicitConversions
import scala.util.control.NonFatal
import org.apache.commons.io.IOUtils
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
private[sql] object MetadataCommiter extends Logging {
val METADATA = "_schema_metadata"
/** Construct metadata path from directory */
private def metadataPath(dir: Path): Path = dir.suffix(s"${Path.SEPARATOR}$METADATA")
/** Replace nulls with provided value for string columns */
private[sql] def convertNulls(df: DataFrame, replacement: String = ""): DataFrame = {
val fields = df.schema.filter { _.dataType == StringType }
var res = df
for (field <- fields) {
res = res.withColumn(field.name,
coalesce(col(s"`${field.name}`"), lit(replacement)))
}
res
}
/** Write content as UTF-8 string */
private def writeContent(fs: FileSystem, path: Path, content: String): Unit = {
val out = fs.create(path, true)
try {
IOUtils.write(content, out, "UTF-8")
} finally {
out.close()
}
}
/** Read content as UTF-8 string */
private def readContent(fs: FileSystem, path: Path): String = {
val in = fs.open(path)
try {
IOUtils.toString(in, "UTF-8")
} finally {
in.close()
}
}
/** Resolve path into fully-qualified file system path */
def resolvePath(conf: Configuration, path: String): (FileSystem, Path) = {
val unresolvedPath = new Path(path)
val fs = unresolvedPath.getFileSystem(conf)
val resolvedPath = unresolvedPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
(fs, resolvedPath)
}
/** Rollback metadata in case of failures */
private def deleteMetadata(fs: FileSystem, path: Path): Unit = {
logInfo(s"Remove metadata from directory $path")
fs.delete(metadataPath(path), false)
}
/** Commit metadata for provided DataFrame */
def commitMetadata(
fs: FileSystem,
path: Path,
df: DataFrame,
partitionColumnNames: Seq[String]): Unit = {
// create metadata file path
val fullPath = metadataPath(path)
// If partition columns are defined, insert them after non-partitioned columns,
// this might not work for nested columns, but csv should not support them
val schema = if (partitionColumnNames.nonEmpty) {
val origSchema = df.schema.filter { field =>
!partitionColumnNames.contains(field.name)
}
val partitionSpec = partitionColumnNames.map { name =>
df.schema(name)
}
StructType(origSchema ++ partitionSpec).prettyJson
} else {
df.schema.prettyJson
}
logInfo(s"Commit metadata $schema for directory $path")
try {
writeContent(fs, fullPath, schema)
} catch {
case NonFatal(err) =>
deleteMetadata(fs, fullPath)
throw err
}
}
/** Infer metadata if file exists in the directory */
def inferMetadata(fs: FileSystem, path: Path): Option[StructType] = {
logInfo(s"Attempt to infer metadata from directory $path")
val fullPath = metadataPath(path)
if (fs.exists(fullPath)) {
logInfo(s"Infer metadata from path $fullPath")
val metadata = readContent(fs, fullPath)
Some(DataType.fromJson(metadata).asInstanceOf[StructType])
} else {
None
}
}
}
private[sql] class CsvDataFrameReader(val spark: SparkSession) extends DataFrameReader(spark) {
override def csv(path: String): DataFrame = {
val (fs, resolvedPath) = MetadataCommiter.
resolvePath(spark.sessionState.newHadoopConf(), path)
val optSchema = MetadataCommiter.inferMetadata(fs, resolvedPath)
if (optSchema.isDefined) {
schema(optSchema.get)
}
val df = super.csv(path)
MetadataCommiter.convertNulls(df)
}
}
private[sql] class CsvDataFrameWriter[T](ds: Dataset[T]) {
@transient private val spark = ds.sparkSession
private val df = MetadataCommiter.convertNulls(ds.toDF(), "null")
private val writer = new DataFrameWriter(df)
private var partitionColumnNames: Seq[String] = Nil
def mode(saveMode: SaveMode): CsvDataFrameWriter[T] = {
this.writer.mode(saveMode)
this
}
def mode(saveMode: String): CsvDataFrameWriter[T] = {
this.writer.mode(saveMode)
this
}
def option(key: String, value: String): CsvDataFrameWriter[T] = {
this.writer.option(key, value)
this
}
def option(key: String, value: Boolean): CsvDataFrameWriter[T] = option(key, value.toString)
def option(key: String, value: Long): CsvDataFrameWriter[T] = option(key, value.toString)
def option(key: String, value: Double): CsvDataFrameWriter[T] = option(key, value.toString)
def partitionBy(colNames: String*): CsvDataFrameWriter[T] = {
partitionColumnNames = colNames
this.writer.partitionBy(colNames: _*)
this
}
def csv(path: String): Unit = {
this.writer.csv(path)
val (fs, resolvedPath) =
MetadataCommiter.resolvePath(spark.sessionState.newHadoopConf(), path)
MetadataCommiter.commitMetadata(fs, resolvedPath, df, partitionColumnNames)
}
}
/** Workaround fix to store metadata when there is no way of inferring schema from files */
private[sql] class Workaround[T](spark: SparkSession, ds: Dataset[T]) {
def fread: CsvDataFrameReader = new CsvDataFrameReader(spark)
def fwrite: CsvDataFrameWriter[T] = new CsvDataFrameWriter[T](ds)
}
/** Implicit methods for index */
package object csvfix {
implicit def workaroundWriter[T](ds: Dataset[T]): Workaround[T] = {
new Workaround[T](ds.sparkSession, ds)
}
implicit def workaroundReader(spark: SparkSession): Workaround[_] = {
new Workaround[Any](spark, null)
}
}
package org.apache.spark.sql
import scala.language.implicitConversions
import scala.util.control.NonFatal
import org.apache.commons.io.IOUtils
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
private[sql] object CsvWriteSupport extends Logging {
/** Replace nulls with provided value for string columns */
private[sql] def convertNulls(df: DataFrame, replacement: String = ""): DataFrame = {
val fields = df.schema.filter { _.dataType == StringType }
var res = df
for (field <- fields) {
res = res.withColumn(field.name,
coalesce(col(s"`${field.name}`"), lit(replacement)))
}
res
}
/** Generate header from given set of columns, does not escape column names */
private def generateHeader(columns: Seq[String], separator: String): String = {
if (separator.length != 1) {
throw new IllegalArgumentException(s"Expected separator as single char, got '$separator'")
}
columns.mkString(separator)
}
/** Append header string to the file, overwrite file if non-empty */
private def appendHeader(fs: FileSystem, path: Path, header: String): Unit = {
val out = fs.create(path, true)
try {
IOUtils.write(header, out, "UTF-8")
} finally {
out.close()
}
}
/** Resolve path into fully-qualified file system path */
def resolvePath(conf: Configuration, path: String): (FileSystem, Path) = {
val unresolvedPath = new Path(path)
val fs = unresolvedPath.getFileSystem(conf)
val resolvedPath = unresolvedPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
(fs, resolvedPath)
}
/** Commit metadata for provided DataFrame */
def writeHeader(fs: FileSystem, path: Path, df: DataFrame, separator: String): Unit = {
val header = generateHeader(df.columns, separator)
logInfo(s"Writing header [$header] for directory $path")
// discover all partition files that are not either checksum or _SUCCESS files
val statuses = fs.listStatus(path, new PathFilter() {
override def accept(p: Path) = {
// everything else is assumed to be a valid csv data file
p.getName != "_SUCCESS" && !p.getName.startsWith(".") && !p.getName.endsWith(".crc")
}
})
// find files that are empty and write header into those files
// non-empty files are assumed to already have header
statuses.foreach { status =>
if (status.getLen == 0) {
appendHeader(fs, status.getPath, header)
logInfo(s"Appending header to the file ${status.getPath}")
}
}
}
}
private[sql] class CsvDataFrameReader(val spark: SparkSession) extends DataFrameReader(spark) {
override def csv(path: String): DataFrame = {
val df = super.csv(path)
CsvWriteSupport.convertNulls(df)
}
}
// Does not support partitioning
private[sql] class CsvDataFrameWriter[T](ds: Dataset[T]) {
@transient private val spark = ds.sparkSession
private val df = CsvWriteSupport.convertNulls(ds.toDF(), "null")
private val writer = new DataFrameWriter(df)
private val headerKey = "header"
private var useHeader = false
private val separatorKey = "sep"
private var separatorValue = ","
def mode(saveMode: SaveMode): CsvDataFrameWriter[T] = {
this.writer.mode(saveMode)
this
}
def mode(saveMode: String): CsvDataFrameWriter[T] = {
this.writer.mode(saveMode)
this
}
def option(key: String, value: String): CsvDataFrameWriter[T] = {
if (key.toLowerCase == headerKey) useHeader = value.toBoolean
if (key.toLowerCase == separatorKey) separatorValue = value
this.writer.option(key, value)
this
}
def options(map: scala.collection.Map[String, String]): CsvDataFrameWriter[T] = {
map.keys.foreach(x => option(x,map(x)))
this
}
def option(key: String, value: Boolean): CsvDataFrameWriter[T] = option(key, value.toString)
def option(key: String, value: Long): CsvDataFrameWriter[T] = option(key, value.toString)
def option(key: String, value: Double): CsvDataFrameWriter[T] = option(key, value.toString)
def csv(path: String): Unit = {
this.writer.csv(path)
if (useHeader) {
val conf = spark.sessionState.newHadoopConf()
val (fs, resolvedPath) = CsvWriteSupport.resolvePath(conf, path)
CsvWriteSupport.writeHeader(fs, resolvedPath, df, separatorValue)
}
}
}
/** Workaround fix to store metadata when there is no way of inferring schema from files */
private[sql] class Workaround[T](spark: SparkSession, ds: Dataset[T]) {
def fread: CsvDataFrameReader = new CsvDataFrameReader(spark)
def fwrite: CsvDataFrameWriter[T] = new CsvDataFrameWriter[T](ds)
}
/** Implicit methods for index */
package object csvfix {
implicit def workaroundWriter[T](ds: Dataset[T]): Workaround[T] = {
new Workaround[T](ds.sparkSession, ds)
}
implicit def workaroundReader(spark: SparkSession): Workaround[_] = {
new Workaround[Any](spark, null)
}
}
@sadikovi
Copy link
Author

sadikovi commented Mar 7, 2017

Another approach is to write header into every empty file.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment