Skip to content

@someone1 /shard_test.py
Created

Embed URL

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Testing code for reliabily saving hundreds of transactions on GAE
import webapp2
import decimal
import logging
import random
import string
from google.appengine.api import datastore_errors
from google.appengine.datastore import entity_pb
from google.appengine.ext import db
from google.appengine.ext import ndb
from google.appengine.ext.ndb import metadata
class DecimalProperty(ndb.Property):
"""A Property whose value is a decimal.Decimal object."""
def _datastore_type(self, value):
return str(value)
def _validate(self, value):
if not isinstance(value, decimal.Decimal):
raise datastore_errors.BadValueError('Expected decimal.Decimal, got %r'
% (value,))
return value
def _db_set_value(self, v, p, value):
value = str(value)
v.set_stringvalue(value)
if not self._indexed:
p.set_meaning(entity_pb.Property.TEXT)
def _db_get_value(self, v, _):
if not v.has_stringvalue():
return None
value = v.stringvalue()
return decimal.Decimal(value)
class Shard(ndb.Model):
"""Shards for each named counter"""
# No need for the name property, as it is the same as the key id
count = DecimalProperty(name='c', default=decimal.Decimal('0.00'),
indexed=False)
class Counter(ndb.Model):
"""Tracks the number of shards for each named counter"""
# No need for the name property, as it is the same as the key id
# NOTE: num_shards can only be set to a maximum of 5 due to xg limitations
num_shards = ndb.IntegerProperty(default=4, indexed=False)
@property
def shards(self):
prefix = self.key.id() # Cache for use in loop
dbkeys = []
for index in range(self.num_shards):
name = prefix + str(index)
dbkey = ndb.Key(Shard, name)
dbkeys.append(dbkey)
return filter(None, ndb.get_multi(dbkeys, use_memcache=False))
@ndb.tasklet
def compress_shards_async(self):
"""To be used when reducing num_shards"""
@ndb.tasklet
def __compress_shards_tx():
shards = self.shards
first_shard = shards.pop(0)
dbkeys = []
for shard in shards:
first_shard.count += shard.count
dbkeys.append(shard.key)
del_fut = ndb.delete_multi_async(dbkeys)
put_fut = first_shard.put_async()
yield del_fut, put_fut
yield ndb.transaction_async(__compress_shards_tx, use_memcache=False,
xg=True)
def compress_shards(self):
return self.compress_shards_async().get_result()
@property
def total(self):
count = decimal.Decimal('0.00') # Use initial value if no shards
for shard in self.shards:
count += shard.count
return count
@ndb.tasklet
def incr_async(self, value):
index = random.randint(0, self.num_shards - 1) # Use random shard
name = self.key.id() + str(index)
key = ndb.Key(Shard, name)
version = metadata.get_entity_group_version(key)
@ndb.tasklet
def __incr_tx():
shard = yield Shard.get_by_id_async(name, use_memcache=False)
if not shard:
# Setting the parent key for future queries and maintenance
# removes the benefit of using shards (shared entity group)
shard = Shard(id=name)
shard.count += value
yield shard.put_async()
try:
yield ndb.transaction_async(__incr_tx)
except db.InternalError, e:
if version == ndb.metadata.get_entity_group_version(key):
logging.warning('Almost corrupted shard %s' % name)
self.incr_async(value)
else:
logging.error('Shard %s is corrupted' % name)
raise e
def incr(self, value):
return self.incr_async(value).get_result()
@ndb.tasklet
def increment_batch(data_set):
# NOTE: data_set is modified in place
# (1/3) filter and fire off counter gets
# so the futures can autobatch
counters = {}
ctr_futs = {}
ctr_put_futs = []
zero_values = set()
for name, value in data_set.iteritems():
if value != decimal.Decimal('0.00'):
ctr_fut = Counter.get_by_id_async(name) # Use cache(s)
ctr_futs[name] = ctr_fut
else:
# Skip zero values because...
zero_values.add(name)
continue
for name in zero_values:
del data_set[name] # Remove all zero values from the data_set
del zero_values
while data_set: # Repeat until all transactions succeed
# (2/3) wait on counter gets and fire off increment transactions
# this way autobatchers should fill time
incr_futs = {}
for name, value in data_set.iteritems():
counter = counters.get(name)
if not counter:
counter = counters[name] = yield ctr_futs.pop(name)
if not counter:
logging.info('Creating new counter %s' % name)
counter = counters[name] = Counter(id=name)
ctr_put_futs.append(counter.put_async())
incr_futs[(name, value)] = counter.incr_async(value)
# (3/3) wait on increments and handle errors
# by using a tuple key for variable access
for (name, value), incr_fut in incr_futs.iteritems():
counter = counters[name]
try:
yield incr_fut
except db.TransactionFailedError:
if counter.num_shards != 5:
counter.num_shards += 1
logging.info('Increasing number of shards for %s to %i.' %
(name, counter.num_shards))
ctr_put_futs.append(counter.put_async())
except db.InternalError:
pass
else:
del data_set[name]
if data_set:
logging.warning('%i increments failed this batch.' % len(data_set))
yield ctr_put_futs # In case you get() the Counters later in the handler
raise ndb.Return(counters)
class ShardTestHandler(webapp2.RequestHandler):
@ndb.toplevel
def get(self):
if self.request.GET.get('delete'):
ndb.delete_multi_async(Shard.query().fetch(keys_only=True))
ndb.delete_multi_async(Counter.query().fetch(keys_only=True))
else:
data_set_test = {''.join([random.choice(string.letters+string.digits) for _ in range(12)]): decimal.Decimal(round(random.random() * 100, 2)) for _ in range(250)}
result = yield increment_batch(data_set_test)
self.response.out.write("Done!")
app = webapp2.WSGIApplication([('/shard_test/', ShardTestHandler)])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Something went wrong with that request. Please try again.