Created
September 15, 2017 15:36
-
-
Save rcoup/b076f2dcd46ded6a915c06caf3cf378a to your computer and use it in GitHub Desktop.
Server-side ORM cursors with PostgreSQL and Django 1.8
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Server-side ORM cursors with PostgreSQL and Django 1.8 | |
Typically Postgres cursors fetch all results regardless of any Python-side iteration | |
This will fetch chunks of rows at a time from the database using a server-side cursor, | |
and yield them usefully for iteration. | |
Usage: | |
>>> queryset = MyModel.objects.all() | |
>>> for o in server_side_iterator(queryset): | |
>>> print(o) | |
:) Django 1.11+ has this built into QuerySet.iterator() | |
""" | |
import itertools | |
import uuid | |
import django | |
from django.db import connections, transaction | |
from django.db.models.query import RawQuerySet | |
from django.db.models.sql.query import RawQuery | |
from django.db.backends.postgresql_psycopg2.base import utc_tzinfo_factory | |
class ServerSideRawQuery(RawQuery): | |
chunk_size = 2000 | |
def _execute_query(self): | |
cursor_name = "django_%s" % uuid.uuid4() | |
# because we need to bypass Django's cursor creation (was accepting **kwargs so hard?!) | |
# we want to do the same things: | |
dbconn = connections[self.using] | |
dbconn.ensure_connection() | |
dbconn.validate_thread_sharing() | |
dbcursor = dbconn.connection.cursor(name=cursor_name) | |
dbcursor.tzinfo_factory = utc_tzinfo_factory if settings.USE_TZ else None | |
if hasattr(dbcursor, 'itersize'): | |
dbcursor.itersize = self.chunk_size | |
self.cursor = dbconn.make_cursor(dbcursor) | |
# now we have a named cursor, psycopg2 will deal with DECLARE/FETCH/etc | |
self.cursor.execute(self.sql, self.params) | |
def __iter__(self): | |
# force the cursor to evaluate the first row | |
# otherwise we can't get the column/description details | |
it = super(ServerSideRawQuery, self).__iter__() | |
first = next(it) | |
# fake it back into a single iterator | |
return itertools.chain([first], it) | |
def server_side_iterator(queryset, chunk_size=2000): | |
# DECLARE/FETCH need to happen in a transaction | |
with transaction.atomic(using=queryset.db): | |
# extract the compiled SQL & parameters | |
sql, params = queryset.query.sql_with_params() | |
# create our raw query class | |
raw_query = ServerSideRawQuery(sql=sql, using=queryset.db, params=params) | |
raw_query.chunk_size = chunk_size | |
# create a RawQuerySet with it | |
raw_qs = RawQuerySet(sql, model=queryset.model, params=params, using=queryset.db, query=raw_query) | |
# iterate things - this keeps the transaction alive | |
for o in iter(raw_qs): | |
yield o |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Caveat: with complex multi-model querysets & deferred fields, I've seen occasional AttributeErrors along the lines of: