Skip to content

Commit

Permalink
[ONNX] Run type promotion test in CI and update the table (pytorch#13…
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby authored and pytorchmergebot committed Sep 16, 2024
1 parent 090046b commit 0aa41eb
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 16 deletions.
12 changes: 3 additions & 9 deletions test/onnx/test_fx_type_promotion.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,16 @@
# Owner(s): ["module: onnx"]

import pytorch_test_common

from torch.onnx._internal.fx.passes import type_promotion
from torch.testing._internal import common_utils


class TestGeneratedTypePromotionRuleSet(common_utils.TestCase):
@pytorch_test_common.skip_in_ci(
"Reduce noise in CI. "
"The test serves as a tool to validate if the generated rule set is current. "
)
def test_generated_rule_set_is_up_to_date(self):
generated_set = type_promotion._GENERATED_ATEN_TYPE_PROMOTION_RULE_SET
latest_set = (
type_promotion.TypePromotionRuleSetGenerator.generate_from_torch_refs()
)
latest_set = type_promotion.ElementwiseTypePromotionRuleSetGenerator.generate_from_torch_refs()

# Please update the list in torch/onnx/_internal/fx/passes/type_promotion.py following the instruction
# if this test fails
self.assertEqual(generated_set, latest_set)

def test_initialize_type_promotion_table_succeeds(self):
Expand Down
14 changes: 7 additions & 7 deletions torch/onnx/_internal/fx/passes/type_promotion.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,9 @@ def preview_type_promotion(
ElementwiseTypePromotionRule(
"aten", "digamma_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
),
ElementwiseTypePromotionRule(
"aten", "dot", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
),
ElementwiseTypePromotionRule(
"aten", "elu", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
),
Expand Down Expand Up @@ -870,10 +873,7 @@ def preview_type_promotion(
"aten", "nll_loss", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
),
ElementwiseTypePromotionRule(
"aten", "normal", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
),
ElementwiseTypePromotionRule(
"aten", "normal_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
"aten", "normal", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
),
ElementwiseTypePromotionRule(
"aten", "pdist", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
Expand Down Expand Up @@ -924,9 +924,6 @@ def preview_type_promotion(
ElementwiseTypePromotionRule(
"aten", "rsqrt_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
),
ElementwiseTypePromotionRule(
"aten", "rsub", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
),
ElementwiseTypePromotionRule(
"aten", "selu", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
),
Expand Down Expand Up @@ -1030,6 +1027,9 @@ def preview_type_promotion(
ElementwiseTypePromotionRule(
"aten", "trunc_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
),
ElementwiseTypePromotionRule(
"aten", "vdot", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
),
ElementwiseTypePromotionRule(
"aten", "where", [1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH
),
Expand Down

0 comments on commit 0aa41eb

Please sign in to comment.