Skip to content

Instantly share code, notes, and snippets.

@fangpenlin
Last active August 29, 2015 14:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save fangpenlin/d9119dd9dfdd5ac3836b to your computer and use it in GitHub Desktop.
Save fangpenlin/d9119dd9dfdd5ac3836b to your computer and use it in GitHub Desktop.
import time
import threading
import pytest
import psycopg2
def acquire(conn, poll_period=0.1):
cur = conn.cursor()
while True:
cur.execute(
"""SELECT id, %s FROM my_locks
WHERE locked = false AND id = 'foobar' LIMIT 1 FOR UPDATE
""",
(threading.current_thread().name, ),
)
# when we get here, we should found the lock is released, but it's
# really odd, even the select returns empty result, it still acquires
# lock for the row?
if not cur.fetchall():
time.sleep(poll_period)
continue
cur.execute(
"""UPDATE my_locks SET locked = true WHERE id = 'foobar' AND '' != %s""",
(threading.current_thread().name, )
)
conn.commit()
break
cur.close()
def release(conn):
cur = conn.cursor()
cur.execute(
"""UPDATE my_locks SET locked = false WHERE id = 'foobar' AND '' != %s""",
(threading.current_thread().name, )
)
conn.commit()
cur.close()
def make_conn():
conn = psycopg2.connect('dbname=test user=postgres')
return conn
@pytest.fixture
def clean_my_locks():
conn = make_conn()
cur = conn.cursor()
cur.execute(
"""DROP TABLE IF EXISTS my_locks """
)
cur.execute(
"""CREATE TABLE my_locks (id varchar PRIMARY KEY, locked boolean);"""
)
cur.execute("""INSERT INTO my_locks VALUES ('foobar', false)""")
conn.commit()
def test_threads_competing(clean_my_locks):
logs = []
def critical_section(conn):
try:
print(threading.current_thread(), 'acquiring')
acquire(conn)
print(threading.current_thread(), 'acquired')
logs.append('enter')
finally:
print(threading.current_thread(), 'releasing')
release(conn)
print(threading.current_thread(), 'released')
logs.append('leave')
def worker():
conn = make_conn()
for _ in range(3):
critical_section(conn)
threads = []
for _ in range(3):
thread = threading.Thread(target=worker)
thread.daemon = True
thread.start()
threads.append(thread)
for thread in threads:
thread.join()
for log in logs:
print('!'*10, log)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment