Skip to content

Instantly share code, notes, and snippets.

@crizCraig
Created July 14, 2012 21:51
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save crizCraig/3113592 to your computer and use it in GitHub Desktop.
Save crizCraig/3113592 to your computer and use it in GitHub Desktop.
GAE Sharded counter
class BaseShardedCountModel(BaseModel):
# services
# - in-memory counts
# - getting counter name from entity
# - check for transition
def counter(self, name):
from lib.shardedcounter import Counter
return Counter(str(name + '_' + str(self.key().id())))
def getshardedcount(self, name):
if name in self.inmemorycounts():
return self.inmemorycounts(name)
else:
counter = self.counter(name)
self.checkfortransition(counter, name)
return self.inmemorycounts(name, counter.get_count())
def increment(self, name, incr):
counter = self.counter(name)
self.checkfortransition(counter, name)
return self.inmemorycounts(name, counter.increment(incr))
def checkfortransition(self, counter, name):
oldvalue = self.__getattribute__(name)
if oldvalue != NEWDEFAULTFOROLDCOUNT:
# transition from old counts
# I made the default value for the old counts negative, so that I knew if an entity was created after the switch.
counter.set_count(oldvalue)
self.__setattr__(name, NEWDEFAULTFOROLDCOUNT)
self.put()
def inmemorycounts(self, name=None, val=None):
if not hasattr(self, '_inmemorycounts'):
self._inmemorycounts = {}
if name and val != None:
self._inmemorycounts[name] = val
return val
elif name:
return self._inmemorycounts[name]
else:
return self._inmemorycounts
# -*- coding: utf-8 -*-
"""
tipfy.ext.sharded_counter
~~~~~~~~~~~~~~~~~~~~~~~~~
A general purpose sharded counter implementation for the datastore.
:copyright: 2008 William T Katz.
:copyright: 2010 Rodrigo Moraes.
:copyright: 2011 Craig Quiter.
:license: Apache, see LICENSE.txt for more details.
"""
import random
import logging
from google.appengine.api import memcache
from google.appengine.ext import db
from google.appengine.ext.db import NotSavedError
from google.appengine.runtime import apiproxy_errors
MAXSHARDS = 20 # Decreasing this will cause data loss.
class MemcachedCount(object):
@property
def namespace(self):
return __name__ + '.' + self.__class__.__name__
def __init__(self, name, counter):
self.key = 'MemcachedCount' + name
self.counter = counter
# maintain an in-process count for quicker lookups, i.e. not repeating memcache gets
self._count = memcache.get(self.key, namespace=self.namespace)
def get_count(self):
return self._count
def set_count(self, value):
self._count = value
memcache.Client().set(self.key, value, namespace=self.namespace) # cas, retries, delete, error
def delete_count(self):
self._count = None
memcache.delete(self.key)
count = property(get_count, set_count, delete_count)
def increment(self, incr=1):
# incr/decr was using unsigned ints and couldn't go negative
memcacheclient = memcache.Client()
for i in range(10): # Retry loop
curvalue = memcacheclient.gets(self.key, namespace=self.namespace)
if curvalue is None:
# Memcache value lost since instantiation...weird but seems to have happenned
self._count = self.counter.get_count_and_cache() # very expensive
import logging
logging.warning('fetching count from db during increment')
# value was incremented in database already, return
return self._count
else:
self._count = curvalue + incr
if memcacheclient.cas(self.key, self._count, namespace=self.namespace):
return self._count
else:
import logging
logging.error('error cas incrementing count: ' + self.key + ' to: ' + str(self._count))
import logging
logging.error('gave up incrementing count: ' + self.key + ' to: ' + str(self._count))
class Counter(object):
"""A counter using sharded writes to prevent contentions.
Should be used for counters that handle a lot of concurrent use.
Follows pattern described in Google I/O talk:
http://sites.google.com/site/io/building-scalable-web-applications-with-google-app-engine
Memcache is used for caching counts and if a cached count is available, it is
the most correct. If there are datastore put issues, we store the un-put values
into a delayed_incr memcache that will be applied as soon as the next shard put
is successful. Changes will only be lost if we lose memcache before a successful
datastore shard put or there's a failure/error in memcache.
Usage:
hits = Counter('hits')
hits.increment()
my_hits = hits.count
hits.get_count(nocache=True) # Forces non-cached count of all shards
hits.count = 6 # Set the counter to arbitrary value
hits.increment(incr=-1) # Decrement
hits.increment(10)
"""
def __init__(self, name, model):
if model:
self.name = name = str(name + '_' + str(model.key().id()))
self.model = model
else:
self.name = name
self.memcached = MemcachedCount(name = 'counter:' + name, counter=self)
self.delayed_incr = MemcachedCount(name = 'delayed:' + name, counter=self)
def delete(self):
q = db.Query(CounterShard).filter('name =', self.name)
shards = q.fetch(limit=MAXSHARDS)
db.delete(shards)
def get_count_and_cache(self, return_isnew=False):
is_new = True
q = db.Query(CounterShard).filter('name =', self.name)
shards = q.fetch(limit=MAXSHARDS)
datastore_count = 0
for shard in shards:
datastore_count += shard.count
is_new = False
if self.delayed_incr.count is None:
self.delayed_incr.count = 0
count = datastore_count + self.delayed_incr.count
self.memcached.count = count
if return_isnew:
return count, is_new
else:
return count
def get_count(self, nocache=False, return_isnew=False):
'''
Returns count and optionally a bool describing if the shard was newly created
- nocache tells whether to bypass memcache
'''
total = self.memcached.count
if nocache or total is None:
return self.get_count_and_cache(return_isnew)
else:
if return_isnew:
return total, False # Found shard in memcache, so we know it already existed
else:
return total
def set_count(self, value):
cur_value = Counter.get_count(self)
self.memcached.count = value
delta = value - cur_value
if delta != 0:
CounterShard.increment(self, incr=delta)
count = property(get_count, set_count)
def increment(self, incr=1):
# This will load the count in memcache, if it wasn't already.
# This fixed the bug that caused incrementbeforeview in shardedcountertests to fail.
self.get_count()
CounterShard.increment(self, incr)
return self.memcached.increment(incr)
class TransitionCounter(Counter):
def __init__(self, name, model):
self.oldname = name
super(TransitionCounter, self).__init__(name, model)
def increment(self, incr=1):
self.checkfortransition()
return super(TransitionCounter, self).increment(incr)
def get_count(self):
self.checkfortransition()
return super(TransitionCounter, self).get_count()
def checkfortransition(self):
from const import NEWDEFAULTFOROLDCOUNT
oldvalue = self.model.__getattribute__(self.oldname)
if oldvalue != NEWDEFAULTFOROLDCOUNT:
# transition from old counts
# I made the default value for the old counts negative, so that I knew if an entity was created after the switch.
self.set_count(oldvalue)
self.model.__setattr__(self.oldname, NEWDEFAULTFOROLDCOUNT)
self.model.put()
def set_count(self, value):
super(TransitionCounter, self).set_count(value)
count = property(get_count, set_count)
class CounterShard(db.Model):
name = db.StringProperty(required=True)
count = db.IntegerProperty(default=0)
@classmethod
def increment(cls, counter, incr=1):
index = random.randint(1, MAXSHARDS)
counter_name = counter.name
delayed_incr = counter.delayed_incr.count or 0
shard_key_name = 'Shard' + counter_name + '_' + str(index)
def get_or_create_shard():
shard = CounterShard.get_by_key_name(shard_key_name)
if shard is None:
shard = CounterShard(key_name=shard_key_name, name=counter_name)
shard.count += incr + delayed_incr
shard.put()
try:
db.run_in_transaction(get_or_create_shard)
except (db.Error, apiproxy_errors.Error), e:
counter.delayed_incr.increment(incr)
logging.error("CounterShard (%s) delayed increment %d: %s",
counter_name, incr, e)
return False
if delayed_incr:
counter.delayed_incr.count = 0
return True
def getcounterproperty(name, model):
if not hasattr(model, name):
setattr(model, name, Counter(name, model))
return getattr(model, name)
def gettransitioncounterproperty(name, model):
newname = 'shardedcounter__' + name
if not hasattr(model, newname):
setattr(model, newname, TransitionCounter(name, model))
return getattr(model, newname)
def transitiontoshardedvotecounttests(t):
def oldvotecountzeronewevote():
deletedbandmemcache()
user, poll, option0, option1, requestinfo = gettestentities(t)
# set old defaults to simulate transition
poll.votecount = 0
option0.votecount = 0
testcastvote(t,
poll,
user,
requestinfo = requestinfo,
selectedoption = OPTIONNAMEPREFIX + str(option0.order),
isskip = False,
newoptiondescription = '')
poll = Poll.get(poll.key())
option0 = PollOption.get(option0.key())
t.assertEqual(option0.getvotecount(), 1)
t.assertEqual(poll.shardedvotecounter.count, 1)
t.assertEqual(option0.votecount, NEWDEFAULTFOROLDCOUNT)
t.response.out.write('old vote count zero, one new vote passed<br>')
def oldvotecountnonzeronewevote():
deletedbandmemcache()
user, poll, option0, option1, requestinfo = gettestentities(t)
# set old defaults to simulate transition
poll.votecount = 1000
option0.votecount = 1000
poll.put()
option0.put()
testcastvote(t,
poll,
user,
requestinfo = requestinfo,
selectedoption = OPTIONNAMEPREFIX + str(option0.order),
isskip = False,
newoptiondescription = '')
poll = Poll.get(poll.key())
option0 = PollOption.get(option0.key())
t.assertEqual(option0.getvotecount(), 1001)
t.assertEqual(poll.shardedvotecounter.count, 1001)
t.assertEqual(option0.votecount, NEWDEFAULTFOROLDCOUNT)
t.response.out.write('old vote count non-zero, one new vote passed<br>')
def changevote(peek):
deletedbandmemcache()
user, poll, option0, option1, requestinfo = gettestentities(t)
# set old defaults to simulate transition
poll.votecount = 1
option0.votecount = 1
option1.votecount = 0
db.put([poll, option0, option1])
poll = poll.get(poll.key())
if peek:
t.assertEqual(poll.options[0].votecount, 1)
t.assertEqual(option0.getvotecount(), 1)
t.assertEqual(option1.getvotecount(), 0)
testcastvote(t,
poll,
user,
requestinfo = requestinfo,
selectedoption = OPTIONNAMEPREFIX + str(option0.order),
isskip = False,
newoptiondescription = '')
# refresh since they were written in castvote which did not update the options we have handles to in heap memory
poll = Poll.get(poll.key())
option0 = PollOption.get(option0.key())
option1 = PollOption.get(option1.key())
t.assertEqual(poll.shardedvotecounter.count, 2)
t.assertEqual(option0.getvotecount(), 2)
t.assertEqual(option1.getvotecount(), 0)
testcastvote(t,
poll,
user,
requestinfo = requestinfo,
selectedoption = OPTIONNAMEPREFIX + str(option1.order),
isskip = False,
newoptiondescription = '')
poll = Poll.get(poll.key())
option0 = PollOption.get(option0.key())
option1 = PollOption.get(option1.key())
t.assertEqual(option0.getvotecount(), 1)
t.assertEqual(option1.getvotecount(), 1)
t.assertEqual(poll.shardedvotecounter.count, 2)
t.assertEqual(poll.votecount, NEWDEFAULTFOROLDCOUNT)
t.assertEqual(option0.votecount, NEWDEFAULTFOROLDCOUNT)
t.assertEqual(option1.votecount, NEWDEFAULTFOROLDCOUNT)
t.response.out.write('change vote ' + ('peek' if peek else ' no peek ') + ' passed<br>')
def newoption():
NEWOPTIONDESCRIPTION = '2'
deletedbandmemcache()
user, poll, option0, option1, requestinfo = gettestentities(t)
testcastvote(t,
poll,
user,
requestinfo = requestinfo,
selectedoption = NONEABOVEOPTIONVALUE,
isskip = False,
newoptiondescription = NEWOPTIONDESCRIPTION)
poll = Poll.get(poll.key())
user = getuser(t)
optionnew = PollOption.all().filter('description = ', NEWOPTIONDESCRIPTION).get()
t.assertEqual(option0.getvotecount(), 0)
t.assertEqual(option1.getvotecount(), 0)
t.assertEqual(optionnew.getvotecount(), 1)
t.assertEqual(poll.shardedvotecounter.count, 1)
t.assertEqual(poll.votecount, NEWDEFAULTFOROLDCOUNT)
t.assertEqual(option0.votecount, NEWDEFAULTFOROLDCOUNT)
t.assertEqual(option1.votecount, NEWDEFAULTFOROLDCOUNT)
t.response.out.write('new option passed<br>')
def dupevote():
deletedbandmemcache()
user, poll, option0, option1, requestinfo = gettestentities(t)
testcastvote(t,
poll,
user,
requestinfo = requestinfo,
selectedoption = OPTIONNAMEPREFIX + str(option.order),
isskip = False,
newoptiondescription = '')
testcastvote(t,
poll,
user,
requestinfo = requestinfo,
selectedoption = OPTIONNAMEPREFIX + str(option.order),
isskip = False,
newoptiondescription = '')
poll = Poll.get(poll.key())
#user = getuser(t)
option = PollOption.get(option.key())
t.assertEqual(Poll.all().count(), 1)
t.assertEqual(PollOption.all().count(), 2)
t.assertEqual(option.getvotecount(), 1)
t.assertEqual(Vote.all().count(), 2)
t.assertEqual(poll.shardedvotecounter.count, 1)
t.assertEqual(user.votesbycount, 1)
t.assertEqual(user.shardedvotesoncounter.count, 1)
checklists(t, user, poll, votecount=1, skipcount=0)
t.response.out.write('dupe vote passed<br>')
t.response.out.write('<br><br><br>transition to sharded vote count tests:<br><br>')
oldvotecountzeronewevote()
oldvotecountnonzeronewevote()
# changevote(peek=True) Fails due to eventual consistency, I think.
changevote(peek=False)
newoption()
def shardedcountertests(t):
def incrementbeforeview():
from lib.shardedcounter import Counter
KEY_NAME = '_incrementbeforeview_counter_'
counter = Counter(KEY_NAME, model=None)
counter.set_count(100)
memcache.flush_all() # reset to replicate a count that hasn't been viewed yet
counter = Counter(KEY_NAME, model=None)
counter.increment()
t.assertEqual(counter.count, 101)
t.response.out.write('increment before view passed<br>')
def increment_nonzero_nocache():
from lib.shardedcounter import Counter
KEY_NAME = '_increment_non-zero_nocache_'
counter = Counter(KEY_NAME, model=None)
counter.set_count(1)
counter = Counter(KEY_NAME, model=None)
memcache.flush_all() # reset to replicate a count that was lost in memcache
counter.increment()
t.assertEqual(counter.count, 2)
t.response.out.write('increment nonzero no cache passed<br>')
t.response.out.write('<br><br><br>sharded counter tests:<br><br>')
incrementbeforeview()
increment_nonzero_nocache()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment