From 7107b3d72d2fd03973c6223685eea6fdf2e2e3eb Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Thu, 26 Dec 2024 16:41:25 +0100 Subject: [PATCH] Revert "Revert "[FRONTEND] do not hold to references unnecessarily (#5402)"" (#3067) This reverts commit ea006f2a21e95bcd53ace3505414539f8987fddd. Closes #3006 --- python/triton/_utils.py | 18 +++++++++++---- python/triton/compiler/code_generator.py | 28 ++++++++++-------------- 2 files changed, 25 insertions(+), 21 deletions(-) diff --git a/python/triton/_utils.py b/python/triton/_utils.py index 0ce1a53a70..5c69751518 100644 --- a/python/triton/_utils.py +++ b/python/triton/_utils.py @@ -22,6 +22,16 @@ def list_list_unflatten(spec: List[int], flat: List[Any]) -> List[List[Any]]: return ret +def get_iterable_path(iterable, path): + from functools import reduce + return reduce(lambda a, idx: a[idx], path, iterable) + + +def set_iterable_path(iterable, path, val): + prev = iterable if len(path) == 1 else get_iterable_path(iterable, path[:-1]) + prev[path[-1]] = val + + def find_paths_if(iterable, pred): from .language import core is_iterable = lambda x: isinstance(x, (list, tuple, core.tuple, core.tuple_type)) @@ -34,17 +44,17 @@ def _impl(current, path): _impl(item, path + (idx, )) elif pred(path, current): if len(path) == 1: - ret[(path[0], )] = current + ret[(path[0], )] = None else: - ret[tuple(path)] = current + ret[tuple(path)] = None if is_iterable(iterable): _impl(iterable, []) elif pred(list(), iterable): - ret = {tuple(): iterable} + ret = {tuple(): None} else: ret = dict() - return ret + return list(ret.keys()) def parse_list_string(s): diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 20249336e4..b5df19ca9a 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -14,8 +14,7 @@ from ..runtime.jit import _normalize_ty, get_jit_fn_file_line # ideally we wouldn't need any runtime component from ..runtime import JITFunction -from .._utils import find_paths_if, list_list_flatten, list_list_unflatten -from functools import reduce +from .._utils import list_list_flatten, list_list_unflatten, find_paths_if, get_iterable_path, set_iterable_path from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct) @@ -196,13 +195,6 @@ def visit_Call(self, node: ast.Call) -> bool: class ASTFunction: - def get_path(self, x, path): - return reduce(lambda a, idx: a[idx], path, x) - - def set_path(self, x, path, val): - prev = x if len(path) == 1 else self.get_path(x, path[:-1]) - prev[path[-1]] = val - def __init__(self, ret_types, arg_types, constexprs, constants, attrs): self.ret_types = ret_types self.arg_types = arg_types @@ -214,8 +206,8 @@ def serialize(self, builder: ir.builder): # fill up IR values in template # > build function is_val = lambda path, _: path not in self.constexprs and _ is not None - val_paths = list(find_paths_if(self.arg_types, is_val).keys()) - arg_types = [self.get_path(self.arg_types, path).to_ir(builder) for path in val_paths] + val_paths = list(find_paths_if(self.arg_types, is_val)) + arg_types = [get_iterable_path(self.arg_types, path).to_ir(builder) for path in val_paths] ret_types = [ret_type.to_ir(builder) for ret_type in self.ret_types] return builder.get_function_ty(arg_types, ret_types) @@ -228,24 +220,24 @@ def make_template(val): vals = make_template(self.arg_types) is_val = lambda path, _: path not in self.constexprs and _ is not None - val_paths = list(find_paths_if(self.arg_types, is_val).keys()) + val_paths = list(find_paths_if(self.arg_types, is_val)) # > set attributes for attr_path, attr_specs in self.attrs.items(): for attr_name, attr_val in attr_specs: if attr_path in val_paths: fn.set_arg_attr(val_paths.index(attr_path), attr_name, attr_val) for i, path in enumerate(val_paths): - ty = self.get_path(self.arg_types, path) + ty = get_iterable_path(self.arg_types, path) if isinstance(ty, nv_tma_desc_type): fn.set_arg_attr(i, "tt.nv_tma_desc", 1) # > add IR values to the template for i, path in enumerate(val_paths): - ty = self.get_path(self.arg_types, path) - self.set_path(vals, path, language.tensor(fn.args(i), ty)) + ty = get_iterable_path(self.arg_types, path) + set_iterable_path(vals, path, language.tensor(fn.args(i), ty)) # > add constexpr values to the template constants = self.constants | self.constexprs for path, val in constants.items(): - self.set_path(vals, path, language.constexpr(val)) + set_iterable_path(vals, path, language.constexpr(val)) return vals @@ -1140,7 +1132,9 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs): if isinstance(arg, (language.dtype, float, int, bool)): args[i] = language.core.constexpr(arg) args_cst = find_paths_if(args, lambda _, x: _is_constexpr(x)) - args_val = find_paths_if(args, lambda _, x: not _is_constexpr(x)).values() + args_cst = {path: get_iterable_path(args, path) for path in args_cst} + args_path = find_paths_if(args, lambda _, x: not _is_constexpr(x)) + args_val = [get_iterable_path(args, path) for path in args_path] # mangle fn_name = mangle_fn(fn.__name__, [arg.type for arg in args_val], args_cst) # generate function def if necessary