Index: build/android/devil/utils/reraiser_thread.py |
diff --git a/build/android/devil/utils/reraiser_thread.py b/build/android/devil/utils/reraiser_thread.py |
index 7607352b73a598f32fc4db87ff0c0f3d5d7e9460..964593b5f348b74665abf8fc01d748a8253db21a 100644 |
--- a/build/android/devil/utils/reraiser_thread.py |
+++ b/build/android/devil/utils/reraiser_thread.py |
@@ -19,21 +19,22 @@ class TimeoutError(Exception): |
pass |
-def LogThreadStack(thread): |
+def LogThreadStack(thread, error_log_func=logging.critical): |
"""Log the stack for the given thread. |
Args: |
thread: a threading.Thread instance. |
+ error_log_func: Logging function when logging errors. |
""" |
stack = sys._current_frames()[thread.ident] |
- logging.critical('*' * 80) |
- logging.critical('Stack dump for thread %r', thread.name) |
- logging.critical('*' * 80) |
+ error_log_func('*' * 80) |
+ error_log_func('Stack dump for thread %r', thread.name) |
+ error_log_func('*' * 80) |
for filename, lineno, name, line in traceback.extract_stack(stack): |
- logging.critical('File: "%s", line %d, in %s', filename, lineno, name) |
+ error_log_func('File: "%s", line %d, in %s', filename, lineno, name) |
if line: |
- logging.critical(' %s', line.strip()) |
- logging.critical('*' * 80) |
+ error_log_func(' %s', line.strip()) |
+ error_log_func('*' * 80) |
class ReraiserThread(threading.Thread): |
@@ -59,6 +60,7 @@ class ReraiserThread(threading.Thread): |
self._kwargs = kwargs |
self._ret = None |
self._exc_info = None |
+ self._thread_group = None |
def ReraiseIfException(self): |
"""Reraise exception if an exception was raised in the thread.""" |
@@ -88,9 +90,14 @@ class ReraiserThreadGroup(object): |
Args: |
threads: a list of ReraiserThread objects; defaults to empty. |
""" |
- if not threads: |
- threads = [] |
- self._threads = list(threads) |
+ self._threads = [] |
+ # Set when a thread from one group has called JoinAll on another. It is used |
+ # to detect when a there is a TimeoutRetryThread active that links to the |
+ # current thread. |
+ self.blocked_parent_thread_group = None |
+ if threads: |
+ for thread in threads: |
+ self.Add(thread) |
def Add(self, thread): |
"""Add a thread to the group. |
@@ -98,10 +105,23 @@ class ReraiserThreadGroup(object): |
Args: |
thread: a ReraiserThread object. |
""" |
+ assert thread._thread_group is None |
+ thread._thread_group = self |
self._threads.append(thread) |
- def StartAll(self): |
- """Start all threads.""" |
+ def StartAll(self, will_block=False): |
+ """Start all threads. |
+ |
+ Args: |
+ will_block: Whether the calling thread will subsequently block on this |
+ thread group. Causes the active ReraiserThreadGroup (if there is one) |
+ to be marked as blocking on this thread group. |
+ """ |
+ if will_block: |
+ # Multiple threads blocking on the same outer thread should not happen in |
+ # practice. |
+ assert not self.blocked_parent_thread_group |
+ self.blocked_parent_thread_group = CurrentThreadGroup() |
for thread in self._threads: |
thread.start() |
@@ -121,18 +141,21 @@ class ReraiserThreadGroup(object): |
watcher = watchdog_timer.WatchdogTimer(None) |
alive_threads = self._threads[:] |
end_time = (time.time() + timeout) if timeout else None |
- while alive_threads and (end_time is None or end_time > time.time()): |
- for thread in alive_threads[:]: |
- if watcher.IsTimedOut(): |
- raise TimeoutError('Timed out waiting for %d of %d threads.' % |
- (len(alive_threads), len(self._threads))) |
- # Allow the main thread to periodically check for interrupts. |
- thread.join(0.1) |
- if not thread.isAlive(): |
- alive_threads.remove(thread) |
- # All threads are allowed to complete before reraising exceptions. |
- for thread in self._threads: |
- thread.ReraiseIfException() |
+ try: |
+ while alive_threads and (end_time is None or end_time > time.time()): |
+ for thread in alive_threads[:]: |
+ if watcher.IsTimedOut(): |
+ raise TimeoutError('Timed out waiting for %d of %d threads.' % |
+ (len(alive_threads), len(self._threads))) |
+ # Allow the main thread to periodically check for interrupts. |
+ thread.join(0.1) |
+ if not thread.isAlive(): |
+ alive_threads.remove(thread) |
+ # All threads are allowed to complete before reraising exceptions. |
+ for thread in self._threads: |
+ thread.ReraiseIfException() |
+ finally: |
+ self.blocked_parent_thread_group = None |
def IsAlive(self): |
"""Check whether any of the threads are still alive. |
@@ -142,7 +165,8 @@ class ReraiserThreadGroup(object): |
""" |
return any(t.isAlive() for t in self._threads) |
- def JoinAll(self, watcher=None, timeout=None): |
+ def JoinAll(self, watcher=None, timeout=None, |
+ error_log_func=logging.critical): |
"""Join all threads. |
Reraises exceptions raised by the child threads and supports breaking |
@@ -154,13 +178,14 @@ class ReraiserThreadGroup(object): |
provided, the thread will never be timed out. |
timeout: An optional number of seconds to wait before timing out the join |
operation. This will not time out the threads. |
+ error_log_func: Logging function when logging errors. |
""" |
try: |
self._JoinAll(watcher, timeout) |
except TimeoutError: |
- logging.critical('Timed out. Dumping threads.') |
+ error_log_func('Timed out. Dumping threads.') |
for thread in (t for t in self._threads if t.isAlive()): |
- LogThreadStack(thread) |
+ LogThreadStack(thread, error_log_func=error_log_func) |
raise |
def GetAllReturnValues(self, watcher=None): |
@@ -174,6 +199,18 @@ class ReraiserThreadGroup(object): |
return [t.GetReturnValue() for t in self._threads] |
+def CurrentThreadGroup(): |
+ """Returns the ReraiserThreadGroup that owns the running thread. |
+ |
+ Returns: |
+ The current thread group, otherwise None. |
+ """ |
+ current_thread = threading.current_thread() |
+ if isinstance(current_thread, ReraiserThread): |
+ return current_thread._thread_group # pylint: disable=no-member |
+ return None |
+ |
+ |
def RunAsync(funcs, watcher=None): |
"""Executes the given functions in parallel and returns their results. |
@@ -185,5 +222,5 @@ def RunAsync(funcs, watcher=None): |
A list of return values in the order of the given functions. |
""" |
thread_group = ReraiserThreadGroup(ReraiserThread(f) for f in funcs) |
- thread_group.StartAll() |
+ thread_group.StartAll(will_block=True) |
return thread_group.GetAllReturnValues(watcher=watcher) |