Skip to content

Commit

Permalink
refactor sharding so that it works with listing tests, not just execu…
Browse files Browse the repository at this point in the history
…ting them
  • Loading branch information
ChewyGumball committed Jul 11, 2021
1 parent efe1634 commit 98d5d84
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 20 deletions.
1 change: 1 addition & 0 deletions src/catch2/catch_all.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
#include <catch2/internal/catch_result_type.hpp>
#include <catch2/internal/catch_run_context.hpp>
#include <catch2/internal/catch_section.hpp>
#include <catch2/internal/catch_sharding.hpp>
#include <catch2/internal/catch_singletons.hpp>
#include <catch2/internal/catch_source_line_info.hpp>
#include <catch2/internal/catch_startup_exception_registry.hpp>
Expand Down
4 changes: 2 additions & 2 deletions src/catch2/catch_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ namespace Catch {
int abortAfter = -1;
unsigned int rngSeed = 0;

int shardCount = 1;
int shardIndex = 0;
unsigned int shardCount = 1;
unsigned int shardIndex = 0;

bool benchmarkNoAnalysis = false;
unsigned int benchmarkSamples = 100;
Expand Down
19 changes: 2 additions & 17 deletions src/catch2/catch_session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <catch2/catch_version.hpp>
#include <catch2/interfaces/catch_interfaces_reporter.hpp>
#include <catch2/internal/catch_startup_exception_registry.hpp>
#include <catch2/internal/catch_sharding.hpp>
#include <catch2/internal/catch_textflow.hpp>
#include <catch2/internal/catch_windows_h_proxy.hpp>
#include <catch2/reporters/catch_reporter_listening.hpp>
Expand Down Expand Up @@ -78,23 +79,7 @@ namespace Catch {
m_tests.insert(match.tests.begin(), match.tests.end());
}

if (m_config->shardCount() > 1) {
int shardCount = m_config->shardCount();
if (shardCount > m_tests.size()) {
shardCount = m_tests.size();
}

int shardIndex = m_config->shardIndex();
if (shardIndex >= shardCount) {
shardIndex = shardCount - 1;
}

int shardSize = m_tests.size() / shardCount;
auto firstIndex = std::next(m_tests.begin(), shardSize * shardIndex);
auto lastIndex = std::next(firstIndex, shardSize);

m_tests = std::set(firstIndex, lastIndex);
}
m_tests = createShard(m_tests, *m_config);
}

Totals execute() {
Expand Down
47 changes: 47 additions & 0 deletions src/catch2/internal/catch_sharding.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@

// Copyright Catch2 Authors
// Distributed under the Boost Software License, Version 1.0.
// (See accompanying file LICENSE_1_0.txt or copy at
// https://www.boost.org/LICENSE_1_0.txt)

// SPDX-License-Identifier: BSL-1.0
#ifndef CATCH_SHARDING_HPP_INCLUDED
#define CATCH_SHARDING_HPP_INCLUDED

#include <catch2/catch_session.hpp>

#include <cmath>

namespace Catch {

template<typename CONTAINER>
CONTAINER createShard(CONTAINER const& container, IConfig const& config) {
if (config.shardCount() > 1) {
unsigned int totalTestCount = container.size();

unsigned int shardCount = (std::min)(config.shardCount(), totalTestCount);
unsigned int shardIndex = (std::min)(config.shardIndex(), shardCount - 1);

double shardSize = totalTestCount / static_cast<double>(shardCount);
double startIndex = shardIndex * shardSize;

auto startIterator = std::next(container.begin(), std::floor(startIndex));
auto endIterator = std::next(container.begin(), std::floor(startIndex + shardSize));

// Since we are calculating the end index with floating point numbers, but flooring
// the value, we can't guarantee that the end index of the last shard lines up exactly
// with the end of input container. If we want the last shard, force the end index to
// be the end of the input container.
if (shardIndex == shardCount - 1) {
endIterator = container.end();
}

return CONTAINER(startIterator, endIterator);
} else {
return container;
}
}

}

#endif // CATCH_SHARDING_HPP_INCLUDED
3 changes: 2 additions & 1 deletion src/catch2/internal/catch_test_case_registry_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <catch2/interfaces/catch_interfaces_registry_hub.hpp>
#include <catch2/internal/catch_random_number_generator.hpp>
#include <catch2/internal/catch_run_context.hpp>
#include <catch2/internal/catch_sharding.hpp>
#include <catch2/catch_test_case_info.hpp>
#include <catch2/catch_test_spec.hpp>

Expand Down Expand Up @@ -110,7 +111,7 @@ namespace {
filtered.push_back(testCase);
}
}
return filtered;
return createShard(filtered, config);
}
std::vector<TestCaseHandle> const& getAllTestCasesSorted( IConfig const& config ) {
return getRegistryHub().getTestCaseRegistry().getAllTestsSorted( config );
Expand Down
3 changes: 3 additions & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,9 @@ set_tests_properties(TagAlias PROPERTIES
add_test(NAME RandomTestOrdering COMMAND ${PYTHON_EXECUTABLE}
${CATCH_DIR}/tests/TestScripts/testRandomOrder.py $<TARGET_FILE:SelfTest>)

add_test(NAME TestSharding COMMAND ${PYTHON_EXECUTABLE}
${CATCH_DIR}/tests/TestScripts/testSharding.py $<TARGET_FILE:SelfTest>)

add_test(NAME CheckConvenienceHeaders
COMMAND
${PYTHON_EXECUTABLE} ${CATCH_DIR}/tools/scripts/checkConvenienceHeaders.py
Expand Down
68 changes: 68 additions & 0 deletions tests/TestScripts/testSharding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#!/usr/bin/env python3

"""
This test script verifies that the random ordering of tests inside
Catch2 is invariant in regards to subsetting. This is done by running
the binary 3 times, once with all tests selected, and twice with smaller
subsets of tests selected, and verifying that the selected tests are in
the same relative order.
"""

import subprocess
import sys
import random
import xml.etree.ElementTree as ET

def list_tests(self_test_exe, tags, rng_seed):
cmd = [self_test_exe, '--reporter', 'xml', '--list-tests', '--order', 'rand',
'--rng-seed', str(rng_seed)]
tags_arg = ','.join('[{}]~[.]'.format(t) for t in tags)
if tags_arg:
cmd.append(tags_arg)
process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = process.communicate()
if stderr:
raise RuntimeError("Unexpected error output:\n" + process.stderr)

root = ET.fromstring(stdout)
result = [elem.text for elem in root.findall('./TestCase/Name')]

if len(result) < 2:
raise RuntimeError("Unexpectedly few tests listed (got {})".format(
len(result)))
return result

def check_is_sublist_of(shorter, longer):
assert len(shorter) < len(longer)
assert len(set(longer)) == len(longer)

indexes_in_longer = {s: i for i, s in enumerate(longer)}
for s1, s2 in zip(shorter, shorter[1:]):
assert indexes_in_longer[s1] < indexes_in_longer[s2], (
'{} comes before {} in longer list.\n'
'Longer: {}\nShorter: {}'.format(s2, s1, longer, shorter))

def main():
self_test_exe, = sys.argv[1:]

test_cases = [
(1, 0), # default values, 1 shard, execute index 0
(1, 1), # 1 shard, invalid index (should still execute using the last valid index)
(4, 2), # 4 shards, second index
]

# We want a random seed for the test, but want to avoid 0,
# because it has special meaning
seed = random.randint(1, 2 ** 32 - 1)

list_one_tag = list_tests(self_test_exe, ['generators'], seed)
list_two_tags = list_tests(self_test_exe, ['generators', 'matchers'], seed)
list_all = list_tests(self_test_exe, [], seed)

# First, verify that restricting to a subset yields the same order
check_is_sublist_of(list_two_tags, list_all)
check_is_sublist_of(list_one_tag, list_two_tags)

if __name__ == '__main__':
sys.exit(main())

0 comments on commit 98d5d84

Please sign in to comment.