Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: add functional-type directive #3517

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions tests/irdl/test_declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -1570,6 +1570,66 @@ class ThreeResultsOp(IRDLOperation):
parser.parse_operation()


################################################################################
# Functional type #
################################################################################


@pytest.mark.parametrize(
"program",
[
"%0 = test.functional_type %1, %2 : (i32, i32) -> i32",
"test.functional_type %0, %1 : (i32, i32) -> ()",
"%0, %1 = test.functional_type %2, %3 : (i32, i32) -> (i32, i32)",
"%0 = test.functional_type %1 : (i32) -> i32",
"%0 = test.functional_type : () -> i32",
],
)
def test_functional_type(program: str):
"""Test the parsing of the functional-type directive"""

@irdl_op_definition
class FunctionalTypeOp(IRDLOperation):
name = "test.functional_type"

ops = var_operand_def()
res = var_result_def()

assembly_format = "$ops attr-dict `:` functional-type($ops, $res)"

ctx = MLContext()
ctx.load_op(FunctionalTypeOp)


@pytest.mark.parametrize(
"program",
[
"%0 = test.functional_type %1, %2 : (i32, i32) -> i32",
"%0, %1 = test.functional_type %2, %3 : (i32, i32) -> (i32, i32)",
"%0 = test.functional_type %1 : (i32) -> i32",
],
)
def test_functional_type_with_operands_and_results(program: str):
"""
Test the parsing of the functional-type directive using the operands and
results directives
"""

@irdl_op_definition
class FunctionalTypeOp(IRDLOperation):
name = "test.functional_type"

op1 = operand_def()
ops2 = var_operand_def()
res1 = var_result_def()
res2 = result_def()

assembly_format = "operands attr-dict `:` functional-type(operands, results)"

ctx = MLContext()
ctx.load_op(FunctionalTypeOp)


################################################################################
# Regions #
################################################################################
Expand Down
71 changes: 65 additions & 6 deletions xdsl/irdl/declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,9 +660,11 @@ def parse_optional(self, parser: Parser, state: ParsingState) -> bool:
return bool(operands)

def parse_single_type(self, parser: Parser, state: ParsingState) -> None:
if len(state.operand_types) > 1:
parser.raise_error("Expected multiple types but received one.")
state.operand_types[0] = parser.parse_type()
pos_start = parser.pos
if s := self._set_using_variadic_index(
state.operand_types, "operand types", (parser.parse_type(),)
):
parser.raise_error(s, at_position=pos_start, end_position=parser.pos)

def parse_many_types(self, parser: Parser, state: ParsingState) -> bool:
pos_start = parser.pos
Expand Down Expand Up @@ -782,9 +784,11 @@ class ResultsDirective(OperandsOrResultDirective):
"""

def parse_single_type(self, parser: Parser, state: ParsingState) -> None:
if len(state.result_types) > 1:
parser.raise_error("Expected multiple types but received one.")
state.result_types[0] = parser.parse_type()
pos_start = parser.pos
if s := self._set_using_variadic_index(
state.result_types, "result types", (parser.parse_type(),)
):
parser.raise_error(s, at_position=pos_start, end_position=parser.pos)

def parse_many_types(self, parser: Parser, state: ParsingState) -> bool:
pos_start = parser.pos
Expand All @@ -811,6 +815,61 @@ def is_present(self, op: IRDLOperation) -> bool:
return bool(op.results)


@dataclass(frozen=True)
class FunctionalTypeDirective(OptionallyParsableDirective):
"""
A directive which parses a functional type, with format:
functional-type-directive ::= functional-type(typeable-directive, typeable-directive)
A functional type is either of the form
`(` type-list `)` `->` `(` type-list `)`
or
`(` type-list `)` `->` type
where type-list is a comma separated list of types (or the empty string to signify the empty list).
The second format is preferred for printing when possible.
"""

operand_typeable_directive: TypeableDirective
result_typeable_directive: TypeableDirective

def parse_optional(self, parser: Parser, state: ParsingState) -> bool:
if not parser.parse_optional_punctuation("("):
return False
if isinstance(self.operand_typeable_directive, VariadicTypeableDirective):
self.operand_typeable_directive.parse_many_types(parser, state)
else:
self.operand_typeable_directive.parse_single_type(parser, state)
parser.parse_punctuation(")")
parser.parse_punctuation("->")
if parser.parse_optional_punctuation("("):
if isinstance(self.result_typeable_directive, VariadicTypeableDirective):
self.result_typeable_directive.parse_many_types(parser, state)
else:
self.result_typeable_directive.parse_single_type(parser, state)
parser.parse_punctuation(")")
else:
self.result_typeable_directive.parse_single_type(parser, state)
return True

def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None:
if state.should_emit_space or not state.last_was_punctuation:
printer.print_string(" ")
state.should_emit_space = True
printer.print_string("(")
printer.print_list(
self.operand_typeable_directive.get_types(op), printer.print_attribute
)
printer.print_string(") -> ")
result_types = self.result_typeable_directive.get_types(op)
if len(result_types) == 1:
printer.print_attribute(result_types[0])
state.last_was_punctuation = False
else:
printer.print_string("(")
printer.print_list(result_types, printer.print_attribute)
printer.print_string(")")
state.last_was_punctuation = True


class RegionDirective(OptionallyParsableDirective, ABC):
"""
Baseclass to help keep typechecking simple.
Expand Down
16 changes: 16 additions & 0 deletions xdsl/irdl/declarative_assembly_format_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
DefaultValuedAttributeVariable,
FormatDirective,
FormatProgram,
FunctionalTypeDirective,
KeywordDirective,
OperandOrResult,
OperandsDirective,
Expand Down Expand Up @@ -604,6 +605,19 @@ def parse_type_directive(self) -> FormatDirective:
return VariadicTypeDirective(inner)
return TypeDirective(inner)

def parse_functional_type_directive(self) -> FormatDirective:
"""
Parse a functional-type directive with the following format
functional-type-directive ::= `functional-type` `(` typeable-directive `,` typeable-directive `)`
`functional-type` is expected to have already been parsed
"""
self.parse_punctuation("(")
operands = self.parse_typeable_directive()
self.parse_punctuation(",")
results = self.parse_typeable_directive()
self.parse_punctuation(")")
return FunctionalTypeDirective(operands, results)

def parse_optional_group(self) -> FormatDirective:
"""
Parse an optional group, with the following format:
Expand Down Expand Up @@ -726,6 +740,8 @@ def parse_format_directive(self) -> FormatDirective:
return self.parse_type_directive()
if self.parse_optional_keyword("operands"):
return self.create_operands_directive(True)
if self.parse_optional_keyword("functional-type"):
return self.parse_functional_type_directive()
if self._current_token.text == "`":
return self.parse_keyword_or_punctuation()
if self.parse_optional_punctuation("("):
Expand Down