Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
Testing code for reliabily saving hundreds of transactions on GAE
import webapp2
import decimal
import logging
import random
import string
from google.appengine.api import datastore_errors
from google.appengine.datastore import entity_pb
from google.appengine.ext import db
from google.appengine.ext import ndb
from google.appengine.ext.ndb import metadata
class DecimalProperty(ndb.Property):
"""A Property whose value is a decimal.Decimal object."""
def _datastore_type(self, value):
return str(value)
def _validate(self, value):
if not isinstance(value, decimal.Decimal):
raise datastore_errors.BadValueError('Expected decimal.Decimal, got %r'
% (value,))
return value
def _db_set_value(self, v, p, value):
value = str(value)
if not self._indexed:
def _db_get_value(self, v, _):
if not v.has_stringvalue():
return None
value = v.stringvalue()
return decimal.Decimal(value)
class Shard(ndb.Model):
"""Shards for each named counter"""
# No need for the name property, as it is the same as the key id
count = DecimalProperty(name='c', default=decimal.Decimal('0.00'),
class Counter(ndb.Model):
"""Tracks the number of shards for each named counter"""
# No need for the name property, as it is the same as the key id
# NOTE: num_shards can only be set to a maximum of 5 due to xg limitations
num_shards = ndb.IntegerProperty(default=4, indexed=False)
def shards(self):
prefix = # Cache for use in loop
dbkeys = []
for index in range(self.num_shards):
name = prefix + str(index)
dbkey = ndb.Key(Shard, name)
return filter(None, ndb.get_multi(dbkeys, use_memcache=False))
def compress_shards_async(self):
"""To be used when reducing num_shards"""
def __compress_shards_tx():
shards = self.shards
first_shard = shards.pop(0)
dbkeys = []
for shard in shards:
first_shard.count += shard.count
del_fut = ndb.delete_multi_async(dbkeys)
put_fut = first_shard.put_async()
yield del_fut, put_fut
yield ndb.transaction_async(__compress_shards_tx, use_memcache=False,
def compress_shards(self):
return self.compress_shards_async().get_result()
def total(self):
count = decimal.Decimal('0.00') # Use initial value if no shards
for shard in self.shards:
count += shard.count
return count
def incr_async(self, value):
index = random.randint(0, self.num_shards - 1) # Use random shard
name = + str(index)
key = ndb.Key(Shard, name)
version = metadata.get_entity_group_version(key)
def __incr_tx():
shard = yield Shard.get_by_id_async(name, use_memcache=False)
if not shard:
# Setting the parent key for future queries and maintenance
# removes the benefit of using shards (shared entity group)
shard = Shard(id=name)
shard.count += value
yield shard.put_async()
yield ndb.transaction_async(__incr_tx)
except db.InternalError, e:
if version == ndb.metadata.get_entity_group_version(key):
logging.warning('Almost corrupted shard %s' % name)
logging.error('Shard %s is corrupted' % name)
raise e
def incr(self, value):
return self.incr_async(value).get_result()
def increment_batch(data_set):
# NOTE: data_set is modified in place
# (1/3) filter and fire off counter gets
# so the futures can autobatch
counters = {}
ctr_futs = {}
ctr_put_futs = []
zero_values = set()
for name, value in data_set.iteritems():
if value != decimal.Decimal('0.00'):
ctr_fut = Counter.get_by_id_async(name) # Use cache(s)
ctr_futs[name] = ctr_fut
# Skip zero values because...
for name in zero_values:
del data_set[name] # Remove all zero values from the data_set
del zero_values
while data_set: # Repeat until all transactions succeed
# (2/3) wait on counter gets and fire off increment transactions
# this way autobatchers should fill time
incr_futs = {}
for name, value in data_set.iteritems():
counter = counters.get(name)
if not counter:
counter = counters[name] = yield ctr_futs.pop(name)
if not counter:'Creating new counter %s' % name)
counter = counters[name] = Counter(id=name)
incr_futs[(name, value)] = counter.incr_async(value)
# (3/3) wait on increments and handle errors
# by using a tuple key for variable access
for (name, value), incr_fut in incr_futs.iteritems():
counter = counters[name]
yield incr_fut
except db.TransactionFailedError:
if counter.num_shards != 5:
counter.num_shards += 1'Increasing number of shards for %s to %i.' %
(name, counter.num_shards))
except db.InternalError:
del data_set[name]
if data_set:
logging.warning('%i increments failed this batch.' % len(data_set))
yield ctr_put_futs # In case you get() the Counters later in the handler
raise ndb.Return(counters)
class ShardTestHandler(webapp2.RequestHandler):
def get(self):
if self.request.GET.get('delete'):
data_set_test = {''.join([random.choice(string.letters+string.digits) for _ in range(12)]): decimal.Decimal(round(random.random() * 100, 2)) for _ in range(250)}
result = yield increment_batch(data_set_test)
app = webapp2.WSGIApplication([('/shard_test/', ShardTestHandler)])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment