diff --git a/tests/irdl/test_declarative_assembly_format.py b/tests/irdl/test_declarative_assembly_format.py index 7057586ca9..aade6fbfd5 100644 --- a/tests/irdl/test_declarative_assembly_format.py +++ b/tests/irdl/test_declarative_assembly_format.py @@ -1383,6 +1383,193 @@ class OptionalResultOp(IRDLOperation): check_equivalence(program, generic_program, ctx) +@pytest.mark.parametrize( + "program", + [ + "%0 = test.results_directive : i32", + "%0, %1 = test.results_directive : i32, i32", + "%0, %1, %2 = test.results_directive : i32, i32, i32", + ], +) +def test_results_directive(program: str): + """Test the results directive""" + + @irdl_op_definition + class ResultsDirectiveOp(IRDLOperation): + name = "test.results_directive" + + res1 = result_def() + res2 = var_result_def() + + assembly_format = "attr-dict `:` type(results)" + + ctx = MLContext() + ctx.load_op(ResultsDirectiveOp) + ctx.load_dialect(Test) + + check_roundtrip(program, ctx) + + +@pytest.mark.parametrize( + "program", + [ + "%0 = test.results_directive : i32", + "%0, %1 = test.results_directive : i32, i32", + ], +) +def test_results_directive_with_optional(program: str): + """Test the results directive with an optional result""" + + @irdl_op_definition + class ResultsDirectiveOp(IRDLOperation): + name = "test.results_directive" + + res1 = opt_result_def() + res2 = result_def() + + assembly_format = "attr-dict `:` type(results)" + + ctx = MLContext() + ctx.load_op(ResultsDirectiveOp) + ctx.load_dialect(Test) + + check_roundtrip(program, ctx) + + +def test_results_directive_fails_with_two_var(): + """Test results directive cannot be used with two variadic results""" + + with pytest.raises( + PyRDLOpDefinitionError, + match="'results' is ambiguous with multiple variadic results", + ): + + @irdl_op_definition + class TwoVarOp(IRDLOperation): # pyright: ignore[reportUnusedClass] + name = "test.two_var_op" + + res1 = var_result_def() + res2 = var_result_def() + + irdl_options = [AttrSizedResultSegments()] + + assembly_format = "attr-dict `:` type(results)" + + +def test_results_directive_fails_with_no_results(): + """Test results directive cannot be used with no results""" + + with pytest.raises( + PyRDLOpDefinitionError, + match="'results' should not be used when there are no results", + ): + + @irdl_op_definition + class NoResultsOp(IRDLOperation): # pyright: ignore[reportUnusedClass] + name = "test.no_results_op" + + assembly_format = "attr-dict `:` type(results)" + + +def test_results_directive_fails_with_other_type_directive(): + """Test results directive cannot be used with no results""" + + with pytest.raises( + PyRDLOpDefinitionError, + match="'results' cannot be used in a type directive with other result type directives", + ): + + @irdl_op_definition + class TwoResultsOp(IRDLOperation): # pyright: ignore[reportUnusedClass] + name = "test.two_results_op" + + res1 = result_def() + res2 = result_def() + + assembly_format = "attr-dict `:` type($res1) `,` type(results)" + + +@pytest.mark.parametrize( + "program, error", + [ + ("%0 = test.two_results : i32", "Expected 2 result types but found 1"), + ( + "%0, %1, %2 = test.two_results : i32, i32, i32", + "Expected 2 result types but found 3", + ), + ], +) +def test_results_directive_bounds(program: str, error: str): + @irdl_op_definition + class TwoResultsOp(IRDLOperation): + name = "test.two_results" + + res1 = result_def() + res2 = result_def() + + assembly_format = "attr-dict `:` type(results)" + + ctx = MLContext() + ctx.load_op(TwoResultsOp) + + with pytest.raises(ParseError, match=error): + parser = Parser(ctx, program) + parser.parse_operation() + + +@pytest.mark.parametrize( + "program, error", + [ + ( + "%0 = test.three_results : i32", + "Expected at least 2 result types but found 1", + ), + ( + "%0, %1, %2, %3 = test.three_results : i32, i32, i32, i32", + "Expected at most 3 result types but found 4", + ), + ], +) +def test_results_directive_bounds_with_opt(program: str, error: str): + @irdl_op_definition + class ThreeResultsOp(IRDLOperation): + name = "test.three_results" + + res1 = result_def() + res2 = opt_result_def() + res3 = result_def() + + assembly_format = "attr-dict `:` type(results)" + + ctx = MLContext() + ctx.load_op(ThreeResultsOp) + + with pytest.raises(ParseError, match=error): + parser = Parser(ctx, program) + parser.parse_operation() + + +def test_results_directive_bound_with_var(): + @irdl_op_definition + class ThreeResultsOp(IRDLOperation): + name = "test.three_results" + + res1 = result_def() + res2 = opt_result_def() + res3 = result_def() + + assembly_format = "attr-dict `:` type(results)" + + ctx = MLContext() + ctx.load_op(ThreeResultsOp) + + with pytest.raises( + ParseError, match="Expected at least 2 result types but found 1" + ): + parser = Parser(ctx, "%0 = test.three_results : i32") + parser.parse_operation() + + ################################################################################ # Regions # ################################################################################ diff --git a/xdsl/irdl/declarative_assembly_format.py b/xdsl/irdl/declarative_assembly_format.py index 1c5d3a6500..4e4fb1c20f 100644 --- a/xdsl/irdl/declarative_assembly_format.py +++ b/xdsl/irdl/declarative_assembly_format.py @@ -604,11 +604,9 @@ def set_types_empty(self, state: ParsingState) -> None: @dataclass(frozen=True) -class OperandsDirective(VariadicOperandDirective, VariadicTypeableDirective): +class OperandsOrResultDirective(VariadicTypeableDirective, ABC): """ - An operands directive, with the following format: - operands-directive ::= operands - Prints each operand of the operation, inserting a comma between each. + Base class for the 'operands' and 'results' directives. """ variadic_index: tuple[bool, int] | None @@ -639,6 +637,14 @@ def _set_using_variadic_index( field[var_position] = set_to[var_position : var_position + var_length] field[var_position + 1 :] = set_to[var_position + var_length :] + +class OperandsDirective(VariadicOperandDirective, OperandsOrResultDirective): + """ + An operands directive, with the following format: + operands-directive ::= operands + Prints each operand of the operation, inserting a comma between each. + """ + def parse_optional(self, parser: Parser, state: ParsingState) -> bool: pos_start = parser.pos operands = ( @@ -768,6 +774,43 @@ def set_types_empty(self, state: ParsingState) -> None: state.result_types[self.index] = () +class ResultsDirective(OperandsOrResultDirective): + """ + A results directive, with the following format: + results-directive ::= results + A typeable directive which processes the result types of the operation. + """ + + def parse_single_type(self, parser: Parser, state: ParsingState) -> None: + if len(state.result_types) > 1: + parser.raise_error("Expected multiple types but received one.") + state.result_types[0] = parser.parse_type() + + def parse_many_types(self, parser: Parser, state: ParsingState) -> bool: + pos_start = parser.pos + types = ( + parser.parse_optional_undelimited_comma_separated_list( + parser.parse_optional_type, parser.parse_type + ) + or [] + ) + + if s := self._set_using_variadic_index( + state.result_types, "result types", types + ): + parser.raise_error(s, at_position=pos_start, end_position=parser.pos) + return bool(types) + + def set_types_empty(self, state: ParsingState) -> None: + state.result_types = [() for _ in state.operand_types] + + def get_types(self, op: IRDLOperation) -> Sequence[Attribute]: + return op.result_types + + def is_present(self, op: IRDLOperation) -> bool: + return bool(op.results) + + class RegionDirective(OptionallyParsableDirective, ABC): """ Baseclass to help keep typechecking simple. diff --git a/xdsl/irdl/declarative_assembly_format_parser.py b/xdsl/irdl/declarative_assembly_format_parser.py index 92f0d1e17e..4250ff6b55 100644 --- a/xdsl/irdl/declarative_assembly_format_parser.py +++ b/xdsl/irdl/declarative_assembly_format_parser.py @@ -60,6 +60,7 @@ PunctuationDirective, RegionDirective, RegionVariable, + ResultsDirective, ResultVariable, SuccessorVariable, TypeableDirective, @@ -702,6 +703,8 @@ def parse_typeable_directive(self) -> TypeableDirective: """ if self.parse_optional_keyword("operands"): return self.create_operands_directive(False) + if self.parse_optional_keyword("results"): + return self.create_results_directive() if variable := self.parse_optional_typeable_variable(): return variable self.raise_error(f"unexpected token '{self._current_token.text}'") @@ -773,3 +776,22 @@ def create_operands_directive(self, top_level: bool) -> OperandsDirective: if not variadics: return OperandsDirective(None) return OperandsDirective(variadics[0]) + + def create_results_directive(self) -> ResultsDirective: + if not self.op_def.results: + self.raise_error("'results' should not be used when there are no results") + if any(self.seen_result_types): + self.raise_error( + "'results' cannot be used in a type directive with other result type directives" + ) + variadics = tuple( + (isinstance(o, OptResultDef), i) + for i, (_, o) in enumerate(self.op_def.results) + if isinstance(o, VarResultDef) + ) + if len(variadics) > 1: + self.raise_error("'results' is ambiguous with multiple variadic results") + self.seen_result_types = [True] * len(self.seen_result_types) + if not variadics: + return ResultsDirective(None) + return ResultsDirective(variadics[0])