Skip to content

Commit

Permalink
Revert "Revert "[frontend] Remove Complex Regex for MLIR Parsing (#49…
Browse files Browse the repository at this point in the history
…24)"" (#2681)

Closes #2653

This reverts commit ecc9bd4.

The issue looks like: 
```bash
loc("/tmp/pytest-of-runner/pytest-0/popen-gw0/test_convert2d_dst_layout8_int21/test_convert2d.ttgir":4:30): error: #"triton_intel_gpu"<"dpas<{repeatCount=8, systolicDepth=8, executionSize = 8, opsPerChan = 1, threadsPerWarp = 32, warpsPerCTA=[4, 1], repCluster=[1, 1]}>"> : 'none' attribute created with unregistered dialect. If this is intended, please call allowUnregisteredDialects() on the MLIRContext, or use -allow-unregistered-dialect with the MLIR opt tool used
```

---------

Signed-off-by: Anatoly Myachev <[email protected]>
  • Loading branch information
anmyachev authored Nov 15, 2024
1 parent fd5b82a commit d34a1f3
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 32 deletions.
48 changes: 48 additions & 0 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Transforms/LocationSnapshot.h"

#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "triton/Dialect/Triton/IR/Utility.h"
Expand Down Expand Up @@ -502,6 +503,16 @@ void init_triton_ir(py::module &&m) {
[](ModuleOp &self, FuncOp &funcOp) -> void {
self.push_back(funcOp);
})
.def("get_entry_func_name",
[](ModuleOp &self) -> std::string {
for (auto &op : self.getOps()) {
if (auto func = dyn_cast<FuncOp>(op)) {
if (LLVM::isKernel(func))
return func.getName().str();
}
}
return "";
})
.def("has_function",
[](ModuleOp &self, std::string &funcName) -> bool {
if (self.lookupSymbol(funcName))
Expand All @@ -512,6 +523,43 @@ void init_triton_ir(py::module &&m) {
[](ModuleOp &self, std::string &funcName) -> FuncOp {
return self.lookupSymbol<FuncOp>(funcName);
})
/*
* def ty_to_cpp(ty) is the consumer of this function.
* If the type is a ptr it expects ty[0] == '*', else the type itself.
*/

.def("get_function_signature",
[](ModuleOp &self, FuncOp &func) -> std::vector<std::string> {
std::vector<std::string> strVec;

auto type = func.getFunctionType();
unsigned numArgs = type.getNumInputs();
for (unsigned i = 0; i != numArgs; ++i) {
std::string tempType;
llvm::raw_string_ostream os(tempType);

auto ty = type.getInput(i);
if (auto attributes = func.getCallableArgAttrs()) {
Attribute attr = attributes[i];
// Check for tt.nv_tma_desc = 1
if (auto dAttr = dyn_cast<DictionaryAttr>(attr)) {
if (dAttr.contains("tt.nv_tma_desc")) {
strVec.push_back("nvTmaDesc");
continue;
}
}
}
if (auto ptrType = dyn_cast<PointerType>(ty)) {
auto pType = ptrType.getPointeeType();
os << "*";
pType.print(os);
} else {
ty.print(os);
}
strVec.push_back(tempType);
}
return strVec;
})
.def("get_int_attr",
[](ModuleOp &self, std::string name) -> py::object {
auto ret = self->getAttrOfType<IntegerAttr>(name);
Expand Down
94 changes: 94 additions & 0 deletions python/test/unit/tools/test_irsource.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import tempfile
import triton
from triton.compiler import IRSource, make_backend
from triton._C.libtriton import ir

target = triton.runtime.driver.active.get_current_target()
backend = make_backend(target)


def test_mlir_attribute_parsing() -> None:
'''
Tests that MLIR attributes are parsed correctly from input ttir/ttgir.
Checks for the following:
1. Name and type signature are parsed correctly
2. _get_num_warps_from_ir_str() works
3. tt.nv_tma_desc attribute is parsed correctly
'''

sample_ttgir = r"""
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>
#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 8]}>
#shared = #triton_gpu.shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}>
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @matmul_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
%arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32},
%arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32},
%arg3: i32 {tt.divisibility = 16 : i32},
%arg4: i32 {tt.divisibility = 16 : i32},
%arg5: i32 {tt.divisibility = 16 : i32},
%arg6: i32 {tt.divisibility = 16 : i32},
%arg7: i32 {tt.divisibility = 16 : i32},
%arg8: i32 {tt.divisibility = 16 : i32, tt.nv_tma_desc = 0 : i32},
%desc: !tt.ptr<i8, 0> {tt.nv_tma_desc = 1 : i32}) attributes {noinline = false} {
tt.return
}
}
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
f.write(sample_ttgir)
f.flush()
context = ir.context()
src = IRSource(f.name, context, backend)

# check name and type signature
# should match ty_to_cpp(...)
assert src.signature == \
{0: "*f32", 1: "*f32", 2: "*f32", 3: "i32", \
4: "i32", 5: "i32", 6: "i32", 7: "i32", 8: "nvTmaDesc", 9: "nvTmaDesc"}
assert src.name == "@matmul_kernel"

# check num warps
assert src.parse_options()['num_warps'] == 8

sample_ttgir_vector_add = r"""
#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @add_kernel(%arg0: !tt.ptr<i32> {tt.divisibility = 16 : i32},
%arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32},
%arg2: !tt.ptr<i32> {tt.divisibility = 16 : i32},
%arg3: i32 {tt.divisibility = 16 : i32})
attributes {noinline = false} {
%c1024_i32 = arith.constant 1024 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
%3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked>
%4 = arith.addi %3, %2 : tensor<1024xi32, #blocked>
%5 = tt.splat %arg3 : i32 -> tensor<1024xi32, #blocked>
%6 = arith.cmpi slt, %4, %5 : tensor<1024xi32, #blocked>
%7 = tt.splat %arg0 : !tt.ptr<i32> -> tensor<1024x!tt.ptr<i32>, #blocked>
%8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<i32>, #blocked>, tensor<1024xi32, #blocked>
%9 = tt.load %8, %6 : tensor<1024x!tt.ptr<i32>, #blocked>
%10 = tt.splat %arg1 : !tt.ptr<i32> -> tensor<1024x!tt.ptr<i32>, #blocked>
%11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr<i32>, #blocked>, tensor<1024xi32, #blocked>
%12 = tt.load %11, %6 : tensor<1024x!tt.ptr<i32>, #blocked>
%13 = arith.addi %9, %12 : tensor<1024xi32, #blocked>
%14 = tt.splat %arg2 : !tt.ptr<i32> -> tensor<1024x!tt.ptr<i32>, #blocked>
%15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr<i32>, #blocked>, tensor<1024xi32, #blocked>
tt.store %15, %13, %6 : tensor<1024x!tt.ptr<i32>, #blocked>
tt.return
}
}
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
f.write(sample_ttgir_vector_add)
f.flush()
context = ir.context()
src = IRSource(f.name, context, backend)

# now test compilation
triton.compile(f.name, target=target)
7 changes: 5 additions & 2 deletions python/triton/compiler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from .compiler import CompiledKernel, ASTSource, compile, make_backend, LazyDict
from .compiler import CompiledKernel, ASTSource, IRSource, compile, make_backend, LazyDict
from .errors import CompilationError

__all__ = ["compile", "make_backend", "ASTSource", "AttrsDescriptor", "CompiledKernel", "CompilationError", "LazyDict"]
__all__ = [
"compile", "make_backend", "ASTSource", "IRSource", "AttrsDescriptor", "CompiledKernel", "CompilationError",
"LazyDict"
]
65 changes: 35 additions & 30 deletions python/triton/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,13 @@
# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing
# zero or more arguments separated by commas, and capture it as group 2 (the argument list)
# - (attributes \{[\S\s]+\})? : optionally match attributes enclosed in braces and capture it as group 3
mlir_prototype_pattern = r"^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: [\S\s]+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*(attributes \{[\S\s]+\})?\s+\{\s*$"
ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)"
prototype_pattern = {
"ttir": mlir_prototype_pattern,
"ttgir": mlir_prototype_pattern,
"ptx": ptx_prototype_pattern,
}

mlir_arg_type_pattern = r'%\w+: ((?:[^,\s<)]+|<[^>]+>)+(?: {[^}]+})?),?'
ptx_arg_type_pattern = r"\.param\s+\.(\w+)"
arg_type_pattern = {
"ttir": mlir_arg_type_pattern,
"ttgir": mlir_arg_type_pattern,
"ptx": ptx_arg_type_pattern,
}

Expand All @@ -55,16 +49,6 @@ def convert_type_repr(x):
return x


def _get_num_warps_from_ir_str(src: str):
ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:'
# TODO(jlebar): Using a regex to get num-warps is a hack, and will break if
# e.g. someone has an instruction (not module) attribute named "num-warps".
num_warps_matches = re.findall(ttgir_num_warps_pattern, src)
assert len(num_warps_matches) == 1, "Expected exactly one match for num_warps"
num_warps = int(num_warps_matches[0])
return num_warps


class ASTSource:

def __init__(self, fn, signature, constants=None, attrs=None) -> None:
Expand Down Expand Up @@ -107,28 +91,42 @@ def parse_options(self):

class IRSource:

def __init__(self, path):
def __init__(self, path, context, backend):
self.path = path
path = Path(path)
self.ext = path.suffix[1:]
self.src = path.read_text()
match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE)
self.name = match.group(1)
signature = match.group(2)
types = re.findall(arg_type_pattern[self.ext], signature)
self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)}
ir.load_dialects(context)
backend.load_dialects(context)

# We don't have a easy-to-use PTX parser that we can use, so keep that regex for now.
# TODO - replace with a proper parser
if self.ext == "ptx":
match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE)
self.name = match.group(1)
signature = match.group(2)
types = re.findall(arg_type_pattern[self.ext], signature)
self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)}
else:
self.module = ir.parse_mlir_module(self.path, context)
fn_name = self.module.get_entry_func_name()
self.name = "@" + fn_name
funcOp = self.module.get_function(fn_name)
func_ty = self.module.get_function_signature(funcOp)
self.signature = {k: ty for k, ty in enumerate(func_ty)}

def hash(self):
return hashlib.sha256(self.src.encode("utf-8")).hexdigest()

def make_ir(self, options, codegen_fns, module_map, context):
module = ir.parse_mlir_module(self.path, context)
module.context = context
return module
self.module.context = context
return self.module

def parse_options(self):
if self.ext == "ttgir":
return {'num_warps': _get_num_warps_from_ir_str(self.src)}
num_warps = self.module.get_int_attr("triton_gpu.num-warps")
assert num_warps is not None, "Unable to parse triton_gpu.num-warps attribute"
return {'num_warps': num_warps}
return dict()


Expand Down Expand Up @@ -225,7 +223,9 @@ def compile(src, target=None, options=None):
# create backend
if ir_source:
assert isinstance(src, str), "source must be either AST or a filepath"
src = IRSource(src)
context = ir.context()
src = IRSource(src, context, backend)

extra_options = src.parse_options()
options = backend.parse_options(dict(options or dict(), **extra_options))
# create cache manager
Expand Down Expand Up @@ -266,9 +266,14 @@ def compile(src, target=None, options=None):
# when the source is an IR file, don't apply the passes related to this stage. This makes it easier to write IR level tests.
if ir_source:
first_stage += 1
context = ir.context()
ir.load_dialects(context)
backend.load_dialects(context)

# For IRSource, we have already grabbed the context + called both
# ir.load_dialects and backend.load_dialects.
if not isinstance(src, IRSource):
context = ir.context()
ir.load_dialects(context)
backend.load_dialects(context)

codegen_fns = backend.get_codegen_implementation()
module_map = backend.get_module_map()
try:
Expand Down

0 comments on commit d34a1f3

Please sign in to comment.