From 9df2f8687f484f3f8cc055d016d4a389c6ba9fa6 Mon Sep 17 00:00:00 2001 From: dshi7 Date: Tue, 14 May 2024 17:09:22 +0000 Subject: [PATCH] cprofile every compile id [x/y] to keep consistent with tlparse (#125659) This PR moves cprofile decorator to keep consistent with `torch_inductor_stats` logging and is needed by fbcode diffs of profiling enablement in internal e2e jobs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/125659 Approved by: https://github.com/ezyang --- test/inductor/test_padding.py | 2 +- torch/_dynamo/convert_frame.py | 87 +++++++++++++++++++++++++++++++++- torch/_dynamo/utils.py | 63 ------------------------ torch/_inductor/compile_fx.py | 1 - 4 files changed, 86 insertions(+), 67 deletions(-) diff --git a/test/inductor/test_padding.py b/test/inductor/test_padding.py index 7aef585842e613..e08ac285801d75 100644 --- a/test/inductor/test_padding.py +++ b/test/inductor/test_padding.py @@ -7,9 +7,9 @@ import torch from torch import nn, Tensor +from torch._dynamo.convert_frame import maybe_cprofile from torch._dynamo.test_case import run_tests, TestCase from torch._dynamo.testing import rand_strided, reduce_to_scalar_loss -from torch._dynamo.utils import maybe_cprofile from torch._inductor import config, ir, metrics from torch._inductor.fx_passes import pad_mm as pad_mm_pass from torch._inductor.runtime.runtime_utils import do_bench diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 77447bc17dee11..38795341be2167 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -1,10 +1,14 @@ +import base64 import collections +import cProfile import dis import functools import itertools import logging import os +import pstats import random +import subprocess import sys import threading import time @@ -12,8 +16,11 @@ import types import typing import weakref +from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Set +from torch._utils_internal import maybe_upload_prof_stats_to_manifold + from torch.fx._lazy_graph_module import ( # type: ignore[attr-defined] _use_lazy_graph_module, ) @@ -87,7 +94,6 @@ is_namedtuple, istype, LazyString, - maybe_cprofile, orig_code_map, record_compilation_metrics, reset_graph_break_dup_checker, @@ -286,6 +292,83 @@ def exception_handler(e, code, frame=None, export=False): FRAME_COMPILE_COUNTER: typing.Counter[int] = collections.Counter() +def maybe_cprofile(func): + if config.cprofile: + return cprofile_wrapper(func) + return func + + +def cprofile_wrapper(func): + @functools.wraps(func) + def profile_wrapper(*args, **kwargs): + trace_id = CompileContext.current_trace_id() + assert trace_id, "Trace id is None" + profile_path = Path( + f"/tmp/{func.__name__}_{str(trace_id).replace('/','_')}.profile" + ) + prof = cProfile.Profile() + prof.enable() + start_ts = time.time() + retval = prof.runcall(func, *args, **kwargs) + profile_latency = time.time() - start_ts + prof.disable() + log.info( + "### Cprofile for %s trace id [%s] took %.3f seconds ###", + func.__name__, + trace_id, + profile_latency, + ) + ps = pstats.Stats(prof) + try: + prof.dump_stats(profile_path) + except PermissionError: + log.info("Cannot write to %s", str(profile_path)) + svg_path = profile_path.with_suffix(".svg") + try: + gprof2dot_process = subprocess.Popen( + [ + "gprof2dot", + "-f", + "pstats", + "--node-label=total-time-percentage", + "--node-label=self-time-percentage", + "--node-label=total-time", + str(profile_path), + ], + stdout=subprocess.PIPE, + ) + subprocess.check_call( + ["dot", "-Tsvg", "-o", str(svg_path)], + stdin=gprof2dot_process.stdout, + ) + log.info("Generated SVG from profile at %s", str(svg_path)) + except FileNotFoundError: + log.info( + "Failed to generate SVG from profile -- dumping stats instead." + "Try installing gprof2dot and dot for a better visualization" + ) + ps.sort_stats(pstats.SortKey.TIME).print_stats(20) + ps.sort_stats(pstats.SortKey.CUMULATIVE).print_stats(20) + + maybe_upload_prof_stats_to_manifold(str(profile_path)) # fb-only + + torch._logging.trace_structured( + "artifact", + lambda: { + "name": "dynamo_cprofile_prof", + "type": "prof", + "encoding": "base64", + }, + payload_fn=lambda: base64.encodebytes( + open(profile_path, "rb").read() + ).decode("ascii"), + ) + + return retval + + return profile_wrapper + + def convert_frame_assert( compiler_fn: CompilerFn, one_graph: bool = True, @@ -428,7 +511,6 @@ def register_bytecode_hook(hook: BytecodeHook) -> RemovableHandle: @compile_time_strobelight_meta(phase_name="_compile") @_use_lazy_graph_module(config.use_lazy_graph_module) -@maybe_cprofile def _compile( code: types.CodeType, globals: Dict[str, object], @@ -512,6 +594,7 @@ def transform(instructions, code_options): instructions[:] = remove_pointless_jumps(remove_dead_code(instructions)) @dynamo_timed(phase_name="entire_frame_compile") + @maybe_cprofile def compile_inner( code: types.CodeType, one_graph: bool, diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index ff9438085c5291..9c050d84a5eee0 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -2,7 +2,6 @@ import collections import contextlib import copy -import cProfile import dataclasses import datetime import dis @@ -16,9 +15,7 @@ import math import operator import os -import pstats import re -import subprocess import sys import textwrap import threading @@ -28,7 +25,6 @@ import weakref from contextlib import contextmanager from functools import lru_cache, wraps -from pathlib import Path from types import MethodWrapperType from typing import ( Any, @@ -50,8 +46,6 @@ ValuesView, ) -from torch._utils_internal import maybe_upload_prof_stats_to_manifold - from ..utils.hooks import RemovableHandle try: @@ -135,63 +129,6 @@ def tabulate(rows, headers): ) -def maybe_cprofile(func): - if config.cprofile: - return cprofile_wrapper(func) - return func - - -def cprofile_wrapper(func): - @wraps(func) - def profile_wrapper(*args, **kwargs): - global timer_counter - profile_cnt = next(timer_counter) - profile_path = Path("/tmp/" + func.__name__ + f"{profile_cnt}.profile") - prof = cProfile.Profile() - prof.enable() - start_ts = time.time() - retval = prof.runcall(func, *args, **kwargs) - profile_latency = time.time() - start_ts - prof.disable() - print( - f"### Cprofile for {func.__name__} iter {profile_cnt} took {profile_latency:.3f} seconds ###" - ) - ps = pstats.Stats(prof) - prof.dump_stats(profile_path) - svg_path = profile_path.with_suffix(".svg") - try: - gprof2dot_process = subprocess.Popen( - [ - "gprof2dot", - "-f", - "pstats", - "--node-label=total-time-percentage", - "--node-label=self-time-percentage", - "--node-label=total-time", - str(profile_path), - ], - stdout=subprocess.PIPE, - ) - subprocess.check_call( - ["dot", "-Tsvg", "-o", str(svg_path)], - stdin=gprof2dot_process.stdout, - ) - print(f"Generated SVG from profile at {str(svg_path)}") - except FileNotFoundError: - print( - "Failed to generate SVG from profile -- dumping stats instead." - "Try installing gprof2dot and dot for a better visualization" - ) - ps.sort_stats(pstats.SortKey.TIME).print_stats(20) - ps.sort_stats(pstats.SortKey.CUMULATIVE).print_stats(20) - - maybe_upload_prof_stats_to_manifold(str(profile_path)) # fb-only - - return retval - - return profile_wrapper - - curr_frame = 0 diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 3df2f5258b655e..bdbfef2eee28fb 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -1428,7 +1428,6 @@ def partition_fn(graph, joint_inputs, **kwargs): @compile_time_strobelight_meta(phase_name="bw_compiler") @dynamo_utils.dynamo_timed - @dynamo_utils.maybe_cprofile def bw_compiler(model: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): user_visible_outputs = {}