Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: pretty printing support for mlir_region #682

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions SSA/Core/MLIRSyntax/GenericParser.lean
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ end Test
-- ==============================

declare_syntax_cat mlir_bb_operand
syntax mlir_op_operand ":" mlir_type : mlir_bb_operand
syntax mlir_op_operand ":" ppSpace mlir_type : mlir_bb_operand

syntax "[mlir_bb_operand|" mlir_bb_operand "]" : term

Expand All @@ -465,7 +465,7 @@ macro_rules



syntax (mlir_op)* : mlir_ops
syntax (mlir_op ppLine)* : mlir_ops

syntax "[mlir_op|" mlir_op "]" : term
syntax "[mlir_ops|" mlir_ops "]" : term
Expand Down Expand Up @@ -500,8 +500,8 @@ value-use-list ::= value-use (`,` value-use)*


syntax mlir_suffix_id := num <|> ident
syntax "{" ("^" mlir_suffix_id ("(" sepBy(mlir_bb_operand, ",") ")")?
":")? mlir_ops "}" : mlir_region
syntax "{" ppLine ("^" mlir_suffix_id ("(" sepBy(mlir_bb_operand, ",") ")")?
":")? ppLine mlir_ops "}" : mlir_region
syntax "[mlir_region|" mlir_region "]": term

/--
Expand Down Expand Up @@ -556,7 +556,7 @@ syntax "#" strLit : mlir_attr_val -- alias
declare_syntax_cat dialect_attribute_contents
syntax mlir_attr_val : dialect_attribute_contents
/--
Following https://mlir.llvm.org/docs/LangRef/, we define a `dialect-attribute`,
Following https://mlir.llvm.org/docs/LangRef/, we define a `dialect-attribute`,
which is a particular case of an `mlir-attr-val` that is namespaced to a particular dialect

```bnf
Expand Down
2 changes: 1 addition & 1 deletion SSA/Core/MLIRSyntax/PrettyEDSL.lean
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ Where the number of repeated `t`s is determined by the number of arguments given
It's also possible to leave out the `: $t` type annotation entirely, in which case `t` will be
assumed to be `_`, the "hole" type.
-/
syntax (mlir_op_operand " = ")? MLIR.Pretty.uniform_op mlir_op_operand,*
syntax (mlir_op_operand " = ")? MLIR.Pretty.uniform_op ppSpace mlir_op_operand,*
(" : " mlir_type)? : mlir_op
macro_rules
| `(mlir_op| $[$resName =]? $name:MLIR.Pretty.uniform_op $xs,* $[: $t]? ) => do
Expand Down
253 changes: 253 additions & 0 deletions SSA/Experimental/ASTPrettyPrinter.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
import Lean
import SSA.Core.MLIRSyntax.AST
import SSA.Core.MLIRSyntax.GenericParser
import SSA.Projects.InstCombine.LLVM.PrettyEDSL
open Lean PrettyPrinter Delaborator SubExpr MLIR AST Elab Term Syntax
open PrettyPrinter

-- AFFINE SYTAX
-- ============

@[app_unexpander AST.AffineExpr.Var]
def unexpandAffineExprVar : Unexpander
| `($_ $xstr:str) =>
let xraw := mkIdent $ Name.mkSimple xstr.getString
`([affine_expr| $xraw:ident])
| _ => throw ()


@[app_unexpander AST.AffineTuple.mk]
def unexpandAffineTuplemk : Unexpander
| `($_ [$[$terms],*]) =>
let affexprs : Array (TSyntax `affine_expr) := terms.map fun term => match term with
| `([affine_expr| $n]) => n
| _ => panic! "affine_expr is illformed"
`([affine_tuple|($affexprs,*)])
| _ => throw ()


@[app_unexpander AST.AffineMap.mk]
def unexpandAffineMapmk : Unexpander
| `($_ [affine_tuple|($xs,*)] [affine_tuple|($ys,*)]) =>
`([affine_map|affine_map< ($xs,*) -> ($ys,*)>])
| _ => throw ()

#check [affine_expr|foo]
#check [affine_tuple| (a,b,c) ]
#check [affine_map|affine_map<(a,b)->(c,d)>]


-- EDSL OPERANDS
-- ==============
-- TODO?: unexpander for [mlir_op_operatand | $$($q)]


#check [mlir_op_operand| %0]
@[app_unexpander AST.SSAVal.SSAVal]
def unexpandSSAValSSSAVal: Unexpander
| `($_ $xstr:str) =>
let xraw := mkIdent $ Name.mkSimple xstr.getString
`([mlir_op_operand| % $xraw:ident])
| `($_ $a) => match a with
| `(EDSL.IntToString $n:num) => `([mlir_op_operand| % $n:num])
| _ => throw ()
| _ => throw ()

#check [mlir_op_operand| %x]
#check [mlir_op_operand| %0]


-- EDSL OP-SUCCESSOR-ARGS
-- =================

-- successor-list ::= `[` successor (`,` successor)* `]`
-- successor ::= caret-id (`:` bb-arg-list)?


@[app_unexpander AST.BBName.mk]
def unexpandBBNamemk : Unexpander
| `($_ $xstr:str) =>
let xraw := mkIdent $ Name.mkSimple xstr.getString
`([mlir_op_successor_arg| ^ $xraw:ident ])
| _ => throw ()

#check [mlir_op_successor_arg| ^bb]


-- EDSL MLIR TYPES
-- ===============

-- TODO: Hardcoded meta-variable case?
@[app_unexpander AST.MLIRType.int]
def unexpandMLIRTypeint : Unexpander
| `($_ Signedness.Signless $n:num) =>
let xid := mkIdent $ Name.mkSimple ("i" ++ n.getNat.repr)
`([mlir_type|$xid:ident])
| _ => throw ()

#check [mlir_type| i32]


@[app_unexpander AST.MLIRType.float]
def unexpandMLIRTypefloat : Unexpander
| `($_ $n:num) =>
let xid := mkIdent $ Name.mkSimple ("f" ++ n.getNat.repr)
`([mlir_type|$xid:ident])
| _ => throw ()

#check [mlir_type| f32]


@[app_unexpander AST.MLIRType.index]
def unexpandMLIRTypefindex : Unexpander
| `($_ ) =>
let id := mkIdent $ Name.mkSimple ("index")
`([mlir_type|$id:ident])

#check [mlir_type| index]


@[app_unexpander AST.MLIRType.undefined]
def unexpandMLIRTypeundefined : Unexpander
| `($_ $x:str) =>
`([mlir_type| ! $x:str ])
| `($_ $x:ident) =>
`([mlir_type| ! $x:ident ])
| _ => throw ()

-- unexpander currently has no way of determining between idents and strings because
-- the macro sends them both to strings
#check [mlir_type| !shape.value]
#check [mlir_type| !"lz.int"]


@[app_unexpander AST.MLIRType.tensor1d]
def unexpandMLIRTypetensor1d : Unexpander
| `($_ ) =>
`([mlir_type|tensor1d])

#check [mlir_type| tensor1d]


@[app_unexpander AST.MLIRType.tensor2d]
def unexpandMLIRTypetensor2d : Unexpander
| `($_ ) =>
`([mlir_type|tensor2d])

#check [mlir_type| tensor2d]


-- === VECTOR TYPE ===
--skipping vector <> for now because it may currently have bugs



-- EDSL MLIR USER ATTRIBUTES
-- =========================


-- EDSL MLIR BASIC BLOCK OPERANDS
-- ==============================


-- EDSL MLIR BASIC BLOCKS
-- ======================


def stringToMLIRuniform_op(op_name : String): UnexpandM (TSyntax `MLIR.Pretty.uniform_op) :=
match op_name with
| "llvm.return" => `(MLIR.Pretty.uniform_op|llvm.return)
| "llvm.copy" => `(MLIR.Pretty.uniform_op|llvm.copy)
| "llvm.neg" => `(MLIR.Pretty.uniform_op|llvm.neg)
| "llvm.not" => `(MLIR.Pretty.uniform_op|llvm.not)
| "llvm.add" => `(MLIR.Pretty.uniform_op|llvm.add)
| "llvm.and" => `(MLIR.Pretty.uniform_op|llvm.and)
| "llvm.ashr" => `(MLIR.Pretty.uniform_op|llvm.ashr)
| "llvm.lshr" => `(MLIR.Pretty.uniform_op|llvm.lshr)
| "llvm.mul" => `(MLIR.Pretty.uniform_op|llvm.mul)
| "llvm.or" => `(MLIR.Pretty.uniform_op|llvm.or)
| "llvm.sdiv" => `(MLIR.Pretty.uniform_op|llvm.sdiv)
| "llvm.shl" => `(MLIR.Pretty.uniform_op|llvm.shl)
| "llvm.srem" => `(MLIR.Pretty.uniform_op|llvm.srem)
| "llvm.sub" => `(MLIR.Pretty.uniform_op|llvm.sub)
| "llvm.udiv" => `(MLIR.Pretty.uniform_op|llvm.udiv)
| "llvm.urem" => `(MLIR.Pretty.uniform_op|llvm.urem)
| "llvm.xor" => `(MLIR.Pretty.uniform_op|llvm.xor)
| _ => throw ()


def AttrDictToneg_num(attrDict : TSyntax `term) : UnexpandM (TSyntax `MLIR.EDSL.neg_num) :=
match attrDict with
| `(AttrDict.mk [AttrEntry.mk "value" $attrValue]) =>
match attrValue with
| `(AttrValue.int $val $t) =>
match val with
| `($v:num) => `(MLIR.EDSL.neg_num| $v:num)
| `(-$v:num) => `(MLIR.EDSL.neg_num| -$v:num)
| _ => throw ()
| _ => throw ()
| _ => throw ()


@[app_unexpander AST.Region.mk]
def unexpandRegionmk : Unexpander
| `($_ $xstr:str [$[$argsList],*] [$[$opsList],*]) => do
let xraw := mkIdent $ Name.mkSimple xstr.getString
let mut args : Array (TSyntax `mlir_bb_operand) := Array.empty
for term in argsList do
match term with
| `(([mlir_op_operand| $name], [mlir_type|$ty])) =>
let x ← `(mlir_bb_operand|$name:mlir_op_operand : $ty:mlir_type)
args := args.push x
| _ => throw ()
let mut ops : Array (TSyntax `mlir_op) := Array.empty
for op in opsList do
match op with
| `(Op.mk $name:str [$[$res],*] [$[$operands],*] [$[$rgnsList],*] $attrDict) =>
if name.getString == "llvm.mlir.constant"
then
match (res.get! 0) with
| `(([mlir_op_operand| $arg], [mlir_type| $ty])) =>
let neg ← AttrDictToneg_num attrDict
ops := ops.push (← `(mlir_op| $arg:mlir_op_operand = llvm.mlir.constant ($neg) : $ty))
| _ => throw ()
else
let op_function_name ← stringToMLIRuniform_op name.getString
let mlir_op_operands := operands.map fun operand => match operand with
| `(([mlir_op_operand| $arg], [mlir_type| $ty])) => arg
| _ => panic! ""
if res == Array.empty
then
match (operands.get! 0) with
| `(([mlir_op_operand| $arg], [mlir_type| $ty])) =>
ops := ops.push (← `(mlir_op| $arg:mlir_op_operand = $op_function_name $mlir_op_operands,* : $ty))
| _ => throw ()
else
match (res.get! 0) with
| `(([mlir_op_operand| $arg], [mlir_type| $ty])) =>
ops := ops.push (← `(mlir_op| $arg:mlir_op_operand = $op_function_name $mlir_op_operands,* : $ty))
| _ => throw ()
| _ => pure ()
let rgn_ops ← `(mlir_ops| $[ $ops ]*)
`([mlir_region|{^$xraw:ident ($[$args],*) : $rgn_ops}])
| _ => throw ()


-- [mlir_region|{
-- ^bb0(%arg0: i32):
-- %0 = llvm.mlir.constant(-8) : i32
-- %1 = llvm.mlir.constant(31) : i32
-- %2 = llvm.ashr %arg0, %1 : i32
-- %3 = llvm.and %2, %0 : i32
-- %4 = llvm.add %3, %2 : i32
-- %4 = llvm.return %4 : i32
-- }] : Region ?m.28125
#check [mlir_region| {
^bb0(%arg0: i32):
%0 = llvm.mlir.constant(-8) : i32
%1 = llvm.mlir.constant(31) : i32
%2 = llvm.ashr %arg0, %1 : i32
%3 = llvm.and %2, %0 : i32
%4 = llvm.add %3, %2 : i32
llvm.return %4 : i32
}]
Loading