Created
November 19, 2018 06:20
-
-
Save jnoortheen/13dead7e6ba9ec7df98bb8ca8bd8dcd9 to your computer and use it in GitHub Desktop.
Bulk Insert with Ignore for Django ORM
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import namedtuple | |
import enum | |
from django.db import connection, transaction | |
from django.db.models import Model, Subquery, IntegerField, Sum, Func | |
from django.db.models.sql import InsertQuery | |
from typing import List, Type, Dict, Any, Union, NamedTuple, Tuple | |
from . import signals | |
def dictfetchall(cursor): | |
"""Return all rows from a cursor as a dict""" | |
columns = [col[0] for col in cursor.description] | |
return [ | |
dict(zip(columns, row)) | |
for row in cursor.fetchall() | |
] | |
def namedtuplefetchall(cursor): | |
"""Return all rows from a cursor as a namedtuple""" | |
desc = cursor.description | |
nt_result = namedtuple('Result', [col[0] for col in desc]) | |
return [nt_result(*row) for row in cursor.fetchall()] | |
def get_insert_sql(model, objs) -> Tuple[str, tuple]: | |
q = InsertQuery(model) | |
fields = [fld for fld in model._meta.concrete_fields if not fld.auto_created] | |
q.insert_values(fields, objs) | |
return q.sql_with_params()[0] | |
def get_upsert_sql(model: Type[Model], objs: List[Model]) -> Tuple[str, tuple]: | |
sql, sql_args = get_insert_sql(model, objs) | |
# add for upsert | |
sql = f'{sql} ON CONFLICT DO NOTHING' | |
return sql, sql_args | |
class ResultType(enum.Enum): | |
pk = 'pk' | |
dict = 'Values' | |
named_tuple = 'Named Tuple' | |
none = 'None' | |
def bulk_insert_with_ignore( | |
objs: List[Model], | |
result_type: ResultType = ResultType.none | |
) -> Union[List[Union[str, int]], List[Dict[str, Any]], List[NamedTuple], None]: | |
""" | |
use postgres bulk upsert | |
Args: | |
objs: list of model instances to insert, | |
result_type: resultant type | |
Returns: | |
list: list of inserted records, empty if not inserted any | |
""" | |
if len(objs) < 1: | |
return | |
model = type(objs[0]) | |
pk_field_name = model._meta.pk.attname | |
sql, sql_args = get_upsert_sql(model, objs) | |
if result_type != ResultType.none: | |
returning_clause = pk_field_name if result_type == ResultType.pk else '*' | |
sql = f'{sql} RETURNING {returning_clause}' | |
with transaction.atomic(), connection.cursor() as cursor: | |
# wrap within a transaction and lock the table | |
cursor.execute(f'LOCK TABLE {model._meta.db_table} IN SHARE ROW EXCLUSIVE MODE') | |
# run upsert and fetch results | |
cursor.execute(sql, sql_args) | |
result = None | |
if result_type != ResultType.none: | |
result = ( | |
[row[0] for row in cursor.fetchall()] | |
if result_type == ResultType.pk | |
else dictfetchall(cursor) if result_type == ResultType.dict | |
else namedtuplefetchall(cursor) if result_type == ResultType.named_tuple | |
else None | |
) | |
if result_type != ResultType.none: | |
signals.post_bulk_insert.send(sender=model, records=result, using=cursor.db) | |
return result |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment