Skip to content

Commit

Permalink
[AOIT] Remove several overloaded members from WrapperCodegen (pytorch…
Browse files Browse the repository at this point in the history
…#141387)

Summary: Remove several overloaded string members from WrapperCodegen classes, including open_bracket, closed_braket, size, stride. Instead of relying on polymorphism, we explicitly generate different strings for PythonWrapperCodegen and CppWrapperCodegen. This is to prepare for one-pass AOTI CUDA codegen.

Differential Revision: [D66459991](https://our.internmc.facebook.com/intern/diff/D66459991)

Pull Request resolved: pytorch#141387
Approved by: https://github.com/chenyang78
ghstack dependencies: pytorch#141388
  • Loading branch information
desertfire authored and pytorchmergebot committed Dec 5, 2024
1 parent 4cc0fc2 commit 5f28c42
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 28 deletions.
9 changes: 3 additions & 6 deletions torch/_inductor/codegen/cpp_wrapper_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,7 @@ def __init__(self):
self.declare = "auto "
self.declare_maybe_reference = "decltype(auto) "
self.ending = ";"
self.open_bracket = "{"
self.closed_bracket = "}"
self.comment = "//"
self.namespace = "at::"
self.none_str = "nullptr"
self.size = "sizes()"
self.stride = "strides()"
Expand Down Expand Up @@ -860,7 +857,7 @@ def generate_return(self, output_refs: List[str]):
cst_names = V.graph.constants.keys()
output2idx: Dict[str, int] = {}
for idx, output in enumerate(output_refs):
if output == self.none_str:
if output == "nullptr":
continue

is_constant_buffer = output in cst_names
Expand Down Expand Up @@ -1520,7 +1517,7 @@ def codegen_subgraph_suffix(self, subgraph, outer_inputs, outer_outputs):
# in case the outer_output carried a value
# before (e.g., in the while_loop codegen)
self.writeline(f"{outer_output}.reset();")
self.writeline(f"{outer_output} = {src}{self.ending}")
self.writeline(f"{outer_output} = {src};")

def codegen_invoke_subgraph(self, invoke_subgraph):
raise NotImplementedError(
Expand Down Expand Up @@ -1570,7 +1567,7 @@ def codegen_subgraph(self, subgraph, outer_inputs, outer_outputs):
# support lifting of subgraphs as functions for cpp wrapper as well.
try:
self.push_codegened_graph(subgraph.graph)
self.writeline(f"{self.comment} subgraph: {subgraph.name}")
self.writeline(f"// subgraph: {subgraph.name}")
self.codegen_subgraph_prefix(subgraph, outer_inputs, outer_outputs)
parent_graph = V.graph
with V.set_graph_handler(subgraph.graph):
Expand Down
17 changes: 2 additions & 15 deletions torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,6 @@ def __init__(self):
if not hasattr(self, "device"):
self.device = "cpu"
super().__init__()
self.declare = "auto "
self.declare_maybe_reference = "decltype(auto) "
self.ending = ";"
self.open_bracket = "{"
self.closed_bracket = "}"
self.comment = "//"
self.namespace = "at::"
self.none_str = "nullptr"
self.size = "sizes()"
self.stride = "strides()"
self.supports_intermediate_hooks = False
self.outputs_need_copy = set()
self.kernel_callsite_id = count()
Expand Down Expand Up @@ -418,7 +408,7 @@ def use_thread_local_cached_output_tensor(idx, output):

output2idx: Dict[str, int] = {}
for idx, output in enumerate(output_refs):
if output == self.none_str:
if output == "nullptr":
continue

is_constant_buffer = output in cst_names
Expand Down Expand Up @@ -655,10 +645,7 @@ def make_buffer_reuse(self, old: BufferLike, new: BufferLike, delete_old: bool):
)
if reinterpret_view in self.stack_allocated_buffers:
self.stack_allocated_buffers[new_name] = new
return (
f"{self.declare_maybe_reference}{new_name} = std::move({reinterpret_view}){del_line}"
f" {self.comment} reuse"
)
return f"{self.declare_maybe_reference}{new_name} = std::move({reinterpret_view}){del_line} // reuse"

def generate_c_shim_extern_kernel_call(self, kernel, args):
# In the abi_compatible mode, we call fallback aten ops through a C shim layer
Expand Down
4 changes: 1 addition & 3 deletions torch/_inductor/codegen/cpp_wrapper_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,4 @@ def generate_kernel_call(
self.writeline(f"kernels.{kernel_name}({call_args_str}, {stream});")

def make_zero_buffer(self, name):
return (
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_zero_({name}.get())){self.ending}"
)
return f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_zero_({name}.get()));"
5 changes: 1 addition & 4 deletions torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,10 +641,7 @@ def __init__(self):
self.declare = ""
self.declare_maybe_reference = ""
self.ending = ""
self.open_bracket = "["
self.closed_bracket = "]"
self.comment = "#"
self.namespace = ""
self.none_str = "None"
self.size = "size()"
self.stride = "stride()"
Expand Down Expand Up @@ -1084,7 +1081,7 @@ def generate_scatter_fallback(
self.writeline(line)

def generate_index_put_fallback(self, kernel, x, indices, values, accumulate):
indices_str = f"{self.open_bracket}{', '.join(indices)}{self.closed_bracket}"
indices_str = f"[{', '.join(indices)}]"
args = [x, indices_str, values, accumulate]
self.writeline(self.wrap_kernel_call(kernel, args))

Expand Down

0 comments on commit 5f28c42

Please sign in to comment.