diff --git a/test/run_test.py b/test/run_test.py index 5889cd21aa583..33b7d5805feb6 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -32,6 +32,7 @@ set_cwd, shell, TEST_WITH_ASAN, + TEST_WITH_CROSSREF, TEST_WITH_ROCM, TEST_WITH_SLOW_GRADCHECK, ) @@ -1164,6 +1165,7 @@ def parse_args(): action="store_true", help="Enables removing tests based on TD", default=IS_CI + and TEST_WITH_CROSSREF and os.getenv("BRANCH", "") != "main" and not strtobool(os.environ.get("NO_TD", "False")), ) @@ -1462,7 +1464,7 @@ def do_sharding( test_file_times: Dict[str, float], test_class_times: Dict[str, Dict[str, float]], sort_by_time: bool = True, -) -> List[ShardedTest]: +) -> Tuple[float, List[ShardedTest]]: which_shard, num_shards = get_sharding_opts(options) # Do sharding @@ -1474,10 +1476,7 @@ def do_sharding( must_serial=must_serial, sort_by_time=sort_by_time, ) - _, tests_from_shard = shards[which_shard - 1] - selected_tests = tests_from_shard - - return selected_tests + return shards[which_shard - 1] class TestFailure(NamedTuple): @@ -1666,7 +1665,7 @@ def __init__( ): self.name = name self.failures = [] - self.sharded_tests = do_sharding( + self.time, self.sharded_tests = do_sharding( options, raw_tests, test_file_times_dict, @@ -1675,7 +1674,7 @@ def __init__( ) def __str__(self): - s = f"Name: {self.name}\n" + s = f"Name: {self.name} (est. time: {round(self.time / 60, 2)}min)\n" serial = [test for test in self.sharded_tests if must_serial(test)] parallel = [test for test in self.sharded_tests if not must_serial(test)] s += f" Serial tests ({len(serial)}):\n" @@ -1684,9 +1683,19 @@ def __str__(self): s += "".join(f" {test}\n" for test in parallel) return s.strip() - test_batch = TestBatch("all_tests", test_prioritizations.get_all_tests(), False) + percent_to_run = 25 if options.enable_td else 100 + print_to_stderr( + f"Running {percent_to_run}% of tests based on TD" + if options.enable_td + else "Running all tests" + ) + include, exclude = test_prioritizations.get_top_per_tests(percent_to_run) + + test_batch = TestBatch("tests to run", include, False) + test_batch_exclude = TestBatch("excluded", exclude, True) print_to_stderr(test_batch) + print_to_stderr(test_batch_exclude) if options.dry_run: return diff --git a/tools/testing/target_determination/heuristics/interface.py b/tools/testing/target_determination/heuristics/interface.py index 809033940d611..c6935643ce589 100644 --- a/tools/testing/target_determination/heuristics/interface.py +++ b/tools/testing/target_determination/heuristics/interface.py @@ -112,6 +112,12 @@ def get_all_tests(self) -> List[TestRun]: """Returns all tests in the TestPrioritizations""" return [x[1] for x in self._traverse_scores()] + def get_top_per_tests(self, n: int) -> Tuple[List[TestRun], List[TestRun]]: + """Divides list of tests into two based on the top n% of scores. The + first list is the top, and the second is the rest.""" + tests = [x[1] for x in self._traverse_scores()] + return tests[: n * len(tests) // 100], tests[n * len(tests) // 100 :] + def get_info_str(self) -> str: info = ""