diff --git a/dipu/SupportedDiopiFunctions.txt b/dipu/SupportedDiopiFunctions.txt index 547e75955..4d75be61b 100644 --- a/dipu/SupportedDiopiFunctions.txt +++ b/dipu/SupportedDiopiFunctions.txt @@ -44,7 +44,6 @@ diopiBitwiseOrInp diopiBitwiseOrInpScalar diopiBitwiseOrScalar diopiBmm -diopiCastDtype diopiCat diopiCdist diopiCdistBackward diff --git a/dipu/scripts/autogen_diopi_wrapper/autogen_diopi_wrapper.py b/dipu/scripts/autogen_diopi_wrapper/autogen_diopi_wrapper.py index 0adcb4eb1..a5902ddf3 100644 --- a/dipu/scripts/autogen_diopi_wrapper/autogen_diopi_wrapper.py +++ b/dipu/scripts/autogen_diopi_wrapper/autogen_diopi_wrapper.py @@ -1,10 +1,10 @@ # Copyright (c) 2023, DeepLink. +import textwrap import yaml import re import json import os -from collections import OrderedDict -from typing import Mapping, Match, Optional, Sequence +from typing import Mapping, Match, Optional, Sequence, Tuple, List from diopi_wrapper_template import ( diopi_wrapper_file_template_content, diopi_wrapper_function_template_content, @@ -77,23 +77,382 @@ def replace(match: Match[str]) -> str: return self.substitution.sub(replace, self.pattern) +def parse_function_signature( + schema: str, +) -> Tuple[str, List[Tuple[str, str, str]], List[Tuple[str, str]]]: + # Note: switch to parser if rules grow complex. + + def unwarp_brackets(input: str) -> str: + input = input.strip() + while input.startswith("("): + input = input.removeprefix("(").removesuffix(")").strip() + return input + + def parse_type_name_value_tuple(input: str) -> List[Tuple[str, str, str]]: + output = [] + for pair in input.split(","): + # for 'type name = value' + left, _, value = pair.partition("=") + left, _, right = left.strip().rpartition(" ") + left = left.strip() + right = right.strip() + type = left if left != "" else right + name = right if left != "" else "" + if type != "": + output.append((type, name, value.strip())) + return output + + function_name, _, others = schema.partition("(") + parameters, _, return_values = others.rpartition("->") + parameters, _, _ = parameters.rpartition(")") + parameters = parse_type_name_value_tuple(parameters) + return_values = parse_type_name_value_tuple(unwarp_brackets(return_values)) + return_values = [(a, b) for a, b, _ in return_values] + return (function_name, parameters, return_values) + + +def generate_cpp_function_return_values( + parameters: List[Tuple[str, str, str]], return_values: List[Tuple[str, str]] +) -> List[str]: + REFERENCE_PATTERN = re.compile(r"\w+(\([a-z]+!\))") + + output = [] + for i, (type, name) in enumerate(return_values): + if reference := REFERENCE_PATTERN.match(type): + tag = reference.group(1) + target = (name for type, name, _ in parameters if tag in type) + if found := next(target, None): + output.append(found) + + elif re.match(r"\w+", type): + if name == "": + name = "out" + ("" if len(return_values) == 1 else str(i)) + output.append(name) + + elif "bool" == type: + output.append("out") + + return output + + +def generate_code_to_process_symint_array( + name_list: List[str], vector_suffix: str = "Vector", size_suffix: str = "DiopiSize" +) -> str: + HEAD = r""" +auto to_int = [](const c10::SymInt& t) -> int64_t { return t.expect_int(); };""" + BODY = r""" +auto {0}{1} = c10::DimVector({0}.size()); +std::transform({0}.cbegin(), {0}.cend(), {0}{1}.begin(), to_int); +auto {0}{2} = ::diopiSize_t{{{0}{1}.data(), static_cast({0}{1}.size())}};""" + + if len(name_list) == 0: + return "" + return HEAD + "".join( + BODY.format(name, vector_suffix, size_suffix) for name in name_list + ) + + +def generate_parameters_from_schema(schema: str) -> List[str]: + MAPPING = [ + (re.compile(p), r) + for p, r in [ # Order matters + # preprocess + (r"\*", r""), + (r"=.+", r""), + # process + (r"^Tensor\?\[\]", r"const c10::List>&"), + (r"^Tensor\?$", r"const c10::optional&"), + (r"^Tensor\[\]", r"at::ArrayRef"), + (r"^Tensor\([a-z]!\)\[\]", r"at::ArrayRef"), + (r"^Tensor\([a-z]!\)", r"at::Tensor&"), + (r"^Tensor$", r"const at::Tensor&"), + (r"^SymInt\[\d*\]\?", r"at::OptionalSymIntArrayRef"), + (r"^SymInt\[\d*\]", r"c10::SymIntArrayRef"), + (r"^SymInt", r"c10::SymInt"), + (r"^ScalarType", r"at::ScalarType"), + (r"^Scalar\?", r"const c10::optional&"), + (r"^Scalar\[\]", r"at::ArrayRef"), + (r"^Scalar", r"const at::Scalar&"), + (r"^Layout", r"at::Layout"), + (r"^Generator", r"at::Generator"), + (r"^Device", r"c10::Device"), + (r"^str", r"c10::string_view"), + (r"^int\[\d*\]\?", r"at::OptionalIntArrayRef"), + (r"^int\[\d*\]", r"at::IntArrayRef"), + (r"^int", r"int64_t"), + (r"^float", r"double"), + (r"^bool\[(\d+)\]", r"::std::array"), + # post-process + (r"([\w:]+)\?", r"c10::optional<\1>"), + (r"\(\w+!\)", r"&"), + ] + ] + + output = [] + _, parameters, _ = parse_function_signature(schema) + for type, name, _ in parameters: + for pattern, replacement in MAPPING: + type = pattern.sub(replacement, type) + if type != "": # Drop '*' type + output.append((type + " " + name).strip()) + return output + + +def generate_cpp_function_name( + name: str, multiple_return_values: bool, fallback: bool +) -> str: + MAPPING = [ # Order matters + (r"_?\.from", ""), + (r"_?\.to", ""), + (r"_mode", ""), + (r"(_foreach_\w+)\.List", r"\1"), + (r"\.(Scalar)?(Tensor)?\w*_out", "_outf"), + (r"\.correction", ""), + (r"\.dim_IntList", ""), + (r"\.dim_max", "_outf"), + (r"\.dim_min", "_outf"), + (r"\.dim", ""), + (r"\.grad_input", "_outf"), + (r"\.input", ""), + (r"\.out\w*", "_outf"), + (r"\.ScalarList", ""), + (r"\.Scalar", ""), + (r"\.self", ""), + (r"\.Tensor_Scalar_out", "_outf"), + (r"\.Tensor_Scalar", ""), + (r"\.Tensor_Tensor", ""), + (r"\.Tensor", ""), + (r"\.values_stable", "_outf"), + (r"\.values", "_outf"), + (r"ctc_loss\.IntList", "ctc_loss"), + (r"\.", "_"), + ] + + if fallback: + return "custom_fallback_" + create_fun_name_from_schema(name) + + function_name = get_op_name_from_schema(name) + for pattern, replacement in MAPPING: + function_name = re.sub(pattern, replacement, function_name) + if function_name.endswith("_") and multiple_return_values > 0: + function_name = function_name.removesuffix("_") + return "at::" + function_name + + +def generate_cpp_function_return_values( + parameters: List[Tuple[str, str, str]], return_values: List[Tuple[str, str]] +) -> List[str]: + REFERENCE_PATTERN = re.compile(r"\w+(\([a-z]+!\))") + + output = [] + for i, (type, name) in enumerate(return_values): + if reference := REFERENCE_PATTERN.match(type): + label = reference.group(1) + if found := next( + (name for type, name, _ in parameters if label in type), None + ): + output.append(found) + + elif re.match(r"\w+", type): + if name == "": + name = "out" + ("" if len(return_values) == 1 else str(i)) + output.append(name) + + elif "bool" == type: + output.append("out") + + return output + + +def generate_cpp_variable_for_return_values( + return_values: List[Tuple[str, str]], *, suffix: str = "_rv" +) -> str: + output = [name + suffix for _, name in return_values] + if len(output) == 0: + return "result" + suffix + if len(output) == 1: + return output[0] + return "[" + ", ".join(output) + "]" + + +def generate_function_call_with_cpu_tensors( + function_name: str, + parameters: List[Tuple[str, str, str]], + return_values_count: int, + use_custom_fallback: bool, + *, + output_variable_name: str = "result_cpu", + symint_array_suffix: str = "Vector", +) -> str: + SYMINT_ARRAY_PATTERN = re.compile(r"SymInt\[\d?\]") + + function_name = generate_cpp_function_name( + function_name, return_values_count > 0, use_custom_fallback + ) + + bullets = list(filter(lambda x: x != "", (name for _, name, _ in parameters))) + + symint_parameters = [name for type, name, _ in parameters if "SymInt" == type] + symint_array_parameters = [ + name for type, name, _ in parameters if SYMINT_ARRAY_PATTERN.match(type) + ] + tensor_parameters = [name for type, name, _ in parameters if "Tensor" in type] + for index, bullet in enumerate(bullets): + if bullet in symint_parameters: + bullets[index] += ".expect_int()" + elif bullet in symint_array_parameters: + bullets[index] += symint_array_suffix + elif bullet in tensor_parameters: + bullets[index] += "_cpu" + elif bullet == "device": + bullets[index] = "at::kCPU" + + output = "" + if len(symint_array_parameters) != 0: + output += generate_code_to_process_symint_array(symint_array_parameters) + "\n" + if return_values_count != 0: + output += f"auto {output_variable_name} = " + output += function_name + "(" + ", ".join(bullets) + ");\n" + return output + + +def generate_code_to_clone_to_host_tensors( + function_name: str, parameters: List[Tuple[str, str, str]] +) -> Tuple[str, List[Tuple[int, str]], List[Tuple[int, str]]]: + TENSOR_ARRAY_PATTERN = re.compile(r"Tensor\??\[\]") + MUTABLE_TENSOR_PATTERN = re.compile(r"Tensor\([a-z]!\)") + MUTABLE_TENSOR_ARRAY_PATTERN = re.compile(r"Tensor\([a-z]!\)\[\]") + + output = [] + mutable_tensor = [] + mutable_tensor_array = [] + + maybe_inplace = [] if ".out" in function_name or "_out" in function_name else None + for index, (type, name, _) in enumerate(parameters): + if type == "Tensor": + output.append(f"auto {name}_cpu = tensor_clone_to_host({name});") + if maybe_inplace is not None: + maybe_inplace.append((name, name + "_cpu")) + + elif ( + type == "Tensor?" + or "ITensorListRef" in type + or "TensorList" in type + or TENSOR_ARRAY_PATTERN.match(type) + ): + output.append(f"auto {name}_cpu = tensor_clone_to_host({name});") + + elif MUTABLE_TENSOR_ARRAY_PATTERN.match(type): + mutable_tensor_array.append((index, name + "_cpu")) + output.append(f"auto {name}_cpu = tensor_clone_to_host({name});") + + elif MUTABLE_TENSOR_PATTERN.match(type): + mutable_tensor.append((index, name + "_cpu")) + if maybe_inplace: + candidates = ", ".join(f"{{{a}, {b}}}" for a, b in maybe_inplace) + output.append( + f"auto {name}_cpu = tensor_reference_or_clone_to_host({name}, {{{candidates}}});" + ) + else: + output.append(f"auto {name}_cpu = tensor_clone_to_host({name});") + + return "\n".join(output) + "\n", mutable_tensor, mutable_tensor_array + + +def generate_code_to_copy_tensor_from_cpu_to_device( + parameters: List[Tuple[str, str, str]], + return_values: List[Tuple[str, str]], + mutable_tensor: List[Tuple[int, str]], + mutable_tensor_array: List[Tuple[int, str]], +) -> str: + TENSOR_PATTERN = re.compile(r"Tensor(\([a-z]!\))?") + TENSOR_ARRAY_PATTERN = re.compile(r"Tensor\[\]") + + HEAD = "auto current_stream = dipu::getCurrentDIPUStream();\n" + COPY_TENSOR = "tensor_copy_host_to_device({0}, {1}, current_stream);\n" + COPY_TENSOR_ARRAY = """\ +decltype(auto) {0}_vec = tensor_array_to_vector({0}); +for (auto i = std::size_t{{}}; i < {0}.size(); ++i) {{ + tensor_copy_host_to_device({0}_vec[i], {1}[i], current_stream); +}} +""" + + output = HEAD + for type, name in return_values: + if TENSOR_ARRAY_PATTERN.match(type): + output += COPY_TENSOR_ARRAY.format(name, name + "_rv") + elif TENSOR_PATTERN.match(type): + output += COPY_TENSOR.format(name, name + "_rv") + else: + output += f"{name} = {name}_rv;" + + for index, cpu_name in mutable_tensor: + type, name, _ = parameters[index] + if match := TENSOR_PATTERN.match(type): + if tag := match.group(1): + if any(True for type, _ in return_values if tag in type): + continue + output += COPY_TENSOR.format(name, cpu_name) + + for index, cpu_name in mutable_tensor_array: + _, name, _ = parameters[index] + output += COPY_TENSOR_ARRAY.format(name, cpu_name) + + return output + + +def generate_code_fallback_to_cpu(config: dict) -> str: + maybe_fallback = ( + config.get("enable_fallback_to_cpu", True) + and config.get("enable_autocompare", True) + and config.get("register_operator", True) + ) + if not maybe_fallback: + return "" + + schema = config["schema"] + use_custom_fallback = config.get("custom_fallback", False) + function_name, parameters, return_values = parse_function_signature(schema) + return_value_names = generate_cpp_function_return_values(parameters, return_values) + return_values = [(a, c) for (a, _), c in zip(return_values, return_value_names)] + + copy_to_cpu_code, mutable_tensor, mutable_tensor_array = ( + generate_code_to_clone_to_host_tensors(function_name, parameters) + ) + + function_call_code = generate_function_call_with_cpu_tensors( + function_name, + parameters, + len(return_values), + use_custom_fallback, + output_variable_name=generate_cpp_variable_for_return_values(return_values), + ) + + copy_to_device_code = generate_code_to_copy_tensor_from_cpu_to_device( + parameters, return_values, mutable_tensor, mutable_tensor_array + ) + + code = copy_to_cpu_code + "\n" + function_call_code + "\n" + copy_to_device_code + return f"""if (ret == ::diopiForceFallbackToCPU) {{ +{textwrap.indent(code, ' ')} + ret = ::diopiSuccess; +}}""" + + def get_fun_name_from_cppsignature(cppnature): return re.search(r"[a-zA-Z_:]+[\w\d:]+\(", cppnature).group().replace("(", "") -def get_op_name_from_schema(schema): - op_name = schema[0 : schema.find("(")] - op_name = re.sub("aten::", "", op_name) - return op_name +def get_op_name_from_schema(schema: str) -> str: + name, _, _ = parse_function_signature(schema) + name = name.strip().removeprefix("aten::") + return name -def create_fun_name_from_schema(schema): - schema = schema.strip() - op_name = schema[0 : schema.find("(")] - op_name = op_name.replace(".", "_") - op_name = "dipu_" + re.sub("aten::", "", op_name) - op_name = op_name.lower() - return op_name +def create_fun_name_from_schema(schema: str) -> str: + name = get_op_name_from_schema(schema) + name = name.replace(".", "_").lower() + return "dipu_" + name def create_return_code_frome_schema(schema, allow_return_ref=True): @@ -113,64 +472,10 @@ def create_return_code_frome_schema(schema, allow_return_ref=True): return return_code -def create_transform_input_to_cpu_code(fun_config): - input_process_code = "" - schema = fun_config["schema"] - opname = get_op_name_from_schema(schema) - inputs = re.findall("Tensor +([\w\d_]+)", schema[: schema.find("->")]) - for input in inputs: - input_process_code += ( - f"at::Tensor {input}_cpu = toCpuTensorWithoutDiopiCopy({input});\n" - ) - - optional_inputs = re.findall("Tensor *\? +([\w\d_]+)", schema[: schema.find("->")]) - for input in optional_inputs: - input_process_code += f"\nc10::optional {input}_cpu = {input}.has_value() && {input}.value().defined() ? c10::make_optional(toCpuTensorWithoutDiopiCopy({input}.value())) : {input};\n" - - optional_tensor_list_inputs = re.findall( - "Tensor *\? *\[ *\] +([\w\d_]+)", schema[: schema.find("->")] - ) - for input in optional_tensor_list_inputs: - input_process_code += f"\nc10::List> {input}_cpu;\n" - input_process_code += f"for (int i = 0; i < {input}.size();++i)" + " {\n" - input_process_code += f" {input}_cpu.push_back({input}[i].has_value() && {input}[i].value().defined() ? c10::make_optional(toCpuTensorWithoutDiopiCopy({input}[i].value())) : {input}[i]);\n" - input_process_code += "}\n" - - outputs = re.findall( - "Tensor\([a-z]!\)[ ]+([\w\d_]+){1}", schema[: schema.find("->")] - ) - for output in outputs: - input_process_code += ( - f"at::Tensor {output}_cpu = toCpuTensorWithoutDiopiCopy({output});\n" - ) - if ".out" in opname or "_out" in opname: - for i in range(len(inputs)): - input_process_code += ( - f"if (({inputs[i]}.data_ptr()) == {output}.data_ptr())" - ) - input_process_code += "{\n\t" - input_process_code += f"{inputs[i]}_cpu = {output}_cpu;\n\t" - input_process_code += "}\n" - - tensors_arrays = re.findall( - "Tensor *\[ *\] * +([\w\d_]+)", schema[: schema.find("->")] - ) - tensors_arrays += re.findall( - "ITensorListRef *&? +([\w\d_]+)", schema[: schema.find("->")] - ) - tensors_arrays += re.findall( - "Tensor *\([a-z]!\) *\[ *\] +([\w\d_]+)", schema[: schema.find("->")] - ) - if len(tensors_arrays) > 0: - for tensors_arg in tensors_arrays: - input_process_code += ( - f"std::vector {tensors_arg}_cpu({tensors_arg}.size());\n" - ) - input_process_code += ( - f"std::transform({tensors_arg}.begin(), {tensors_arg}.end(), {tensors_arg}_cpu.begin(), [](const at::Tensor& tensor)" - + "{return toCpuTensorWithoutDiopiCopy(tensor);});\n" - ) - return input_process_code +def create_transform_input_to_cpu_code(config: dict) -> str: + function_name, parameters, _ = parse_function_signature(config["schema"]) + code, _, _ = generate_code_to_clone_to_host_tensors(function_name, parameters) + return code def create_print_op_args_code(fun_config): @@ -190,57 +495,13 @@ def create_print_op_args_code(fun_config): return code -def create_param_list_from_schema(schema): - param_list = schema[schema.find("(") + 1 : schema.find("->")].strip() - param_list = param_list[0 : param_list.rfind(")")] - args_type_map = OrderedDict( - { - "Tensor\([a-z]\)": "Tensor", - "Scalar *\[ *\]": "at::ArrayRef", - "Tensor *\( *[a-z]\!\) *\[ *\]": "at::ArrayRef", - "[ ]*\([a-zA-Z]!\)": "&", - "MemoryFormat\?": "const c10::optional", - "str\?": "c10::optional", - "([, \(]{1})str ": R"\1c10::string_view ", - "ScalarType[ ]*\?": "c10::optional", - "ScalarType[ ]+([\w\d_]+)": R"at::ScalarType \1", - "Scalar[ ]*\? *([\w\d_]+)": R"const c10::optional& \1", - "Generator ?\?": "c10::optional", - "Device ?\?": "c10::optional", - "Device": "c10::Device", - "Layout ?\?": "c10::optional", - "Tensor ?\? *\[ *\]": R"const c10::List>&", - "Tensor ?\?": "const c10::optional&", - "int ?\?": "c10::optional", - "float ?\?": "c10::optional", - "([\(, ]*)int ([\w\d_]+)": R"\1int64_t \2", - "([\(, ]*)float ([\w\d_]+)": R"\1double \2", - "([\(, ]*)SymInt ([\w\d_]+)": R"\1c10::SymInt \2", - "([\(, ]*)SymInt *\[[ \d]*\] ([\w\d_]+)": R"\1c10::SymIntArrayRef \2", - "([\(, ]*)SymInt *\[[ \d]*\] *\? +([\w\d_]+)": R"\1at::OptionalSymIntArrayRef \2", - "int\[\d*\] +([\w\d_]+)": R"at::IntArrayRef \1", - "([a-zA-Z0-9]+)\?": R"c10::optional<\1>", - "Tensor *\[ *\]": "at::ArrayRef", - "Tensor[ ]*& +([\w\d_]+)": R"at::Tensor& \1", - "Tensor[ ]+([\w\d_]+)": R"const at::Tensor& \1", - "Scalar ": R"const at::Scalar& ", - "([, \(]+)int\[\d\]\?": R"\1at::OptionalIntArrayRef", - "int *\[ *\d+\ *]": "at::IntArrayRef", - "bool\[(\d+)\]": R"::std::array", - "\*[ ,]+": "", - "\=[ ]*\[ *\]": "", - "=[ ]*'?\w*-?\.?[\d ]*'?": "", - } - ) - for pattern, cpp_type in args_type_map.items(): - param_list = re.sub(str(pattern), str(cpp_type), param_list) - return param_list +def create_param_list_from_schema(schema: str) -> str: + return ", ".join(generate_parameters_from_schema(schema)) def get_function_inputs_from_schema(schema): - param_list = create_param_list_from_schema(schema) ins = [] - for args in param_list.split(","): + for args in generate_parameters_from_schema(schema): args = args.strip() tensor_match_result = re.search("Tensor[ ]*&+", args) if tensor_match_result is not None: @@ -309,9 +570,8 @@ def get_function_optional_generator_args_from_schema(schema): def get_function_int_array_args_from_schema(schema): - param_list = create_param_list_from_schema(schema) int_arrays = [] - for args in param_list.split(","): + for args in generate_parameters_from_schema(schema): args = args.strip() match_result = re.search("[^Optional]SymIntArray[\w\d]*", args) if match_result is not None: @@ -322,30 +582,9 @@ def get_function_int_array_args_from_schema(schema): return int_arrays -def get_function_return_param_from_schema(schema): - return_schema = schema[schema.find("->") + 2 :].strip() - params = [] - return_params = return_schema.split(",") - for i in range(len(return_params)): - args = return_params[i] - inplace_match = re.search("Tensor\([a-zA-Z]+!\)", args) - pure_out_match = re.search("Tensor[ ,]?", args) - bool_out_match = re.search("bool", args) - if inplace_match is not None: - arg_label = re.sub(".*(\(.*\))", r"\1", inplace_match.group()) - index = schema.find(arg_label) + len(arg_label) - param = re.search("[a-zA-Z0-9_::]+", schema[index:]).group() - params.append(param) - elif pure_out_match is not None: - name_from_schema = re.sub("\(?Tensor[ ]+([\w\d_]+)\)?", R"\1", args) - if name_from_schema == args: - name = "out" + (str(i) if len(return_params) > 1 else "") - else: - name = name_from_schema - params.append(name) - elif bool_out_match is not None: - params.append("out") - return params +def get_function_return_param_from_schema(schema: str) -> List[str]: + _, parameters, return_values = parse_function_signature(schema) + return generate_cpp_function_return_values(parameters, return_values) def create_call_diop_interface_code_from_schema(schema): @@ -416,16 +655,10 @@ def create_cpp_signature_from_schema(schema): return cppsignature -def create_args_name_list_from_schema(schema): - code = "" - param_list = create_param_list_from_schema(schema) - args_list = re.findall("([\w\d_<>:& ]+ )([\w\d_]+)", param_list) - for i in range(len(args_list)): - arg_type, arg_name = args_list[i] - code += arg_name - if i < len(args_list) - 1: - code += ", " - return code +def create_args_name_list_from_schema(schema: str) -> str: + _, output, _ = parse_function_signature(schema) + output = filter(lambda x: x != "", (name for _, name, _ in output)) + return ", ".join(output) def create_call_cpp_function_code_from_schema(schema): @@ -438,97 +671,13 @@ def create_call_cpp_function_code_from_schema(schema): return code -def create_call_aten_cpu_cpp_function_code_from_config(fun_config): - schema = fun_config["schema"] - opname = get_op_name_from_schema(schema) - opname = re.sub("\.ScalarList", "", opname) - opname = re.sub("(_foreach_[\w\d_]+_?)\.List", R"\1", opname) - opname = re.sub("ctc_loss\.IntList", "ctc_loss", opname) - opname = re.sub("\.(Scalar)?(Tensor)?[\w_\d]*_out", "_outf", opname) - opname = re.sub("\.out[\w_\d]*", "_outf", opname) - opname = re.sub("\.Tensor_Scalar_out", "_outf", opname) - opname = re.sub("\.Tensor_Tensor", "", opname) - opname = re.sub("\.Tensor_Scalar", "", opname) - opname = re.sub("\.Tensor", "", opname) - opname = re.sub("_?\.to", "", opname) - opname = re.sub("_?\.from", "", opname) - opname = re.sub("_mode", "", opname) - opname = re.sub("\.Scalar", "", opname) - opname = re.sub("\.self", "", opname) - opname = re.sub("\.values_stable", "_outf", opname) - opname = re.sub("\.values", "_outf", opname) - opname = re.sub("\.grad_input", "_outf", opname) - opname = re.sub("\.dim_max", "_outf", opname) - opname = re.sub("\.dim_min", "_outf", opname) - opname = re.sub("\.correction", "", opname) - opname = re.sub("\.input", "", opname) - opname = re.sub("\.dim_IntList", "", opname) - opname = re.sub("\.dim", "", opname) - opname = opname.replace(".", "_") - opname = opname.split(".")[0] - if opname[-1] == "_" and len(get_function_return_param_from_schema(schema)) > 0: - opname = opname[0 : len(opname) - 1] - - sym_int_array_params = re.findall("[ ,\)]?SymInt\[\d?\] *([\w\d_]+)", schema) - if len(sym_int_array_params) > 0: - sym_int_process_code = ( - create_int_array_process_code(sym_int_array_params) + "\n" - ) - else: - sym_int_process_code = "" - if fun_config.get("custom_fallback", False) == True: - opname = "custom_fallback_" + create_fun_name_from_schema(schema) - else: - opname = "at::" + opname - code = "" - if len(get_function_return_param_from_schema(schema)) > 0: - code = "auto " + "result_cpu = " - code += opname + "(" + create_args_name_list_from_schema(schema) + ");" - for sym_int_param in sym_int_array_params: - code = code.replace(sym_int_param, sym_int_param + "Vector") - - code = sym_int_process_code + code - - sym_int_params = re.findall("[ ,\)]?SymInt\ *([\w\d_]+)", schema) - for sym_int_param in sym_int_params: - code = re.sub( - "([ ,\(])?" + sym_int_param + "([, \)])?", - R"\1" + sym_int_param + R".expect_int()\2", - code, - ) - - if "device" in code: - code = code.replace("device", "at::kCPU") - - inputs = re.findall("Tensor +([\w\d_]+)", schema[: schema.find("->")]) - optional_inputs = re.findall("Tensor *\? +([\w\d_]+)", schema[: schema.find("->")]) - outputs = re.findall( - "Tensor\([a-z]!\)[ ]+([\w\d_]+){1}", schema[: schema.find("->")] - ) - tensors_arrays = re.findall( - "Tensor *\[ *\] * +([\w\d_]+)", schema[: schema.find("->")] - ) - tensors_arrays += re.findall( - "ITensorListRef *&? +([\w\d_]+)", schema[: schema.find("->")] - ) - tensors_arrays += re.findall( - "Tensor *\([a-z]!\) *\[ *\] +([\w\d_]+)", schema[: schema.find("->")] +def create_call_aten_cpu_cpp_function_code_from_config(config: dict) -> str: + schema = config["schema"] + function_name, parameters, return_values = parse_function_signature(schema) + use_custom_fallback = config.get("custom_fallback", False) + return generate_function_call_with_cpu_tensors( + function_name, parameters, len(return_values), use_custom_fallback ) - optional_tensor_list_inputs = re.findall( - "Tensor *\? *\[ *\] +([\w\d_]+)", schema[: schema.find("->")] - ) - for input in ( - inputs - + optional_inputs - + outputs - + tensors_arrays - + optional_tensor_list_inputs - ): - code = re.sub( - "([\(, ]+)" + input + "([, \)]+)", R"\1" + input + "_cpu" + R"\2", code - ) - - return code def create_call_dipu_cpp_function_code_from_schema(schema): @@ -578,20 +727,6 @@ def create_code_to_print_fun_call_info_from_schema(fun_config): return debug_code -def create_int_array_process_code(int_array_list): - if len(int_array_list) <= 0: - return "" - code = ( - R"auto symIntToInt = [](const c10::SymInt& t)-> int64_t {return t.expect_int();};" - + "\n" - ) - for int_array in int_array_list: - code += f"c10::DimVector {int_array}Vector({int_array}.size());\n" - code += f"std::transform({int_array}.cbegin(), {int_array}.cend(), {int_array}Vector.begin(), symIntToInt);\n" - code += f"::diopiSize_t {int_array}DiopiSize{{{int_array}Vector.data(), static_cast({int_array}Vector.size())}};\n" - return code - - def create_autograd_function_name(op_name): op_name = "Dipu" + op_name[0].upper() + op_name[1:] for patten in re.findall("[_\.][a-z]{1}", op_name): @@ -796,7 +931,6 @@ def functions_code_gen(fun_config): diopi_fun_call_code, ) - cppsignature_template = CodeTemplate("$return_code $fun_name($param_list)") for scalar_param in get_function_optional_scalar_args_from_schema( fun_config["schema"] ): @@ -818,7 +952,7 @@ def functions_code_gen(fun_config): ) int_array_list = get_function_int_array_args_from_schema(fun_config["schema"]) - attrs_process_code += create_int_array_process_code(int_array_list) + attrs_process_code += generate_code_to_process_symint_array(int_array_list) for int_array_param in int_array_list: diopi_fun_call_code = re.sub( "([,\(] *&? *)" + int_array_param.strip() + "( *[,\)])", @@ -899,6 +1033,7 @@ def functions_code_gen(fun_config): print_op_args=[print_op_args], record_code_before_call_diopi=[record_code_before_call_diopi], diopi_fun_call_code=[diopi_fun_call_code], + force_fallback_code=[generate_code_fallback_to_cpu(fun_config)], record_code_after_call_diopi=[record_code_after_call_diopi], check_after_call_diopi=[check_after_call_diopi], custom_code_before_return=[ @@ -978,9 +1113,9 @@ def functions_code_gen(fun_config): fbody += custom_autograd_function_code fun_name = wrapper_fun_name - if fun_config.get("autocompare") not in ["disable"] and fun_config.get( - "register_op", True - ) in [True, "True"]: + if fun_config.get("enable_autocompare", True) and fun_config.get( + "register_operator", True + ): auto_compare_fun_name = fun_name + "_autocompare" autocompare_code = autocompare_template.substitute( cppsignautre=[ @@ -1017,8 +1152,8 @@ def functions_code_gen(fun_config): # case 1: custom_fallback=False and autocompare not disabled register_body = "" if fun_config.get("custom_fallback", False) in ["False", False] and fun_config.get( - "autocompare", True - ) in ["True", True]: + "enable_autocompare", True + ): register_body = ( op_no_customfallback_with_autocompare_register_template.substitute( register_name=[get_op_name_from_schema(fun_config["schema"])], @@ -1035,7 +1170,7 @@ def functions_code_gen(fun_config): elif fun_config.get("custom_fallback", False) in [ "False", False, - ] and fun_config.get("autocompare") in ["disable"]: + ] and not fun_config.get("enable_autocompare", True): register_body = ( op_no_customfallback_no_autocompare_register_template.substitute( register_name=[get_op_name_from_schema(fun_config["schema"])], @@ -1049,8 +1184,8 @@ def functions_code_gen(fun_config): ) # case3: custom_fallback=True and autocompare not disabled elif fun_config.get("custom_fallback", False) in ["True", True] and fun_config.get( - "autocompare", True - ) in ["True", True]: + "enable_autocompare", True + ): register_body = ( op_with_customfallback_with_autocompare_register_template.substitute( register_name=[get_op_name_from_schema(fun_config["schema"])], @@ -1071,9 +1206,10 @@ def functions_code_gen(fun_config): ) ) # case4: custom_fallback=True and autocompare disabled - elif fun_config.get("custom_fallback", False) in ["True", True] and fun_config.get( - "autocompare", True - ) in ["disable"]: + elif fun_config.get("custom_fallback", False) in [ + "True", + True, + ] and not fun_config.get("enable_autocompare", True): register_body = ( op_with_customfallback_no_autocompare_register_template.substitute( register_name=[get_op_name_from_schema(fun_config["schema"])], @@ -1170,7 +1306,7 @@ def parse_args(): type=json.loads, default=dict(), help="fun config for all ops", - ) # --fun_config_dict '{"register_op": "False", "dummy_call_diopi":"True"}' + ) # --fun_config_dict '{"register_operator": "false", "dummy_call_diopi":"True"}' parser.add_argument( "--enable_dipu_extra_feature", default=True, @@ -1258,7 +1394,7 @@ def main(): fun_code = memory_format_converter.convert(fun_code, fun_config) functions_code += fun_code - if merged_fun_config.get("register_op", True) in [True, "True"]: + if merged_fun_config.get("register_operator", True): if merged_fun_config.get("autograd", False) == True: autograd_op_register_code += register_code op_register_code += register_code diff --git a/dipu/scripts/autogen_diopi_wrapper/custom_diopi_functions.yaml b/dipu/scripts/autogen_diopi_wrapper/custom_diopi_functions.yaml index 74540a5da..bda47a727 100644 --- a/dipu/scripts/autogen_diopi_wrapper/custom_diopi_functions.yaml +++ b/dipu/scripts/autogen_diopi_wrapper/custom_diopi_functions.yaml @@ -1,6 +1,12 @@ - schema: "custom_op.overloadname(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!)" - autocompare: disable - register_op: False # Whether generate registe code for this op, default value is True + # If return value is ::diopiForceFallbackToCPU, it trys to fallback to CPU. + # + # Automatically disable when 'enable_autocompare' is false or 'register_operator' is false. + # In most cases it generates fallback code automatically, but custom or special operator may stop it. + enable_fallback_to_cpu: false + # Generate autocompare code, 'true' by default. + enable_autocompare: false + register_operator: false # Whether generate registe code for this op, default value is True print_func_call_info: False # whether generate code that prints function call information print_op_args: True # whether generate code that prints op args dummy_call_diopi: False # Does not generate code that actually calls the diopi function, defalut value is False @@ -11,4 +17,4 @@ custom_code_before_return: | dipu::getCurrentDIPUStream().synchronize(); std::cout << "out:" << out << std::endl; - interface: diopiAddScalar(ctx, out, self, other, alpha) \ No newline at end of file + interface: diopiAddScalar(ctx, out, self, other, alpha) diff --git a/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml b/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml index 2759f7fb6..7ec22f5eb 100755 --- a/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml +++ b/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml @@ -1,10 +1,18 @@ - schema: "exampleop.overloadname(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!)" - autocompare: disable + # Generate autocompare code, 'true' by default. + # + # Note: boolean type in YAML 1.2 + # https://yaml.org/spec/1.2.2/#10212-boolean + enable_autocompare: false + # When return value is ::diopiForceFallbackToCPU, it trys to fallback to CPU. + # + # Automatically disable when 'enable_autocompare' is false or 'register_operator' is false. + enable_fallback_to_cpu: false # op gen only on these torch version. use it only if op has different signature on different torch. # if it's only different on implementation , please use compile macro DIPU_TORCHXXX. # torch version number, 5 in total: {X-major}{XX-minor}{XX-patch} torch_ver: ["20000",] - register_op: False # Whether generate register code for this op, default value is True + register_operator: false # Whether generate register code for this op, default value is True print_func_call_info: False # whether generate code that prints function call information print_op_args: True # whether generate code that prints op args dummy_call_diopi: False # Does not generate code that actually calls the diopi function, default value is False @@ -372,7 +380,7 @@ interface: diopiLayerNormBackward(ctx, grad_input, grad_weight, grad_bias, grad_out, input, weight, bias, mean, rstd, normalized_shape); - schema: "adaptive_avg_pool2d.out(Tensor self, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!)" - #autocompare: disable # TODO: cpu impl not support half now + #enable_autocompare: false # TODO: cpu impl not support half now interface: diopiAdaptiveAvgPool2d(ctx, out, self, output_size) - schema: "_adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor" @@ -555,13 +563,13 @@ interface: diopiRelu(ctx, out, self) - schema: "randperm.out(int n, *, Tensor(a!) out) -> Tensor(a!)" - autocompare: disable + enable_autocompare: false custom_code_at_the_beginning: | diopiGeneratorHandle_t generatorDiopiGenerator = toDiopiGeneratorHandle(getDefaultDIPUGenerator()); interface: diopiRandperm(ctx, out, n, generatorDiopiGenerator) - schema: "randperm.generator_out(int n, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)" - autocompare: disable + enable_autocompare: false interface: diopiRandperm(ctx, out, n, generator) - schema: "aten::sum.IntList_out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)" @@ -575,7 +583,7 @@ ::diopiSize_t diopi_size = toDiopiSize(dim); interface: diopiSum(ctx, out, self_dtype_diopi, diopi_size) -- schema: "aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, * ScalarType? dtype=None) -> Tensor" +- schema: "aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor" device: [cuda, ascend] dummy_call_diopi: True custom_code_at_the_beginning: | @@ -589,14 +597,14 @@ interface: diopiAddmm(&context, out, self, mat1, mat2, beta, alpha) - schema: "cross_entropy_loss_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor" - register_op: False + register_operator: false custom_code_at_the_beginning: | const auto reductionDiopi = static_cast<::diopiReduction_t>(reduction); at::Tensor out = nodispatch::empty_like(self); interface: diopiCrossEntropyLossBackward(ctx, out, grad_output, self, target, weight, reductionDiopi, ignore_index.expect_int(), label_smoothing) - schema: "cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor" - register_op: False + register_operator: false custom_code_at_the_beginning: | const int64_t ignore_index_int = ignore_index.expect_int(); const auto reductionDiopi = static_cast<::diopiReduction_t>(reduction); @@ -661,7 +669,7 @@ interface: diopiConvolution2dBackward(ctx, grad_input, grad_weight, grad_bias, grad_output, input, weight, bias_sizes_ptr, stride, padding, dilation, groups); - schema: "convolution_transpose_backward(Tensor grad_output, Tensor input, Tensor weight, int[] bias_sizes, int[] stride, int[] padding, int[] dilation, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias)" - register_op: False + register_operator: false size_attr: [stride, padding, dilation, bias_sizes, output_padding] custom_code_at_the_beginning: | at::Tensor grad_input; @@ -744,7 +752,7 @@ interface: diopiMul(ctx, out, out, out) - schema: "bernoulli_.float(Tensor(a!) self, float p=0.5, *, Generator? generator=None) -> Tensor(a!)" - autocompare: disable + enable_autocompare: false interface: diopiBernoulliScalar(ctx, self, p, generatorDiopiGenerator); - schema: "log.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)" @@ -922,7 +930,7 @@ - schema: "max_pool2d_backward(Tensor grad_output, Tensor input, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool ceil_mode, Tensor? indices=None) -> Tensor" device: [topsrider] - register_op: False + register_operator: false size_attr: [kernel_size, stride, padding, dilation] custom_code_at_the_beginning: | auto out = nodispatch::empty(input.sizes(), grad_output.options()); @@ -1057,7 +1065,7 @@ interface: diopiThresholdBackward(ctx, grad_input, grad_output, self, &threshold) - schema: "transpose_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!)" - register_op: False + register_operator: false custom_code_at_the_beginning: | c10::DimVector output_size(self.sizes().cbegin(), self.sizes().cend()); c10::DimVector output_stride(self.strides().cbegin(), self.strides().cend()); @@ -1177,7 +1185,7 @@ interface: diopiRsqrt(ctx, out, self) - schema: "uniform_(Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None) -> Tensor(a!)" - autocompare: disable + enable_autocompare: false interface: diopiUniformInp(ctx, self, from, to, generator) - schema: "tril(Tensor self, int diagonal=0) -> Tensor" @@ -1259,18 +1267,19 @@ interface: diopiClamp(ctx, out, self, min, max) - schema: "random_(Tensor(a!) self, *, Generator? generator=None) -> Tensor(a!)" - autocompare: disable + enable_autocompare: false interface: diopiRandomInp(ctx, self, 0, nullptr, generator) - schema: "random_.to(Tensor(a!) self, int to, *, Generator? generator=None) -> Tensor(a!)" - autocompare: disable + enable_autocompare: false interface: diopiRandomInp(ctx, self, 0, &to, generator) - schema: "random_.from(Tensor(a!) self, int from, int? to, *, Generator? generator=None) -> Tensor(a!)" - autocompare: disable + enable_autocompare: false interface: "diopiRandomInp(ctx, self, from, to.has_value() ? &to.value() : nullptr, generator)" - schema: "nonzero(Tensor self) -> Tensor" + enable_fallback_to_cpu: false custom_code_at_the_beginning: | at::Tensor out; diopiTensorHandle_t out_ptr = nullptr; @@ -1332,7 +1341,7 @@ interface: diopiProd(ctx, out, self_dtype_diopi, &dim) - schema: repeat(Tensor self, SymInt[] repeats) -> Tensor - autocompare: disable + enable_autocompare: false custom_code_at_the_beginning: | c10::DimVector output_size(repeats.size()); for (int i = 0;i< repeats.size();++i) { @@ -1355,6 +1364,7 @@ interface: diopiSub(ctx, out, other, self, alpha) - schema: "unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor out, Tensor indices, Tensor counts)" + enable_fallback_to_cpu: false custom_code_at_the_beginning: | at::Tensor out; at::Tensor counts; @@ -1377,6 +1387,7 @@ } - schema: "_unique2(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor out, Tensor indices, Tensor counts)" + enable_fallback_to_cpu: false custom_code_at_the_beginning: | at::Tensor out; at::Tensor counts; @@ -1498,6 +1509,7 @@ interface: diopiArgmax(ctx, out, self, ptr, keepdim) - schema: "masked_select(Tensor self, Tensor mask) -> Tensor" + enable_fallback_to_cpu: false custom_code_at_the_beginning: | at::Tensor out; diopiTensorHandle_t out_ptr = nullptr; @@ -1767,41 +1779,41 @@ interface: diopiReciprocal(ctx, out, self) - schema: "normal_(Tensor(a!) self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor(a!)" - autocompare: disable + enable_autocompare: false interface: diopiNormalInp(ctx, self, mean, std, generator) - schema: normal.Tensor_float_out(Tensor mean, float std=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) - autocompare: disable + enable_autocompare: false interface: diopiNormalTensorScalar(ctx, out, mean, std, generator) - schema: normal.Tensor_float(Tensor mean, float std=1, *, Generator? generator=None) -> Tensor - autocompare: disable + enable_autocompare: false custom_code_at_the_beginning: | auto out = nodispatch::empty_like(mean); interface: diopiNormalTensorScalar(ctx, out, mean, std, generator) - schema: normal.float_Tensor_out(float mean, Tensor std, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) - autocompare: disable + enable_autocompare: false interface: diopiNormalScalarTensor(ctx, out, mean, std, generator) - schema: normal.float_Tensor(float mean, Tensor std, *, Generator? generator=None) -> Tensor - autocompare: disable + enable_autocompare: false custom_code_at_the_beginning: | auto out = nodispatch::empty_like(std); interface: diopiNormalScalarTensor(ctx, out, mean, std, generator) - schema: normal.Tensor_Tensor_out(Tensor mean, Tensor std, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) - autocompare: disable + enable_autocompare: false interface: diopiNormalTensor(ctx, out, mean, std, generator) - schema: normal.Tensor_Tensor(Tensor mean, Tensor std, *, Generator? generator=None) -> Tensor - autocompare: disable + enable_autocompare: false custom_code_at_the_beginning: | auto out = nodispatch::empty_like(mean); interface: diopiNormalTensor(ctx, out, mean, std, generator) - schema: normal.float_float_out(float mean, float std, SymInt[] size, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) - autocompare: disable + enable_autocompare: false interface: diopiNormal(ctx, out, mean, std, generator) - schema: "mm(Tensor self, Tensor mat2) -> Tensor" @@ -1954,8 +1966,8 @@ - schema: "ctc_loss_tensor_backward(Tensor grad_output, Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, Tensor neg_log_likelihood, Tensor log_alpha, int blank, int reduction=Mean, bool zero_infinity=False) -> Tensor grad_input" device: [camb] - autocompare: disable - register_op: False + enable_autocompare: false + register_operator: false custom_code_at_the_beginning: | const auto reductionDiopi = static_cast<::diopiReduction_t>(reduction); at::Tensor grad_input = nodispatch::empty_like(log_probs); @@ -2050,8 +2062,8 @@ - schema: "ctc_loss_intlist_backward(Tensor grad_output, Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, Tensor neg_log_likelihood, Tensor log_alpha, int blank, int reduction=Mean, bool zero_infinity=False) -> Tensor grad_input" device: [camb] - autocompare: disable - register_op: False + enable_autocompare: false + register_operator: false ins: [input_lengths_tensor, target_lengths_tensor] custom_code_at_the_beginning: | const auto reductionDiopi = static_cast<::diopiReduction_t>(reduction); @@ -2325,6 +2337,7 @@ interface: diopiArange(ctx, out, start, end, step) - schema: "index.Tensor_out(Tensor self, Tensor?[] indices, *, Tensor(a!) out) -> Tensor(a!)" + enable_fallback_to_cpu: false custom_fallback: True custom_code_at_the_beginning: | std::vector indices_tensor_vec(indices.size()); @@ -2942,18 +2955,10 @@ return; interface: diopiLerpScalar(ctx, out, input, end, weight) -# wrap_diopi_cast_dtype has no corresponding aten op and not registered, it's just a diopi func wrapper. -# use this tricky method to support call multiple diopi-op in one aten-op -- schema: "wrap_diopi_cast_dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a)" - register_op: False - custom_code_at_the_beginning: | - auto out = nodispatch::empty_like(self, self.options().dtype(dtype)); - interface: diopiCastDtype(ctx, out, self); - # a diopi func wrapper. - schema: wrap_diopi_copy_inp(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!) generate_device_guard: False - register_op: False + register_operator: false no_device_check_args: [self, src] interface: diopiCopyInp(ctx, src, self) diff --git a/dipu/scripts/autogen_diopi_wrapper/diopi_wrapper_template.py b/dipu/scripts/autogen_diopi_wrapper/diopi_wrapper_template.py index d10ae7c7f..27f3c83b0 100644 --- a/dipu/scripts/autogen_diopi_wrapper/diopi_wrapper_template.py +++ b/dipu/scripts/autogen_diopi_wrapper/diopi_wrapper_template.py @@ -11,48 +11,47 @@ #include #include #include +#include -#include -#include -#include -#include #include #include #include #include #include -#include +#include +#include #include -#include +#include #include +#include #include #include #include +#include #include #include #include #include +#include #include #include #include #include -#include #include #include #include #include -#include #include #include -#include "csrc_dipu/aten/RegisterDIPU.hpp" #include "csrc_dipu/aten/ops/AutoCompareUtils.hpp" #include "csrc_dipu/aten/ops/DIPUCopy.hpp" -#include "csrc_dipu/aten/ops/NodispatchUtils.hpp" -#include "csrc_dipu/aten/ops/OpUtils.hpp" #include "csrc_dipu/aten/ops/DIPUOpInferrer.h" +#include "csrc_dipu/aten/ops/NodispatchUtils.hpp" #include "csrc_dipu/aten/ops/OpRegexMatch.hpp" +#include "csrc_dipu/aten/ops/OpUtils.hpp" +#include "csrc_dipu/aten/RegisterDIPU.hpp" #include "csrc_dipu/base/basedef.h" #include "csrc_dipu/diopirt/diopirt_impl.h" #include "csrc_dipu/profiler/profiler.h" @@ -129,6 +128,8 @@ ::diopiError_t ret = $diopi_fun_call_code $record_code_after_call_diopi + + $force_fallback_code $check_after_call_diopi diff --git a/dipu/torch_dipu/csrc_dipu/aten/ops/AutoCompareUtils.hpp b/dipu/torch_dipu/csrc_dipu/aten/ops/AutoCompareUtils.hpp index 2077521dd..47be2a871 100644 --- a/dipu/torch_dipu/csrc_dipu/aten/ops/AutoCompareUtils.hpp +++ b/dipu/torch_dipu/csrc_dipu/aten/ops/AutoCompareUtils.hpp @@ -81,7 +81,7 @@ inline std::string allclose_autocompare(const at::Tensor& tensor_cpu, constexpr double tolerance_absolute = 1e-4; constexpr double tolerance_relative = 1e-5; const at::Tensor& tensor_cpu_from_device = - toCpuTensorWithoutDiopiCopy(tensor_device); + tensor_clone_to_host(tensor_device); bool passed = at::allclose(tensor_cpu, tensor_cpu_from_device, tolerance_absolute, tolerance_relative, true); if (passed) { diff --git a/dipu/torch_dipu/csrc_dipu/aten/ops/OpUtils.hpp b/dipu/torch_dipu/csrc_dipu/aten/ops/OpUtils.hpp index 713df9dde..53a757da4 100644 --- a/dipu/torch_dipu/csrc_dipu/aten/ops/OpUtils.hpp +++ b/dipu/torch_dipu/csrc_dipu/aten/ops/OpUtils.hpp @@ -35,13 +35,13 @@ namespace dipu { namespace native { // avoid infinite recursion when dumpArg() before calling diopiCopy() -inline at::Tensor toCpuTensorWithoutDiopiCopy(const at::Tensor& in) { +inline at::Tensor tensor_clone_to_host(const at::Tensor& in) { if (in.is_cpu()) { return in; } - at::Tensor out = at::empty_strided(in.sizes(), in.strides(), - in.options().device(c10::Device("cpu"))); + auto opt = in.options().device(c10::Device("cpu")); + auto out = at::empty_strided(in.sizes(), in.strides(), opt); if (in.nbytes() > 0) { dipu::getCurrentDIPUStream().synchronize(); dipu::devapis::memCopyD2H(out.storage().nbytes(), out.data_ptr(), @@ -50,6 +50,77 @@ inline at::Tensor toCpuTensorWithoutDiopiCopy(const at::Tensor& in) { return out; } +inline c10::optional tensor_clone_to_host( + const c10::optional& in) { + if (in) { + if (auto& tensor = in.value(); tensor.defined()) { + return c10::make_optional(tensor_clone_to_host(tensor)); + } + } + return c10::nullopt; +} + +inline c10::List> tensor_clone_to_host( + const c10::List>& in /* Tensor?[] */) { + auto out = c10::List>(); + out.reserve(in.size()); + for (auto const& tensor : in) { + out.push_back(tensor_clone_to_host(tensor)); + } + return out; +} + +template +inline auto tensor_clone_to_host(const R& in) + -> decltype(in.begin(), in.end(), std::vector()) { + auto out = std::vector(); + out.reserve(in.size()); + for (auto const& tensor : in) { + out.push_back(tensor_clone_to_host(tensor)); + } + return out; +} + +inline at::Tensor tensor_reference_or_clone_to_host( + at::Tensor const& in, + std::initializer_list> + device_host_tensor_pairs) { + for (auto const& [device, host] : device_host_tensor_pairs) { + if (in.is_same(device)) { + return host; + } + } + return tensor_clone_to_host(in); +} + +inline void tensor_copy_host_to_device(at::Tensor& out, const at::Tensor& in, + DIPUStream stream) { + TORCH_CHECK(in.is_cpu(), "in should be cpu tensor"); + TORCH_CHECK(!out.is_cpu(), "out should not be cpu tensor"); + + stream.synchronize(); + + if (out.sizes() != in.sizes()) { + auto device = out.options().device(); + auto option = in.options().device(device); + out = at::empty_strided(in.sizes(), in.strides(), option); + } + + auto size = out.storage().nbytes(); + dipu::devapis::memCopyH2D(size, out.data_ptr(), in.data_ptr()); +} + +inline std::vector tensor_array_to_vector( + at::ArrayRef in) { + return in.vec(); +} + +// Warning: it returns reference, thus decltype(auto) is required to avoid copy. +inline std::vector& tensor_array_to_vector( + std::vector& in) { + return in; +} + inline bool checkTensorDevice() { static bool enable = []() { const char* env_ptr = std::getenv("DIPU_CHECK_TENSOR_DEVICE"); @@ -135,7 +206,7 @@ inline std::string dumpArg(const at::Tensor& tensor) { << ", storage_data_ptr: " << tensor.storage().data_ptr().get() << ", storage_offset: " << tensor.storage_offset(); if (dumpOpArgLevel() > 2) { - stream << '\n' << toCpuTensorWithoutDiopiCopy(tensor); + stream << '\n' << tensor_clone_to_host(tensor); } } else { stream << "undefined";