From 75933ff5231b1caed333065ea9f5a847caa4cdaa Mon Sep 17 00:00:00 2001 From: Thiago Crepaldi Date: Fri, 5 Apr 2024 20:56:26 +0000 Subject: [PATCH] Ignore logging.Logger.* calls during dynamo export (#123402) Follow up for https://github.com/pytorch/pytorch/pull/123368 Pull Request resolved: https://github.com/pytorch/pytorch/pull/123402 Approved by: https://github.com/williamwen42 --- test/export/test_export.py | 24 ++++++++++++++++++++++++ test/onnx/test_fx_to_onnx.py | 14 ++++++++++++++ torch/_dynamo/variables/builder.py | 4 ++++ torch/_dynamo/variables/misc.py | 21 +++++++++++++++++++++ 4 files changed, 63 insertions(+) diff --git a/test/export/test_export.py b/test/export/test_export.py index f6522c73c37e86..8d66b3f3e1d1e8 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -4,6 +4,7 @@ import dataclasses import io import re +import logging import unittest import warnings from contextlib import contextmanager @@ -4211,6 +4212,29 @@ def forward(self, x): return (add, add_1)""", ) + def test_logging_logger(self): + logger = logging.getLogger(__name__) + class M(torch.nn.Module): + def forward(self, x): + logger.log("start") + x1 = x + x + logger.debug(x1) + x2 = x1 * x1 + logger.info(1, 2, 3) + x3 = x2 + x2 + return (x1, x3) + + gm = export(M(), (torch.randn(3, 3),)).graph_module + self.assertExpectedInline( + gm.code.strip(), + """\ +def forward(self, arg0_1): + add = torch.ops.aten.add.Tensor(arg0_1, arg0_1); arg0_1 = None + mul = torch.ops.aten.mul.Tensor(add, add) + add_1 = torch.ops.aten.add.Tensor(mul, mul); mul = None + return (add, add_1)""", + ) + def test_warning(self): class M(torch.nn.Module): def forward(self, x): diff --git a/test/onnx/test_fx_to_onnx.py b/test/onnx/test_fx_to_onnx.py index c5b71633d71b37..c444ae54c74ebc 100644 --- a/test/onnx/test_fx_to_onnx.py +++ b/test/onnx/test_fx_to_onnx.py @@ -1,6 +1,8 @@ # Owner(s): ["module: onnx"] from __future__ import annotations +import logging + import tempfile from typing import Mapping, Tuple @@ -707,6 +709,18 @@ def forward(self, input: torch.Tensor): _ = torch.onnx.dynamo_export(Float8Module(), torch.randn(1, 2, 3, 4)) + def test_export_with_logging_logger(self): + logger = logging.getLogger(__name__) + + class LoggingLoggerModule(torch.nn.Module): + def forward(self, x): + logger.log("abc") + return x + 1 + + input = torch.randn(2, 3) + model = LoggingLoggerModule() + _ = torch.onnx.dynamo_export(model, input) + def test_checkpoint_cast(self): model_id = "openai/whisper-large-v3" feature_extractor = transformers.WhisperFeatureExtractor(feature_size=128) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index c57fb57987e874..1c6cf687ff884a 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -140,6 +140,7 @@ GetSetDescriptorVariable, InspectSignatureVariable, LambdaVariable, + LoggingLoggerVariable, MethodWrapperVariable, NumpyVariable, PythonModuleVariable, @@ -515,6 +516,9 @@ def build_key_value(i, k, v): # along with other builtin debugging functions self.install_guards(GuardBuilder.BUILTIN_MATCH) return DebuggingVariable(value, source=self.source) + elif isinstance(value, logging.Logger): + self.install_guards(GuardBuilder.FUNCTION_MATCH) + return LoggingLoggerVariable(value, source=self.source) elif is_utils_checkpoint(value): return build_checkpoint_variable(source=self.source) elif isinstance(value, functools.partial): diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 6bf3ffe3f348ce..f58db0f24a8edd 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -946,6 +946,27 @@ def can_reorder_logs(fn, args, kwargs) -> True: return True +class LoggingLoggerVariable(VariableTracker): + """ + Represents a call to any of logging.Logger methods + """ + + def __init__(self, value, **kwargs): + super().__init__(**kwargs) + + def call_method( + self, + tx, + name, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + if tx.export: + # For export cases, we can just make debugging functions no-ops + return + unimplemented("Logger not supported for non-export cases") + + class StopIterationVariable(VariableTracker): def __init__(self, args, **kwargs): super().__init__(**kwargs)