Last active
December 19, 2019 17:14
-
-
Save cspinelive/9ba8c4034b42c00683ef175f54068546 to your computer and use it in GitHub Desktop.
Django Raw Sql Helper
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# | |
# Acme.db.py | |
# | |
import boto3 | |
from django.db import connections | |
from random import choice | |
def sign_for_s3(object_name, expiration=3600): | |
bucket_name = 'acme-bucket' | |
s3_client = boto3.client('s3') | |
response = s3_client.generate_presigned_url('get_object', Params={'Bucket': bucket_name, 'Key': object_name}, ExpiresIn=expiration) | |
return response.url | |
def sign_dict_for_s3(d, keys): | |
# commonly used on results of raw sql queries that used dictfetchall | |
for key in keys: | |
if key in d: | |
s3_name = d[key] | |
d[key] = sign_for_s3(s3_name) | |
def dictfetchall(cursor): | |
"Returns all rows from a cursor as a dict" | |
columns = [col[0] for col in cursor.description] | |
return [ | |
dict(zip(columns, row)) | |
for row in cursor.fetchall() | |
] | |
def db_connection(use_replica=True): | |
default = 'default' | |
replicas = ['replica-1', 'replica-2', 'replica-3'] | |
if not use_replica: | |
return default | |
return choice(replicas) | |
class AcmeQuery(object): | |
""" | |
helper class for running raw sql | |
uses dictfetchall by default | |
when only returning one column, flattens dict into a list like .values_list(flat=True) | |
accepts a post processing function to be ran on each row if needed | |
has support for replica databases | |
------------------------------------------------------------------ | |
basic query with params | |
------------------------------------------------------------------ | |
from Acme.db import AcmeQuery | |
sql = 'select id from books where status=%(status)s limit 10' | |
params = {'status': 'A'} | |
rows = AcmeQuery(sql, params=params).run() | |
print(rows) | |
-->: [1, 2, 3, 4, 5] | |
------------------------------------------------------------------ | |
more advanced example | |
uses replica db and | |
a row modifier function that does s3 url signing on the s3 files | |
------------------------------------------------------------------ | |
from Acme.db import AcmeQuery | |
sql = 'select id, "pdf_url" from books where status=%(status)s limit 10' | |
params = {'status': 'A'} | |
s3signer = AcmeQuery.S3Signer(cols=['pdf_url']) | |
rows = AcmeQuery(sql, params=params, use_db_replica=True, row_modifier=s3signer).run() | |
print(rows[0]['pdf_url']) | |
-->: https://acme-bucket.s3.amazonaws.com/pdfs/book1.pdf?key=key_goes_here | |
""" | |
class RowModifier(object): | |
def __init__(self, func, params=None): | |
self.func = func | |
self.params = params or {} | |
class S3Signer(RowModifier): | |
def __init__(self, cols=None): | |
params = None | |
if cols: | |
params = {'keys': cols} | |
super().__init__(sign_dict_for_s3, params=params) | |
def __init__(self, sql, params=None, use_db_replica=False, row_modifier=None, flat=True): | |
self.db_conn = db_connection(use_db_replica) | |
self.sql = sql | |
self.params = params | |
self.flat = flat | |
self.row_modifier = row_modifier | |
self.results = [] | |
if not self.sql: | |
raise Exception('AcmeQuery needs some sql.') | |
def run(self): | |
# TODO: implement these as a custom cursor so we don't have to loop the results multiple times (psycopg cursor_factory) | |
# TODO: implement namedtuple variant as alternate option to dictfetchall (http://initd.org/psycopg/docs/extras.html#namedtuple-cursor) | |
with connections[self.db_conn].cursor() as cursor: | |
cursor.execute(self.sql, self.params) | |
# covert list of tuples to list of dictionaries | |
rows = dictfetchall(cursor) | |
# post process each row with a custom function | |
# function must accept the as a dictionary as its first argument | |
# other arguments may be accepted as well via kwargs | |
if self.row_modifier: | |
for row in rows: | |
self.row_modifier.func(row, **self.row_modifier.params) | |
# politely change [{'id': 1, 'id': 2}] into [1, 2] | |
if self.flat and len(cursor.description) == 1: | |
rows = [list(row.values())[0] for row in rows] | |
self.results = rows | |
return self.results |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment