Skip to content

Commit

Permalink
Ignore logging.Logger.* calls during dynamo export (pytorch#123402)
Browse files Browse the repository at this point in the history
Follow up for pytorch#123368

Pull Request resolved: pytorch#123402
Approved by: https://github.com/williamwen42
  • Loading branch information
Thiago Crepaldi authored and pytorchmergebot committed Apr 8, 2024
1 parent aa9aed2 commit 75933ff
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 0 deletions.
24 changes: 24 additions & 0 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import dataclasses
import io
import re
import logging
import unittest
import warnings
from contextlib import contextmanager
Expand Down Expand Up @@ -4211,6 +4212,29 @@ def forward(self, x):
return (add, add_1)""",
)

def test_logging_logger(self):
logger = logging.getLogger(__name__)
class M(torch.nn.Module):
def forward(self, x):
logger.log("start")
x1 = x + x
logger.debug(x1)
x2 = x1 * x1
logger.info(1, 2, 3)
x3 = x2 + x2
return (x1, x3)

gm = export(M(), (torch.randn(3, 3),)).graph_module
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, arg0_1):
add = torch.ops.aten.add.Tensor(arg0_1, arg0_1); arg0_1 = None
mul = torch.ops.aten.mul.Tensor(add, add)
add_1 = torch.ops.aten.add.Tensor(mul, mul); mul = None
return (add, add_1)""",
)

def test_warning(self):
class M(torch.nn.Module):
def forward(self, x):
Expand Down
14 changes: 14 additions & 0 deletions test/onnx/test_fx_to_onnx.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Owner(s): ["module: onnx"]
from __future__ import annotations

import logging

import tempfile

from typing import Mapping, Tuple
Expand Down Expand Up @@ -707,6 +709,18 @@ def forward(self, input: torch.Tensor):

_ = torch.onnx.dynamo_export(Float8Module(), torch.randn(1, 2, 3, 4))

def test_export_with_logging_logger(self):
logger = logging.getLogger(__name__)

class LoggingLoggerModule(torch.nn.Module):
def forward(self, x):
logger.log("abc")
return x + 1

input = torch.randn(2, 3)
model = LoggingLoggerModule()
_ = torch.onnx.dynamo_export(model, input)

def test_checkpoint_cast(self):
model_id = "openai/whisper-large-v3"
feature_extractor = transformers.WhisperFeatureExtractor(feature_size=128)
Expand Down
4 changes: 4 additions & 0 deletions torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@
GetSetDescriptorVariable,
InspectSignatureVariable,
LambdaVariable,
LoggingLoggerVariable,
MethodWrapperVariable,
NumpyVariable,
PythonModuleVariable,
Expand Down Expand Up @@ -515,6 +516,9 @@ def build_key_value(i, k, v):
# along with other builtin debugging functions
self.install_guards(GuardBuilder.BUILTIN_MATCH)
return DebuggingVariable(value, source=self.source)
elif isinstance(value, logging.Logger):
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return LoggingLoggerVariable(value, source=self.source)
elif is_utils_checkpoint(value):
return build_checkpoint_variable(source=self.source)
elif isinstance(value, functools.partial):
Expand Down
21 changes: 21 additions & 0 deletions torch/_dynamo/variables/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,6 +946,27 @@ def can_reorder_logs(fn, args, kwargs) -> True:
return True


class LoggingLoggerVariable(VariableTracker):
"""
Represents a call to any of logging.Logger methods
"""

def __init__(self, value, **kwargs):
super().__init__(**kwargs)

def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
if tx.export:
# For export cases, we can just make debugging functions no-ops
return
unimplemented("Logger not supported for non-export cases")


class StopIterationVariable(VariableTracker):
def __init__(self, args, **kwargs):
super().__init__(**kwargs)
Expand Down

0 comments on commit 75933ff

Please sign in to comment.