Skip to content

Instantly share code, notes, and snippets.

@anti1869
Created June 27, 2018 08:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save anti1869/cf557cd8a04870811a5d56cc746b516b to your computer and use it in GitHub Desktop.
Save anti1869/cf557cd8a04870811a5d56cc746b516b to your computer and use it in GitHub Desktop.
Flushes item into PostgreSQL table with fast upserts
import logging
from typing import Any, Dict, Optional, Sequence, Tuple, List
import simplejson as json
from psycopg2 import IntegrityError
from django.core.serializers.json import DjangoJSONEncoder
from django.db import connection
from django.db.utils import ProgrammingError
logger = logging.getLogger(__name__)
class PGFlusher(object):
"""
Flushes item into PostgreSQL table with fast upserts.
Usage::
data = [
{
"long_id": 123,
"created": "2018-01-15",
"street": "Street 1",
"burg": "Burg 1",
"_meta": "stuff1",
},
{
"long_id": 124,
"created": "2018-01-15",
"street": "Street 2",
"burg": "Burg 2",
"_meta": "stuff2",
},
]
flusher = PGFlusher(
db_table="test",
unique_fields=("long_id", ),
create_fields=("created", ),
ignore_fields=("_meta", ),
)
pk_collection = flusher.flush(data, return_pk=True)
for idx, pk in enumerate(pk_collection):
data[idx]["_pk"] = pk
"""
def __init__(self, db_table: str, unique_fields: Tuple[str, ...],
create_fields: Optional[Tuple[str, ...]] = None,
ignore_fields: Optional[Tuple[str, ...]] = None):
"""
You need to provide table name and list of fields on which to build unique constraint.
E.g. `db_table="store_items", unique_fields=("shop_id", "original_id")`
:param db_table: DB table name
:param unique_fields: List of fields to deduct constraint from.
:param create_fields: Those only used in INSERT and not in UPDATE operation.
:param ignore_fields: These fields will be ignored in all operations.
"""
self.db_table = db_table
self.unique_fields = unique_fields
self.unique_fields_string = ", ".join(self.unique_fields)
self.create_fields = create_fields or tuple()
self.ignore_fields = ignore_fields or tuple()
# Those will be populated on first passed data item
self.all_fields = None
self.all_fields_string = None
self.placeholders_string = None
self.set_statement = None
self._encoder = DjangoJSONEncoder()
def _make_fields(self, names: Sequence[str]):
"""
Prepare fields collection to use in SQL statement.
"""
self.all_fields = tuple(name for name in names if name not in self.ignore_fields)
self.data_fields = tuple(name for name in self.all_fields if name not in self.unique_fields)
self.all_fields_string = ", ".join(self.all_fields)
self.placeholders_string = ", ".join("%s" for _ in self.all_fields)
self.set_statement = ", ".join(
"{} = EXCLUDED.{}".format(f, f)
for f in self.data_fields if f not in self.create_fields
)
def flush(self, data: List[Dict], return_pk: Optional[bool] = True):
"""
This will save data to database using PostgreSQL-specific SQL statement
allowing to insert and update records in one take.
Example statement produced here::
INSERT INTO household_address
(street_name, city, zip_code, house_number, house_number_extension)
VALUES
('Street 1', 'burg', '1211BB', '22', ''),
('Street 2', 'burg', '1211BC', '11', '')
ON CONFLICT (zip_code, house_number, house_number_extension)
DO UPDATE SET
street_name = EXCLUDED.street_name, city = EXCLUDED.city
RETURNING id;
In return you will get list of PK that was either inserted or found during update.
This list is mapped back to the passed data.
More info:
https://www.postgresql.org/docs/9.6/static/sql-insert.html#SQL-ON-CONFLICT
"""
if not data:
return
# First item passed will trigger field initialization, so we can be a little bit dynamic
if self.all_fields is None:
self._make_fields(tuple(data[0].keys()))
sql = """
INSERT INTO {table_name}
({all_fields})
VALUES
{values_lists}
ON CONFLICT ({unique_fields})
DO UPDATE SET {set_statement}
{return_pk};
""".format(
table_name=self.db_table,
all_fields=self.all_fields_string,
values_lists=self.make_placeholder_lists(data),
unique_fields=self.unique_fields_string,
set_statement=self.set_statement,
return_pk='RETURNING id' if return_pk else '',
)
# Execute query
try:
with connection.cursor() as cursor:
cursor.execute(sql, self.make_values_list(data))
if not return_pk: # Exit early if PKs not needed
return
pk_list = cursor.fetchall()
except (IntegrityError, ProgrammingError):
logger.error("SQL error %s", sql, exc_info=True)
raise
for idx, value in enumerate(data):
value['_pk'] = pk_list[idx][0]
def make_placeholder_lists(self, data: List[Dict]):
"""
Prepare string placeholder to be filled with values by database driver.
E.g.: (%s, %s, %s), (%s, %s, %s). Number of blocks here must correspond to the number
of items in data.
"""
result = '({})'.format(
'), ('.join(self.placeholders_string for _ in range(len(data)))
)
return result
def make_values_list(self, data: List[Dict]):
"""
Prepare tuples with values that are extracted from each item (dict).
This is needed to preserve same order of fields for every data item.
Result is one big tuple with all values flattened in correct order.
"""
result = tuple(self._ex_value(v, name) for v in data for name in self.all_fields)
return result
def _ex_value(self, v: Dict, name: str) -> Any:
"""Extract from dict and cast type if necessary."""
value = v.get(name)
if isinstance(value, dict): # Simplest JSON detection
value = json.dumps(value, default=self._encoder.default)
return value
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment