Skip to content

Commit

Permalink
[ONNX] Bump onnxscript version in CI; temporarily remove op test (pyt…
Browse files Browse the repository at this point in the history
…orch#133748)

Bump onnxscript version in CI to 0.1.0.dev20240831, and temporarily remove the fx consistency test. We will add a better version back later.
Pull Request resolved: pytorch#133748
Approved by: https://github.com/titaiwangms
  • Loading branch information
justinchuby authored and pytorchmergebot committed Sep 3, 2024
1 parent 27677ea commit 1b9f51b
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 2,363 deletions.
7 changes: 3 additions & 4 deletions .ci/docker/common/install_onnx.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,9 @@ pip_install \

pip_install coloredlogs packaging

pip_install onnxruntime==1.18
pip_install onnx==1.16.0
# pip_install "onnxscript@git+https://github.com/microsoft/onnxscript@3e869ef8ccf19b5ebd21c10d3e9c267c9a9fa729" --no-deps
pip_install onnxscript==0.1.0.dev20240613 --no-deps
pip_install onnxruntime==1.18.1
pip_install onnx==1.16.2
pip_install onnxscript==0.1.0.dev20240831 --no-deps
# required by onnxscript
pip_install ml_dtypes

Expand Down
223 changes: 1 addition & 222 deletions test/onnx/dynamo/test_registry_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,11 @@

import onnxscript # type: ignore[import]
from onnxscript import BFLOAT16, DOUBLE, FLOAT, FLOAT16 # type: ignore[import]
from onnxscript.function_libs.torch_lib import ops # type: ignore[import]
from onnxscript.onnx_opset import opset15 as op # type: ignore[import]

import torch
import torch.fx
from torch.onnx._internal.diagnostics import infra
from torch.onnx._internal.fx import (
analysis,
diagnostics,
onnxfunction_dispatcher,
registration,
)
from torch.onnx._internal.fx import diagnostics, onnxfunction_dispatcher, registration
from torch.testing._internal import common_utils


Expand Down Expand Up @@ -84,60 +77,6 @@ def test_custom(x, y):
[test_original, test_custom],
)

def test_unsupported_nodes_analysis_with_missing_aten_op(self):
# NOTE: simulate unsupported nodes
aten_mul_tensor = registration.OpName.from_name_parts(
namespace="aten", op_name="mul", overload="Tensor"
)
aten_mul_default = registration.OpName.from_name_parts(
namespace="aten", op_name="mul"
)
aten_add_tensor = registration.OpName.from_name_parts(
namespace="aten", op_name="add", overload="Tensor"
)
aten_add_default = registration.OpName.from_name_parts(
namespace="aten", op_name="add"
)

self.registry._registry.pop(aten_mul_tensor)
self.registry._registry.pop(aten_mul_default)
self.registry._registry.pop(aten_add_tensor)
self.registry._registry.pop(aten_add_default)

diagnostic_context = diagnostics.DiagnosticContext(
"torch.onnx.dynamo_export", torch.__version__
)
dispatcher = onnxfunction_dispatcher.OnnxFunctionDispatcher(
self.registry, diagnostic_context
)

graph: torch.fx.Graph = torch.fx.Graph()
x: torch.fx.Node = graph.create_node("placeholder", "x")
x.meta["val"] = torch.tensor(3.0)
b: torch.fx.Node = graph.create_node(
"call_function", target=torch.ops.aten.mul.Tensor, args=(x, x)
)
c: torch.fx.Node = graph.create_node(
"call_function", target=torch.ops.aten.add.Tensor, args=(b, b)
)
output: torch.fx.Node = graph.output(c)
module = torch.fx.GraphModule(torch.nn.Module(), graph)

with self.assertRaises(infra.RuntimeErrorWithDiagnostic):
analysis.UnsupportedFxNodesAnalysis(
diagnostic_context, module, dispatcher
).analyze(infra.levels.ERROR)

try:
analysis.UnsupportedFxNodesAnalysis(
diagnostic_context, module, dispatcher
).analyze(infra.levels.ERROR)
except infra.RuntimeErrorWithDiagnostic as e:
self.assertIn(
"Unsupported FX nodes: {'call_function': ['aten.mul.Tensor', 'aten.add.Tensor']}.",
e.diagnostic.message,
)


@common_utils.instantiate_parametrized_tests
class TestDispatcher(common_utils.TestCase):
Expand Down Expand Up @@ -488,165 +427,5 @@ def test_first_custom_op(
self.assertEqual(symbolic_fn, test_third_custom_op)


@common_utils.instantiate_parametrized_tests
class TestOpSchemaWrapper(common_utils.TestCase):
def setUp(self):
# overload type: optional dtype
self.onnx_function_new_full = ops.core.aten_new_full
self.onnx_function_new_full_dtype = ops.core.aten_new_full_dtype

@common_utils.parametrize(
"inputs, attributes, assertion",
[
common_utils.subtest(
([torch.randn(3, 4), torch.randn(3, 4)], {"alpha": 2.0}, True),
name="perfect_match_with_kwargs",
),
common_utils.subtest(
(["A", "B"], {}, False),
name="non_perfect_match_due_to_non_tensor_inputs",
),
common_utils.subtest(
([torch.randn(3, 4), torch.randn(3, 4), torch.randn(3, 4)], {}, False),
name="non_perfect_match_due_to_too_many_inputs",
),
common_utils.subtest(
([torch.randn(3, 4), torch.randn(3, 4)], {"wrong_kwargs": 2.0}, False),
name="non_perfect_match_due_to_wrong_kwargs",
),
],
)
def test_perfect_match_inputs(self, inputs, attributes, assertion):
# OnnxFunction with default attributes
dummy_diagnostic = diagnostics.Diagnostic(
rule=diagnostics.rules.find_opschema_matched_symbolic_function,
level=diagnostics.levels.WARNING,
)
op_schema_wrapper_add = onnxfunction_dispatcher._OnnxSchemaChecker(
ops.core.aten_add
)
self.assertEqual(
op_schema_wrapper_add.perfect_match_inputs(
dummy_diagnostic, inputs, attributes
),
assertion,
)

@common_utils.parametrize(
"inputs, kwargs, op, score",
[
common_utils.subtest(
([torch.randn(3, 4), torch.randn(3, 4)], {}, ops.core.aten_mul, 2),
name="match_2_inputs",
),
common_utils.subtest(
(
[
torch.randint(0, 2, size=(3, 4), dtype=torch.int).bool(),
torch.randint(0, 2, size=(3, 4), dtype=torch.int).bool(),
],
{},
ops.core.aten_mul,
0,
),
name="match_0_inputs",
),
common_utils.subtest(
([torch.randn(3, 4), torch.randn(3, 4)], {}, ops.core.aten_mul_bool, 0),
name="match_0_inputs_bool",
),
common_utils.subtest(
(
[
torch.randint(0, 2, size=(3, 4), dtype=torch.int).bool(),
torch.randint(0, 2, size=(3, 4), dtype=torch.int).bool(),
],
{},
ops.core.aten_mul_bool,
2,
),
name="match_2_inputs_bool",
),
],
)
def test_matching_score_system_on_overload_dtypes(self, inputs, kwargs, op, score):
op_schema_wrapper = onnxfunction_dispatcher._OnnxSchemaChecker(op)
op_schema_wrapper._record_matching_score(inputs, kwargs)
self.assertEqual(op_schema_wrapper.match_score, score)

@common_utils.parametrize(
"inputs, kwargs, op, score",
[
common_utils.subtest(
([torch.randn(3, 4), torch.tensor(3)], {}, ops.core.aten_new_full, 2),
name="match_2_inputs",
),
common_utils.subtest(
(
[torch.randn(3, 4), torch.tensor(3)],
{"dtype": 2}, # at this point, dtype should be converted to int
ops.core.aten_new_full_dtype,
2,
),
name="match_2_input_and_match_1_kwargs_optional",
),
],
)
def test_matching_score_system_on_optional_dtypes(self, inputs, kwargs, op, score):
op_schema_wrapper = onnxfunction_dispatcher._OnnxSchemaChecker(op)
op_schema_wrapper._record_matching_score(inputs, kwargs)
self.assertEqual(op_schema_wrapper.match_score, score)

@common_utils.parametrize(
"value, expected_onnx_str_dtype",
[
common_utils.subtest(
(1, {"tensor(int64)", "tensor(int16)", "tensor(int32)"}),
name="all_ints",
),
common_utils.subtest(
(1.0, {"tensor(float)", "tensor(double)", "tensor(float16)"}),
name="all_floats",
),
common_utils.subtest(
(torch.tensor([True]), {"tensor(bool)"}),
name="bool",
),
common_utils.subtest(
(torch.tensor([1], dtype=torch.int64), {"tensor(int64)"}),
name="int64",
),
common_utils.subtest(
(torch.tensor([1], dtype=torch.int32), {"tensor(int32)"}),
name="int32",
),
common_utils.subtest(
(torch.tensor([1], dtype=torch.int16), {"tensor(int16)"}),
name="int16",
),
common_utils.subtest(
(torch.tensor([1], dtype=torch.float), {"tensor(float)"}),
name="float",
),
common_utils.subtest(
(torch.tensor([1], dtype=torch.float16), {"tensor(float16)"}),
name="float16",
),
common_utils.subtest(
(torch.tensor([1], dtype=torch.double), {"tensor(double)"}),
name="double",
),
common_utils.subtest((None, set()), name="None"), # None allows no dtype
common_utils.subtest(
([], set()), name="empaty_list"
), # Empty list allows no dtype
],
)
def test_find_onnx_data_type(self, value, expected_onnx_str_dtype):
self.assertEqual(
onnxfunction_dispatcher._find_onnx_data_type(value), expected_onnx_str_dtype
)


if __name__ == "__main__":
common_utils.run_tests()
Loading

0 comments on commit 1b9f51b

Please sign in to comment.