Skip to content

Commit

Permalink
core: Move assembly format type parsing to directives
Browse files Browse the repository at this point in the history
  • Loading branch information
alexarice committed Nov 25, 2024
1 parent ef12f04 commit a2cc01f
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 109 deletions.
185 changes: 85 additions & 100 deletions xdsl/irdl/declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Literal, TypeAlias
from typing import Literal

from xdsl.dialects.builtin import UnitAttr
from xdsl.ir import (
Expand Down Expand Up @@ -306,39 +306,22 @@ class TypeableDirective(Directive, ABC):
"""

@abstractmethod
def set_type(self, type: Attribute, state: ParsingState) -> None: ...

@abstractmethod
def get_type(self, op: IRDLOperation) -> Attribute: ...


class VariadicTypeableDirective(AnchorableDirective, ABC):
"""
Directives which can set or get multiple types.
"""

@abstractmethod
def set_types(self, types: Sequence[Attribute], state: ParsingState) -> None: ...
def parse_single_type(self, parser: Parser, state: ParsingState) -> None: ...

@abstractmethod
def get_types(self, op: IRDLOperation) -> Sequence[Attribute]: ...


class OptionalTypeableDirective(AnchorableDirective, ABC):
class VariadicTypeableDirective(TypeableDirective, AnchorableDirective, ABC):
"""
Directives which can optionally set or get a single type.
Directives which can set or get multiple types.
"""

@abstractmethod
def set_type(self, type: Attribute | None, state: ParsingState) -> None: ...
def parse_many_types(self, parser: Parser, state: ParsingState) -> bool: ...

@abstractmethod
def get_type(self, op: IRDLOperation) -> Attribute | None: ...


AnyTypeableDirective: TypeAlias = (
TypeableDirective | VariadicTypeableDirective | OptionalTypeableDirective
)
def set_types_empty(self, state: ParsingState) -> None: ...


@dataclass(frozen=True)
Expand All @@ -351,26 +334,18 @@ class TypeDirective(FormatDirective):
inner: TypeableDirective

def parse(self, parser: Parser, state: ParsingState) -> None:
ty = parser.parse_type()
self.inner.set_type(ty, state)
self.inner.parse_single_type(parser, state)

def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None:
if state.should_emit_space or not state.last_was_punctuation:
printer.print(" ")
printer.print_attribute(self.inner.get_type(op))
printer.print_list(self.inner.get_types(op), printer.print_attribute)
state.last_was_punctuation = False
state.should_emit_space = True


class VariadicLikeTypeDirective(VariadicLikeFormatDirective):
"""
Base class for type checking.
A variadic-like type directive can not be followed by a variadic-like type directive.
"""


@dataclass(frozen=True)
class VariadicTypeDirective(VariadicLikeTypeDirective):
class VariadicTypeDirective(VariadicLikeFormatDirective):
"""
A directive which parses the type of a variadic typeable directive, with format:
type-directive ::= type(typeable-directive)
Expand All @@ -379,13 +354,7 @@ class VariadicTypeDirective(VariadicLikeTypeDirective):
inner: VariadicTypeableDirective

def parse_optional(self, parser: Parser, state: ParsingState) -> bool:
types = parser.parse_optional_undelimited_comma_separated_list(
parser.parse_optional_type, parser.parse_type
)
if types is None:
types = ()
self.inner.set_types(types, state)
return bool(types)
return self.inner.parse_many_types(parser, state)

def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None:
if state.should_emit_space or not state.last_was_punctuation:
Expand All @@ -398,37 +367,7 @@ def is_present(self, op: IRDLOperation) -> bool:
return self.inner.is_present(op)

def set_empty(self, state: ParsingState):
self.inner.set_types((), state)


@dataclass(frozen=True)
class OptionalTypeDirective(VariadicLikeTypeDirective):
"""
A directive which parses the type of a optional typeable directive, with format:
type-directive ::= type(typeable-directive)
"""

inner: OptionalTypeableDirective

def parse_optional(self, parser: Parser, state: ParsingState) -> bool:
type = parser.parse_optional_type()
self.inner.set_type(type, state)
return bool(type)

def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None:
if state.should_emit_space or not state.last_was_punctuation:
printer.print(" ")
type = self.inner.get_type(op)
if type:
printer.print_attribute(type)
state.last_was_punctuation = False
state.should_emit_space = True

def is_present(self, op: IRDLOperation) -> bool:
return self.inner.is_present(op)

def set_empty(self, state: ParsingState):
self.inner.set_type(None, state)
self.inner.set_types_empty(state)


@dataclass(frozen=True)
Expand Down Expand Up @@ -542,8 +481,8 @@ def parse(self, parser: Parser, state: ParsingState) -> None:
operand = parser.parse_unresolved_operand()
state.operands[self.index] = operand

def set_type(self, type: Attribute, state: ParsingState) -> None:
state.operand_types[self.index] = type
def parse_single_type(self, parser: Parser, state: ParsingState) -> None:
state.operand_types[self.index] = parser.parse_type()

def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None:
if state.should_emit_space or not state.last_was_punctuation:
Expand All @@ -552,21 +491,21 @@ def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> No
state.last_was_punctuation = False
state.should_emit_space = True

def get_type(self, op: IRDLOperation) -> Attribute:
return getattr(op, self.name).type
def get_types(self, op: IRDLOperation) -> Sequence[Attribute]:
return (getattr(op, self.name).type,)


class VariadicOperandDirective(VariadicLikeFormatDirective, ABC):
class VariadicOperandDirective(
VariadicLikeFormatDirective, VariadicTypeableDirective, ABC
):
"""
Base class for typechecking.
A variadic operand directive cannot follow another variadic operand directive.
"""


@dataclass(frozen=True)
class VariadicOperandVariable(
VariadicVariable, VariadicOperandDirective, VariadicTypeableDirective
):
class VariadicOperandVariable(VariadicVariable, VariadicOperandDirective):
"""
A variadic operand variable, with the following format:
operand-directive ::= ( percent-ident ( `,` percent-id )* )?
Expand All @@ -582,8 +521,18 @@ def parse_optional(self, parser: Parser, state: ParsingState) -> bool:
state.operands[self.index] = operands
return bool(operands)

def set_types(self, types: Sequence[Attribute], state: ParsingState) -> None:
def parse_single_type(self, parser: Parser, state: ParsingState) -> None:
state.operand_types[self.index] = (parser.parse_type(),)

def parse_many_types(self, parser: Parser, state: ParsingState) -> bool:
types = parser.parse_optional_undelimited_comma_separated_list(
parser.parse_optional_type, parser.parse_type
)
ret = types is None
if ret:
types = ()
state.operand_types[self.index] = types
return ret

def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None:
if state.should_emit_space or not state.last_was_punctuation:
Expand All @@ -600,10 +549,11 @@ def get_types(self, op: IRDLOperation) -> Sequence[Attribute]:
def set_empty(self, state: ParsingState):
state.operands[self.index] = ()

def set_types_empty(self, state: ParsingState) -> None:
state.operand_types[self.index] = ()

class OptionalOperandVariable(
OptionalVariable, VariadicOperandDirective, OptionalTypeableDirective
):

class OptionalOperandVariable(OptionalVariable, VariadicOperandDirective):
"""
An optional operand variable, with the following format:
operand-directive ::= ( percent-ident )?
Expand All @@ -617,8 +567,16 @@ def parse_optional(self, parser: Parser, state: ParsingState) -> bool:
state.operands[self.index] = operand
return bool(operand)

def set_type(self, type: Attribute | None, state: ParsingState) -> None:
state.operand_types[self.index] = type or ()
def parse_single_type(self, parser: Parser, state: ParsingState) -> None:
state.operand_types[self.index] = (parser.parse_type(),)

def parse_many_types(self, parser: Parser, state: ParsingState) -> bool:
type = parser.parse_optional_type()
ret = type is None
if ret:
type = ()
state.operand_types[self.index] = type
return ret

def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None:
if state.should_emit_space or not state.last_was_punctuation:
Expand All @@ -629,15 +587,18 @@ def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> No
state.last_was_punctuation = False
state.should_emit_space = True

def get_type(self, op: IRDLOperation) -> Attribute | None:
def get_types(self, op: IRDLOperation) -> Sequence[Attribute]:
operand = getattr(op, self.name)
if operand:
return operand.type
return None
return (operand.type,)
return ()

def set_empty(self, state: ParsingState):
state.operands[self.index] = ()

def set_types_empty(self, state: ParsingState) -> None:
state.operand_types[self.index] = ()


@dataclass(frozen=True)
class ResultVariable(VariableDirective, TypeableDirective):
Expand All @@ -648,11 +609,11 @@ class ResultVariable(VariableDirective, TypeableDirective):
parsing is not handled by the custom operation parser.
"""

def set_type(self, type: Attribute, state: ParsingState) -> None:
state.result_types[self.index] = type
def parse_single_type(self, parser: Parser, state: ParsingState) -> None:
state.result_types[self.index] = parser.parse_type()

def get_type(self, op: IRDLOperation) -> Attribute:
return getattr(op, self.name).type
def get_types(self, op: IRDLOperation) -> Sequence[Attribute]:
return (getattr(op, self.name).type,)


@dataclass(frozen=True)
Expand All @@ -664,29 +625,53 @@ class VariadicResultVariable(VariadicVariable, VariadicTypeableDirective):
parsing is not handled by the custom operation parser.
"""

def set_types(self, types: Sequence[Attribute], state: ParsingState) -> None:
def parse_single_type(self, parser: Parser, state: ParsingState) -> None:
state.result_types[self.index] = (parser.parse_type(),)

def parse_many_types(self, parser: Parser, state: ParsingState) -> bool:
types = parser.parse_optional_undelimited_comma_separated_list(
parser.parse_optional_type, parser.parse_type
)
ret = types is None
if ret:
types = ()
state.result_types[self.index] = types
return ret

def get_types(self, op: IRDLOperation) -> Sequence[Attribute]:
return getattr(op, self.name).types

def set_types_empty(self, state: ParsingState) -> None:
state.result_types[self.index] = ()

class OptionalResultVariable(OptionalVariable, OptionalTypeableDirective):

class OptionalResultVariable(OptionalVariable, VariadicTypeableDirective):
"""
An optional result variable, with the following format:
result-directive ::= ( percent-ident )?
This directive can not be used for parsing and printing directly, as result
parsing is not handled by the custom operation parser.
"""

def set_type(self, type: Attribute | None, state: ParsingState) -> None:
state.result_types[self.index] = type or ()
def parse_single_type(self, parser: Parser, state: ParsingState) -> None:
state.result_types[self.index] = (parser.parse_type(),)

def get_type(self, op: IRDLOperation) -> Attribute | None:
def parse_many_types(self, parser: Parser, state: ParsingState) -> bool:
type = parser.parse_optional_type()
ret = type is None
if ret:
type = ()
state.result_types[self.index] = type
return ret

def get_types(self, op: IRDLOperation) -> Sequence[Attribute]:
res = getattr(op, self.name)
if res:
return res.type
return None
return (res.type,)
return ()

def set_types_empty(self, state: ParsingState) -> None:
state.result_types[self.index] = ()


class RegionDirective(OptionallyParsableDirective, ABC):
Expand Down
13 changes: 4 additions & 9 deletions xdsl/irdl/declarative_assembly_format_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
)
from xdsl.irdl.declarative_assembly_format import (
AnchorableDirective,
AnyTypeableDirective,
AttrDictDirective,
AttributeVariable,
DefaultValuedAttributeVariable,
Expand All @@ -56,18 +55,16 @@
OptionalRegionVariable,
OptionalResultVariable,
OptionalSuccessorVariable,
OptionalTypeableDirective,
OptionalTypeDirective,
OptionalUnitAttrVariable,
ParsingState,
PunctuationDirective,
RegionDirective,
RegionVariable,
ResultVariable,
SuccessorVariable,
TypeableDirective,
TypeDirective,
VariadicLikeFormatDirective,
VariadicLikeTypeDirective,
VariadicOperandDirective,
VariadicOperandVariable,
VariadicRegionDirective,
Expand Down Expand Up @@ -200,7 +197,7 @@ def verify_directives(self, elements: list[FormatDirective]):
self.raise_error(
"A variadic directive cannot be followed by a comma literal."
)
case VariadicLikeTypeDirective(), VariadicLikeTypeDirective():
case VariadicTypeDirective(), VariadicTypeDirective():
self.raise_error(
"A variadic type directive cannot be followed by another variadic type directive."
)
Expand Down Expand Up @@ -447,7 +444,7 @@ def _parse_optional_operand(
case _:
return OperandVariable(variable_name, idx)

def parse_optional_typeable_variable(self) -> AnyTypeableDirective | None:
def parse_optional_typeable_variable(self) -> TypeableDirective | None:
"""
Parse a variable, if present, with the following format:
variable ::= `$` bare-ident
Expand Down Expand Up @@ -611,8 +608,6 @@ def parse_type_directive(self) -> FormatDirective:
self.parse_punctuation(")")
if isinstance(inner, VariadicTypeableDirective):
return VariadicTypeDirective(inner)
if isinstance(inner, OptionalTypeableDirective):
return OptionalTypeDirective(inner)
return TypeDirective(inner)

def parse_optional_group(self) -> FormatDirective:
Expand Down Expand Up @@ -707,7 +702,7 @@ def parse_keyword_or_punctuation(self) -> FormatDirective:
self.parse_characters("`")
return KeywordDirective(ident)

def parse_typeable_directive(self) -> AnyTypeableDirective:
def parse_typeable_directive(self) -> TypeableDirective:
"""
Parse a typeable directive, with the following format:
directive ::= variable
Expand Down

0 comments on commit a2cc01f

Please sign in to comment.