Skip to content

Commit

Permalink
Add operand types tests
Browse files Browse the repository at this point in the history
  • Loading branch information
alexarice committed Nov 26, 2024
1 parent a6d139b commit 3b4c3b0
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 15 deletions.
43 changes: 36 additions & 7 deletions tests/irdl/test_declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -1131,11 +1131,16 @@ class TwoOperandsOp(IRDLOperation): # pyright: ignore[reportUnusedClass]
@pytest.mark.parametrize(
"program, error",
[
("test.two_operands %0 : i32", "Expected 2 operands but found 1"),
("test.two_operands %0 : i32, i32", "Expected 2 operands but found 1"),
(
"test.two_operands %0, %1, %2 : i32, i32, i32",
"test.two_operands %0, %1, %2 : i32, i32",
"Expected 2 operands but found 3",
),
("test.two_operands %0, %1 : i32", "Expected 2 operand types but found 1"),
(
"test.two_operands %0, %1 : i32, i32, i32",
"Expected 2 operand types but found 3",
),
],
)
def test_operands_directive_bounds(program: str, error: str):
Expand All @@ -1159,11 +1164,22 @@ class TwoOperandsOp(IRDLOperation):
@pytest.mark.parametrize(
"program, error",
[
("test.three_operands %0 : i32", "Expected at least 2 operands but found 1"),
(
"test.three_operands %0, %1, %2, %3 : i32, i32, i32, i32",
"test.three_operands %0 : i32, i32",
"Expected at least 2 operands but found 1",
),
(
"test.three_operands %0, %1, %2, %3 : i32, i32, i32",
"Expected at most 3 operands but found 4",
),
(
"test.three_operands %0, %1 : i32",
"Expected at least 2 operand types but found 1",
),
(
"test.three_operands %0, %1, %3 : i32, i32, i32, i32",
"Expected at most 3 operand types but found 4",
),
],
)
def test_operands_directive_bounds_with_opt(program: str, error: str):
Expand All @@ -1185,7 +1201,20 @@ class ThreeOperandsOp(IRDLOperation):
parser.parse_operation()


def test_operands_directive_bound_with_var():
@pytest.mark.parametrize(
"program, error",
[
(
"test.three_operands %0 : i32, i32",
"Expected at least 2 operands but found 1",
),
(
"test.three_operands %0, %1 : i32",
"Expected at least 2 operand types but found 1",
),
],
)
def test_operands_directive_bound_with_var(program: str, error: str):
@irdl_op_definition
class ThreeOperandsOp(IRDLOperation):
name = "test.three_operands"
Expand All @@ -1199,8 +1228,8 @@ class ThreeOperandsOp(IRDLOperation):
ctx = MLContext()
ctx.load_op(ThreeOperandsOp)

with pytest.raises(ParseError, match="Expected at least 2 operands but found 1"):
parser = Parser(ctx, "test.three_operands %0 : i32")
with pytest.raises(ParseError, match=error):
parser = Parser(ctx, program)
parser.parse_operation()


Expand Down
19 changes: 11 additions & 8 deletions xdsl/irdl/declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,22 +616,23 @@ class OperandsDirective(VariadicOperandDirective, VariadicTypeableDirective):
variadic_index: tuple[bool, int] | None

def _set_using_variadic_index(
self, field: list[_T | None | Sequence[_T]], set_to: Sequence[_T]
self,
field: list[_T | None | Sequence[_T]],
field_name: str,
set_to: Sequence[_T],
) -> str | None:
if self.variadic_index is None:
if len(set_to) != len(field):
return f"Expected {len(field)} operands but found {len(set_to)}"
return f"Expected {len(field)} {field_name} but found {len(set_to)}"
field = [o for o in set_to] # Copy needed as list is not covariant
return

is_optional, var_position = self.variadic_index
var_length = len(set_to) - len(field) + 1
if var_length < 0:
return (
f"Expected at least {len(field) - 1} operands but found {len(set_to)}"
)
return f"Expected at least {len(field) - 1} {field_name} but found {len(set_to)}"
if var_length > 1 and is_optional:
return f"Expected at most {len(field)} operands but found {len(set_to)}"
return f"Expected at most {len(field)} {field_name} but found {len(set_to)}"
field[:var_position] = set_to[:var_position]
field[var_position] = set_to[var_position : var_position + var_length]
field[var_position + 1 :] = set_to[var_position + var_length :]
Expand All @@ -646,7 +647,7 @@ def parse_optional(self, parser: Parser, state: ParsingState) -> bool:
or []
)

if s := self._set_using_variadic_index(state.operands, operands):
if s := self._set_using_variadic_index(state.operands, "operands", operands):
parser.raise_error(s, at_position=pos_start, end_position=parser.pos)
return bool(operands)

Expand All @@ -664,7 +665,9 @@ def parse_many_types(self, parser: Parser, state: ParsingState) -> bool:
or []
)

if s := self._set_using_variadic_index(state.operand_types, types):
if s := self._set_using_variadic_index(
state.operand_types, "operand types", types
):
parser.raise_error(s, at_position=pos_start, end_position=parser.pos)
return bool(types)

Expand Down

0 comments on commit 3b4c3b0

Please sign in to comment.