Skip to content

Instantly share code, notes, and snippets.

@NightTsarina
Created November 12, 2020 17:25
Show Gist options
  • Save NightTsarina/bd42736531a52c2843ee3e6d0d8b0ae2 to your computer and use it in GitHub Desktop.
Save NightTsarina/bd42736531a52c2843ee3e6d0d8b0ae2 to your computer and use it in GitHub Desktop.
LMDB wrapper class
# vim:ts=2:sw=2:et:ai:sts=2
"""Database layer common code."""
import contextlib
import lmdb
import logging
import os.path
import threading
import time
from tina.common import exceptions
from tina.common.util import ParametricSingleton
from tina.metrics.lmdb import LMDBMetrics
LOGGER = logging.getLogger(__name__)
T_ = lambda message: message
class LMDB(ParametricSingleton):
"""LMDB.Environment wrapper."""
#: Resizing factor when database is full.
RESIZE_FACTOR = 1.5
#: Resize if the database usage reached this threshold when being opened.
AUTO_RESIZE_THRESHOLD = .9
Metrics = LMDBMetrics
_DB = None
_dbname = None
_dbpath = None
_dbargs = None
_buffers = None
_tables = None
_needs_reopen = False
_needs_resize = False
_current_size = None
_txn_no = None
_lock = threading.RLock()
_txn_cond = threading.Condition(_lock)
@classmethod
def _singleton_id(cls, dbname, dbpath, *args, **kwargs):
return os.path.realpath(dbpath)
def __init__(self, dbname, dbpath, buffers=False, **kwargs):
with self._lock:
# Watch out for repeated calls.
if self._DB is None:
self._dbname = dbname
self._dbpath = dbpath
self._dbargs = kwargs
self._buffers = buffers
self._tables = {}
self._open()
def _no_transactions(self):
return self._txn_no == 0
def _open(self):
with self._lock:
if self._DB:
raise exceptions.DatabaseError('Attempt to re-open database.')
self._DB = lmdb.Environment(self._dbpath, **self._dbargs)
info = self._DB.info()
stat = self._DB.stat()
current_size = info['map_size']
page_size = stat['psize']
used_size = page_size * (info['last_pgno'] + 1)
if current_size < used_size:
raise exceptions.DatabaseError(
'Possible database corruption: '
'last page (%d) exceeds map size (%d).', used_size, current_size)
if self._current_size and current_size < self._current_size:
LOGGER.warn('LMDB reports smaller map_size (%d B) after reopen',
current_size)
current_size = self._current_size
if self._needs_resize or (
used_size > current_size * self.AUTO_RESIZE_THRESHOLD):
current_size = page_size * round(
current_size * self.RESIZE_FACTOR / page_size)
LOGGER.info('Resizing LMDB database %r to %d bytes.', self._dbname,
current_size)
self._needs_reopen = False
self._needs_resize = False
self._current_size = current_size
self._DB.set_mapsize(current_size)
self._txn_no = 0
stale = self._DB.reader_check()
if stale:
LOGGER.warn(
'Stale entries detected in reader lock table: %d entries.', stale)
info = self._DB.info()
LOGGER.info(
'LMDB database %r opened: path=%s, capacity=%d B, used=%d B',
self._dbname, self._dbpath, current_size,
page_size * (info['last_pgno'] + 1))
LOGGER.debug(
'LMDB database %r flags: %s', self._dbname,
', '.join('%s=%s' % item for item in self._DB.flags().items()))
self.Metrics.register_database(self._dbname, self)
orig_tables = self._tables
self._tables = {}
try:
for name in orig_tables.keys() or ('',):
self.open_table(name)
except:
self._tables = orig_tables
raise
# Save a record so map_size changes are persisted.
with self.begin(write=True) as txn:
txn.put(b'_opened', b'%d' % time.time())
def open_table(self, name, **kwargs):
name = name or ''
with self._lock:
if name in self._tables:
return self._tables[name]
key = name.encode('ascii') if name else None
with self.begin(write=True) as txn:
try:
dbi = txn._open_table(self._DB, name, **kwargs)
except lmdb.NotFoundError as exc:
raise exceptions.TableNotFoundError(str(exc)) from exc
stats = txn.table_stats(name)
flags = txn.table_flags(name)
with self._lock:
self._tables[name] = dbi
used = stats['psize'] * (
stats['branch_pages'] + stats['leaf_pages'] + stats['overflow_pages'])
LOGGER.debug(
'LMDB table %r opened: entries=%d, space used=%d B', name or '<main>',
stats['entries'], used)
LOGGER.debug(
'LMDB table %r flags: %s', name or '<main>',
', '.join('%s=%s' % item for item in flags.items()))
self.Metrics.register_table(self._dbname, name)
def _close(self):
if not self._DB:
return
with self._lock:
LOGGER.debug(
'Closing LMDB database %r, waiting for %d transactions to finish.',
self._dbname, self._txn_no)
self._txn_cond.wait_for(self._no_transactions)
self.Metrics.unregister_database(self._dbname)
self._DB.close()
self._DB = None
def close(self):
with self._lock:
dbpath = self._dbpath
self._close()
self._dbname = self._dbpath = self._dbargs = self._tables = None
super().destroy(dbpath)
LOGGER.info('LMDB database %r closed.', self._dbname)
def reopen(self):
LOGGER.debug('Attempting to reopen LMDB database %r.', self._dbname)
with self._lock:
self._close()
self._open()
def info(self):
return self._DB.info()
def db_stats(self):
tables = tuple(self._tables.keys())
page_size = None
entries = branch_pages = leaf_pages = overflow_pages = 0
with self.begin() as txn:
for table in tables:
stat = txn.table_stats(table)
psize = stat['psize']
entries += stat['entries']
branch_pages += stat['branch_pages']
leaf_pages += stat['leaf_pages']
overflow_pages += stat['overflow_pages']
return {
'psize': psize,
'entries': entries,
'branch_pages': branch_pages,
'leaf_pages': leaf_pages,
'overflow_pages': overflow_pages,
'used_pages': branch_pages + leaf_pages + overflow_pages,
}
def begin(self, table=None, write=False, buffers=None):
with self._lock:
if self._needs_resize or self._needs_reopen:
# Only attempt if there is no risk of deadlock.
if write or self._no_transactions():
self.reopen()
self._txn_no += 1
buffers = self._buffers if buffers is None else buffers
def callback(needs_resize=False):
with self._lock:
if needs_resize:
self._needs_resize = True
self._txn_no -= 1
self._txn_cond.notify_all()
tables = self._tables.copy()
try:
return _Transaction(
self._dbname, self._DB, tables, self.Metrics, callback, table, write,
buffers)
except exceptions.DatabaseResizedError:
with self._lock:
self._needs_reopen = 1
raise
class _Transaction(contextlib.AbstractContextManager):
def __init__(
self, dbname, dbenv, tables, metrics, callback, table, write, buffers):
self._dbname = dbname
self._tables = tables
self._metrics = metrics
self._write = write
self._callback = callback
self._start = time.time()
self._needs_resize = False
self._table = None
self._txn = None
if table:
# Validate.
table, dbi = self._find_table(table)
else:
table, dbi = '', None
self._table = table
LOGGER.debug('Starting DB transaction (write=%s).', write)
try:
with self._op_context(table, 'begin'):
self._txn = dbenv.begin(db=dbi, write=write, buffers=buffers)
except:
self._callback(self._needs_resize)
raise
LOGGER.debug('Started DB transaction %s.', self._txn.id())
def __exit__(self, exc_type, exc, exc_tb):
if exc:
self.abort()
else:
self.commit()
class _OpContext:
def __init__(self, txn, table, op):
self.txn = txn
self.table = table
self.op = op
def __enter__(self):
self.start = time.time()
return self
def __exit__(self, exc_type, exc, exc_tb):
if self.txn._write and self.op != 'open_db':
dur = max(self.start - time.time(), 0)
self.txn._metrics.OP_DURATION.labels(
self.txn._dbname, self.table, self.op).observe(dur)
if not exc_type or not issubclass(exc_type, lmdb.Error):
return # let the exception propagate.
self.txn._metrics.ERRORS.labels(
self.txn._dbname, exc_type.__name__).inc()
if exc_type is lmdb.MapFullError:
LOGGER.error('Database error executing %r: %s', self.op, exc)
self.txn._needs_resize = True
raise exceptions.DatabaseFullError(str(exc)) from None
elif exc_type is lmdb.MapResizedError:
LOGGER.error('Database error executing %r: %s', self.op, exc)
raise exceptions.DatabaseResizedError(str(exc)) from None
elif exc_type is lmdb.NotFoundError:
LOGGER.error('Database error executing %r: %s', self.op, exc)
raise exceptions.TableNotFoundError(str(exc)) from None
else:
LOGGER.error('Database error executing %r: %s', self.op, exc)
raise exceptions.DatabaseError(str(exc)) from None
def _op_context(self, table, op):
return self._OpContext(self, table, op)
def _open_table(self, env, name, **kwargs):
name = name or ''
key = name.encode('ascii') if name else None
with self._op_context(name, 'open_db'):
dbi = env.open_db(key=key, txn=self._txn, **kwargs)
self._tables[name] = dbi
return dbi
def _find_table(self, name):
name = name or self._table or ''
if name in self._tables:
return name, self._tables[name]
else:
raise exceptions.TableNotFoundError('Unknown table: %r' % name)
def abort(self):
if self._txn:
LOGGER.debug('Aborting DB transaction %s.', self._txn.id())
try:
with self._op_context(self._table, 'abort'):
self._txn.abort()
finally:
self._close()
def commit(self):
if self._txn:
LOGGER.debug('Committing DB transaction %s.', self._txn.id())
try:
with self._op_context(self._table, 'commit'):
self._txn.commit()
finally:
self._close()
def _close(self):
#LOGGER.debug('Finished DB transaction.')
self._metrics.TXN_DURATION.labels(
self._dbname, 'write' if self._write else 'read').observe(
max(time.time() - self._start, 0))
self._txn = None
self._callback(self._needs_resize)
def table_flags(self, table):
return self._find_table(table)[1].flags(self._txn)
def table_stats(self, table):
return self._txn.stat(self._find_table(table)[1])
def get(self, key, default=None, table=None):
"""Fetch one element from the table."""
tablename, dbi = self._find_table(table)
with self._op_context(tablename, 'get'):
return self._txn.get(key, default, db=dbi)
def put(self, key, value, dupdata=True, overwrite=True, table=None):
"""Store one element in the table."""
tablename, dbi = self._find_table(table)
with self._op_context(tablename, 'put'):
return self._txn.put(
key, value, dupdata=dupdata, overwrite=overwrite, db=dbi)
def delete(self, key, value=None, table=None):
"""Delete one element from the table."""
tablename, dbi = self._find_table(table)
with self._op_context(tablename, 'del'):
return self._txn.delete(key, value or b'', db=dbi)
def putmulti(self, items, dupdata=True, overwrite=True, table=None):
"""Store multiple elements in the table."""
tablename, dbi = self._find_table(table)
cursor = self._txn.cursor(db=dbi)
with self._op_context(tablename, 'putmulti'):
return cursor.putmulti(items, dupdata=dupdata, overwrite=overwrite)
def replace(self, key, value, table=None):
"""Replace one element in the table, and return its previous value."""
tablename, dbi = self._find_table(table)
if dbi.flags(self._txn)['dupsort']:
raise TypeError('Can run `replace` on a `dupsort` table.')
with self._op_context(tablename, 'replace'):
return self._txn.replace(key, value, db=dbi)
def _cursor_getter(self, cursor, keys, values):
if not values:
return cursor.key
elif not keys:
return cursor.value
else:
return cursor.item
def iter(self, table=None, start=None, stop=None, prefix=None,
skipdup=False, keys=True, values=True):
"""Return an iterator that yields elements according to search criteria."""
tablename, dbi = self._find_table(table)
dupsort = dbi.flags(self._txn)['dupsort']
cursor = self._txn.cursor(db=dbi)
if (start or stop) and prefix:
raise ValueError(
'Cannot pass `start` or `stop` and `prefix` at the same time')
if not dupsort and skipdup:
raise ValueError('Cannot pass `skipdup` in a dupsort=True table')
ok = True
if start:
ok = cursor.set_range(start)
elif prefix:
ok = cursor.set_range(prefix)
else:
ok = cursor.first()
if not ok:
return
if stop:
# Cast to bytes as memoryview has no __ge__ method.
cont_cond = lambda x: bytes(x) <= stop
elif prefix:
p_len = len(prefix)
cont_cond = lambda x: x[:p_len] == prefix
else:
cont_cond = None
nextf = cursor.next_nodup if skipdup else cursor.next
get = self._cursor_getter(cursor, keys, values)
key = cursor.key
context = self._op_context(tablename, 'get')
cur_key = key()
while len(cur_key) and (not cont_cond or cont_cond(cur_key)):
with context:
rv = get()
cont = nextf()
cur_key = key()
yield rv
def first(self, table=None, key=True, value=True):
"""Return the first element in the table."""
tablename, dbi = self._find_table(table)
cursor = self._txn.cursor(db=dbi)
get = self._cursor_getter(cursor, keys, values)
if cursor.first():
return get()
return None, None
def last(self, table=None, key=True, value=True):
"""Return the last element in the table."""
tablename, dbi = self._find_table(table)
cursor = self._txn.cursor(db=dbi)
get = self._cursor_getter(cursor, key, value)
cursor.last()
return get()
def get_key_dup(self, key, table=None):
"""Return all the duplicate values for one key."""
return tuple(self.iter(table, start=key, stop=key, keys=False))
def get_all_prefix(self, prefix, table=None, keys=True, values=True):
"""Return all the elements for a specified key prefix."""
return tuple(self.iter(table, prefix=prefix, keys=keys, values=values))
def count_dup(self, key, table=None):
"""Return the count of duplicate values for one key."""
tablename, dbi = self._find_table(table)
cursor = self._txn.cursor(db=dbi)
with self._op_context(tablename, 'get'):
if not cursor.set_key(key):
return 0
return cursor.count()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment