Skip to content

Commit

Permalink
Test TD (test removal) on crossref (pytorch#119426)
Browse files Browse the repository at this point in the history
Current threshold is to cut the bottom 75% of test files, which results in 13 min of tests getting cut.
test_ops, functorch/test_ops, and test_decomp and other really long running test files are not getting cut and make the top 25% to take really long (still 90+ min)

The original plan was to test on rocm but I'm worried about queuing given that cutting 75% of test files only cuts off 13 min, and crossref is rarely referenced by others and people keep talking about getting rid of it, so it's a good alternative

Pull Request resolved: pytorch#119426
Approved by: https://github.com/huydhn
  • Loading branch information
clee2000 authored and pytorchmergebot committed Feb 29, 2024
1 parent 1458f1d commit 0290fe6
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 8 deletions.
25 changes: 17 additions & 8 deletions test/run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
set_cwd,
shell,
TEST_WITH_ASAN,
TEST_WITH_CROSSREF,
TEST_WITH_ROCM,
TEST_WITH_SLOW_GRADCHECK,
)
Expand Down Expand Up @@ -1164,6 +1165,7 @@ def parse_args():
action="store_true",
help="Enables removing tests based on TD",
default=IS_CI
and TEST_WITH_CROSSREF
and os.getenv("BRANCH", "") != "main"
and not strtobool(os.environ.get("NO_TD", "False")),
)
Expand Down Expand Up @@ -1462,7 +1464,7 @@ def do_sharding(
test_file_times: Dict[str, float],
test_class_times: Dict[str, Dict[str, float]],
sort_by_time: bool = True,
) -> List[ShardedTest]:
) -> Tuple[float, List[ShardedTest]]:
which_shard, num_shards = get_sharding_opts(options)

# Do sharding
Expand All @@ -1474,10 +1476,7 @@ def do_sharding(
must_serial=must_serial,
sort_by_time=sort_by_time,
)
_, tests_from_shard = shards[which_shard - 1]
selected_tests = tests_from_shard

return selected_tests
return shards[which_shard - 1]


class TestFailure(NamedTuple):
Expand Down Expand Up @@ -1666,7 +1665,7 @@ def __init__(
):
self.name = name
self.failures = []
self.sharded_tests = do_sharding(
self.time, self.sharded_tests = do_sharding(
options,
raw_tests,
test_file_times_dict,
Expand All @@ -1675,7 +1674,7 @@ def __init__(
)

def __str__(self):
s = f"Name: {self.name}\n"
s = f"Name: {self.name} (est. time: {round(self.time / 60, 2)}min)\n"
serial = [test for test in self.sharded_tests if must_serial(test)]
parallel = [test for test in self.sharded_tests if not must_serial(test)]
s += f" Serial tests ({len(serial)}):\n"
Expand All @@ -1684,9 +1683,19 @@ def __str__(self):
s += "".join(f" {test}\n" for test in parallel)
return s.strip()

test_batch = TestBatch("all_tests", test_prioritizations.get_all_tests(), False)
percent_to_run = 25 if options.enable_td else 100
print_to_stderr(
f"Running {percent_to_run}% of tests based on TD"
if options.enable_td
else "Running all tests"
)
include, exclude = test_prioritizations.get_top_per_tests(percent_to_run)

test_batch = TestBatch("tests to run", include, False)
test_batch_exclude = TestBatch("excluded", exclude, True)

print_to_stderr(test_batch)
print_to_stderr(test_batch_exclude)

if options.dry_run:
return
Expand Down
6 changes: 6 additions & 0 deletions tools/testing/target_determination/heuristics/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,12 @@ def get_all_tests(self) -> List[TestRun]:
"""Returns all tests in the TestPrioritizations"""
return [x[1] for x in self._traverse_scores()]

def get_top_per_tests(self, n: int) -> Tuple[List[TestRun], List[TestRun]]:
"""Divides list of tests into two based on the top n% of scores. The
first list is the top, and the second is the rest."""
tests = [x[1] for x in self._traverse_scores()]
return tests[: n * len(tests) // 100], tests[n * len(tests) // 100 :]

def get_info_str(self) -> str:
info = ""

Expand Down

0 comments on commit 0290fe6

Please sign in to comment.