Skip to content

Commit

Permalink
[AOTI] Remove WrapperCodegen.expr_printer (pytorch#141388)
Browse files Browse the repository at this point in the history
Summary: Avoid using expr_printer as an overriden class member for WrapperCodegen. Instead, use pexpr and cexpr explicitly for python and cpp expression print respectively. This is to prepare for one-pass AOTI CUDA codegen, where PythonWrapperCodegen is used to generate the autotune block and CppWrapperCodegen is used to generate the model code.

Differential Revision: [D66459992](https://our.internmc.facebook.com/intern/diff/D66459992)

Pull Request resolved: pytorch#141388
Approved by: https://github.com/chenyang78
  • Loading branch information
desertfire authored and pytorchmergebot committed Dec 5, 2024
1 parent 12b8c2f commit 4cc0fc2
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 33 deletions.
75 changes: 62 additions & 13 deletions torch/_inductor/codegen/cpp_wrapper_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,21 @@
import torch._ops
from torch._inductor.runtime.runtime_utils import dynamo_timed
from torch.fx.experimental.symbolic_shapes import ConvertIntKey, DivideByKey, SymTypes
from torch.utils._sympy.symbol import symbol_is_type, SymT

from .. import config, ir
from ..utils import _align, ALIGN_BYTES, cache_on_self, normalize_name
from ..virtualized import V
from .aoti_hipify_utils import maybe_hipify_code_wrapper
from .common import IndentedBuffer, Kernel
from .cpp_utils import cexpr, DEVICE_TO_ATEN, DTYPE_TO_ATEN, DTYPE_TO_CPP
from .wrapper import EnterSubgraphLine, ExitSubgraphLine, PythonWrapperCodegen
from .triton_utils import should_unwrap_unspec_arg
from .wrapper import (
EnterSubgraphLine,
ExitSubgraphLine,
PythonWrapperCodegen,
SymbolicCallArg,
)


class CppWrapperCpu(PythonWrapperCodegen):
Expand Down Expand Up @@ -64,7 +71,6 @@ def __init__(self):
self.custom_op_wrapper_loaded = False
# For GEMM kernels that must be initialized and are resolved at linking.
self.initialized_kernels: Dict[str, Kernel] = {}
self.expr_printer = cexpr

@staticmethod
def create(
Expand Down Expand Up @@ -1037,7 +1043,7 @@ def generate_c_shim_fallback_kernel(self, fallback_kernel, args):
output_args.append(f"&{output_name}")
elif isinstance(output, sympy.Expr):
output_name = f"{output_name_base}_{idx}"
self.writeline(f"auto {output_name} = {self.expr_printer(output)};")
self.writeline(f"auto {output_name} = {cexpr(output)};")
output_args.append(f"&{output_name}")
elif output is None:
output_args.append("nullptr")
Expand Down Expand Up @@ -1117,7 +1123,7 @@ def add_benchmark_harness(self, output):
super().add_benchmark_harness(output)

def codegen_cpp_sizevar(self, x: Expr, *, simplify: bool = True) -> str:
return self.expr_printer(V.graph.sizevars.simplify(x) if simplify else x)
return cexpr(V.graph.sizevars.simplify(x) if simplify else x)

def codegen_sizevar(self, x: Expr) -> str:
return self.codegen_cpp_sizevar(x)
Expand All @@ -1134,6 +1140,51 @@ def codegen_shape_tuple(self, shape: Sequence[Expr]) -> str:
return f"{{{parts[0]}, }}"
return f"{{{', '.join(parts)}}}"

def ensure_size_computed(self, sym: sympy.Symbol):
if isinstance(sym, sympy.Symbol) and symbol_is_type(sym, SymT.PRECOMPUTED_SIZE):
if sym in self.computed_sizes:
return
self.computed_sizes.add(sym)
expr = V.graph.sizevars.inv_precomputed_replacements[sym]
self.writeline(f"int64_t {sym} = {cexpr(expr)};")

def generate_numel_expr(self, kernel_name: str, tree, suffix: Optional[str] = None):
expr = f"{kernel_name}_{tree.prefix}numel"
if suffix is not None:
expr += f"_{suffix}"
if (expr, V.graph) not in self.kernel_numel_expr:
# declare expr once in each graph (scope)
self.kernel_numel_expr.add((expr, V.graph))
self.writeline(f"int64_t {expr} = {cexpr(tree.numel)};")
else:
self.writeline(f"{expr} = {cexpr(tree.numel)};")
# We can get symbolic expressions here, like s0*64
# It is fine to have them here, but we need to handle them correctly as their own type
# This is tricky to do, so we wrap in a custom type, distinct from scalars, but also from sympy*
# scalars as well.
# This is handled in `generate_args_decl` which has a correct comment of: TODO: only works for
# constant now, need type info. I agree, this needs type info, and while this is not true type info
# it suffices as a type hint for the purposes of producing the correct code for this type.
return SymbolicCallArg(expr, tree.numel)

def prepare_triton_kernel_call(self, device_index, call_args):
def wrap_arg(arg):
if isinstance(arg, str):
# dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar
return arg + ".item()" if should_unwrap_unspec_arg(arg) else arg
elif isinstance(arg, (int, float, bool, SymbolicCallArg)):
return str(arg)
else:
return cexpr(V.graph.sizevars.simplify(arg))

call_args = [wrap_arg(arg) for arg in call_args]

if device_index is None:
current_device = V.graph.get_current_device_or_throw()
device_index = current_device.index

return device_index, call_args

def codegen_dynamic_scalar(self, node):
(data,) = (t.codegen_reference() for t in node.inputs)
self.codegen_tensor_item(node.inputs[0].get_dtype(), data, f"{node.sym}_raw")
Expand Down Expand Up @@ -1291,7 +1342,7 @@ def codegen_alloc_from_pool(self, name, offset, dtype, shape, stride) -> str:
tmp_name = f"tmp_tensor_handle_{next(self.tmp_tensor_id)}"
args = [
name,
self.expr_printer(offset), # bytes not numel
cexpr(offset), # bytes not numel
self.codegen_dtype(dtype),
str(len(shape)),
self.codegen_int_array_var(
Expand Down Expand Up @@ -1616,7 +1667,7 @@ def fill_args(arg, arg_type):
elif isinstance(arg_type, torch.SymIntType):
# SymInt
expr = arg.node.expr if isinstance(arg, torch.SymInt) else arg
new_int_args.append(self.expr_printer(expr))
new_int_args.append(cexpr(expr))
elif isinstance(arg_type, torch.NumberType):
# Scalar of type int
assert isinstance(arg, (int, float, bool))
Expand Down Expand Up @@ -1646,9 +1697,7 @@ def fill_args(arg, arg_type):
expressions = [
a.node.expr if isinstance(a, torch.SymInt) else a for a in arg
]
new_int_args.extend(
[self.expr_printer(expr) for expr in expressions]
)
new_int_args.extend([cexpr(expr) for expr in expressions])
# List[Scalar]
elif isinstance(arg_type.getElementType(), torch.NumberType):
# Only treat int Scalar as dynamic
Expand Down Expand Up @@ -1859,7 +1908,7 @@ def add_py_newref():
expr = (
raw_arg.node.expr if isinstance(raw_arg, torch.SymInt) else raw_arg
)
return f"PyLong_FromLongLong({self.expr_printer(expr)})"
return f"PyLong_FromLongLong({cexpr(expr)})"
elif isinstance(arg_type, torch.FloatType):
return f"PyFloat_FromDouble({self.generate_float_value(raw_arg)})"
elif isinstance(arg_type, torch.BoolType):
Expand All @@ -1879,7 +1928,7 @@ def add_py_newref():
return f"PyComplex_FromDoubles({raw_arg.real, raw_arg.imag})"
elif isinstance(raw_arg, torch.SymInt):
expr = raw_arg.node.expr
return f"PyLong_FromLongLong({self.expr_printer(expr)})"
return f"PyLong_FromLongLong({cexpr(expr)})"
else:
raise NotImplementedError(
f"arg type {arg_type} with raw_arg {raw_arg}, {type(raw_arg)} is not yet supported by custom_op_wrapper"
Expand Down Expand Up @@ -2103,9 +2152,9 @@ def val_to_arg_str_for_prim_type(self, val, type_) -> str:
# FIXME: This happens because type_ is not always properly set to torch.ListType
return f"{{{', '.join(self.val_to_arg_str(x, None) for x in val)}}}"
elif isinstance(val, SymTypes):
return self.expr_printer(val.node.expr)
return cexpr(val.node.expr)
elif isinstance(val, sympy.Expr):
return self.expr_printer(val)
return cexpr(val)
else:
return repr(val)

Expand Down
3 changes: 1 addition & 2 deletions torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .. import config, ir
from ..utils import sympy_product
from ..virtualized import V
from .cpp_utils import cexpr, DTYPE_TO_CPP
from .cpp_utils import DTYPE_TO_CPP
from .cpp_wrapper_cpu import CppWrapperCpu
from .wrapper import (
BufferLike,
Expand Down Expand Up @@ -73,7 +73,6 @@ def __init__(self):
self.cached_output_id = count()
self.scalar_to_tensor_id = count()
self.custom_op_wrapper_loaded = False
self.expr_printer = cexpr
self.allow_stack_allocation: Optional[
bool
] = config.aot_inductor.allow_stack_allocation
Expand Down
8 changes: 4 additions & 4 deletions torch/_inductor/codegen/cpp_wrapper_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,9 +363,9 @@ def process_args(arg, arg_type, arg_signature=None):
)
)
elif arg_type in (sympy.Integer, int):
self.writeline(f"int {var_name} = {self.expr_printer(arg)};")
self.writeline(f"int {var_name} = {cexpr(arg)};")
elif arg_type in (sympy.Float, float):
self.writeline(f"float {var_name} = {self.expr_printer(arg)};")
self.writeline(f"float {var_name} = {cexpr(arg)};")
# For symbolic call arguments, examine the arg signatures from triton meta
# to explicitly cast to the right type
# Reason: `auto` can infer unexpected type against kernel input signature.
Expand All @@ -375,10 +375,10 @@ def process_args(arg, arg_type, arg_signature=None):
and arg_signature in signature2dtype.keys()
):
self.writeline(
f"{signature2dtype[arg_signature]} {var_name} = {self.expr_printer(arg)};"
f"{signature2dtype[arg_signature]} {var_name} = {cexpr(arg)};"
)
else:
self.writeline(f"auto {var_name} = {self.expr_printer(arg)};")
self.writeline(f"auto {var_name} = {cexpr(arg)};")
new_args.append(f"&{var_name}")

for arg, arg_type, arg_signature in zip_longest(
Expand Down
16 changes: 3 additions & 13 deletions torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,6 @@ def __init__(self):
self.move_end = ")" if V.graph.cpp_wrapper else ""
self.last_seen_device_guard_index: Optional[int] = None
self.supports_intermediate_hooks = True
self.expr_printer: Callable[[Any], str] = pexpr
self.user_defined_kernel_cache: Dict[Tuple[Any, ...], Tuple[str, Any]] = {}
self.unbacked_symbol_decls: Set[str] = set() # str of sympy.Symbol
self.computed_sizes: Set[sympy.Symbol] = set()
Expand Down Expand Up @@ -1304,9 +1303,7 @@ def ensure_size_computed(self, sym: sympy.Symbol):
return
self.computed_sizes.add(sym)
expr = V.graph.sizevars.inv_precomputed_replacements[sym]
self.writeline(
f"{self.declare}{sym} = {self.expr_printer(expr)}{self.ending}"
)
self.writeline(f"{sym} = {pexpr(expr)}")

def finalize_prefix(self):
pass
Expand Down Expand Up @@ -1684,14 +1681,7 @@ def generate_numel_expr(self, kernel_name: str, tree, suffix: Optional[str] = No
expr = f"{kernel_name}_{tree.prefix}numel"
if suffix is not None:
expr += f"_{suffix}"
if (expr, V.graph) not in self.kernel_numel_expr:
# declare expr once in each graph (scope)
self.kernel_numel_expr.add((expr, V.graph))
self.writeline(
f"{self.declare}{expr} = {self.expr_printer(tree.numel)}{self.ending}"
)
else:
self.writeline(f"{expr} = {self.expr_printer(tree.numel)}{self.ending}")
self.writeline(f"{expr} = {pexpr(tree.numel)}")
# We can get symbolic expressions here, like s0*64
# It is fine to have them here, but we need to handle them correctly as their own type
# This is tricky to do, so we wrap in a custom type, distinct from scalars, but also from sympy*
Expand Down Expand Up @@ -1811,7 +1801,7 @@ def wrap_arg(arg):
elif isinstance(arg, (int, float, bool, SymbolicCallArg)):
return str(arg)
else:
return self.expr_printer(V.graph.sizevars.simplify(arg))
return pexpr(V.graph.sizevars.simplify(arg))

call_args = [wrap_arg(arg) for arg in call_args]

Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -3884,7 +3884,7 @@ def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]:
return free_unbacked_symbols(self.expr)

def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str:
return V.graph.wrapper_code.expr_printer(V.graph.sizevars.simplify(self.expr))
return V.graph.wrapper_code.codegen_sizevar(self.expr)

def has_tensor_output(self) -> bool:
return False
Expand Down

0 comments on commit 4cc0fc2

Please sign in to comment.