Chromium Code Reviews| Index: telemetry/telemetry/testing/browser_test_runner.py |
| diff --git a/telemetry/telemetry/testing/browser_test_runner.py b/telemetry/telemetry/testing/browser_test_runner.py |
| index 7c973e107b2940d319b0eacd777ed287d2ca7af6..5a4dd4588f78f83ec6e2ecf7438349a52b3121a7 100644 |
| --- a/telemetry/telemetry/testing/browser_test_runner.py |
| +++ b/telemetry/telemetry/testing/browser_test_runner.py |
| @@ -45,9 +45,42 @@ def ValidateTestMethodname(test_name): |
| assert not bool(_INVALID_TEST_NAME_RE.search(test_name)) |
| +def TestRangeForShard(total_shards, shard_index, num_tests): |
| + """Returns a 2-tuple containing the start (inclusive) and ending |
| + (exclusive) indices of the tests that should be run, given that |
| + |num_tests| tests are split across |total_shards| shards, and that |
| + |shard_index| is currently being run. |
|
nednguyen
2016/06/17 04:38:00
Now that I think about this, why do we need such a
|
| + """ |
| + assert num_tests >= 0 |
| + assert total_shards >= 1 |
| + assert shard_index >= 0 and shard_index < total_shards, ( |
| + 'shard_index (%d) must be >= 0 and < total_shards (%d)' % |
| + (shard_index, total_shards)) |
| + if num_tests == 0: |
| + return (0, 0) |
| + floored_tests_per_shard = num_tests // total_shards |
| + remaining_tests = num_tests % total_shards |
| + if remaining_tests == 0: |
| + return (floored_tests_per_shard * shard_index, |
| + floored_tests_per_shard * (1 + shard_index)) |
| + # More complicated. Some shards will run floored_tests_per_shard |
| + # tests, and some will run 1 + floored_tests_per_shard. |
| + num_earlier_shards_with_one_extra_test = min(remaining_tests, shard_index) |
| + num_earlier_shards_with_no_extra_tests = max( |
| + 0, shard_index - num_earlier_shards_with_one_extra_test) |
| + num_earlier_tests = ( |
| + num_earlier_shards_with_one_extra_test * (floored_tests_per_shard + 1) + |
| + num_earlier_shards_with_no_extra_tests * floored_tests_per_shard) |
| + tests_for_this_shard = floored_tests_per_shard |
| + if shard_index < remaining_tests: |
| + tests_for_this_shard += 1 |
| + return (num_earlier_tests, num_earlier_tests + tests_for_this_shard) |
| + |
| + |
| _TEST_GENERATOR_PREFIX = 'GenerateTestCases_' |
| -def LoadTests(test_class, finder_options, filter_regex_str): |
| +def LoadTests(test_class, finder_options, filter_regex_str, |
| + total_shards, shard_index): |
| test_cases = [] |
| filter_regex = re.compile(filter_regex_str) |
| for name, method in inspect.getmembers( |
| @@ -74,7 +107,8 @@ def LoadTests(test_class, finder_options, filter_regex_str): |
| based_method, args)) |
| test_cases.append(test_class(generated_test_name)) |
| test_cases.sort(key=lambda t: t.id()) |
| - return test_cases |
| + test_range = TestRangeForShard(total_shards, shard_index, len(test_cases)) |
| + return test_cases[test_range[0]:test_range[1]] |
| class TestRunOptions(object): |
| @@ -101,6 +135,11 @@ def Run(project_config, test_run_options, args): |
| help=('If specified, writes the full results to that path in json form.')) |
| parser.add_argument('--test-filter', type=str, default='', action='store', |
| help='Run only tests whose names match the given filter regexp.') |
| + parser.add_argument('--total-shards', default=1, type=int, |
| + help='Total number of shards being used for this test run. (The user of ' |
| + 'this script is responsible for spawning all of the shards.)') |
| + parser.add_argument('--shard-index', default=0, type=int, |
| + help='Shard index (0..total_shards-1) of this test run.') |
| option, extra_args = parser.parse_known_args(args) |
| for start_dir in project_config.start_dirs: |
| @@ -115,9 +154,10 @@ def Run(project_config, test_run_options, args): |
| for cl in browser_test_classes: |
| if cl.Name() == option.test: |
| test_class = cl |
| + break |
| if not test_class: |
| - print 'Cannot find test class with name matched %s' % option.test |
| + print 'Cannot find test class with name matching %s' % option.test |
| print 'Available tests: %s' % '\n'.join( |
| cl.Name() for cl in browser_test_classes) |
| return 1 |
| @@ -125,7 +165,8 @@ def Run(project_config, test_run_options, args): |
| options = ProcessCommandLineOptions(test_class, extra_args) |
| suite = unittest.TestSuite() |
| - for test in LoadTests(test_class, options, option.test_filter): |
| + for test in LoadTests(test_class, options, option.test_filter, |
| + option.total_shards, option.shard_index): |
| suite.addTest(test) |
| results = unittest.TextTestRunner( |