Skip to content

Instantly share code, notes, and snippets.

@Baccata
Created October 2, 2016 18:56
Show Gist options
  • Save Baccata/52dd5a45e99c1c6ae1f8f374afbe29ce to your computer and use it in GitHub Desktop.
Save Baccata/52dd5a45e99c1c6ae1f8f374afbe29ce to your computer and use it in GitHub Desktop.
monix-grpc
import com.google.common.util.concurrent.ListenableFuture
import com.trueaccord.scalapb.grpc.Grpc
import io.grpc.stub.StreamObserver
import monix.eval.{Callback, Task}
import monix.execution.Ack.{Continue, Stop}
import monix.execution.{Ack, Scheduler}
import monix.reactive.observers.{BufferedSubscriber, Subscriber}
import monix.reactive.subjects.PublishToOneSubject
import monix.reactive.{Observable, OverflowStrategy}
import scala.concurrent.Future
package object grpcmonix {
def guavaFuture2Task[A](guavaFuture: ListenableFuture[A]): Task[A] =
Task.defer(Task.fromFuture(Grpc.guavaFuture2ScalaFuture(guavaFuture)))
def monixSubscriber2GrpcStreamObserver[T](
subscriber: Subscriber[T]): StreamObserver[T] = new StreamObserver[T] {
val rSubscriber =
BufferedSubscriber[T](subscriber, OverflowStrategy.Unbounded)
override def onError(ex: Throwable): Unit = rSubscriber.onError(ex)
override def onCompleted(): Unit = rSubscriber.onComplete()
override def onNext(value: T): Unit = {
//The onNext method of the buffered returns an Ack synchronously. We don't have to worry about having
//to chain a Future[Ack] with some other computation
rSubscriber.onNext(value)
}
}
def grpcStreamObserver2monixSubscriber[T](observer: StreamObserver[T],
s: Scheduler) = new Subscriber[T] {
override implicit def scheduler: Scheduler = s
override def onError(ex: Throwable): Unit = observer.onError(ex)
override def onComplete(): Unit = observer.onCompleted()
override def onNext(elem: T): Future[Ack] =
try {
observer.onNext(elem)
Continue
} catch {
case ex: Throwable => observer.onError(ex); Stop
}
}
private def grpcOperator2MonixOperator[In, Out](
grpcOperator: StreamObserver[Out] => StreamObserver[In])
: Subscriber[Out] => Subscriber[In] = (subscriberOut: Subscriber[Out]) => {
val streamObserverOut: StreamObserver[Out] =
monixSubscriber2GrpcStreamObserver(subscriberOut)
val streamObserverIn: StreamObserver[In] = grpcOperator(streamObserverOut)
grpcStreamObserver2monixSubscriber(streamObserverIn,
subscriberOut.scheduler)
}
def liftByOperator[In, Out](
operator: Subscriber[Out] => Subscriber[In],
obsIn: Observable[In]
): Observable[Out] =
obsIn.liftByOperator(operator)
def liftFromGrpcOperator[In, Out](
operator: StreamObserver[Out] => StreamObserver[In],
obsIn: Observable[In]
): Observable[Out] =
liftByOperator(grpcOperator2MonixOperator(operator), obsIn)
def unliftByTransformer[In, Out](
transformer: Observable[In] => Observable[Out],
obsOut: Subscriber[Out]
): Subscriber[In] = {
new Subscriber[In] {
private[this] val input = PublishToOneSubject[In]()
locally {
input.transform(transformer).subscribe(obsOut)
}
implicit val scheduler = obsOut.scheduler
def onNext(elem: In): Future[Ack] =
input.onNext(elem)
def onError(ex: Throwable): Unit =
input.onError(ex)
def onComplete(): Unit =
input.onComplete()
}
}
def grpcStreamObserver2Callback[T](
observer: StreamObserver[T]): Callback[T] = new Callback[T] {
override def onError(ex: Throwable): Unit = {
observer.onError(ex)
}
override def onSuccess(value: T): Unit = {
observer.onNext(value)
observer.onCompleted()
}
}
}
import com.google.protobuf.Descriptors.{MethodDescriptor, ServiceDescriptor}
import com.trueaccord.scalapb.compiler.FunctionalPrinter.PrinterEndo
import com.trueaccord.scalapb.compiler._
import scala.collection.JavaConverters._
final class MonixGrpcPrinter(service: ServiceDescriptor,
override val params: GeneratorParams)
extends DescriptorPimps {
private[this] def observable(typeParam: String): String =
s"Observable[$typeParam]"
private[this] def task(typeParam: String): String =
s"Task[$typeParam]"
private[this] def serviceMethodSignature(method: MethodDescriptor) = {
s"def ${method.name}" + (method.streamType match {
case StreamType.Unary =>
s"(request: ${method.scalaIn}): ${task(method.scalaOut)}"
case StreamType.ClientStreaming =>
s"(input: ${observable(method.scalaIn)}): ${task(method.scalaOut)}"
case StreamType.ServerStreaming =>
s"(request: ${method.scalaIn}): ${observable(method.scalaOut)}"
case StreamType.Bidirectional =>
s"(input: ${observable(method.scalaIn)}): ${observable(method.scalaOut)}"
})
}
private[this] def serviceTrait: PrinterEndo = {
val endos: PrinterEndo = { p =>
p.seq(service.methods.map(m => serviceMethodSignature(m) + "\n"))
}
{ p =>
p.add(s"trait ${service.name} {").withIndent(endos).add("}")
}
}
private[this] val channel = "_root_.io.grpc.Channel"
private[this] val callOptions = "_root_.io.grpc.CallOptions"
private[this] val streamObserver = "_root_.io.grpc.stub.StreamObserver"
private[this] val abstractStub = "_root_.io.grpc.stub.AbstractStub"
private[this] val serverCalls = "_root_.io.grpc.stub.ServerCalls"
private[this] val clientCalls = "_root_.io.grpc.stub.ClientCalls"
private[this] val guavaFuture2Task = "grpcmonix.guavaFuture2Task"
private[this] val liftFromGrpcOperator = "grpcmonix.liftFromGrpcOperator"
private[this] val subscriber2Observer = "grpcmonix.monixSubscriber2GrpcStreamObserver"
private[this] val observer2Subscriber = "grpcmonix.grpcStreamObserver2monixSubscriber"
private[this] val observer2Callback = "grpcmonix.grpcStreamObserver2Callback"
private[this] val unliftByTransformer = "grpcmonix.unliftByTransformer"
private[this] val monixTask = "Task"
private[this] def clientMethodImpl(m: MethodDescriptor) =
PrinterEndo { p =>
m.streamType match {
case StreamType.Unary =>
p.addM(
s"""|override ${serviceMethodSignature(m)} = {
| $guavaFuture2Task($clientCalls.futureUnaryCall(channel.newCall(${m.descriptorName}, options), request))
|}"""
)
case StreamType.ServerStreaming =>
p.addM(
s"""|override ${serviceMethodSignature(m)} = Observable.create[${m.scalaOut}](OverflowStrategy.Unbounded){
| subscriber: Subscriber[${m.scalaOut}] =>
| val responseObserver = $subscriber2Observer(subscriber)
| $clientCalls.asyncServerStreamingCall(
| channel.newCall(${m.descriptorName}, options),
| request,
| responseObserver)
| Cancelable(() => {})
|}"""
)
case StreamType.Bidirectional =>
p.addM(
s"""|override ${serviceMethodSignature(m)} = {
| $liftFromGrpcOperator({responseObserver : $streamObserver[${m.scalaOut}] =>
| $clientCalls.asyncBidiStreamingCall(channel.newCall(${m.descriptorName}, options), responseObserver)
| }, input)
|}"""
)
case StreamType.ClientStreaming =>
p.addM(
s"""|override ${serviceMethodSignature(m)} = {
| $liftFromGrpcOperator({responseObserver : $streamObserver[${m.scalaOut}] =>
| $clientCalls.asyncClientStreamingCall(channel.newCall(${m.descriptorName}, options), responseObserver)
| }, input).headL
|}"""
)
}
} andThen PrinterEndo { _.newline }
private def stubImplementation(
className: String,
baseClass: String,
methods: Seq[PrinterEndo]
): PrinterEndo = { p =>
val build =
s" override def build(channel: $channel, options: $callOptions): $className = new $className(channel, options)"
p.add(
s"class $className(channel: $channel, options: $callOptions = $callOptions.DEFAULT) extends $abstractStub[$className](channel, options) with $baseClass {"
)
.withIndent(
methods: _*
)
.add(
build
)
.add(
"}"
)
}
private[this] val stub: PrinterEndo = {
val methods = service.getMethods.asScala.map(clientMethodImpl(_))
stubImplementation(service.stub, service.name, methods)
}
private[this] def methodDescriptor(method: MethodDescriptor) = PrinterEndo {
p =>
def marshaller(typeName: String) =
s"new com.trueaccord.scalapb.grpc.Marshaller($typeName)"
val methodType = method.streamType match {
case StreamType.Unary => "UNARY"
case StreamType.ClientStreaming => "CLIENT_STREAMING"
case StreamType.ServerStreaming => "SERVER_STREAMING"
case StreamType.Bidirectional => "BIDI_STREAMING"
}
val grpcMethodDescriptor = "_root_.io.grpc.MethodDescriptor"
p.addM(
s"""val ${method.descriptorName}: $grpcMethodDescriptor[${method.scalaIn}, ${method.scalaOut}] =
| $grpcMethodDescriptor.create(
| $grpcMethodDescriptor.MethodType.$methodType,
| $grpcMethodDescriptor.generateFullMethodName("${service.getFullName}", "${method.getName}"),
| ${marshaller(method.scalaIn)},
| ${marshaller(method.scalaOut)})
|""")
}
private[this] def addMethodImplementation(
method: MethodDescriptor): PrinterEndo = PrinterEndo {
_.add(".addMethod(")
.add(s" ${method.descriptorName},")
.withIndent(PrinterEndo { p =>
val call = method.streamType match {
case StreamType.Unary => s"$serverCalls.asyncUnaryCall"
case StreamType.ClientStreaming =>
s"$serverCalls.asyncClientStreamingCall"
case StreamType.ServerStreaming =>
s"$serverCalls.asyncServerStreamingCall"
case StreamType.Bidirectional =>
s"$serverCalls.asyncBidiStreamingCall"
}
val scheduler = "scheduler"
val serviceImpl = "serviceImpl"
method.streamType match {
case StreamType.Unary =>
val serverMethod =
s"$serverCalls.UnaryMethod[${method.scalaIn}, ${method.scalaOut}]"
p.addM(s"""$call(new $serverMethod {
| override def invoke(request: ${method.scalaIn}, observer: $streamObserver[${method.scalaOut}]): Unit =
| $serviceImpl.${method.name}(request).runAsync($observer2Callback(observer))(
| $scheduler)
|}))""")
case StreamType.ServerStreaming =>
val serverMethod =
s"$serverCalls.ServerStreamingMethod[${method.scalaIn}, ${method.scalaOut}]"
p.addM(s"""$call(new $serverMethod {
| override def invoke(request: ${method.scalaIn}, observer: $streamObserver[${method.scalaOut}]): Unit =
| $serviceImpl.${method.name}(request).subscribe($observer2Subscriber(observer, scheduler))
|}))""")
case StreamType.Bidirectional =>
//ClientStreamingMethod
val serverMethod =
s"$serverCalls.BidiStreamingMethod[${method.scalaIn}, ${method.scalaOut}]"
p.addM(s"""$call(new $serverMethod {
| override def invoke(observer: $streamObserver[${method.scalaOut}]): $streamObserver[${method.scalaIn}] = {
| val subscriberOut : Subscriber[${method.scalaOut}] = $observer2Subscriber(observer, scheduler)
| val subscriberIn : Subscriber[${method.scalaIn}] = $unliftByTransformer(
| observableIn => $serviceImpl.${method.name}(observableIn),
| subscriberOut
| )
| $subscriber2Observer(subscriberIn)
| }
|}))""")
case StreamType.ClientStreaming =>
val serverMethod =
s"$serverCalls.ClientStreamingMethod[${method.scalaIn}, ${method.scalaOut}]"
p.addM(s"""$call(new $serverMethod {
| override def invoke(observer: $streamObserver[${method.scalaOut}]): $streamObserver[${method.scalaIn}] = {
| val subscriberOut : Subscriber[${method.scalaOut}] = $observer2Subscriber(observer, scheduler)
| val subscriberIn : Subscriber[${method.scalaIn}] = $unliftByTransformer(
| observableIn => Observable.fromTask($serviceImpl.${method.name}(observableIn)),
| subscriberOut
| )
| $subscriber2Observer(subscriberIn)
| }
|}))""")
}
})
}
private[this] val bindService = {
val scheduler = "scheduler"
val methods = service.methods.map(addMethodImplementation)
val serverServiceDef = "_root_.io.grpc.ServerServiceDefinition"
PrinterEndo(
_.add(
s"""def bindService(serviceImpl: ${service.name}, $scheduler: monix.execution.Scheduler): $serverServiceDef =""")
.withIndent(
_.add(s"""$serverServiceDef.builder("${service.getFullName}")"""),
_.call(methods: _*),
_.add(".build()")))
}
def printService(printer: FunctionalPrinter): FunctionalPrinter = {
printer
.add(
"package " + service.getFile.scalaPackageName,
"",
"import monix.execution.Cancelable",
"import monix.eval.Task",
"import monix.reactive.observers.Subscriber",
"import monix.reactive.{Observable, OverflowStrategy}",
""
)
.call(serviceTrait)
.newline
.add(
s"object ${service.getName} extends grpcmonix.MonixGrpcServiceCompanion[${service.getName}] {"
)
.newline
.withIndent(
_.call(service.methods.map(methodDescriptor): _*),
_.newline,
stub,
_.newline,
bindService,
_.newline,
_.add(
s"def stub(channel: $channel): ${service.stub} = new ${service.stub}(channel)"),
_.newline,
_.add(
s"def descriptor: _root_.com.google.protobuf.Descriptors.ServiceDescriptor = ${service.getFile.fileDescriptorObjectFullName}.descriptor.getServices().get(${service.getIndex})")
)
.add("}")
}
}
@jchapuis
Copy link

Hey, I'm interested in your gist - is it working and can it handle grpc flow control through monix backpressure? Also, how do you go about registering it in scalapb so that the compiler uses your generator? Thanks a lot!

@vyshane
Copy link

vyshane commented Apr 28, 2017

@Baccata, I'd be interested in the answers to @jchapuis's question too :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment