diff --git a/tests/test_memcache.py b/tests/test_memcache.py index 0072813..1ce090a 100644 --- a/tests/test_memcache.py +++ b/tests/test_memcache.py @@ -1,5 +1,11 @@ from __future__ import print_function +from itertools import chain +import multiprocessing +import os +import signal +import sys +import traceback import unittest import six @@ -166,6 +172,73 @@ def test_disconnect_all_delete_multi(self): ret = self.mc.delete_multi({'keyhere': 'a', 'keythere': 'b'}) self.assertEqual(ret, 1) + def test_exception_handling(self): + """Tests closing socket when custom exception raised""" + queue = multiprocessing.Queue() + process = multiprocessing.Process(target=worker, args=(self.mc, queue)) + process.start() + if queue.get() != 'loop started': + raise ValueError( + 'Expected "loop started" message from the child process' + ) + + # maximum test duration is 0.5 second + num_iters = 50 + timeout = 0.01 + for i in range(num_iters): + os.kill(process.pid, signal.SIGUSR1) + try: + exc = WorkerError(*queue.get(timeout=timeout)) + raise exc + except six.moves.queue.Empty: + pass + if not process.is_alive(): + break + + if process.is_alive(): + os.kill(process.pid, signal.SIGTERM) + process.join() + + +class SignalException(Exception): + pass + + +def sighandler(signum, frame): + raise SignalException() + + +class WorkerError(Exception): + def __init__(self, exc, assert_tb, signal_tb=None): + super(WorkerError, self).__init__( + ''.join(chain(assert_tb, signal_tb or [])) + ) + self.cause = exc + + +def worker(mc, queue): + signal.signal(signal.SIGUSR1, sighandler) + + signal_tb = None + for i in range(100000): + if i == 0: + queue.put('loop started') + try: + k = str(i) + mc.set(k, i) + # This loop is just to increase chance to get previous value + # for clarity + for j in range(10): + mc.get(str(i-1)) + res = mc.get(k) + assert res == i, 'Expected {} but was {}'.format(i, res) + except AssertionError as e: + assert_tb = traceback.format_exception(*sys.exc_info()) + queue.put((e, assert_tb, signal_tb)) + break + except SignalException as e: + signal_tb = traceback.format_exception(*sys.exc_info()) + if __name__ == '__main__': unittest.main()