| OLD | NEW |
| 1 # Copyright 2013 The Chromium Authors. All rights reserved. | 1 # Copyright 2013 The Chromium Authors. All rights reserved. |
| 2 # Use of this source code is governed by a BSD-style license that can be | 2 # Use of this source code is governed by a BSD-style license that can be |
| 3 # found in the LICENSE file. | 3 # found in the LICENSE file. |
| 4 | 4 |
| 5 """A utility to run functions with timeouts and retries.""" | 5 """A utility to run functions with timeouts and retries.""" |
| 6 # pylint: disable=W0702 | 6 # pylint: disable=W0702 |
| 7 | 7 |
| 8 import logging | 8 import logging |
| 9 import threading | 9 import threading |
| 10 import time | 10 import time |
| 11 import traceback | 11 import traceback |
| 12 | 12 |
| 13 from devil.utils import reraiser_thread | 13 from devil.utils import reraiser_thread |
| 14 from devil.utils import watchdog_timer | 14 from devil.utils import watchdog_timer |
| 15 | 15 |
| 16 | 16 |
| 17 class TimeoutRetryThread(reraiser_thread.ReraiserThread): | 17 |
| 18 def __init__(self, func, timeout, name): | 18 class TimeoutRetryThreadGroup(reraiser_thread.ReraiserThreadGroup): |
| 19 super(TimeoutRetryThread, self).__init__(func, name=name) | 19 def __init__(self, timeout, threads=None): |
| 20 super(TimeoutRetryThreadGroup, self).__init__(threads) |
| 20 self._watcher = watchdog_timer.WatchdogTimer(timeout) | 21 self._watcher = watchdog_timer.WatchdogTimer(timeout) |
| 21 self._expired = False | |
| 22 | 22 |
| 23 def GetWatcher(self): | 23 def GetWatcher(self): |
| 24 """Returns the watchdog keeping track of this thread's time.""" | 24 """Returns the watchdog keeping track of this thread's time.""" |
| 25 return self._watcher | 25 return self._watcher |
| 26 | 26 |
| 27 def GetElapsedTime(self): | 27 def GetElapsedTime(self): |
| 28 return self._watcher.GetElapsed() | 28 return self._watcher.GetElapsed() |
| 29 | 29 |
| 30 def GetRemainingTime(self, required=0, msg=None): | 30 def GetRemainingTime(self, required=0, msg=None): |
| 31 """Get the remaining time before the thread times out. | 31 """Get the remaining time before the thread times out. |
| (...skipping 13 matching lines...) Expand all Loading... |
| 45 reraiser_thread.TimeoutError if the remaining time is less than the | 45 reraiser_thread.TimeoutError if the remaining time is less than the |
| 46 required time. | 46 required time. |
| 47 """ | 47 """ |
| 48 remaining = self._watcher.GetRemaining() | 48 remaining = self._watcher.GetRemaining() |
| 49 if remaining is not None and remaining < required: | 49 if remaining is not None and remaining < required: |
| 50 if msg is None: | 50 if msg is None: |
| 51 msg = 'Timeout expired' | 51 msg = 'Timeout expired' |
| 52 if remaining > 0: | 52 if remaining > 0: |
| 53 msg += (', wait of %.1f secs required but only %.1f secs left' | 53 msg += (', wait of %.1f secs required but only %.1f secs left' |
| 54 % (required, remaining)) | 54 % (required, remaining)) |
| 55 self._expired = True | |
| 56 raise reraiser_thread.TimeoutError(msg) | 55 raise reraiser_thread.TimeoutError(msg) |
| 57 return remaining | 56 return remaining |
| 58 | 57 |
| 59 def LogTimeoutException(self): | |
| 60 """Log the exception that terminated this thread.""" | |
| 61 if not self._expired: | |
| 62 return | |
| 63 logging.critical('*' * 80) | |
| 64 logging.critical('%s on thread %r', self._exc_info[0].__name__, self.name) | |
| 65 logging.critical('*' * 80) | |
| 66 fmt_exc = ''.join(traceback.format_exception(*self._exc_info)) | |
| 67 for line in fmt_exc.splitlines(): | |
| 68 logging.critical(line.rstrip()) | |
| 69 logging.critical('*' * 80) | |
| 70 | 58 |
| 71 | 59 def CurrentTimeoutThreadGroup(): |
| 72 def CurrentTimeoutThread(): | 60 """Returns the thread group that owns or is blocked on the active thread. |
| 73 """Get the current thread if it is a TimeoutRetryThread. | |
| 74 | 61 |
| 75 Returns: | 62 Returns: |
| 76 The current thread if it is a TimeoutRetryThread, otherwise None. | 63 Returns None if no TimeoutRetryThreadGroup is tracking the current thread. |
| 77 """ | 64 """ |
| 78 current_thread = threading.current_thread() | 65 thread_group = reraiser_thread.CurrentThreadGroup() |
| 79 if isinstance(current_thread, TimeoutRetryThread): | 66 while thread_group: |
| 80 return current_thread | 67 if isinstance(thread_group, TimeoutRetryThreadGroup): |
| 81 else: | 68 return thread_group |
| 82 return None | 69 thread_group = thread_group.blocked_parent_thread_group |
| 70 return None |
| 83 | 71 |
| 84 | 72 |
| 85 def WaitFor(condition, wait_period=5, max_tries=None): | 73 def WaitFor(condition, wait_period=5, max_tries=None): |
| 86 """Wait for a condition to become true. | 74 """Wait for a condition to become true. |
| 87 | 75 |
| 88 Repeatedly call the function condition(), with no arguments, until it returns | 76 Repeatedly call the function condition(), with no arguments, until it returns |
| 89 a true value. | 77 a true value. |
| 90 | 78 |
| 91 If called within a TimeoutRetryThread, it cooperates nicely with it. | 79 If called within a TimeoutRetryThreadGroup, it cooperates nicely with it. |
| 92 | 80 |
| 93 Args: | 81 Args: |
| 94 condition: function with the condition to check | 82 condition: function with the condition to check |
| 95 wait_period: number of seconds to wait before retrying to check the | 83 wait_period: number of seconds to wait before retrying to check the |
| 96 condition | 84 condition |
| 97 max_tries: maximum number of checks to make, the default tries forever | 85 max_tries: maximum number of checks to make, the default tries forever |
| 98 or until the TimeoutRetryThread expires. | 86 or until the TimeoutRetryThreadGroup expires. |
| 99 | 87 |
| 100 Returns: | 88 Returns: |
| 101 The true value returned by the condition, or None if the condition was | 89 The true value returned by the condition, or None if the condition was |
| 102 not met after max_tries. | 90 not met after max_tries. |
| 103 | 91 |
| 104 Raises: | 92 Raises: |
| 105 reraiser_thread.TimeoutError if the current thread is a TimeoutRetryThread | 93 reraiser_thread.TimeoutError: if the current thread is a |
| 106 and the timeout expires. | 94 TimeoutRetryThreadGroup and the timeout expires. |
| 107 """ | 95 """ |
| 108 condition_name = condition.__name__ | 96 condition_name = condition.__name__ |
| 109 timeout_thread = CurrentTimeoutThread() | 97 timeout_thread_group = CurrentTimeoutThreadGroup() |
| 110 while max_tries is None or max_tries > 0: | 98 while max_tries is None or max_tries > 0: |
| 111 result = condition() | 99 result = condition() |
| 112 if max_tries is not None: | 100 if max_tries is not None: |
| 113 max_tries -= 1 | 101 max_tries -= 1 |
| 114 msg = ['condition', repr(condition_name), 'met' if result else 'not met'] | 102 msg = ['condition', repr(condition_name), 'met' if result else 'not met'] |
| 115 if timeout_thread: | 103 if timeout_thread_group: |
| 116 # pylint: disable=no-member | 104 # pylint: disable=no-member |
| 117 msg.append('(%.1fs)' % timeout_thread.GetElapsedTime()) | 105 msg.append('(%.1fs)' % timeout_thread_group.GetElapsedTime()) |
| 118 logging.info(' '.join(msg)) | 106 logging.info(' '.join(msg)) |
| 119 if result: | 107 if result: |
| 120 return result | 108 return result |
| 121 if timeout_thread: | 109 if timeout_thread_group: |
| 122 # pylint: disable=no-member | 110 # pylint: disable=no-member |
| 123 timeout_thread.GetRemainingTime(wait_period, | 111 timeout_thread_group.GetRemainingTime(wait_period, |
| 124 msg='Timed out waiting for %r' % condition_name) | 112 msg='Timed out waiting for %r' % condition_name) |
| 125 time.sleep(wait_period) | 113 time.sleep(wait_period) |
| 126 return None | 114 return None |
| 127 | 115 |
| 128 | 116 |
| 129 def Run(func, timeout, retries, args=None, kwargs=None, desc=None): | 117 def _LogLastException(thread_name, attempt, max_attempts, log_func): |
| 118 log_func('*' * 80) |
| 119 log_func('Exception on thread %s (attempt %d of %d)', thread_name, |
| 120 attempt, max_attempts) |
| 121 log_func('*' * 80) |
| 122 fmt_exc = ''.join(traceback.format_exc()) |
| 123 for line in fmt_exc.splitlines(): |
| 124 log_func(line.rstrip()) |
| 125 log_func('*' * 80) |
| 126 |
| 127 |
| 128 def Run(func, timeout, retries, args=None, kwargs=None, desc=None, |
| 129 error_log_func=logging.critical): |
| 130 """Runs the passed function in a separate thread with timeouts and retries. | 130 """Runs the passed function in a separate thread with timeouts and retries. |
| 131 | 131 |
| 132 Args: | 132 Args: |
| 133 func: the function to be wrapped. | 133 func: the function to be wrapped. |
| 134 timeout: the timeout in seconds for each try. | 134 timeout: the timeout in seconds for each try. |
| 135 retries: the number of retries. | 135 retries: the number of retries. |
| 136 args: list of positional args to pass to |func|. | 136 args: list of positional args to pass to |func|. |
| 137 kwargs: dictionary of keyword args to pass to |func|. | 137 kwargs: dictionary of keyword args to pass to |func|. |
| 138 desc: An optional description of |func| used in logging. If omitted, | 138 desc: An optional description of |func| used in logging. If omitted, |
| 139 |func.__name__| will be used. | 139 |func.__name__| will be used. |
| 140 error_log_func: Logging function when logging errors. |
| 140 | 141 |
| 141 Returns: | 142 Returns: |
| 142 The return value of func(*args, **kwargs). | 143 The return value of func(*args, **kwargs). |
| 143 """ | 144 """ |
| 144 if not args: | 145 if not args: |
| 145 args = [] | 146 args = [] |
| 146 if not kwargs: | 147 if not kwargs: |
| 147 kwargs = {} | 148 kwargs = {} |
| 148 | 149 |
| 149 # The return value uses a list because Python variables are references, not | |
| 150 # values. Closures make a copy of the reference, so updating the closure's | |
| 151 # reference wouldn't update where the original reference pointed. | |
| 152 ret = [None] | |
| 153 def RunOnTimeoutThread(): | |
| 154 ret[0] = func(*args, **kwargs) | |
| 155 | |
| 156 num_try = 1 | 150 num_try = 1 |
| 157 while True: | 151 while True: |
| 158 child_thread = TimeoutRetryThread( | 152 thread_name = 'TimeoutThread-%d-for-%s' % (num_try, |
| 159 RunOnTimeoutThread, timeout, | 153 threading.current_thread().name) |
| 160 name='TimeoutThread-%d-for-%s' % (num_try, | 154 child_thread = reraiser_thread.ReraiserThread(lambda: func(*args, **kwargs), |
| 161 threading.current_thread().name)) | 155 name=thread_name) |
| 162 try: | 156 try: |
| 163 thread_group = reraiser_thread.ReraiserThreadGroup([child_thread]) | 157 thread_group = TimeoutRetryThreadGroup(timeout, threads=[child_thread]) |
| 164 thread_group.StartAll() | 158 thread_group.StartAll(will_block=True) |
| 165 while True: | 159 while True: |
| 166 thread_group.JoinAll(watcher=child_thread.GetWatcher(), timeout=60) | 160 thread_group.JoinAll(watcher=thread_group.GetWatcher(), timeout=60, |
| 161 error_log_func=error_log_func) |
| 167 if thread_group.IsAlive(): | 162 if thread_group.IsAlive(): |
| 168 logging.info('Still working on %s', desc if desc else func.__name__) | 163 logging.info('Still working on %s', desc if desc else func.__name__) |
| 169 else: | 164 else: |
| 170 return ret[0] | 165 return thread_group.GetAllReturnValues()[0] |
| 171 except: | 166 except reraiser_thread.TimeoutError: |
| 172 child_thread.LogTimeoutException() | 167 # Timeouts already get their stacks logged. |
| 173 if num_try > retries: | 168 if num_try > retries: |
| 174 raise | 169 raise |
| 175 num_try += 1 | 170 # Do not catch KeyboardInterrupt. |
| 171 except Exception: # pylint: disable=broad-except |
| 172 if num_try > retries: |
| 173 raise |
| 174 _LogLastException(thread_name, num_try, retries + 1, error_log_func) |
| 175 num_try += 1 |
| OLD | NEW |