Last active
June 10, 2020 12:21
-
-
Save sarkologist/7650fd025edcb20499290f9c46cc909e to your computer and use it in GitHub Desktop.
protobuf to bigquery
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package utils | |
import java.util | |
import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor} | |
import com.google.protobuf.Descriptors.FieldDescriptor.Type.{GROUP, MESSAGE} | |
import com.google.protobuf.Message | |
import scalaz.{Functor, Yoneda} | |
import scala.collection.JavaConverters._ | |
import scala.collection.mutable | |
object Protobuf { | |
def foldDescriptor[A]( | |
base: FieldDescriptor => A, | |
recurse: List[(A, FieldDescriptor)] => A, | |
maxDepth: Int = Integer.MAX_VALUE)(descriptor: Descriptor): A = { | |
def go(d: Descriptor, depth: Int): A = { | |
val fields = | |
d.getFields.asScala.foldLeft(List.empty[(A, FieldDescriptor)]) { | |
case (as, f) => | |
(f.getType match { | |
case MESSAGE | GROUP if depth <= maxDepth => | |
go(f.getMessageType, depth + 1) -> f | |
case _ => base(f) -> f | |
}) :: as | |
} | |
recurse(fields) | |
} | |
go(descriptor, 0) | |
} | |
def lift[A](fieldDescriptor: FieldDescriptor, value: AnyRef): Repeated[A] = | |
if (fieldDescriptor.isRepeated) | |
Many(value.asInstanceOf[util.List[A]].asScala) | |
else One(value.asInstanceOf[A]) | |
def foldMessage[A]( | |
recurse: Seq[(Yoneda[Repeated, (A, Message)], FieldDescriptor)] => A, | |
base: (AnyRef, FieldDescriptor) => A)(message: Message): A = { | |
def go(msg: Message): A = { | |
val presentFields = | |
msg.getAllFields.asScala | |
.foldLeft( | |
Seq.empty[(Yoneda[Repeated, (A, Message)], FieldDescriptor)]) { | |
case (fields, (f, value)) => | |
fields :+ (f.getType match { | |
case MESSAGE | GROUP => | |
Yoneda(lift[Message](f, value)).map(m => go(m) -> m) | |
case _ => | |
Yoneda(lift[AnyRef](f, value)).map(r => base(r, f) -> msg) | |
}) -> f | |
} | |
val allFields = defaultValues(msg)().foldLeft(presentFields) { | |
case (fields, (f, value)) => | |
fields :+ Yoneda(lift[AnyRef](f, value)) | |
.map(v => base(v, f) -> msg) -> f | |
} | |
recurse(allFields) | |
} | |
go(message) | |
} | |
def defaultValues(message: Message)(presentFields: Set[FieldDescriptor] = | |
message.getAllFields.asScala.keys.toSet) | |
: Set[(FieldDescriptor, AnyRef)] = { | |
val absentFields = message.getDescriptorForType.getFields.asScala.toSet -- presentFields | |
val oneOfFieldsOfMessage = oneOfFields(message.getDescriptorForType) | |
(absentFields -- oneOfFieldsOfMessage) | |
.map { field => | |
if (!field.isRepeated) | |
(field.getType match { | |
case MESSAGE | GROUP => None | |
case _ => | |
Some(field.getDefaultValue) | |
}).map { defVal => | |
field -> defVal | |
} else None | |
} | |
.collect { case Some(pair) => pair } | |
} | |
def oneOfFields(descriptor: Descriptor): Set[FieldDescriptor] = { | |
descriptor.getOneofs.asScala | |
.flatMap(_.getFields.asScala) | |
.toSet | |
} | |
sealed trait Repeated[A] { | |
def underlying: AnyRef | |
} | |
case class One[A](a: A) extends Repeated[A] { | |
def underlying: AnyRef = a.asInstanceOf[AnyRef] | |
} | |
// JavaConverters returns mutable.Buffer | |
case class Many[A](as: mutable.Buffer[A]) extends Repeated[A] { | |
def underlying: AnyRef = cast(as.asJava) | |
def cast(x: util.List[A]): AnyRef = x.asInstanceOf[AnyRef] | |
} | |
implicit def repeatedFunctor: Functor[Repeated] = new Functor[Repeated] { | |
override def map[A, B](fa: Repeated[A])(f: A => B): Repeated[B] = | |
fa match { | |
case One(a) => One(f(a)) | |
case Many(as) => Many(as.map(f)) | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package utils | |
import com.google.api.services.bigquery.model.{TableFieldSchema, TableRow, TableSchema} | |
import com.google.protobuf.Descriptors.FieldDescriptor.Type._ | |
import com.google.protobuf.Descriptors.{Descriptor, EnumValueDescriptor, FieldDescriptor} | |
import com.google.protobuf.{ByteString, Message} | |
import Protobuf._ | |
import scalaz.Yoneda | |
import scala.collection.JavaConverters._ | |
object ProtobufToBigQuery extends Serializable { | |
val _set = "_set" | |
def toTableSchemaType(t: FieldDescriptor.Type): String = t match { | |
case DOUBLE | FLOAT => "FLOAT" | |
case INT32 | UINT32 | SINT32 | FIXED32 | SFIXED32 | INT64 | UINT64 | | |
SINT64 | FIXED64 | SFIXED64 => | |
"INTEGER" | |
case BOOL => "BOOLEAN" | |
case STRING => "STRING" | |
case BYTES => "BYTES" | |
case ENUM => "STRING" | |
case MESSAGE | GROUP => "RECORD" | |
} | |
// protobuf messages have potentially unlimited depth, since | |
// message-inside-message recursion is possible | |
// that doesn't mean we cannot produce a schema though | |
// provided at runtime messages written to the table are never | |
// beyond the max depth | |
// https://cloud.google.com/bigquery/docs/nested-repeated#limitations | |
def makeTableSchema(descriptor: Descriptor): TableSchema = { | |
def base( | |
f: FieldDescriptor): Either[TableFieldSchema, List[TableFieldSchema]] = | |
Left(makeField(f)) | |
def recurse( | |
children: List[(Either[TableFieldSchema, List[TableFieldSchema]], | |
FieldDescriptor)]) | |
: Either[TableFieldSchema, List[TableFieldSchema]] = Right { | |
children | |
.foldLeft(List.empty[TableFieldSchema]) { | |
case (fields, (Left(zero), _)) => zero :: fields | |
case (fields, (Right(many), f)) => | |
val fs = | |
if (!f.isRepeated && hitPotentialBottom(f.getMessageType) && needsToFlagPresence( | |
f)) | |
many ::: List( | |
new TableFieldSchema() | |
.setName(_set) | |
.setType("BOOLEAN") | |
.setMode("NULLABLE")) | |
else many | |
makeField(f).setFields(fs.asJava) :: fields | |
} | |
} | |
foldDescriptor(base, recurse, 15)(descriptor) | |
.fold(_ => throw new RuntimeException, | |
fields => new TableSchema().setFields(fields.asJava)) | |
} | |
def makeField(f: FieldDescriptor): TableFieldSchema = | |
new TableFieldSchema() | |
.setName(bigqueryColumnName(f)) | |
.setType(toTableSchemaType(f.getType)) | |
.setMode(if (f.isRepeated) "REPEATED" else "NULLABLE") | |
def bigqueryColumnName(f: FieldDescriptor): String = { | |
f.getName.replace(".", "_") | |
} | |
// unlike `hitBottom` this is static | |
def hitPotentialBottom(d: Descriptor): Boolean = { | |
d.getFields.asScala.forall(f => | |
f.getType match { | |
case MESSAGE | GROUP => !f.isRequired | |
case _ => true | |
}) | |
} | |
// note that this is not static, i.e. | |
// it depends on the message at runtime not just the schema | |
// it cannot depend on the schema as message trees have unbounded depth due to recursion | |
def hitBottom(message: Message): Boolean = { | |
def isNotFurtherMessage(fieldDescriptor: FieldDescriptor): Boolean = | |
fieldDescriptor.getType match { | |
case MESSAGE | GROUP => false | |
case _ => true | |
} | |
message.getAllFields.asScala.keys.forall(isNotFurtherMessage) | |
} | |
def hasAlwaysPresentFields(descriptor: Descriptor): Boolean = | |
descriptor.getFields.asScala | |
.exists(f => isAlwaysPresentField(oneOfFields(descriptor), f)) | |
def isAlwaysPresentField(oneOfFields: Set[FieldDescriptor], | |
f: FieldDescriptor): Boolean = | |
f.getType match { | |
// only proto2 have required fields | |
// but for proto3 .isRequired is always false | |
// oneof fields don't have default values | |
// optional primitive-type fields have default values | |
case MESSAGE | GROUP => false | |
case _ => | |
!oneOfFields.contains(f) && (f.isOptional || f.isRequired) | |
} | |
def makeTableRow( | |
customRow: (FieldDescriptor, | |
Yoneda[Repeated, Any]) => Yoneda[Repeated, Any]) | |
: Message => TableRow = | |
msg => { | |
def base(value: AnyRef, f: FieldDescriptor): Any = | |
toBigQueryType(f.getType)(value) | |
def recurse( | |
children: Seq[(Yoneda[Repeated, (Any, Message)], FieldDescriptor)]) | |
: Any = | |
children.foldLeft(new TableRow) { | |
case (tr, (repeated, f)) => | |
val value = customRow( | |
f, | |
repeated | |
.map { | |
case (child: TableRow, msg) => | |
if (!f.isRepeated && hitBottom(msg) && needsToFlagPresence( | |
f)) | |
child.set(_set, true) | |
else child | |
case (primitive, _) => primitive | |
} | |
).run.underlying | |
tr.set(bigqueryColumnName(f), value) | |
} | |
foldMessage(recurse, base)(msg).asInstanceOf[TableRow] | |
} | |
def needsToFlagPresence(f: FieldDescriptor): Boolean = | |
!f.isRequired && !hasAlwaysPresentFields(f.getMessageType) | |
def toBigQueryType(f: FieldDescriptor.Type)(v: AnyRef): Any = f match { | |
case DOUBLE => v.asInstanceOf[Double] | |
case FLOAT => v.asInstanceOf[Float] | |
case INT32 | UINT32 | SINT32 | FIXED32 | SFIXED32 => | |
v.asInstanceOf[Int] | |
case INT64 | UINT64 | SINT64 | FIXED64 | SFIXED64 => | |
v.asInstanceOf[Long] | |
case BOOL => v.asInstanceOf[Boolean] | |
case STRING => v.asInstanceOf[String] | |
case BYTES => v.asInstanceOf[ByteString].toByteArray | |
case ENUM => v.asInstanceOf[EnumValueDescriptor].getName | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment