| OLD | NEW |
| (Empty) |
| 1 # Copyright 2011 Google Inc. All Rights Reserved. | |
| 2 # | |
| 3 # Licensed under the Apache License, Version 2.0 (the "License"); | |
| 4 # you may not use this file except in compliance with the License. | |
| 5 # You may obtain a copy of the License at | |
| 6 # | |
| 7 # http://www.apache.org/licenses/LICENSE-2.0 | |
| 8 # | |
| 9 # Unless required by applicable law or agreed to in writing, software | |
| 10 # distributed under the License is distributed on an "AS IS" BASIS, | |
| 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| 12 # See the License for the specific language governing permissions and | |
| 13 # limitations under the License. | |
| 14 | |
| 15 """Basic thread pool with exception handler.""" | |
| 16 | |
| 17 import logging | |
| 18 import Queue | |
| 19 import threading | |
| 20 | |
| 21 | |
| 22 # Magic values used to cleanly bring down threads. | |
| 23 _THREAD_EXIT_MAGIC = ('Clean', 'Thread', 'Exit') | |
| 24 | |
| 25 | |
| 26 def _DefaultExceptionHandler(e): | |
| 27 logging.exception(e) | |
| 28 | |
| 29 | |
| 30 class Worker(threading.Thread): | |
| 31 """Thread executing tasks from a given task's queue.""" | |
| 32 | |
| 33 def __init__(self, tasks, exception_handler): | |
| 34 threading.Thread.__init__(self) | |
| 35 self.tasks = tasks | |
| 36 self.daemon = True | |
| 37 self.exception_handler = exception_handler | |
| 38 self.results = [] | |
| 39 self.start() | |
| 40 | |
| 41 def run(self): | |
| 42 while True: | |
| 43 func, args, kargs = self.tasks.get() | |
| 44 | |
| 45 # Listen for magic value indicating thread exit. | |
| 46 if (func, args, kargs) == _THREAD_EXIT_MAGIC: | |
| 47 break | |
| 48 | |
| 49 try: | |
| 50 result = func(*args, **kargs) | |
| 51 if result is not None: | |
| 52 self.results.append(result) | |
| 53 except Exception, e: | |
| 54 self.exception_handler(e) | |
| 55 finally: | |
| 56 self.tasks.task_done() | |
| 57 | |
| 58 | |
| 59 class ThreadPool(object): | |
| 60 """Pool of threads consuming tasks from a queue.""" | |
| 61 | |
| 62 def __init__(self, num_threads, exception_handler=_DefaultExceptionHandler): | |
| 63 self.tasks = Queue.Queue(num_threads) | |
| 64 self.threads = [] | |
| 65 for _ in range(num_threads): | |
| 66 self.threads.append(Worker(self.tasks, exception_handler)) | |
| 67 | |
| 68 def AddTask(self, func, *args, **kargs): | |
| 69 """Add a task to the queue.""" | |
| 70 self.tasks.put((func, args, kargs)) | |
| 71 | |
| 72 def Shutdown(self, should_return_results=False): | |
| 73 """Shutdown the thread pool.""" | |
| 74 self.tasks.join() | |
| 75 for thread in self.threads: | |
| 76 self.tasks.put(_THREAD_EXIT_MAGIC) | |
| 77 | |
| 78 results = [] | |
| 79 for thread in self.threads: | |
| 80 thread.join() | |
| 81 if should_return_results: | |
| 82 results += thread.results | |
| 83 return results | |
| OLD | NEW |