Skip to content

Instantly share code, notes, and snippets.

@seddonm1
Last active April 5, 2024 18:56
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save seddonm1/6e4354f68d7d74f71bb5e27d31d5b980 to your computer and use it in GitHub Desktop.
Save seddonm1/6e4354f68d7d74f71bb5e27d31d5b980 to your computer and use it in GitHub Desktop.
Makes a Spark Schema (StructType) from an input XSD file
// need to add the Apache WS XMLSchema library to spark/jars (does not have dependencies)
// https://repo1.maven.org/maven2/org/apache/ws/xmlschema/xmlschema-core/2.2.5/xmlschema-core-2.2.5.jar
import org.apache.ws.commons.schema.XmlSchemaCollection
import java.io.StringReader
import scala.collection.JavaConverters._
import org.apache.ws.commons.schema._
import org.apache.ws.commons.schema.constants.Constants
import org.apache.spark.sql.types._
def getStructField(xmlSchema: XmlSchema, schemaType: XmlSchemaType): StructField = {
schemaType match {
// xs:simpleType
case schemaSimpleType: XmlSchemaSimpleType => {
schemaSimpleType.getContent match {
case schemaSimpleTypeRestriction: XmlSchemaSimpleTypeRestriction => {
val baseName = "baseName"
val matchType = if (schemaSimpleTypeRestriction.getBaseTypeName == Constants.XSD_ANYSIMPLETYPE) {
schemaSimpleType.getQName
} else {
schemaSimpleTypeRestriction.getBaseTypeName
}
matchType match {
case Constants.XSD_BASE64 => StructField(baseName, StringType, true)
case Constants.XSD_BOOLEAN => StructField(baseName, BooleanType, true)
case Constants.XSD_BYTE => StructField(baseName, BinaryType, true)
case Constants.XSD_DATE => StructField(baseName, StringType, true)
case Constants.XSD_DATETIME => StructField(baseName, StringType, true)
case Constants.XSD_DECIMAL => {
val scale = schemaSimpleTypeRestriction.getFacets.asScala.toList.collect {
case schemaFractionDigitsFacet: XmlSchemaFractionDigitsFacet => schemaFractionDigitsFacet
}.headOption
scale match {
case Some(scale) => StructField(baseName, DecimalType(38, scale.getValue.asInstanceOf[String].toInt), true)
case None => StructField(baseName, DecimalType(38, 18), true)
}
}
case Constants.XSD_DOUBLE => StructField(baseName, DoubleType, true)
case Constants.XSD_FLOAT => StructField(baseName, FloatType, true)
case Constants.XSD_INTEGER => StructField(baseName, IntegerType, true)
case Constants.XSD_LONG => StructField(baseName, LongType, true)
case Constants.XSD_NEGATIVEINTEGER => StructField(baseName, IntegerType, true)
case Constants.XSD_NONNEGATIVEINTEGER => StructField(baseName, IntegerType, true)
case Constants.XSD_NONPOSITIVEINTEGER => StructField(baseName, IntegerType, true)
case Constants.XSD_POSITIVEINTEGER => StructField(baseName, IntegerType, true)
case Constants.XSD_SHORT => StructField(baseName, IntegerType, true)
case Constants.XSD_STRING => StructField(baseName, StringType, true)
case Constants.XSD_TIME => StructField(baseName, StringType, true)
case Constants.XSD_UNSIGNEDINT => StructField(baseName, IntegerType, true)
case Constants.XSD_UNSIGNEDLONG => StructField(baseName, IntegerType, true)
case Constants.XSD_UNSIGNEDSHORT => StructField(baseName, IntegerType, true)
}
}
}
}
// xs:complexType
case schemaComplexType: XmlSchemaComplexType => {
Option(schemaComplexType.getContentModel) match {
case Some(contentModel) => contentModel match {
// xs:simpleContent
case simpleContent: XmlSchemaSimpleContent => {
simpleContent.getContent match {
case schemaSimpleContentExtension: XmlSchemaSimpleContentExtension => {
val value = {
val baseStructField = getStructField(xmlSchema, xmlSchema.getTypeByName(schemaSimpleContentExtension.getBaseTypeName))
StructField("_VALUE", baseStructField.dataType, true)
}
val attributes = schemaSimpleContentExtension.getAttributes.asScala.toList.map { attribute =>
attribute match {
case schemaAttribute: XmlSchemaAttribute => {
val baseStructField = getStructField(xmlSchema, xmlSchema.getTypeByName(schemaAttribute.getSchemaTypeName))
StructField(s"_${schemaAttribute.getName}", baseStructField.dataType, true)
}
}
}
StructField(schemaComplexType.getName, StructType(List(value) ++ attributes), true)
}
}
}
}
case None => {
schemaComplexType.getParticle match {
// xs:all
case schemaAll: XmlSchemaAll => {
val fields = schemaAll.getItems.asScala.toList.map { element =>
element match {
case schemaElement: XmlSchemaElement => {
val baseStructField = getStructField(xmlSchema, schemaElement.getSchemaType)
val field = StructField(schemaElement.getName, baseStructField.dataType, true)
if (schemaElement.getMaxOccurs == 1) {
field
} else {
val field = StructField(schemaElement.getName, baseStructField.dataType, true)
StructField(schemaElement.getName, ArrayType(field.dataType, true), true)
}
}
}
}
StructField(schemaComplexType.getName, StructType(fields), true)
}
// xs:choice
case schemaChoice: XmlSchemaChoice => {
val fields = schemaChoice.getItems.asScala.toList.map { element =>
element match {
case schemaElement: XmlSchemaElement => {
val baseStructField = getStructField(xmlSchema, schemaElement.getSchemaType)
val field = StructField(schemaElement.getName, baseStructField.dataType, true)
if (schemaElement.getMaxOccurs == 1) {
field
} else {
val field = StructField(schemaElement.getName, baseStructField.dataType, true)
StructField(schemaElement.getName, ArrayType(field.dataType, true), true)
}
}
}
}
StructField(schemaComplexType.getName, StructType(fields), true)
}
// xs:sequence
case schemaSequence: XmlSchemaSequence => {
// flatten xs:choice nodes
val fields = schemaSequence.getItems.asScala.toList.flatMap { schemaSequenceMember: XmlSchemaSequenceMember =>
schemaSequenceMember match {
case schemaChoice: XmlSchemaChoice => schemaChoice.getItems.asScala.toList.map((_, true))
case schemaElement: XmlSchemaElement => List((schemaElement, schemaElement.getMinOccurs == 0))
}
}.map { case (element, nullable) =>
element match {
case schemaElement: XmlSchemaElement => {
val baseStructField = getStructField(xmlSchema, schemaElement.getSchemaType)
val field = StructField(schemaElement.getName, baseStructField.dataType, nullable)
if (schemaElement.getMaxOccurs == 1) {
field
} else {
val field = StructField(schemaElement.getName, baseStructField.dataType, nullable)
StructField(schemaElement.getName, ArrayType(field.dataType, true), true)
}
}
}
}
StructField(schemaComplexType.getName, StructType(fields), true)
}
}
}
}
}
}
}
def getStructType(xmlSchema: XmlSchema): StructType = {
val baseElement = xmlSchema.getElements.asScala.head
val schemaType = baseElement._2.getSchemaType
if (schemaType.isAnonymous) {
schemaType.setName(baseElement._1.getLocalPart)
}
StructType(getStructField(xmlSchema, schemaType) :: Nil)
}
// read the XSD
val df = spark.read.option("wholetext", "true").text("/src/pain.001.001.03.xsd")
val xmlSchemaCollection = new XmlSchemaCollection
val xmlSchema = xmlSchemaCollection.read(new StringReader(df.head.getString(0)))
val sparkSchema = getStructType(xmlSchema)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment