Skip to content

Instantly share code, notes, and snippets.

@hishnash
Created March 13, 2017 09:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save hishnash/52e5622de571c10f4d5f12368b205620 to your computer and use it in GitHub Desktop.
Save hishnash/52e5622de571c10f4d5f12368b205620 to your computer and use it in GitHub Desktop.
from contextlib import contextmanager
from django.db import transaction
from rest_framework.exceptions import APIException
class LockTakenException(APIException, OperationalError):
status_code = 423
default_detail = 'This item is currently locked by another request.'
client_code = 'ObjectLocked'
def ensure_atomic(func):
@contextmanager
def f(*args, existing_transaction=False, **kwargs):
if not existing_transaction:
with transaction.atomic():
with func(*args, **kwargs) as data:
yield data
else:
with func(*args, **kwargs) as data:
yield data
return f
class LockableModel(TimeStampedModel):
class Meta:
abstract = True
@ensure_atomic
@contextmanager
def get_shared_lock(self,
_connection=connection,
timeout=settings.PG_LOCKING_TIMEOUT_MS,
wait=True):
cursor = _connection.cursor()
cursor.execute('SET LOCAL statement_timeout = %s;', (timeout,))
sql = 'SELECT ctid from {table_name} WHERE id=%s FOR SHARE'.format(
table_name=self._meta.db_table
)
if not wait:
sql += ' NOWAIT'
try:
cursor.execute(sql, (self.pk,))
except OperationalError:
raise LockTakenException()
else:
yield True
@ensure_atomic
@contextmanager
def get_exclusive_lock(self,
_connection=connection,
timeout=settings.PG_LOCKING_TIMEOUT_MS,
wait=True):
cursor = _connection.cursor()
cursor.execute('SET LOCAL statement_timeout = %s;', (timeout,))
sql = 'SELECT ctid from {table_name} WHERE id=%s FOR UPDATE'.format(
table_name=self._meta.db_table
)
if not wait:
sql += ' NOWAIT'
try:
cursor.execute(sql, (self.pk,))
except OperationalError:
raise LockTakenException()
else:
yield True
def lock_state(self):
if self.pk is None:
return None
cursor = connection.cursor()
sql = '''
SELECT modes FROM {table_name} AS a, pgrowlocks('{table_name}') AS p
WHERE p.locked_row = a.ctid AND a.id = %s;
'''.format(table_name=self._meta.db_table)
cursor.execute(sql, (self.pk,))
output = cursor.fetchall()
if len(output) == 0:
return None
locs = []
for row in output:
locs.append(row[0])
return locs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment