diff --git a/tests/irdl/test_declarative_assembly_format.py b/tests/irdl/test_declarative_assembly_format.py index 99cd7e5173..c1f144a4af 100644 --- a/tests/irdl/test_declarative_assembly_format.py +++ b/tests/irdl/test_declarative_assembly_format.py @@ -1601,6 +1601,35 @@ class FunctionalTypeOp(IRDLOperation): ctx.load_op(FunctionalTypeOp) +@pytest.mark.parametrize( + "program", + [ + "%0 = test.functional_type %1, %2 : (i32, i32) -> i32", + "%0, %1 = test.functional_type %2, %3 : (i32, i32) -> (i32, i32)", + "%0 = test.functional_type %1 : (i32) -> i32", + ], +) +def test_functional_type_with_operands_and_results(program: str): + """ + Test the parsing of the functional-type directive using the operands and + results directives + """ + + @irdl_op_definition + class FunctionalTypeOp(IRDLOperation): + name = "test.functional_type" + + op1 = operand_def() + ops2 = var_operand_def() + res1 = var_result_def() + res2 = result_def() + + assembly_format = "operands attr-dict `:` functional-type(operands, results)" + + ctx = MLContext() + ctx.load_op(FunctionalTypeOp) + + ################################################################################ # Regions # ################################################################################ diff --git a/xdsl/irdl/declarative_assembly_format.py b/xdsl/irdl/declarative_assembly_format.py index 9bd6d3acb3..727ffe4b2b 100644 --- a/xdsl/irdl/declarative_assembly_format.py +++ b/xdsl/irdl/declarative_assembly_format.py @@ -660,9 +660,11 @@ def parse_optional(self, parser: Parser, state: ParsingState) -> bool: return bool(operands) def parse_single_type(self, parser: Parser, state: ParsingState) -> None: - if len(state.operand_types) > 1: - parser.raise_error("Expected multiple types but received one.") - state.operand_types[0] = parser.parse_type() + pos_start = parser.pos + if s := self._set_using_variadic_index( + state.operand_types, "operand types", (parser.parse_type(),) + ): + parser.raise_error(s, at_position=pos_start, end_position=parser.pos) def parse_many_types(self, parser: Parser, state: ParsingState) -> bool: pos_start = parser.pos @@ -782,9 +784,11 @@ class ResultsDirective(OperandsOrResultDirective): """ def parse_single_type(self, parser: Parser, state: ParsingState) -> None: - if len(state.result_types) > 1: - parser.raise_error("Expected multiple types but received one.") - state.result_types[0] = parser.parse_type() + pos_start = parser.pos + if s := self._set_using_variadic_index( + state.result_types, "result types", (parser.parse_type(),) + ): + parser.raise_error(s, at_position=pos_start, end_position=parser.pos) def parse_many_types(self, parser: Parser, state: ParsingState) -> bool: pos_start = parser.pos