Created
April 2, 2024 21:05
-
-
Save Bitnik212/9d94f26c4791af485504ca14c00b8187 to your computer and use it in GitHub Desktop.
aiohttp interceptor
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dataclasses import dataclass | |
from ssl import SSLContext | |
from types import SimpleNamespace | |
from typing import Any, Optional, Mapping, Iterable, Union, Callable, Awaitable | |
from aiohttp.helpers import _SENTINEL | |
from aiohttp import BasicAuth, ClientResponse as ClientResponse, ClientTimeout, Fingerprint as Fingerprint | |
from aiohttp.typedefs import StrOrURL, LooseCookies, LooseHeaders | |
@dataclass | |
class AIOHttpRequest: | |
method: str | |
str_or_url: StrOrURL | |
params: Optional[Mapping[str, str]] | |
data: Any | |
json: Any | |
cookies: Optional[LooseCookies] | |
headers: Optional[LooseHeaders] | |
skip_auto_headers: Optional[Iterable[str]] | |
auth: Optional[BasicAuth] | |
allow_redirects: bool | |
max_redirects: int | |
compress: Optional[str] | |
chunked: Optional[bool] | |
expect100: bool | |
raise_for_status: Union[None, bool, Callable[[ClientResponse], Awaitable[None]]] | |
read_until_eof: bool | |
proxy: Optional[StrOrURL] | |
proxy_auth: Optional[BasicAuth] | |
timeout: Union[ClientTimeout, _SENTINEL] | |
verify_ssl: Optional[bool] | |
fingerprint: Optional[bytes] | |
ssl_context: Optional[SSLContext] | |
ssl: Union[SSLContext, bool, Fingerprint] | |
server_hostname: Optional[str] | |
proxy_headers: Optional[LooseHeaders] | |
trace_request_ctx: Optional[SimpleNamespace] | |
read_bufsize: Optional[int] | |
auto_decompress: Optional[bool] | |
max_line_size: Optional[int] | |
max_field_size: Optional[int] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from aiohttp import ClientResponse | |
from .InterceptorChain import InterceptorChain | |
class Interceptor: | |
async def intercept(self, chain: InterceptorChain) -> ClientResponse: ... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Optional | |
from aiohttp import ClientResponse | |
from .AIOHttpRequest import AIOHttpRequest | |
class InterceptorChain: | |
def __init__(self, request: AIOHttpRequest): | |
self.request: AIOHttpRequest = request | |
async def proceed(self, request: Optional[AIOHttpRequest] = None) -> ClientResponse: ... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dataclasses import asdict | |
from typing import Callable, Optional | |
from aiohttp import ClientResponse | |
from .AIOHttpRequest import AIOHttpRequest | |
from .Interceptor import Interceptor | |
from .InterceptorChain import InterceptorChain | |
class InterceptorChainImpl(InterceptorChain): | |
def __init__(self, interceptors: list[Interceptor], request: AIOHttpRequest, process: Callable, index: int = -1): | |
super().__init__(request) | |
self.__interceptors = interceptors | |
self.__index = index | |
self.send_request = process | |
def copy(self, index: int, request: AIOHttpRequest): | |
return InterceptorChainImpl( | |
index=index, | |
interceptors=self.__interceptors, | |
process=self.send_request, | |
request=request | |
) | |
async def proceed(self, request: Optional[AIOHttpRequest] = None) -> ClientResponse: | |
if request is None: | |
request = self.request | |
next_chain = self.copy(index=self.__index + 1, request=request) | |
if next_chain.__index < len(self.__interceptors): | |
interceptor: Interceptor = self.__interceptors[next_chain.__index] | |
return await interceptor.intercept(chain=next_chain) | |
else: | |
return await self.send_request(**asdict(request)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from aiohttp import ClientResponse | |
from .Interceptor import Interceptor | |
from .InterceptorChainImpl import InterceptorChainImpl | |
class LogInterceptor(Interceptor): | |
async def intercept(self, chain: InterceptorChainImpl) -> ClientResponse: | |
response = await chain.proceed() | |
request = response.request_info | |
print("\n" + "=="*25 + f" {request.method} '{request.url}' " + "==" * 25) | |
print(f"-> Request with \nheaders={request.headers}, \ncookies={chain.request.cookies}\n") | |
print(f"-> Response with \nstatus_code={response.status}, \ncontent_type={response.content_type}, \nheaders={dict(response.headers)}, \ncookies={response.cookies}") | |
return response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import json | |
from ssl import SSLContext | |
from types import SimpleNamespace | |
from typing import Any, Optional, Mapping, Iterable, Union, Callable, Awaitable, Type, List, Set | |
from aiohttp.abc import AbstractCookieJar | |
from aiohttp.http import HttpVersion, HttpVersion11 | |
from aiohttp.connector import BaseConnector | |
from aiohttp import ClientSession, BasicAuth, ClientResponse, ClientTimeout, Fingerprint, ClientRequest, ClientWebSocketResponse, TraceConfig | |
from aiohttp.helpers import _SENTINEL, sentinel | |
from aiohttp.typedefs import StrOrURL, LooseCookies, LooseHeaders, JSONEncoder | |
from .AIOHttpRequest import AIOHttpRequest | |
from .Interceptor import Interceptor | |
from .InterceptorChainImpl import InterceptorChainImpl | |
_CharsetResolver = Callable[[ClientResponse, bytes], str] | |
class WebClientSession(ClientSession): | |
def __init__( | |
self, | |
base_url: Optional[StrOrURL] = None, | |
interceptors: Optional[Set[Interceptor]] = None, | |
*, | |
connector: Optional[BaseConnector] = None, | |
loop: Optional[asyncio.AbstractEventLoop] = None, | |
cookies: Optional[LooseCookies] = None, | |
headers: Optional[LooseHeaders] = None, | |
skip_auto_headers: Optional[Iterable[str]] = None, | |
auth: Optional[BasicAuth] = None, | |
json_serialize: JSONEncoder = json.dumps, | |
request_class: Type[ClientRequest] = ClientRequest, | |
response_class: Type[ClientResponse] = ClientResponse, | |
ws_response_class: Type[ClientWebSocketResponse] = ClientWebSocketResponse, | |
version: HttpVersion = HttpVersion11, | |
cookie_jar: Optional[AbstractCookieJar] = None, | |
connector_owner: bool = True, | |
raise_for_status: Union[ | |
bool, Callable[[ClientResponse], Awaitable[None]] | |
] = False, | |
read_timeout: Union[float, _SENTINEL] = sentinel, | |
conn_timeout: Optional[float] = None, | |
timeout: Union[object, ClientTimeout] = sentinel, | |
auto_decompress: bool = True, | |
trust_env: bool = False, | |
requote_redirect_url: bool = True, | |
trace_configs: Optional[List[TraceConfig]] = None, | |
read_bufsize: int = 2 ** 16, | |
max_line_size: int = 8190, | |
max_field_size: int = 8190, | |
fallback_charset_resolver: _CharsetResolver = lambda r, b: "utf-8", | |
): | |
super().__init__(base_url, connector=connector, loop=loop, cookies=cookies, headers=headers, skip_auto_headers=skip_auto_headers, auth=auth, json_serialize=json_serialize, request_class=request_class, response_class=response_class, ws_response_class=ws_response_class, version=version, cookie_jar=cookie_jar, connector_owner=connector_owner, raise_for_status=raise_for_status, read_timeout=read_timeout, conn_timeout=conn_timeout, timeout=timeout, auto_decompress=auto_decompress, trust_env=trust_env, requote_redirect_url=requote_redirect_url, trace_configs=trace_configs, read_bufsize=read_bufsize, max_line_size=max_line_size, max_field_size=max_field_size, fallback_charset_resolver=fallback_charset_resolver) | |
if interceptors is None: | |
self.interceptors = set() | |
else: | |
self.interceptors = interceptors | |
self.session_data = { | |
"base_url": base_url, | |
"json": json, | |
"cookies": cookies, | |
"headers": headers, | |
"skip_auto_headers": skip_auto_headers, | |
"auth": auth, | |
"raise_for_status": raise_for_status, | |
"timeout": timeout, | |
"read_bufsize": read_bufsize, | |
"auto_decompress": auto_decompress, | |
"max_line_size": max_line_size, | |
"max_field_size": max_field_size, | |
"read_timeout": read_timeout, | |
"conn_timeout": conn_timeout | |
} | |
async def _request( | |
self, | |
method: str, | |
str_or_url: StrOrURL, | |
*, | |
params: Optional[Mapping[str, str]] = None, | |
data: Any = None, | |
json: Any = None, | |
cookies: Optional[LooseCookies] = None, | |
headers: Optional[LooseHeaders] = None, | |
skip_auto_headers: Optional[Iterable[str]] = None, | |
auth: Optional[BasicAuth] = None, | |
allow_redirects: bool = True, | |
max_redirects: int = 10, | |
compress: Optional[str] = None, | |
chunked: Optional[bool] = None, | |
expect100: bool = False, | |
raise_for_status: Union[ | |
None, bool, Callable[[ClientResponse], Awaitable[None]] | |
] = None, | |
read_until_eof: bool = True, | |
proxy: Optional[StrOrURL] = None, | |
proxy_auth: Optional[BasicAuth] = None, | |
timeout: Union[ClientTimeout, _SENTINEL] = sentinel, | |
verify_ssl: Optional[bool] = None, | |
fingerprint: Optional[bytes] = None, | |
ssl_context: Optional[SSLContext] = None, | |
ssl: Union[SSLContext, bool, Fingerprint] = True, | |
server_hostname: Optional[str] = None, | |
proxy_headers: Optional[LooseHeaders] = None, | |
trace_request_ctx: Optional[SimpleNamespace] = None, | |
read_bufsize: Optional[int] = None, | |
auto_decompress: Optional[bool] = None, | |
max_line_size: Optional[int] = None, | |
max_field_size: Optional[int] = None, | |
) -> ClientResponse: | |
request = { | |
"method": method, | |
"str_or_url": str_or_url, | |
"params": params, | |
"data": data, | |
"json": json, | |
"cookies": cookies if cookies is not None else self.session_data.get('cookies'), | |
"headers": headers if headers is not None else self.session_data.get('headers'), | |
"skip_auto_headers": skip_auto_headers, | |
"auth": auth if auth is not None else self.session_data.get('auth'), | |
"allow_redirects": allow_redirects, | |
"max_redirects": max_redirects, | |
"compress": compress, | |
"chunked": chunked, | |
"expect100": expect100, | |
"raise_for_status": raise_for_status, | |
"read_until_eof": read_until_eof, | |
"proxy": proxy, | |
"proxy_auth": proxy_auth, | |
"timeout": timeout, | |
"verify_ssl": verify_ssl, | |
"fingerprint": fingerprint, | |
"ssl_context": ssl_context, | |
"ssl": ssl, | |
"server_hostname": server_hostname, | |
"proxy_headers": proxy_headers, | |
"trace_request_ctx": trace_request_ctx, | |
"read_bufsize": read_bufsize, | |
"auto_decompress": auto_decompress, | |
"max_line_size": max_line_size, | |
"max_field_size": max_field_size | |
} | |
if len(self.interceptors) == 0: | |
return await super()._request(method, str_or_url, params=params, data=data, json=json, cookies=cookies, headers=headers, skip_auto_headers=skip_auto_headers, auth=auth, allow_redirects=allow_redirects, max_redirects=max_redirects, compress=compress, chunked=chunked, expect100=expect100, raise_for_status=raise_for_status, read_until_eof=read_until_eof, proxy=proxy, proxy_auth=proxy_auth, timeout=timeout, verify_ssl=verify_ssl, fingerprint=fingerprint, ssl_context=ssl_context, ssl=ssl, server_hostname=server_hostname, proxy_headers=proxy_headers, trace_request_ctx=trace_request_ctx, read_bufsize=read_bufsize, auto_decompress=auto_decompress, max_line_size=max_line_size, max_field_size=max_field_size) | |
else: | |
try: | |
return await InterceptorChainImpl( | |
interceptors=list(self.interceptors), | |
request=AIOHttpRequest(**request), | |
process=super()._request | |
).proceed() | |
except Exception as e: # TODO change to specific exception type | |
await self.close() | |
raise e |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment