Skip to content

Instantly share code, notes, and snippets.

@IlyaSkriblovsky
Created December 4, 2018 08:41
Show Gist options
  • Save IlyaSkriblovsky/a803f71d770ba1f29024313a9ca762af to your computer and use it in GitHub Desktop.
Save IlyaSkriblovsky/a803f71d770ba1f29024313a9ca762af to your computer and use it in GitHub Desktop.
Batching txmongo requests
WriteOp = Union[
InsertOne,
UpdateOne, UpdateMany,
DeleteOne, DeleteMany,
ReplaceOne
]
class MongoCollectionQueue:
# FIXME:
# Do not use ordered=True for inserts! In case of AutoReconnect error MongoCollectionQueue
# tries to re-send the batch of operations. But if this batch was already submitted MongoDB will
# return DuplicateKeyError for the first duplicated op. MongoCollectionQueue in ordered
# mode will remove that op and retry all others from the queue. Then MongoDB will return
# DuplicateKeyError for the next op and this process will continue for every other op
# in the failed batch, making the queue VERY slow. (Bear in mind that documents that are
# going to be inserted get their _ids before sending to MongoDB and retry will be performed
# with the same _ids).
max_batch_size = 10000
active = True
paused = True
__FLAG_RETRY_ON_DUP = 1
def __init__(self, collection: Collection, ordered, name, noisy):
self.collection = collection
self.ordered = ordered
self.name = name
self.noisy = noisy
self._queue: List[WriteOp] = []
self.__op_flags: Dict[int, int] = {}
self.__stop_deferreds = []
def queue_size(self): return len(self._queue)
def stop(self):
d = defer.Deferred()
self.__stop_deferreds.append(d)
self.active = False
if self.paused:
self.__step()
return d
@classmethod
def __flags_to_int(cls, retry_on_duplicate_key):
flags = 0
if retry_on_duplicate_key:
flags |= cls.__FLAG_RETRY_ON_DUP
return flags
def enqueue_one(self, op: WriteOp, *, retry_on_duplicate_key = False):
self._queue.append(op)
flags = self.__flags_to_int(retry_on_duplicate_key)
if flags != 0:
self.__op_flags[len(self._queue) - 1] = flags
if self.paused:
self.__step()
def enqueue_many(self, ops: Iterable[WriteOp], *, retry_on_duplicate_key = False):
old_len = len(self._queue)
self._queue.extend(ops)
flags = self.__flags_to_int(retry_on_duplicate_key)
if flags != 0:
self.__op_flags.update({i: flags for i in range(old_len, len(self._queue))})
if self.paused:
self.__step()
def enqueue_one_insert(self, doc: Dict, *, retry_on_duplicate_key=False):
self.enqueue_one(InsertOne(RawBSONDocument(BSON.encode(doc))), retry_on_duplicate_key=retry_on_duplicate_key)
def enqueue_many_inserts(self, docs: Iterable[dict], *, retry_on_duplicate_key=False):
self.enqueue_many(
(InsertOne(RawBSONDocument(BSON.encode(doc))) for doc in docs),
retry_on_duplicate_key=retry_on_duplicate_key
)
def enqueue_update(self, filter: dict, update: dict, upsert=False, *, retry_on_duplicate_key=False):
self.enqueue_one(
UpdateOne(
RawBSONDocument(BSON.encode(filter)),
RawBSONDocument(BSON.encode(update)),
upsert=upsert
),
retry_on_duplicate_key=retry_on_duplicate_key
)
def enqueue_replace(self, filter: dict, replacement: dict, upsert=False, *, retry_on_duplicate_key=False):
self.enqueue_one(
ReplaceOne(
RawBSONDocument(BSON.encode(filter)),
RawBSONDocument(BSON.encode(replacement)),
upsert=upsert
),
retry_on_duplicate_key=retry_on_duplicate_key
)
def __log(self, message):
if self.noisy:
print('MQ {}: {}'.format(self.name, message))
def __err(self, message):
print('MQ ERR {}: {}'.format(self.name, message))
@defer.inlineCallbacks
def __step(self):
if not self._queue:
if not self.active:
self.__log('Gracefully stopped')
self.paused = False
for dfr in self.__stop_deferreds:
dfr.callback(None)
self.__stop_deferreds = []
else:
self.paused = True
return
self.paused = False
batch_size = min(len(self._queue), self.max_batch_size)
suffix = ''
if len(self._queue) > batch_size:
suffix = ' ({} more in queue)'.format(len(self._queue) - batch_size)
self.__log('Sending {} ops{}'.format(batch_size, suffix))
retry_ops = []
processed_ops = 0
try:
yield self.collection.bulk_write(self._queue[:batch_size], ordered = self.ordered)
processed_ops = batch_size
self.__log('done')
except BulkWriteError as e:
write_errors = e.details.get('writeErrors', [])
if write_errors:
for err in write_errors:
if err['code'] == 11000 and self.__op_flags.get(err['index'], 0) & self.__FLAG_RETRY_ON_DUP:
# 11000 is a DuplicateKeyError
op = self._queue[err['index']]
self.__log('Will retry after DuplicateKeyError: {}'.format(op))
retry_ops.append(op)
else:
self.__err('ERROR performing queued op #{} {}: {}'.format(err['index'], err['op'],
err['errmsg']))
if self.ordered:
processed_ops = write_errors[0]['index'] + 1
self.__log('keeping {} ops in queue'.format(batch_size - processed_ops))
else:
processed_ops = batch_size
else:
self.__err('UNEXPECTED ERROR {}, {}'.format(e.code, e.details))
except AutoReconnect as e:
self.__err(f'ERROR AutoReconnect: {e}')
except OperationFailure as e:
self.__err(f'ERROR OperationFailure: {e}')
except Exception as e:
self.__err(f'UNEXPECTED MONGOQUEUE ERROR: {e}')
yield self._sleep(2)
retry_count = len(retry_ops)
self._queue[ : processed_ops] = retry_ops
self.__op_flags = {
op_idx - processed_ops + retry_count: op_flags
for op_idx, op_flags in self.__op_flags.items()
if op_idx >= processed_ops
}
# Not calling __step directly to avoid too much recursion under high load
reactor.callLater(0, self.__step)
@staticmethod
def _sleep(delay):
d = defer.Deferred()
reactor.callLater(delay, d.callback, None)
return d
class MongoDatabaseQueue:
active = True
def __init__(self, database: Database, *, ordered, name = '', noisy = False):
self.database = database
self.ordered = ordered
self.noisy = noisy
self.name = name
self._collqueues = {}
def __create_collection_queue(self, collection):
return MongoCollectionQueue(collection, self.ordered, self.name + ' ' + collection.name, self.noisy)
def __getitem__(self, collection_name) -> MongoCollectionQueue:
if collection_name not in self._collqueues:
queue = self.__create_collection_queue(self.database[collection_name])
self._collqueues[collection_name] = queue
if not self.active:
queue.stop()
return queue
return self._collqueues[collection_name]
def __getattr__(self, item):
return self[item]
def stop(self):
self.active = False
return defer.gatherResults([queue.stop() for queue in self._collqueues.values()])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment