Skip to content

Instantly share code, notes, and snippets.

@pathikrit
Last active June 1, 2020 15:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pathikrit/44f13bb9492cc1827b208f6a9862da33 to your computer and use it in GitHub Desktop.
Save pathikrit/44f13bb9492cc1827b208f6a9862da33 to your computer and use it in GitHub Desktop.
Spark utils to ship data
import java.nio.charset.{ Charset, StandardCharsets }
import org.apache.spark.sql._
import org.apache.spark.sql.types._
object SparkDataLoad {
def fromCsv[A : Encoder](
path: Set[String],
encoding: Charset = StandardCharsets.UTF_8,
useHeader: Boolean = false,
delimiter: Char = '|',
quote: Char = '"',
escape: Char = '\\',
skipLinesStartingWith: Option[Char] = None,
dateFormat: String = "yyyyMMdd",
timestampFormat: String = "yyyy-MM-dd'T'HH:mm:ss.SSSXXX",
representEmptyValueAs: String = "",
treatAsNull: String = "",
treatAsNaN: String = "NaN",
treatAsPositiveInf: String = "Inf",
treatAsNegativeInf: String = "-Inf",
ignoreLeadingWhiteSpace: Boolean = true,
ignoreTrailingWhiteSpace: Boolean = true,
inputFileNameColumn: String = "_source_file"
)(implicit spark: SparkSession): DataFrame = {
spark.read
.option("mode", "PERMISSIVE")
.option("encoding", encoding.name())
.option("header", useHeader)
.option("delimiter", delimiter.toString)
.option("quote", quote.toString)
.option("escape", escape.toString)
.option("dateFormat", dateFormat)
.option("timestampFormat", timestampFormat)
.option("emptyValue", representEmptyValueAs)
.option("nullValue", treatAsNull)
.option("nanValue", treatAsNaN)
.option("positiveInf", treatAsPositiveInf)
.option("negativeInf", treatAsNegativeInf)
.option("comment", skipLinesStartingWith.map(_.toString).orNull)
.option("ignoreLeadingWhiteSpace", ignoreLeadingWhiteSpace)
.option("ignoreTrailingWhiteSpace", ignoreTrailingWhiteSpace)
.schema(implicitly[Encoder[A]].schema)
.csv(path.toSeq: _*)
.withColumn(inputFileNameColumn, input_file_name())
}
def readFromSnowflake(
account: String = "*****.us-east-1.snowflakecomputing.com",
username: String = "dev",
password: String = "***************",
warehouse: String = "dev",
database: String = "dev",
table: String // Can be either a SELECT statement OR a table name
)(implicit spark: SparkSession): DataFrame =
spark.read
.format("net.snowflake.spark.snowflake")
.options(
Map(
"sfUrl" -> account,
"sfUser" -> username,
"sfPassword" -> password,
"sfDatabase" -> database,
"sfWarehouse" -> warehouse,
(if (table.toUpperCase.contains("SELECT ")) "query" else "dbtable") -> table
)
)
.load()
def toSnowflake(
account: String = "*****.us-east-1.snowflakecomputing.com",
username: String = "dev",
password: String = "***************",
warehouse: String = "dev",
database: String = "dev",
schema: String,
table: String,
clusterBy: Seq[String] = Nil,
dataset: Dataset[_],
isAppend: Boolean = false
): Unit = {
def toSnowflakeColumn(field: StructField): String = {
val col = field.dataType match {
case _: BooleanType => "BOOLEAN"
case _: ByteType | _: ShortType | _: IntegerType | _: LongType => "INTEGER"
case _: DecimalType | _: FloatType | _: DoubleType => "REAL"
case _: DateType => "DATE"
case _: TimestampType => "TIMESTAMP_TZ"
case _: StringType | _: VarcharType => "TEXT"
case _: ArrayType => "ARRAY"
case _ => throw new UnsupportedOperationException(s"Unsupported field = ${field}")
}
s"${field.name.toLowerCase} ${if (field.nullable) s"$col" else s"$col NOT NULL"}"
}
val tempTable = s"${table}_stage"
val clusterStmt = if (clusterBy.isEmpty) "" else clusterBy.mkString(" CLUSTER BY(", ", ", ")");
val createTable = dataset.schema.fields
.map(toSnowflakeColumn)
.mkString(s"CREATE OR REPLACE TRANSIENT TABLE $schema.$tempTable(\n\t", ",\n\t", s") $clusterStmt")
val preActions = Seq(
s"USE DATABASE $db",
s"USE WAREHOUSE $warehouse",
s"CREATE SCHEMA IF NOT EXISTS $schema",
s"USE SCHEMA $schema",
createTable
)
val postActions = Seq(
s"DROP TABLE IF EXISTS $schema.$table",
s"ALTER TABLE $schema.$tempTable RENAME TO $table"
)
println(((preActions :+ s"COPY DATAFRAME TO ${schema}.${tempTable}") ++ postActions).mkString("", ";\n\n", ";"))
dataset
.write
.format("snowflake")
.options(Map(
"sfUrl" -> account,
"sfUser" -> username,
"sfPassword" -> password,
"sfDatabase" -> database,
"sfWarehouse" -> warehouse,
"dbtable" -> table,
"preactions" -> preActions.mkString("", ";", ";"),
"postactions" -> postActions.mkString("", ";", ";")
))
.mode(if (isAppend) SaveMode.Append else SaveMode.Overwrite)
.save()
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment