Create a gist now

Instantly share code, notes, and snippets.

Testing code for reliabily saving hundreds of transactions on GAE
import webapp2
import decimal
import logging
import random
import string
from google.appengine.api import datastore_errors
from google.appengine.datastore import entity_pb
from google.appengine.ext import db
from google.appengine.ext import ndb
from google.appengine.ext.ndb import metadata
class DecimalProperty(ndb.Property):
"""A Property whose value is a decimal.Decimal object."""
def _datastore_type(self, value):
return str(value)
def _validate(self, value):
if not isinstance(value, decimal.Decimal):
raise datastore_errors.BadValueError('Expected decimal.Decimal, got %r'
% (value,))
return value
def _db_set_value(self, v, p, value):
value = str(value)
v.set_stringvalue(value)
if not self._indexed:
p.set_meaning(entity_pb.Property.TEXT)
def _db_get_value(self, v, _):
if not v.has_stringvalue():
return None
value = v.stringvalue()
return decimal.Decimal(value)
class Shard(ndb.Model):
"""Shards for each named counter"""
# No need for the name property, as it is the same as the key id
count = DecimalProperty(name='c', default=decimal.Decimal('0.00'),
indexed=False)
class Counter(ndb.Model):
"""Tracks the number of shards for each named counter"""
# No need for the name property, as it is the same as the key id
# NOTE: num_shards can only be set to a maximum of 5 due to xg limitations
num_shards = ndb.IntegerProperty(default=4, indexed=False)
@property
def shards(self):
prefix = self.key.id() # Cache for use in loop
dbkeys = []
for index in range(self.num_shards):
name = prefix + str(index)
dbkey = ndb.Key(Shard, name)
dbkeys.append(dbkey)
return filter(None, ndb.get_multi(dbkeys, use_memcache=False))
@ndb.tasklet
def compress_shards_async(self):
"""To be used when reducing num_shards"""
@ndb.tasklet
def __compress_shards_tx():
shards = self.shards
first_shard = shards.pop(0)
dbkeys = []
for shard in shards:
first_shard.count += shard.count
dbkeys.append(shard.key)
del_fut = ndb.delete_multi_async(dbkeys)
put_fut = first_shard.put_async()
yield del_fut, put_fut
yield ndb.transaction_async(__compress_shards_tx, use_memcache=False,
xg=True)
def compress_shards(self):
return self.compress_shards_async().get_result()
@property
def total(self):
count = decimal.Decimal('0.00') # Use initial value if no shards
for shard in self.shards:
count += shard.count
return count
@ndb.tasklet
def incr_async(self, value):
index = random.randint(0, self.num_shards - 1) # Use random shard
name = self.key.id() + str(index)
key = ndb.Key(Shard, name)
version = metadata.get_entity_group_version(key)
@ndb.tasklet
def __incr_tx():
shard = yield Shard.get_by_id_async(name, use_memcache=False)
if not shard:
# Setting the parent key for future queries and maintenance
# removes the benefit of using shards (shared entity group)
shard = Shard(id=name)
shard.count += value
yield shard.put_async()
try:
yield ndb.transaction_async(__incr_tx)
except db.InternalError, e:
if version == ndb.metadata.get_entity_group_version(key):
logging.warning('Almost corrupted shard %s' % name)
self.incr_async(value)
else:
logging.error('Shard %s is corrupted' % name)
raise e
def incr(self, value):
return self.incr_async(value).get_result()
@ndb.tasklet
def increment_batch(data_set):
# NOTE: data_set is modified in place
# (1/3) filter and fire off counter gets
# so the futures can autobatch
counters = {}
ctr_futs = {}
ctr_put_futs = []
zero_values = set()
for name, value in data_set.iteritems():
if value != decimal.Decimal('0.00'):
ctr_fut = Counter.get_by_id_async(name) # Use cache(s)
ctr_futs[name] = ctr_fut
else:
# Skip zero values because...
zero_values.add(name)
continue
for name in zero_values:
del data_set[name] # Remove all zero values from the data_set
del zero_values
while data_set: # Repeat until all transactions succeed
# (2/3) wait on counter gets and fire off increment transactions
# this way autobatchers should fill time
incr_futs = {}
for name, value in data_set.iteritems():
counter = counters.get(name)
if not counter:
counter = counters[name] = yield ctr_futs.pop(name)
if not counter:
logging.info('Creating new counter %s' % name)
counter = counters[name] = Counter(id=name)
ctr_put_futs.append(counter.put_async())
incr_futs[(name, value)] = counter.incr_async(value)
# (3/3) wait on increments and handle errors
# by using a tuple key for variable access
for (name, value), incr_fut in incr_futs.iteritems():
counter = counters[name]
try:
yield incr_fut
except db.TransactionFailedError:
if counter.num_shards != 5:
counter.num_shards += 1
logging.info('Increasing number of shards for %s to %i.' %
(name, counter.num_shards))
ctr_put_futs.append(counter.put_async())
except db.InternalError:
pass
else:
del data_set[name]
if data_set:
logging.warning('%i increments failed this batch.' % len(data_set))
yield ctr_put_futs # In case you get() the Counters later in the handler
raise ndb.Return(counters)
class ShardTestHandler(webapp2.RequestHandler):
@ndb.toplevel
def get(self):
if self.request.GET.get('delete'):
ndb.delete_multi_async(Shard.query().fetch(keys_only=True))
ndb.delete_multi_async(Counter.query().fetch(keys_only=True))
else:
data_set_test = {''.join([random.choice(string.letters+string.digits) for _ in range(12)]): decimal.Decimal(round(random.random() * 100, 2)) for _ in range(250)}
result = yield increment_batch(data_set_test)
self.response.out.write("Done!")
app = webapp2.WSGIApplication([('/shard_test/', ShardTestHandler)])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment