OLD | NEW |
| (Empty) |
1 # Copyright 2011 Google Inc. All Rights Reserved. | |
2 # | |
3 # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 # you may not use this file except in compliance with the License. | |
5 # You may obtain a copy of the License at | |
6 # | |
7 # http://www.apache.org/licenses/LICENSE-2.0 | |
8 # | |
9 # Unless required by applicable law or agreed to in writing, software | |
10 # distributed under the License is distributed on an "AS IS" BASIS, | |
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 # See the License for the specific language governing permissions and | |
13 # limitations under the License. | |
14 | |
15 """Basic thread pool with exception handler.""" | |
16 | |
17 import logging | |
18 import Queue | |
19 import threading | |
20 | |
21 | |
22 # Magic values used to cleanly bring down threads. | |
23 _THREAD_EXIT_MAGIC = ('Clean', 'Thread', 'Exit') | |
24 | |
25 | |
26 def _DefaultExceptionHandler(e): | |
27 logging.exception(e) | |
28 | |
29 | |
30 class Worker(threading.Thread): | |
31 """Thread executing tasks from a given task's queue.""" | |
32 | |
33 def __init__(self, tasks, exception_handler): | |
34 threading.Thread.__init__(self) | |
35 self.tasks = tasks | |
36 self.daemon = True | |
37 self.exception_handler = exception_handler | |
38 self.start() | |
39 | |
40 def run(self): | |
41 while True: | |
42 func, args, kargs = self.tasks.get() | |
43 | |
44 # Listen for magic value indicating thread exit. | |
45 if (func, args, kargs) == _THREAD_EXIT_MAGIC: | |
46 break | |
47 | |
48 try: | |
49 func(*args, **kargs) | |
50 except Exception, e: | |
51 self.exception_handler(e) | |
52 finally: | |
53 self.tasks.task_done() | |
54 | |
55 | |
56 class ThreadPool(object): | |
57 """Pool of threads consuming tasks from a queue.""" | |
58 | |
59 def __init__(self, num_threads, exception_handler=_DefaultExceptionHandler): | |
60 self.tasks = Queue.Queue(num_threads) | |
61 self.threads = [] | |
62 for _ in range(num_threads): | |
63 self.threads.append(Worker(self.tasks, exception_handler)) | |
64 | |
65 def AddTask(self, func, *args, **kargs): | |
66 """Add a task to the queue.""" | |
67 self.tasks.put((func, args, kargs)) | |
68 | |
69 def WaitCompletion(self): | |
70 """Wait for completion of all the tasks in the queue.""" | |
71 self.tasks.join() | |
72 | |
73 def Shutdown(self): | |
74 """Shutdown the thread pool.""" | |
75 for thread in self.threads: | |
76 self.tasks.put(_THREAD_EXIT_MAGIC) | |
77 | |
78 for thread in self.threads: | |
79 thread.join() | |
OLD | NEW |