Skip to content

Instantly share code, notes, and snippets.

@ArtemBernatskyy
Forked from bendavis78/seed_pagination.py
Created November 3, 2017 00:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ArtemBernatskyy/9183885d9b4d4e9776bd492b8212c064 to your computer and use it in GitHub Desktop.
Save ArtemBernatskyy/9183885d9b4d4e9776bd492b8212c064 to your computer and use it in GitHub Desktop.
Example implementation of randomized pagination in django and django-rest-framework
"""
Adds a `seed` paramter to DRF's `next` and `prev` pagination urls
"""
from rest_framework import serializers
from rest_framework import pagination
from rest_framework.templatetags.rest_framework import replace_query_param
from . import utils
class PageSeedFieldMixin:
seed_field = 'seed'
def to_representation(self, value):
url = super().to_representation(value)
if not url:
return None
seed = utils.encode_float(self.context.get('seed'))
return replace_query_param(url, self.seed_field, seed)
class NextPageSeedField(PageSeedFieldMixin, pagination.NextPageField):
pass
class PreviousPageSeedField(PageSeedFieldMixin, pagination.PreviousPageField):
pass
class PaginationSeedSerializer(pagination.PaginationSerializer):
next = NextPageSeedField(source='*')
previous = PreviousPageSeedField(source='*')
seed = serializers.SerializerMethodField()
def get_seed(self, object):
return utils.encode_float(self.context.get('seed')).decode()
"""
`SeededQuerySet` adds the `set_seed` function which prepends a setseed() call
to the SQL query (postgres only).
See http://www.postgresql.org/docs/8.3/static/sql-set.html for more info.
"""
from django.db import models
from django.db import connections
from django.db.models.sql import Query
from django.db.models.sql.compiler import SQLCompiler
class SeededSQLCompiler(SQLCompiler):
def as_sql(self, *args, **kwargs):
sql, params = super().as_sql()
if self.query.seed and '?' in self.query.order_by:
# seed must be a float between 0 and 1
seed = float(self.query.seed)
if not 0 < seed < 1:
raise ValueError("Invalid seed value: " + seed)
sql = 'SELECT setseed(%s); ' + sql
params += ('{:0.52f}'.format(seed),)
return sql, params
class SeededQuery(Query):
seed = None
def clone(self, *args, **kwargs):
kwargs['seed'] = self.seed
return super().clone(*args, **kwargs)
def get_compiler(self, using=None, connection=None):
if using is None and connection is None:
raise ValueError("Need either using or connection")
if using:
connection = connections[using]
# Check that the compiler will be able to execute the query
for alias, aggregate in self.aggregate_select.items():
connection.ops.check_aggregate_support(aggregate)
return SeededSQLCompiler(self, connection, using)
class SeededQuerySet(models.QuerySet):
def __init__(self, model=None, query=None, using=None, hints=None):
query = query or SeededQuery(model)
super().__init__(model, query, using, hints)
def set_seed(self, seed):
self.query.seed = seed
return self
"""
Utility functions for encoding integers and floats into short ASCII strings
(ideal for URL parameters)
"""
import struct
import string
BASE_ALPH = tuple(string.ascii_letters + string.digits)
BASE_DICT = dict((c, v) for v, c in enumerate(BASE_ALPH))
def decode_int(encoded):
num = 0
for char in encoded:
num = num * len(BASE_ALPH) + BASE_DICT[char]
return num
def encode_int(num):
encoding = ''
while num:
num, rem = divmod(num, len(BASE_ALPH))
encoding = BASE_ALPH[rem] + encoding
return encoding.encode('ascii')
def encode_float(num):
b = struct.pack('>d', num)
return encode_int(int.from_bytes(b, 'big'))
def decode_float(encoded):
b = decode_int(encoded).to_bytes(8, 'big')
return struct.unpack('>d', b)[0]
import random
from rest_framework.response import Response
from . import models
from . import seed_pagination
class EntryViewSet(BaseViewSet):
queryset = models.Entry.objects.all()
serializer_class = serializers.EntrySerializer
paginate_by = 10
paginate_by_param = 'page_size'
max_paginate_by = 100
pagination_serializer_class = seed_pagination.PaginationSeedSerializer
def list(self, request, *args, **kwargs):
qs = self.filter_queryset(self.get_queryset())
# randomize queryset based on given seed
seed = self.request.query_params.get('seed', random.random())
qs = qs.set_seed(self.seed).order_by('?')
page = self.paginate_queryset(qs)
if page is not None:
serializer = self.get_pagination_serializer(page)
else:
serializer = self.get_serializer(qs, many=True)
return Response(serializer.data)
def get_serializer_context(self):
context = super().get_serializer_context()
context.update({
'seed': self.seed
})
return context
@property
def seed(self):
field = self.pagination_serializer_class._declared_fields['next']
page = self.request.query_params.get(field.page_field)
seed = self.request.query_params.get(field.seed_field)
if not seed and (not page or page == '1'):
return random.random()
return seed and utils.decode_float(seed) or None
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment