Skip to content

Instantly share code, notes, and snippets.

@Bitnik212
Created April 2, 2024 21:05
Show Gist options
  • Save Bitnik212/9d94f26c4791af485504ca14c00b8187 to your computer and use it in GitHub Desktop.
Save Bitnik212/9d94f26c4791af485504ca14c00b8187 to your computer and use it in GitHub Desktop.
aiohttp interceptor
from dataclasses import dataclass
from ssl import SSLContext
from types import SimpleNamespace
from typing import Any, Optional, Mapping, Iterable, Union, Callable, Awaitable
from aiohttp.helpers import _SENTINEL
from aiohttp import BasicAuth, ClientResponse as ClientResponse, ClientTimeout, Fingerprint as Fingerprint
from aiohttp.typedefs import StrOrURL, LooseCookies, LooseHeaders
@dataclass
class AIOHttpRequest:
method: str
str_or_url: StrOrURL
params: Optional[Mapping[str, str]]
data: Any
json: Any
cookies: Optional[LooseCookies]
headers: Optional[LooseHeaders]
skip_auto_headers: Optional[Iterable[str]]
auth: Optional[BasicAuth]
allow_redirects: bool
max_redirects: int
compress: Optional[str]
chunked: Optional[bool]
expect100: bool
raise_for_status: Union[None, bool, Callable[[ClientResponse], Awaitable[None]]]
read_until_eof: bool
proxy: Optional[StrOrURL]
proxy_auth: Optional[BasicAuth]
timeout: Union[ClientTimeout, _SENTINEL]
verify_ssl: Optional[bool]
fingerprint: Optional[bytes]
ssl_context: Optional[SSLContext]
ssl: Union[SSLContext, bool, Fingerprint]
server_hostname: Optional[str]
proxy_headers: Optional[LooseHeaders]
trace_request_ctx: Optional[SimpleNamespace]
read_bufsize: Optional[int]
auto_decompress: Optional[bool]
max_line_size: Optional[int]
max_field_size: Optional[int]
from aiohttp import ClientResponse
from .InterceptorChain import InterceptorChain
class Interceptor:
async def intercept(self, chain: InterceptorChain) -> ClientResponse: ...
from typing import Optional
from aiohttp import ClientResponse
from .AIOHttpRequest import AIOHttpRequest
class InterceptorChain:
def __init__(self, request: AIOHttpRequest):
self.request: AIOHttpRequest = request
async def proceed(self, request: Optional[AIOHttpRequest] = None) -> ClientResponse: ...
from dataclasses import asdict
from typing import Callable, Optional
from aiohttp import ClientResponse
from .AIOHttpRequest import AIOHttpRequest
from .Interceptor import Interceptor
from .InterceptorChain import InterceptorChain
class InterceptorChainImpl(InterceptorChain):
def __init__(self, interceptors: list[Interceptor], request: AIOHttpRequest, process: Callable, index: int = -1):
super().__init__(request)
self.__interceptors = interceptors
self.__index = index
self.send_request = process
def copy(self, index: int, request: AIOHttpRequest):
return InterceptorChainImpl(
index=index,
interceptors=self.__interceptors,
process=self.send_request,
request=request
)
async def proceed(self, request: Optional[AIOHttpRequest] = None) -> ClientResponse:
if request is None:
request = self.request
next_chain = self.copy(index=self.__index + 1, request=request)
if next_chain.__index < len(self.__interceptors):
interceptor: Interceptor = self.__interceptors[next_chain.__index]
return await interceptor.intercept(chain=next_chain)
else:
return await self.send_request(**asdict(request))
from aiohttp import ClientResponse
from .Interceptor import Interceptor
from .InterceptorChainImpl import InterceptorChainImpl
class LogInterceptor(Interceptor):
async def intercept(self, chain: InterceptorChainImpl) -> ClientResponse:
response = await chain.proceed()
request = response.request_info
print("\n" + "=="*25 + f" {request.method} '{request.url}' " + "==" * 25)
print(f"-> Request with \nheaders={request.headers}, \ncookies={chain.request.cookies}\n")
print(f"-> Response with \nstatus_code={response.status}, \ncontent_type={response.content_type}, \nheaders={dict(response.headers)}, \ncookies={response.cookies}")
return response
import asyncio
import json
from ssl import SSLContext
from types import SimpleNamespace
from typing import Any, Optional, Mapping, Iterable, Union, Callable, Awaitable, Type, List, Set
from aiohttp.abc import AbstractCookieJar
from aiohttp.http import HttpVersion, HttpVersion11
from aiohttp.connector import BaseConnector
from aiohttp import ClientSession, BasicAuth, ClientResponse, ClientTimeout, Fingerprint, ClientRequest, ClientWebSocketResponse, TraceConfig
from aiohttp.helpers import _SENTINEL, sentinel
from aiohttp.typedefs import StrOrURL, LooseCookies, LooseHeaders, JSONEncoder
from .AIOHttpRequest import AIOHttpRequest
from .Interceptor import Interceptor
from .InterceptorChainImpl import InterceptorChainImpl
_CharsetResolver = Callable[[ClientResponse, bytes], str]
class WebClientSession(ClientSession):
def __init__(
self,
base_url: Optional[StrOrURL] = None,
interceptors: Optional[Set[Interceptor]] = None,
*,
connector: Optional[BaseConnector] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
cookies: Optional[LooseCookies] = None,
headers: Optional[LooseHeaders] = None,
skip_auto_headers: Optional[Iterable[str]] = None,
auth: Optional[BasicAuth] = None,
json_serialize: JSONEncoder = json.dumps,
request_class: Type[ClientRequest] = ClientRequest,
response_class: Type[ClientResponse] = ClientResponse,
ws_response_class: Type[ClientWebSocketResponse] = ClientWebSocketResponse,
version: HttpVersion = HttpVersion11,
cookie_jar: Optional[AbstractCookieJar] = None,
connector_owner: bool = True,
raise_for_status: Union[
bool, Callable[[ClientResponse], Awaitable[None]]
] = False,
read_timeout: Union[float, _SENTINEL] = sentinel,
conn_timeout: Optional[float] = None,
timeout: Union[object, ClientTimeout] = sentinel,
auto_decompress: bool = True,
trust_env: bool = False,
requote_redirect_url: bool = True,
trace_configs: Optional[List[TraceConfig]] = None,
read_bufsize: int = 2 ** 16,
max_line_size: int = 8190,
max_field_size: int = 8190,
fallback_charset_resolver: _CharsetResolver = lambda r, b: "utf-8",
):
super().__init__(base_url, connector=connector, loop=loop, cookies=cookies, headers=headers, skip_auto_headers=skip_auto_headers, auth=auth, json_serialize=json_serialize, request_class=request_class, response_class=response_class, ws_response_class=ws_response_class, version=version, cookie_jar=cookie_jar, connector_owner=connector_owner, raise_for_status=raise_for_status, read_timeout=read_timeout, conn_timeout=conn_timeout, timeout=timeout, auto_decompress=auto_decompress, trust_env=trust_env, requote_redirect_url=requote_redirect_url, trace_configs=trace_configs, read_bufsize=read_bufsize, max_line_size=max_line_size, max_field_size=max_field_size, fallback_charset_resolver=fallback_charset_resolver)
if interceptors is None:
self.interceptors = set()
else:
self.interceptors = interceptors
self.session_data = {
"base_url": base_url,
"json": json,
"cookies": cookies,
"headers": headers,
"skip_auto_headers": skip_auto_headers,
"auth": auth,
"raise_for_status": raise_for_status,
"timeout": timeout,
"read_bufsize": read_bufsize,
"auto_decompress": auto_decompress,
"max_line_size": max_line_size,
"max_field_size": max_field_size,
"read_timeout": read_timeout,
"conn_timeout": conn_timeout
}
async def _request(
self,
method: str,
str_or_url: StrOrURL,
*,
params: Optional[Mapping[str, str]] = None,
data: Any = None,
json: Any = None,
cookies: Optional[LooseCookies] = None,
headers: Optional[LooseHeaders] = None,
skip_auto_headers: Optional[Iterable[str]] = None,
auth: Optional[BasicAuth] = None,
allow_redirects: bool = True,
max_redirects: int = 10,
compress: Optional[str] = None,
chunked: Optional[bool] = None,
expect100: bool = False,
raise_for_status: Union[
None, bool, Callable[[ClientResponse], Awaitable[None]]
] = None,
read_until_eof: bool = True,
proxy: Optional[StrOrURL] = None,
proxy_auth: Optional[BasicAuth] = None,
timeout: Union[ClientTimeout, _SENTINEL] = sentinel,
verify_ssl: Optional[bool] = None,
fingerprint: Optional[bytes] = None,
ssl_context: Optional[SSLContext] = None,
ssl: Union[SSLContext, bool, Fingerprint] = True,
server_hostname: Optional[str] = None,
proxy_headers: Optional[LooseHeaders] = None,
trace_request_ctx: Optional[SimpleNamespace] = None,
read_bufsize: Optional[int] = None,
auto_decompress: Optional[bool] = None,
max_line_size: Optional[int] = None,
max_field_size: Optional[int] = None,
) -> ClientResponse:
request = {
"method": method,
"str_or_url": str_or_url,
"params": params,
"data": data,
"json": json,
"cookies": cookies if cookies is not None else self.session_data.get('cookies'),
"headers": headers if headers is not None else self.session_data.get('headers'),
"skip_auto_headers": skip_auto_headers,
"auth": auth if auth is not None else self.session_data.get('auth'),
"allow_redirects": allow_redirects,
"max_redirects": max_redirects,
"compress": compress,
"chunked": chunked,
"expect100": expect100,
"raise_for_status": raise_for_status,
"read_until_eof": read_until_eof,
"proxy": proxy,
"proxy_auth": proxy_auth,
"timeout": timeout,
"verify_ssl": verify_ssl,
"fingerprint": fingerprint,
"ssl_context": ssl_context,
"ssl": ssl,
"server_hostname": server_hostname,
"proxy_headers": proxy_headers,
"trace_request_ctx": trace_request_ctx,
"read_bufsize": read_bufsize,
"auto_decompress": auto_decompress,
"max_line_size": max_line_size,
"max_field_size": max_field_size
}
if len(self.interceptors) == 0:
return await super()._request(method, str_or_url, params=params, data=data, json=json, cookies=cookies, headers=headers, skip_auto_headers=skip_auto_headers, auth=auth, allow_redirects=allow_redirects, max_redirects=max_redirects, compress=compress, chunked=chunked, expect100=expect100, raise_for_status=raise_for_status, read_until_eof=read_until_eof, proxy=proxy, proxy_auth=proxy_auth, timeout=timeout, verify_ssl=verify_ssl, fingerprint=fingerprint, ssl_context=ssl_context, ssl=ssl, server_hostname=server_hostname, proxy_headers=proxy_headers, trace_request_ctx=trace_request_ctx, read_bufsize=read_bufsize, auto_decompress=auto_decompress, max_line_size=max_line_size, max_field_size=max_field_size)
else:
try:
return await InterceptorChainImpl(
interceptors=list(self.interceptors),
request=AIOHttpRequest(**request),
process=super()._request
).proceed()
except Exception as e: # TODO change to specific exception type
await self.close()
raise e
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment