From 0aa41eb52f7e577cf88e0f1b0adb34167a9ae94b Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 16 Sep 2024 16:46:13 +0000 Subject: [PATCH] [ONNX] Run type promotion test in CI and update the table (#135915) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135915 Approved by: https://github.com/gramalingam, https://github.com/xadupre --- test/onnx/test_fx_type_promotion.py | 12 +++--------- torch/onnx/_internal/fx/passes/type_promotion.py | 14 +++++++------- 2 files changed, 10 insertions(+), 16 deletions(-) diff --git a/test/onnx/test_fx_type_promotion.py b/test/onnx/test_fx_type_promotion.py index 1e3860ad2a8efa..fc7dc21fba0069 100644 --- a/test/onnx/test_fx_type_promotion.py +++ b/test/onnx/test_fx_type_promotion.py @@ -1,22 +1,16 @@ # Owner(s): ["module: onnx"] -import pytorch_test_common - from torch.onnx._internal.fx.passes import type_promotion from torch.testing._internal import common_utils class TestGeneratedTypePromotionRuleSet(common_utils.TestCase): - @pytorch_test_common.skip_in_ci( - "Reduce noise in CI. " - "The test serves as a tool to validate if the generated rule set is current. " - ) def test_generated_rule_set_is_up_to_date(self): generated_set = type_promotion._GENERATED_ATEN_TYPE_PROMOTION_RULE_SET - latest_set = ( - type_promotion.TypePromotionRuleSetGenerator.generate_from_torch_refs() - ) + latest_set = type_promotion.ElementwiseTypePromotionRuleSetGenerator.generate_from_torch_refs() + # Please update the list in torch/onnx/_internal/fx/passes/type_promotion.py following the instruction + # if this test fails self.assertEqual(generated_set, latest_set) def test_initialize_type_promotion_table_succeeds(self): diff --git a/torch/onnx/_internal/fx/passes/type_promotion.py b/torch/onnx/_internal/fx/passes/type_promotion.py index 6397beb5f089a4..81cb6ccb7439d9 100644 --- a/torch/onnx/_internal/fx/passes/type_promotion.py +++ b/torch/onnx/_internal/fx/passes/type_promotion.py @@ -554,6 +554,9 @@ def preview_type_promotion( ElementwiseTypePromotionRule( "aten", "digamma_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT ), + ElementwiseTypePromotionRule( + "aten", "dot", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), ElementwiseTypePromotionRule( "aten", "elu", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT ), @@ -870,10 +873,7 @@ def preview_type_promotion( "aten", "nll_loss", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT ), ElementwiseTypePromotionRule( - "aten", "normal", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT - ), - ElementwiseTypePromotionRule( - "aten", "normal_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + "aten", "normal", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT ), ElementwiseTypePromotionRule( "aten", "pdist", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT @@ -924,9 +924,6 @@ def preview_type_promotion( ElementwiseTypePromotionRule( "aten", "rsqrt_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT ), - ElementwiseTypePromotionRule( - "aten", "rsub", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT - ), ElementwiseTypePromotionRule( "aten", "selu", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT ), @@ -1030,6 +1027,9 @@ def preview_type_promotion( ElementwiseTypePromotionRule( "aten", "trunc_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT ), + ElementwiseTypePromotionRule( + "aten", "vdot", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), ElementwiseTypePromotionRule( "aten", "where", [1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH ),