Skip to content

Instantly share code, notes, and snippets.

@ryankask
Created November 9, 2018 21:23
Show Gist options
  • Save ryankask/90c0dc155d3f95a52a1eaa0ac4ef63d0 to your computer and use it in GitHub Desktop.
Save ryankask/90c0dc155d3f95a52a1eaa0ac4ef63d0 to your computer and use it in GitHub Desktop.
Async Starlette test client
import http
import io
import typing
from urllib.parse import unquote, urljoin, urlparse
import requests
class _HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict):
def get_all(self, key, default):
return self.getheaders(key)
class _MockOriginalResponse(object):
"""
We have to jump through some hoops to present the response as if
it was made using urllib3.
"""
def __init__(self, headers):
self.msg = _HeaderDict(headers)
self.closed = False
def isclosed(self):
return self.closed
def _get_reason_phrase(status_code):
try:
return http.HTTPStatus(status_code).phrase
except ValueError:
return ""
class StarletteTestClient:
def __init__(
self, app: typing.Callable, base_url: str, raise_server_exceptions=True
) -> None:
self.app = app
self.base_url = base_url
self.raise_server_exceptions = True
self.headers = requests.utils.default_headers()
self.headers.update({"user-agent": "testclient"})
async def send(self, request, *args, **kwargs):
scheme, netloc, path, params, query, fragement = urlparse(request.url)
if ":" in netloc:
host, port = netloc.split(":", 1)
port = int(port)
else:
host = netloc
port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme]
# Include the 'host' header.
if "host" in request.headers:
headers = []
elif port == 80:
headers = [[b"host", host.encode()]]
else:
headers = [[b"host", ("%s:%d" % (host, port)).encode()]]
# Include other request headers.
headers += [
[key.lower().encode(), value.encode()]
for key, value in request.headers.items()
]
if scheme in {"ws", "wss"}:
raise RuntimeError("Web sockets not supported")
scope = {
"type": "http",
"http_version": "1.1",
"method": request.method,
"path": unquote(path),
"root_path": "",
"scheme": scheme,
"query_string": query.encode(),
"headers": headers,
"client": ["testclient", 50000],
"server": [host, port],
}
async def receive():
body = request.body
if isinstance(body, str):
body_bytes = body.encode("utf-8") # type: bytes
elif body is None:
body_bytes = b""
else:
body_bytes = body
return {"type": "http.request", "body": body_bytes}
async def send(message):
nonlocal raw_kwargs, response_started, response_complete
if message["type"] == "http.response.start":
assert (
not response_started
), 'Received multiple "http.response.start" messages.'
raw_kwargs["version"] = 11
raw_kwargs["status"] = message["status"]
raw_kwargs["reason"] = _get_reason_phrase(message["status"])
raw_kwargs["headers"] = [
(key.decode(), value.decode()) for key, value in message["headers"]
]
raw_kwargs["preload_content"] = False
raw_kwargs["original_response"] = _MockOriginalResponse(
raw_kwargs["headers"]
)
response_started = True
elif message["type"] == "http.response.body":
assert (
response_started
), 'Received "http.response.body" without "http.response.start".'
assert (
not response_complete
), 'Received "http.response.body" after response completed.'
body = message.get("body", b"")
more_body = message.get("more_body", False)
raw_kwargs["body"].write(body)
if not more_body:
raw_kwargs["body"].seek(0)
response_complete = True
response_started = False
response_complete = False
raw_kwargs = {"body": io.BytesIO()}
try:
connection = self.app(scope)
await connection(receive, send)
except BaseException as exc:
if self.raise_server_exceptions:
raise exc from None
if self.raise_server_exceptions:
assert response_started, "TestClient did not receive any response."
elif not response_started:
raw_kwargs = {
"version": 11,
"status": 500,
"reason": "Internal Server Error",
"headers": [],
"preload_content": False,
"original_response": _MockOriginalResponse([]),
"body": io.BytesIO(),
}
raw = requests.packages.urllib3.HTTPResponse(**raw_kwargs)
return requests.adapters.HTTPAdapter.build_response(self, request, raw)
def prepare_request(self, request):
cookies = request.cookies or {}
# Bootstrap CookieJar.
if not isinstance(cookies, requests.compat.cookielib.CookieJar):
cookies = requests.cookies.cookiejar_from_dict(cookies)
p = requests.PreparedRequest()
p.prepare(
method=request.method.upper(),
url=request.url,
files=request.files,
data=request.data,
json=request.json,
headers=requests.sessions.merge_setting(
request.headers,
self.headers,
dict_class=requests.structures.CaseInsensitiveDict,
),
params=request.params,
auth=request.auth,
cookies=cookies,
hooks=request.hooks,
)
return p
async def request(
self,
method,
url,
params=None,
data=None,
headers=None,
cookies=None,
files=None,
auth=None,
timeout=None,
allow_redirects=True,
proxies=None,
hooks=None,
stream=None,
verify=None,
cert=None,
json=None,
):
req = requests.Request(
method=method.upper(),
url=urljoin(self.base_url, url),
headers=headers,
files=files,
data=data or {},
json=json,
params=params or {},
auth=auth,
cookies=cookies,
hooks=hooks,
)
return await self.send(
self.prepare_request(req),
timeout=timeout,
allow_redirects=allow_redirects,
proxies=proxies or {},
)
async def get(self, url, **kwargs):
kwargs.setdefault("allow_redirects", True)
return await self.request("GET", url, **kwargs)
async def options(self, url, **kwargs):
kwargs.setdefault("allow_redirects", True)
return await self.request("OPTIONS", url, **kwargs)
async def head(self, url, **kwargs):
kwargs.setdefault("allow_redirects", False)
return await self.request("HEAD", url, **kwargs)
async def post(self, url, data=None, json=None, **kwargs): # noqa
return await self.request("POST", url, data=data, json=json, **kwargs)
async def put(self, url, data=None, **kwargs):
return await self.request("PUT", url, data=data, **kwargs)
async def patch(self, url, data=None, **kwargs):
return await self.request("PATCH", url, data=data, **kwargs)
async def delete(self, url, **kwargs):
return await self.request("DELETE", url, **kwargs)
@sbv-trueenergy
Copy link

sbv-trueenergy commented Jun 12, 2019

And for the new ASGI3 apps it would need to be

await self.app(scope, receive, send)

instead of

connection = self.app(scope)
await connection(receive, end)

Aka if you see:

TypeError: __call__() missing 2 required positional arguments: 'receive' and 'send'

Reference: encode/starlette#435

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment