Skip to content

Commit

Permalink
Feat: Auto-generate inst-combine files to lean (#452)
Browse files Browse the repository at this point in the history
I auto-generated lean files from the InstCombine benchmark in LLVM. The
ultimate goal is to turn these files into bitvector theorems for Henrik.

These tests come from the LLVM repo, specifically

https://github.com/llvm/llvm-project/tree/main/llvm/test/Transforms/InstCombine

I then used `mlir-opt` to optimize these test cases and embedded them in
Lean to extract the relevant theorems behind those optimizations. See
the SSA/Projects/InstCombine/scripts` directory for the scripts that I
used to do this.

---------

Co-authored-by: Atticus Kuhn <[email protected]>
Co-authored-by: Tobias Grosser <[email protected]>
Co-authored-by: AnotherAlexHere <[email protected]>
Co-authored-by: Sasha Lopoukhine <[email protected]>
  • Loading branch information
5 people authored Jul 16, 2024
1 parent 102be68 commit ab81800
Show file tree
Hide file tree
Showing 241 changed files with 38,237 additions and 1 deletion.
2 changes: 2 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SSA/Projects/InstCombine/tests/LLVM/** linguist-generated=true
SSA/Projects/InstCombine/all.lean linguist-generated=true
1 change: 0 additions & 1 deletion SSA/Projects/InstCombine/TacticAuto.lean
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ macro "of_bool_tactic" : tactic =>
| simp only [← decide_not]
| simp only [decide_eq_decide]
| simp [of_decide_eq_true]
| simp only [BitVec.toNat_eq]
)
try omega
)
Expand Down
117 changes: 117 additions & 0 deletions SSA/Projects/InstCombine/all.lean

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

22 changes: 22 additions & 0 deletions SSA/Projects/InstCombine/scripts/all.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#!/usr/bin/env sh
# Directory containing the .lean files
DIR="./tests/LLVM"

# Output file
OUTPUT_FILE="all.lean"

# Initialize the output file
echo "" > $OUTPUT_FILE

# List all files in the directory, filter out *_proof.lean files, and process the remaining files
for file in "$DIR"/*.lean; do
# Check if the file does not end with *_proof.lean
if [[ ! $file =~ _proof\.lean$ ]]; then
# Extract the filename without the directory path
filename=$(basename "$file" .lean)
# Append the import statement to the output file
echo "import SSA.Projects.InstCombine.lean.$filename" >> $OUTPUT_FILE
fi
done

echo "all.lean file has been created with the necessary import statements."
186 changes: 186 additions & 0 deletions SSA/Projects/InstCombine/scripts/mlir-tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
#!/usr/bin/env python3
from xdsl.dialects.builtin import ModuleOp
from xdsl.dialects.llvm import LLVM
from xdsl.utils.exceptions import ParseError
from xdsl.context import MLContext
from xdsl.dialects import get_all_dialects
from xdsl.dialects.llvm import FuncOp
from xdsl.parser import Parser
from xdsl.printer import Printer
from xdsl.dialects.builtin import (
Builtin,
IndexType,
IntegerAttr,
IntegerType,
ModuleOp,
StringAttr,
i32,
i64,
)
import os
import io
from pathlib import Path
from xdsl.printer import Printer

# Initialize the MLIR context and register the LLVM dialect
ctx = MLContext(allow_unregistered=True)
ctx.load_dialect(LLVM)
ctx.load_dialect(Builtin)

directory = os.fsencode("../vcombined-mlir")


allowed_names = {
"llvm.return",
"llvm.mul",
"llvm.add",
"llvm.sub",
"llvm.shl",
"llvm.and",
"llvm.or",
"llvm.xor",
"llvm.mlir.constant",
"llvm.lshr",
"llvm.ashr",
"llvm.urem",
"llvm.srem",
"llvm.add",
"llvm.mul",
"llvm.sub",
"llvm.sdiv",
}
allowed_unregistered = set() # {


# "llvm.icmp"
# }
def allowed(op):
return (
hasattr(op, "sym_name")
or (
op.name == "builtin.unregistered"
and op.op_name.data in allowed_unregistered
)
or (op.name in allowed_names)
)


def show(block):
output = io.StringIO()
p = Printer(stream=output)
block.print(p)
contents = output.getvalue()
output.close()
return contents


def showr(region):
output = io.StringIO()
p = Printer(stream=output)
p.print_region(region)
contents = output.getvalue()
output.close()
return contents


def size(func):
return sum(1 for _ in func.walk())


def read_file(file_name):
with open(file_name, "r") as f:
return f.read()


def parse_module(module):
parser = Parser(ctx, module)
try:
return parser.parse_module()
except ParserError:
print("failed to parse the module")


def parse_from_file(file_name):
return parse_module(read_file(file_name))


for file in os.listdir(directory):
filename = os.fsdecode(file)
print(filename)
stem = "g" + filename.split(".")[0].replace("-", "h")
output = ""
module1 = parse_from_file("../vcombined-mlir")
module2 = parse_from_file("../vbefore-mlir")
funcs = [
func
for func in module1.walk()
if isinstance(func, FuncOp)
and all(allowed(o) for o in func.walk())
and size(func) > 1
]
funcs2 = {f.sym_name.data: f for f in module2.walk() if isinstance(f, FuncOp)}
for func in funcs:
other = funcs2.get(func.sym_name.data, None)
if other is None:
print(f"Cannot function function with sym name {func.sym_name}")
continue

if not all(allowed(o) for o in other.walk()):
print(f"{other.sym_name} contains unsupported operations, ignoring")
continue

s1 = showr(func.body)
s2 = showr(other.body)
name = func.sym_name.data
if s1 == s2:
continue
if "vector" in (s1 + s2):
continue
print(f"-----{filename}.{func.sym_name}-----")
o1 = f"""
def {name}_before := [llvm|
{s2}
]
def {name}_after := [llvm|
{s1}
]
theorem {name}_proof : {name}_before ⊑ {name}_after := by
unfold {name}_before {name}_after
simp_alive_peephole
simp_alive_undef
simp_alive_ops
simp_alive_case_bash
try alive_auto
---BEGIN {name}
all_goals (try extract_goal ; sorry)
---END {name}\n\n\n"""
print(o1)
write_file = os.path.join(
"../lean-mlir",
"SSA",
"Projects",
"InstCombine",
"tests",
"LLVM",
f"{stem}.lean",
)
with open(write_file, "a+") as f3:
if os.stat(write_file).st_size == 0:
f3.write(
"""
import SSA.Projects.InstCombine.LLVM.PrettyEDSL
import SSA.Projects.InstCombine.TacticAuto
import SSA.Projects.InstCombine.LLVM.Semantics
open LLVM
open BitVec
open MLIR AST
open Std (BitVec)
open Ctxt (Var)
set_option linter.deprecated false
set_option linter.unreachableTactic false
set_option linter.unusedTactic false
"""
)
f3.write(o1)
34 changes: 34 additions & 0 deletions SSA/Projects/InstCombine/scripts/process.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#!/usr/bin/env sh

# Create necessary directories
# mkdir -p combined combined-mlir before-mlir
mkdir -p vcombined vcombined-mlir vbefore-mlir

# Loop through each .ll file in the specified directory
for file in llvm-project-main/llvm/test/Transforms/InstCombine/*.ll; do
# Extract the filename without path
filename=$(basename "$file")
d="${filename%.ll}"

# Extract the RUN command
run_cmd=$(grep '^; RUN:' "$file" | head -n 1 | sed 's/^; RUN: //' | sed 's/ |.*$//')
echo $run_cmd

# If the command doesn't end with -S, add it
if [[ ! $run_cmd =~ -S$ ]]; then
run_cmd="$run_cmd -S"
fi
run_cmd="${run_cmd/-disable-output/}"

# Replace %s with the actual filename
run_cmd="${run_cmd//%s/$file}"
echo $run_cmd
# Run the extracted command and save the output
eval "$run_cmd" > "vcombined/${d}.ll"

# Convert the processed LLVM to MLIR
mlir-translate -import-llvm "vcombined/${d}.ll" | mlir-opt --mlir-print-op-generic > "vcombined-mlir/${d}.ll.mlir"

# Convert the original LLVM to MLIR
mlir-translate -import-llvm "$file" | mlir-opt --mlir-print-op-generic > "vbefore-mlir/${d}.ll.mlir"
done
Loading

0 comments on commit ab81800

Please sign in to comment.