Skip to content

Instantly share code, notes, and snippets.

@polyvertex
Last active January 19, 2021 19:59
Show Gist options
  • Save polyvertex/e5dacc97350910f080fc85c61af20192 to your computer and use it in GitHub Desktop.
Save polyvertex/e5dacc97350910f080fc85c61af20192 to your computer and use it in GitHub Desktop.
sqlite3 with nested transactions for real
# Copyright (c) Jean-Charles Lefebvre
# SPDX-License-Identifier: MIT
import contextlib
import importlib
import importlib.resources
import os
import re
import sqlite3
import sys
import threading
import types
__all__ = ("SqliteConnection", "SqliteCursor")
DEFAULT_SCHEMA_RESOURCE_REGEX = re.compile(
r"^schema\-(\d+(?:_\d+)?)\.(?:sql|py)$", re.A)
SCHEMA_UPDATER_CALLABLE_NAME = "sqlitedb_update_schema"
class SqliteCursor(sqlite3.Cursor):
"""
A wrapper around `sqlite3.Cursor` that is the default Cursor class for
`SqliteConnection`.
Most notably it re-implements `executescript`, to honor the support of
nested transactions offered by `SqliteConnection`. This is because CPython's
`sqlite3.Connection.executescript` and `sqlite3.Cursor.executescript` do not
take into account the ``isolation_level`` value and forcefully issue a
``COMMIT`` statement before executing the passed SQL script.
This implies that any current transaction is commited at a lower-level
without any chance for us to be notified about the new internal state of the
sqlite3 connection. Thus breaking the support of nested transactions.
So there was no other choice than reimplementing `executescript`. It is done
here by relying on sqlite3 API to parse the SQL script and extract its
statements one by one, in order to execute them manually.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __del__(self):
with contextlib.suppress(Exception):
self.close()
def executescript(self, script, *, source="<memory>"):
stmt_it = sqlite_iterate_script_statements(
script, source=source, with_keyword=True)
for keyword, stmt in stmt_it:
if keyword == "SELECT":
raise sqlite3.ProgrammingError(
"SELECT statements not permitted in executescript method")
elif keyword in (
"BEGIN", "COMMIT", "END",
"SAVEPOINT", "RELEASE", "ROLLBACK"):
raise sqlite3.ProgrammingError(
f"transaction-related statements not permitted in "
f"executescript method (got {keyword} statement)")
self.execute(stmt)
class SqliteConnection:
"""
A wrapper around `sqlite3.Connection` with a schema updating feature and
that truly supports nested transactions with context management by using
sqlite3's ``SAVEPOINT`` feature.
.. seealso::
Python's `issue16958 <https://bugs.python.org/issue16958>`_ about using
`sqlite3.Connection` as a context manager.
"""
def __init__(self, database, **kwargs):
isolation_level = kwargs.pop("isolation_level", None)
if isolation_level is not None:
raise ValueError(
"isolation_level arg specified and different than None")
self.uri = os.fspath(database)
self.conn = sqlite3.connect(self.uri, isolation_level=None, **kwargs)
assert self.conn.isolation_level is None
self.conn.execute("PRAGMA temp_store = MEMORY")
self.conn.execute("PRAGMA journal_mode = WAL")
self._savepoint_lock = threading.RLock()
self._savepoint_id = 0
self._savepoint_stack = []
def __del__(self):
with contextlib.suppress(Exception):
self.close()
def __getattr__(self, name):
if self.conn is None:
raise sqlite3.OperationalError(
f"trying to get or call {name} but database is closed: "
f"{self.uri}")
return getattr(self.conn, name)
def __bool__(self):
return self.conn is not None
def __str__(self):
return self.uri
def __repr__(self):
return f"<{self.__class__.__name__} {self.uri}>"
def __enter__(self):
self._push_savepoint()
return self
def __exit__(self, exc_type, exc_value, traceback):
commit = exc_type is None
self._pop_savepoint(commit=commit, pop=True)
@property
def isolation_level(self):
self.ensure_open()
return self.conn.isolation_level
@isolation_level.setter
def isolation_level(self, value):
raise sqlite3.NotSupportedError(
"isolation_level change not supported by this wrapper")
@property
def in_transaction(self):
return self.conn is not None and self.conn.in_transaction
def close(self):
with self._savepoint_lock:
if self.conn is not None:
self.conn.close()
self.conn = None
if self._savepoint_stack:
self._savepoint_stack = []
def ensure_open(self):
"""Raise `RuntimeError` if `close` has been called already"""
if self.conn is None:
raise sqlite3.OperationalError(
f"database connection closed: {self.uri}")
assert bool(self._savepoint_stack) == bool(self.conn.in_transaction)
def cursor(self, factory=SqliteCursor):
"""
The Cursor class factory.
It is important to use `SqliteCursor` or a derived class as a factory,
due to the reimplementation of `executescript`.
"""
self.ensure_open()
return self.conn.cursor(factory=factory)
def executescript(self, script, *, source="<memory>"):
# see `SqliteCursor` for the rationale behind this reimplementation
cursor = self.cursor()
cursor.executescript(script, source=source)
cursor.close()
def commit(self):
"""
Commit the current transaction if not already commited or released.
This method must only be called from a context.
"""
with self._savepoint_lock:
self.ensure_open()
if not self._savepoint_stack:
raise sqlite3.OperationalError(
"commit() called outside of a transaction context")
else:
self._pop_savepoint(commit=True, pop=False)
def rollback(self):
"""
Rollback the current transaction if not already commited or released.
This method must only be called from a context.
"""
with self._savepoint_lock:
self.ensure_open()
if not self._savepoint_stack:
raise sqlite3.OperationalError(
"rollback() called outside of a transaction context")
else:
self._pop_savepoint(commit=False, pop=False)
def fetchone(self, sql, parameters=()):
"""Shorthand for an `execute` call folowed by `fetchone`"""
cursor = self.execute(sql, parameters)
row = cursor.fetchone()
cursor.close()
del cursor
return row
def fetchmany(self, sql, parameters=(), size=None):
"""Shorthand for an `execute` call folowed by `fetchmany`"""
cursor = self.execute(sql, parameters)
if not size:
size = cursor.arraysize
rows = cursor.fetchmany(size)
cursor.close()
del cursor
return rows
def fetchall(self, sql, parameters=()):
"""Shorthand for an `execute` call folowed by `fetchall`"""
cursor = self.execute(sql, parameters)
rows = cursor.fetchall()
cursor.close()
del cursor
return rows
def create_or_update_schema(
self, meta_table, meta_column, resource_package, *,
schema_resource_regex=None):
"""
Get the current schema version of the database using *meta_table* and
*meta_column* names, then apply all the schema updates found in
*resource_package* if any.
*resource_package* must be a module object.
Return a `tuple` of two `int`: the detected version number before
applying any update (may be zero if database was not created), and the
version number of the latest update applied by this method, which may be
equal to the first value in the tuple.
"""
self.ensure_open()
initial_version = self.get_installed_schema_version(
meta_table, meta_column)
# get all the available schema updates
manifest = self.get_schema_resources_manifest(resource_package)
if not manifest:
raise sqlite3.OperationalError(
f"empty SQL schema manifest for package: "
f"{resource_package.__name__}")
# apply updates
latest_version = manifest.apply_updates(self, initial_version)
return (initial_version, latest_version)
def get_installed_schema_version(self, table_name, column_name):
"""
Used by `create_or_update_schema` to get the current database schema
version.
This method executes a ``SELECT`` statement using *table_name* and
*column_name* and return the value of the *column_name* value of the
first row. Expected to be an `int`.
Additionally, the requested table is expected to be a one-row table.
This method raises `RuntimeError` in case the number of rows is
different than one.
"""
self.ensure_open()
try:
cursor = self.execute(
f"SELECT {column_name} FROM {table_name} LIMIT 2")
except sqlite3.OperationalError: # missing table
cursor = None
if not cursor:
return 0
rows = cursor.fetchall()
if not rows:
raise sqlite3.DatabaseError(
f"missing {table_name}.{column_name} value in database: "
f"{self.uri}")
elif len(rows) > 1:
raise sqlite3.DatabaseError(
f"unexpected multiple rows in table {table_name} in database: "
f"{self.uri}")
version = rows[0][0]
assert isinstance(version, int)
return version
def get_schema_resources_manifest(
self, resource_package, *, schema_resource_regex=None):
"""
Used by `create_or_update_schema` to get a `SqliteSchemasManifest`
object populated with all the database schema resources found under
*resource_package*.
*resource_package* must be a module object.
"""
if not schema_resource_regex:
schema_resource_regex = DEFAULT_SCHEMA_RESOURCE_REGEX
manifest = SqliteSchemasManifest(resource_package)
for res_name in importlib.resources.contents(resource_package):
if rem := schema_resource_regex.fullmatch(res_name):
dbver = int(rem.group(1))
if not dbver:
raise ValueError(
f"SQL package {resource_package.__name__} contains a "
f"resource with a schema version value of zero: "
f"{res_name}")
manifest.register_resource(dbver, res_name)
return manifest
def _push_savepoint(self):
with self._savepoint_lock:
self.ensure_open()
if __debug__:
if not self._savepoint_stack:
assert not self.conn.in_transaction
self._savepoint_id += 1
savepoint = f"SqliteConnTx_{self._savepoint_id}"
self.conn.execute(f"SAVEPOINT {savepoint}")
self._savepoint_stack.append(savepoint)
assert self.conn.in_transaction
def _pop_savepoint(self, *, commit, pop):
with self._savepoint_lock:
if not self._savepoint_stack:
return None
if self.conn is None:
self._savepoint_stack = []
return None
else:
if pop:
savepoint = self._savepoint_stack.pop(-1)
else:
savepoint = self._savepoint_stack[-1]
if savepoint is not None:
self._savepoint_stack[-1] = None
# reminder: savepoint may be None due to commit() or
# rollback() methods
if savepoint:
assert self.conn.in_transaction
verb = "RELEASE" if commit else "ROLLBACK TO"
self.conn.execute(f"{verb} SAVEPOINT {savepoint}")
return savepoint
class SqliteSchemaResource:
"""
Utility class for the schema resource(s) associated with a single version.
Created by `SqliteSchemasManifest`. Not meant to be instanciated nor used
directly.
"""
def __init__(self, resource_package, version):
assert isinstance(resource_package, types.ModuleType)
assert isinstance(version, int)
self.resource_package = resource_package
self.version = version
self.sql_resource_name = None
self.py_resource_name = None
@property
def has_sql(self):
return bool(self.sql_resource_name)
@property
def has_py(self):
return bool(self.py_resource_name)
@property
def py_module_name(self):
if not self.has_py:
raise ValueError(
f"no Python module for schema version {self.version}")
assert self.py_resource_name.lower().endswith(".py")
return "{}.{}".format(
self.resource_package.__name__,
self.py_resource_name[0:-len(".py")])
def apply_update(self, db):
"""
Apply schema update to the provided Connection or Cursor object *db*
"""
if self.has_sql:
db.executescript(self._extract_sql())
if self.has_py:
module = self._import_py()
try:
try:
func = getattr(module, SCHEMA_UPDATER_CALLABLE_NAME)
except AttributeError:
func = None
if not func or not callable(func):
raise ValueError(
f"{module.__name__}.{SCHEMA_UPDATER_CALLABLE_NAME} "
f"missing or is not a callable")
func(db)
finally:
# release module
del func
modname = module.__name__
del module
del sys.modules[modname]
def _extract_sql(self):
if not self.has_sql:
raise ValueError(f"no SQL resource for schema version {self.version}")
return importlib.resources.read_text(
self.resource_package, self.sql_resource_name,
encoding="utf-8", errors="strict")
def _import_py(self):
if not self.has_sql:
raise ImportError(
f"no Python module for schema version {self.version}")
return importlib.import_module(self.py_module_name)
class SqliteSchemasManifest:
"""
A snapshot of the schema resources (``.sql`` and ``.py``) embedded in the
passed Python package.
Created by `SqliteConnection.get_schema_resources_manifest`. Not meant to be
instanciated directly.
"""
def __init__(self, resource_package):
assert isinstance(resource_package, types.ModuleType)
self.resource_package = resource_package
self._modified = False
self._resources = {}
self._oldest_version = None
self._latest_version = None
def __len__(self):
return len(self._resources)
def __iter__(self):
if self._modified:
self._sort()
# guaranteed by _sort() to be ordered by ascending version
return self._resources.values()
def __contains__(self, version):
return self.has_schema(version)
def __getitem__(self, version):
return self.get_schema(version)
@property
def oldest_version(self):
"""The smallest version number registered (`int`)"""
if self._modified:
self._sort()
return self._oldest_version
@property
def latest_version(self):
"""The biggest version number registered (`int`)"""
if self._modified:
self._sort()
return self._latest_version
@property
def versions(self):
"""
The `list` of registered versions so far.
List is ordered by ascending version number.
"""
if self._modified:
self._sort()
# guaranteed by _sort() to be ordered by ascending version
return list(self._resources.keys())
def has_schema(self, version):
"""
Check if a schema of the passed *version* exists and return a `bool`
value.
A null *version* value stands for "the oldest version".
"""
try:
self.get_schema(version)
return True
except IndexError:
return False
def get_schema(self, version):
"""
Get the `SqliteSchemaResource` object associated to the passed
*version*.
A null *version* value stands for "the oldest version".
Raise `IndexError` if *version* was not found.
"""
assert isinstance(version, int)
self._sort()
if not self._resources:
raise IndexError("no schema in manifest")
assert self._oldest_version is not None
assert self._latest_version is not None
if not version:
version = self._oldest_version
return self._resources[version] # may raise IndexError
def apply_updates(self, db, from_version):
"""
Apply every schema updates available from the specified *from_version*
number (non-included unless it is zero), up to the latest version
available in embedded resources.
*db* must be either a Connection or a Cursor compatible object.
Return the latest version number applied. This value may be equal to
*from_version* if database was up-to-date already.
"""
assert isinstance(from_version, int)
self._sort()
if from_version and from_version not in self._resources:
raise ValueError(f"unknown schema version {from_version}")
# _sort() guarantees self._resources to be ordered by ascending version
latest_version = from_version
for schema in self._resources.values():
if not from_version or schema.version > from_version:
schema.apply_update(db)
latest_version = schema.version
return latest_version
def register_resource(self, schema_version, resource_name):
"""
Used by `SqliteConnection.get_schema_resources_manifest` to register an
embbeded resource in the manifest
"""
assert isinstance(schema_version, int)
if schema_version in self._resources:
schema = self._resources[schema_version]
else:
schema = SqliteSchemaResource(self.resource_package, schema_version)
self._resources[schema_version] = schema
self._modified = True
if resource_name.lower().endswith(".sql"):
assert schema.sql_resource_name is None
schema.sql_resource_name = resource_name
elif resource_name.lower().endswith(".py"):
assert schema.py_resource_name is None
schema.py_resource_name = resource_name
else:
raise ValueError(f"unknown resource type: {resource_name}")
def _sort(self):
"""Ensure ``self._resources`` is ordered by ascending version"""
if self._modified:
versions = list(self._resources.keys())
versions.sort()
old_dict = self._resources
self._resources = {}
self._oldest_version = versions[0]
self._latest_version = versions[-1]
# CAUTION: this assumes Python 3.7+ (i.e. ordered dict)
assert sys.version_info >= (3, 7)
for ver in versions:
assert isinstance(ver, int)
self._resources[ver] = old_dict[ver]
self._modified = False
def sqlite_iterate_script_statements(
script, *, source="<memory>", with_keyword=False):
"""
Yield SQL `str` statements from a sqlite3-compatible *script*.
If *with_keyword* is true, each yielded value is a `tuple` containing a pair
of `str` objects: the keyword of the statement and the statement itself.
"""
def _prepare_yield(st):
st = st.strip()
return st if not with_keyword else (st.split(maxsplit=1)[0].upper(), st)
if not sqlite3.complete_statement(script):
raise sqlite3.ProgrammingError(
f"not a complete SQL statement or script: {source}")
stmt = "" # current statement
for line in script.splitlines(keepends=True):
if ";" not in line:
stmt += line
if sqlite3.complete_statement(stmt):
yield _prepare_yield(stmt)
stmt = ""
else:
parts = line.split(";")
for idx, part in enumerate(parts):
stmt += part
if idx < len(parts) - 1:
stmt += ";"
if sqlite3.complete_statement(stmt):
yield _prepare_yield(stmt)
stmt = ""
# Trailing data, if any, can be safely ignored because the whole script has
# been validated at the beginning of this function so that any remaining
# data is likely to be space characters or comment.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment