-
-
Save NightTsarina/bd42736531a52c2843ee3e6d0d8b0ae2 to your computer and use it in GitHub Desktop.
LMDB wrapper class
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# vim:ts=2:sw=2:et:ai:sts=2 | |
"""Database layer common code.""" | |
import contextlib | |
import lmdb | |
import logging | |
import os.path | |
import threading | |
import time | |
from tina.common import exceptions | |
from tina.common.util import ParametricSingleton | |
from tina.metrics.lmdb import LMDBMetrics | |
LOGGER = logging.getLogger(__name__) | |
T_ = lambda message: message | |
class LMDB(ParametricSingleton): | |
"""LMDB.Environment wrapper.""" | |
#: Resizing factor when database is full. | |
RESIZE_FACTOR = 1.5 | |
#: Resize if the database usage reached this threshold when being opened. | |
AUTO_RESIZE_THRESHOLD = .9 | |
Metrics = LMDBMetrics | |
_DB = None | |
_dbname = None | |
_dbpath = None | |
_dbargs = None | |
_buffers = None | |
_tables = None | |
_needs_reopen = False | |
_needs_resize = False | |
_current_size = None | |
_txn_no = None | |
_lock = threading.RLock() | |
_txn_cond = threading.Condition(_lock) | |
@classmethod | |
def _singleton_id(cls, dbname, dbpath, *args, **kwargs): | |
return os.path.realpath(dbpath) | |
def __init__(self, dbname, dbpath, buffers=False, **kwargs): | |
with self._lock: | |
# Watch out for repeated calls. | |
if self._DB is None: | |
self._dbname = dbname | |
self._dbpath = dbpath | |
self._dbargs = kwargs | |
self._buffers = buffers | |
self._tables = {} | |
self._open() | |
def _no_transactions(self): | |
return self._txn_no == 0 | |
def _open(self): | |
with self._lock: | |
if self._DB: | |
raise exceptions.DatabaseError('Attempt to re-open database.') | |
self._DB = lmdb.Environment(self._dbpath, **self._dbargs) | |
info = self._DB.info() | |
stat = self._DB.stat() | |
current_size = info['map_size'] | |
page_size = stat['psize'] | |
used_size = page_size * (info['last_pgno'] + 1) | |
if current_size < used_size: | |
raise exceptions.DatabaseError( | |
'Possible database corruption: ' | |
'last page (%d) exceeds map size (%d).', used_size, current_size) | |
if self._current_size and current_size < self._current_size: | |
LOGGER.warn('LMDB reports smaller map_size (%d B) after reopen', | |
current_size) | |
current_size = self._current_size | |
if self._needs_resize or ( | |
used_size > current_size * self.AUTO_RESIZE_THRESHOLD): | |
current_size = page_size * round( | |
current_size * self.RESIZE_FACTOR / page_size) | |
LOGGER.info('Resizing LMDB database %r to %d bytes.', self._dbname, | |
current_size) | |
self._needs_reopen = False | |
self._needs_resize = False | |
self._current_size = current_size | |
self._DB.set_mapsize(current_size) | |
self._txn_no = 0 | |
stale = self._DB.reader_check() | |
if stale: | |
LOGGER.warn( | |
'Stale entries detected in reader lock table: %d entries.', stale) | |
info = self._DB.info() | |
LOGGER.info( | |
'LMDB database %r opened: path=%s, capacity=%d B, used=%d B', | |
self._dbname, self._dbpath, current_size, | |
page_size * (info['last_pgno'] + 1)) | |
LOGGER.debug( | |
'LMDB database %r flags: %s', self._dbname, | |
', '.join('%s=%s' % item for item in self._DB.flags().items())) | |
self.Metrics.register_database(self._dbname, self) | |
orig_tables = self._tables | |
self._tables = {} | |
try: | |
for name in orig_tables.keys() or ('',): | |
self.open_table(name) | |
except: | |
self._tables = orig_tables | |
raise | |
# Save a record so map_size changes are persisted. | |
with self.begin(write=True) as txn: | |
txn.put(b'_opened', b'%d' % time.time()) | |
def open_table(self, name, **kwargs): | |
name = name or '' | |
with self._lock: | |
if name in self._tables: | |
return self._tables[name] | |
key = name.encode('ascii') if name else None | |
with self.begin(write=True) as txn: | |
try: | |
dbi = txn._open_table(self._DB, name, **kwargs) | |
except lmdb.NotFoundError as exc: | |
raise exceptions.TableNotFoundError(str(exc)) from exc | |
stats = txn.table_stats(name) | |
flags = txn.table_flags(name) | |
with self._lock: | |
self._tables[name] = dbi | |
used = stats['psize'] * ( | |
stats['branch_pages'] + stats['leaf_pages'] + stats['overflow_pages']) | |
LOGGER.debug( | |
'LMDB table %r opened: entries=%d, space used=%d B', name or '<main>', | |
stats['entries'], used) | |
LOGGER.debug( | |
'LMDB table %r flags: %s', name or '<main>', | |
', '.join('%s=%s' % item for item in flags.items())) | |
self.Metrics.register_table(self._dbname, name) | |
def _close(self): | |
if not self._DB: | |
return | |
with self._lock: | |
LOGGER.debug( | |
'Closing LMDB database %r, waiting for %d transactions to finish.', | |
self._dbname, self._txn_no) | |
self._txn_cond.wait_for(self._no_transactions) | |
self.Metrics.unregister_database(self._dbname) | |
self._DB.close() | |
self._DB = None | |
def close(self): | |
with self._lock: | |
dbpath = self._dbpath | |
self._close() | |
self._dbname = self._dbpath = self._dbargs = self._tables = None | |
super().destroy(dbpath) | |
LOGGER.info('LMDB database %r closed.', self._dbname) | |
def reopen(self): | |
LOGGER.debug('Attempting to reopen LMDB database %r.', self._dbname) | |
with self._lock: | |
self._close() | |
self._open() | |
def info(self): | |
return self._DB.info() | |
def db_stats(self): | |
tables = tuple(self._tables.keys()) | |
page_size = None | |
entries = branch_pages = leaf_pages = overflow_pages = 0 | |
with self.begin() as txn: | |
for table in tables: | |
stat = txn.table_stats(table) | |
psize = stat['psize'] | |
entries += stat['entries'] | |
branch_pages += stat['branch_pages'] | |
leaf_pages += stat['leaf_pages'] | |
overflow_pages += stat['overflow_pages'] | |
return { | |
'psize': psize, | |
'entries': entries, | |
'branch_pages': branch_pages, | |
'leaf_pages': leaf_pages, | |
'overflow_pages': overflow_pages, | |
'used_pages': branch_pages + leaf_pages + overflow_pages, | |
} | |
def begin(self, table=None, write=False, buffers=None): | |
with self._lock: | |
if self._needs_resize or self._needs_reopen: | |
# Only attempt if there is no risk of deadlock. | |
if write or self._no_transactions(): | |
self.reopen() | |
self._txn_no += 1 | |
buffers = self._buffers if buffers is None else buffers | |
def callback(needs_resize=False): | |
with self._lock: | |
if needs_resize: | |
self._needs_resize = True | |
self._txn_no -= 1 | |
self._txn_cond.notify_all() | |
tables = self._tables.copy() | |
try: | |
return _Transaction( | |
self._dbname, self._DB, tables, self.Metrics, callback, table, write, | |
buffers) | |
except exceptions.DatabaseResizedError: | |
with self._lock: | |
self._needs_reopen = 1 | |
raise | |
class _Transaction(contextlib.AbstractContextManager): | |
def __init__( | |
self, dbname, dbenv, tables, metrics, callback, table, write, buffers): | |
self._dbname = dbname | |
self._tables = tables | |
self._metrics = metrics | |
self._write = write | |
self._callback = callback | |
self._start = time.time() | |
self._needs_resize = False | |
self._table = None | |
self._txn = None | |
if table: | |
# Validate. | |
table, dbi = self._find_table(table) | |
else: | |
table, dbi = '', None | |
self._table = table | |
LOGGER.debug('Starting DB transaction (write=%s).', write) | |
try: | |
with self._op_context(table, 'begin'): | |
self._txn = dbenv.begin(db=dbi, write=write, buffers=buffers) | |
except: | |
self._callback(self._needs_resize) | |
raise | |
LOGGER.debug('Started DB transaction %s.', self._txn.id()) | |
def __exit__(self, exc_type, exc, exc_tb): | |
if exc: | |
self.abort() | |
else: | |
self.commit() | |
class _OpContext: | |
def __init__(self, txn, table, op): | |
self.txn = txn | |
self.table = table | |
self.op = op | |
def __enter__(self): | |
self.start = time.time() | |
return self | |
def __exit__(self, exc_type, exc, exc_tb): | |
if self.txn._write and self.op != 'open_db': | |
dur = max(self.start - time.time(), 0) | |
self.txn._metrics.OP_DURATION.labels( | |
self.txn._dbname, self.table, self.op).observe(dur) | |
if not exc_type or not issubclass(exc_type, lmdb.Error): | |
return # let the exception propagate. | |
self.txn._metrics.ERRORS.labels( | |
self.txn._dbname, exc_type.__name__).inc() | |
if exc_type is lmdb.MapFullError: | |
LOGGER.error('Database error executing %r: %s', self.op, exc) | |
self.txn._needs_resize = True | |
raise exceptions.DatabaseFullError(str(exc)) from None | |
elif exc_type is lmdb.MapResizedError: | |
LOGGER.error('Database error executing %r: %s', self.op, exc) | |
raise exceptions.DatabaseResizedError(str(exc)) from None | |
elif exc_type is lmdb.NotFoundError: | |
LOGGER.error('Database error executing %r: %s', self.op, exc) | |
raise exceptions.TableNotFoundError(str(exc)) from None | |
else: | |
LOGGER.error('Database error executing %r: %s', self.op, exc) | |
raise exceptions.DatabaseError(str(exc)) from None | |
def _op_context(self, table, op): | |
return self._OpContext(self, table, op) | |
def _open_table(self, env, name, **kwargs): | |
name = name or '' | |
key = name.encode('ascii') if name else None | |
with self._op_context(name, 'open_db'): | |
dbi = env.open_db(key=key, txn=self._txn, **kwargs) | |
self._tables[name] = dbi | |
return dbi | |
def _find_table(self, name): | |
name = name or self._table or '' | |
if name in self._tables: | |
return name, self._tables[name] | |
else: | |
raise exceptions.TableNotFoundError('Unknown table: %r' % name) | |
def abort(self): | |
if self._txn: | |
LOGGER.debug('Aborting DB transaction %s.', self._txn.id()) | |
try: | |
with self._op_context(self._table, 'abort'): | |
self._txn.abort() | |
finally: | |
self._close() | |
def commit(self): | |
if self._txn: | |
LOGGER.debug('Committing DB transaction %s.', self._txn.id()) | |
try: | |
with self._op_context(self._table, 'commit'): | |
self._txn.commit() | |
finally: | |
self._close() | |
def _close(self): | |
#LOGGER.debug('Finished DB transaction.') | |
self._metrics.TXN_DURATION.labels( | |
self._dbname, 'write' if self._write else 'read').observe( | |
max(time.time() - self._start, 0)) | |
self._txn = None | |
self._callback(self._needs_resize) | |
def table_flags(self, table): | |
return self._find_table(table)[1].flags(self._txn) | |
def table_stats(self, table): | |
return self._txn.stat(self._find_table(table)[1]) | |
def get(self, key, default=None, table=None): | |
"""Fetch one element from the table.""" | |
tablename, dbi = self._find_table(table) | |
with self._op_context(tablename, 'get'): | |
return self._txn.get(key, default, db=dbi) | |
def put(self, key, value, dupdata=True, overwrite=True, table=None): | |
"""Store one element in the table.""" | |
tablename, dbi = self._find_table(table) | |
with self._op_context(tablename, 'put'): | |
return self._txn.put( | |
key, value, dupdata=dupdata, overwrite=overwrite, db=dbi) | |
def delete(self, key, value=None, table=None): | |
"""Delete one element from the table.""" | |
tablename, dbi = self._find_table(table) | |
with self._op_context(tablename, 'del'): | |
return self._txn.delete(key, value or b'', db=dbi) | |
def putmulti(self, items, dupdata=True, overwrite=True, table=None): | |
"""Store multiple elements in the table.""" | |
tablename, dbi = self._find_table(table) | |
cursor = self._txn.cursor(db=dbi) | |
with self._op_context(tablename, 'putmulti'): | |
return cursor.putmulti(items, dupdata=dupdata, overwrite=overwrite) | |
def replace(self, key, value, table=None): | |
"""Replace one element in the table, and return its previous value.""" | |
tablename, dbi = self._find_table(table) | |
if dbi.flags(self._txn)['dupsort']: | |
raise TypeError('Can run `replace` on a `dupsort` table.') | |
with self._op_context(tablename, 'replace'): | |
return self._txn.replace(key, value, db=dbi) | |
def _cursor_getter(self, cursor, keys, values): | |
if not values: | |
return cursor.key | |
elif not keys: | |
return cursor.value | |
else: | |
return cursor.item | |
def iter(self, table=None, start=None, stop=None, prefix=None, | |
skipdup=False, keys=True, values=True): | |
"""Return an iterator that yields elements according to search criteria.""" | |
tablename, dbi = self._find_table(table) | |
dupsort = dbi.flags(self._txn)['dupsort'] | |
cursor = self._txn.cursor(db=dbi) | |
if (start or stop) and prefix: | |
raise ValueError( | |
'Cannot pass `start` or `stop` and `prefix` at the same time') | |
if not dupsort and skipdup: | |
raise ValueError('Cannot pass `skipdup` in a dupsort=True table') | |
ok = True | |
if start: | |
ok = cursor.set_range(start) | |
elif prefix: | |
ok = cursor.set_range(prefix) | |
else: | |
ok = cursor.first() | |
if not ok: | |
return | |
if stop: | |
# Cast to bytes as memoryview has no __ge__ method. | |
cont_cond = lambda x: bytes(x) <= stop | |
elif prefix: | |
p_len = len(prefix) | |
cont_cond = lambda x: x[:p_len] == prefix | |
else: | |
cont_cond = None | |
nextf = cursor.next_nodup if skipdup else cursor.next | |
get = self._cursor_getter(cursor, keys, values) | |
key = cursor.key | |
context = self._op_context(tablename, 'get') | |
cur_key = key() | |
while len(cur_key) and (not cont_cond or cont_cond(cur_key)): | |
with context: | |
rv = get() | |
cont = nextf() | |
cur_key = key() | |
yield rv | |
def first(self, table=None, key=True, value=True): | |
"""Return the first element in the table.""" | |
tablename, dbi = self._find_table(table) | |
cursor = self._txn.cursor(db=dbi) | |
get = self._cursor_getter(cursor, keys, values) | |
if cursor.first(): | |
return get() | |
return None, None | |
def last(self, table=None, key=True, value=True): | |
"""Return the last element in the table.""" | |
tablename, dbi = self._find_table(table) | |
cursor = self._txn.cursor(db=dbi) | |
get = self._cursor_getter(cursor, key, value) | |
cursor.last() | |
return get() | |
def get_key_dup(self, key, table=None): | |
"""Return all the duplicate values for one key.""" | |
return tuple(self.iter(table, start=key, stop=key, keys=False)) | |
def get_all_prefix(self, prefix, table=None, keys=True, values=True): | |
"""Return all the elements for a specified key prefix.""" | |
return tuple(self.iter(table, prefix=prefix, keys=keys, values=values)) | |
def count_dup(self, key, table=None): | |
"""Return the count of duplicate values for one key.""" | |
tablename, dbi = self._find_table(table) | |
cursor = self._txn.cursor(db=dbi) | |
with self._op_context(tablename, 'get'): | |
if not cursor.set_key(key): | |
return 0 | |
return cursor.count() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment