Skip to content

Instantly share code, notes, and snippets.

@mgedmin
Created December 22, 2015 08:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mgedmin/a91872054884dbaaa344 to your computer and use it in GitHub Desktop.
Save mgedmin/a91872054884dbaaa344 to your computer and use it in GitHub Desktop.
diff --git a/Lib/unittest/signals.py b/Lib/unittest/signals.py
index e6a5fc5..1ba8b43 100644
--- a/Lib/unittest/signals.py
+++ b/Lib/unittest/signals.py
@@ -6,6 +6,10 @@ from functools import wraps
__unittest = True
+def _do_nothing_handler(unused_signum, unused_frame):
+ pass
+
+
class _InterruptHandler(object):
def __init__(self, default_handler):
self.called = False
@@ -17,8 +21,7 @@ class _InterruptHandler(object):
elif default_handler == signal.SIG_IGN:
# Not quite the same thing as SIG_IGN, but the closest we
# can make it: do nothing.
- def default_handler(unused_signum, unused_frame):
- pass
+ default_handler = _do_nothing_handler
else:
raise TypeError("expected SIGINT signal handler to be "
"signal.SIG_IGN, signal.SIG_DFL, or a "
@@ -31,12 +34,19 @@ class _InterruptHandler(object):
# if we aren't the installed handler, then delegate immediately
# to the default handler
self.default_handler(signum, frame)
+ return
if self.called:
self.default_handler(signum, frame)
+ return
self.called = True
+ stopped = False
for result in _results.keys():
result.stop()
+ stopped = True
+ if not stopped:
+ # if there aren't any registered results, delegate immediately
+ self.default_handler(signum, frame)
_results = weakref.WeakKeyDictionary()
def registerResult(result):
diff --git a/Lib/unittest/test/test_signals.py b/Lib/unittest/test/test_signals.py
new file mode 100644
index 0000000..5116796
--- /dev/null
+++ b/Lib/unittest/test/test_signals.py
@@ -0,0 +1,61 @@
+import inspect
+import signal
+import unittest
+from unittest import mock, signals
+
+
+class Test_InterruptHandler(unittest.TestCase):
+
+ def test_init_recognizes_default_handler(self):
+ handler = signals._InterruptHandler(signal.SIG_DFL)
+ self.assertEqual(handler.default_handler, signal.default_int_handler)
+
+ def test_init_recognizes_sigign(self):
+ handler = signals._InterruptHandler(signal.SIG_IGN)
+ self.assertEqual(handler.default_handler, signals._do_nothing_handler)
+
+ def test_init_refuses_unexpected_values(self):
+ with self.assertRaises(TypeError):
+ signals._InterruptHandler(42)
+
+ @mock.patch('unittest.signals._results', {})
+ def test_call_with_no_registered_results(self):
+ default_handler = mock.Mock()
+ handler = signals._InterruptHandler(default_handler)
+ with mock.patch('signal.getsignal', lambda sig: handler):
+ handler(signal.SIGINT, inspect.currentframe())
+ self.assertTrue(handler.called)
+ self.assertEqual(default_handler.call_count, 1)
+
+ def test_call_with_registered_results(self):
+ default_handler = mock.Mock()
+ result = mock.Mock()
+ handler = signals._InterruptHandler(default_handler)
+ with mock.patch('signal.getsignal', lambda sig: handler), \
+ mock.patch('unittest.signals._results', {result: 1}):
+ handler(signal.SIGINT, inspect.currentframe())
+ self.assertTrue(handler.called)
+ self.assertEqual(default_handler.call_count, 0)
+ self.assertEqual(result.stop.call_count, 1)
+
+ def test_call_twice(self):
+ default_handler = mock.Mock()
+ result = mock.Mock()
+ handler = signals._InterruptHandler(default_handler)
+ with mock.patch('signal.getsignal', lambda sig: handler), \
+ mock.patch('unittest.signals._results', {result: 1}):
+ handler(signal.SIGINT, inspect.currentframe())
+ handler(signal.SIGINT, inspect.currentframe())
+ self.assertTrue(handler.called)
+ self.assertEqual(default_handler.call_count, 1)
+ self.assertEqual(result.stop.call_count, 1)
+
+ def test_call_when_not_installed(self):
+ default_handler = mock.Mock()
+ result = mock.Mock()
+ handler = signals._InterruptHandler(default_handler)
+ with mock.patch('unittest.signals._results', {result: 1}):
+ handler(signal.SIGINT, inspect.currentframe())
+ self.assertFalse(handler.called)
+ self.assertEqual(default_handler.call_count, 1)
+ self.assertEqual(result.stop.call_count, 0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment