From 371abe73029bab1b72dc84ad36ce665786a767d4 Mon Sep 17 00:00:00 2001
From: Rohan McGovern <rohan@mcgovern.id.au>
Date: Tue, 15 Jan 2019 13:48:06 +1000
Subject: [PATCH] Fix deadlocks when futures outlive executor

Let futures hold a strong reference to their executor until
completion, otherwise the futures may never be resolved.

Clearing the reference when the future completes avoids keeping
threads alive longer than needed.

Fixes #114
---
 README.md                         |  2 +
 more_executors/poll.py            | 16 ++---
 more_executors/retry.py           | 10 +++-
 more_executors/throttle.py        | 22 ++++---
 tests/test_executor.py            | 97 +++++++++++++++++++------------
 tests/test_executor_threadleak.py | 16 ++++-
 6 files changed, 108 insertions(+), 55 deletions(-)

diff --git a/README.md b/README.md
index 28e8bcf6..95f5573b 100644
--- a/README.md
+++ b/README.md
@@ -72,6 +72,8 @@ def fetch_urls(urls):
 
 - Reduced log verbosity
   ([#115](https://github.com/rohanpm/more-executors/issues/115))
+- Fixed deadlock when awaiting a future whose executor was garbage collected
+  ([#114](https://github.com/rohanpm/more-executors/issues/114))
 
 ### v1.17.0
 
diff --git a/more_executors/poll.py b/more_executors/poll.py
index 3e9d1925..92eb59c1 100644
--- a/more_executors/poll.py
+++ b/more_executors/poll.py
@@ -13,6 +13,7 @@ def __init__(self, delegate, executor):
         self._delegate = delegate
         self._executor = executor
         self._delegate.add_done_callback(self._delegate_resolved)
+        self.add_done_callback(self._clear_executor)
 
     def _delegate_resolved(self, delegate):
         assert delegate is self._delegate, \
@@ -29,6 +30,11 @@ def _clear_delegate(self):
         with self._me_lock:
             self._delegate = None
 
+    @classmethod
+    def _clear_executor(cls, future):
+        future._executor._deregister_poll(future)
+        future._executor = None
+
     def set_result(self, result):
         with self._me_lock:
             if self.done():
@@ -65,10 +71,8 @@ def running(self):
     def _me_cancel(self):
         if self._delegate and not self._delegate.cancel():
             return False
-        if not self._executor._run_cancel_fn(self):
-            return False
-        self._executor._deregister_poll(self)
-        return True
+        executor = self._executor
+        return executor and executor._run_cancel_fn(self)
 
 
 class PollDescriptor(object):
@@ -160,9 +164,7 @@ def __init__(self, delegate, poll_fn, cancel_fn=None, default_interval=5.0, logg
 
     def submit(self, fn, *args, **kwargs):
         delegate_future = self._delegate.submit(fn, *args, **kwargs)
-        out = _PollFuture(delegate_future, self)
-        out.add_done_callback(self._deregister_poll)
-        return out
+        return _PollFuture(delegate_future, self)
 
     def _register_poll(self, future, delegate_future):
         descriptor = PollDescriptor(future, delegate_future.result())
diff --git a/more_executors/retry.py b/more_executors/retry.py
index cc7eb827..cae2a09e 100644
--- a/more_executors/retry.py
+++ b/more_executors/retry.py
@@ -95,7 +95,8 @@ class _RetryFuture(_Future):
     def __init__(self, executor):
         super(_RetryFuture, self).__init__()
         self.delegate_future = None
-        self._executor_ref = weakref.ref(executor)
+        self._executor = executor
+        self.add_done_callback(self._clear_executor)
 
     def running(self):
         with self._me_lock:
@@ -113,6 +114,11 @@ def _clear_delegate(self):
         with self._me_lock:
             self.delegate_future = None
 
+    @classmethod
+    def _clear_executor(cls, future):
+        with future._me_lock:
+            future._executor = None
+
     def set_result(self, result):
         with self._me_lock:
             self._clear_delegate()
@@ -134,7 +140,7 @@ def set_exception_info(self, exception, traceback):
         self._me_invoke_callbacks()
 
     def _me_cancel(self):
-        executor = self._executor_ref()
+        executor = self._executor
         return executor and executor._cancel(self)
 
 
diff --git a/more_executors/throttle.py b/more_executors/throttle.py
index 2d4d8274..4e595e3d 100644
--- a/more_executors/throttle.py
+++ b/more_executors/throttle.py
@@ -1,6 +1,7 @@
 from concurrent.futures import Executor
 from threading import Event, Thread, Lock, Semaphore
 from collections import namedtuple, deque
+from functools import partial
 import logging
 import weakref
 
@@ -11,15 +12,20 @@
 
 class _ThrottleFuture(_MapFuture):
     def __init__(self, executor):
-        self._executor_ref = weakref.ref(executor)
+        self._executor = executor
         super(_ThrottleFuture, self).__init__(delegate=None, map_fn=lambda x: x)
+        self.add_done_callback(self._clear_executor)
 
     def _me_cancel(self):
         if self._delegate:
             return self._delegate.cancel()
-        executor = self._executor_ref()
+        executor = self._executor
         return executor and executor._do_cancel(self)
 
+    @classmethod
+    def _clear_executor(cls, future):
+        future._executor = None
+
 
 _ThrottleJob = namedtuple('_ThrottleJob', ['future', 'fn', 'args', 'kwargs'])
 
@@ -87,7 +93,8 @@ def _do_submit(self, job):
         delegate_future = self._delegate.submit(job.fn, *job.args, **job.kwargs)
         self._log.debug("Submitted %s yielding %s", job, delegate_future)
 
-        delegate_future.add_done_callback(self._delegate_future_done)
+        delegate_future.add_done_callback(
+            partial(self._delegate_future_done, self._log, self._sem, self._event))
         job.future._set_delegate(delegate_future)
 
     def _do_cancel(self, future):
@@ -100,13 +107,14 @@ def _do_cancel(self, future):
         self._log.debug("Could not find for cancel: %s", future)
         return False
 
-    def _delegate_future_done(self, future):
+    @classmethod
+    def _delegate_future_done(cls, log, sem, event, future):
         # Whenever an inner future completes, one more execution slot becomes
         # available, and the thread should wake up in case there's something to
         # be submitted
-        self._log.debug("Delegate future done: %s", future)
-        self._sem.release()
-        self._event.set()
+        log.debug("Delegate future done: %s", future)
+        sem.release()
+        event.set()
 
 
 def _submit_loop_iter(executor):
diff --git a/tests/test_executor.py b/tests/test_executor.py
index c7025565..0958e883 100644
--- a/tests/test_executor.py
+++ b/tests/test_executor.py
@@ -49,58 +49,61 @@ def poll_noop(ds):
 
 
 @fixture
-def retry_executor():
-    return Executors.thread_pool().with_retry(max_attempts=1)
+def retry_executor_ctor():
+    return lambda: Executors.thread_pool().with_retry(max_attempts=1)
 
 
 @fixture
-def threadpool_executor():
-    return Executors.thread_pool(max_workers=20)
+def threadpool_executor_ctor():
+    return lambda: Executors.thread_pool(max_workers=20)
 
 
 @fixture
-def sync_executor():
-    return Executors.sync()
+def sync_executor_ctor():
+    return Executors.sync
 
 
 @fixture
-def map_executor(threadpool_executor):
-    return threadpool_executor.with_map(map_noop)
+def map_executor_ctor(threadpool_executor_ctor):
+    return lambda: threadpool_executor_ctor().with_map(map_noop)
 
 
 @fixture
-def flat_map_executor(threadpool_executor):
-    return threadpool_executor.with_flat_map(partial(flat_map_noop, threadpool_executor))
+def flat_map_executor_ctor(threadpool_executor_ctor):
+    def out():
+        threadpool_executor = threadpool_executor_ctor()
+        return threadpool_executor.with_flat_map(partial(flat_map_noop, threadpool_executor))
+    return out
 
 
 @fixture
-def throttle_executor(threadpool_executor):
-    return threadpool_executor.with_throttle(10)
+def throttle_executor_ctor(threadpool_executor_ctor):
+    return lambda: threadpool_executor_ctor().with_throttle(10)
 
 
 @fixture
-def cancel_on_shutdown_executor(threadpool_executor):
-    return threadpool_executor.with_cancel_on_shutdown()
+def cancel_on_shutdown_executor_ctor(threadpool_executor_ctor):
+    return lambda: threadpool_executor_ctor().with_cancel_on_shutdown()
 
 
 @fixture
-def map_retry_executor(threadpool_executor):
-    return threadpool_executor.with_retry(RetryPolicy()).with_map(map_noop)
+def map_retry_executor_ctor(threadpool_executor_ctor):
+    return lambda: threadpool_executor_ctor().with_retry(RetryPolicy()).with_map(map_noop)
 
 
 @fixture
-def retry_map_executor(threadpool_executor):
-    return threadpool_executor.with_map(map_noop).with_retry(RetryPolicy())
+def retry_map_executor_ctor(threadpool_executor_ctor):
+    return lambda: threadpool_executor_ctor().with_map(map_noop).with_retry(RetryPolicy())
 
 
 @fixture
-def timeout_executor(threadpool_executor):
-    return threadpool_executor.with_timeout(60.0)
+def timeout_executor_ctor(threadpool_executor_ctor):
+    return lambda: threadpool_executor_ctor().with_timeout(60.0)
 
 
 @fixture
-def cancel_poll_map_retry_executor(threadpool_executor):
-    return threadpool_executor.\
+def cancel_poll_map_retry_executor_ctor(threadpool_executor_ctor):
+    return lambda: threadpool_executor_ctor().\
         with_retry(RetryPolicy()).\
         with_map(map_noop).\
         with_poll(poll_noop).\
@@ -108,8 +111,8 @@ def cancel_poll_map_retry_executor(threadpool_executor):
 
 
 @fixture
-def cancel_retry_map_poll_executor(threadpool_executor):
-    return threadpool_executor.\
+def cancel_retry_map_poll_executor_ctor(threadpool_executor_ctor):
+    return lambda: threadpool_executor_ctor().\
         with_poll(poll_noop).\
         with_map(map_noop).\
         with_retry(RetryPolicy()).\
@@ -117,8 +120,8 @@ def cancel_retry_map_poll_executor(threadpool_executor):
 
 
 @fixture
-def retry_map_poll_executor(threadpool_executor):
-    return threadpool_executor.\
+def retry_map_poll_executor_ctor(threadpool_executor_ctor):
+    return lambda: threadpool_executor_ctor().\
         with_poll(poll_noop).\
         with_map(map_noop).\
         with_retry(RetryPolicy())
@@ -138,8 +141,8 @@ def random_cancel(_value):
 
 
 @fixture
-def poll_executor(threadpool_executor):
-    return threadpool_executor.\
+def poll_executor_ctor(threadpool_executor_ctor):
+    return lambda: threadpool_executor_ctor().\
         with_poll(poll_noop,
                   random_cancel)
 
@@ -168,21 +171,25 @@ def everything_executor(base_executor):
 
 
 @fixture
-def everything_sync_executor(sync_executor):
-    return everything_executor(sync_executor)
+def everything_sync_executor_ctor(sync_executor_ctor):
+    return lambda: everything_executor(sync_executor_ctor())
 
 
 @fixture
-def everything_threadpool_executor(threadpool_executor):
-    return everything_executor(threadpool_executor)
+def everything_threadpool_executor_ctor(threadpool_executor_ctor):
+    return lambda: everything_executor(threadpool_executor_ctor())
 
 
-@fixture(params=['threadpool', 'retry', 'map', 'retry_map', 'map_retry', 'poll', 'retry_map_poll',
-                 'sync', 'timeout', 'throttle', 'cancel_poll_map_retry', 'cancel_retry_map_poll',
-                 'flat_map',
-                 'everything_sync', 'everything_threadpool'])
+EXECUTOR_TYPES = ['threadpool', 'retry', 'map', 'retry_map', 'map_retry', 'poll', 'retry_map_poll',
+                  'sync', 'timeout', 'throttle', 'cancel_poll_map_retry', 'cancel_retry_map_poll',
+                  'flat_map',
+                  'everything_sync', 'everything_threadpool']
+
+
+@fixture(params=EXECUTOR_TYPES)
 def any_executor(request):
-    ex = request.getfixturevalue(request.param + '_executor')
+    ctor = request.getfixturevalue(request.param + '_executor_ctor')
+    ex = ctor()
 
     # Capture log messages onto the executor itself,
     # for use with dump_executor if test fails.
@@ -212,6 +219,11 @@ def any_executor(request):
     ex.shutdown(False)
 
 
+@fixture(params=EXECUTOR_TYPES)
+def any_executor_ctor(request):
+    return request.getfixturevalue(request.param + '_executor_ctor')
+
+
 def test_submit_results(any_executor):
     values = range(0, 1000)
     expected_results = [v*2 for v in values]
@@ -228,6 +240,17 @@ def fn(x):
     assert_that(results, equal_to(expected_results))
 
 
+def test_future_outlive_executor(any_executor_ctor):
+    def make_futures(executor):
+        return [executor.submit(lambda x: x*2, y)
+                for y in [1, 2, 3, 4]]
+
+    futures = make_futures(any_executor_ctor())
+    results = [f.result(TIMEOUT) for f in futures]
+
+    assert results == [2, 4, 6, 8]
+
+
 def test_broken_callback(any_executor):
     values = range(0, 1000)
     expected_results = [v*2 for v in values]
diff --git a/tests/test_executor_threadleak.py b/tests/test_executor_threadleak.py
index 5e08db87..4a95f917 100644
--- a/tests/test_executor_threadleak.py
+++ b/tests/test_executor_threadleak.py
@@ -84,7 +84,7 @@ def test_no_leak_on_discarded_futures(executor_ctor):
     no_extra_threads = partial(assert_no_extra_threads, thread_names())
 
     executor = executor_ctor()
-    futures = [executor.submit(mult2, n) for n in [10, 20, 30]]
+    futures = [executor.submit(mult2, n) for n in range(0, 1000)]
     del executor
     del futures
 
@@ -99,9 +99,21 @@ def test_no_leak_on_completed_futures(executor_ctor):
     no_extra_threads = partial(assert_no_extra_threads, thread_names())
 
     executor = executor_ctor()
-    results = [executor.submit(mult2, n) for n in [10, 20, 30]]
+    results = [executor.submit(mult2, n) for n in range(0, 1000)]
     results = get_future_results(results)
 
     del executor
 
     assert_soon(no_extra_threads)
+
+
+def test_no_leak_on_completed_held_futures(executor_ctor):
+    no_extra_threads = partial(assert_no_extra_threads, thread_names())
+
+    executor = executor_ctor()
+    futures = [executor.submit(mult2, n) for n in range(0, 1000)]
+    get_future_results(futures)
+
+    del executor
+
+    assert_soon(no_extra_threads)