-
-
Save yayahuman/db06718ffdf8a9b66e133e29d7d7965f to your computer and use it in GitHub Desktop.
Support async file types in `files = {}` and `content = ...`
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# SPDX-FileCopyrightText: 2023 yayahuman <yayahuman@pm.me> | |
# SPDX-License-Identifier: 0BSD | |
__all__ = ( | |
'apply', | |
) | |
import binascii | |
import os | |
import warnings | |
from abc import ABC | |
from inspect import isawaitable, iscoroutinefunction | |
from io import UnsupportedOperation | |
from json import dumps as json_dumps | |
from pathlib import PurePath | |
from typing import AsyncIterable, Iterable, Mapping, Sequence | |
from urllib.parse import urlencode | |
import httpx | |
from httpx import ( | |
_api as httpx_api, | |
_client as httpx_client, | |
_models as httpx_models, | |
_multipart as httpx_multipart, | |
) | |
from httpx._exceptions import StreamConsumed | |
from httpx._transports import ( | |
asgi as httpx_transports_asgi, | |
default as httpx_transports_default, | |
wsgi as httpx_transports_wsgi, | |
) | |
from httpx._urls import URL | |
# ------------------------------- httpx._utils ------------------------------ # | |
from httpx._utils import guess_content_type, primitive_value_to_str, to_bytes | |
MISSING = object() | |
UNAVAILABLE = object() | |
async def maybe_await(obj): | |
if isawaitable(obj): | |
return await obj | |
else: | |
return obj | |
def is_file_content(obj): | |
return ( | |
not isinstance(obj, (set, dict, list, tuple)) | |
and ( | |
isinstance(obj, (str, bytes, Iterable, AsyncIterable)) | |
or hasattr(obj, 'read') | |
or hasattr(obj, 'aread') | |
) | |
) | |
def peek_content_type(obj): | |
if hasattr(obj, 'content_type'): | |
content_type = obj.content_type | |
elif hasattr(obj, 'get_content_type'): | |
content_type = obj.get_content_type() | |
else: | |
content_type = None | |
return content_type | |
async def apeek_content_type(obj): | |
if hasattr(obj, 'content_type'): | |
content_type = obj.content_type | |
elif hasattr(obj, 'aget_content_type'): | |
content_type = await obj.aget_content_type() | |
elif hasattr(obj, 'get_content_type'): | |
content_type = obj.get_content_type() | |
else: | |
content_type = None | |
return content_type | |
def peek_content_length(obj, encoding='utf-8'): | |
if isinstance(obj, (str, bytes)): | |
content_length = len(to_bytes(obj, encoding)) | |
elif hasattr(obj, 'content_length'): | |
content_length = obj.content_length | |
elif hasattr(obj, 'get_content_length'): | |
content_length = obj.get_content_length() | |
else: | |
try: | |
fd = obj.fileno() | |
content_length = os.fstat(fd).st_size | |
except (AttributeError, OSError): | |
try: | |
tell = obj.tell | |
seek = obj.seek | |
if iscoroutinefunction(tell) or iscoroutinefunction(seek): | |
content_length = UNAVAILABLE | |
else: | |
offset = tell() | |
content_length = seek(0, os.SEEK_END) | |
seek(offset) | |
except (AttributeError, OSError): | |
content_length = None | |
return content_length | |
async def apeek_content_length(obj, encoding='utf-8'): | |
if isinstance(obj, (str, bytes)): | |
content_length = len(to_bytes(obj, encoding)) | |
elif hasattr(obj, 'content_length'): | |
content_length = obj.content_length | |
elif hasattr(obj, 'aget_content_length'): | |
content_length = await obj.aget_content_length() | |
elif hasattr(obj, 'get_content_length'): | |
content_length = obj.get_content_length() | |
else: | |
try: | |
fd = obj.fileno() | |
content_length = os.fstat(fd).st_size | |
except (AttributeError, OSError): | |
try: | |
tell = obj.tell | |
seek = obj.seek | |
offset = await maybe_await(tell()) | |
content_length = await maybe_await(seek(0, os.SEEK_END)) | |
await maybe_await(seek(offset)) | |
except (AttributeError, OSError): | |
content_length = None | |
return content_length | |
# ------------------------------ httpx._content ----------------------------- # | |
from httpx._types import AsyncByteStream, SyncByteStream | |
class DefaultEncoding(str): | |
pass | |
FIXED_WIDTH_ENCODING = DefaultEncoding('ascii') # 1 char = 1 byte | |
VARIABLE_WIDTH_ENCODING = DefaultEncoding('utf-8') # 1 char = 1-4 bytes | |
class ByteStream(ABC): | |
__slots__ = () | |
@classmethod | |
def from_( | |
cls, | |
content=None, | |
*, | |
data=None, | |
files=None, | |
boundary=None, | |
text=None, | |
html=None, | |
json=None, | |
encoding=None, | |
): | |
if data is not None and not isinstance(data, Mapping): | |
message = 'Use \'content=<...>\' to upload raw bytes/text content.' | |
warnings.warn(message, DeprecationWarning) | |
return cls.from_content(data, encoding) | |
elif content is not None: | |
return cls.from_content(content, encoding) | |
elif files: | |
return cls.from_multipart_data( | |
data=data, files=files, boundary=boundary, encoding=encoding, | |
) | |
elif data: | |
return cls.from_urlencoded_data(data) | |
elif text is not None: | |
return cls.from_text(text, encoding) | |
elif html is not None: | |
return cls.from_html(html, encoding) | |
elif json is not None: | |
return cls.from_json(json) | |
else: | |
return SimpleByteStream(b'') | |
@staticmethod | |
def from_content(content, encoding=None): | |
if is_file_content(content): | |
if isinstance(content, (str, bytes)): | |
return SimpleByteStream(content, encoding=encoding) | |
elif isinstance(content, (SyncByteStream, AsyncByteStream)): | |
return content | |
else: | |
return IteratorByteStream(content, encoding=encoding) | |
else: | |
raise TypeError(f"Unexpected type for content, {type(content)!r}") | |
@staticmethod | |
def from_multipart_data( | |
data=None, | |
*, | |
files=None, | |
boundary=None, | |
encoding=None, | |
): | |
if data is None: | |
data = {} | |
if files is None: | |
files = {} | |
return MultipartStream(data, files, boundary, encoding) | |
@staticmethod | |
def from_urlencoded_data(data): | |
plain_data = [] | |
convert = primitive_value_to_str | |
for key, value in data.items(): | |
if isinstance(value, (list, tuple)): | |
plain_data.extend((key, convert(x)) for x in value) | |
else: | |
plain_data.append((key, convert(value))) | |
body = urlencode(plain_data, doseq=True).encode() | |
return SimpleByteStream(body, 'application/x-www-form-urlencoded') | |
@staticmethod | |
def from_text(text, encoding=None): | |
if encoding is None: | |
encoding = SimpleByteStream.DEFAULT_ENCODING | |
return SimpleByteStream( | |
text, f"text/plain; charset={encoding.lower()}", encoding=encoding, | |
) | |
@staticmethod | |
def from_html(html, encoding=None): | |
if encoding is None: | |
encoding = SimpleByteStream.DEFAULT_ENCODING | |
return SimpleByteStream( | |
html, f"text/html; charset={encoding.lower()}", encoding=encoding, | |
) | |
@staticmethod | |
def from_json(json): | |
return SimpleByteStream(json_dumps(json).encode(), 'application/json') | |
ByteStream.register(SyncByteStream) | |
ByteStream.register(AsyncByteStream) | |
class SimpleByteStream(SyncByteStream, AsyncByteStream): | |
DEFAULT_ENCODING = VARIABLE_WIDTH_ENCODING | |
def __init__( | |
self, | |
content, | |
content_type=MISSING, | |
content_length=MISSING, | |
*, | |
encoding=None, | |
): | |
if encoding is None: | |
encoding = self.DEFAULT_ENCODING | |
self.content = content | |
if content_type is not MISSING: | |
self.content_type = content_type | |
if content_length is not MISSING: | |
self.content_length = content_length | |
self.encoding = encoding | |
def __iter__(self): | |
yield to_bytes(self.content, self.encoding) | |
async def __aiter__(self): | |
yield to_bytes(self.content, self.encoding) | |
def get_content_type(self): | |
return peek_content_type(self.content) | |
async def aget_content_type(self): | |
return await apeek_content_type(self.content) | |
def get_content_length(self): | |
return peek_content_length(self.content, self.encoding) | |
async def aget_content_length(self): | |
return await apeek_content_length(self.content, self.encoding) | |
class IteratorByteStream(SyncByteStream, AsyncByteStream): | |
DEFAULT_ENCODING = FIXED_WIDTH_ENCODING | |
DEFAULT_CHUNK_SIZE = 64 * 1024 | |
def __init__( | |
self, | |
content, | |
content_type=MISSING, | |
content_length=MISSING, | |
*, | |
encoding=None, | |
chunk_size=None, | |
): | |
if encoding is None: | |
encoding = self.DEFAULT_ENCODING | |
if chunk_size is None: | |
chunk_size = self.DEFAULT_CHUNK_SIZE | |
self.content = content | |
if content_type is not MISSING: | |
self.content_type = content_type | |
if content_length is not MISSING: | |
self.content_length = content_length | |
self.encoding = encoding | |
self.chunk_size = chunk_size | |
self.consumed = False | |
def __iter__(self): | |
consumed = self.consumed | |
encoding = self.encoding | |
chunk_size = self.chunk_size | |
self.consumed = True | |
if isinstance(self.content, (str, bytes)): | |
yield to_bytes(self.content, encoding) | |
return | |
try: | |
self.seek(0) | |
except UnsupportedOperation: | |
if consumed and not isinstance(self.content, Sequence): | |
raise StreamConsumed from None | |
if hasattr(self.content, 'read'): | |
while True: | |
chunk = self.content.read(chunk_size) | |
if not chunk: | |
break | |
yield to_bytes(chunk, encoding) | |
else: | |
for part in self.content: | |
yield to_bytes(part, encoding) | |
async def __aiter__(self): | |
consumed = self.consumed | |
encoding = self.encoding | |
chunk_size = self.chunk_size | |
self.consumed = True | |
if isinstance(self.content, (str, bytes)): | |
yield to_bytes(self.content, encoding) | |
return | |
try: | |
await self.aseek(0) | |
except UnsupportedOperation: | |
if consumed and not isinstance(self.content, Sequence): | |
raise StreamConsumed from None | |
if hasattr(self.content, 'aread'): | |
while True: | |
chunk = await self.content.aread(chunk_size) | |
if not chunk: | |
break | |
yield to_bytes(chunk, encoding) | |
elif hasattr(self.content, 'read'): | |
if iscoroutinefunction(self.content.read): | |
while True: | |
chunk = await self.content.read(chunk_size) | |
if not chunk: | |
break | |
yield to_bytes(chunk, encoding) | |
elif hasattr(self.content, '__aiter__'): | |
async for part in self.content: | |
yield to_bytes(part, encoding) | |
else: | |
while True: | |
chunk = await maybe_await(self.content.read(chunk_size)) | |
if not chunk: | |
break | |
yield to_bytes(chunk, encoding) | |
elif hasattr(self.content, '__aiter__'): | |
async for part in self.content: | |
yield to_bytes(part, encoding) | |
else: | |
for part in self.content: | |
yield to_bytes(part, encoding) | |
def seek(self, offset, whence=os.SEEK_SET): | |
if hasattr(self.content, 'seek'): | |
return self.content.seek(offset, whence) | |
else: | |
raise UnsupportedOperation | |
async def aseek(self, offset, whence=os.SEEK_SET): | |
if hasattr(self.content, 'aseek'): | |
return await self.content.aseek(offset, whence) | |
if hasattr(self.content, 'seek'): | |
return await maybe_await(self.content.seek(offset, whence)) | |
else: | |
raise UnsupportedOperation | |
def get_content_type(self): | |
return peek_content_type(self.content) | |
async def aget_content_type(self): | |
return await apeek_content_type(self.content) | |
def get_content_length(self): | |
return peek_content_length(self.content, self.encoding) | |
async def aget_content_length(self): | |
return await apeek_content_length(self.content, self.encoding) | |
# ----------------------------- httpx._multipart ---------------------------- # | |
from httpx._multipart import ( | |
DataField, | |
get_multipart_boundary_from_content_type, | |
) | |
class FileField(httpx_multipart.FileField): | |
DEFAULT_FILENAME = 'upload' | |
def __init__(self, name, value, encoding): | |
if not isinstance(name, str): | |
raise TypeError( | |
'Invalid type for name.' | |
f" Expected str, got {type(name)}: {name!r}" | |
) | |
content_type = None | |
headers = None | |
if isinstance(value, (list, tuple)): | |
if len(value) == 2: | |
filename, file = value | |
elif len(value) == 3: | |
filename, file, content_type = value | |
elif len(value) == 4: | |
filename, file, content_type, headers = value | |
else: | |
filename, file, content_type, headers, encoding = value | |
else: | |
file = value | |
filename = getattr(file, 'name', None) | |
if filename is None: | |
filename = self.DEFAULT_FILENAME | |
else: | |
filename = PurePath(str(filename)).name | |
if not is_file_content(file): | |
raise TypeError( | |
'Invalid type for file.' | |
f" Expected FileContent, got {type(name)}: {name!r}" | |
) | |
if headers is None: | |
headers = {} | |
self.name = name | |
self.file = file | |
self.filename = filename | |
self.headers = headers | |
self.encoding = encoding | |
has_content_type_header = any( | |
'content-type' in key.lower() for key in headers | |
) | |
if not has_content_type_header: | |
if content_type is None: | |
self.calculate_content_type = True | |
self.calculate_content_type_sync = True | |
self.calculate_content_type_async = True | |
else: | |
self.calculate_content_type = False | |
self.calculate_content_type_sync = False | |
self.calculate_content_type_async = False | |
self.headers = {'Content-Type': content_type, **self.headers} | |
else: | |
self.calculate_content_type = False | |
self.calculate_content_type_sync = False | |
self.calculate_content_type_async = False | |
self.prepare() | |
def prepare(self): | |
auto_headers = {} | |
calculate_content_type = ( | |
self.calculate_content_type | |
and self.calculate_content_type_sync | |
) | |
if calculate_content_type: | |
has_content_type = any( | |
'content-type' in key.lower() for key in self.headers | |
) | |
if not has_content_type: | |
content_type = peek_content_type(self.file) | |
unavailable = (content_type is UNAVAILABLE) | |
if content_type is None or unavailable: | |
content_type = guess_content_type(self.filename) | |
if content_type is not None: | |
auto_headers['Content-Type'] = content_type | |
if content_type is not None or not unavailable: | |
self.calculate_content_type_async = False | |
self.calculate_content_type_sync = False | |
if auto_headers: | |
self.headers = {**auto_headers, **self.headers} | |
async def aprepare(self): | |
auto_headers = {} | |
calculate_content_type = ( | |
self.calculate_content_type | |
and self.calculate_content_type_async | |
) | |
if calculate_content_type: | |
has_content_type = any( | |
'content-type' in key.lower() for key in self.headers | |
) | |
if not has_content_type: | |
content_type = await apeek_content_type(self.file) | |
unavailable = (content_type is UNAVAILABLE) | |
if content_type is None or unavailable: | |
content_type = guess_content_type(self.filename) | |
if content_type is not None: | |
auto_headers['Content-Type'] = content_type | |
if content_type is not None or not unavailable: | |
self.calculate_content_type_sync = False | |
self.calculate_content_type_async = False | |
if auto_headers: | |
self.headers = {**auto_headers, **self.headers} | |
def get_length(self): | |
content_length = peek_content_length(self.stream) | |
if content_length is None: | |
return None | |
return len(self.render_headers()) + content_length | |
async def aget_length(self): | |
content_length = await apeek_content_length(self.stream) | |
if content_length is None: | |
return None | |
return len(await self.arender_headers()) + content_length | |
def render_headers(self): | |
calculate_content_type = ( | |
self.calculate_content_type | |
and self.calculate_content_type_sync | |
) | |
if calculate_content_type or not hasattr(self, '_headers'): | |
self.prepare() | |
if hasattr(self, '_headers'): | |
del self._headers | |
super().render_headers() | |
return self._headers | |
async def arender_headers(self): | |
calculate_content_type = ( | |
self.calculate_content_type | |
and self.calculate_content_type_async | |
) | |
if calculate_content_type or not hasattr(self, '_headers'): | |
await self.aprepare() | |
if hasattr(self, '_headers'): | |
del self._headers | |
super().render_headers() | |
return self._headers | |
def render_data(self): | |
yield from self.stream | |
async def arender_data(self): | |
async for chunk in self.stream: | |
yield chunk | |
async def arender(self): | |
yield await self.arender_headers() | |
async for chunk in self.arender_data(): | |
yield chunk | |
@property | |
def stream(self): | |
try: | |
stream = self._stream | |
except AttributeError: | |
pass | |
else: | |
if stream.content is self.file: | |
return stream | |
self._stream = stream = ByteStream.from_( | |
self.file, encoding=self.encoding, | |
) | |
return stream | |
class MultipartStream(httpx_multipart.MultipartStream): | |
def __init__(self, data, files, boundary, encoding): | |
if boundary is None: | |
boundary = binascii.hexlify(os.urandom(16)) | |
self.boundary = boundary | |
self.content_type = ( | |
f"multipart/form-data; boundary={boundary.decode('ascii')}" | |
) | |
self.fields = list(self._iter_fields(data, files, encoding)) | |
async def __aiter__(self): | |
async for chunk in self.aiter_chunks(): | |
yield chunk | |
def _iter_fields(self, data, files, encoding): | |
for name, value in data.items(): | |
if isinstance(value, (list, tuple)): | |
for item in value: | |
yield DataField(name, item) | |
else: | |
yield DataField(name, value) | |
if isinstance(files, Mapping): | |
file_items = files.items() | |
else: | |
file_items = files | |
for name, value in file_items: | |
yield FileField(name, value, encoding) | |
async def aiter_chunks(self): | |
for field in self.fields: | |
yield b'--%s\r\n' % self.boundary | |
if hasattr(field, 'arender'): | |
async for chunk in field.arender(): | |
yield chunk | |
else: | |
for chunk in field.render(): | |
yield chunk | |
yield b'\r\n' | |
yield b'--%s--\r\n' % self.boundary | |
async def aget_content_length(self): | |
boundary_length = len(self.boundary) | |
content_length = 0 | |
for field in self.fields: | |
content_length += 2 + boundary_length + 2 # b'--{boundary}\r\n' | |
if hasattr(field, 'aget_length'): | |
field_length = await field.aget_length() | |
else: | |
field_length = field.get_length() | |
if field_length is None: | |
return None | |
content_length += field_length | |
content_length += 2 # b'\r\n' | |
content_length += 2 + boundary_length + 4 # b'--{boundary}--\r\n' | |
return content_length | |
# ------------------------------ httpx._models ------------------------------ # | |
from httpx._models import Cookies, Headers | |
class Request(httpx_models.Request): | |
def __init__( | |
self, | |
method, | |
url, | |
*, | |
params=None, | |
headers=None, | |
cookies=None, | |
content=None, | |
data=None, | |
files=None, | |
json=None, | |
stream=None, | |
stream_encoding=None, | |
extensions=None, | |
): | |
if isinstance(method, bytes): | |
method = method.decode('ascii').upper() | |
else: | |
method = method.upper() | |
if params is None: | |
url = URL(url) | |
else: | |
url = URL(url).copy_merge_params(params=params) | |
headers = Headers(headers) | |
if stream_encoding is None: | |
has_only_transfer_encoding_header = ( | |
'Content-Length' not in headers | |
and 'Transfer-Encoding' in headers | |
) | |
if has_only_transfer_encoding_header: | |
stream_encoding = VARIABLE_WIDTH_ENCODING | |
if extensions is None: | |
extensions = {} | |
self.method = method | |
self.url = url | |
self.headers = headers | |
self.extensions = extensions | |
if cookies: | |
Cookies(cookies).set_cookie_header(self) | |
if stream is None: | |
content_type = self.headers.get('Content-Type') | |
stream = ByteStream.from_( | |
content=content, | |
data=data, | |
files=files, | |
json=json, | |
boundary=get_multipart_boundary_from_content_type( | |
content_type=content_type.encode(self.headers.encoding) | |
if content_type | |
else None | |
), | |
encoding=stream_encoding, | |
) | |
has_message_framing_headers = ( | |
'Content-Length' in self.headers | |
or 'Transfer-Encoding' in self.headers | |
) | |
if 'Host' not in self.headers: | |
self.calculate_host = True | |
else: | |
self.calculate_host = False | |
if 'Content-Type' not in self.headers: | |
self.calculate_content_type = True | |
self.calculate_content_type_sync = True | |
self.calculate_content_type_async = True | |
else: | |
self.calculate_content_type = False | |
self.calculate_content_type_sync = False | |
self.calculate_content_type_async = False | |
if not has_message_framing_headers: | |
self.calculate_content_length = True | |
self.calculate_content_length_sync = True | |
self.calculate_content_length_async = True | |
else: | |
self.calculate_content_length = False | |
self.calculate_content_length_sync = False | |
self.calculate_content_length_async = False | |
self.stream = stream | |
self.prepare() | |
if isinstance(stream, SimpleByteStream): | |
self.read() | |
else: | |
self.calculate_host = False | |
self.calculate_content_type = False | |
self.calculate_content_type_sync = False | |
self.calculate_content_type_async = False | |
self.calculate_content_length = False | |
self.calculate_content_length_sync = False | |
self.calculate_content_length_async = False | |
self.stream = stream | |
def prepare(self): | |
auto_headers = {} | |
calculate_host = self.calculate_host | |
if calculate_host: | |
if 'Host' not in self.headers: | |
if self.url.host: | |
auto_headers['Host'] = self.url.netloc | |
self.calculate_host = False | |
calculate_content_type = ( | |
self.calculate_content_type | |
and self.calculate_content_type_sync | |
) | |
if calculate_content_type: | |
if 'Content-Type' not in self.headers: | |
content_type = peek_content_type(self.stream) | |
unavailable = (content_type is UNAVAILABLE) | |
if content_type is not None and not unavailable: | |
auto_headers['Content-Type'] = content_type | |
if self.calculate_content_length: | |
self.calculate_content_length_sync = True | |
if content_type is not None or not unavailable: | |
self.calculate_content_type_async = False | |
self.calculate_content_type_sync = False | |
calculate_content_length = ( | |
self.calculate_content_length | |
and self.calculate_content_length_sync | |
) | |
if calculate_content_length: | |
calculate_content_length = ( | |
'Content-Length' not in self.headers | |
and ( | |
'Transfer-Encoding' not in self.headers | |
or self.headers['Transfer-Encoding'] == 'chunked' | |
) | |
) | |
if calculate_content_length: | |
content_length = peek_content_length(self.stream) | |
unavailable = (content_length is UNAVAILABLE) | |
if content_length is not None and not unavailable: | |
can_be_set = ( | |
content_length > 0 | |
or 'Content-Type' in auto_headers | |
or 'Content-Type' in self.headers | |
or self.method in ['POST', 'PUT', 'PATCH'] | |
) | |
try: | |
del self.headers['Transfer-Encoding'] | |
except KeyError: | |
pass | |
if can_be_set: | |
auto_headers['Content-Length'] = str(content_length) | |
elif 'Transfer-Encoding' not in self.headers: | |
auto_headers['Transfer-Encoding'] = 'chunked' | |
if content_length is not None or not unavailable: | |
self.calculate_content_length_async = False | |
self.calculate_content_length_sync = False | |
if auto_headers: | |
raw_headers = [] | |
keys = [ | |
'Host', 'Content-Type', 'Content-Length', 'Transfer-Encoding', | |
] | |
for key in keys: | |
value = auto_headers.pop(key, None) | |
if value is None: | |
value = self.headers.pop(key, None) | |
if value is not None: | |
raw_headers.append((to_bytes(key), to_bytes(value))) | |
self.headers = Headers(raw_headers + self.headers.raw) | |
async def aprepare(self): | |
auto_headers = {} | |
calculate_host = self.calculate_host | |
if calculate_host: | |
if 'Host' not in self.headers: | |
if self.url.host: | |
auto_headers['Host'] = self.url.netloc | |
self.calculate_host = False | |
calculate_content_type = ( | |
self.calculate_content_type | |
and self.calculate_content_type_async | |
) | |
if calculate_content_type: | |
if 'Content-Type' not in self.headers: | |
content_type = await apeek_content_type(self.stream) | |
unavailable = (content_type is UNAVAILABLE) | |
if content_type is not None and not unavailable: | |
auto_headers['Content-Type'] = content_type | |
if self.calculate_content_length: | |
self.calculate_content_length_async = True | |
if content_type is not None or not unavailable: | |
self.calculate_content_type_sync = False | |
self.calculate_content_type_async = False | |
calculate_content_length = ( | |
self.calculate_content_length | |
and self.calculate_content_length_async | |
) | |
if calculate_content_length: | |
calculate_content_length = ( | |
'Content-Length' not in self.headers | |
and ( | |
'Transfer-Encoding' not in self.headers | |
or self.headers['Transfer-Encoding'] == 'chunked' | |
) | |
) | |
if calculate_content_length: | |
content_length = await apeek_content_length(self.stream) | |
unavailable = (content_length is UNAVAILABLE) | |
if content_length is not None and not unavailable: | |
can_be_set = ( | |
content_length > 0 | |
or 'Content-Type' in auto_headers | |
or 'Content-Type' in self.headers | |
or self.method in ['POST', 'PUT', 'PATCH'] | |
) | |
try: | |
del self.headers['Transfer-Encoding'] | |
except KeyError: | |
pass | |
if can_be_set: | |
auto_headers['Content-Length'] = str(content_length) | |
elif 'Transfer-Encoding' not in self.headers: | |
auto_headers['Transfer-Encoding'] = 'chunked' | |
if content_length is not None or not unavailable: | |
self.calculate_content_length_sync = False | |
self.calculate_content_length_async = False | |
if auto_headers: | |
raw_headers = [] | |
keys = [ | |
'Host', 'Content-Type', 'Content-Length', 'Transfer-Encoding', | |
] | |
for key in keys: | |
value = auto_headers.pop(key, None) | |
if value is None: | |
value = self.headers.pop(key, None) | |
if value is not None: | |
raw_headers.append((to_bytes(key), to_bytes(value))) | |
self.headers = Headers(raw_headers + self.headers.raw) | |
class Response(httpx_models.Response): | |
def __init__( | |
self, | |
status_code, | |
*, | |
headers=None, | |
content=None, | |
text=None, | |
html=None, | |
json=None, | |
stream=None, | |
stream_encoding=None, | |
request=None, | |
extensions=None, | |
history=None, | |
default_encoding='utf-8', | |
): | |
headers = Headers(headers) | |
if stream_encoding is None: | |
has_only_transfer_encoding_header = ( | |
'Content-Length' not in headers | |
and 'Transfer-Encoding' in headers | |
) | |
if has_only_transfer_encoding_header: | |
stream_encoding = VARIABLE_WIDTH_ENCODING | |
if extensions is None: | |
extensions = {} | |
if history is None: | |
history = [] | |
else: | |
history = list(history) | |
self.status_code = status_code | |
self.headers = headers | |
self._request = request | |
self.next_request = None | |
self.extensions = extensions | |
self.history = history | |
self.is_closed = False | |
self.is_stream_consumed = False | |
self.default_encoding = default_encoding | |
self._num_bytes_downloaded = 0 | |
if stream is None: | |
stream = ByteStream.from_( | |
content=content, | |
text=text, | |
html=html, | |
json=json, | |
encoding=stream_encoding, | |
) | |
has_message_framing_headers = ( | |
'Content-Length' in self.headers | |
or 'Transfer-Encoding' in self.headers | |
) | |
if 'Content-Type' not in self.headers: | |
self.calculate_content_type = True | |
self.calculate_content_type_sync = True | |
self.calculate_content_type_async = True | |
else: | |
self.calculate_content_type = False | |
self.calculate_content_type_sync = False | |
self.calculate_content_type_async = False | |
if not has_message_framing_headers: | |
self.calculate_content_length = True | |
self.calculate_content_length_sync = True | |
self.calculate_content_length_async = True | |
else: | |
self.calculate_content_length = False | |
self.calculate_content_length_sync = False | |
self.calculate_content_length_async = False | |
self.stream = stream | |
self.prepare() | |
if isinstance(stream, SimpleByteStream): | |
self.read() | |
else: | |
self.calculate_content_type = False | |
self.calculate_content_type_sync = False | |
self.calculate_content_type_async = False | |
self.calculate_content_length = False | |
self.calculate_content_length_sync = False | |
self.calculate_content_length_async = False | |
self.stream = stream | |
def prepare(self): | |
auto_headers = {} | |
calculate_content_type = ( | |
self.calculate_content_type | |
and self.calculate_content_type_sync | |
) | |
if calculate_content_type: | |
if 'Content-Type' not in self.headers: | |
content_type = peek_content_type(self.stream) | |
unavailable = (content_type is UNAVAILABLE) | |
if content_type is not None and not unavailable: | |
auto_headers['Content-Type'] = content_type | |
if self.calculate_content_length: | |
self.calculate_content_length_sync = True | |
if content_type is not None or not unavailable: | |
self.calculate_content_type_async = False | |
self.calculate_content_type_sync = False | |
calculate_content_length = ( | |
self.calculate_content_length | |
and self.calculate_content_length_sync | |
) | |
if calculate_content_length: | |
calculate_content_length = ( | |
'Content-Length' not in self.headers | |
and ( | |
'Transfer-Encoding' not in self.headers | |
or self.headers['Transfer-Encoding'] == 'chunked' | |
) | |
) | |
if calculate_content_length: | |
content_length = peek_content_length(self.stream) | |
unavailable = (content_length is UNAVAILABLE) | |
if content_length is not None and not unavailable: | |
can_be_set = ( | |
content_length > 0 | |
or 'Content-Type' in auto_headers | |
or 'Content-Type' in self.headers | |
) | |
try: | |
del self.headers['Transfer-Encoding'] | |
except KeyError: | |
pass | |
if can_be_set: | |
auto_headers['Content-Length'] = str(content_length) | |
elif 'Transfer-Encoding' not in self.headers: | |
auto_headers['Transfer-Encoding'] = 'chunked' | |
if content_length is not None or not unavailable: | |
self.calculate_content_length_async = False | |
self.calculate_content_length_sync = False | |
if auto_headers: | |
raw_headers = [] | |
keys = [ | |
'Content-Type', 'Content-Length', 'Transfer-Encoding', | |
] | |
for key in keys: | |
value = auto_headers.pop(key, None) | |
if value is None: | |
value = self.headers.pop(key, None) | |
if value is not None: | |
raw_headers.append((to_bytes(key), to_bytes(value))) | |
self.headers = Headers(raw_headers + self.headers.raw) | |
async def aprepare(self): | |
auto_headers = {} | |
calculate_content_type = ( | |
self.calculate_content_type | |
and self.calculate_content_type_async | |
) | |
if calculate_content_type: | |
if 'Content-Type' not in self.headers: | |
content_type = await apeek_content_type(self.stream) | |
unavailable = (content_type is UNAVAILABLE) | |
if content_type is not None and not unavailable: | |
auto_headers['Content-Type'] = content_type | |
if self.calculate_content_length: | |
self.calculate_content_length_async = True | |
if content_type is not None or not unavailable: | |
self.calculate_content_type_sync = False | |
self.calculate_content_type_async = False | |
calculate_content_length = ( | |
self.calculate_content_length | |
and self.calculate_content_length_async | |
) | |
if calculate_content_length: | |
calculate_content_length = ( | |
'Content-Length' not in self.headers | |
and ( | |
'Transfer-Encoding' not in self.headers | |
or self.headers['Transfer-Encoding'] == 'chunked' | |
) | |
) | |
if calculate_content_length: | |
content_length = await apeek_content_length(self.stream) | |
unavailable = (content_length is UNAVAILABLE) | |
if content_length is not None and not unavailable: | |
can_be_set = ( | |
content_length > 0 | |
or 'Content-Type' in auto_headers | |
or 'Content-Type' in self.headers | |
) | |
try: | |
del self.headers['Transfer-Encoding'] | |
except KeyError: | |
pass | |
if can_be_set: | |
auto_headers['Content-Length'] = str(content_length) | |
elif 'Transfer-Encoding' not in self.headers: | |
auto_headers['Transfer-Encoding'] = 'chunked' | |
if content_length is not None or not unavailable: | |
self.calculate_content_length_sync = False | |
self.calculate_content_length_async = False | |
if auto_headers: | |
raw_headers = [] | |
keys = [ | |
'Content-Type', 'Content-Length', 'Transfer-Encoding', | |
] | |
for key in keys: | |
value = auto_headers.pop(key, None) | |
if value is None: | |
value = self.headers.pop(key, None) | |
if value is not None: | |
raw_headers.append((to_bytes(key), to_bytes(value))) | |
self.headers = Headers(raw_headers + self.headers.raw) | |
# ------------------------------ httpx._client ------------------------------ # | |
class Client(httpx_client.Client): | |
def send(self, request, *args, **kwargs): | |
request.prepare() | |
response = super().send(request, *args, **kwargs) | |
response.prepare() | |
return response | |
class AsyncClient(httpx_client.AsyncClient): | |
async def send(self, request, *args, **kwargs): | |
await request.aprepare() | |
response = await super().send(request, *args, **kwargs) | |
await response.aprepare() | |
return response | |
# --------------------------------------------------------------------------- # | |
def apply(): | |
httpx_multipart.FileField = FileField | |
httpx_multipart.MultipartStream = MultipartStream | |
httpx_models.Request = Request | |
httpx_models.Response = Response | |
httpx_transports_asgi.Response = Response | |
httpx_transports_default.Response = Response | |
httpx_transports_wsgi.Response = Response | |
httpx_client.Request = Request | |
httpx_client.Client = Client | |
httpx_client.AsyncClient = AsyncClient | |
httpx_api.Client = Client | |
httpx.Request = Request | |
httpx.Response = Response | |
httpx.Client = Client | |
httpx.AsyncClient = AsyncClient |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment