Skip to content

Instantly share code, notes, and snippets.

@dtaniwaki
Last active February 6, 2024 21:28
Show Gist options
  • Save dtaniwaki/341ec184eed20d965e154f71a826ee73 to your computer and use it in GitHub Desktop.
Save dtaniwaki/341ec184eed20d965e154f71a826ee73 to your computer and use it in GitHub Desktop.
import collections
import functools
import logging
from typing import Any, Callable, Iterable, Union
from google.protobuf.message import Message
from grpc import HandlerCallDetails, RpcMethodHandler, ServerInterceptor, ServicerContext, StatusCode
from grpc.experimental import wrap_server_method_handler
from validator import ValidationFailed, validate
logger = logging.getLogger(__name__)
MESSAGE_TYPE = Union[Message, Iterable[Message]]
def _validate_iter(message_iter: Iterable[Message]) -> Iterable[Message]:
for msg in message_iter:
v = validate(msg)
v(msg)
yield msg
def _validate_response_iter(response_iter: Iterable[Message]) -> Iterable[Message]:
for res in response_iter:
v = validate(res)
try:
v(res)
except ValidationFailed as e:
logger.warning("Response validation failed: %s" % str(e))
yield res
def _wrapper(
behavior: Callable[[MESSAGE_TYPE, ServicerContext], MESSAGE_TYPE]
) -> Callable[[MESSAGE_TYPE, ServicerContext], Message]:
@functools.wraps(behavior)
def wrapper(request: MESSAGE_TYPE, context: ServicerContext) -> MESSAGE_TYPE:
if isinstance(request, collections.Iterable):
# No validation until the actual iteration in behavior.
request = _validate_iter(request)
else:
try:
v = validate(request)
v(request)
except ValidationFailed as e:
context.abort(StatusCode.INVALID_ARGUMENT, str(e))
try:
response = behavior(request, context)
if isinstance(response, collections.Iterable):
response = _validate_response_iter(response)
else:
v = validate(response)
try:
v(response)
except ValidationFailed as e:
logger.warning("Response validation failed: %s" % str(e))
return response
except ValidationFailed as e:
context.abort(StatusCode.INVALID_ARGUMENT, str(e))
return # type: ignore
return wrapper
class ProtocValidationServerInterceptor(ServerInterceptor): # type: ignore
def intercept_service(
self, continuation: Callable[[HandlerCallDetails], RpcMethodHandler], handler_call_details: HandlerCallDetails
) -> RpcMethodHandler:
handler: RpcMethodHandler = continuation(handler_call_details)
return wrap_server_method_handler(_wrapper, handler)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment