diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 518c7205dcf3b..6d4cd8aa2b440 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -43,10 +43,7 @@ def __init__(self): self.declare = "auto " self.declare_maybe_reference = "decltype(auto) " self.ending = ";" - self.open_bracket = "{" - self.closed_bracket = "}" self.comment = "//" - self.namespace = "at::" self.none_str = "nullptr" self.size = "sizes()" self.stride = "strides()" @@ -860,7 +857,7 @@ def generate_return(self, output_refs: List[str]): cst_names = V.graph.constants.keys() output2idx: Dict[str, int] = {} for idx, output in enumerate(output_refs): - if output == self.none_str: + if output == "nullptr": continue is_constant_buffer = output in cst_names @@ -1520,7 +1517,7 @@ def codegen_subgraph_suffix(self, subgraph, outer_inputs, outer_outputs): # in case the outer_output carried a value # before (e.g., in the while_loop codegen) self.writeline(f"{outer_output}.reset();") - self.writeline(f"{outer_output} = {src}{self.ending}") + self.writeline(f"{outer_output} = {src};") def codegen_invoke_subgraph(self, invoke_subgraph): raise NotImplementedError( @@ -1570,7 +1567,7 @@ def codegen_subgraph(self, subgraph, outer_inputs, outer_outputs): # support lifting of subgraphs as functions for cpp wrapper as well. try: self.push_codegened_graph(subgraph.graph) - self.writeline(f"{self.comment} subgraph: {subgraph.name}") + self.writeline(f"// subgraph: {subgraph.name}") self.codegen_subgraph_prefix(subgraph, outer_inputs, outer_outputs) parent_graph = V.graph with V.set_graph_handler(subgraph.graph): diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py b/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py index 6906fc851a8ef..9e842728a334a 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py @@ -46,16 +46,6 @@ def __init__(self): if not hasattr(self, "device"): self.device = "cpu" super().__init__() - self.declare = "auto " - self.declare_maybe_reference = "decltype(auto) " - self.ending = ";" - self.open_bracket = "{" - self.closed_bracket = "}" - self.comment = "//" - self.namespace = "at::" - self.none_str = "nullptr" - self.size = "sizes()" - self.stride = "strides()" self.supports_intermediate_hooks = False self.outputs_need_copy = set() self.kernel_callsite_id = count() @@ -418,7 +408,7 @@ def use_thread_local_cached_output_tensor(idx, output): output2idx: Dict[str, int] = {} for idx, output in enumerate(output_refs): - if output == self.none_str: + if output == "nullptr": continue is_constant_buffer = output in cst_names @@ -655,10 +645,7 @@ def make_buffer_reuse(self, old: BufferLike, new: BufferLike, delete_old: bool): ) if reinterpret_view in self.stack_allocated_buffers: self.stack_allocated_buffers[new_name] = new - return ( - f"{self.declare_maybe_reference}{new_name} = std::move({reinterpret_view}){del_line}" - f" {self.comment} reuse" - ) + return f"{self.declare_maybe_reference}{new_name} = std::move({reinterpret_view}){del_line} // reuse" def generate_c_shim_extern_kernel_call(self, kernel, args): # In the abi_compatible mode, we call fallback aten ops through a C shim layer diff --git a/torch/_inductor/codegen/cpp_wrapper_gpu.py b/torch/_inductor/codegen/cpp_wrapper_gpu.py index 0a8dfeaa7e624..d68106a8eae01 100644 --- a/torch/_inductor/codegen/cpp_wrapper_gpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_gpu.py @@ -527,6 +527,4 @@ def generate_kernel_call( self.writeline(f"kernels.{kernel_name}({call_args_str}, {stream});") def make_zero_buffer(self, name): - return ( - f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_zero_({name}.get())){self.ending}" - ) + return f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_zero_({name}.get()));" diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 86b35ae2a044b..5be708e736874 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -641,10 +641,7 @@ def __init__(self): self.declare = "" self.declare_maybe_reference = "" self.ending = "" - self.open_bracket = "[" - self.closed_bracket = "]" self.comment = "#" - self.namespace = "" self.none_str = "None" self.size = "size()" self.stride = "stride()" @@ -1084,7 +1081,7 @@ def generate_scatter_fallback( self.writeline(line) def generate_index_put_fallback(self, kernel, x, indices, values, accumulate): - indices_str = f"{self.open_bracket}{', '.join(indices)}{self.closed_bracket}" + indices_str = f"[{', '.join(indices)}]" args = [x, indices_str, values, accumulate] self.writeline(self.wrap_kernel_call(kernel, args))