Skip to content

Commit

Permalink
Add structured trace logs (pytorch#120289)
Browse files Browse the repository at this point in the history
Overall design: https://docs.google.com/document/d/1CX_hJ0PNy9f3R1y8TJrfkSeLkvGjjjLU84BSXgS2AZ8/edit

How to read the diff:
* Most files are me augmenting pre-existing logging with structured variants. For the most part it's simple (esp FX graphs, which have a canonical string representation); it gets more complicated when I decided to JSON-ify some data structure instead of keeping the ad hoc printing (notably, guards and dynamo output graph sizes)
* torch/_functorch/_aot_autograd/collect_metadata_analysis.py is some unrelated fixes I noticed while auditing artifact logs
* torch/_logging/_internal.py has the actual trace log implementation. The trace logger is implement as a logger named torch.__trace which is disconnected from the logging hierarchy. It gets its own handler and formatter (TorchLogsFormatter with _is_trace True). `trace_structured` is the main way to emit a trace log. Unusually, there's a separate "metadata" and "payload" field. The metadata field should not be too long (as it is serialized as a single line) and is always JSON (we put contextual things like compile id in it); the payload field can be long and is emitted after the metadata log line and can span multiple lines.
* torch/_logging/structured.py contains some helpers for converting Python data structures into JSON form. Notably, we have a string interning implementation here, which helps reduce the cost of serializing filenames into the log.
* test/dynamo/test_structured_trace.py the tests are cribbed from test_logging.py, but all rewritten to use expect tests on munged versions of what we'd actually output. Payloads are never tested, since they tend not be very stable.

https://github.com/ezyang/tlparse is a POC Rust program that can interpret these logs.

Signed-off-by: Edward Z. Yang <[email protected]>

Pull Request resolved: pytorch#120289
Approved by: https://github.com/Skylion007
ghstack dependencies: pytorch#120712
  • Loading branch information
ezyang authored and pytorchmergebot committed Feb 28, 2024
1 parent 677e67c commit 1a1fc10
Show file tree
Hide file tree
Showing 16 changed files with 688 additions and 22 deletions.
394 changes: 394 additions & 0 deletions test/dynamo/test_structured_trace.py

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions torch/_dynamo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ def reset() -> None:
_reset_guarded_backend_cache()
reset_frame_count()
torch._C._dynamo.compiled_autograd.clear_cache()
convert_frame.FRAME_COUNTER = 0
convert_frame.FRAME_COMPILE_COUNTER.clear()


def reset_code_caches() -> None:
Expand Down
13 changes: 13 additions & 0 deletions torch/_dynamo/backends/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch import fx
from torch._dynamo.output_graph import GraphCompileReason
from torch._dynamo.utils import deepcopy_to_fake_tensor, detect_fake_mode
from torch._logging import trace_structured
from torch.fx.node import Node

# Regular log messages should go through 'log'.
Expand Down Expand Up @@ -312,6 +313,18 @@ def compile_fn(self, gm: fx.GraphModule, example_inputs: List[torch.Tensor]):
debug_str += "\n---------------\n"
ddp_graph_log.debug(debug_str)

trace_structured(
"optimize_ddp_split_graph",
payload_fn=lambda: split_gm.print_readable(print_output=False),
)
for name, module in split_gm.named_modules():
if "." not in name and len(name):
trace_structured(
"optimize_ddp_split_child",
lambda: {"name": name},
payload_fn=lambda: module.print_readable(print_output=False),
)

# 3 (lazy compile): Replace submodules with lazily compiling submodule
class SubmoduleReplacer(torch.fx.interpreter.Interpreter):
def __init__(self, module, compiler):
Expand Down
6 changes: 5 additions & 1 deletion torch/_dynamo/compiled_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch._dynamo.external_utils import call_backward, call_hook
from torch._dynamo.source import GetItemSource, LocalSource
from torch._dynamo.utils import counters, lazy_format_graph_code
from torch._logging import getArtifactLogger
from torch._logging import getArtifactLogger, trace_structured
from torch._prims_common import clone_preserve_strides
from torch._subclasses import FakeTensorMode
from torch.fx import GraphModule
Expand Down Expand Up @@ -201,6 +201,10 @@ def end_capture(self, outputs):
compiled_autograd_log.info(
"%s", lazy_format_graph_code("Compiled autograd graph", graph)
)
trace_structured(
"compiled_autograd_graph",
payload_fn=lambda: graph.print_readable(print_output=False),
)
return self.compiler_fn(graph)

def to_proxy(self, t):
Expand Down
10 changes: 10 additions & 0 deletions torch/_dynamo/convert_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import torch
import torch._logging
from torch._guards import compile_context, CompileContext, CompileId, tracing
from torch._logging import structured
from torch._utils_internal import signpost_event
from torch.fx.experimental.symbolic_shapes import (
ConstraintViolationError,
Expand Down Expand Up @@ -656,6 +657,15 @@ def count_args(code):
# -2: omit current frame, omit contextlib decorator
"".join(traceback.format_list(traceback.extract_stack()[: -2 - skip])),
)
# -4: -2 as above, plus trace_structured frames
torch._logging.trace_structured(
"dynamo_start",
lambda: {
"stack": structured.from_traceback(
traceback.extract_stack()[: -4 - skip]
)
},
)
start_time = time.time()
fail_type: Optional[str] = None
fail_reason: Optional[str] = None
Expand Down
21 changes: 20 additions & 1 deletion torch/_dynamo/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,13 @@
GuardSource,
Source,
)

from torch._logging import structured
from torch.fx.experimental.symbolic_shapes import (
EqualityConstraint,
is_symbolic,
SYMPY_INTERP,
)

from torch.utils._traceback import format_frame, report_compile_source_on_error
from torch.utils.weak import TensorWeakRef

Expand Down Expand Up @@ -1077,11 +1078,24 @@ def compile_check_fn(self, builder, guards_out, guard_fail_fn):
# Don't report this guard, it's always the same, useless!
code_parts = ["___check_global_state()"]
verbose_code_parts = code_parts[:]
structured_guard_fns = []

def add_code_part(code_part, guard, log_only=False):
verbose_code_part = get_verbose_code_part(code_part, guard)
guards_log.debug("%s", verbose_code_part)

structured_guard_fns.append(
lambda: {
"code": code_part,
"stack": structured.from_traceback(guard.stack.summary())
if guard.stack
else None,
"user_stack": structured.from_traceback(guard.user_stack)
if guard.user_stack
else None,
}
)

if verbose_guards_log.isEnabledFor(logging.DEBUG):
maybe_stack = ""
maybe_user_stack = ""
Expand Down Expand Up @@ -1176,6 +1190,11 @@ def add_code_part(code_part, guard, log_only=False):
for code in gcl.code_list:
add_code_part(code, gcl.guard)

# OK, all done generating guards
torch._logging.trace_structured(
"dynamo_guards", payload_fn=lambda: [f() for f in structured_guard_fns]
)

global_state = convert_frame.initial_global_state
if global_state is None:
# we should only hit this case in NopTests()
Expand Down
20 changes: 16 additions & 4 deletions torch/_dynamo/output_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,7 +999,16 @@ def cleanup_graph(self):
self.graph.erase_node(node1)
self.graph.erase_node(node2)

def get_graph_sizes_log_str(self, name):
def get_graph_sizes_structured(self):
ret = {}
for node in self.graph.nodes:
example_value = node.meta.get("example_value", None)
if isinstance(example_value, torch._subclasses.FakeTensor):
size = example_value.size()
ret[node.name] = [s if isinstance(s, int) else repr(s) for s in size]
return ret

def get_graph_sizes(self, name: str):
graph_sizes_str = "TRACED GRAPH TENSOR SIZES\n"
graph_sizes_str += f"===== {name} =====\n"
for node in self.graph.nodes:
Expand Down Expand Up @@ -1082,10 +1091,13 @@ def compile_and_call_fx_graph(self, tx, rv, root):
] = self.dynamo_flat_name_to_original_fqn.copy()

graph_code_log.debug("%s", lazy_format_graph_code(name, gm))
graph_tabular_log.debug("%s", lazy_format_graph_tabular(name, gm))
graph_sizes_log.debug(
"%s", LazyString(lambda: self.get_graph_sizes_log_str(name))
torch._logging.trace_structured(
"dynamo_output_graph",
lambda: {"sizes": self.get_graph_sizes_structured()},
payload_fn=lambda: gm.print_readable(print_output=False),
)
graph_tabular_log.debug("%s", lazy_format_graph_tabular(name, gm))
graph_sizes_log.debug("%s", LazyString(lambda: self.get_graph_sizes(name)))
self.call_cleanup_hooks()
old_fake_mode = self.tracing_context.fake_mode
if not self.export:
Expand Down
10 changes: 5 additions & 5 deletions torch/_functorch/_aot_autograd/collect_metadata_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
"""

import collections
import logging
from functools import wraps
from typing import Callable, DefaultDict, Dict, List

import torch
import torch.utils._pytree as pytree
from torch import Tensor
from torch._logging import getArtifactLogger
from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode
from torch._subclasses.meta_utils import safe_is_leaf
from torch.fx.experimental.symbolic_shapes import is_concrete_int
Expand Down Expand Up @@ -45,7 +45,7 @@

zip = strict_zip

aot_graphs_log = getArtifactLogger(__name__, "aot_graphs")
log = logging.getLogger(__name__)


# This is a version of functionalization that is specifically designed
Expand Down Expand Up @@ -413,7 +413,7 @@ def inner(*flat_args):
# However, autograd does not allow users to mutate multi-output views
# in any way that can change the autograd metadata of other aliases.
# So we hide this aliasing from autograd here.
aot_graphs_log.info(
log.debug(
"Encountered AOTAutograd case: differentiable outputs that \
alias each other from a multi-output view call"
)
Expand Down Expand Up @@ -455,7 +455,7 @@ def inner(*flat_args):
out_tensor_alias_counts[curr_storage] != 1
and num_aliased_outs_that_are_not_multi_output_views <= 1
):
aot_graphs_log.info(
log.debug(
"Encountered AOTAutograd case: differentiable outputs that alias each other \
from a multi-output view call"
)
Expand Down Expand Up @@ -599,7 +599,7 @@ def view_avoid_dupes_with_primals(t):
torch.set_grad_enabled(
prior_grad_enabled
) # Restore the prior state after tracing it
aot_graphs_log.info(
log.debug(
(
"grad_mode mutation encountered in graph. "
"Will emit mutation epilogue, to set grad_mode=%s"
Expand Down
6 changes: 5 additions & 1 deletion torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch import Tensor
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.utils import lazy_format_graph_code
from torch._logging import getArtifactLogger
from torch._logging import getArtifactLogger, trace_structured
from torch._subclasses.functional_tensor import FunctionalTensorMode
from torch.fx.experimental.proxy_tensor import make_fx

Expand Down Expand Up @@ -109,6 +109,10 @@ def aot_dispatch_base_graph(
aot_graphs_log.info(
"%s", lazy_format_graph_code("Forward graph", fw_module, aot_config.aot_id)
)
trace_structured(
"aot_forward_graph",
payload_fn=lambda: fw_module.print_readable(print_output=False),
)

# TODO: should factor this into a separate function for export that always only returns just the graph.
if aot_config.is_export:
Expand Down
3 changes: 0 additions & 3 deletions torch/_functorch/_aot_autograd/input_output_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import torch
import torch.utils._pytree as pytree
from torch import Tensor
from torch._logging import getArtifactLogger
from torch._subclasses.functional_tensor import FunctionalTensor
from torch.fx.experimental.symbolic_shapes import is_concrete_int
from .schemas import (
Expand All @@ -30,8 +29,6 @@

zip = strict_zip

aot_graphs_log = getArtifactLogger(__name__, "aot_graphs")


def remove_dupe_metadata(
m: ViewAndMutationMeta,
Expand Down
14 changes: 13 additions & 1 deletion torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torch import Tensor
from torch._dynamo.utils import lazy_format_graph_code
from torch._guards import detect_fake_mode, tracing, TracingContext
from torch._logging import getArtifactLogger
from torch._logging import getArtifactLogger, trace_structured
from torch._prims_common import CUDARngStateHelper
from torch._subclasses import FakeTensor
from torch.fx.experimental.proxy_tensor import is_sym_node
Expand Down Expand Up @@ -168,6 +168,10 @@ def aot_dispatch_autograd(
aot_joint_log.info(
"%s", lazy_format_graph_code("Joint graph", fx_g, aot_config.aot_id)
)
trace_structured(
"aot_joint_graph",
payload_fn=lambda: fx_g.print_readable(print_output=False), # type: ignore[union-attr]
)

with torch.no_grad():
inner_meta = (
Expand Down Expand Up @@ -287,6 +291,14 @@ def aot_dispatch_autograd(
"%s",
lazy_format_graph_code("Backward graph", bw_module, aot_config.aot_id),
)
trace_structured(
"aot_forward_graph",
payload_fn=lambda: fw_module.print_readable(print_output=False),
)
trace_structured(
"aot_backward_graph",
payload_fn=lambda: bw_module.print_readable(print_output=False),
)

with track_graph_compiling(aot_config, "forward"):
# flat_args at this point might still be subclasses-
Expand Down
8 changes: 8 additions & 0 deletions torch/_inductor/compile_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from torch._inductor.codecache import code_hash, CompiledFxGraph, FxGraphCache

from torch._inductor.debug import save_args_for_compile_fx_inner
from torch._logging import trace_structured
from torch._ops import OpOverload
from torch._subclasses.fake_tensor import FakeTensor
from torch._utils_internal import signpost_event
Expand Down Expand Up @@ -590,6 +591,9 @@ def fx_codegen_and_compile(
f"graph {graph_id}",
)
V.debug.fx_graph(gm, example_inputs)
# TODO: Should we actually dump this? It should be redundant with the aot
# structured logs...
# trace_structured("inductor_input_graph", payload_fn=lambda: gm.print_readable(print_output=False))

shape_env = _shape_env_from_inputs(example_inputs)

Expand Down Expand Up @@ -627,6 +631,10 @@ def fx_codegen_and_compile(
post_grad_passes(gm, is_inference=is_inference)
V.debug.fx_graph_transformed(gm, example_inputs)
post_grad_graphs_log.debug("%s", lazy_format_graph_code("AFTER POST GRAD", gm))
trace_structured(
"inductor_post_grad_graph",
payload_fn=lambda: gm.print_readable(print_output=False),
)
optimus_scuba_log["inductor_post_grad"] = counters["inductor"]
signpost_event(
"optimus",
Expand Down
7 changes: 6 additions & 1 deletion torch/_inductor/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch.fx
from torch._decomp import get_decompositions
from torch._dynamo.utils import defake, dynamo_timed
from torch._logging import LazyString
from torch._logging import LazyString, trace_structured
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx.experimental.sym_node import magic_methods, method_to_operator
from torch.fx.experimental.symbolic_shapes import has_free_symbols, ShapeEnv, SymTypes
Expand Down Expand Up @@ -1242,6 +1242,11 @@ def compile_to_module(self):
log_module_code(mod.__file__)
log.debug("Output code written to: %s", mod.__file__)
output_code_log.debug("Output code: \n%s", code)
trace_structured(
"inductor_output_code",
lambda: {"filename": mod.__file__},
payload_fn=lambda: code,
)
output_code_log.info("Output code written to: %s", mod.__file__)
if config.benchmark_kernel:
print(f"Compiled module path: {mod.__file__}", file=sys.stderr)
Expand Down
1 change: 1 addition & 0 deletions torch/_logging/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@
getArtifactLogger,
LazyString,
set_logs,
trace_structured,
warning_once,
)
Loading

0 comments on commit 1a1fc10

Please sign in to comment.