Index: build/android/pylib/base/shard.py |
diff --git a/build/android/pylib/base/shard.py b/build/android/pylib/base/shard.py |
index fdca5249a06b06f920d7b9ccf7a4d9e9fe026c60..100eef4ab54710feff86bd5f9aa49a8fc96c9e1c 100644 |
--- a/build/android/pylib/base/shard.py |
+++ b/build/android/pylib/base/shard.py |
@@ -14,6 +14,24 @@ from pylib.utils import reraiser_thread |
import test_result |
+class _ThreadSafeCounter(object): |
+ """A threadsafe counter.""" |
+ def __init__(self): |
+ self._lock = threading.Lock() |
+ self._value = 0 |
+ |
+ def GetAndIncrement(self): |
+ """Get the current value and increment it atomically. |
+ |
+ Returns: |
+ The value before incrementing. |
+ """ |
+ with self._lock: |
+ pre_increment = self._value |
+ self._value += 1 |
+ return pre_increment |
+ |
+ |
class _Test(object): |
"""Holds a test with additional metadata.""" |
def __init__(self, test, tries=0): |
@@ -58,7 +76,7 @@ class _TestCollection(object): |
if self._tests_in_progress == 0: |
return None |
try: |
- return self._tests.pop() |
+ return self._tests.pop(0) |
except IndexError: |
# Another thread beat us to the avaliable test, wait again. |
self._item_avaliable_or_all_done.clear() |
@@ -114,7 +132,7 @@ def _RunTestsFromQueue(runner, test_collection, out_results): |
if retry and test.tries <= 3: |
# Retry non-passing results, only record passing results. |
out_results.append(test_result.TestResults.FromRun(ok=result.ok)) |
- logging.warning('****Retrying test, try #%s.' % test.tries) |
+ logging.warning('****Will retry test, try #%s.' % test.tries) |
test_collection.add(_Test(test=retry, tries=test.tries)) |
else: |
# All tests passed or retry limit reached. Either way, record results. |
@@ -135,24 +153,27 @@ def _RunTestsFromQueue(runner, test_collection, out_results): |
test_collection.test_completed() |
-def _SetUp(runner_factory, device, out_runners): |
+def _SetUp(runner_factory, device, out_runners, threadsafe_counter): |
"""Creates a test runner for each device and calls SetUp() in parallel. |
Note: if a device is unresponsive the corresponding TestRunner will not be |
added to out_runners. |
Args: |
- runner_factory: callable that takes a device and returns a TestRunner. |
+ runner_factory: callable that takes a device and index and returns a |
+ TestRunner object. |
device: the device serial number to set up. |
out_runners: list to add the successfully set up TestRunner object. |
+ threadsafe_counter: a _ThreadSafeCounter object used to get shard indices. |
""" |
try: |
- logging.warning('*****Creating shard for %s.', device) |
- runner = runner_factory(device) |
+ index = threadsafe_counter.GetAndIncrement() |
+ logging.warning('*****Creating shard %s for device %s.', index, device) |
+ runner = runner_factory(device, index) |
runner.SetUp() |
out_runners.append(runner) |
except android_commands.errors.DeviceUnresponsiveError as e: |
- logging.warning('****Failed to create shard for %s: [%s]', (device, e)) |
+ logging.warning('****Failed to create shard for %s: [%s]', device, e) |
def _RunAllTests(runners, tests): |
@@ -183,20 +204,23 @@ def _CreateRunners(runner_factory, devices): |
included in the returned list. |
Args: |
- runner_factory: callable that takes a device and returns a TestRunner. |
+ runner_factory: callable that takes a device and index and returns a |
+ TestRunner object. |
devices: list of device serial numbers as strings. |
Returns: |
A list of TestRunner objects. |
""" |
logging.warning('****Creating %s test runners.' % len(devices)) |
- test_runners = [] |
+ runners = [] |
+ counter = _ThreadSafeCounter() |
threads = reraiser_thread.ReraiserThreadGroup( |
- [reraiser_thread.ReraiserThread(_SetUp, [runner_factory, d, test_runners]) |
+ [reraiser_thread.ReraiserThread(_SetUp, [runner_factory, d, runners, |
+ counter]) |
for d in devices]) |
threads.StartAll() |
threads.JoinAll() |
- return test_runners |
+ return runners |
def _TearDownRunners(runners): |
@@ -215,7 +239,8 @@ def ShardAndRunTests(runner_factory, devices, tests, build_type='Debug'): |
"""Run all tests on attached devices, retrying tests that don't pass. |
Args: |
- runner_factory: callable that takes a device and returns a TestRunner. |
+ runner_factory: callable that takes a device and index and returns a |
+ TestRunner object. |
devices: list of attached device serial numbers as strings. |
tests: list of tests to run. |
build_type: either 'Debug' or 'Release'. |