Index: build/android/devil/utils/timeout_retry.py |
diff --git a/build/android/devil/utils/timeout_retry.py b/build/android/devil/utils/timeout_retry.py |
index e6c77f129ca24ce89d51f9aad17de05117d036c8..1d748a9f1245d1309f3f1d61ad02a65eebf73016 100644 |
--- a/build/android/devil/utils/timeout_retry.py |
+++ b/build/android/devil/utils/timeout_retry.py |
@@ -14,11 +14,11 @@ from devil.utils import reraiser_thread |
from devil.utils import watchdog_timer |
-class TimeoutRetryThread(reraiser_thread.ReraiserThread): |
- def __init__(self, func, timeout, name): |
- super(TimeoutRetryThread, self).__init__(func, name=name) |
+ |
+class TimeoutRetryThreadGroup(reraiser_thread.ReraiserThreadGroup): |
+ def __init__(self, timeout, threads=None): |
+ super(TimeoutRetryThreadGroup, self).__init__(threads) |
self._watcher = watchdog_timer.WatchdogTimer(timeout) |
- self._expired = False |
def GetWatcher(self): |
"""Returns the watchdog keeping track of this thread's time.""" |
@@ -52,34 +52,22 @@ class TimeoutRetryThread(reraiser_thread.ReraiserThread): |
if remaining > 0: |
msg += (', wait of %.1f secs required but only %.1f secs left' |
% (required, remaining)) |
- self._expired = True |
raise reraiser_thread.TimeoutError(msg) |
return remaining |
- def LogTimeoutException(self): |
- """Log the exception that terminated this thread.""" |
- if not self._expired: |
- return |
- logging.critical('*' * 80) |
- logging.critical('%s on thread %r', self._exc_info[0].__name__, self.name) |
- logging.critical('*' * 80) |
- fmt_exc = ''.join(traceback.format_exception(*self._exc_info)) |
- for line in fmt_exc.splitlines(): |
- logging.critical(line.rstrip()) |
- logging.critical('*' * 80) |
- |
-def CurrentTimeoutThread(): |
- """Get the current thread if it is a TimeoutRetryThread. |
+def CurrentTimeoutThreadGroup(): |
+ """Returns the thread group that owns or is blocked on the active thread. |
Returns: |
- The current thread if it is a TimeoutRetryThread, otherwise None. |
+ Returns None if no TimeoutRetryThreadGroup is tracking the current thread. |
""" |
- current_thread = threading.current_thread() |
- if isinstance(current_thread, TimeoutRetryThread): |
- return current_thread |
- else: |
- return None |
+ thread_group = reraiser_thread.CurrentThreadGroup() |
+ while thread_group: |
+ if isinstance(thread_group, TimeoutRetryThreadGroup): |
+ return thread_group |
+ thread_group = thread_group.blocked_parent_thread_group |
+ return None |
def WaitFor(condition, wait_period=5, max_tries=None): |
@@ -88,45 +76,57 @@ def WaitFor(condition, wait_period=5, max_tries=None): |
Repeatedly call the function condition(), with no arguments, until it returns |
a true value. |
- If called within a TimeoutRetryThread, it cooperates nicely with it. |
+ If called within a TimeoutRetryThreadGroup, it cooperates nicely with it. |
Args: |
condition: function with the condition to check |
wait_period: number of seconds to wait before retrying to check the |
condition |
max_tries: maximum number of checks to make, the default tries forever |
- or until the TimeoutRetryThread expires. |
+ or until the TimeoutRetryThreadGroup expires. |
Returns: |
The true value returned by the condition, or None if the condition was |
not met after max_tries. |
Raises: |
- reraiser_thread.TimeoutError if the current thread is a TimeoutRetryThread |
- and the timeout expires. |
+ reraiser_thread.TimeoutError: if the current thread is a |
+ TimeoutRetryThreadGroup and the timeout expires. |
""" |
condition_name = condition.__name__ |
- timeout_thread = CurrentTimeoutThread() |
+ timeout_thread_group = CurrentTimeoutThreadGroup() |
while max_tries is None or max_tries > 0: |
result = condition() |
if max_tries is not None: |
max_tries -= 1 |
msg = ['condition', repr(condition_name), 'met' if result else 'not met'] |
- if timeout_thread: |
+ if timeout_thread_group: |
# pylint: disable=no-member |
- msg.append('(%.1fs)' % timeout_thread.GetElapsedTime()) |
+ msg.append('(%.1fs)' % timeout_thread_group.GetElapsedTime()) |
logging.info(' '.join(msg)) |
if result: |
return result |
- if timeout_thread: |
+ if timeout_thread_group: |
# pylint: disable=no-member |
- timeout_thread.GetRemainingTime(wait_period, |
+ timeout_thread_group.GetRemainingTime(wait_period, |
msg='Timed out waiting for %r' % condition_name) |
time.sleep(wait_period) |
return None |
-def Run(func, timeout, retries, args=None, kwargs=None, desc=None): |
+def _LogLastException(thread_name, attempt, max_attempts, log_func): |
+ log_func('*' * 80) |
+ log_func('Exception on thread %s (attempt %d of %d)', thread_name, |
+ attempt, max_attempts) |
+ log_func('*' * 80) |
+ fmt_exc = ''.join(traceback.format_exc()) |
+ for line in fmt_exc.splitlines(): |
+ log_func(line.rstrip()) |
+ log_func('*' * 80) |
+ |
+ |
+def Run(func, timeout, retries, args=None, kwargs=None, desc=None, |
+ error_log_func=logging.critical): |
"""Runs the passed function in a separate thread with timeouts and retries. |
Args: |
@@ -137,6 +137,7 @@ def Run(func, timeout, retries, args=None, kwargs=None, desc=None): |
kwargs: dictionary of keyword args to pass to |func|. |
desc: An optional description of |func| used in logging. If omitted, |
|func.__name__| will be used. |
+ error_log_func: Logging function when logging errors. |
Returns: |
The return value of func(*args, **kwargs). |
@@ -146,30 +147,29 @@ def Run(func, timeout, retries, args=None, kwargs=None, desc=None): |
if not kwargs: |
kwargs = {} |
- # The return value uses a list because Python variables are references, not |
- # values. Closures make a copy of the reference, so updating the closure's |
- # reference wouldn't update where the original reference pointed. |
- ret = [None] |
- def RunOnTimeoutThread(): |
- ret[0] = func(*args, **kwargs) |
- |
num_try = 1 |
while True: |
- child_thread = TimeoutRetryThread( |
- RunOnTimeoutThread, timeout, |
- name='TimeoutThread-%d-for-%s' % (num_try, |
- threading.current_thread().name)) |
+ thread_name = 'TimeoutThread-%d-for-%s' % (num_try, |
+ threading.current_thread().name) |
+ child_thread = reraiser_thread.ReraiserThread(lambda: func(*args, **kwargs), |
+ name=thread_name) |
try: |
- thread_group = reraiser_thread.ReraiserThreadGroup([child_thread]) |
- thread_group.StartAll() |
+ thread_group = TimeoutRetryThreadGroup(timeout, threads=[child_thread]) |
+ thread_group.StartAll(will_block=True) |
while True: |
- thread_group.JoinAll(watcher=child_thread.GetWatcher(), timeout=60) |
+ thread_group.JoinAll(watcher=thread_group.GetWatcher(), timeout=60, |
+ error_log_func=error_log_func) |
if thread_group.IsAlive(): |
logging.info('Still working on %s', desc if desc else func.__name__) |
else: |
- return ret[0] |
- except: |
- child_thread.LogTimeoutException() |
+ return thread_group.GetAllReturnValues()[0] |
+ except reraiser_thread.TimeoutError: |
+ # Timeouts already get their stacks logged. |
+ if num_try > retries: |
+ raise |
+ # Do not catch KeyboardInterrupt. |
+ except Exception: # pylint: disable=broad-except |
if num_try > retries: |
raise |
- num_try += 1 |
+ _LogLastException(thread_name, num_try, retries + 1, error_log_func) |
+ num_try += 1 |