Index: au_test_harness/parallel_test_job.py |
diff --git a/au_test_harness/parallel_test_job.py b/au_test_harness/parallel_test_job.py |
index 096ec170432a06b44f52b12b423f5aaddbc51040..89befffe530175174deb4597f0ee8f84f1b4454e 100644 |
--- a/au_test_harness/parallel_test_job.py |
+++ b/au_test_harness/parallel_test_job.py |
@@ -4,58 +4,86 @@ |
"""Module containing methods/classes related to running parallel test jobs.""" |
+import multiprocessing |
import sys |
-import threading |
import time |
import cros_build_lib as cros_lib |
-class ParallelJob(threading.Thread): |
- """Small wrapper for threading. Thread that releases a semaphores on exit.""" |
+class ParallelJobTimeoutError(Exception): |
+ """Thrown when a job ran for longer than expected.""" |
+ pass |
- def __init__(self, starting_semaphore, ending_semaphore, target, args): |
+ |
+class ParallelJob(multiprocessing.Process): |
+ """Small wrapper for Process that stores output of its target method.""" |
+ |
+ MAX_TIMEOUT_SECONDS = 1800 |
+ SLEEP_TIMEOUT_SECONDS = 180 |
+ |
+ def __init__(self, starting_semaphore, target, args): |
"""Initializes an instance of a job. |
Args: |
starting_semaphore: Semaphore used by caller to wait on such that |
- there isn't more than a certain number of threads running. Should |
- be initialized to a value for the number of threads wanting to be run |
- at a time. |
- ending_semaphore: Semaphore is released every time a job ends. Should be |
- initialized to 0 before starting first job. Should be acquired once for |
- each job. Threading.Thread.join() has a bug where if the run function |
- terminates too quickly join() will hang forever. |
+ there isn't more than a certain number of parallel_jobs running. Should |
+ be initialized to a value for the number of parallel_jobs wanting to be |
+ run at a time. |
target: The func to run. |
args: Args to pass to the fun. |
""" |
- threading.Thread.__init__(self, target=target, args=args) |
+ super(ParallelJob, self).__init__(target=target, args=args) |
self._target = target |
self._args = args |
self._starting_semaphore = starting_semaphore |
- self._ending_semaphore = ending_semaphore |
- self._output = None |
- self._completed = False |
def run(self): |
"""Thread override. Runs the method specified and sets output.""" |
try: |
- self._output = self._target(*self._args) |
+ self._target(*self._args) |
finally: |
- # Our own clean up. |
- self._Cleanup() |
- self._completed = True |
- # From threading.py to avoid a refcycle. |
- del self._target, self._args |
- |
- def GetOutput(self): |
- """Returns the output of the method run.""" |
- assert self._completed, 'GetOutput called before thread was run.' |
- return self._output |
- |
- def _Cleanup(self): |
- """Releases semaphores for a waiting caller.""" |
- self._starting_semaphore.release() |
- self._ending_semaphore.release() |
+ self._starting_semaphore.release() |
+ |
+ @classmethod |
+ def WaitUntilJobsComplete(cls, parallel_jobs): |
+ """Waits until all parallel_jobs have completed before returning. |
+ |
+ Given an array of parallel_jobs, returns once all parallel_jobs have |
+ completed or a max timeout is reached. |
+ |
+ Raises: |
+ ParallelJobTimeoutError: if max timeout is reached. |
+ """ |
+ def GetCurrentActiveCount(): |
+ """Returns the (number of active jobs, first active job).""" |
+ active_count = 0 |
+ active_job = None |
+ for parallel_job in parallel_jobs: |
+ if parallel_job.is_alive(): |
+ active_count += 1 |
+ if not active_job: |
+ active_job = parallel_job |
+ |
+ return (active_count, parallel_job) |
+ |
+ start_time = time.time() |
+ while (time.time() - start_time) < cls.MAX_TIMEOUT_SECONDS: |
+ (active_count, active_job) = GetCurrentActiveCount() |
+ if active_count == 0: |
+ return |
+ else: |
+ print >> sys.stderr, ( |
+ 'Process Pool Active: Waiting on %d/%d jobs to complete' % |
+ (active_count, len(parallel_jobs))) |
+ active_job.join(cls.SLEEP_TIMEOUT_SECONDS) |
+ time.sleep(5) # Prevents lots of printing out as job is ending. |
+ |
+ for parallel_job in parallel_jobs: |
+ if parallel_job.is_alive(): |
+ parallel_job.terminate() |
+ |
+ raise ParallelJobTimeoutError('Exceeded max time of %d seconds to wait for ' |
+ 'job completion.' % cls.MAX_TIMEOUT_SECONDS) |
def __str__(self): |
return '%s(%s)' % (self._target, self._args) |
@@ -66,44 +94,44 @@ def RunParallelJobs(number_of_simultaneous_jobs, jobs, jobs_args, |
"""Runs set number of specified jobs in parallel. |
Args: |
- number_of_simultaneous_jobs: Max number of threads to be run in parallel. |
+ number_of_simultaneous_jobs: Max number of parallel_jobs to be run in |
+ parallel. |
jobs: Array of methods to run. |
jobs_args: Array of args associated with method calls. |
print_status: True if you'd like this to print out .'s as it runs jobs. |
Returns: |
- Returns an array of results corresponding to each thread. |
+ Returns an array of results corresponding to each parallel_job. |
""" |
- def _TwoTupleize(x, y): |
- return (x, y) |
+ def ProcessOutputWrapper(func, args, output): |
+ """Simple function wrapper that puts the output of a function in a queue.""" |
+ output.put(func(*args)) |
- threads = [] |
- job_start_semaphore = threading.Semaphore(number_of_simultaneous_jobs) |
- join_semaphore = threading.Semaphore(0) |
assert len(jobs) == len(jobs_args), 'Length of args array is wrong.' |
- |
- # Create the parallel jobs. |
- for job, args in map(_TwoTupleize, jobs, jobs_args): |
- thread = ParallelJob(job_start_semaphore, join_semaphore, target=job, |
- args=args) |
- threads.append(thread) |
- |
# Cache sudo access. |
cros_lib.RunCommand(['sudo', 'echo', 'Caching sudo credentials'], |
print_cmd=False, redirect_stdout=True, |
redirect_stderr=True) |
+ parallel_jobs = [] |
+ output_array = [] |
+ |
+ # Semaphore used to create a Process Pool. |
+ job_start_semaphore = multiprocessing.Semaphore(number_of_simultaneous_jobs) |
+ |
+ # Create the parallel jobs. |
+ for job, args in map(lambda x, y: (x, y), jobs, jobs_args): |
+ output = multiprocessing.Queue() |
+ parallel_job = ParallelJob(job_start_semaphore, |
+ target=ProcessOutputWrapper, |
+ args=(job, args, output)) |
+ parallel_jobs.append(parallel_job) |
+ output_array.append(output) |
+ |
# We use a semaphore to ensure we don't run more jobs than required. |
- # After each thread finishes, it releases (increments semaphore). |
- # Acquire blocks of num jobs reached and continues when a thread finishes. |
- for next_thread in threads: |
- job_start_semaphore.acquire(blocking=True) |
- next_thread.start() |
- |
- # Wait on the rest of the threads to finish. |
- for thread in threads: |
- while not join_semaphore.acquire(blocking=False): |
- time.sleep(5) |
- if print_status: |
- print >> sys.stderr, '.', |
- |
- return [thread.GetOutput() for thread in threads] |
+ # After each parallel_job finishes, it releases (increments semaphore). |
+ for next_parallel_job in parallel_jobs: |
+ job_start_semaphore.acquire(block=True) |
+ next_parallel_job.start() |
+ |
+ ParallelJob.WaitUntilJobsComplete(parallel_jobs) |
+ return [output.get() for output in output_array] |