| OLD | NEW | 
|---|
| 1 # Copyright 2013 The Chromium Authors. All rights reserved. | 1 # Copyright 2013 The Chromium Authors. All rights reserved. | 
| 2 # Use of this source code is governed by a BSD-style license that can be | 2 # Use of this source code is governed by a BSD-style license that can be | 
| 3 # found in the LICENSE file. | 3 # found in the LICENSE file. | 
| 4 | 4 | 
| 5 """A utility to run functions with timeouts and retries.""" | 5 """A utility to run functions with timeouts and retries.""" | 
| 6 # pylint: disable=W0702 | 6 # pylint: disable=W0702 | 
| 7 | 7 | 
| 8 import logging | 8 import logging | 
| 9 import threading | 9 import threading | 
| 10 import time | 10 import time | 
| 11 import traceback | 11 import traceback | 
| 12 | 12 | 
| 13 from devil.utils import reraiser_thread | 13 from devil.utils import reraiser_thread | 
| 14 from devil.utils import watchdog_timer | 14 from devil.utils import watchdog_timer | 
| 15 | 15 | 
| 16 | 16 | 
| 17 class TimeoutRetryThread(reraiser_thread.ReraiserThread): | 17 | 
| 18   def __init__(self, func, timeout, name): | 18 class TimeoutRetryThreadGroup(reraiser_thread.ReraiserThreadGroup): | 
| 19     super(TimeoutRetryThread, self).__init__(func, name=name) | 19   def __init__(self, timeout, threads=None): | 
|  | 20     super(TimeoutRetryThreadGroup, self).__init__(threads) | 
| 20     self._watcher = watchdog_timer.WatchdogTimer(timeout) | 21     self._watcher = watchdog_timer.WatchdogTimer(timeout) | 
| 21     self._expired = False |  | 
| 22 | 22 | 
| 23   def GetWatcher(self): | 23   def GetWatcher(self): | 
| 24     """Returns the watchdog keeping track of this thread's time.""" | 24     """Returns the watchdog keeping track of this thread's time.""" | 
| 25     return self._watcher | 25     return self._watcher | 
| 26 | 26 | 
| 27   def GetElapsedTime(self): | 27   def GetElapsedTime(self): | 
| 28     return self._watcher.GetElapsed() | 28     return self._watcher.GetElapsed() | 
| 29 | 29 | 
| 30   def GetRemainingTime(self, required=0, msg=None): | 30   def GetRemainingTime(self, required=0, msg=None): | 
| 31     """Get the remaining time before the thread times out. | 31     """Get the remaining time before the thread times out. | 
| (...skipping 13 matching lines...) Expand all  Loading... | 
| 45       reraiser_thread.TimeoutError if the remaining time is less than the | 45       reraiser_thread.TimeoutError if the remaining time is less than the | 
| 46         required time. | 46         required time. | 
| 47     """ | 47     """ | 
| 48     remaining = self._watcher.GetRemaining() | 48     remaining = self._watcher.GetRemaining() | 
| 49     if remaining is not None and remaining < required: | 49     if remaining is not None and remaining < required: | 
| 50       if msg is None: | 50       if msg is None: | 
| 51         msg = 'Timeout expired' | 51         msg = 'Timeout expired' | 
| 52       if remaining > 0: | 52       if remaining > 0: | 
| 53         msg += (', wait of %.1f secs required but only %.1f secs left' | 53         msg += (', wait of %.1f secs required but only %.1f secs left' | 
| 54                 % (required, remaining)) | 54                 % (required, remaining)) | 
| 55       self._expired = True |  | 
| 56       raise reraiser_thread.TimeoutError(msg) | 55       raise reraiser_thread.TimeoutError(msg) | 
| 57     return remaining | 56     return remaining | 
| 58 | 57 | 
| 59   def LogTimeoutException(self): |  | 
| 60     """Log the exception that terminated this thread.""" |  | 
| 61     if not self._expired: |  | 
| 62       return |  | 
| 63     logging.critical('*' * 80) |  | 
| 64     logging.critical('%s on thread %r', self._exc_info[0].__name__, self.name) |  | 
| 65     logging.critical('*' * 80) |  | 
| 66     fmt_exc = ''.join(traceback.format_exception(*self._exc_info)) |  | 
| 67     for line in fmt_exc.splitlines(): |  | 
| 68       logging.critical(line.rstrip()) |  | 
| 69     logging.critical('*' * 80) |  | 
| 70 | 58 | 
| 71 | 59 def CurrentTimeoutThreadGroup(): | 
| 72 def CurrentTimeoutThread(): | 60   """Returns the thread group that owns or is blocked on the active thread. | 
| 73   """Get the current thread if it is a TimeoutRetryThread. |  | 
| 74 | 61 | 
| 75   Returns: | 62   Returns: | 
| 76     The current thread if it is a TimeoutRetryThread, otherwise None. | 63     Returns None if no TimeoutRetryThreadGroup is tracking the current thread. | 
| 77   """ | 64   """ | 
| 78   current_thread = threading.current_thread() | 65   thread_group = reraiser_thread.CurrentThreadGroup() | 
| 79   if isinstance(current_thread, TimeoutRetryThread): | 66   while thread_group: | 
| 80     return current_thread | 67     if isinstance(thread_group, TimeoutRetryThreadGroup): | 
| 81   else: | 68       return thread_group | 
| 82     return None | 69     thread_group = thread_group.blocked_parent_thread_group | 
|  | 70   return None | 
| 83 | 71 | 
| 84 | 72 | 
| 85 def WaitFor(condition, wait_period=5, max_tries=None): | 73 def WaitFor(condition, wait_period=5, max_tries=None): | 
| 86   """Wait for a condition to become true. | 74   """Wait for a condition to become true. | 
| 87 | 75 | 
| 88   Repeatedly call the function condition(), with no arguments, until it returns | 76   Repeatedly call the function condition(), with no arguments, until it returns | 
| 89   a true value. | 77   a true value. | 
| 90 | 78 | 
| 91   If called within a TimeoutRetryThread, it cooperates nicely with it. | 79   If called within a TimeoutRetryThreadGroup, it cooperates nicely with it. | 
| 92 | 80 | 
| 93   Args: | 81   Args: | 
| 94     condition: function with the condition to check | 82     condition: function with the condition to check | 
| 95     wait_period: number of seconds to wait before retrying to check the | 83     wait_period: number of seconds to wait before retrying to check the | 
| 96       condition | 84       condition | 
| 97     max_tries: maximum number of checks to make, the default tries forever | 85     max_tries: maximum number of checks to make, the default tries forever | 
| 98       or until the TimeoutRetryThread expires. | 86       or until the TimeoutRetryThreadGroup expires. | 
| 99 | 87 | 
| 100   Returns: | 88   Returns: | 
| 101     The true value returned by the condition, or None if the condition was | 89     The true value returned by the condition, or None if the condition was | 
| 102     not met after max_tries. | 90     not met after max_tries. | 
| 103 | 91 | 
| 104   Raises: | 92   Raises: | 
| 105     reraiser_thread.TimeoutError if the current thread is a TimeoutRetryThread | 93     reraiser_thread.TimeoutError: if the current thread is a | 
| 106       and the timeout expires. | 94       TimeoutRetryThreadGroup and the timeout expires. | 
| 107   """ | 95   """ | 
| 108   condition_name = condition.__name__ | 96   condition_name = condition.__name__ | 
| 109   timeout_thread = CurrentTimeoutThread() | 97   timeout_thread_group = CurrentTimeoutThreadGroup() | 
| 110   while max_tries is None or max_tries > 0: | 98   while max_tries is None or max_tries > 0: | 
| 111     result = condition() | 99     result = condition() | 
| 112     if max_tries is not None: | 100     if max_tries is not None: | 
| 113       max_tries -= 1 | 101       max_tries -= 1 | 
| 114     msg = ['condition', repr(condition_name), 'met' if result else 'not met'] | 102     msg = ['condition', repr(condition_name), 'met' if result else 'not met'] | 
| 115     if timeout_thread: | 103     if timeout_thread_group: | 
| 116       # pylint: disable=no-member | 104       # pylint: disable=no-member | 
| 117       msg.append('(%.1fs)' % timeout_thread.GetElapsedTime()) | 105       msg.append('(%.1fs)' % timeout_thread_group.GetElapsedTime()) | 
| 118     logging.info(' '.join(msg)) | 106     logging.info(' '.join(msg)) | 
| 119     if result: | 107     if result: | 
| 120       return result | 108       return result | 
| 121     if timeout_thread: | 109     if timeout_thread_group: | 
| 122       # pylint: disable=no-member | 110       # pylint: disable=no-member | 
| 123       timeout_thread.GetRemainingTime(wait_period, | 111       timeout_thread_group.GetRemainingTime(wait_period, | 
| 124           msg='Timed out waiting for %r' % condition_name) | 112           msg='Timed out waiting for %r' % condition_name) | 
| 125     time.sleep(wait_period) | 113     time.sleep(wait_period) | 
| 126   return None | 114   return None | 
| 127 | 115 | 
| 128 | 116 | 
| 129 def Run(func, timeout, retries, args=None, kwargs=None, desc=None): | 117 def _LogLastException(thread_name, attempt, max_attempts, log_func): | 
|  | 118   log_func('*' * 80) | 
|  | 119   log_func('Exception on thread %s (attempt %d of %d)', thread_name, | 
|  | 120                    attempt, max_attempts) | 
|  | 121   log_func('*' * 80) | 
|  | 122   fmt_exc = ''.join(traceback.format_exc()) | 
|  | 123   for line in fmt_exc.splitlines(): | 
|  | 124     log_func(line.rstrip()) | 
|  | 125   log_func('*' * 80) | 
|  | 126 | 
|  | 127 | 
|  | 128 def Run(func, timeout, retries, args=None, kwargs=None, desc=None, | 
|  | 129         error_log_func=logging.critical): | 
| 130   """Runs the passed function in a separate thread with timeouts and retries. | 130   """Runs the passed function in a separate thread with timeouts and retries. | 
| 131 | 131 | 
| 132   Args: | 132   Args: | 
| 133     func: the function to be wrapped. | 133     func: the function to be wrapped. | 
| 134     timeout: the timeout in seconds for each try. | 134     timeout: the timeout in seconds for each try. | 
| 135     retries: the number of retries. | 135     retries: the number of retries. | 
| 136     args: list of positional args to pass to |func|. | 136     args: list of positional args to pass to |func|. | 
| 137     kwargs: dictionary of keyword args to pass to |func|. | 137     kwargs: dictionary of keyword args to pass to |func|. | 
| 138     desc: An optional description of |func| used in logging. If omitted, | 138     desc: An optional description of |func| used in logging. If omitted, | 
| 139       |func.__name__| will be used. | 139       |func.__name__| will be used. | 
|  | 140     error_log_func: Logging function when logging errors. | 
| 140 | 141 | 
| 141   Returns: | 142   Returns: | 
| 142     The return value of func(*args, **kwargs). | 143     The return value of func(*args, **kwargs). | 
| 143   """ | 144   """ | 
| 144   if not args: | 145   if not args: | 
| 145     args = [] | 146     args = [] | 
| 146   if not kwargs: | 147   if not kwargs: | 
| 147     kwargs = {} | 148     kwargs = {} | 
| 148 | 149 | 
| 149   # The return value uses a list because Python variables are references, not |  | 
| 150   # values. Closures make a copy of the reference, so updating the closure's |  | 
| 151   # reference wouldn't update where the original reference pointed. |  | 
| 152   ret = [None] |  | 
| 153   def RunOnTimeoutThread(): |  | 
| 154     ret[0] = func(*args, **kwargs) |  | 
| 155 |  | 
| 156   num_try = 1 | 150   num_try = 1 | 
| 157   while True: | 151   while True: | 
| 158     child_thread = TimeoutRetryThread( | 152     thread_name = 'TimeoutThread-%d-for-%s' % (num_try, | 
| 159       RunOnTimeoutThread, timeout, | 153                                                threading.current_thread().name) | 
| 160       name='TimeoutThread-%d-for-%s' % (num_try, | 154     child_thread = reraiser_thread.ReraiserThread(lambda: func(*args, **kwargs), | 
| 161                                         threading.current_thread().name)) | 155                                                   name=thread_name) | 
| 162     try: | 156     try: | 
| 163       thread_group = reraiser_thread.ReraiserThreadGroup([child_thread]) | 157       thread_group = TimeoutRetryThreadGroup(timeout, threads=[child_thread]) | 
| 164       thread_group.StartAll() | 158       thread_group.StartAll(will_block=True) | 
| 165       while True: | 159       while True: | 
| 166         thread_group.JoinAll(watcher=child_thread.GetWatcher(), timeout=60) | 160         thread_group.JoinAll(watcher=thread_group.GetWatcher(), timeout=60, | 
|  | 161                              error_log_func=error_log_func) | 
| 167         if thread_group.IsAlive(): | 162         if thread_group.IsAlive(): | 
| 168           logging.info('Still working on %s', desc if desc else func.__name__) | 163           logging.info('Still working on %s', desc if desc else func.__name__) | 
| 169         else: | 164         else: | 
| 170           return ret[0] | 165           return thread_group.GetAllReturnValues()[0] | 
| 171     except: | 166     except reraiser_thread.TimeoutError: | 
| 172       child_thread.LogTimeoutException() | 167       # Timeouts already get their stacks logged. | 
| 173       if num_try > retries: | 168       if num_try > retries: | 
| 174         raise | 169         raise | 
| 175       num_try += 1 | 170       # Do not catch KeyboardInterrupt. | 
|  | 171     except Exception:  # pylint: disable=broad-except | 
|  | 172       if num_try > retries: | 
|  | 173         raise | 
|  | 174       _LogLastException(thread_name, num_try, retries + 1, error_log_func) | 
|  | 175     num_try += 1 | 
| OLD | NEW | 
|---|