Skip to content

Instantly share code, notes, and snippets.

@viirya
Created January 24, 2020 19:29
Show Gist options
  • Save viirya/40325c95678832ec2104f7df0f04538f to your computer and use it in GitHub Desktop.
Save viirya/40325c95678832ec2104f7df0f04538f to your computer and use it in GitHub Desktop.
Snippet for extracting nested column from an input row in Spark
import java.io.{ByteArrayOutputStream, File}
import java.nio.charset.StandardCharsets
import java.sql.{Date, Timestamp}
import java.util.UUID
import java.util.concurrent.atomic.AtomicLong
import scala.util.Random
import org.scalatest.Matchers._
import org.apache.spark.SparkException
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCodegenExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSparkSession}
import org.apache.spark.sql.test.SQLTestData.{DecimalData, NullStrings, TestData2}
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom
class DataFrameSuite extends QueryTest
with SharedSparkSession
with AdaptiveSparkPlanHelper {
import testImplicits._
def genExtractor(
structField: StructField,
optChild: Option[Expression]): Expression = structField.dataType match {
case StructType(fields) if fields.length == 1 =>
val nextChild = optChild.map { child =>
UnresolvedExtractValue(child, Literal(fields(0).name))
}.getOrElse {
// The root field.
UnresolvedExtractValue(
UnresolvedAttribute(structField.name),
Literal(fields(0).name))
}
genExtractor(fields(0), Some(nextChild))
case StructType(fields) =>
throw new AnalysisException("Access type should have only one field.")
// The leaf field.
case _ =>
optChild.getOrElse(UnresolvedAttribute(structField.name))
}
test("test") {
val nestedSchema = StructType(
StructField("col1", StringType) ::
StructField("col2", StringType) ::
StructField("col3", IntegerType) :: Nil)
val values = Array("value1", "value2", 1)
// val nestedRow = new GenericRowWithSchema(values, nestedSchema)
val schema = StructType(StructField("topCol", nestedSchema) :: Nil)
// val row = new GenericRowWithSchema(Array(nestedRow), schema)
val nestedRow = new GenericInternalRow(values)
val inputRow = new GenericInternalRow(Array(nestedRow.asInstanceOf[Any]))
val accessField = StructField("topCol",
StructType(StructField("col1", StringType) :: Nil))
val extractorProjection = ProjectionOverSchema(schema)
val extractors = Seq(genExtractor(accessField, None)).map { extractor =>
val projected = extractor.transform {
case extractorProjection(expr) => expr
}
projected match {
case n: NamedExpression => n
case other => Alias(other, other.prettyName)()
}
}
val attrs = schema.toAttributes
val dummyPlan = Project(extractors, LocalRelation(attrs))
val analyzedPlan = SimpleAnalyzer.execute(dummyPlan)
SimpleAnalyzer.checkAnalysis(analyzedPlan)
val resolvedExtractors = analyzedPlan match {
case Project(projectList, _) => projectList
case _ => throw new AnalysisException(s"wrong analyzed plan: $analyzedPlan")
}
val project = new InterpretedProjection(resolvedExtractors, attrs)
val outputRow = project(inputRow)
println(outputRow)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment