Chromium Code Reviews| Index: au_test_harness/parallel_test_job.py |
| diff --git a/au_test_harness/parallel_test_job.py b/au_test_harness/parallel_test_job.py |
| index 096ec170432a06b44f52b12b423f5aaddbc51040..dca781b62e0069b42b75d65f4360392fd7127f1a 100644 |
| --- a/au_test_harness/parallel_test_job.py |
| +++ b/au_test_harness/parallel_test_job.py |
| @@ -4,58 +4,86 @@ |
| """Module containing methods/classes related to running parallel test jobs.""" |
| +import multiprocessing |
| import sys |
| import threading |
| import time |
| import cros_build_lib as cros_lib |
| -class ParallelJob(threading.Thread): |
| - """Small wrapper for threading. Thread that releases a semaphores on exit.""" |
| +class ParallelJobTimeoutError(Exception): |
| + """Thrown when a job ran for longer than expected.""" |
| + pass |
| - def __init__(self, starting_semaphore, ending_semaphore, target, args): |
| + |
| +class ParallelJob(multiprocessing.Process): |
| + """Small wrapper for Process that stores output of a its target method.""" |
|
dgarrett
2011/04/08 03:56:42
its -> it's
sosa
2011/04/08 18:59:04
it's == it is :)
On 2011/04/08 03:56:42, dgarrett
|
| + |
| + MAX_TIMEOUT = 1800 |
|
dgarrett
2011/04/08 03:56:42
Units?
sosa
2011/04/08 18:59:04
Done.
|
| + SLEEP_TIMEOUT = 180 |
| + |
| + def __init__(self, starting_semaphore, target, args): |
| """Initializes an instance of a job. |
| Args: |
| starting_semaphore: Semaphore used by caller to wait on such that |
| - there isn't more than a certain number of threads running. Should |
| - be initialized to a value for the number of threads wanting to be run |
| - at a time. |
| - ending_semaphore: Semaphore is released every time a job ends. Should be |
| - initialized to 0 before starting first job. Should be acquired once for |
| - each job. Threading.Thread.join() has a bug where if the run function |
| - terminates too quickly join() will hang forever. |
| + there isn't more than a certain number of parallel_jobs running. Should |
| + be initialized to a value for the number of parallel_jobs wanting to be |
| + run at a time. |
| target: The func to run. |
| args: Args to pass to the fun. |
| """ |
| - threading.Thread.__init__(self, target=target, args=args) |
| + super(ParallelJob, self).__init__(target=target, args=args) |
| self._target = target |
| self._args = args |
| self._starting_semaphore = starting_semaphore |
| - self._ending_semaphore = ending_semaphore |
| - self._output = None |
| - self._completed = False |
| def run(self): |
| """Thread override. Runs the method specified and sets output.""" |
| try: |
| - self._output = self._target(*self._args) |
| + self._target(*self._args) |
| finally: |
| - # Our own clean up. |
| - self._Cleanup() |
| - self._completed = True |
| - # From threading.py to avoid a refcycle. |
| - del self._target, self._args |
| - |
| - def GetOutput(self): |
| - """Returns the output of the method run.""" |
| - assert self._completed, 'GetOutput called before thread was run.' |
| - return self._output |
| - |
| - def _Cleanup(self): |
| - """Releases semaphores for a waiting caller.""" |
| - self._starting_semaphore.release() |
| - self._ending_semaphore.release() |
| + self._starting_semaphore.release() |
| + |
| + @classmethod |
| + def WaitUntilJobsComplete(cls, parallel_jobs): |
| + """Waits until all parallel_jobs have completed before returning. |
| + |
| + Given an array of parallel_jobs, returns once all parallel_jobs have |
| + completed or a max timeout is reached. |
| + |
| + Raises: |
| + ParallelJobTimeoutError: if max timeout is reached. |
| + """ |
| + def GetCurrentActiveCount(): |
| + """Returns the (number of active jobs, first active job).""" |
| + active_count = 0 |
| + active_job = None |
| + for parallel_job in parallel_jobs: |
| + if parallel_job.is_alive(): |
| + active_count += 1 |
| + if not active_job: |
| + active_job = parallel_job |
| + |
| + return (active_count, parallel_job) |
| + |
| + start_time = time.time() |
| + while (time.time() - start_time) < cls.MAX_TIMEOUT: |
| + (active_count, active_job) = GetCurrentActiveCount() |
| + if active_count == 0: |
| + return |
| + else: |
| + print >> sys.stderr, ( |
| + 'Process Pool Active: Waiting on %d/%d jobs to complete' % |
| + (active_count, len(parallel_jobs))) |
| + active_job.join(cls.SLEEP_TIMEOUT) |
| + |
| + for parallel_job in parallel_jobs: |
| + if parallel_job.is_alive(): |
| + parallel_job.terminate() |
| + |
| + raise ParallelJobTimeoutError('Exceeded max time of %d seconds to wait for ' |
| + 'job completion.' % cls.MAX_TIMEOUT) |
| def __str__(self): |
| return '%s(%s)' % (self._target, self._args) |
| @@ -66,44 +94,47 @@ def RunParallelJobs(number_of_simultaneous_jobs, jobs, jobs_args, |
| """Runs set number of specified jobs in parallel. |
| Args: |
| - number_of_simultaneous_jobs: Max number of threads to be run in parallel. |
| + number_of_simultaneous_jobs: Max number of parallel_jobs to be run in |
| + parallel. |
| jobs: Array of methods to run. |
| jobs_args: Array of args associated with method calls. |
| print_status: True if you'd like this to print out .'s as it runs jobs. |
| Returns: |
| - Returns an array of results corresponding to each thread. |
| + Returns an array of results corresponding to each parallel_job. |
| """ |
| def _TwoTupleize(x, y): |
| return (x, y) |
| - threads = [] |
| - job_start_semaphore = threading.Semaphore(number_of_simultaneous_jobs) |
| - join_semaphore = threading.Semaphore(0) |
| - assert len(jobs) == len(jobs_args), 'Length of args array is wrong.' |
| - |
| - # Create the parallel jobs. |
| - for job, args in map(_TwoTupleize, jobs, jobs_args): |
| - thread = ParallelJob(job_start_semaphore, join_semaphore, target=job, |
| - args=args) |
| - threads.append(thread) |
| + def ProcessOutputWrapper(func, args, output): |
| + """Simple function wrapper that puts the output of a function in a queue.""" |
| + output.put(func(*args)) |
| + assert len(jobs) == len(jobs_args), 'Length of args array is wrong.' |
| # Cache sudo access. |
| cros_lib.RunCommand(['sudo', 'echo', 'Caching sudo credentials'], |
| print_cmd=False, redirect_stdout=True, |
| redirect_stderr=True) |
| + parallel_jobs = [] |
| + output_array = [] |
| + |
| + # Semaphore used to create a Process Pool. |
| + job_start_semaphore = threading.Semaphore(number_of_simultaneous_jobs) |
|
dgarrett
2011/04/08 03:56:42
Is this semaphore really safe across processes ins
sosa
2011/04/08 18:59:04
I think you're right, switching to multiprocessing
|
| + |
| + # Create the parallel jobs. |
| + for job, args in map(_TwoTupleize, jobs, jobs_args): |
|
dgarrett
2011/04/08 03:56:42
_TwoTupleize could be replaced in-line with "lambd
sosa
2011/04/08 18:59:04
Done.
|
| + output = multiprocessing.Queue() |
| + parallel_job = ParallelJob(job_start_semaphore, |
| + target=ProcessOutputWrapper, |
| + args=(job, args, output)) |
| + parallel_jobs.append(parallel_job) |
| + output_array.append(output) |
| + |
| # We use a semaphore to ensure we don't run more jobs than required. |
| - # After each thread finishes, it releases (increments semaphore). |
| - # Acquire blocks of num jobs reached and continues when a thread finishes. |
| - for next_thread in threads: |
| + # After each parallel_job finishes, it releases (increments semaphore). |
| + for next_parallel_job in parallel_jobs: |
| job_start_semaphore.acquire(blocking=True) |
| - next_thread.start() |
| - |
| - # Wait on the rest of the threads to finish. |
| - for thread in threads: |
| - while not join_semaphore.acquire(blocking=False): |
| - time.sleep(5) |
| - if print_status: |
| - print >> sys.stderr, '.', |
| + next_parallel_job.start() |
| - return [thread.GetOutput() for thread in threads] |
| + ParallelJob.WaitUntilJobsComplete(parallel_jobs) |
| + return [output.get() for output in output_array] |