diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 4e49f57761c026..518c7205dcf3b3 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -14,6 +14,7 @@ import torch._ops from torch._inductor.runtime.runtime_utils import dynamo_timed from torch.fx.experimental.symbolic_shapes import ConvertIntKey, DivideByKey, SymTypes +from torch.utils._sympy.symbol import symbol_is_type, SymT from .. import config, ir from ..utils import _align, ALIGN_BYTES, cache_on_self, normalize_name @@ -21,7 +22,13 @@ from .aoti_hipify_utils import maybe_hipify_code_wrapper from .common import IndentedBuffer, Kernel from .cpp_utils import cexpr, DEVICE_TO_ATEN, DTYPE_TO_ATEN, DTYPE_TO_CPP -from .wrapper import EnterSubgraphLine, ExitSubgraphLine, PythonWrapperCodegen +from .triton_utils import should_unwrap_unspec_arg +from .wrapper import ( + EnterSubgraphLine, + ExitSubgraphLine, + PythonWrapperCodegen, + SymbolicCallArg, +) class CppWrapperCpu(PythonWrapperCodegen): @@ -64,7 +71,6 @@ def __init__(self): self.custom_op_wrapper_loaded = False # For GEMM kernels that must be initialized and are resolved at linking. self.initialized_kernels: Dict[str, Kernel] = {} - self.expr_printer = cexpr @staticmethod def create( @@ -1037,7 +1043,7 @@ def generate_c_shim_fallback_kernel(self, fallback_kernel, args): output_args.append(f"&{output_name}") elif isinstance(output, sympy.Expr): output_name = f"{output_name_base}_{idx}" - self.writeline(f"auto {output_name} = {self.expr_printer(output)};") + self.writeline(f"auto {output_name} = {cexpr(output)};") output_args.append(f"&{output_name}") elif output is None: output_args.append("nullptr") @@ -1117,7 +1123,7 @@ def add_benchmark_harness(self, output): super().add_benchmark_harness(output) def codegen_cpp_sizevar(self, x: Expr, *, simplify: bool = True) -> str: - return self.expr_printer(V.graph.sizevars.simplify(x) if simplify else x) + return cexpr(V.graph.sizevars.simplify(x) if simplify else x) def codegen_sizevar(self, x: Expr) -> str: return self.codegen_cpp_sizevar(x) @@ -1134,6 +1140,51 @@ def codegen_shape_tuple(self, shape: Sequence[Expr]) -> str: return f"{{{parts[0]}, }}" return f"{{{', '.join(parts)}}}" + def ensure_size_computed(self, sym: sympy.Symbol): + if isinstance(sym, sympy.Symbol) and symbol_is_type(sym, SymT.PRECOMPUTED_SIZE): + if sym in self.computed_sizes: + return + self.computed_sizes.add(sym) + expr = V.graph.sizevars.inv_precomputed_replacements[sym] + self.writeline(f"int64_t {sym} = {cexpr(expr)};") + + def generate_numel_expr(self, kernel_name: str, tree, suffix: Optional[str] = None): + expr = f"{kernel_name}_{tree.prefix}numel" + if suffix is not None: + expr += f"_{suffix}" + if (expr, V.graph) not in self.kernel_numel_expr: + # declare expr once in each graph (scope) + self.kernel_numel_expr.add((expr, V.graph)) + self.writeline(f"int64_t {expr} = {cexpr(tree.numel)};") + else: + self.writeline(f"{expr} = {cexpr(tree.numel)};") + # We can get symbolic expressions here, like s0*64 + # It is fine to have them here, but we need to handle them correctly as their own type + # This is tricky to do, so we wrap in a custom type, distinct from scalars, but also from sympy* + # scalars as well. + # This is handled in `generate_args_decl` which has a correct comment of: TODO: only works for + # constant now, need type info. I agree, this needs type info, and while this is not true type info + # it suffices as a type hint for the purposes of producing the correct code for this type. + return SymbolicCallArg(expr, tree.numel) + + def prepare_triton_kernel_call(self, device_index, call_args): + def wrap_arg(arg): + if isinstance(arg, str): + # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar + return arg + ".item()" if should_unwrap_unspec_arg(arg) else arg + elif isinstance(arg, (int, float, bool, SymbolicCallArg)): + return str(arg) + else: + return cexpr(V.graph.sizevars.simplify(arg)) + + call_args = [wrap_arg(arg) for arg in call_args] + + if device_index is None: + current_device = V.graph.get_current_device_or_throw() + device_index = current_device.index + + return device_index, call_args + def codegen_dynamic_scalar(self, node): (data,) = (t.codegen_reference() for t in node.inputs) self.codegen_tensor_item(node.inputs[0].get_dtype(), data, f"{node.sym}_raw") @@ -1291,7 +1342,7 @@ def codegen_alloc_from_pool(self, name, offset, dtype, shape, stride) -> str: tmp_name = f"tmp_tensor_handle_{next(self.tmp_tensor_id)}" args = [ name, - self.expr_printer(offset), # bytes not numel + cexpr(offset), # bytes not numel self.codegen_dtype(dtype), str(len(shape)), self.codegen_int_array_var( @@ -1616,7 +1667,7 @@ def fill_args(arg, arg_type): elif isinstance(arg_type, torch.SymIntType): # SymInt expr = arg.node.expr if isinstance(arg, torch.SymInt) else arg - new_int_args.append(self.expr_printer(expr)) + new_int_args.append(cexpr(expr)) elif isinstance(arg_type, torch.NumberType): # Scalar of type int assert isinstance(arg, (int, float, bool)) @@ -1646,9 +1697,7 @@ def fill_args(arg, arg_type): expressions = [ a.node.expr if isinstance(a, torch.SymInt) else a for a in arg ] - new_int_args.extend( - [self.expr_printer(expr) for expr in expressions] - ) + new_int_args.extend([cexpr(expr) for expr in expressions]) # List[Scalar] elif isinstance(arg_type.getElementType(), torch.NumberType): # Only treat int Scalar as dynamic @@ -1859,7 +1908,7 @@ def add_py_newref(): expr = ( raw_arg.node.expr if isinstance(raw_arg, torch.SymInt) else raw_arg ) - return f"PyLong_FromLongLong({self.expr_printer(expr)})" + return f"PyLong_FromLongLong({cexpr(expr)})" elif isinstance(arg_type, torch.FloatType): return f"PyFloat_FromDouble({self.generate_float_value(raw_arg)})" elif isinstance(arg_type, torch.BoolType): @@ -1879,7 +1928,7 @@ def add_py_newref(): return f"PyComplex_FromDoubles({raw_arg.real, raw_arg.imag})" elif isinstance(raw_arg, torch.SymInt): expr = raw_arg.node.expr - return f"PyLong_FromLongLong({self.expr_printer(expr)})" + return f"PyLong_FromLongLong({cexpr(expr)})" else: raise NotImplementedError( f"arg type {arg_type} with raw_arg {raw_arg}, {type(raw_arg)} is not yet supported by custom_op_wrapper" @@ -2103,9 +2152,9 @@ def val_to_arg_str_for_prim_type(self, val, type_) -> str: # FIXME: This happens because type_ is not always properly set to torch.ListType return f"{{{', '.join(self.val_to_arg_str(x, None) for x in val)}}}" elif isinstance(val, SymTypes): - return self.expr_printer(val.node.expr) + return cexpr(val.node.expr) elif isinstance(val, sympy.Expr): - return self.expr_printer(val) + return cexpr(val) else: return repr(val) diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py b/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py index c422af51aa02e8..6906fc851a8efa 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py @@ -12,7 +12,7 @@ from .. import config, ir from ..utils import sympy_product from ..virtualized import V -from .cpp_utils import cexpr, DTYPE_TO_CPP +from .cpp_utils import DTYPE_TO_CPP from .cpp_wrapper_cpu import CppWrapperCpu from .wrapper import ( BufferLike, @@ -73,7 +73,6 @@ def __init__(self): self.cached_output_id = count() self.scalar_to_tensor_id = count() self.custom_op_wrapper_loaded = False - self.expr_printer = cexpr self.allow_stack_allocation: Optional[ bool ] = config.aot_inductor.allow_stack_allocation diff --git a/torch/_inductor/codegen/cpp_wrapper_gpu.py b/torch/_inductor/codegen/cpp_wrapper_gpu.py index 6e4faa00df81ec..0a8dfeaa7e6242 100644 --- a/torch/_inductor/codegen/cpp_wrapper_gpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_gpu.py @@ -363,9 +363,9 @@ def process_args(arg, arg_type, arg_signature=None): ) ) elif arg_type in (sympy.Integer, int): - self.writeline(f"int {var_name} = {self.expr_printer(arg)};") + self.writeline(f"int {var_name} = {cexpr(arg)};") elif arg_type in (sympy.Float, float): - self.writeline(f"float {var_name} = {self.expr_printer(arg)};") + self.writeline(f"float {var_name} = {cexpr(arg)};") # For symbolic call arguments, examine the arg signatures from triton meta # to explicitly cast to the right type # Reason: `auto` can infer unexpected type against kernel input signature. @@ -375,10 +375,10 @@ def process_args(arg, arg_type, arg_signature=None): and arg_signature in signature2dtype.keys() ): self.writeline( - f"{signature2dtype[arg_signature]} {var_name} = {self.expr_printer(arg)};" + f"{signature2dtype[arg_signature]} {var_name} = {cexpr(arg)};" ) else: - self.writeline(f"auto {var_name} = {self.expr_printer(arg)};") + self.writeline(f"auto {var_name} = {cexpr(arg)};") new_args.append(f"&{var_name}") for arg, arg_type, arg_signature in zip_longest( diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 2ab2b3263547e4..86b35ae2a044ba 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -652,7 +652,6 @@ def __init__(self): self.move_end = ")" if V.graph.cpp_wrapper else "" self.last_seen_device_guard_index: Optional[int] = None self.supports_intermediate_hooks = True - self.expr_printer: Callable[[Any], str] = pexpr self.user_defined_kernel_cache: Dict[Tuple[Any, ...], Tuple[str, Any]] = {} self.unbacked_symbol_decls: Set[str] = set() # str of sympy.Symbol self.computed_sizes: Set[sympy.Symbol] = set() @@ -1304,9 +1303,7 @@ def ensure_size_computed(self, sym: sympy.Symbol): return self.computed_sizes.add(sym) expr = V.graph.sizevars.inv_precomputed_replacements[sym] - self.writeline( - f"{self.declare}{sym} = {self.expr_printer(expr)}{self.ending}" - ) + self.writeline(f"{sym} = {pexpr(expr)}") def finalize_prefix(self): pass @@ -1684,14 +1681,7 @@ def generate_numel_expr(self, kernel_name: str, tree, suffix: Optional[str] = No expr = f"{kernel_name}_{tree.prefix}numel" if suffix is not None: expr += f"_{suffix}" - if (expr, V.graph) not in self.kernel_numel_expr: - # declare expr once in each graph (scope) - self.kernel_numel_expr.add((expr, V.graph)) - self.writeline( - f"{self.declare}{expr} = {self.expr_printer(tree.numel)}{self.ending}" - ) - else: - self.writeline(f"{expr} = {self.expr_printer(tree.numel)}{self.ending}") + self.writeline(f"{expr} = {pexpr(tree.numel)}") # We can get symbolic expressions here, like s0*64 # It is fine to have them here, but we need to handle them correctly as their own type # This is tricky to do, so we wrap in a custom type, distinct from scalars, but also from sympy* @@ -1811,7 +1801,7 @@ def wrap_arg(arg): elif isinstance(arg, (int, float, bool, SymbolicCallArg)): return str(arg) else: - return self.expr_printer(V.graph.sizevars.simplify(arg)) + return pexpr(V.graph.sizevars.simplify(arg)) call_args = [wrap_arg(arg) for arg in call_args] diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 4947757f4379d5..d5e4be44a2eb8b 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -3884,7 +3884,7 @@ def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]: return free_unbacked_symbols(self.expr) def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str: - return V.graph.wrapper_code.expr_printer(V.graph.sizevars.simplify(self.expr)) + return V.graph.wrapper_code.codegen_sizevar(self.expr) def has_tensor_output(self) -> bool: return False