OLD | NEW |
| (Empty) |
1 #!/usr/bin/env python | |
2 # Copyright 2013 The Chromium Authors. All rights reserved. | |
3 # Use of this source code is governed by a BSD-style license that can be | |
4 # found in the LICENSE file. | |
5 | |
6 # Lambda may not be necessary. | |
7 # pylint: disable=W0108 | |
8 | |
9 import functools | |
10 import logging | |
11 import os | |
12 import signal | |
13 import sys | |
14 import threading | |
15 import time | |
16 import unittest | |
17 | |
18 ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
19 sys.path.insert(0, ROOT_DIR) | |
20 | |
21 from utils import threading_utils | |
22 | |
23 | |
24 def timeout(max_running_time): | |
25 """Test method decorator that fails the test if it executes longer | |
26 than |max_running_time| seconds. | |
27 | |
28 It exists to terminate tests in case of deadlocks. There's a high chance that | |
29 process is broken after such timeout (due to hanging deadlocked threads that | |
30 can own some shared resources). But failing early (maybe not in a cleanest | |
31 way) due to timeout is generally better than hanging indefinitely. | |
32 | |
33 |max_running_time| should be an order of magnitude (or even two orders) larger | |
34 than the expected run time of the test to compensate for slow machine, high | |
35 CPU utilization by some other processes, etc. | |
36 | |
37 Can not be nested. | |
38 | |
39 Noop on windows (since win32 doesn't support signal.setitimer). | |
40 """ | |
41 if sys.platform == 'win32': | |
42 return lambda method: method | |
43 | |
44 def decorator(method): | |
45 @functools.wraps(method) | |
46 def wrapper(self, *args, **kwargs): | |
47 signal.signal(signal.SIGALRM, lambda *_args: self.fail('Timeout')) | |
48 signal.setitimer(signal.ITIMER_REAL, max_running_time) | |
49 try: | |
50 return method(self, *args, **kwargs) | |
51 finally: | |
52 signal.signal(signal.SIGALRM, signal.SIG_DFL) | |
53 signal.setitimer(signal.ITIMER_REAL, 0) | |
54 return wrapper | |
55 | |
56 return decorator | |
57 | |
58 | |
59 class ThreadPoolTest(unittest.TestCase): | |
60 MIN_THREADS = 0 | |
61 MAX_THREADS = 32 | |
62 | |
63 # Append custom assert messages to default ones (works with python >= 2.7). | |
64 longMessage = True | |
65 | |
66 @staticmethod | |
67 def sleep_task(duration=0.01): | |
68 """Returns function that sleeps |duration| sec and returns its argument.""" | |
69 def task(arg): | |
70 time.sleep(duration) | |
71 return arg | |
72 return task | |
73 | |
74 def retrying_sleep_task(self, duration=0.01): | |
75 """Returns function that adds sleep_task to the thread pool.""" | |
76 def task(arg): | |
77 self.thread_pool.add_task(0, self.sleep_task(duration), arg) | |
78 return task | |
79 | |
80 @staticmethod | |
81 def none_task(): | |
82 """Returns function that returns None.""" | |
83 return lambda _arg: None | |
84 | |
85 def setUp(self): | |
86 super(ThreadPoolTest, self).setUp() | |
87 self.thread_pool = threading_utils.ThreadPool( | |
88 self.MIN_THREADS, self.MAX_THREADS, 0) | |
89 | |
90 @timeout(1) | |
91 def tearDown(self): | |
92 super(ThreadPoolTest, self).tearDown() | |
93 self.thread_pool.close() | |
94 | |
95 def get_results_via_join(self, _expected): | |
96 return self.thread_pool.join() | |
97 | |
98 def get_results_via_get_one_result(self, expected): | |
99 return [self.thread_pool.get_one_result() for _ in expected] | |
100 | |
101 def get_results_via_iter_results(self, _expected): | |
102 return list(self.thread_pool.iter_results()) | |
103 | |
104 def run_results_test(self, task, results_getter, args=None, expected=None): | |
105 """Template function for tests checking that pool returns all results. | |
106 | |
107 Will add multiple instances of |task| to the thread pool, then call | |
108 |results_getter| to get back all results and compare them to expected ones. | |
109 """ | |
110 args = range(0, 100) if args is None else args | |
111 expected = args if expected is None else expected | |
112 msg = 'Using \'%s\' to get results.' % (results_getter.__name__,) | |
113 | |
114 for i in args: | |
115 self.thread_pool.add_task(0, task, i) | |
116 results = results_getter(expected) | |
117 | |
118 # Check that got all results back (exact same set, no duplicates). | |
119 self.assertEqual(set(expected), set(results), msg) | |
120 self.assertEqual(len(expected), len(results), msg) | |
121 | |
122 # Queue is empty, result request should fail. | |
123 with self.assertRaises(threading_utils.ThreadPoolEmpty): | |
124 self.thread_pool.get_one_result() | |
125 | |
126 @timeout(1) | |
127 def test_get_one_result_ok(self): | |
128 self.thread_pool.add_task(0, lambda: 'OK') | |
129 self.assertEqual(self.thread_pool.get_one_result(), 'OK') | |
130 | |
131 @timeout(1) | |
132 def test_get_one_result_fail(self): | |
133 # No tasks added -> get_one_result raises an exception. | |
134 with self.assertRaises(threading_utils.ThreadPoolEmpty): | |
135 self.thread_pool.get_one_result() | |
136 | |
137 @timeout(5) | |
138 def test_join(self): | |
139 self.run_results_test(self.sleep_task(), | |
140 self.get_results_via_join) | |
141 | |
142 @timeout(5) | |
143 def test_get_one_result(self): | |
144 self.run_results_test(self.sleep_task(), | |
145 self.get_results_via_get_one_result) | |
146 | |
147 @timeout(5) | |
148 def test_iter_results(self): | |
149 self.run_results_test(self.sleep_task(), | |
150 self.get_results_via_iter_results) | |
151 | |
152 @timeout(5) | |
153 def test_retry_and_join(self): | |
154 self.run_results_test(self.retrying_sleep_task(), | |
155 self.get_results_via_join) | |
156 | |
157 @timeout(5) | |
158 def test_retry_and_get_one_result(self): | |
159 self.run_results_test(self.retrying_sleep_task(), | |
160 self.get_results_via_get_one_result) | |
161 | |
162 @timeout(5) | |
163 def test_retry_and_iter_results(self): | |
164 self.run_results_test(self.retrying_sleep_task(), | |
165 self.get_results_via_iter_results) | |
166 | |
167 @timeout(5) | |
168 def test_none_task_and_join(self): | |
169 self.run_results_test(self.none_task(), | |
170 self.get_results_via_join, | |
171 expected=[]) | |
172 | |
173 @timeout(5) | |
174 def test_none_task_and_get_one_result(self): | |
175 self.thread_pool.add_task(0, self.none_task(), 0) | |
176 with self.assertRaises(threading_utils.ThreadPoolEmpty): | |
177 self.thread_pool.get_one_result() | |
178 | |
179 @timeout(5) | |
180 def test_none_task_and_and_iter_results(self): | |
181 self.run_results_test(self.none_task(), | |
182 self.get_results_via_iter_results, | |
183 expected=[]) | |
184 | |
185 @timeout(5) | |
186 def test_generator_task(self): | |
187 MULTIPLIER = 1000 | |
188 COUNT = 10 | |
189 | |
190 # Generator that yields [i * MULTIPLIER, i * MULTIPLIER + COUNT). | |
191 def generator_task(i): | |
192 for j in xrange(COUNT): | |
193 time.sleep(0.001) | |
194 yield i * MULTIPLIER + j | |
195 | |
196 # Arguments for tasks and expected results. | |
197 args = range(0, 10) | |
198 expected = [i * MULTIPLIER + j for i in args for j in xrange(COUNT)] | |
199 | |
200 # Test all possible ways to pull results from the thread pool. | |
201 getters = (self.get_results_via_join, | |
202 self.get_results_via_iter_results, | |
203 self.get_results_via_get_one_result,) | |
204 for results_getter in getters: | |
205 self.run_results_test(generator_task, results_getter, args, expected) | |
206 | |
207 @timeout(5) | |
208 def test_concurrent_iter_results(self): | |
209 def poller_proc(result): | |
210 result.extend(self.thread_pool.iter_results()) | |
211 | |
212 args = range(0, 100) | |
213 for i in args: | |
214 self.thread_pool.add_task(0, self.sleep_task(), i) | |
215 | |
216 # Start a bunch of threads, all calling iter_results in parallel. | |
217 pollers = [] | |
218 for _ in xrange(0, 4): | |
219 result = [] | |
220 poller = threading.Thread(target=poller_proc, args=(result,)) | |
221 poller.start() | |
222 pollers.append((poller, result)) | |
223 | |
224 # Collects results from all polling threads. | |
225 all_results = [] | |
226 for poller, results in pollers: | |
227 poller.join() | |
228 all_results.extend(results) | |
229 | |
230 # Check that got all results back (exact same set, no duplicates). | |
231 self.assertEqual(set(args), set(all_results)) | |
232 self.assertEqual(len(args), len(all_results)) | |
233 | |
234 @timeout(1) | |
235 def test_adding_tasks_after_close(self): | |
236 pool = threading_utils.ThreadPool(1, 1, 0) | |
237 pool.add_task(0, lambda: None) | |
238 pool.close() | |
239 with self.assertRaises(threading_utils.ThreadPoolClosed): | |
240 pool.add_task(0, lambda: None) | |
241 | |
242 @timeout(1) | |
243 def test_double_close(self): | |
244 pool = threading_utils.ThreadPool(1, 1, 0) | |
245 pool.close() | |
246 with self.assertRaises(threading_utils.ThreadPoolClosed): | |
247 pool.close() | |
248 | |
249 def test_priority(self): | |
250 # Verifies that a lower priority is run first. | |
251 with threading_utils.ThreadPool(1, 1, 0) as pool: | |
252 lock = threading.Lock() | |
253 | |
254 def wait_and_return(x): | |
255 with lock: | |
256 return x | |
257 | |
258 def return_x(x): | |
259 return x | |
260 | |
261 with lock: | |
262 pool.add_task(0, wait_and_return, 'a') | |
263 pool.add_task(2, return_x, 'b') | |
264 pool.add_task(1, return_x, 'c') | |
265 | |
266 actual = pool.join() | |
267 self.assertEqual(['a', 'c', 'b'], actual) | |
268 | |
269 @timeout(2) | |
270 def test_abort(self): | |
271 # Trigger a ridiculous amount of tasks, and abort the remaining. | |
272 with threading_utils.ThreadPool(2, 2, 0) as pool: | |
273 # Allow 10 tasks to run initially. | |
274 sem = threading.Semaphore(10) | |
275 | |
276 def grab_and_return(x): | |
277 sem.acquire() | |
278 return x | |
279 | |
280 for i in range(100): | |
281 pool.add_task(0, grab_and_return, i) | |
282 | |
283 # Running at 11 would hang. | |
284 results = [pool.get_one_result() for _ in xrange(10)] | |
285 # At that point, there's 10 completed tasks and 2 tasks hanging, 88 | |
286 # pending. | |
287 self.assertEqual(88, pool.abort()) | |
288 # Calling .join() before these 2 .release() would hang. | |
289 sem.release() | |
290 sem.release() | |
291 results.extend(pool.join()) | |
292 # The results *may* be out of order. Even if the calls are processed | |
293 # strictly in FIFO mode, a thread may preempt another one when returning the | |
294 # values. | |
295 self.assertEqual(range(12), sorted(results)) | |
296 | |
297 | |
298 class AutoRetryThreadPoolTest(unittest.TestCase): | |
299 def test_bad_class(self): | |
300 exceptions = [AutoRetryThreadPoolTest] | |
301 with self.assertRaises(AssertionError): | |
302 threading_utils.AutoRetryThreadPool(exceptions, 1, 0, 1, 0) | |
303 | |
304 def test_no_exception(self): | |
305 with self.assertRaises(AssertionError): | |
306 threading_utils.AutoRetryThreadPool([], 1, 0, 1, 0) | |
307 | |
308 def test_bad_retry(self): | |
309 exceptions = [IOError] | |
310 with self.assertRaises(AssertionError): | |
311 threading_utils.AutoRetryThreadPool(exceptions, 256, 0, 1, 0) | |
312 | |
313 def test_bad_priority(self): | |
314 exceptions = [IOError] | |
315 with threading_utils.AutoRetryThreadPool(exceptions, 1, 1, 1, 0) as pool: | |
316 pool.add_task(0, lambda x: x, 0) | |
317 pool.add_task(256, lambda x: x, 0) | |
318 pool.add_task(512, lambda x: x, 0) | |
319 with self.assertRaises(AssertionError): | |
320 pool.add_task(1, lambda x: x, 0) | |
321 with self.assertRaises(AssertionError): | |
322 pool.add_task(255, lambda x: x, 0) | |
323 | |
324 def test_priority(self): | |
325 # Verifies that a lower priority is run first. | |
326 exceptions = [IOError] | |
327 with threading_utils.AutoRetryThreadPool(exceptions, 1, 1, 1, 0) as pool: | |
328 lock = threading.Lock() | |
329 | |
330 def wait_and_return(x): | |
331 with lock: | |
332 return x | |
333 | |
334 def return_x(x): | |
335 return x | |
336 | |
337 with lock: | |
338 pool.add_task(pool.HIGH, wait_and_return, 'a') | |
339 pool.add_task(pool.LOW, return_x, 'b') | |
340 pool.add_task(pool.MED, return_x, 'c') | |
341 | |
342 actual = pool.join() | |
343 self.assertEqual(['a', 'c', 'b'], actual) | |
344 | |
345 def test_retry_inherited(self): | |
346 # Exception class inheritance works. | |
347 class CustomException(IOError): | |
348 pass | |
349 ran = [] | |
350 def throw(to_throw, x): | |
351 ran.append(x) | |
352 if to_throw: | |
353 raise to_throw.pop(0) | |
354 return x | |
355 with threading_utils.AutoRetryThreadPool([IOError], 1, 1, 1, 0) as pool: | |
356 pool.add_task(pool.MED, throw, [CustomException('a')], 'yay') | |
357 actual = pool.join() | |
358 self.assertEqual(['yay'], actual) | |
359 self.assertEqual(['yay', 'yay'], ran) | |
360 | |
361 def test_retry_2_times(self): | |
362 exceptions = [IOError, OSError] | |
363 to_throw = [OSError('a'), IOError('b')] | |
364 def throw(x): | |
365 if to_throw: | |
366 raise to_throw.pop(0) | |
367 return x | |
368 with threading_utils.AutoRetryThreadPool(exceptions, 2, 1, 1, 0) as pool: | |
369 pool.add_task(pool.MED, throw, 'yay') | |
370 actual = pool.join() | |
371 self.assertEqual(['yay'], actual) | |
372 | |
373 def test_retry_too_many_times(self): | |
374 exceptions = [IOError, OSError] | |
375 to_throw = [OSError('a'), IOError('b')] | |
376 def throw(x): | |
377 if to_throw: | |
378 raise to_throw.pop(0) | |
379 return x | |
380 with threading_utils.AutoRetryThreadPool(exceptions, 1, 1, 1, 0) as pool: | |
381 pool.add_task(pool.MED, throw, 'yay') | |
382 with self.assertRaises(IOError): | |
383 pool.join() | |
384 | |
385 def test_retry_mutation_1(self): | |
386 # This is to warn that mutable arguments WILL be mutated. | |
387 def throw(to_throw, x): | |
388 if to_throw: | |
389 raise to_throw.pop(0) | |
390 return x | |
391 exceptions = [IOError, OSError] | |
392 with threading_utils.AutoRetryThreadPool(exceptions, 1, 1, 1, 0) as pool: | |
393 pool.add_task(pool.MED, throw, [OSError('a'), IOError('b')], 'yay') | |
394 with self.assertRaises(IOError): | |
395 pool.join() | |
396 | |
397 def test_retry_mutation_2(self): | |
398 # This is to warn that mutable arguments WILL be mutated. | |
399 def throw(to_throw, x): | |
400 if to_throw: | |
401 raise to_throw.pop(0) | |
402 return x | |
403 exceptions = [IOError, OSError] | |
404 with threading_utils.AutoRetryThreadPool(exceptions, 2, 1, 1, 0) as pool: | |
405 pool.add_task(pool.MED, throw, [OSError('a'), IOError('b')], 'yay') | |
406 actual = pool.join() | |
407 self.assertEqual(['yay'], actual) | |
408 | |
409 def test_retry_interleaved(self): | |
410 # Verifies that retries are interleaved. This is important, we don't want a | |
411 # retried task to take all the pool during retries. | |
412 exceptions = [IOError, OSError] | |
413 lock = threading.Lock() | |
414 ran = [] | |
415 with threading_utils.AutoRetryThreadPool(exceptions, 2, 1, 1, 0) as pool: | |
416 def lock_and_throw(to_throw, x): | |
417 with lock: | |
418 ran.append(x) | |
419 if to_throw: | |
420 raise to_throw.pop(0) | |
421 return x | |
422 with lock: | |
423 pool.add_task( | |
424 pool.MED, lock_and_throw, [OSError('a'), IOError('b')], 'A') | |
425 pool.add_task( | |
426 pool.MED, lock_and_throw, [OSError('a'), IOError('b')], 'B') | |
427 | |
428 actual = pool.join() | |
429 self.assertEqual(['A', 'B'], actual) | |
430 # Retries are properly interleaved: | |
431 self.assertEqual(['A', 'B', 'A', 'B', 'A', 'B'], ran) | |
432 | |
433 def test_add_task_with_channel_success(self): | |
434 with threading_utils.AutoRetryThreadPool([OSError], 2, 1, 1, 0) as pool: | |
435 channel = threading_utils.TaskChannel() | |
436 pool.add_task_with_channel(channel, 0, lambda: 0) | |
437 self.assertEqual(0, channel.pull()) | |
438 | |
439 def test_add_task_with_channel_fatal_error(self): | |
440 with threading_utils.AutoRetryThreadPool([OSError], 2, 1, 1, 0) as pool: | |
441 channel = threading_utils.TaskChannel() | |
442 def throw(exc): | |
443 raise exc | |
444 pool.add_task_with_channel(channel, 0, throw, ValueError()) | |
445 with self.assertRaises(ValueError): | |
446 channel.pull() | |
447 | |
448 def test_add_task_with_channel_retryable_error(self): | |
449 with threading_utils.AutoRetryThreadPool([OSError], 2, 1, 1, 0) as pool: | |
450 channel = threading_utils.TaskChannel() | |
451 def throw(exc): | |
452 raise exc | |
453 pool.add_task_with_channel(channel, 0, throw, OSError()) | |
454 with self.assertRaises(OSError): | |
455 channel.pull() | |
456 | |
457 | |
458 class FakeProgress(object): | |
459 @staticmethod | |
460 def print_update(): | |
461 pass | |
462 | |
463 | |
464 class WorkerPoolTest(unittest.TestCase): | |
465 def test_normal(self): | |
466 mapper = lambda value: -value | |
467 progress = FakeProgress() | |
468 with threading_utils.ThreadPoolWithProgress(progress, 8, 8, 0) as pool: | |
469 for i in range(32): | |
470 pool.add_task(0, mapper, i) | |
471 results = pool.join() | |
472 self.assertEqual(range(-31, 1), sorted(results)) | |
473 | |
474 def test_exception(self): | |
475 class FearsomeException(Exception): | |
476 pass | |
477 def mapper(value): | |
478 raise FearsomeException(value) | |
479 task_added = False | |
480 try: | |
481 progress = FakeProgress() | |
482 with threading_utils.ThreadPoolWithProgress(progress, 8, 8, 0) as pool: | |
483 pool.add_task(0, mapper, 0) | |
484 task_added = True | |
485 pool.join() | |
486 self.fail() | |
487 except FearsomeException: | |
488 self.assertEqual(True, task_added) | |
489 | |
490 | |
491 class TaskChannelTest(unittest.TestCase): | |
492 def test_passes_simple_value(self): | |
493 with threading_utils.ThreadPool(1, 1, 0) as tp: | |
494 channel = threading_utils.TaskChannel() | |
495 tp.add_task(0, lambda: channel.send_result(0)) | |
496 self.assertEqual(0, channel.pull()) | |
497 | |
498 def test_passes_exception_value(self): | |
499 with threading_utils.ThreadPool(1, 1, 0) as tp: | |
500 channel = threading_utils.TaskChannel() | |
501 tp.add_task(0, lambda: channel.send_result(Exception())) | |
502 self.assertTrue(isinstance(channel.pull(), Exception)) | |
503 | |
504 def test_wrap_task_passes_simple_value(self): | |
505 with threading_utils.ThreadPool(1, 1, 0) as tp: | |
506 channel = threading_utils.TaskChannel() | |
507 tp.add_task(0, channel.wrap_task(lambda: 0)) | |
508 self.assertEqual(0, channel.pull()) | |
509 | |
510 def test_wrap_task_passes_exception_value(self): | |
511 with threading_utils.ThreadPool(1, 1, 0) as tp: | |
512 channel = threading_utils.TaskChannel() | |
513 tp.add_task(0, channel.wrap_task(lambda: Exception())) | |
514 self.assertTrue(isinstance(channel.pull(), Exception)) | |
515 | |
516 def test_send_exception_raises_exception(self): | |
517 class CustomError(Exception): | |
518 pass | |
519 with threading_utils.ThreadPool(1, 1, 0) as tp: | |
520 channel = threading_utils.TaskChannel() | |
521 tp.add_task(0, lambda: channel.send_exception(CustomError())) | |
522 with self.assertRaises(CustomError): | |
523 channel.pull() | |
524 | |
525 def test_wrap_task_raises_exception(self): | |
526 class CustomError(Exception): | |
527 pass | |
528 with threading_utils.ThreadPool(1, 1, 0) as tp: | |
529 channel = threading_utils.TaskChannel() | |
530 def task_func(): | |
531 raise CustomError() | |
532 tp.add_task(0, channel.wrap_task(task_func)) | |
533 with self.assertRaises(CustomError): | |
534 channel.pull() | |
535 | |
536 | |
537 if __name__ == '__main__': | |
538 VERBOSE = '-v' in sys.argv | |
539 logging.basicConfig(level=logging.DEBUG if VERBOSE else logging.ERROR) | |
540 unittest.main() | |
OLD | NEW |