Skip to content

Instantly share code, notes, and snippets.

@rdblue
Created May 22, 2019 22:21
Show Gist options
  • Save rdblue/6bc140a575fdf266beb2710ad9dbed8f to your computer and use it in GitHub Desktop.
Save rdblue/6bc140a575fdf266beb2710ad9dbed8f to your computer and use it in GitHub Desktop.
Prototype DataFrameWriter for v2 tables
/**
* Interface used to write a [[Dataset]] to external storage using the v2 API.
*
* @since 3.0.0
*/
@Experimental
final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T])
extends CreateTableWriter[T] with LookupCatalog {
import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._
private val df = ds.toDF()
private val sparkSession = ds.sparkSession
private val tableName = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(table)
private val logicalPlan = df.queryExecution.logical
private var provider: Option[String] = None
private val options = new mutable.HashMap[String, String]()
private var partitioning: Option[Seq[Transform]] = None
/**
* Specifies a provider for the underlying output data source. Spark's default catalog supports
* "parquet", "json", etc.
*
* @since 3.0.0
*/
def using(provider: String): DataFrameWriterV2[T] = {
this.provider = Some(provider)
this
}
/**
* Add a write option.
*
* @since 3.0.0
*/
def option(key: String, value: String): DataFrameWriterV2[T] = {
this.options.put(key, value)
this
}
def partitionBy(columns: Column*): CreateTableWriter[T] = {
val asTransforms = columns.map(_.expr).map {
case Years(attr: Attribute) =>
LogicalExpressions.years(attr.name)
case Months(attr: Attribute) =>
LogicalExpressions.months(attr.name)
case Days(attr: Attribute, _: DateType, _) =>
LogicalExpressions.days(attr.name)
case Hours(attr: Attribute, _) =>
LogicalExpressions.hours(attr.name)
case attr: Attribute =>
LogicalExpressions.identity(attr.name)
case expr =>
throw new AnalysisException(s"Invalid partition transformation: ${expr.sql}")
}
this.partitioning = Some(asTransforms)
this
}
def create(): Unit = {
val CatalogObjectIdentifier(maybeCatalog, identifier) = tableName
val catalog = maybeCatalog
.getOrElse(throw new AnalysisException(
s"No catalog specified for table ${identifier.quoted} and no default catalog is set"))
.asTableCatalog
// TODO: Maybe this should be a different statement instead? CreateTableFromDataFrame?
runCommand("create") {
CreateTableAsSelect(
catalog,
identifier,
partitioning.getOrElse(Seq.empty),
logicalPlan,
properties = Map.empty[String, String],
writeOptions = options.toMap,
ignoreIfExists = false)
}
}
def append(): Unit = {
runCommand("append") {
AppendData.byName(UnresolvedRelation(tableName), logicalPlan)
}
}
def overwrite(condition: Column): Unit = {
runCommand("overwrite") {
OverwriteByExpression.byName(UnresolvedRelation(tableName), logicalPlan, condition.expr)
}
}
def overwritePartitions(): Unit = {
runCommand("overwritePartitions") {
OverwritePartitionsDynamic.byName(UnresolvedRelation(tableName), logicalPlan)
}
}
/**
* Wrap a DataFrameWriter action to track the QueryExecution and time cost, then report to the
* user-registered callback functions.
*/
private def runCommand(name: String)(command: LogicalPlan): Unit = {
val qe = sparkSession.sessionState.executePlan(command)
// call `QueryExecution.toRDD` to trigger the execution of commands.
SQLExecution.withNewExecutionId(sparkSession, qe, Some(name))(qe.toRdd)
}
override protected def lookupCatalog: Option[String => CatalogPlugin] =
Some(sparkSession.catalog(_))
}
private trait WriteConfigMethods[R] {
def using(provider: String): R
def option(key: String, value: String): R
}
trait CreateTableWriter[T] extends WriteConfigMethods[CreateTableWriter[T]] {
def create(): Unit
def partitionBy(columns: Column*): CreateTableWriter[T]
}
class Demo(df: DataFrame) {
import df.sparkSession.implicits._
import org.apache.spark.sql.functions._
def spark: SparkSession = df.sparkSession
implicit class DSv2Write[T](ds: Dataset[T]) {
def writeTo(table: String): DataFrameWriterV2[T] = {
new DataFrameWriterV2[T](table, ds)
}
}
def test(): Unit = {
df.writeTo("db.table").option("key", "value").append()
df.writeTo("db.table").option("key", "value").overwrite($"day" === "2019-01-01")
df.writeTo("db.table").partitionBy(month($"ts")).using("test").create()
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment