Skip to content

Commit

Permalink
Revert "Revert "[FRONTEND] do not hold to references unnecessarily (#…
Browse files Browse the repository at this point in the history
…5402)"" (#3067)

This reverts commit ea006f2.

Closes #3006
  • Loading branch information
anmyachev authored Dec 26, 2024
1 parent 6ee08cd commit 7107b3d
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 21 deletions.
18 changes: 14 additions & 4 deletions python/triton/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@ def list_list_unflatten(spec: List[int], flat: List[Any]) -> List[List[Any]]:
return ret


def get_iterable_path(iterable, path):
from functools import reduce
return reduce(lambda a, idx: a[idx], path, iterable)


def set_iterable_path(iterable, path, val):
prev = iterable if len(path) == 1 else get_iterable_path(iterable, path[:-1])
prev[path[-1]] = val


def find_paths_if(iterable, pred):
from .language import core
is_iterable = lambda x: isinstance(x, (list, tuple, core.tuple, core.tuple_type))
Expand All @@ -34,17 +44,17 @@ def _impl(current, path):
_impl(item, path + (idx, ))
elif pred(path, current):
if len(path) == 1:
ret[(path[0], )] = current
ret[(path[0], )] = None
else:
ret[tuple(path)] = current
ret[tuple(path)] = None

if is_iterable(iterable):
_impl(iterable, [])
elif pred(list(), iterable):
ret = {tuple(): iterable}
ret = {tuple(): None}
else:
ret = dict()
return ret
return list(ret.keys())


def parse_list_string(s):
Expand Down
28 changes: 11 additions & 17 deletions python/triton/compiler/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
from ..runtime.jit import _normalize_ty, get_jit_fn_file_line
# ideally we wouldn't need any runtime component
from ..runtime import JITFunction
from .._utils import find_paths_if, list_list_flatten, list_list_unflatten
from functools import reduce
from .._utils import list_list_flatten, list_list_unflatten, find_paths_if, get_iterable_path, set_iterable_path

from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct)

Expand Down Expand Up @@ -196,13 +195,6 @@ def visit_Call(self, node: ast.Call) -> bool:

class ASTFunction:

def get_path(self, x, path):
return reduce(lambda a, idx: a[idx], path, x)

def set_path(self, x, path, val):
prev = x if len(path) == 1 else self.get_path(x, path[:-1])
prev[path[-1]] = val

def __init__(self, ret_types, arg_types, constexprs, constants, attrs):
self.ret_types = ret_types
self.arg_types = arg_types
Expand All @@ -214,8 +206,8 @@ def serialize(self, builder: ir.builder):
# fill up IR values in template
# > build function
is_val = lambda path, _: path not in self.constexprs and _ is not None
val_paths = list(find_paths_if(self.arg_types, is_val).keys())
arg_types = [self.get_path(self.arg_types, path).to_ir(builder) for path in val_paths]
val_paths = list(find_paths_if(self.arg_types, is_val))
arg_types = [get_iterable_path(self.arg_types, path).to_ir(builder) for path in val_paths]
ret_types = [ret_type.to_ir(builder) for ret_type in self.ret_types]
return builder.get_function_ty(arg_types, ret_types)

Expand All @@ -228,24 +220,24 @@ def make_template(val):

vals = make_template(self.arg_types)
is_val = lambda path, _: path not in self.constexprs and _ is not None
val_paths = list(find_paths_if(self.arg_types, is_val).keys())
val_paths = list(find_paths_if(self.arg_types, is_val))
# > set attributes
for attr_path, attr_specs in self.attrs.items():
for attr_name, attr_val in attr_specs:
if attr_path in val_paths:
fn.set_arg_attr(val_paths.index(attr_path), attr_name, attr_val)
for i, path in enumerate(val_paths):
ty = self.get_path(self.arg_types, path)
ty = get_iterable_path(self.arg_types, path)
if isinstance(ty, nv_tma_desc_type):
fn.set_arg_attr(i, "tt.nv_tma_desc", 1)
# > add IR values to the template
for i, path in enumerate(val_paths):
ty = self.get_path(self.arg_types, path)
self.set_path(vals, path, language.tensor(fn.args(i), ty))
ty = get_iterable_path(self.arg_types, path)
set_iterable_path(vals, path, language.tensor(fn.args(i), ty))
# > add constexpr values to the template
constants = self.constants | self.constexprs
for path, val in constants.items():
self.set_path(vals, path, language.constexpr(val))
set_iterable_path(vals, path, language.constexpr(val))
return vals


Expand Down Expand Up @@ -1140,7 +1132,9 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs):
if isinstance(arg, (language.dtype, float, int, bool)):
args[i] = language.core.constexpr(arg)
args_cst = find_paths_if(args, lambda _, x: _is_constexpr(x))
args_val = find_paths_if(args, lambda _, x: not _is_constexpr(x)).values()
args_cst = {path: get_iterable_path(args, path) for path in args_cst}
args_path = find_paths_if(args, lambda _, x: not _is_constexpr(x))
args_val = [get_iterable_path(args, path) for path in args_path]
# mangle
fn_name = mangle_fn(fn.__name__, [arg.type for arg in args_val], args_cst)
# generate function def if necessary
Expand Down

0 comments on commit 7107b3d

Please sign in to comment.