Skip to content

Instantly share code, notes, and snippets.

@yayahuman
Last active September 3, 2023 15:13
Show Gist options
  • Save yayahuman/db06718ffdf8a9b66e133e29d7d7965f to your computer and use it in GitHub Desktop.
Save yayahuman/db06718ffdf8a9b66e133e29d7d7965f to your computer and use it in GitHub Desktop.
Support async file types in `files = {}` and `content = ...`
#!/usr/bin/env python3
# SPDX-FileCopyrightText: 2023 yayahuman <yayahuman@pm.me>
# SPDX-License-Identifier: 0BSD
__all__ = (
'apply',
)
import binascii
import os
import warnings
from abc import ABC
from inspect import isawaitable, iscoroutinefunction
from io import UnsupportedOperation
from json import dumps as json_dumps
from pathlib import PurePath
from typing import AsyncIterable, Iterable, Mapping, Sequence
from urllib.parse import urlencode
import httpx
from httpx import (
_api as httpx_api,
_client as httpx_client,
_models as httpx_models,
_multipart as httpx_multipart,
)
from httpx._exceptions import StreamConsumed
from httpx._transports import (
asgi as httpx_transports_asgi,
default as httpx_transports_default,
wsgi as httpx_transports_wsgi,
)
from httpx._urls import URL
# ------------------------------- httpx._utils ------------------------------ #
from httpx._utils import guess_content_type, primitive_value_to_str, to_bytes
MISSING = object()
UNAVAILABLE = object()
async def maybe_await(obj):
if isawaitable(obj):
return await obj
else:
return obj
def is_file_content(obj):
return (
not isinstance(obj, (set, dict, list, tuple))
and (
isinstance(obj, (str, bytes, Iterable, AsyncIterable))
or hasattr(obj, 'read')
or hasattr(obj, 'aread')
)
)
def peek_content_type(obj):
if hasattr(obj, 'content_type'):
content_type = obj.content_type
elif hasattr(obj, 'get_content_type'):
content_type = obj.get_content_type()
else:
content_type = None
return content_type
async def apeek_content_type(obj):
if hasattr(obj, 'content_type'):
content_type = obj.content_type
elif hasattr(obj, 'aget_content_type'):
content_type = await obj.aget_content_type()
elif hasattr(obj, 'get_content_type'):
content_type = obj.get_content_type()
else:
content_type = None
return content_type
def peek_content_length(obj, encoding='utf-8'):
if isinstance(obj, (str, bytes)):
content_length = len(to_bytes(obj, encoding))
elif hasattr(obj, 'content_length'):
content_length = obj.content_length
elif hasattr(obj, 'get_content_length'):
content_length = obj.get_content_length()
else:
try:
fd = obj.fileno()
content_length = os.fstat(fd).st_size
except (AttributeError, OSError):
try:
tell = obj.tell
seek = obj.seek
if iscoroutinefunction(tell) or iscoroutinefunction(seek):
content_length = UNAVAILABLE
else:
offset = tell()
content_length = seek(0, os.SEEK_END)
seek(offset)
except (AttributeError, OSError):
content_length = None
return content_length
async def apeek_content_length(obj, encoding='utf-8'):
if isinstance(obj, (str, bytes)):
content_length = len(to_bytes(obj, encoding))
elif hasattr(obj, 'content_length'):
content_length = obj.content_length
elif hasattr(obj, 'aget_content_length'):
content_length = await obj.aget_content_length()
elif hasattr(obj, 'get_content_length'):
content_length = obj.get_content_length()
else:
try:
fd = obj.fileno()
content_length = os.fstat(fd).st_size
except (AttributeError, OSError):
try:
tell = obj.tell
seek = obj.seek
offset = await maybe_await(tell())
content_length = await maybe_await(seek(0, os.SEEK_END))
await maybe_await(seek(offset))
except (AttributeError, OSError):
content_length = None
return content_length
# ------------------------------ httpx._content ----------------------------- #
from httpx._types import AsyncByteStream, SyncByteStream
class DefaultEncoding(str):
pass
FIXED_WIDTH_ENCODING = DefaultEncoding('ascii') # 1 char = 1 byte
VARIABLE_WIDTH_ENCODING = DefaultEncoding('utf-8') # 1 char = 1-4 bytes
class ByteStream(ABC):
__slots__ = ()
@classmethod
def from_(
cls,
content=None,
*,
data=None,
files=None,
boundary=None,
text=None,
html=None,
json=None,
encoding=None,
):
if data is not None and not isinstance(data, Mapping):
message = 'Use \'content=<...>\' to upload raw bytes/text content.'
warnings.warn(message, DeprecationWarning)
return cls.from_content(data, encoding)
elif content is not None:
return cls.from_content(content, encoding)
elif files:
return cls.from_multipart_data(
data=data, files=files, boundary=boundary, encoding=encoding,
)
elif data:
return cls.from_urlencoded_data(data)
elif text is not None:
return cls.from_text(text, encoding)
elif html is not None:
return cls.from_html(html, encoding)
elif json is not None:
return cls.from_json(json)
else:
return SimpleByteStream(b'')
@staticmethod
def from_content(content, encoding=None):
if is_file_content(content):
if isinstance(content, (str, bytes)):
return SimpleByteStream(content, encoding=encoding)
elif isinstance(content, (SyncByteStream, AsyncByteStream)):
return content
else:
return IteratorByteStream(content, encoding=encoding)
else:
raise TypeError(f"Unexpected type for content, {type(content)!r}")
@staticmethod
def from_multipart_data(
data=None,
*,
files=None,
boundary=None,
encoding=None,
):
if data is None:
data = {}
if files is None:
files = {}
return MultipartStream(data, files, boundary, encoding)
@staticmethod
def from_urlencoded_data(data):
plain_data = []
convert = primitive_value_to_str
for key, value in data.items():
if isinstance(value, (list, tuple)):
plain_data.extend((key, convert(x)) for x in value)
else:
plain_data.append((key, convert(value)))
body = urlencode(plain_data, doseq=True).encode()
return SimpleByteStream(body, 'application/x-www-form-urlencoded')
@staticmethod
def from_text(text, encoding=None):
if encoding is None:
encoding = SimpleByteStream.DEFAULT_ENCODING
return SimpleByteStream(
text, f"text/plain; charset={encoding.lower()}", encoding=encoding,
)
@staticmethod
def from_html(html, encoding=None):
if encoding is None:
encoding = SimpleByteStream.DEFAULT_ENCODING
return SimpleByteStream(
html, f"text/html; charset={encoding.lower()}", encoding=encoding,
)
@staticmethod
def from_json(json):
return SimpleByteStream(json_dumps(json).encode(), 'application/json')
ByteStream.register(SyncByteStream)
ByteStream.register(AsyncByteStream)
class SimpleByteStream(SyncByteStream, AsyncByteStream):
DEFAULT_ENCODING = VARIABLE_WIDTH_ENCODING
def __init__(
self,
content,
content_type=MISSING,
content_length=MISSING,
*,
encoding=None,
):
if encoding is None:
encoding = self.DEFAULT_ENCODING
self.content = content
if content_type is not MISSING:
self.content_type = content_type
if content_length is not MISSING:
self.content_length = content_length
self.encoding = encoding
def __iter__(self):
yield to_bytes(self.content, self.encoding)
async def __aiter__(self):
yield to_bytes(self.content, self.encoding)
def get_content_type(self):
return peek_content_type(self.content)
async def aget_content_type(self):
return await apeek_content_type(self.content)
def get_content_length(self):
return peek_content_length(self.content, self.encoding)
async def aget_content_length(self):
return await apeek_content_length(self.content, self.encoding)
class IteratorByteStream(SyncByteStream, AsyncByteStream):
DEFAULT_ENCODING = FIXED_WIDTH_ENCODING
DEFAULT_CHUNK_SIZE = 64 * 1024
def __init__(
self,
content,
content_type=MISSING,
content_length=MISSING,
*,
encoding=None,
chunk_size=None,
):
if encoding is None:
encoding = self.DEFAULT_ENCODING
if chunk_size is None:
chunk_size = self.DEFAULT_CHUNK_SIZE
self.content = content
if content_type is not MISSING:
self.content_type = content_type
if content_length is not MISSING:
self.content_length = content_length
self.encoding = encoding
self.chunk_size = chunk_size
self.consumed = False
def __iter__(self):
consumed = self.consumed
encoding = self.encoding
chunk_size = self.chunk_size
self.consumed = True
if isinstance(self.content, (str, bytes)):
yield to_bytes(self.content, encoding)
return
try:
self.seek(0)
except UnsupportedOperation:
if consumed and not isinstance(self.content, Sequence):
raise StreamConsumed from None
if hasattr(self.content, 'read'):
while True:
chunk = self.content.read(chunk_size)
if not chunk:
break
yield to_bytes(chunk, encoding)
else:
for part in self.content:
yield to_bytes(part, encoding)
async def __aiter__(self):
consumed = self.consumed
encoding = self.encoding
chunk_size = self.chunk_size
self.consumed = True
if isinstance(self.content, (str, bytes)):
yield to_bytes(self.content, encoding)
return
try:
await self.aseek(0)
except UnsupportedOperation:
if consumed and not isinstance(self.content, Sequence):
raise StreamConsumed from None
if hasattr(self.content, 'aread'):
while True:
chunk = await self.content.aread(chunk_size)
if not chunk:
break
yield to_bytes(chunk, encoding)
elif hasattr(self.content, 'read'):
if iscoroutinefunction(self.content.read):
while True:
chunk = await self.content.read(chunk_size)
if not chunk:
break
yield to_bytes(chunk, encoding)
elif hasattr(self.content, '__aiter__'):
async for part in self.content:
yield to_bytes(part, encoding)
else:
while True:
chunk = await maybe_await(self.content.read(chunk_size))
if not chunk:
break
yield to_bytes(chunk, encoding)
elif hasattr(self.content, '__aiter__'):
async for part in self.content:
yield to_bytes(part, encoding)
else:
for part in self.content:
yield to_bytes(part, encoding)
def seek(self, offset, whence=os.SEEK_SET):
if hasattr(self.content, 'seek'):
return self.content.seek(offset, whence)
else:
raise UnsupportedOperation
async def aseek(self, offset, whence=os.SEEK_SET):
if hasattr(self.content, 'aseek'):
return await self.content.aseek(offset, whence)
if hasattr(self.content, 'seek'):
return await maybe_await(self.content.seek(offset, whence))
else:
raise UnsupportedOperation
def get_content_type(self):
return peek_content_type(self.content)
async def aget_content_type(self):
return await apeek_content_type(self.content)
def get_content_length(self):
return peek_content_length(self.content, self.encoding)
async def aget_content_length(self):
return await apeek_content_length(self.content, self.encoding)
# ----------------------------- httpx._multipart ---------------------------- #
from httpx._multipart import (
DataField,
get_multipart_boundary_from_content_type,
)
class FileField(httpx_multipart.FileField):
DEFAULT_FILENAME = 'upload'
def __init__(self, name, value, encoding):
if not isinstance(name, str):
raise TypeError(
'Invalid type for name.'
f" Expected str, got {type(name)}: {name!r}"
)
content_type = None
headers = None
if isinstance(value, (list, tuple)):
if len(value) == 2:
filename, file = value
elif len(value) == 3:
filename, file, content_type = value
elif len(value) == 4:
filename, file, content_type, headers = value
else:
filename, file, content_type, headers, encoding = value
else:
file = value
filename = getattr(file, 'name', None)
if filename is None:
filename = self.DEFAULT_FILENAME
else:
filename = PurePath(str(filename)).name
if not is_file_content(file):
raise TypeError(
'Invalid type for file.'
f" Expected FileContent, got {type(name)}: {name!r}"
)
if headers is None:
headers = {}
self.name = name
self.file = file
self.filename = filename
self.headers = headers
self.encoding = encoding
has_content_type_header = any(
'content-type' in key.lower() for key in headers
)
if not has_content_type_header:
if content_type is None:
self.calculate_content_type = True
self.calculate_content_type_sync = True
self.calculate_content_type_async = True
else:
self.calculate_content_type = False
self.calculate_content_type_sync = False
self.calculate_content_type_async = False
self.headers = {'Content-Type': content_type, **self.headers}
else:
self.calculate_content_type = False
self.calculate_content_type_sync = False
self.calculate_content_type_async = False
self.prepare()
def prepare(self):
auto_headers = {}
calculate_content_type = (
self.calculate_content_type
and self.calculate_content_type_sync
)
if calculate_content_type:
has_content_type = any(
'content-type' in key.lower() for key in self.headers
)
if not has_content_type:
content_type = peek_content_type(self.file)
unavailable = (content_type is UNAVAILABLE)
if content_type is None or unavailable:
content_type = guess_content_type(self.filename)
if content_type is not None:
auto_headers['Content-Type'] = content_type
if content_type is not None or not unavailable:
self.calculate_content_type_async = False
self.calculate_content_type_sync = False
if auto_headers:
self.headers = {**auto_headers, **self.headers}
async def aprepare(self):
auto_headers = {}
calculate_content_type = (
self.calculate_content_type
and self.calculate_content_type_async
)
if calculate_content_type:
has_content_type = any(
'content-type' in key.lower() for key in self.headers
)
if not has_content_type:
content_type = await apeek_content_type(self.file)
unavailable = (content_type is UNAVAILABLE)
if content_type is None or unavailable:
content_type = guess_content_type(self.filename)
if content_type is not None:
auto_headers['Content-Type'] = content_type
if content_type is not None or not unavailable:
self.calculate_content_type_sync = False
self.calculate_content_type_async = False
if auto_headers:
self.headers = {**auto_headers, **self.headers}
def get_length(self):
content_length = peek_content_length(self.stream)
if content_length is None:
return None
return len(self.render_headers()) + content_length
async def aget_length(self):
content_length = await apeek_content_length(self.stream)
if content_length is None:
return None
return len(await self.arender_headers()) + content_length
def render_headers(self):
calculate_content_type = (
self.calculate_content_type
and self.calculate_content_type_sync
)
if calculate_content_type or not hasattr(self, '_headers'):
self.prepare()
if hasattr(self, '_headers'):
del self._headers
super().render_headers()
return self._headers
async def arender_headers(self):
calculate_content_type = (
self.calculate_content_type
and self.calculate_content_type_async
)
if calculate_content_type or not hasattr(self, '_headers'):
await self.aprepare()
if hasattr(self, '_headers'):
del self._headers
super().render_headers()
return self._headers
def render_data(self):
yield from self.stream
async def arender_data(self):
async for chunk in self.stream:
yield chunk
async def arender(self):
yield await self.arender_headers()
async for chunk in self.arender_data():
yield chunk
@property
def stream(self):
try:
stream = self._stream
except AttributeError:
pass
else:
if stream.content is self.file:
return stream
self._stream = stream = ByteStream.from_(
self.file, encoding=self.encoding,
)
return stream
class MultipartStream(httpx_multipart.MultipartStream):
def __init__(self, data, files, boundary, encoding):
if boundary is None:
boundary = binascii.hexlify(os.urandom(16))
self.boundary = boundary
self.content_type = (
f"multipart/form-data; boundary={boundary.decode('ascii')}"
)
self.fields = list(self._iter_fields(data, files, encoding))
async def __aiter__(self):
async for chunk in self.aiter_chunks():
yield chunk
def _iter_fields(self, data, files, encoding):
for name, value in data.items():
if isinstance(value, (list, tuple)):
for item in value:
yield DataField(name, item)
else:
yield DataField(name, value)
if isinstance(files, Mapping):
file_items = files.items()
else:
file_items = files
for name, value in file_items:
yield FileField(name, value, encoding)
async def aiter_chunks(self):
for field in self.fields:
yield b'--%s\r\n' % self.boundary
if hasattr(field, 'arender'):
async for chunk in field.arender():
yield chunk
else:
for chunk in field.render():
yield chunk
yield b'\r\n'
yield b'--%s--\r\n' % self.boundary
async def aget_content_length(self):
boundary_length = len(self.boundary)
content_length = 0
for field in self.fields:
content_length += 2 + boundary_length + 2 # b'--{boundary}\r\n'
if hasattr(field, 'aget_length'):
field_length = await field.aget_length()
else:
field_length = field.get_length()
if field_length is None:
return None
content_length += field_length
content_length += 2 # b'\r\n'
content_length += 2 + boundary_length + 4 # b'--{boundary}--\r\n'
return content_length
# ------------------------------ httpx._models ------------------------------ #
from httpx._models import Cookies, Headers
class Request(httpx_models.Request):
def __init__(
self,
method,
url,
*,
params=None,
headers=None,
cookies=None,
content=None,
data=None,
files=None,
json=None,
stream=None,
stream_encoding=None,
extensions=None,
):
if isinstance(method, bytes):
method = method.decode('ascii').upper()
else:
method = method.upper()
if params is None:
url = URL(url)
else:
url = URL(url).copy_merge_params(params=params)
headers = Headers(headers)
if stream_encoding is None:
has_only_transfer_encoding_header = (
'Content-Length' not in headers
and 'Transfer-Encoding' in headers
)
if has_only_transfer_encoding_header:
stream_encoding = VARIABLE_WIDTH_ENCODING
if extensions is None:
extensions = {}
self.method = method
self.url = url
self.headers = headers
self.extensions = extensions
if cookies:
Cookies(cookies).set_cookie_header(self)
if stream is None:
content_type = self.headers.get('Content-Type')
stream = ByteStream.from_(
content=content,
data=data,
files=files,
json=json,
boundary=get_multipart_boundary_from_content_type(
content_type=content_type.encode(self.headers.encoding)
if content_type
else None
),
encoding=stream_encoding,
)
has_message_framing_headers = (
'Content-Length' in self.headers
or 'Transfer-Encoding' in self.headers
)
if 'Host' not in self.headers:
self.calculate_host = True
else:
self.calculate_host = False
if 'Content-Type' not in self.headers:
self.calculate_content_type = True
self.calculate_content_type_sync = True
self.calculate_content_type_async = True
else:
self.calculate_content_type = False
self.calculate_content_type_sync = False
self.calculate_content_type_async = False
if not has_message_framing_headers:
self.calculate_content_length = True
self.calculate_content_length_sync = True
self.calculate_content_length_async = True
else:
self.calculate_content_length = False
self.calculate_content_length_sync = False
self.calculate_content_length_async = False
self.stream = stream
self.prepare()
if isinstance(stream, SimpleByteStream):
self.read()
else:
self.calculate_host = False
self.calculate_content_type = False
self.calculate_content_type_sync = False
self.calculate_content_type_async = False
self.calculate_content_length = False
self.calculate_content_length_sync = False
self.calculate_content_length_async = False
self.stream = stream
def prepare(self):
auto_headers = {}
calculate_host = self.calculate_host
if calculate_host:
if 'Host' not in self.headers:
if self.url.host:
auto_headers['Host'] = self.url.netloc
self.calculate_host = False
calculate_content_type = (
self.calculate_content_type
and self.calculate_content_type_sync
)
if calculate_content_type:
if 'Content-Type' not in self.headers:
content_type = peek_content_type(self.stream)
unavailable = (content_type is UNAVAILABLE)
if content_type is not None and not unavailable:
auto_headers['Content-Type'] = content_type
if self.calculate_content_length:
self.calculate_content_length_sync = True
if content_type is not None or not unavailable:
self.calculate_content_type_async = False
self.calculate_content_type_sync = False
calculate_content_length = (
self.calculate_content_length
and self.calculate_content_length_sync
)
if calculate_content_length:
calculate_content_length = (
'Content-Length' not in self.headers
and (
'Transfer-Encoding' not in self.headers
or self.headers['Transfer-Encoding'] == 'chunked'
)
)
if calculate_content_length:
content_length = peek_content_length(self.stream)
unavailable = (content_length is UNAVAILABLE)
if content_length is not None and not unavailable:
can_be_set = (
content_length > 0
or 'Content-Type' in auto_headers
or 'Content-Type' in self.headers
or self.method in ['POST', 'PUT', 'PATCH']
)
try:
del self.headers['Transfer-Encoding']
except KeyError:
pass
if can_be_set:
auto_headers['Content-Length'] = str(content_length)
elif 'Transfer-Encoding' not in self.headers:
auto_headers['Transfer-Encoding'] = 'chunked'
if content_length is not None or not unavailable:
self.calculate_content_length_async = False
self.calculate_content_length_sync = False
if auto_headers:
raw_headers = []
keys = [
'Host', 'Content-Type', 'Content-Length', 'Transfer-Encoding',
]
for key in keys:
value = auto_headers.pop(key, None)
if value is None:
value = self.headers.pop(key, None)
if value is not None:
raw_headers.append((to_bytes(key), to_bytes(value)))
self.headers = Headers(raw_headers + self.headers.raw)
async def aprepare(self):
auto_headers = {}
calculate_host = self.calculate_host
if calculate_host:
if 'Host' not in self.headers:
if self.url.host:
auto_headers['Host'] = self.url.netloc
self.calculate_host = False
calculate_content_type = (
self.calculate_content_type
and self.calculate_content_type_async
)
if calculate_content_type:
if 'Content-Type' not in self.headers:
content_type = await apeek_content_type(self.stream)
unavailable = (content_type is UNAVAILABLE)
if content_type is not None and not unavailable:
auto_headers['Content-Type'] = content_type
if self.calculate_content_length:
self.calculate_content_length_async = True
if content_type is not None or not unavailable:
self.calculate_content_type_sync = False
self.calculate_content_type_async = False
calculate_content_length = (
self.calculate_content_length
and self.calculate_content_length_async
)
if calculate_content_length:
calculate_content_length = (
'Content-Length' not in self.headers
and (
'Transfer-Encoding' not in self.headers
or self.headers['Transfer-Encoding'] == 'chunked'
)
)
if calculate_content_length:
content_length = await apeek_content_length(self.stream)
unavailable = (content_length is UNAVAILABLE)
if content_length is not None and not unavailable:
can_be_set = (
content_length > 0
or 'Content-Type' in auto_headers
or 'Content-Type' in self.headers
or self.method in ['POST', 'PUT', 'PATCH']
)
try:
del self.headers['Transfer-Encoding']
except KeyError:
pass
if can_be_set:
auto_headers['Content-Length'] = str(content_length)
elif 'Transfer-Encoding' not in self.headers:
auto_headers['Transfer-Encoding'] = 'chunked'
if content_length is not None or not unavailable:
self.calculate_content_length_sync = False
self.calculate_content_length_async = False
if auto_headers:
raw_headers = []
keys = [
'Host', 'Content-Type', 'Content-Length', 'Transfer-Encoding',
]
for key in keys:
value = auto_headers.pop(key, None)
if value is None:
value = self.headers.pop(key, None)
if value is not None:
raw_headers.append((to_bytes(key), to_bytes(value)))
self.headers = Headers(raw_headers + self.headers.raw)
class Response(httpx_models.Response):
def __init__(
self,
status_code,
*,
headers=None,
content=None,
text=None,
html=None,
json=None,
stream=None,
stream_encoding=None,
request=None,
extensions=None,
history=None,
default_encoding='utf-8',
):
headers = Headers(headers)
if stream_encoding is None:
has_only_transfer_encoding_header = (
'Content-Length' not in headers
and 'Transfer-Encoding' in headers
)
if has_only_transfer_encoding_header:
stream_encoding = VARIABLE_WIDTH_ENCODING
if extensions is None:
extensions = {}
if history is None:
history = []
else:
history = list(history)
self.status_code = status_code
self.headers = headers
self._request = request
self.next_request = None
self.extensions = extensions
self.history = history
self.is_closed = False
self.is_stream_consumed = False
self.default_encoding = default_encoding
self._num_bytes_downloaded = 0
if stream is None:
stream = ByteStream.from_(
content=content,
text=text,
html=html,
json=json,
encoding=stream_encoding,
)
has_message_framing_headers = (
'Content-Length' in self.headers
or 'Transfer-Encoding' in self.headers
)
if 'Content-Type' not in self.headers:
self.calculate_content_type = True
self.calculate_content_type_sync = True
self.calculate_content_type_async = True
else:
self.calculate_content_type = False
self.calculate_content_type_sync = False
self.calculate_content_type_async = False
if not has_message_framing_headers:
self.calculate_content_length = True
self.calculate_content_length_sync = True
self.calculate_content_length_async = True
else:
self.calculate_content_length = False
self.calculate_content_length_sync = False
self.calculate_content_length_async = False
self.stream = stream
self.prepare()
if isinstance(stream, SimpleByteStream):
self.read()
else:
self.calculate_content_type = False
self.calculate_content_type_sync = False
self.calculate_content_type_async = False
self.calculate_content_length = False
self.calculate_content_length_sync = False
self.calculate_content_length_async = False
self.stream = stream
def prepare(self):
auto_headers = {}
calculate_content_type = (
self.calculate_content_type
and self.calculate_content_type_sync
)
if calculate_content_type:
if 'Content-Type' not in self.headers:
content_type = peek_content_type(self.stream)
unavailable = (content_type is UNAVAILABLE)
if content_type is not None and not unavailable:
auto_headers['Content-Type'] = content_type
if self.calculate_content_length:
self.calculate_content_length_sync = True
if content_type is not None or not unavailable:
self.calculate_content_type_async = False
self.calculate_content_type_sync = False
calculate_content_length = (
self.calculate_content_length
and self.calculate_content_length_sync
)
if calculate_content_length:
calculate_content_length = (
'Content-Length' not in self.headers
and (
'Transfer-Encoding' not in self.headers
or self.headers['Transfer-Encoding'] == 'chunked'
)
)
if calculate_content_length:
content_length = peek_content_length(self.stream)
unavailable = (content_length is UNAVAILABLE)
if content_length is not None and not unavailable:
can_be_set = (
content_length > 0
or 'Content-Type' in auto_headers
or 'Content-Type' in self.headers
)
try:
del self.headers['Transfer-Encoding']
except KeyError:
pass
if can_be_set:
auto_headers['Content-Length'] = str(content_length)
elif 'Transfer-Encoding' not in self.headers:
auto_headers['Transfer-Encoding'] = 'chunked'
if content_length is not None or not unavailable:
self.calculate_content_length_async = False
self.calculate_content_length_sync = False
if auto_headers:
raw_headers = []
keys = [
'Content-Type', 'Content-Length', 'Transfer-Encoding',
]
for key in keys:
value = auto_headers.pop(key, None)
if value is None:
value = self.headers.pop(key, None)
if value is not None:
raw_headers.append((to_bytes(key), to_bytes(value)))
self.headers = Headers(raw_headers + self.headers.raw)
async def aprepare(self):
auto_headers = {}
calculate_content_type = (
self.calculate_content_type
and self.calculate_content_type_async
)
if calculate_content_type:
if 'Content-Type' not in self.headers:
content_type = await apeek_content_type(self.stream)
unavailable = (content_type is UNAVAILABLE)
if content_type is not None and not unavailable:
auto_headers['Content-Type'] = content_type
if self.calculate_content_length:
self.calculate_content_length_async = True
if content_type is not None or not unavailable:
self.calculate_content_type_sync = False
self.calculate_content_type_async = False
calculate_content_length = (
self.calculate_content_length
and self.calculate_content_length_async
)
if calculate_content_length:
calculate_content_length = (
'Content-Length' not in self.headers
and (
'Transfer-Encoding' not in self.headers
or self.headers['Transfer-Encoding'] == 'chunked'
)
)
if calculate_content_length:
content_length = await apeek_content_length(self.stream)
unavailable = (content_length is UNAVAILABLE)
if content_length is not None and not unavailable:
can_be_set = (
content_length > 0
or 'Content-Type' in auto_headers
or 'Content-Type' in self.headers
)
try:
del self.headers['Transfer-Encoding']
except KeyError:
pass
if can_be_set:
auto_headers['Content-Length'] = str(content_length)
elif 'Transfer-Encoding' not in self.headers:
auto_headers['Transfer-Encoding'] = 'chunked'
if content_length is not None or not unavailable:
self.calculate_content_length_sync = False
self.calculate_content_length_async = False
if auto_headers:
raw_headers = []
keys = [
'Content-Type', 'Content-Length', 'Transfer-Encoding',
]
for key in keys:
value = auto_headers.pop(key, None)
if value is None:
value = self.headers.pop(key, None)
if value is not None:
raw_headers.append((to_bytes(key), to_bytes(value)))
self.headers = Headers(raw_headers + self.headers.raw)
# ------------------------------ httpx._client ------------------------------ #
class Client(httpx_client.Client):
def send(self, request, *args, **kwargs):
request.prepare()
response = super().send(request, *args, **kwargs)
response.prepare()
return response
class AsyncClient(httpx_client.AsyncClient):
async def send(self, request, *args, **kwargs):
await request.aprepare()
response = await super().send(request, *args, **kwargs)
await response.aprepare()
return response
# --------------------------------------------------------------------------- #
def apply():
httpx_multipart.FileField = FileField
httpx_multipart.MultipartStream = MultipartStream
httpx_models.Request = Request
httpx_models.Response = Response
httpx_transports_asgi.Response = Response
httpx_transports_default.Response = Response
httpx_transports_wsgi.Response = Response
httpx_client.Request = Request
httpx_client.Client = Client
httpx_client.AsyncClient = AsyncClient
httpx_api.Client = Client
httpx.Request = Request
httpx.Response = Response
httpx.Client = Client
httpx.AsyncClient = AsyncClient
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment