Skip to content

Instantly share code, notes, and snippets.

@w4rum
Created October 24, 2020 17:04
Show Gist options
  • Save w4rum/4e20ec18b9065b1b6780e2f92ac4b6f0 to your computer and use it in GitHub Desktop.
Save w4rum/4e20ec18b9065b1b6780e2f92ac4b6f0 to your computer and use it in GitHub Desktop.
syntax = "proto3";
package example_service;
service ExampleService {
rpc ExampleUnaryUnary(ExampleRequest) returns (ExampleResponse);
rpc ExampleUnaryStream(ExampleRequest) returns (stream ExampleResponse);
rpc ExampleStreamUnary(stream ExampleRequest) returns (ExampleResponse);
rpc ExampleStreamStream(stream ExampleRequest) returns (stream ExampleResponse);
}
message ExampleRequest {
string example_string = 1;
}
message ExampleResponse {
string example_string = 1;
}
"""
Original gist by nat-n:
https://gist.github.com/nat-n/e90097ebfb861cbb25e20b68bec0e39c
"""
import inspect
from collections import AsyncIterator
import grpclib
import typing
import typing_inspect
from grpclib.const import Cardinality
from grpclib.server import Stream
def rpc_unary_unary(*, name=None):
return _rpc_method(name=name, cardinality=Cardinality.UNARY_UNARY)
def rpc_unary_stream(*, name=None):
return _rpc_method(name=name, cardinality=Cardinality.UNARY_STREAM)
def rpc_stream_unary(*, name=None):
return _rpc_method(name=name, cardinality=Cardinality.STREAM_UNARY)
def rpc_stream_stream(*, name=None):
return _rpc_method(name=name, cardinality=Cardinality.STREAM_STREAM)
def _rpc_method(*, name=None, cardinality=Cardinality.UNARY_UNARY):
def inner_decorator(func):
signature = typing.get_type_hints(func)
request_type = signature["request"]
response_type = signature["return"]
# roughly check that type hints match cardinality
if cardinality.server_streaming:
assert typing_inspect.get_origin(response_type) == AsyncIterator, \
"streaming response type hint is not AsyncIterator"
response_type = typing_inspect.get_args(response_type)[0]
if cardinality.client_streaming:
assert typing_inspect.get_origin(request_type) == AsyncIterator, \
"streaming request type hint is not AsyncIterator"
request_type = typing_inspect.get_args(request_type)[0]
async def wrapper(self, stream: grpclib.server.Stream[request_type,
response_type]):
# get request from client
if not cardinality.client_streaming:
request = await stream.recv_message()
else:
async def request_iterator() \
-> typing.AsyncIterator[request_type]:
async for request_message in stream:
yield request_message
request = request_iterator()
# execute handler and send response back to client
if not cardinality.server_streaming:
response = await func(self, request)
await stream.send_message(response)
else:
response_iter = func(self, request)
# check if response is actually an AsyncIterator
# this might be false if the method just returns without
# yielding at least once
# in that case, we just interpret it as an empty iterator
if isinstance(response_iter, AsyncIterator):
async for response_message in response_iter:
await stream.send_message(response_message)
else:
response_iter.close()
wrapper.__rpc_method__ = {
"request_type": request_type,
"response_type": response_type,
"cardinality": cardinality,
"name": name or func.__name__,
}
assert wrapper.__rpc_method__["request_type"], \
"request type must be annotated"
assert wrapper.__rpc_method__["response_type"], \
"response type must be annotated"
return wrapper
return inner_decorator
def is_rpc_method(value):
"""
Check if the given value is a function annotated with __rpc_method__
"""
return callable(value) and isinstance(
getattr(value, "__rpc_method__", None), dict)
class ServiceStub:
def __mapping__(self):
return {
f"/{self.service_name}/{method.__rpc_method__.get('name')}":
grpclib.const.Handler(
method,
method.__rpc_method__["cardinality"],
method.__rpc_method__["request_type"],
method.__rpc_method__["response_type"],
)
for method_name, method in inspect.getmembers(self, is_rpc_method)
}
import asyncio
from typing import AsyncIterator
from grpclib.server import Server
from .magic_glue import rpc_unary_unary, rpc_unary_stream, ServiceStub, \
rpc_stream_stream, rpc_stream_unary
from ..protobuf.example_service import ExampleRequest, ExampleResponse
class ExampleService(ServiceStub):
service_name = "example_service.ExampleService"
@rpc_unary_unary(name="ExampleUnaryUnary")
async def example_unary_unary(self, request: ExampleRequest) \
-> ExampleResponse:
return ExampleResponse(example_string=request.example_string)
@rpc_unary_stream(name="ExampleUnaryStream")
async def example_unary_stream(self, request: ExampleRequest) \
-> AsyncIterator[ExampleResponse]:
yield ExampleResponse(example_string=request.example_string)
yield ExampleResponse(example_string=request.example_string)
yield ExampleResponse(example_string=request.example_string)
@rpc_stream_unary(name="ExampleStreamUnary")
async def example_stream_unary(self,
request: AsyncIterator[ExampleRequest]) \
-> ExampleResponse:
async for request_message in request:
# just return on the first message
return ExampleResponse(
example_string=request_message.example_string)
@rpc_stream_stream(name="ExampleStreamStream")
async def example_stream_stream(self,
request: AsyncIterator[ExampleRequest]) \
-> AsyncIterator[ExampleResponse]:
async for request_message in request:
yield ExampleResponse(
example_string=request_message.example_string)
async def run_rpc_server():
server = Server([ExampleService()])
await server.start("localhost", 50051)
await server.wait_closed()
if __name__ == '__main__':
asyncio.run(run_rpc_server())
import asyncio
from grpclib.client import Channel
from ..protobuf.example_service import ExampleServiceStub, ExampleRequest
async def main():
channel = Channel("localhost", 50051)
example_service = ExampleServiceStub(channel)
print("-- Start")
print("Unary Unary:")
print(await example_service.example_unary_unary(
example_string="TEST UNARY UNARY"))
print("Unary Stream:")
async for response_message in example_service.example_unary_stream(
example_string="TEST UNARY STREAM"):
print(response_message)
async def stream_request():
for i in range(10):
yield ExampleRequest(example_string=f"TEST STREAM UNARY {i}")
print("Stream Unary:")
print(await example_service.example_stream_unary(stream_request()))
print("Stream Stream:")
async for response_message in example_service.example_stream_stream(
stream_request()):
print(response_message)
channel.close()
print("-- Done")
asyncio.run(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment