Last active
March 24, 2017 01:08
-
-
Save sadikovi/e5308e6cb6e681fe08abb58c21632131 to your computer and use it in GitHub Desktop.
Fix for CSV read/write for empty DataFrame, or with some empty partitions, will store metadata for a directory (csvfix1); or will write headers for each empty file (csvfix2)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.apache.spark.sql | |
import scala.language.implicitConversions | |
import scala.util.control.NonFatal | |
import org.apache.commons.io.IOUtils | |
import org.apache.hadoop.conf.Configuration | |
import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} | |
import org.apache.spark.internal.Logging | |
import org.apache.spark.sql.functions._ | |
import org.apache.spark.sql.types._ | |
private[sql] object MetadataCommiter extends Logging { | |
val METADATA = "_schema_metadata" | |
/** Construct metadata path from directory */ | |
private def metadataPath(dir: Path): Path = dir.suffix(s"${Path.SEPARATOR}$METADATA") | |
/** Replace nulls with provided value for string columns */ | |
private[sql] def convertNulls(df: DataFrame, replacement: String = ""): DataFrame = { | |
val fields = df.schema.filter { _.dataType == StringType } | |
var res = df | |
for (field <- fields) { | |
res = res.withColumn(field.name, | |
coalesce(col(s"`${field.name}`"), lit(replacement))) | |
} | |
res | |
} | |
/** Write content as UTF-8 string */ | |
private def writeContent(fs: FileSystem, path: Path, content: String): Unit = { | |
val out = fs.create(path, true) | |
try { | |
IOUtils.write(content, out, "UTF-8") | |
} finally { | |
out.close() | |
} | |
} | |
/** Read content as UTF-8 string */ | |
private def readContent(fs: FileSystem, path: Path): String = { | |
val in = fs.open(path) | |
try { | |
IOUtils.toString(in, "UTF-8") | |
} finally { | |
in.close() | |
} | |
} | |
/** Resolve path into fully-qualified file system path */ | |
def resolvePath(conf: Configuration, path: String): (FileSystem, Path) = { | |
val unresolvedPath = new Path(path) | |
val fs = unresolvedPath.getFileSystem(conf) | |
val resolvedPath = unresolvedPath.makeQualified(fs.getUri, fs.getWorkingDirectory) | |
(fs, resolvedPath) | |
} | |
/** Rollback metadata in case of failures */ | |
private def deleteMetadata(fs: FileSystem, path: Path): Unit = { | |
logInfo(s"Remove metadata from directory $path") | |
fs.delete(metadataPath(path), false) | |
} | |
/** Commit metadata for provided DataFrame */ | |
def commitMetadata( | |
fs: FileSystem, | |
path: Path, | |
df: DataFrame, | |
partitionColumnNames: Seq[String]): Unit = { | |
// create metadata file path | |
val fullPath = metadataPath(path) | |
// If partition columns are defined, insert them after non-partitioned columns, | |
// this might not work for nested columns, but csv should not support them | |
val schema = if (partitionColumnNames.nonEmpty) { | |
val origSchema = df.schema.filter { field => | |
!partitionColumnNames.contains(field.name) | |
} | |
val partitionSpec = partitionColumnNames.map { name => | |
df.schema(name) | |
} | |
StructType(origSchema ++ partitionSpec).prettyJson | |
} else { | |
df.schema.prettyJson | |
} | |
logInfo(s"Commit metadata $schema for directory $path") | |
try { | |
writeContent(fs, fullPath, schema) | |
} catch { | |
case NonFatal(err) => | |
deleteMetadata(fs, fullPath) | |
throw err | |
} | |
} | |
/** Infer metadata if file exists in the directory */ | |
def inferMetadata(fs: FileSystem, path: Path): Option[StructType] = { | |
logInfo(s"Attempt to infer metadata from directory $path") | |
val fullPath = metadataPath(path) | |
if (fs.exists(fullPath)) { | |
logInfo(s"Infer metadata from path $fullPath") | |
val metadata = readContent(fs, fullPath) | |
Some(DataType.fromJson(metadata).asInstanceOf[StructType]) | |
} else { | |
None | |
} | |
} | |
} | |
private[sql] class CsvDataFrameReader(val spark: SparkSession) extends DataFrameReader(spark) { | |
override def csv(path: String): DataFrame = { | |
val (fs, resolvedPath) = MetadataCommiter. | |
resolvePath(spark.sessionState.newHadoopConf(), path) | |
val optSchema = MetadataCommiter.inferMetadata(fs, resolvedPath) | |
if (optSchema.isDefined) { | |
schema(optSchema.get) | |
} | |
val df = super.csv(path) | |
MetadataCommiter.convertNulls(df) | |
} | |
} | |
private[sql] class CsvDataFrameWriter[T](ds: Dataset[T]) { | |
@transient private val spark = ds.sparkSession | |
private val df = MetadataCommiter.convertNulls(ds.toDF(), "null") | |
private val writer = new DataFrameWriter(df) | |
private var partitionColumnNames: Seq[String] = Nil | |
def mode(saveMode: SaveMode): CsvDataFrameWriter[T] = { | |
this.writer.mode(saveMode) | |
this | |
} | |
def mode(saveMode: String): CsvDataFrameWriter[T] = { | |
this.writer.mode(saveMode) | |
this | |
} | |
def option(key: String, value: String): CsvDataFrameWriter[T] = { | |
this.writer.option(key, value) | |
this | |
} | |
def option(key: String, value: Boolean): CsvDataFrameWriter[T] = option(key, value.toString) | |
def option(key: String, value: Long): CsvDataFrameWriter[T] = option(key, value.toString) | |
def option(key: String, value: Double): CsvDataFrameWriter[T] = option(key, value.toString) | |
def partitionBy(colNames: String*): CsvDataFrameWriter[T] = { | |
partitionColumnNames = colNames | |
this.writer.partitionBy(colNames: _*) | |
this | |
} | |
def csv(path: String): Unit = { | |
this.writer.csv(path) | |
val (fs, resolvedPath) = | |
MetadataCommiter.resolvePath(spark.sessionState.newHadoopConf(), path) | |
MetadataCommiter.commitMetadata(fs, resolvedPath, df, partitionColumnNames) | |
} | |
} | |
/** Workaround fix to store metadata when there is no way of inferring schema from files */ | |
private[sql] class Workaround[T](spark: SparkSession, ds: Dataset[T]) { | |
def fread: CsvDataFrameReader = new CsvDataFrameReader(spark) | |
def fwrite: CsvDataFrameWriter[T] = new CsvDataFrameWriter[T](ds) | |
} | |
/** Implicit methods for index */ | |
package object csvfix { | |
implicit def workaroundWriter[T](ds: Dataset[T]): Workaround[T] = { | |
new Workaround[T](ds.sparkSession, ds) | |
} | |
implicit def workaroundReader(spark: SparkSession): Workaround[_] = { | |
new Workaround[Any](spark, null) | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.apache.spark.sql | |
import scala.language.implicitConversions | |
import scala.util.control.NonFatal | |
import org.apache.commons.io.IOUtils | |
import org.apache.hadoop.conf.Configuration | |
import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter} | |
import org.apache.spark.internal.Logging | |
import org.apache.spark.sql.functions._ | |
import org.apache.spark.sql.types._ | |
private[sql] object CsvWriteSupport extends Logging { | |
/** Replace nulls with provided value for string columns */ | |
private[sql] def convertNulls(df: DataFrame, replacement: String = ""): DataFrame = { | |
val fields = df.schema.filter { _.dataType == StringType } | |
var res = df | |
for (field <- fields) { | |
res = res.withColumn(field.name, | |
coalesce(col(s"`${field.name}`"), lit(replacement))) | |
} | |
res | |
} | |
/** Generate header from given set of columns, does not escape column names */ | |
private def generateHeader(columns: Seq[String], separator: String): String = { | |
if (separator.length != 1) { | |
throw new IllegalArgumentException(s"Expected separator as single char, got '$separator'") | |
} | |
columns.mkString(separator) | |
} | |
/** Append header string to the file, overwrite file if non-empty */ | |
private def appendHeader(fs: FileSystem, path: Path, header: String): Unit = { | |
val out = fs.create(path, true) | |
try { | |
IOUtils.write(header, out, "UTF-8") | |
} finally { | |
out.close() | |
} | |
} | |
/** Resolve path into fully-qualified file system path */ | |
def resolvePath(conf: Configuration, path: String): (FileSystem, Path) = { | |
val unresolvedPath = new Path(path) | |
val fs = unresolvedPath.getFileSystem(conf) | |
val resolvedPath = unresolvedPath.makeQualified(fs.getUri, fs.getWorkingDirectory) | |
(fs, resolvedPath) | |
} | |
/** Commit metadata for provided DataFrame */ | |
def writeHeader(fs: FileSystem, path: Path, df: DataFrame, separator: String): Unit = { | |
val header = generateHeader(df.columns, separator) | |
logInfo(s"Writing header [$header] for directory $path") | |
// discover all partition files that are not either checksum or _SUCCESS files | |
val statuses = fs.listStatus(path, new PathFilter() { | |
override def accept(p: Path) = { | |
// everything else is assumed to be a valid csv data file | |
p.getName != "_SUCCESS" && !p.getName.startsWith(".") && !p.getName.endsWith(".crc") | |
} | |
}) | |
// find files that are empty and write header into those files | |
// non-empty files are assumed to already have header | |
statuses.foreach { status => | |
if (status.getLen == 0) { | |
appendHeader(fs, status.getPath, header) | |
logInfo(s"Appending header to the file ${status.getPath}") | |
} | |
} | |
} | |
} | |
private[sql] class CsvDataFrameReader(val spark: SparkSession) extends DataFrameReader(spark) { | |
override def csv(path: String): DataFrame = { | |
val df = super.csv(path) | |
CsvWriteSupport.convertNulls(df) | |
} | |
} | |
// Does not support partitioning | |
private[sql] class CsvDataFrameWriter[T](ds: Dataset[T]) { | |
@transient private val spark = ds.sparkSession | |
private val df = CsvWriteSupport.convertNulls(ds.toDF(), "null") | |
private val writer = new DataFrameWriter(df) | |
private val headerKey = "header" | |
private var useHeader = false | |
private val separatorKey = "sep" | |
private var separatorValue = "," | |
def mode(saveMode: SaveMode): CsvDataFrameWriter[T] = { | |
this.writer.mode(saveMode) | |
this | |
} | |
def mode(saveMode: String): CsvDataFrameWriter[T] = { | |
this.writer.mode(saveMode) | |
this | |
} | |
def option(key: String, value: String): CsvDataFrameWriter[T] = { | |
if (key.toLowerCase == headerKey) useHeader = value.toBoolean | |
if (key.toLowerCase == separatorKey) separatorValue = value | |
this.writer.option(key, value) | |
this | |
} | |
def options(map: scala.collection.Map[String, String]): CsvDataFrameWriter[T] = { | |
map.keys.foreach(x => option(x,map(x))) | |
this | |
} | |
def option(key: String, value: Boolean): CsvDataFrameWriter[T] = option(key, value.toString) | |
def option(key: String, value: Long): CsvDataFrameWriter[T] = option(key, value.toString) | |
def option(key: String, value: Double): CsvDataFrameWriter[T] = option(key, value.toString) | |
def csv(path: String): Unit = { | |
this.writer.csv(path) | |
if (useHeader) { | |
val conf = spark.sessionState.newHadoopConf() | |
val (fs, resolvedPath) = CsvWriteSupport.resolvePath(conf, path) | |
CsvWriteSupport.writeHeader(fs, resolvedPath, df, separatorValue) | |
} | |
} | |
} | |
/** Workaround fix to store metadata when there is no way of inferring schema from files */ | |
private[sql] class Workaround[T](spark: SparkSession, ds: Dataset[T]) { | |
def fread: CsvDataFrameReader = new CsvDataFrameReader(spark) | |
def fwrite: CsvDataFrameWriter[T] = new CsvDataFrameWriter[T](ds) | |
} | |
/** Implicit methods for index */ | |
package object csvfix { | |
implicit def workaroundWriter[T](ds: Dataset[T]): Workaround[T] = { | |
new Workaround[T](ds.sparkSession, ds) | |
} | |
implicit def workaroundReader(spark: SparkSession): Workaround[_] = { | |
new Workaround[Any](spark, null) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Another approach is to write header into every empty file.