Skip to content

Instantly share code, notes, and snippets.

@euri10
Created July 28, 2019 13:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save euri10/5aadc0c8f83ea81cb760c5c88c401fd9 to your computer and use it in GitHub Desktop.
Save euri10/5aadc0c8f83ea81cb760c5c88c401fd9 to your computer and use it in GitHub Desktop.
import asyncio
import functools
import ssl
from pathlib import Path
import pytest
from databases import Database
def async_adapter(wrapped_func):
"""
Decorator used to run async test cases.
"""
@functools.wraps(wrapped_func)
def run_sync(*args, **kwargs):
loop = asyncio.get_event_loop()
task = wrapped_func(*args, **kwargs)
return loop.run_until_complete(task)
return run_sync
CERTS = Path("cert")
CLIENT_CERTS = CERTS / "client"
SSL_CA_CERT_FILE = CLIENT_CERTS / "root.crt"
SSL_CERT_FILE = CLIENT_CERTS / "postgresql.crt"
SSL_KEY_FILE = CLIENT_CERTS / "postgresql.key"
DSN_BASE = "postgresql://postgres:postgres@localhost:5555/postgres"
dsn_data = [
(f"{DSN_BASE}?sslmode=disable", False),
(f"{DSN_BASE}?sslmode=allow", True),
(f"{DSN_BASE}?sslmode=prefer", True),
(f"{DSN_BASE}?sslmode=require", True),
# (f"{DSN_BASE}?sslcert={SSL_CERT_FILE}&sslkey={SSL_KEY_FILE}", True), # not implemented asyncpg
# (f"{DSN_BASE}?sslrootcert={SSL_CA_CERT_FILE}", True), # not implemented asyncpg
]
@pytest.mark.parametrize("dsn, ssl_expected", dsn_data)
@async_adapter
async def test_sslmode(dsn, ssl_expected):
database = Database(dsn)
await database.connect()
query = "select current_user, pid, ssl, version, cipher, bits, compression, clientdn from pg_stat_ssl where pid = pg_backend_pid()"
row = await database.fetch_one(query=query)
assert ssl_expected == row['ssl']
@async_adapter
async def test_ssl_context():
ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
ssl_context.load_verify_locations(SSL_CA_CERT_FILE)
database = Database(DSN_BASE, ssl=ssl_context)
await database.connect()
query = "select current_user, pid, ssl, version, cipher, bits, compression, clientdn from pg_stat_ssl where pid = pg_backend_pid()"
row = await database.fetch_one(query=query)
assert row['ssl']
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment