Last active
February 22, 2021 14:39
-
-
Save DanielChabrowski/270e2725e30d7313fe417a20e8741d7a to your computer and use it in GitHub Desktop.
httpx http2 with prior knowledge transport
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import httpx | |
from typing import Optional, Tuple, cast | |
from httpcore import AsyncHTTPTransport, AsyncByteStream, ConnectError, ConnectTimeout | |
from httpcore._async.base import ConnectionState, NewConnectionRequired | |
from httpcore._async.connection import RETRIES_BACKOFF_FACTOR | |
from httpcore._async.http import AsyncBaseHTTPConnection | |
from httpcore._async.http2 import AsyncHTTP2Connection | |
from httpcore._backends.auto import AsyncSocketStream, AsyncBackend, AsyncLock, AutoBackend | |
from httpcore._types import Origin, URL, Headers, TimeoutDict | |
from httpcore._utils import url_to_origin, exponential_backoff | |
from ssl import SSLContext | |
class Http2PriorKnowledgeTransport(AsyncHTTPTransport): | |
def __init__( | |
self, | |
uds: str = None, | |
ssl_context: SSLContext = None, | |
socket: AsyncSocketStream = None, | |
local_address: str = None, | |
retries: int = 0, | |
backend: AsyncBackend = None, | |
): | |
self.ssl_context = None | |
self.uds = uds | |
self.socket = socket | |
self.local_address = local_address | |
self.retries = retries | |
self.connection: Optional[AsyncBaseHTTPConnection] = None | |
self.connect_failed = False | |
self.expires_at: Optional[float] = None | |
self.backend = AutoBackend() if backend is None else backend | |
def __repr__(self) -> str: | |
http_version = "HTTP2 with prior knowledge" | |
return f"<AsyncHTTPConnection http_version={http_version} state={self.state}>" | |
def info(self) -> str: | |
if self.connection is None: | |
return "Not connected" | |
elif self.state == ConnectionState.PENDING: | |
return "Connecting" | |
return self.connection.info() | |
@property | |
def request_lock(self) -> AsyncLock: | |
# We do this lazily, to make sure backend autodetection always | |
# runs within an async context. | |
if not hasattr(self, "_request_lock"): | |
self._request_lock = self.backend.create_lock() | |
return self._request_lock | |
async def arequest( | |
self, | |
method: bytes, | |
url: URL, | |
headers: Headers = None, | |
stream: AsyncByteStream = None, | |
ext: dict = None, | |
) -> Tuple[int, Headers, AsyncByteStream, dict]: | |
origin = url_to_origin(url) | |
ext = {} if ext is None else ext | |
timeout = cast(TimeoutDict, ext.get("timeout", {})) | |
async with self.request_lock: | |
if self.state == ConnectionState.PENDING: | |
if not self.socket: | |
self.socket = await self._open_socket(origin, timeout) | |
self._create_connection(self.socket) | |
elif self.state in (ConnectionState.READY, ConnectionState.IDLE): | |
pass | |
elif self.state == ConnectionState.ACTIVE: | |
pass | |
else: | |
raise NewConnectionRequired() | |
assert self.connection is not None | |
return await self.connection.arequest(method, url, headers, stream, ext) | |
async def _open_socket(self, origin: Origin, timeout: TimeoutDict = None) -> AsyncSocketStream: | |
scheme, hostname, port = origin | |
timeout = {} if timeout is None else timeout | |
ssl_context = self.ssl_context if scheme == b"https" else None | |
retries_left = self.retries | |
delays = exponential_backoff(factor=RETRIES_BACKOFF_FACTOR) | |
while True: | |
try: | |
if self.uds is None: | |
return await self.backend.open_tcp_stream( | |
hostname, | |
port, | |
ssl_context, | |
timeout, | |
local_address=self.local_address, | |
) | |
else: | |
return await self.backend.open_uds_stream( | |
self.uds, hostname, ssl_context, timeout | |
) | |
except (ConnectError, ConnectTimeout): | |
if retries_left <= 0: | |
self.connect_failed = True | |
raise | |
retries_left -= 1 | |
delay = next(delays) | |
await self.backend.sleep(delay) | |
except Exception: # noqa: PIE786 | |
self.connect_failed = True | |
raise | |
def _create_connection(self, socket: AsyncSocketStream) -> None: | |
self.connection = AsyncHTTP2Connection( | |
socket=socket, backend=self.backend, ssl_context=self.ssl_context | |
) | |
@property | |
def state(self) -> ConnectionState: | |
if self.connect_failed: | |
return ConnectionState.CLOSED | |
elif self.connection is None: | |
return ConnectionState.PENDING | |
return self.connection.get_state() | |
async def aclose(self) -> None: | |
async with self.request_lock: | |
if self.connection is not None: | |
await self.connection.aclose() | |
async def run_client(): | |
async with httpx.AsyncClient(http2=True, transport=Http2PriorKnowledgeTransport()) as client: | |
response = await client.post('http://192.168.1.3:3000/echo', content="hello") | |
print(response.content, response.status_code) | |
asyncio.run(run_client()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment