-
-
Save Diniboy1123/a54f22ecc177a73afa707b1aff4e8d13 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import struct | |
from typing import Generator, Tuple | |
import requests | |
from requests import Response | |
from requests.exceptions import ContentDecodingError, InvalidHeader | |
from enum import Enum | |
class StatusCode(Enum): | |
# https://grpc.github.io/grpc/core/md_doc_statuscodes.html | |
OK = (0, "OK") | |
CANCELLED = (1, "CANCELLED") | |
UNKNOWN = (2, "UNKNOWN") | |
INVALID_ARGUMENT = (3, "INVALID_ARGUMENT") | |
DEADLINE_EXCEEDED = (4, "DEADLINE_EXCEEDED") | |
NOT_FOUND = (5, "NOT_FOUND") | |
ALREADY_EXISTS = (6, "ALREADY_EXISTS") | |
PERMISSION_DENIED = (7, "PERMISSION_DENIED") | |
RESOURCE_EXHAUSTED = (8, "RESOURCE_EXHAUSTED") | |
FAILED_PRECONDITION = (9, "FAILED_PRECONDITION") | |
ABORTED = (10, "ABORTED") | |
OUT_OF_RANGE = (11, "OUT_OF_RANGE") | |
UNIMPLEMENTED = (12, "UNIMPLEMENTED") | |
INTERNAL = (13, "INTERNAL") | |
UNAVAILABLE = (14, "UNAVAILABLE") | |
DATA_LOSS = (15, "DATA_LOSS") | |
UNAUTHENTICATED = (16, "UNAUTHENTICATED") | |
class RpcError(Exception): | |
def __init__(self, status: int, message: str): | |
self.status = status | |
self.message = message | |
super().__init__(f"RPC error {status}: {message}") | |
class GrpcWebClient: | |
def __init__(self, base_url: str, headers: dict): | |
self.base_url = base_url | |
self.headers = headers | |
self._HEADER_FORMAT = ">BI" | |
self._HEADER_LENGTH = struct.calcsize(self._HEADER_FORMAT) | |
self._DEFAULT_CHUNK_SIZE = 512 | |
def _pack_header(self, trailer: bool, compressed: bool, length: int) -> bytes: | |
flags = (trailer << 7) | (compressed) | |
return struct.pack(self._HEADER_FORMAT, flags, length) | |
def _unpack_header(self, message: bytes) -> Tuple[bool, bool, int]: | |
flags, length = struct.unpack(self._HEADER_FORMAT, message) | |
trailer = bool(flags & (1 << 7)) | |
compressed = bool(flags & 1) | |
return trailer, compressed, length | |
def _read_upto(self, length, previous, iterator) -> Tuple[bytes, bytes]: | |
while len(previous) < length: | |
try: | |
previous += next(iterator) | |
except StopIteration: | |
break | |
return previous[:length], previous[length:] | |
def _unwrap_message_stream( | |
self, response: Response | |
) -> Generator[bytes, None, None]: | |
it = response.iter_content(self._DEFAULT_CHUNK_SIZE) | |
content = b"" | |
trailer = None | |
while not trailer: | |
header, content = self._read_upto(self._HEADER_LENGTH, content, it) | |
if len(header) != self._HEADER_LENGTH: | |
raise InvalidHeader( | |
f"Expected {self._HEADER_LENGTH} bytes, got {len(header)} bytes" | |
) | |
trailer, _, length = self._unpack_header(header) | |
if trailer: | |
break | |
data, content = self._read_upto(length, content, it) | |
if length != len(data): | |
raise ContentDecodingError( | |
f"Expected {length} bytes, got {len(data)} bytes" | |
) | |
yield data | |
if trailer: | |
trailer = self._deserialize_trailer(content) | |
if "grpc-status" in trailer: | |
status = int(trailer["grpc-status"]) | |
if status != StatusCode.OK.value[0]: | |
message = trailer.get("grpc-message", "") | |
raise RpcError(status, message) | |
def _deserialize_trailer(self, data: bytes) -> dict: | |
return dict([line.split(":", 1) for line in data.decode("utf8").splitlines()]) | |
def wrap_message( | |
self, message: bytes, trailer: bool = False, compressed: bool = False | |
) -> bytes: | |
header = self._pack_header(trailer, compressed, len(message)) | |
return header + message | |
def call_grpc(self, endpoint: str, data: bytes) -> Generator[bytes, None, None]: | |
url = f"{self.base_url}/{endpoint}" | |
response = requests.post(url, headers=self.headers, data=data) | |
yield from self._unwrap_message_stream(response) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment