Index: build/android/pylib/base/shard_unittest.py |
diff --git a/build/android/pylib/base/shard_unittest.py b/build/android/pylib/base/shard_unittest.py |
index 5f8b9908965d16e54f2db8db4a340a8c407387eb..65669e00d48e676f9af424908f71ebe0de72d646 100644 |
--- a/build/android/pylib/base/shard_unittest.py |
+++ b/build/android/pylib/base/shard_unittest.py |
@@ -122,6 +122,9 @@ class TestThreadGroupFunctions(unittest.TestCase): |
"""Tests for shard._RunAllTests and shard._CreateRunners.""" |
def setUp(self): |
self.tests = ['a', 'b', 'c', 'd', 'e', 'f', 'g'] |
+ shared_test_collection = shard._TestCollection( |
+ [shard._Test(t) for t in self.tests]) |
+ self.test_collection_factory = lambda: shared_test_collection |
def testCreate(self): |
runners = shard._CreateRunners(MockRunner, ['0', '1']) |
@@ -134,7 +137,8 @@ class TestThreadGroupFunctions(unittest.TestCase): |
def testRun(self): |
runners = [MockRunner('0'), MockRunner('1')] |
- results, exit_code = shard._RunAllTests(runners, self.tests, 0) |
+ results, exit_code = shard._RunAllTests( |
+ runners, self.test_collection_factory, 0) |
self.assertEqual(len(results.GetPass()), len(self.tests)) |
self.assertEqual(exit_code, 0) |
@@ -146,21 +150,22 @@ class TestThreadGroupFunctions(unittest.TestCase): |
def testRetry(self): |
runners = shard._CreateRunners(MockRunnerFail, ['0', '1']) |
- results, exit_code = shard._RunAllTests(runners, self.tests, 0) |
+ results, exit_code = shard._RunAllTests( |
+ runners, self.test_collection_factory, 0) |
self.assertEqual(len(results.GetFail()), len(self.tests)) |
self.assertEqual(exit_code, constants.ERROR_EXIT_CODE) |
def testReraise(self): |
runners = shard._CreateRunners(MockRunnerException, ['0', '1']) |
with self.assertRaises(TestException): |
- shard._RunAllTests(runners, self.tests, 0) |
+ shard._RunAllTests(runners, self.test_collection_factory, 0) |
class TestShard(unittest.TestCase): |
- """Tests for shard.Shard.""" |
+ """Tests for shard.ShardAndRunTests.""" |
@staticmethod |
def _RunShard(runner_factory): |
- return shard.ShardAndRunTests(runner_factory, ['0', '1'], ['a', 'b', 'c']) |
+ return shard.ShardAndRunTests(['a', 'b', 'c'], runner_factory, ['0', '1']) |
def testShard(self): |
results, exit_code = TestShard._RunShard(MockRunner) |
@@ -174,7 +179,31 @@ class TestShard(unittest.TestCase): |
self.assertEqual(exit_code, 0) |
def testNoTests(self): |
- results, exit_code = shard.ShardAndRunTests(MockRunner, ['0', '1'], []) |
+ results, exit_code = shard.ShardAndRunTests([], MockRunner, ['0', '1']) |
+ self.assertEqual(len(results.GetAll()), 0) |
+ self.assertEqual(exit_code, constants.ERROR_EXIT_CODE) |
+ |
+ |
+class TestReplicate(unittest.TestCase): |
+ """Tests for shard.ReplicateAndRunTests.""" |
+ @staticmethod |
+ def _RunReplicate(runner_factory): |
+ return shard.ReplicateAndRunTests(['a', 'b', 'c'], runner_factory, |
+ ['0', '1']) |
+ |
+ def testReplicate(self): |
+ results, exit_code = TestShard._RunShard(MockRunner) |
+ # We expect 6 results since each test should have been run on every device |
+ self.assertEqual(len(results.GetPass()), 6) |
+ self.assertEqual(exit_code, 0) |
+ |
+ def testFailing(self): |
+ results, exit_code = TestShard._RunShard(MockRunnerFail) |
+ self.assertEqual(len(results.GetPass()), 0) |
+ self.assertEqual(len(results.GetFail()), 6) |
+ |
+ def testNoTests(self): |
+ results, exit_code = shard.ReplicateAndRunTests([], MockRunner, ['0', '1']) |
self.assertEqual(len(results.GetAll()), 0) |
self.assertEqual(exit_code, constants.ERROR_EXIT_CODE) |