public
Last active

Testing code for reliabily saving hundreds of transactions on GAE

  • Download Gist
shard_test.py
Python
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200
import webapp2
 
import decimal
import logging
import random
import string
 
from google.appengine.api import datastore_errors
from google.appengine.datastore import entity_pb
from google.appengine.ext import db
from google.appengine.ext import ndb
from google.appengine.ext.ndb import metadata
 
 
class DecimalProperty(ndb.Property):
"""A Property whose value is a decimal.Decimal object."""
 
def _datastore_type(self, value):
return str(value)
 
def _validate(self, value):
if not isinstance(value, decimal.Decimal):
raise datastore_errors.BadValueError('Expected decimal.Decimal, got %r'
% (value,))
return value
 
def _db_set_value(self, v, p, value):
value = str(value)
v.set_stringvalue(value)
if not self._indexed:
p.set_meaning(entity_pb.Property.TEXT)
 
def _db_get_value(self, v, _):
if not v.has_stringvalue():
return None
value = v.stringvalue()
return decimal.Decimal(value)
 
class Shard(ndb.Model):
"""Shards for each named counter"""
 
# No need for the name property, as it is the same as the key id
count = DecimalProperty(name='c', default=decimal.Decimal('0.00'),
indexed=False)
 
class Counter(ndb.Model):
"""Tracks the number of shards for each named counter"""
 
# No need for the name property, as it is the same as the key id
# NOTE: num_shards can only be set to a maximum of 5 due to xg limitations
num_shards = ndb.IntegerProperty(default=4, indexed=False)
 
@property
def shards(self):
prefix = self.key.id() # Cache for use in loop
dbkeys = []
for index in range(self.num_shards):
name = prefix + str(index)
dbkey = ndb.Key(Shard, name)
dbkeys.append(dbkey)
return filter(None, ndb.get_multi(dbkeys, use_memcache=False))
 
@ndb.tasklet
def compress_shards_async(self):
"""To be used when reducing num_shards"""
 
@ndb.tasklet
def __compress_shards_tx():
shards = self.shards
first_shard = shards.pop(0)
 
dbkeys = []
for shard in shards:
first_shard.count += shard.count
dbkeys.append(shard.key)
del_fut = ndb.delete_multi_async(dbkeys)
put_fut = first_shard.put_async()
yield del_fut, put_fut
 
yield ndb.transaction_async(__compress_shards_tx, use_memcache=False,
xg=True)
 
def compress_shards(self):
return self.compress_shards_async().get_result()
 
@property
def total(self):
count = decimal.Decimal('0.00') # Use initial value if no shards
for shard in self.shards:
count += shard.count
 
return count
 
@ndb.tasklet
def incr_async(self, value):
index = random.randint(0, self.num_shards - 1) # Use random shard
name = self.key.id() + str(index)
key = ndb.Key(Shard, name)
version = metadata.get_entity_group_version(key)
 
@ndb.tasklet
def __incr_tx():
 
shard = yield Shard.get_by_id_async(name, use_memcache=False)
if not shard:
# Setting the parent key for future queries and maintenance
# removes the benefit of using shards (shared entity group)
shard = Shard(id=name)
 
shard.count += value
yield shard.put_async()
 
try:
yield ndb.transaction_async(__incr_tx)
except db.InternalError, e:
if version == ndb.metadata.get_entity_group_version(key):
logging.warning('Almost corrupted shard %s' % name)
self.incr_async(value)
else:
logging.error('Shard %s is corrupted' % name)
raise e
 
def incr(self, value):
return self.incr_async(value).get_result()
 
@ndb.tasklet
def increment_batch(data_set):
# NOTE: data_set is modified in place
 
# (1/3) filter and fire off counter gets
# so the futures can autobatch
counters = {}
ctr_futs = {}
ctr_put_futs = []
zero_values = set()
for name, value in data_set.iteritems():
if value != decimal.Decimal('0.00'):
ctr_fut = Counter.get_by_id_async(name) # Use cache(s)
ctr_futs[name] = ctr_fut
else:
# Skip zero values because...
zero_values.add(name)
continue
 
for name in zero_values:
del data_set[name] # Remove all zero values from the data_set
del zero_values
 
while data_set: # Repeat until all transactions succeed
 
# (2/3) wait on counter gets and fire off increment transactions
# this way autobatchers should fill time
incr_futs = {}
for name, value in data_set.iteritems():
counter = counters.get(name)
if not counter:
counter = counters[name] = yield ctr_futs.pop(name)
if not counter:
logging.info('Creating new counter %s' % name)
counter = counters[name] = Counter(id=name)
ctr_put_futs.append(counter.put_async())
incr_futs[(name, value)] = counter.incr_async(value)
 
# (3/3) wait on increments and handle errors
# by using a tuple key for variable access
for (name, value), incr_fut in incr_futs.iteritems():
counter = counters[name]
try:
yield incr_fut
except db.TransactionFailedError:
if counter.num_shards != 5:
counter.num_shards += 1
logging.info('Increasing number of shards for %s to %i.' %
(name, counter.num_shards))
ctr_put_futs.append(counter.put_async())
except db.InternalError:
pass
else:
del data_set[name]
 
if data_set:
logging.warning('%i increments failed this batch.' % len(data_set))
 
yield ctr_put_futs # In case you get() the Counters later in the handler
 
raise ndb.Return(counters)
 
class ShardTestHandler(webapp2.RequestHandler):
@ndb.toplevel
def get(self):
if self.request.GET.get('delete'):
ndb.delete_multi_async(Shard.query().fetch(keys_only=True))
ndb.delete_multi_async(Counter.query().fetch(keys_only=True))
else:
data_set_test = {''.join([random.choice(string.letters+string.digits) for _ in range(12)]): decimal.Decimal(round(random.random() * 100, 2)) for _ in range(250)}
result = yield increment_batch(data_set_test)
self.response.out.write("Done!")
app = webapp2.WSGIApplication([('/shard_test/', ShardTestHandler)])

Please sign in to comment on this gist.

Something went wrong with that request. Please try again.