From 3b4c3b0f8938de0c9e994b69e1e59b4b8ee2e2bd Mon Sep 17 00:00:00 2001 From: Alex Rice Date: Tue, 26 Nov 2024 09:55:46 +0000 Subject: [PATCH] Add operand types tests --- .../irdl/test_declarative_assembly_format.py | 43 ++++++++++++++++--- xdsl/irdl/declarative_assembly_format.py | 19 ++++---- 2 files changed, 47 insertions(+), 15 deletions(-) diff --git a/tests/irdl/test_declarative_assembly_format.py b/tests/irdl/test_declarative_assembly_format.py index 1b7be9b550..7cd2d21e11 100644 --- a/tests/irdl/test_declarative_assembly_format.py +++ b/tests/irdl/test_declarative_assembly_format.py @@ -1131,11 +1131,16 @@ class TwoOperandsOp(IRDLOperation): # pyright: ignore[reportUnusedClass] @pytest.mark.parametrize( "program, error", [ - ("test.two_operands %0 : i32", "Expected 2 operands but found 1"), + ("test.two_operands %0 : i32, i32", "Expected 2 operands but found 1"), ( - "test.two_operands %0, %1, %2 : i32, i32, i32", + "test.two_operands %0, %1, %2 : i32, i32", "Expected 2 operands but found 3", ), + ("test.two_operands %0, %1 : i32", "Expected 2 operand types but found 1"), + ( + "test.two_operands %0, %1 : i32, i32, i32", + "Expected 2 operand types but found 3", + ), ], ) def test_operands_directive_bounds(program: str, error: str): @@ -1159,11 +1164,22 @@ class TwoOperandsOp(IRDLOperation): @pytest.mark.parametrize( "program, error", [ - ("test.three_operands %0 : i32", "Expected at least 2 operands but found 1"), ( - "test.three_operands %0, %1, %2, %3 : i32, i32, i32, i32", + "test.three_operands %0 : i32, i32", + "Expected at least 2 operands but found 1", + ), + ( + "test.three_operands %0, %1, %2, %3 : i32, i32, i32", "Expected at most 3 operands but found 4", ), + ( + "test.three_operands %0, %1 : i32", + "Expected at least 2 operand types but found 1", + ), + ( + "test.three_operands %0, %1, %3 : i32, i32, i32, i32", + "Expected at most 3 operand types but found 4", + ), ], ) def test_operands_directive_bounds_with_opt(program: str, error: str): @@ -1185,7 +1201,20 @@ class ThreeOperandsOp(IRDLOperation): parser.parse_operation() -def test_operands_directive_bound_with_var(): +@pytest.mark.parametrize( + "program, error", + [ + ( + "test.three_operands %0 : i32, i32", + "Expected at least 2 operands but found 1", + ), + ( + "test.three_operands %0, %1 : i32", + "Expected at least 2 operand types but found 1", + ), + ], +) +def test_operands_directive_bound_with_var(program: str, error: str): @irdl_op_definition class ThreeOperandsOp(IRDLOperation): name = "test.three_operands" @@ -1199,8 +1228,8 @@ class ThreeOperandsOp(IRDLOperation): ctx = MLContext() ctx.load_op(ThreeOperandsOp) - with pytest.raises(ParseError, match="Expected at least 2 operands but found 1"): - parser = Parser(ctx, "test.three_operands %0 : i32") + with pytest.raises(ParseError, match=error): + parser = Parser(ctx, program) parser.parse_operation() diff --git a/xdsl/irdl/declarative_assembly_format.py b/xdsl/irdl/declarative_assembly_format.py index df9887abec..3b23359fdf 100644 --- a/xdsl/irdl/declarative_assembly_format.py +++ b/xdsl/irdl/declarative_assembly_format.py @@ -616,22 +616,23 @@ class OperandsDirective(VariadicOperandDirective, VariadicTypeableDirective): variadic_index: tuple[bool, int] | None def _set_using_variadic_index( - self, field: list[_T | None | Sequence[_T]], set_to: Sequence[_T] + self, + field: list[_T | None | Sequence[_T]], + field_name: str, + set_to: Sequence[_T], ) -> str | None: if self.variadic_index is None: if len(set_to) != len(field): - return f"Expected {len(field)} operands but found {len(set_to)}" + return f"Expected {len(field)} {field_name} but found {len(set_to)}" field = [o for o in set_to] # Copy needed as list is not covariant return is_optional, var_position = self.variadic_index var_length = len(set_to) - len(field) + 1 if var_length < 0: - return ( - f"Expected at least {len(field) - 1} operands but found {len(set_to)}" - ) + return f"Expected at least {len(field) - 1} {field_name} but found {len(set_to)}" if var_length > 1 and is_optional: - return f"Expected at most {len(field)} operands but found {len(set_to)}" + return f"Expected at most {len(field)} {field_name} but found {len(set_to)}" field[:var_position] = set_to[:var_position] field[var_position] = set_to[var_position : var_position + var_length] field[var_position + 1 :] = set_to[var_position + var_length :] @@ -646,7 +647,7 @@ def parse_optional(self, parser: Parser, state: ParsingState) -> bool: or [] ) - if s := self._set_using_variadic_index(state.operands, operands): + if s := self._set_using_variadic_index(state.operands, "operands", operands): parser.raise_error(s, at_position=pos_start, end_position=parser.pos) return bool(operands) @@ -664,7 +665,9 @@ def parse_many_types(self, parser: Parser, state: ParsingState) -> bool: or [] ) - if s := self._set_using_variadic_index(state.operand_types, types): + if s := self._set_using_variadic_index( + state.operand_types, "operand types", types + ): parser.raise_error(s, at_position=pos_start, end_position=parser.pos) return bool(types)