diff --git a/docs/source/onnx_torchscript.rst b/docs/source/onnx_torchscript.rst index 8c8032bd26b4d..aec370f4411d5 100644 --- a/docs/source/onnx_torchscript.rst +++ b/docs/source/onnx_torchscript.rst @@ -697,7 +697,6 @@ Functions ^^^^^^^^^ .. autofunction:: export -.. autofunction:: export_to_pretty_string .. autofunction:: register_custom_op_symbolic .. autofunction:: unregister_custom_op_symbolic .. autofunction:: select_model_mode_for_export diff --git a/test/onnx/test_pytorch_onnx_no_runtime.py b/test/onnx/test_pytorch_onnx_no_runtime.py index 37ca3836e5387..380a208bf9881 100644 --- a/test/onnx/test_pytorch_onnx_no_runtime.py +++ b/test/onnx/test_pytorch_onnx_no_runtime.py @@ -83,7 +83,7 @@ def forward(self, x): x = torch.ones(3, 3) f = io.BytesIO() - torch.onnx.export(AddmmModel(), x, f, verbose=False) + torch.onnx.export(AddmmModel(), x, f) def test_onnx_transpose_incomplete_tensor_type(self): # Smoke test to get us into the state where we are attempting to export @@ -115,7 +115,8 @@ def foo(x): traced = torch.jit.trace(foo, (torch.rand([2]))) - torch.onnx.export_to_pretty_string(traced, (torch.rand([2]),)) + f = io.BytesIO() + torch.onnx.export(traced, (torch.rand([2]),), f) def test_onnx_export_script_module(self): class ModuleToExport(torch.jit.ScriptModule): @@ -125,7 +126,8 @@ def forward(self, x): return x + x mte = ModuleToExport() - torch.onnx.export_to_pretty_string(mte, (torch.zeros(1, 2, 3),), verbose=False) + f = io.BytesIO() + torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f) @common_utils.suppress_warnings def test_onnx_export_func_with_warnings(self): @@ -138,9 +140,8 @@ def forward(self, x): return func_with_warning(x) # no exception - torch.onnx.export_to_pretty_string( - WarningTest(), torch.randn(42), verbose=False - ) + f = io.BytesIO() + torch.onnx.export(WarningTest(), torch.randn(42), f) def test_onnx_export_script_python_fail(self): class PythonModule(torch.jit.ScriptModule): @@ -161,7 +162,7 @@ def forward(self, x): mte = ModuleToExport() f = io.BytesIO() with self.assertRaisesRegex(RuntimeError, "Couldn't export Python"): - torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f, verbose=False) + torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f) def test_onnx_export_script_inline_trace(self): class ModuleToInline(torch.nn.Module): @@ -179,7 +180,8 @@ def forward(self, x): return y + y mte = ModuleToExport() - torch.onnx.export_to_pretty_string(mte, (torch.zeros(1, 2, 3),), verbose=False) + f = io.BytesIO() + torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f) def test_onnx_export_script_inline_script(self): class ModuleToInline(torch.jit.ScriptModule): @@ -198,7 +200,8 @@ def forward(self, x): return y + y mte = ModuleToExport() - torch.onnx.export_to_pretty_string(mte, (torch.zeros(1, 2, 3),), verbose=False) + f = io.BytesIO() + torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f) def test_onnx_export_script_module_loop(self): class ModuleToExport(torch.jit.ScriptModule): @@ -212,7 +215,8 @@ def forward(self, x): return x mte = ModuleToExport() - torch.onnx.export_to_pretty_string(mte, (torch.zeros(1, 2, 3),), verbose=False) + f = io.BytesIO() + torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f) @common_utils.suppress_warnings def test_onnx_export_script_truediv(self): @@ -224,9 +228,8 @@ def forward(self, x): mte = ModuleToExport() - torch.onnx.export_to_pretty_string( - mte, (torch.zeros(1, 2, 3, dtype=torch.float),), verbose=False - ) + f = io.BytesIO() + torch.onnx.export(mte, (torch.zeros(1, 2, 3, dtype=torch.float),), f) def test_onnx_export_script_non_alpha_add_sub(self): class ModuleToExport(torch.jit.ScriptModule): @@ -236,7 +239,8 @@ def forward(self, x): return bs - 1 mte = ModuleToExport() - torch.onnx.export_to_pretty_string(mte, (torch.rand(3, 4),), verbose=False) + f = io.BytesIO() + torch.onnx.export(mte, (torch.rand(3, 4),), f) def test_onnx_export_script_module_if(self): class ModuleToExport(torch.jit.ScriptModule): @@ -247,7 +251,8 @@ def forward(self, x): return x mte = ModuleToExport() - torch.onnx.export_to_pretty_string(mte, (torch.zeros(1, 2, 3),), verbose=False) + f = io.BytesIO() + torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f) def test_onnx_export_script_inline_params(self): class ModuleToInline(torch.jit.ScriptModule): @@ -277,7 +282,8 @@ def forward(self, x): torch.mm(torch.zeros(2, 3), torch.ones(3, 3)), torch.ones(3, 4) ) self.assertEqual(result, reference) - torch.onnx.export_to_pretty_string(mte, (torch.ones(2, 3),), verbose=False) + f = io.BytesIO() + torch.onnx.export(mte, (torch.ones(2, 3),), f) def test_onnx_export_speculate(self): class Foo(torch.jit.ScriptModule): @@ -312,8 +318,10 @@ def transpose(x): f1 = Foo(transpose) f2 = Foo(linear) - torch.onnx.export_to_pretty_string(f1, (torch.ones(1, 10, dtype=torch.float),)) - torch.onnx.export_to_pretty_string(f2, (torch.ones(1, 10, dtype=torch.float),)) + f = io.BytesIO() + torch.onnx.export(f1, (torch.ones(1, 10, dtype=torch.float),), f) + f = io.BytesIO() + torch.onnx.export(f2, (torch.ones(1, 10, dtype=torch.float),), f) def test_onnx_export_shape_reshape(self): class Foo(torch.nn.Module): @@ -326,7 +334,8 @@ def forward(self, x): return reshaped foo = torch.jit.trace(Foo(), torch.zeros(1, 2, 3)) - torch.onnx.export_to_pretty_string(foo, (torch.zeros(1, 2, 3))) + f = io.BytesIO() + torch.onnx.export(foo, (torch.zeros(1, 2, 3)), f) def test_listconstruct_erasure(self): class FooMod(torch.nn.Module): @@ -334,9 +343,11 @@ def forward(self, x): mask = x < 0.0 return x[mask] - torch.onnx.export_to_pretty_string( + f = io.BytesIO() + torch.onnx.export( FooMod(), (torch.rand(3, 4),), + f, add_node_names=False, do_constant_folding=False, operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, @@ -351,13 +362,10 @@ def forward(self, x): retval += torch.sum(x[0:i], dim=0) return retval - mod = DynamicSliceExportMod() - input = torch.rand(3, 4, 5) - torch.onnx.export_to_pretty_string( - DynamicSliceExportMod(), (input,), opset_version=10 - ) + f = io.BytesIO() + torch.onnx.export(DynamicSliceExportMod(), (input,), f, opset_version=10) def test_export_dict(self): class DictModule(torch.nn.Module): @@ -368,10 +376,12 @@ def forward(self, x_in: torch.Tensor) -> Dict[str, torch.Tensor]: mod = DictModule() mod.train(False) - torch.onnx.export_to_pretty_string(mod, (x_in,)) + f = io.BytesIO() + torch.onnx.export(mod, (x_in,), f) with self.assertRaisesRegex(RuntimeError, r"DictConstruct.+is not supported."): - torch.onnx.export_to_pretty_string(torch.jit.script(mod), (x_in,)) + f = io.BytesIO() + torch.onnx.export(torch.jit.script(mod), (x_in,), f) def test_source_range_propagation(self): class ExpandingModule(torch.nn.Module): @@ -497,11 +507,11 @@ def forward(self, box_regression: Tensor, proposals: List[Tensor]): proposal = [torch.randn(2, 4), torch.randn(2, 4)] with self.assertRaises(RuntimeError) as cm: - onnx_model = io.BytesIO() + f = io.BytesIO() torch.onnx.export( model, (box_regression, proposal), - onnx_model, + f, ) def test_initializer_sequence(self): @@ -637,7 +647,7 @@ def forward(self, x): x = torch.randn(1, 2, 3, requires_grad=True) f = io.BytesIO() - torch.onnx.export(Model(), x, f) + torch.onnx.export(Model(), (x,), f) model = onnx.load(f) model.ir_version = 0 @@ -744,7 +754,7 @@ def forward(self, x): f = io.BytesIO() with warnings.catch_warnings(record=True): - torch.onnx.export(MyDrop(), (eg,), f, verbose=False) + torch.onnx.export(MyDrop(), (eg,), f) def test_pack_padded_pad_packed_trace(self): from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence @@ -791,7 +801,7 @@ def forward(self, x, seq_lens): self.assertEqual(grad, grad_traced) f = io.BytesIO() - torch.onnx.export(m, (x, seq_lens), f, verbose=False) + torch.onnx.export(m, (x, seq_lens), f) # Suppression: ONNX warns when exporting RNNs because of potential batch size mismatch. @common_utils.suppress_warnings @@ -851,7 +861,7 @@ def forward(self, x, seq_lens): self.assertEqual(grad, grad_traced) f = io.BytesIO() - torch.onnx.export(m, (x, seq_lens), f, verbose=False) + torch.onnx.export(m, (x, seq_lens), f) def test_pushpackingpastrnn_in_peephole_create_own_gather_input(self): from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence @@ -931,7 +941,8 @@ class Mod(torch.nn.Module): def forward(self, x, w): return torch.matmul(x, w).detach() - torch.onnx.export_to_pretty_string(Mod(), (torch.rand(3, 4), torch.rand(4, 5))) + f = io.BytesIO() + torch.onnx.export(Mod(), (torch.rand(3, 4), torch.rand(4, 5)), f) def test_aten_fallback_must_fallback(self): class ModelWithAtenNotONNXOp(torch.nn.Module): @@ -1088,12 +1099,12 @@ def sym_scatter_max(g, src, index, dim, out, dim_size): torch.onnx.register_custom_op_symbolic( "torch_scatter::scatter_max", sym_scatter_max, 1 ) + f = io.BytesIO() with torch.no_grad(): torch.onnx.export( m, (src, idx), - "mymodel.onnx", - verbose=False, + f, opset_version=13, custom_opsets={"torch_scatter": 1}, do_constant_folding=True, @@ -1176,7 +1187,7 @@ def forward(self, x): model = Net(C).cuda().half() x = torch.randn(N, C).cuda().half() f = io.BytesIO() - torch.onnx.export(model, x, f, opset_version=14) + torch.onnx.export(model, (x,), f, opset_version=14) onnx_model = onnx.load_from_string(f.getvalue()) const_node = [n for n in onnx_model.graph.node if n.op_type == "Constant"] self.assertNotEqual(len(const_node), 0) diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index b25c823049dc5..b15a45a4d17bf 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -30,7 +30,6 @@ "JitScalarType", # Public functions "export", - "export_to_pretty_string", "is_in_onnx_export", "select_model_mode_for_export", "register_custom_op_symbolic", @@ -68,7 +67,6 @@ from .utils import ( _run_symbolic_function, _run_symbolic_method, - export_to_pretty_string, is_in_onnx_export, register_custom_op_symbolic, select_model_mode_for_export, diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 3d94581c79b04..7561438924591 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -35,7 +35,6 @@ "model_signature", "warn_on_static_input_change", "unpack_quantized_tensor", - "export_to_pretty_string", "unconvertible_ops", "register_custom_op_symbolic", "unregister_custom_op_symbolic", @@ -1140,84 +1139,6 @@ def _model_to_graph( return graph, params_dict, torch_out -@torch._disable_dynamo -@_deprecation.deprecated("2.5", "the future", "use onnx.printer.to_text() instead") -def export_to_pretty_string( - model, - args, - export_params=True, - verbose=False, - training=_C_onnx.TrainingMode.EVAL, - input_names=None, - output_names=None, - operator_export_type=_C_onnx.OperatorExportTypes.ONNX, - export_type=None, - google_printer=False, - opset_version=None, - keep_initializers_as_inputs=None, - custom_opsets=None, - add_node_names=True, - do_constant_folding=True, - dynamic_axes=None, -): - """Similar to :func:`export`, but returns a text representation of the ONNX model. - - Only differences in args listed below. All other args are the same - as :func:`export`. - - Args: - add_node_names (bool, default True): Whether or not to set - NodeProto.name. This makes no difference unless - ``google_printer=True``. - google_printer (bool, default False): If False, will return a custom, - compact representation of the model. If True will return the - protobuf's `Message::DebugString()`, which is more verbose. - - Returns: - A UTF-8 str containing a human-readable representation of the ONNX model. - """ - if opset_version is None: - opset_version = _constants.ONNX_DEFAULT_OPSET - if custom_opsets is None: - custom_opsets = {} - GLOBALS.export_onnx_opset_version = opset_version - GLOBALS.operator_export_type = operator_export_type - - with exporter_context(model, training, verbose): - val_keep_init_as_ip = _decide_keep_init_as_input( - keep_initializers_as_inputs, operator_export_type, opset_version - ) - val_add_node_names = _decide_add_node_names( - add_node_names, operator_export_type - ) - val_do_constant_folding = _decide_constant_folding( - do_constant_folding, operator_export_type, training - ) - args = _decide_input_format(model, args) - graph, params_dict, torch_out = _model_to_graph( - model, - args, - verbose, - input_names, - output_names, - operator_export_type, - val_do_constant_folding, - training=training, - dynamic_axes=dynamic_axes, - ) - - return graph._pretty_print_onnx( # type: ignore[attr-defined] - params_dict, - opset_version, - False, - operator_export_type, - google_printer, - val_keep_init_as_ip, - custom_opsets, - val_add_node_names, - ) - - @_deprecation.deprecated("2.5", "the future", "avoid using this function") def unconvertible_ops( model,