From 92f73301b115a9681dc6fe2fed6f63982bb1e00d Mon Sep 17 00:00:00 2001 From: Kiuk Chung Date: Thu, 8 Aug 2024 09:08:24 -0700 Subject: [PATCH] [test/profiler] Make test_profiler_pattern_matcher_json_report write torchtidy_report.json to a tmpdir instead of CWD --- test/profiler/test_profiler.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/test/profiler/test_profiler.py b/test/profiler/test_profiler.py index 0115a9eb0ca05..89dc29bdd0e5a 100644 --- a/test/profiler/test_profiler.py +++ b/test/profiler/test_profiler.py @@ -11,6 +11,7 @@ import struct import subprocess import sys +import tempfile import threading import time import unittest @@ -2528,9 +2529,11 @@ def test_profiler_pattern_matcher_json_report(self): loss.backward() optimizer.step() optimizer.zero_grad() - report_all_anti_patterns(prof, json_report_dir=".", print_enable=False) - try: - with open("./torchtidy_report.json") as f: + + with tempfile.TemporaryDirectory() as tmpdir: + report_all_anti_patterns(prof, json_report_dir=tmpdir, print_enable=False) + + with open(os.path.join(tmpdir, "torchtidy_report.json")) as f: report = json.load(f) # It is platform dependent whether the path will include "profiler/" @@ -2543,8 +2546,6 @@ def test_profiler_pattern_matcher_json_report(self): for event in entry: actual_fields = sorted(event.keys()) self.assertEqual(expected_fields, actual_fields) - finally: - os.remove("torchtidy_report.json") @unittest.skipIf(IS_ARM64 or not IS_LINUX, "x86 linux only cpp unwinding") def test_fuzz_symbolize(self):