Skip to content

Instantly share code, notes, and snippets.

@sarkologist
Last active June 10, 2020 12:21
Show Gist options
  • Save sarkologist/7650fd025edcb20499290f9c46cc909e to your computer and use it in GitHub Desktop.
Save sarkologist/7650fd025edcb20499290f9c46cc909e to your computer and use it in GitHub Desktop.
protobuf to bigquery
package utils
import java.util
import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor}
import com.google.protobuf.Descriptors.FieldDescriptor.Type.{GROUP, MESSAGE}
import com.google.protobuf.Message
import scalaz.{Functor, Yoneda}
import scala.collection.JavaConverters._
import scala.collection.mutable
object Protobuf {
def foldDescriptor[A](
base: FieldDescriptor => A,
recurse: List[(A, FieldDescriptor)] => A,
maxDepth: Int = Integer.MAX_VALUE)(descriptor: Descriptor): A = {
def go(d: Descriptor, depth: Int): A = {
val fields =
d.getFields.asScala.foldLeft(List.empty[(A, FieldDescriptor)]) {
case (as, f) =>
(f.getType match {
case MESSAGE | GROUP if depth <= maxDepth =>
go(f.getMessageType, depth + 1) -> f
case _ => base(f) -> f
}) :: as
}
recurse(fields)
}
go(descriptor, 0)
}
def lift[A](fieldDescriptor: FieldDescriptor, value: AnyRef): Repeated[A] =
if (fieldDescriptor.isRepeated)
Many(value.asInstanceOf[util.List[A]].asScala)
else One(value.asInstanceOf[A])
def foldMessage[A](
recurse: Seq[(Yoneda[Repeated, (A, Message)], FieldDescriptor)] => A,
base: (AnyRef, FieldDescriptor) => A)(message: Message): A = {
def go(msg: Message): A = {
val presentFields =
msg.getAllFields.asScala
.foldLeft(
Seq.empty[(Yoneda[Repeated, (A, Message)], FieldDescriptor)]) {
case (fields, (f, value)) =>
fields :+ (f.getType match {
case MESSAGE | GROUP =>
Yoneda(lift[Message](f, value)).map(m => go(m) -> m)
case _ =>
Yoneda(lift[AnyRef](f, value)).map(r => base(r, f) -> msg)
}) -> f
}
val allFields = defaultValues(msg)().foldLeft(presentFields) {
case (fields, (f, value)) =>
fields :+ Yoneda(lift[AnyRef](f, value))
.map(v => base(v, f) -> msg) -> f
}
recurse(allFields)
}
go(message)
}
def defaultValues(message: Message)(presentFields: Set[FieldDescriptor] =
message.getAllFields.asScala.keys.toSet)
: Set[(FieldDescriptor, AnyRef)] = {
val absentFields = message.getDescriptorForType.getFields.asScala.toSet -- presentFields
val oneOfFieldsOfMessage = oneOfFields(message.getDescriptorForType)
(absentFields -- oneOfFieldsOfMessage)
.map { field =>
if (!field.isRepeated)
(field.getType match {
case MESSAGE | GROUP => None
case _ =>
Some(field.getDefaultValue)
}).map { defVal =>
field -> defVal
} else None
}
.collect { case Some(pair) => pair }
}
def oneOfFields(descriptor: Descriptor): Set[FieldDescriptor] = {
descriptor.getOneofs.asScala
.flatMap(_.getFields.asScala)
.toSet
}
sealed trait Repeated[A] {
def underlying: AnyRef
}
case class One[A](a: A) extends Repeated[A] {
def underlying: AnyRef = a.asInstanceOf[AnyRef]
}
// JavaConverters returns mutable.Buffer
case class Many[A](as: mutable.Buffer[A]) extends Repeated[A] {
def underlying: AnyRef = cast(as.asJava)
def cast(x: util.List[A]): AnyRef = x.asInstanceOf[AnyRef]
}
implicit def repeatedFunctor: Functor[Repeated] = new Functor[Repeated] {
override def map[A, B](fa: Repeated[A])(f: A => B): Repeated[B] =
fa match {
case One(a) => One(f(a))
case Many(as) => Many(as.map(f))
}
}
}
package utils
import com.google.api.services.bigquery.model.{TableFieldSchema, TableRow, TableSchema}
import com.google.protobuf.Descriptors.FieldDescriptor.Type._
import com.google.protobuf.Descriptors.{Descriptor, EnumValueDescriptor, FieldDescriptor}
import com.google.protobuf.{ByteString, Message}
import Protobuf._
import scalaz.Yoneda
import scala.collection.JavaConverters._
object ProtobufToBigQuery extends Serializable {
val _set = "_set"
def toTableSchemaType(t: FieldDescriptor.Type): String = t match {
case DOUBLE | FLOAT => "FLOAT"
case INT32 | UINT32 | SINT32 | FIXED32 | SFIXED32 | INT64 | UINT64 |
SINT64 | FIXED64 | SFIXED64 =>
"INTEGER"
case BOOL => "BOOLEAN"
case STRING => "STRING"
case BYTES => "BYTES"
case ENUM => "STRING"
case MESSAGE | GROUP => "RECORD"
}
// protobuf messages have potentially unlimited depth, since
// message-inside-message recursion is possible
// that doesn't mean we cannot produce a schema though
// provided at runtime messages written to the table are never
// beyond the max depth
// https://cloud.google.com/bigquery/docs/nested-repeated#limitations
def makeTableSchema(descriptor: Descriptor): TableSchema = {
def base(
f: FieldDescriptor): Either[TableFieldSchema, List[TableFieldSchema]] =
Left(makeField(f))
def recurse(
children: List[(Either[TableFieldSchema, List[TableFieldSchema]],
FieldDescriptor)])
: Either[TableFieldSchema, List[TableFieldSchema]] = Right {
children
.foldLeft(List.empty[TableFieldSchema]) {
case (fields, (Left(zero), _)) => zero :: fields
case (fields, (Right(many), f)) =>
val fs =
if (!f.isRepeated && hitPotentialBottom(f.getMessageType) && needsToFlagPresence(
f))
many ::: List(
new TableFieldSchema()
.setName(_set)
.setType("BOOLEAN")
.setMode("NULLABLE"))
else many
makeField(f).setFields(fs.asJava) :: fields
}
}
foldDescriptor(base, recurse, 15)(descriptor)
.fold(_ => throw new RuntimeException,
fields => new TableSchema().setFields(fields.asJava))
}
def makeField(f: FieldDescriptor): TableFieldSchema =
new TableFieldSchema()
.setName(bigqueryColumnName(f))
.setType(toTableSchemaType(f.getType))
.setMode(if (f.isRepeated) "REPEATED" else "NULLABLE")
def bigqueryColumnName(f: FieldDescriptor): String = {
f.getName.replace(".", "_")
}
// unlike `hitBottom` this is static
def hitPotentialBottom(d: Descriptor): Boolean = {
d.getFields.asScala.forall(f =>
f.getType match {
case MESSAGE | GROUP => !f.isRequired
case _ => true
})
}
// note that this is not static, i.e.
// it depends on the message at runtime not just the schema
// it cannot depend on the schema as message trees have unbounded depth due to recursion
def hitBottom(message: Message): Boolean = {
def isNotFurtherMessage(fieldDescriptor: FieldDescriptor): Boolean =
fieldDescriptor.getType match {
case MESSAGE | GROUP => false
case _ => true
}
message.getAllFields.asScala.keys.forall(isNotFurtherMessage)
}
def hasAlwaysPresentFields(descriptor: Descriptor): Boolean =
descriptor.getFields.asScala
.exists(f => isAlwaysPresentField(oneOfFields(descriptor), f))
def isAlwaysPresentField(oneOfFields: Set[FieldDescriptor],
f: FieldDescriptor): Boolean =
f.getType match {
// only proto2 have required fields
// but for proto3 .isRequired is always false
// oneof fields don't have default values
// optional primitive-type fields have default values
case MESSAGE | GROUP => false
case _ =>
!oneOfFields.contains(f) && (f.isOptional || f.isRequired)
}
def makeTableRow(
customRow: (FieldDescriptor,
Yoneda[Repeated, Any]) => Yoneda[Repeated, Any])
: Message => TableRow =
msg => {
def base(value: AnyRef, f: FieldDescriptor): Any =
toBigQueryType(f.getType)(value)
def recurse(
children: Seq[(Yoneda[Repeated, (Any, Message)], FieldDescriptor)])
: Any =
children.foldLeft(new TableRow) {
case (tr, (repeated, f)) =>
val value = customRow(
f,
repeated
.map {
case (child: TableRow, msg) =>
if (!f.isRepeated && hitBottom(msg) && needsToFlagPresence(
f))
child.set(_set, true)
else child
case (primitive, _) => primitive
}
).run.underlying
tr.set(bigqueryColumnName(f), value)
}
foldMessage(recurse, base)(msg).asInstanceOf[TableRow]
}
def needsToFlagPresence(f: FieldDescriptor): Boolean =
!f.isRequired && !hasAlwaysPresentFields(f.getMessageType)
def toBigQueryType(f: FieldDescriptor.Type)(v: AnyRef): Any = f match {
case DOUBLE => v.asInstanceOf[Double]
case FLOAT => v.asInstanceOf[Float]
case INT32 | UINT32 | SINT32 | FIXED32 | SFIXED32 =>
v.asInstanceOf[Int]
case INT64 | UINT64 | SINT64 | FIXED64 | SFIXED64 =>
v.asInstanceOf[Long]
case BOOL => v.asInstanceOf[Boolean]
case STRING => v.asInstanceOf[String]
case BYTES => v.asInstanceOf[ByteString].toByteArray
case ENUM => v.asInstanceOf[EnumValueDescriptor].getName
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment