Skip to content

Commit

Permalink
cprofile every compile id [x/y] to keep consistent with tlparse (pyto…
Browse files Browse the repository at this point in the history
…rch#125659)

This PR moves cprofile decorator to keep consistent with `torch_inductor_stats` logging and is needed by fbcode diffs of profiling enablement in internal e2e jobs.

Pull Request resolved: pytorch#125659
Approved by: https://github.com/ezyang
  • Loading branch information
dshi7 authored and pytorchmergebot committed May 14, 2024
1 parent 2e4d011 commit 9df2f86
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 67 deletions.
2 changes: 1 addition & 1 deletion test/inductor/test_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

import torch
from torch import nn, Tensor
from torch._dynamo.convert_frame import maybe_cprofile
from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.testing import rand_strided, reduce_to_scalar_loss
from torch._dynamo.utils import maybe_cprofile
from torch._inductor import config, ir, metrics
from torch._inductor.fx_passes import pad_mm as pad_mm_pass
from torch._inductor.runtime.runtime_utils import do_bench
Expand Down
87 changes: 85 additions & 2 deletions torch/_dynamo/convert_frame.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
import base64
import collections
import cProfile
import dis
import functools
import itertools
import logging
import os
import pstats
import random
import subprocess
import sys
import threading
import time
import traceback
import types
import typing
import weakref
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Set

from torch._utils_internal import maybe_upload_prof_stats_to_manifold

from torch.fx._lazy_graph_module import ( # type: ignore[attr-defined]
_use_lazy_graph_module,
)
Expand Down Expand Up @@ -87,7 +94,6 @@
is_namedtuple,
istype,
LazyString,
maybe_cprofile,
orig_code_map,
record_compilation_metrics,
reset_graph_break_dup_checker,
Expand Down Expand Up @@ -286,6 +292,83 @@ def exception_handler(e, code, frame=None, export=False):
FRAME_COMPILE_COUNTER: typing.Counter[int] = collections.Counter()


def maybe_cprofile(func):
if config.cprofile:
return cprofile_wrapper(func)
return func


def cprofile_wrapper(func):
@functools.wraps(func)
def profile_wrapper(*args, **kwargs):
trace_id = CompileContext.current_trace_id()
assert trace_id, "Trace id is None"
profile_path = Path(
f"/tmp/{func.__name__}_{str(trace_id).replace('/','_')}.profile"
)
prof = cProfile.Profile()
prof.enable()
start_ts = time.time()
retval = prof.runcall(func, *args, **kwargs)
profile_latency = time.time() - start_ts
prof.disable()
log.info(
"### Cprofile for %s trace id [%s] took %.3f seconds ###",
func.__name__,
trace_id,
profile_latency,
)
ps = pstats.Stats(prof)
try:
prof.dump_stats(profile_path)
except PermissionError:
log.info("Cannot write to %s", str(profile_path))
svg_path = profile_path.with_suffix(".svg")
try:
gprof2dot_process = subprocess.Popen(
[
"gprof2dot",
"-f",
"pstats",
"--node-label=total-time-percentage",
"--node-label=self-time-percentage",
"--node-label=total-time",
str(profile_path),
],
stdout=subprocess.PIPE,
)
subprocess.check_call(
["dot", "-Tsvg", "-o", str(svg_path)],
stdin=gprof2dot_process.stdout,
)
log.info("Generated SVG from profile at %s", str(svg_path))
except FileNotFoundError:
log.info(
"Failed to generate SVG from profile -- dumping stats instead."
"Try installing gprof2dot and dot for a better visualization"
)
ps.sort_stats(pstats.SortKey.TIME).print_stats(20)
ps.sort_stats(pstats.SortKey.CUMULATIVE).print_stats(20)

maybe_upload_prof_stats_to_manifold(str(profile_path)) # fb-only

torch._logging.trace_structured(
"artifact",
lambda: {
"name": "dynamo_cprofile_prof",
"type": "prof",
"encoding": "base64",
},
payload_fn=lambda: base64.encodebytes(
open(profile_path, "rb").read()
).decode("ascii"),
)

return retval

return profile_wrapper


def convert_frame_assert(
compiler_fn: CompilerFn,
one_graph: bool = True,
Expand Down Expand Up @@ -428,7 +511,6 @@ def register_bytecode_hook(hook: BytecodeHook) -> RemovableHandle:

@compile_time_strobelight_meta(phase_name="_compile")
@_use_lazy_graph_module(config.use_lazy_graph_module)
@maybe_cprofile
def _compile(
code: types.CodeType,
globals: Dict[str, object],
Expand Down Expand Up @@ -512,6 +594,7 @@ def transform(instructions, code_options):
instructions[:] = remove_pointless_jumps(remove_dead_code(instructions))

@dynamo_timed(phase_name="entire_frame_compile")
@maybe_cprofile
def compile_inner(
code: types.CodeType,
one_graph: bool,
Expand Down
63 changes: 0 additions & 63 deletions torch/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import collections
import contextlib
import copy
import cProfile
import dataclasses
import datetime
import dis
Expand All @@ -16,9 +15,7 @@
import math
import operator
import os
import pstats
import re
import subprocess
import sys
import textwrap
import threading
Expand All @@ -28,7 +25,6 @@
import weakref
from contextlib import contextmanager
from functools import lru_cache, wraps
from pathlib import Path
from types import MethodWrapperType
from typing import (
Any,
Expand All @@ -50,8 +46,6 @@
ValuesView,
)

from torch._utils_internal import maybe_upload_prof_stats_to_manifold

from ..utils.hooks import RemovableHandle

try:
Expand Down Expand Up @@ -135,63 +129,6 @@ def tabulate(rows, headers):
)


def maybe_cprofile(func):
if config.cprofile:
return cprofile_wrapper(func)
return func


def cprofile_wrapper(func):
@wraps(func)
def profile_wrapper(*args, **kwargs):
global timer_counter
profile_cnt = next(timer_counter)
profile_path = Path("/tmp/" + func.__name__ + f"{profile_cnt}.profile")
prof = cProfile.Profile()
prof.enable()
start_ts = time.time()
retval = prof.runcall(func, *args, **kwargs)
profile_latency = time.time() - start_ts
prof.disable()
print(
f"### Cprofile for {func.__name__} iter {profile_cnt} took {profile_latency:.3f} seconds ###"
)
ps = pstats.Stats(prof)
prof.dump_stats(profile_path)
svg_path = profile_path.with_suffix(".svg")
try:
gprof2dot_process = subprocess.Popen(
[
"gprof2dot",
"-f",
"pstats",
"--node-label=total-time-percentage",
"--node-label=self-time-percentage",
"--node-label=total-time",
str(profile_path),
],
stdout=subprocess.PIPE,
)
subprocess.check_call(
["dot", "-Tsvg", "-o", str(svg_path)],
stdin=gprof2dot_process.stdout,
)
print(f"Generated SVG from profile at {str(svg_path)}")
except FileNotFoundError:
print(
"Failed to generate SVG from profile -- dumping stats instead."
"Try installing gprof2dot and dot for a better visualization"
)
ps.sort_stats(pstats.SortKey.TIME).print_stats(20)
ps.sort_stats(pstats.SortKey.CUMULATIVE).print_stats(20)

maybe_upload_prof_stats_to_manifold(str(profile_path)) # fb-only

return retval

return profile_wrapper


curr_frame = 0


Expand Down
1 change: 0 additions & 1 deletion torch/_inductor/compile_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1428,7 +1428,6 @@ def partition_fn(graph, joint_inputs, **kwargs):

@compile_time_strobelight_meta(phase_name="bw_compiler")
@dynamo_utils.dynamo_timed
@dynamo_utils.maybe_cprofile
def bw_compiler(model: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
user_visible_outputs = {}

Expand Down

0 comments on commit 9df2f86

Please sign in to comment.