public
Created

Trigger a callback when a thread dies. Demonstrate some oddities with thread locals in Python 2.6, before http://bugs.python.org/issue1868 was fixed. See http://emptysquare.net/blog/knowing-when-a-python-thread-has-died/ for more.

  • Download Gist
thread_watcher.py
Python
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
from __future__ import print_function
 
import gc
import threading
import time
import unittest
import weakref
 
from functools import partial
 
 
class ThreadWatcher(object):
class Vigil(object):
pass
 
def __init__(self):
self._refs = {}
self._local = threading.local()
 
def _on_death(self, vigil_id, callback, ref):
self._refs.pop(vigil_id)
callback()
 
def watch(self, callback):
if not self.is_watching():
self._local.vigil = v = ThreadWatcher.Vigil()
on_death = partial(
self._on_death, id(v), callback)
 
ref = weakref.ref(v, on_death)
self._refs[id(v)] = ref
 
def is_watching(self):
"Is the current thread being watched?"
try:
v = self._local.vigil
return id(v) in self._refs
except AttributeError:
return False
 
def unwatch(self):
try:
v = self._local.vigil
del self._local.vigil
self._refs.pop(id(v))
except AttributeError:
pass
 
 
try:
# Python 2
import thread
 
def get_ident():
return thread.get_ident()
 
except ImportError:
# Python 3
def get_ident():
return threading.get_ident()
 
 
class TestWatch(unittest.TestCase):
def test_watch(self):
print('main', get_ident())
watcher = ThreadWatcher()
callback_ran = [False]
 
def callback():
print('callback', get_ident())
callback_ran[0] = True
 
def target():
watcher.watch(callback)
 
t = threading.Thread(target=target)
t.start()
t.join()
# Trigger collection in Py 2.6, see http://bugs.python.org/issue1868
watcher.is_watching()
gc.collect()
for _ in range(10):
if callback_ran[0]:
break
else:
time.sleep(.1)
assert callback_ran[0]
# id(v) removed from _refs
assert not watcher._refs
 
def test_unwatch(self):
watcher = ThreadWatcher()
callback_ran = [False]
 
def callback():
callback_ran[0] = True
 
def target():
watcher.watch(callback)
watcher.unwatch()
 
t = threading.Thread(target=target)
t.start()
t.join()
# Trigger collection in Py 2.6, see http://bugs.python.org/issue1868
watcher.is_watching()
gc.collect()
assert not callback_ran[0]
 
def test_unwatch_twice(self):
watcher = ThreadWatcher()
assert not watcher.is_watching()
watcher.unwatch()
assert not watcher.is_watching()
watcher.watch(lambda _: None)
assert watcher.is_watching()
watcher.unwatch()
assert not watcher.is_watching()
watcher.unwatch()
assert not watcher.is_watching()
 
 
class TestRefLeak(unittest.TestCase):
def test_leak(self):
watcher = ThreadWatcher()
n_callbacks = [0]
nthreads = 10
 
def callback():
# BAD, NO!:
# Accessing thread-local in callback
watcher.is_watching()
n_callbacks[0] += 1
 
def target():
watcher.watch(callback)
 
for _ in range(nthreads):
t = threading.Thread(target=target)
t.start()
t.join()
 
watcher.is_watching()
gc.collect()
for _ in range(10):
if n_callbacks[0] == nthreads:
break
else:
time.sleep(.1)
 
self.assertEqual(nthreads, n_callbacks[0])
 
 
if __name__ == '__main__':
unittest.main()

Please sign in to comment on this gist.

Something went wrong with that request. Please try again.