Index: telemetry/telemetry/testing/browser_test_runner.py |
diff --git a/telemetry/telemetry/testing/browser_test_runner.py b/telemetry/telemetry/testing/browser_test_runner.py |
index 7c973e107b2940d319b0eacd777ed287d2ca7af6..5a4dd4588f78f83ec6e2ecf7438349a52b3121a7 100644 |
--- a/telemetry/telemetry/testing/browser_test_runner.py |
+++ b/telemetry/telemetry/testing/browser_test_runner.py |
@@ -45,9 +45,42 @@ def ValidateTestMethodname(test_name): |
assert not bool(_INVALID_TEST_NAME_RE.search(test_name)) |
+def TestRangeForShard(total_shards, shard_index, num_tests): |
+ """Returns a 2-tuple containing the start (inclusive) and ending |
+ (exclusive) indices of the tests that should be run, given that |
+ |num_tests| tests are split across |total_shards| shards, and that |
+ |shard_index| is currently being run. |
nednguyen
2016/06/17 04:38:00
Now that I think about this, why do we need such a
|
+ """ |
+ assert num_tests >= 0 |
+ assert total_shards >= 1 |
+ assert shard_index >= 0 and shard_index < total_shards, ( |
+ 'shard_index (%d) must be >= 0 and < total_shards (%d)' % |
+ (shard_index, total_shards)) |
+ if num_tests == 0: |
+ return (0, 0) |
+ floored_tests_per_shard = num_tests // total_shards |
+ remaining_tests = num_tests % total_shards |
+ if remaining_tests == 0: |
+ return (floored_tests_per_shard * shard_index, |
+ floored_tests_per_shard * (1 + shard_index)) |
+ # More complicated. Some shards will run floored_tests_per_shard |
+ # tests, and some will run 1 + floored_tests_per_shard. |
+ num_earlier_shards_with_one_extra_test = min(remaining_tests, shard_index) |
+ num_earlier_shards_with_no_extra_tests = max( |
+ 0, shard_index - num_earlier_shards_with_one_extra_test) |
+ num_earlier_tests = ( |
+ num_earlier_shards_with_one_extra_test * (floored_tests_per_shard + 1) + |
+ num_earlier_shards_with_no_extra_tests * floored_tests_per_shard) |
+ tests_for_this_shard = floored_tests_per_shard |
+ if shard_index < remaining_tests: |
+ tests_for_this_shard += 1 |
+ return (num_earlier_tests, num_earlier_tests + tests_for_this_shard) |
+ |
+ |
_TEST_GENERATOR_PREFIX = 'GenerateTestCases_' |
-def LoadTests(test_class, finder_options, filter_regex_str): |
+def LoadTests(test_class, finder_options, filter_regex_str, |
+ total_shards, shard_index): |
test_cases = [] |
filter_regex = re.compile(filter_regex_str) |
for name, method in inspect.getmembers( |
@@ -74,7 +107,8 @@ def LoadTests(test_class, finder_options, filter_regex_str): |
based_method, args)) |
test_cases.append(test_class(generated_test_name)) |
test_cases.sort(key=lambda t: t.id()) |
- return test_cases |
+ test_range = TestRangeForShard(total_shards, shard_index, len(test_cases)) |
+ return test_cases[test_range[0]:test_range[1]] |
class TestRunOptions(object): |
@@ -101,6 +135,11 @@ def Run(project_config, test_run_options, args): |
help=('If specified, writes the full results to that path in json form.')) |
parser.add_argument('--test-filter', type=str, default='', action='store', |
help='Run only tests whose names match the given filter regexp.') |
+ parser.add_argument('--total-shards', default=1, type=int, |
+ help='Total number of shards being used for this test run. (The user of ' |
+ 'this script is responsible for spawning all of the shards.)') |
+ parser.add_argument('--shard-index', default=0, type=int, |
+ help='Shard index (0..total_shards-1) of this test run.') |
option, extra_args = parser.parse_known_args(args) |
for start_dir in project_config.start_dirs: |
@@ -115,9 +154,10 @@ def Run(project_config, test_run_options, args): |
for cl in browser_test_classes: |
if cl.Name() == option.test: |
test_class = cl |
+ break |
if not test_class: |
- print 'Cannot find test class with name matched %s' % option.test |
+ print 'Cannot find test class with name matching %s' % option.test |
print 'Available tests: %s' % '\n'.join( |
cl.Name() for cl in browser_test_classes) |
return 1 |
@@ -125,7 +165,8 @@ def Run(project_config, test_run_options, args): |
options = ProcessCommandLineOptions(test_class, extra_args) |
suite = unittest.TestSuite() |
- for test in LoadTests(test_class, options, option.test_filter): |
+ for test in LoadTests(test_class, options, option.test_filter, |
+ option.total_shards, option.shard_index): |
suite.addTest(test) |
results = unittest.TextTestRunner( |