Created
July 28, 2019 13:00
-
-
Save euri10/5aadc0c8f83ea81cb760c5c88c401fd9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import functools | |
import ssl | |
from pathlib import Path | |
import pytest | |
from databases import Database | |
def async_adapter(wrapped_func): | |
""" | |
Decorator used to run async test cases. | |
""" | |
@functools.wraps(wrapped_func) | |
def run_sync(*args, **kwargs): | |
loop = asyncio.get_event_loop() | |
task = wrapped_func(*args, **kwargs) | |
return loop.run_until_complete(task) | |
return run_sync | |
CERTS = Path("cert") | |
CLIENT_CERTS = CERTS / "client" | |
SSL_CA_CERT_FILE = CLIENT_CERTS / "root.crt" | |
SSL_CERT_FILE = CLIENT_CERTS / "postgresql.crt" | |
SSL_KEY_FILE = CLIENT_CERTS / "postgresql.key" | |
DSN_BASE = "postgresql://postgres:postgres@localhost:5555/postgres" | |
dsn_data = [ | |
(f"{DSN_BASE}?sslmode=disable", False), | |
(f"{DSN_BASE}?sslmode=allow", True), | |
(f"{DSN_BASE}?sslmode=prefer", True), | |
(f"{DSN_BASE}?sslmode=require", True), | |
# (f"{DSN_BASE}?sslcert={SSL_CERT_FILE}&sslkey={SSL_KEY_FILE}", True), # not implemented asyncpg | |
# (f"{DSN_BASE}?sslrootcert={SSL_CA_CERT_FILE}", True), # not implemented asyncpg | |
] | |
@pytest.mark.parametrize("dsn, ssl_expected", dsn_data) | |
@async_adapter | |
async def test_sslmode(dsn, ssl_expected): | |
database = Database(dsn) | |
await database.connect() | |
query = "select current_user, pid, ssl, version, cipher, bits, compression, clientdn from pg_stat_ssl where pid = pg_backend_pid()" | |
row = await database.fetch_one(query=query) | |
assert ssl_expected == row['ssl'] | |
@async_adapter | |
async def test_ssl_context(): | |
ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) | |
ssl_context.load_verify_locations(SSL_CA_CERT_FILE) | |
database = Database(DSN_BASE, ssl=ssl_context) | |
await database.connect() | |
query = "select current_user, pid, ssl, version, cipher, bits, compression, clientdn from pg_stat_ssl where pid = pg_backend_pid()" | |
row = await database.fetch_one(query=query) | |
assert row['ssl'] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment