From f7c3c84b99e0c7bcc9e28633a8bc7158951144e3 Mon Sep 17 00:00:00 2001 From: zzk0 Date: Tue, 9 May 2023 08:03:18 +0000 Subject: [PATCH 01/13] fix typo select unittest --- python/tests/ops/test_select_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tests/ops/test_select_op.py b/python/tests/ops/test_select_op.py index 68f2b653c0..2c870c47f9 100644 --- a/python/tests/ops/test_select_op.py +++ b/python/tests/ops/test_select_op.py @@ -29,7 +29,7 @@ def setUp(self): def prepare_inputs(self): self.inputs = { - "Condition": self.random(self.case["shape"], "bool", 0, 2), + "Condition": self.random(self.case["shape"], "bool"), "X": self.random(self.case["shape"], self.case["dtype"]), "Y": self.random(self.case["shape"], self.case["dtype"]) } From 0d551d62fe9e1c20a9475d55a60b8aadb90cdbf7 Mon Sep 17 00:00:00 2001 From: zzk0 Date: Tue, 9 May 2023 08:31:12 +0000 Subject: [PATCH 02/13] op unittest for sort --- python/tests/ops/test_sort_op.py | 255 +++++++++++++++++++++++++------ 1 file changed, 207 insertions(+), 48 deletions(-) diff --git a/python/tests/ops/test_sort_op.py b/python/tests/ops/test_sort_op.py index 9230fcc4fb..a28d5400b0 100644 --- a/python/tests/ops/test_sort_op.py +++ b/python/tests/ops/test_sort_op.py @@ -14,45 +14,43 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest -import numpy as np -from op_test import OpTest, OpTestTool import paddle -import paddle.nn.functional as F -import cinn from cinn.frontend import * from cinn.common import * +from op_test import OpTest, OpTestTool +from op_test_helper import TestCaseHelper @OpTestTool.skip_if(not is_compiled_with_cuda(), "x86 test will be skipped due to timeout.") class TestSortOp(OpTest): def setUp(self): - self.init_case() + print(f"\nRunning {self.__class__.__name__}: {self.case}") + self.inputs = {} + self.prepare_inputs() - def init_case(self): + def prepare_inputs(self): self.inputs = { - "x1": np.random.random([ - 2, - 4, - ]).astype("float32") + "x": self.random(self.case["shape"], self.case["dtype"]) } - self.axis = 1 - self.descending = False + self.axis = self.case["axis"] + self.descending = self.case["descending"] def build_paddle_program(self, target): - x1 = paddle.to_tensor(self.inputs["x1"], stop_gradient=True) + x1 = paddle.to_tensor(self.inputs["x"], stop_gradient=True) out = paddle.sort(x1, self.axis, self.descending) self.paddle_outputs = [out] def build_cinn_program(self, target): - builder = NetBuilder("sum") - x1 = builder.create_input(Float(32), self.inputs["x1"].shape, "x1") + builder = NetBuilder("sort") + x1 = builder.create_input( + self.nptype2cinntype(self.inputs["x"].dtype), + self.inputs["x"].shape, "x") out = builder.sort(x1, self.axis, not self.descending) prog = builder.build() forward_res = self.get_cinn_output(prog, target, [x1], - [self.inputs["x1"]], [out]) + [self.inputs["x"]], [out]) self.cinn_outputs = forward_res @@ -60,41 +58,202 @@ def test_check_results(self): self.check_outputs_and_grads() -class TestSortCase1(TestSortOp): - def init_case(self): - self.inputs = { - "x1": np.random.random([ - 2, - 4, - ]).astype("float32") - } - self.axis = 0 - self.descending = False +class TestSortOpShapeTest(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestSortOpShapeTest" + self.cls = TestSortOp + self.inputs = [ + { + "shape": [10], + }, + { + "shape": [8, 5], + }, + { + "shape": [10, 3, 5], + }, + # F0509 08:18:53.483060 2316861 cuda_util.cc:110] CUDA Driver Error: cuLaunchKernel(static_cast(kernel_fn), grid_x, grid_y, grid_z, block_x, block_y, block_z, 0, static_cast(stream), kernel_args.data(), nullptr) failed with error: invalid argument + # { + # "shape": [80, 40, 5, 7], + # }, + { + "shape": [80, 1, 5, 7], + }, + # { + # "shape": [80, 3, 1024, 7], + # }, + # { + # "shape": [10, 5, 1024, 2048], + # }, + { + "shape": [1], + }, + { + "shape": [1, 1, 1, 1], + }, + { + "shape": [1, 1, 1, 1, 1], + }, + { + "shape": [512], + }, + { + "shape": [1024], + }, + { + "shape": [2048], + }, + # { + # "shape": [512, 256], + # }, + # { + # "shape": [128, 64, 32], + # }, + { + "shape": [16, 8, 4, 2], + }, + { + "shape": [16, 8, 4, 2, 1], + } + ] + self.dtypes = [{"dtype": "float32"}] + self.attrs = [{"axis": 0, "descending": False}] -class TestSortCase2(TestSortOp): - def init_case(self): - self.inputs = { - "x1": np.random.random([ - 2, - 4, - ]).astype("float32") - } - self.axis = 0 - self.descending = True +class TestSortOpDtypeTest(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestSortOpDtypeTest" + self.cls = TestSortOp + self.inputs = [ + { + "shape": [16, 8, 4, 2], + }, + # { + # "shape": [80, 40, 5, 7], + # }, + # { + # "shape": [16, 8, 4, 2, 1], + # } + ] + self.dtypes = [ + { + "dtype": "float32" + }, + { + "dtype": "float64" + }, + { + "dtype": "int32" + }, + { + "dtype": "int64" + }, + ] + self.attrs = [{"axis": 0, "descending": False}] -class TestSortCase3(TestSortOp): - def init_case(self): - self.inputs = { - "x1": np.random.random([ - 2, - 4, - ]).astype("float32") - } - self.axis = 1 - self.descending = True +class TestSortOpAxisTest(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestSortOpAttrsTest" + self.cls = TestSortOp + self.inputs = [ + { + "shape": [16, 8, 4, 2], + }, + # { + # "shape": [80, 40, 5, 7], + # }, + # { + # "shape": [16, 8, 4, 2, 1], + # } + ] + self.dtypes = [{"dtype": "float32"}] + self.attrs = [{ + "axis": 0, + "descending": False + }, { + "axis": 1, + "descending": False + }, { + "axis": 2, + "descending": False + }, { + "axis": 3, + "descending": False + }] + + +class TestSortOpDescedingTest(TestSortOpShapeTest): + def init_attrs(self): + self.class_name = "TestSortOpDescedingTest" + self.cls = TestSortOp + self.inputs = [ + { + "shape": [10], + }, + { + "shape": [8, 5], + }, + { + "shape": [10, 3, 5], + }, + # F0509 08:18:53.483060 2316861 cuda_util.cc:110] CUDA Driver Error: cuLaunchKernel(static_cast(kernel_fn), grid_x, grid_y, grid_z, block_x, block_y, block_z, 0, static_cast(stream), kernel_args.data(), nullptr) failed with error: invalid argument + # { + # "shape": [80, 40, 5, 7], + # }, + { + "shape": [80, 1, 5, 7], + }, + # { + # "shape": [80, 3, 1024, 7], + # }, + # { + # "shape": [10, 5, 1024, 2048], + # }, + { + "shape": [1], + }, + { + "shape": [1, 1, 1, 1], + }, + { + "shape": [1, 1, 1, 1, 1], + }, + { + "shape": [512], + }, + { + "shape": [1024], + }, + { + "shape": [2048], + }, + # { + # "shape": [512, 256], + # }, + # { + # "shape": [128, 64, 32], + # }, + { + "shape": [16, 8, 4, 2], + }, + { + "shape": [16, 8, 4, 2, 1], + } + ] + self.dtypes = [{"dtype": "float32"}] + self.attrs = [ + # NOTE: TestSortOpShapeTest has already tested with + # the parameter 'descending=False', so just skip + { + "axis": 0, + "descending": True + } + ] if __name__ == "__main__": - unittest.main() + # TestSortOpShapeTest().run() + TestSortOpDtypeTest().run() + TestSortOpAxisTest().run() + TestSortOpDescedingTest().run() From 50463a4e5e1aac9ac406843884724e880581eeec Mon Sep 17 00:00:00 2001 From: zzk0 Date: Wed, 10 May 2023 03:39:33 +0000 Subject: [PATCH 03/13] op unittest for sort --- python/tests/ops/test_sort_op.py | 141 ++++++++++--------------------- 1 file changed, 43 insertions(+), 98 deletions(-) diff --git a/python/tests/ops/test_sort_op.py b/python/tests/ops/test_sort_op.py index a28d5400b0..e9ba730522 100644 --- a/python/tests/ops/test_sort_op.py +++ b/python/tests/ops/test_sort_op.py @@ -64,57 +64,49 @@ def init_attrs(self): self.cls = TestSortOp self.inputs = [ { - "shape": [10], + "shape": [512], }, { - "shape": [8, 5], + "shape": [1024], }, { - "shape": [10, 3, 5], + "shape": [2048], }, - # F0509 08:18:53.483060 2316861 cuda_util.cc:110] CUDA Driver Error: cuLaunchKernel(static_cast(kernel_fn), grid_x, grid_y, grid_z, block_x, block_y, block_z, 0, static_cast(stream), kernel_args.data(), nullptr) failed with error: invalid argument - # { - # "shape": [80, 40, 5, 7], - # }, { - "shape": [80, 1, 5, 7], + "shape": [128, 64], }, - # { - # "shape": [80, 3, 1024, 7], - # }, - # { - # "shape": [10, 5, 1024, 2048], - # }, { - "shape": [1], + "shape": [8, 32, 16], }, { - "shape": [1, 1, 1, 1], + "shape": [16, 8, 4, 2], }, { - "shape": [1, 1, 1, 1, 1], + "shape": [16, 8, 4, 2, 5], }, { - "shape": [512], + "shape": [16, 8, 1, 2, 32], }, { - "shape": [1024], + "shape": [1], }, { - "shape": [2048], + "shape": [1, 1, 1, 1], }, + { + "shape": [1, 1, 1, 1, 1], + }, + # TODO: known issue cinn/hlir/op/contrib/sort.cc:201 + # the array will exceed the cuda kernel stack size limit # { - # "shape": [512, 256], + # "shape": [32768], # }, # { - # "shape": [128, 64, 32], + # "shape": [65536], + # }, + # { + # "shape": [131072], # }, - { - "shape": [16, 8, 4, 2], - }, - { - "shape": [16, 8, 4, 2, 1], - } ] self.dtypes = [{"dtype": "float32"}] self.attrs = [{"axis": 0, "descending": False}] @@ -125,15 +117,18 @@ def init_attrs(self): self.class_name = "TestSortOpDtypeTest" self.cls = TestSortOp self.inputs = [ + { + "shape": [2048], + }, + { + "shape": [128, 64], + }, + { + "shape": [8, 32, 16], + }, { "shape": [16, 8, 4, 2], }, - # { - # "shape": [80, 40, 5, 7], - # }, - # { - # "shape": [16, 8, 4, 2, 1], - # } ] self.dtypes = [ { @@ -160,12 +155,6 @@ def init_attrs(self): { "shape": [16, 8, 4, 2], }, - # { - # "shape": [80, 40, 5, 7], - # }, - # { - # "shape": [16, 8, 4, 2, 1], - # } ] self.dtypes = [{"dtype": "float32"}] self.attrs = [{ @@ -188,72 +177,28 @@ def init_attrs(self): self.class_name = "TestSortOpDescedingTest" self.cls = TestSortOp self.inputs = [ - { - "shape": [10], - }, - { - "shape": [8, 5], - }, - { - "shape": [10, 3, 5], - }, - # F0509 08:18:53.483060 2316861 cuda_util.cc:110] CUDA Driver Error: cuLaunchKernel(static_cast(kernel_fn), grid_x, grid_y, grid_z, block_x, block_y, block_z, 0, static_cast(stream), kernel_args.data(), nullptr) failed with error: invalid argument - # { - # "shape": [80, 40, 5, 7], - # }, - { - "shape": [80, 1, 5, 7], - }, - # { - # "shape": [80, 3, 1024, 7], - # }, - # { - # "shape": [10, 5, 1024, 2048], - # }, - { - "shape": [1], - }, - { - "shape": [1, 1, 1, 1], - }, - { - "shape": [1, 1, 1, 1, 1], - }, - { - "shape": [512], - }, - { - "shape": [1024], - }, - { - "shape": [2048], - }, - # { - # "shape": [512, 256], - # }, - # { - # "shape": [128, 64, 32], - # }, { "shape": [16, 8, 4, 2], }, - { - "shape": [16, 8, 4, 2, 1], - } ] self.dtypes = [{"dtype": "float32"}] - self.attrs = [ - # NOTE: TestSortOpShapeTest has already tested with - # the parameter 'descending=False', so just skip - { - "axis": 0, - "descending": True - } - ] + self.attrs = [{ + "axis": 0, + "descending": True + }, { + "axis": 1, + "descending": True + }, { + "axis": 2, + "descending": True + }, { + "axis": 3, + "descending": True + }] if __name__ == "__main__": - # TestSortOpShapeTest().run() + TestSortOpShapeTest().run() TestSortOpDtypeTest().run() TestSortOpAxisTest().run() TestSortOpDescedingTest().run() From 703c340b6f9b6877db967711935dedd9bcf3235c Mon Sep 17 00:00:00 2001 From: zzk0 Date: Wed, 10 May 2023 09:45:31 +0000 Subject: [PATCH 04/13] enhance TestCaseHelper & add special case for sort op --- python/tests/ops/op_test_helper.py | 11 +++++++++++ python/tests/ops/test_sort_op.py | 19 ++++++++++++++++++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/python/tests/ops/op_test_helper.py b/python/tests/ops/op_test_helper.py index 445ab046ac..3d61aa7d24 100644 --- a/python/tests/ops/op_test_helper.py +++ b/python/tests/ops/op_test_helper.py @@ -18,6 +18,8 @@ import unittest import re +from typing import Union, List, Type + parser = argparse.ArgumentParser(description="Argparse for op test helper") parser.add_argument( "--case", @@ -102,4 +104,13 @@ def run(self): test_suite.addTests(test_loader.loadTestsFromTestCase(x)) runner = unittest.TextTestRunner() res = runner.run(test_suite) + if not res.wasSuccessful(): + sys.exit(not res.wasSuccessful()) + + +def run_test(test_class: Union[int, List[Type]]): + test_suite = unittest.TestLoader().loadTestsFromTestCase(test_class) + runner = unittest.TextTestRunner() + res = runner.run(test_suite) + if not res.wasSuccessful(): sys.exit(not res.wasSuccessful()) diff --git a/python/tests/ops/test_sort_op.py b/python/tests/ops/test_sort_op.py index e9ba730522..85cd969902 100644 --- a/python/tests/ops/test_sort_op.py +++ b/python/tests/ops/test_sort_op.py @@ -15,10 +15,12 @@ # limitations under the License. import paddle +import numpy as np +import unittest from cinn.frontend import * from cinn.common import * from op_test import OpTest, OpTestTool -from op_test_helper import TestCaseHelper +from op_test_helper import TestCaseHelper, run_test @OpTestTool.skip_if(not is_compiled_with_cuda(), @@ -58,6 +60,19 @@ def test_check_results(self): self.check_outputs_and_grads() +class TestSortOpDumpicateElement(TestSortOp): + def setUp(self): + self.inputs = {} + self.prepare_inputs() + + def prepare_inputs(self): + self.inputs = { + "x": np.array([1, 1, 1, 2, 2, 2, 3, 3, 3]).astype("int64") + } + self.axis = 0 + self.descending = False + + class TestSortOpShapeTest(TestCaseHelper): def init_attrs(self): self.class_name = "TestSortOpShapeTest" @@ -198,6 +213,8 @@ def init_attrs(self): if __name__ == "__main__": + run_test(TestSortOpDumpicateElement) + TestSortOpShapeTest().run() TestSortOpDtypeTest().run() TestSortOpAxisTest().run() From e6da5908f6addf4ee578a432f54d834afd028fdc Mon Sep 17 00:00:00 2001 From: zzk0 Date: Thu, 11 May 2023 01:21:32 +0000 Subject: [PATCH 05/13] fix sort bug for duplicate element --- cinn/hlir/op/contrib/sort.cc | 4 ++-- cinn/runtime/cpu/host_intrinsics.cc | 22 +++++++++++++++++++ .../runtime/cuda/cinn_cuda_runtime_source.cuh | 13 +++++++++++ cinn/runtime/cuda/cuda_intrinsics.cc | 9 ++++++++ python/tests/ops/op_test_helper.py | 13 ++++++++--- python/tests/ops/test_sort_op.py | 10 ++------- 6 files changed, 58 insertions(+), 13 deletions(-) diff --git a/cinn/hlir/op/contrib/sort.cc b/cinn/hlir/op/contrib/sort.cc index 927949f13b..e8471bcb4a 100644 --- a/cinn/hlir/op/contrib/sort.cc +++ b/cinn/hlir/op/contrib/sort.cc @@ -56,9 +56,9 @@ std::vector ArgSort(const ir::Tensor &A, std::string find_func_name; std::string index_func_name; if (target.arch == common::Target::Arch::NVGPU) { - find_func_name.assign("cinn_cuda_find_int_nd"); + find_func_name.assign("cinn_nvgpu_next_smallest_int32"); } else if (target.arch == common::Target::Arch::X86) { - find_func_name.assign("cinn_host_find_int_nd"); + find_func_name.assign("cinn_host_next_smallest_int32"); } else { LOG(FATAL) << "ArgSort only supports X86 and NVGPU ! Please Check.\n"; } diff --git a/cinn/runtime/cpu/host_intrinsics.cc b/cinn/runtime/cpu/host_intrinsics.cc index 2ce67f7b86..d31d810045 100644 --- a/cinn/runtime/cpu/host_intrinsics.cc +++ b/cinn/runtime/cpu/host_intrinsics.cc @@ -64,6 +64,19 @@ inline int cinn_host_find_float_nd(const cinn_buffer_t* buf, int size, float num #undef __cinn_host_find_kernel +inline int cinn_host_next_smallest_int32(cinn_buffer_t* buf, int size, int num, int begin, int stride) { + int id = -1; + for (int i = begin; i < begin + size * stride; i += stride) { + if (id == -1 || reinterpret_cast(buf->memory)[i] < reinterpret_cast(buf->memory)[id]) { + id = i; + } + } + if (id != -1) { + reinterpret_cast(buf->memory)[id] = 2147483647; + } + return (id - begin) / stride; +} + #define CINN_HOST_LT_NUM(TYPE_SUFFIX, TYPE) \ inline int cinn_host_lt_num_##TYPE_SUFFIX( \ const cinn_buffer_t* buf, const int size, const TYPE num, const int offset, const int stride) { \ @@ -349,6 +362,15 @@ CINN_REGISTER_HELPER(host_intrinsics) { .AddInputType() .End(); + REGISTER_EXTERN_FUNC_HELPER(cinn_host_next_smallest_int32, host_target) + .SetRetType() + .AddInputType() + .AddInputType() + .AddInputType() + .AddInputType() + .AddInputType() + .End(); + #define _REGISTER_CINN_HOST_LT_NUM(TYPE_SUFFIX, TYPE) \ REGISTER_EXTERN_FUNC_HELPER(cinn_host_lt_num_##TYPE_SUFFIX, host_target) \ .SetRetType() \ diff --git a/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh b/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh index a5d751cfa4..f4faf63681 100644 --- a/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh +++ b/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh @@ -525,6 +525,19 @@ __device__ inline int cinn_cuda_find_float_nd(const float *buf, int size, float #undef __cinn_cuda_find_kernel +__device__ inline int cinn_nvgpu_next_smallest_int32(int *buf, int size, int num, int begin, int stride) { + int id = -1; + for (int i = begin; i < begin + size * stride; i += stride) { + if (id == -1 || buf[i] < buf[id]) { + id = i; + } + } + if (id != -1) { + buf[id] = 2147483647; + } + return (id - begin) / stride; +} + #define __cinn_cuda_find_from_kernel(buf, size, num, begin) \ do { \ for (int i = begin; i < size; ++i) { \ diff --git a/cinn/runtime/cuda/cuda_intrinsics.cc b/cinn/runtime/cuda/cuda_intrinsics.cc index 698395f9fc..071a3eb136 100644 --- a/cinn/runtime/cuda/cuda_intrinsics.cc +++ b/cinn/runtime/cuda/cuda_intrinsics.cc @@ -224,6 +224,15 @@ CINN_REGISTER_HELPER(cuda_intrinsics) { .AddInputType() .End(); + REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_nvgpu_next_smallest_int32, target) + .SetRetType() + .AddInputType() + .AddInputType() + .AddInputType() + .AddInputType() + .AddInputType() + .End(); + #define _REGISTER_CINN_CUDA_LT_NUM(TYPE_SUFFIX, TYPE) \ REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_cuda_lt_num_##TYPE_SUFFIX, target) \ .SetRetType() \ diff --git a/python/tests/ops/op_test_helper.py b/python/tests/ops/op_test_helper.py index 3d61aa7d24..1f1708a4bd 100644 --- a/python/tests/ops/op_test_helper.py +++ b/python/tests/ops/op_test_helper.py @@ -18,7 +18,8 @@ import unittest import re -from typing import Union, List, Type +from unittest import suite +from typing import Union, List parser = argparse.ArgumentParser(description="Argparse for op test helper") parser.add_argument( @@ -108,8 +109,14 @@ def run(self): sys.exit(not res.wasSuccessful()) -def run_test(test_class: Union[int, List[Type]]): - test_suite = unittest.TestLoader().loadTestsFromTestCase(test_class) +def run_test(test_class: Union[suite.TestSuite, List[suite.TestSuite]]): + test_suite = unittest.TestSuite() + test_loader = unittest.TestLoader() + if isinstance(test_class, type): + test_suite.addTests(test_loader.loadTestsFromTestCase(test_class)) + else: + for cls in test_class: + test_suite.addTests(test_loader.loadTestsFromTestCase(cls)) runner = unittest.TextTestRunner() res = runner.run(test_suite) if not res.wasSuccessful(): diff --git a/python/tests/ops/test_sort_op.py b/python/tests/ops/test_sort_op.py index 85cd969902..29bbede0d6 100644 --- a/python/tests/ops/test_sort_op.py +++ b/python/tests/ops/test_sort_op.py @@ -15,16 +15,12 @@ # limitations under the License. import paddle -import numpy as np -import unittest from cinn.frontend import * from cinn.common import * -from op_test import OpTest, OpTestTool +from op_test import OpTest from op_test_helper import TestCaseHelper, run_test -@OpTestTool.skip_if(not is_compiled_with_cuda(), - "x86 test will be skipped due to timeout.") class TestSortOp(OpTest): def setUp(self): print(f"\nRunning {self.__class__.__name__}: {self.case}") @@ -66,9 +62,7 @@ def setUp(self): self.prepare_inputs() def prepare_inputs(self): - self.inputs = { - "x": np.array([1, 1, 1, 2, 2, 2, 3, 3, 3]).astype("int64") - } + self.inputs = {"x": self.random([128], "int64", -10, 10)} self.axis = 0 self.descending = False From 1ced8849815d26565357aa71a00ee51a8a02446a Mon Sep 17 00:00:00 2001 From: zzk0 Date: Thu, 11 May 2023 01:42:32 +0000 Subject: [PATCH 06/13] fix index typo --- cinn/runtime/cpu/host_intrinsics.cc | 3 ++- cinn/runtime/cuda/cinn_cuda_runtime_source.cuh | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/cinn/runtime/cpu/host_intrinsics.cc b/cinn/runtime/cpu/host_intrinsics.cc index d31d810045..a9d1cd0e38 100644 --- a/cinn/runtime/cpu/host_intrinsics.cc +++ b/cinn/runtime/cpu/host_intrinsics.cc @@ -73,8 +73,9 @@ inline int cinn_host_next_smallest_int32(cinn_buffer_t* buf, int size, int num, } if (id != -1) { reinterpret_cast(buf->memory)[id] = 2147483647; + return (id - begin) / stride; } - return (id - begin) / stride; + return -1; } #define CINN_HOST_LT_NUM(TYPE_SUFFIX, TYPE) \ diff --git a/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh b/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh index 63f1f277cf..fedfa0d674 100644 --- a/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh +++ b/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh @@ -531,8 +531,9 @@ __device__ inline int cinn_nvgpu_next_smallest_int32(int *buf, int size, int num } if (id != -1) { buf[id] = 2147483647; + return (id - begin) / stride; } - return (id - begin) / stride; + return -1; } #define __cinn_cuda_find_from_kernel(buf, size, num, begin) \ From 5a1fd74110cb84fa71badabe4700ecd66b3d7cd1 Mon Sep 17 00:00:00 2001 From: zzk0 Date: Fri, 12 May 2023 00:41:02 +0000 Subject: [PATCH 07/13] refine hard code testcase --- cinn/hlir/op/contrib/argmax_test.cc | 2 +- cinn/hlir/op/contrib/argmin_test.cc | 2 +- cinn/hlir/op/contrib/sort_test.cc | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cinn/hlir/op/contrib/argmax_test.cc b/cinn/hlir/op/contrib/argmax_test.cc index 419fb081c1..3b42e73f1f 100644 --- a/cinn/hlir/op/contrib/argmax_test.cc +++ b/cinn/hlir/op/contrib/argmax_test.cc @@ -94,7 +94,7 @@ void TestGenerateCodeCpu_Argmax_Keep(void* _args, int32_t num_args) for (int32_t j = 0; j < 3; j += 1) { for (int32_t k = 0; k < 28; k += 1) { for (int32_t a = 0; a < 28; a += 1) { - test_argmax_in_index[((2352 * i) + ((784 * j) + ((28 * k) + a)))] = cinn_host_find_int_nd(_test_argmax_in_index_temp, 3, j, ((2352 * i) + ((28 * k) + a)), 784); + test_argmax_in_index[((2352 * i) + ((784 * j) + ((28 * k) + a)))] = cinn_host_next_smallest_int32(_test_argmax_in_index_temp, 3, j, ((2352 * i) + ((28 * k) + a)), 784); }; }; }; diff --git a/cinn/hlir/op/contrib/argmin_test.cc b/cinn/hlir/op/contrib/argmin_test.cc index 98e9bdb8d5..bfe053f101 100644 --- a/cinn/hlir/op/contrib/argmin_test.cc +++ b/cinn/hlir/op/contrib/argmin_test.cc @@ -93,7 +93,7 @@ void TestGenerateCodeCpu_Argmin_Keep(void* _args, int32_t num_args) for (int32_t j = 0; j < 3; j += 1) { for (int32_t k = 0; k < 28; k += 1) { for (int32_t a = 0; a < 28; a += 1) { - test_argmin_in_index[((2352 * i) + ((784 * j) + ((28 * k) + a)))] = cinn_host_find_int_nd(_test_argmin_in_index_temp, 3, j, ((2352 * i) + ((28 * k) + a)), 784); + test_argmin_in_index[((2352 * i) + ((784 * j) + ((28 * k) + a)))] = cinn_host_next_smallest_int32(_test_argmin_in_index_temp, 3, j, ((2352 * i) + ((28 * k) + a)), 784); }; }; }; diff --git a/cinn/hlir/op/contrib/sort_test.cc b/cinn/hlir/op/contrib/sort_test.cc index 860eef32e3..e5f990ba11 100644 --- a/cinn/hlir/op/contrib/sort_test.cc +++ b/cinn/hlir/op/contrib/sort_test.cc @@ -112,7 +112,7 @@ void TestGenerateCodeCpu_Sort(void* _args, int32_t num_args) }; for (int32_t i = 0; i < 4; i += 1) { for (int32_t j = 0; j < 28; j += 1) { - test_sort_out_index[((28 * i) + j)] = cinn_host_find_int_nd(_test_sort_out_index_temp, 28, j, (28 * i), 1); + test_sort_out_index[((28 * i) + j)] = cinn_host_next_smallest_int32(_test_sort_out_index_temp, 28, j, (28 * i), 1); }; }; for (int32_t i = 0; i < 4; i += 1) { From 2ffc6e20ffd029bff81140a98e925d379cdd96ab Mon Sep 17 00:00:00 2001 From: zzk0 Date: Fri, 12 May 2023 13:00:39 +0000 Subject: [PATCH 08/13] reduce array size to avoid large cuda memory occupation --- python/tests/ops/test_sort_op.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/tests/ops/test_sort_op.py b/python/tests/ops/test_sort_op.py index 29bbede0d6..2ac70030fc 100644 --- a/python/tests/ops/test_sort_op.py +++ b/python/tests/ops/test_sort_op.py @@ -82,19 +82,19 @@ def init_attrs(self): "shape": [2048], }, { - "shape": [128, 64], + "shape": [64, 32], }, { - "shape": [8, 32, 16], + "shape": [4, 32, 16], }, { "shape": [16, 8, 4, 2], }, { - "shape": [16, 8, 4, 2, 5], + "shape": [2, 8, 4, 2, 5], }, { - "shape": [16, 8, 1, 2, 32], + "shape": [4, 8, 1, 2, 32], }, { "shape": [1], @@ -130,10 +130,10 @@ def init_attrs(self): "shape": [2048], }, { - "shape": [128, 64], + "shape": [64, 32], }, { - "shape": [8, 32, 16], + "shape": [4, 32, 16], }, { "shape": [16, 8, 4, 2], From 47078c27701dac18846266e0990468ecdbcf6279 Mon Sep 17 00:00:00 2001 From: zzk0 Date: Fri, 12 May 2023 13:05:38 +0000 Subject: [PATCH 09/13] add not passed test case --- python/tests/ops/test_sort_op.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/python/tests/ops/test_sort_op.py b/python/tests/ops/test_sort_op.py index 2ac70030fc..5a31d05158 100644 --- a/python/tests/ops/test_sort_op.py +++ b/python/tests/ops/test_sort_op.py @@ -67,6 +67,18 @@ def prepare_inputs(self): self.descending = False +# This test case will cause CINN to allocate a large amount of GPU memory, nearly 10 GB. +class TestSortOpLargeCudaMemoryOccupation(TestSortOp): + def setUp(self): + self.inputs = {} + self.prepare_inputs() + + def prepare_inputs(self): + self.inputs = {"x": self.random([128, 64], "float64")} + self.axis = 0 + self.descending = False + + class TestSortOpShapeTest(TestCaseHelper): def init_attrs(self): self.class_name = "TestSortOpShapeTest" @@ -82,7 +94,7 @@ def init_attrs(self): "shape": [2048], }, { - "shape": [64, 32], + "shape": [128, 64], }, { "shape": [4, 32, 16], @@ -208,6 +220,7 @@ def init_attrs(self): if __name__ == "__main__": run_test(TestSortOpDumpicateElement) + # run_test(TestSortOpLargeCudaMemoryOccupation) TestSortOpShapeTest().run() TestSortOpDtypeTest().run() From 6d82a0b2e1fa1c90736bbd3239fd7747988970cc Mon Sep 17 00:00:00 2001 From: zzk0 Date: Sat, 13 May 2023 00:31:01 +0000 Subject: [PATCH 10/13] reduce array size again --- python/tests/ops/test_sort_op.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/python/tests/ops/test_sort_op.py b/python/tests/ops/test_sort_op.py index 5a31d05158..7708d58e86 100644 --- a/python/tests/ops/test_sort_op.py +++ b/python/tests/ops/test_sort_op.py @@ -74,7 +74,7 @@ def setUp(self): self.prepare_inputs() def prepare_inputs(self): - self.inputs = {"x": self.random([128, 64], "float64")} + self.inputs = {"x": self.random([8192], "float64")} self.axis = 0 self.descending = False @@ -91,13 +91,13 @@ def init_attrs(self): "shape": [1024], }, { - "shape": [2048], + "shape": [1200], }, { - "shape": [128, 64], + "shape": [64, 16], }, { - "shape": [4, 32, 16], + "shape": [4, 32, 8], }, { "shape": [16, 8, 4, 2], @@ -106,7 +106,7 @@ def init_attrs(self): "shape": [2, 8, 4, 2, 5], }, { - "shape": [4, 8, 1, 2, 32], + "shape": [4, 8, 1, 2, 16], }, { "shape": [1], @@ -139,13 +139,13 @@ def init_attrs(self): self.cls = TestSortOp self.inputs = [ { - "shape": [2048], + "shape": [1024], }, { - "shape": [64, 32], + "shape": [64, 16], }, { - "shape": [4, 32, 16], + "shape": [4, 32, 8], }, { "shape": [16, 8, 4, 2], From 111fc49206ee2f7f956459355952f92647d3e7e7 Mon Sep 17 00:00:00 2001 From: zzk0 Date: Mon, 15 May 2023 12:19:19 +0000 Subject: [PATCH 11/13] fix magic number --- cinn/runtime/cuda/cinn_cuda_runtime_source.cuh | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh b/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh index f44898d3a3..3b04456544 100644 --- a/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh +++ b/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh @@ -5,6 +5,9 @@ #include #include +constexpr int CINN_INT32_MAX = 2147483647; +constexpr int CINN_INT32_MIN = -2147483648; + extern "C" { // *************************************************************** // // float32 unary and binary operator @@ -258,8 +261,8 @@ __device__ inline float16 FN_FP16(pow)(float16 a, float16 b) { #define EXPAND_REDUCE_INT32_MARCO(MARCO, ...) \ MARCO(sum_int32, 0, int, ##__VA_ARGS__) \ MARCO(prod_int32, 1, int, ##__VA_ARGS__) \ - MARCO(max_int32, -2147483648, int, ##__VA_ARGS__) \ - MARCO(min_int32, 2147483647, int, ##__VA_ARGS__) + MARCO(max_int32, CINN_INT32_MIN, int, ##__VA_ARGS__) \ + MARCO(min_int32, CINN_INT32_MAX, int, ##__VA_ARGS__) __device__ inline int cinn_sum_int32(const int left, const int right) { return left + right; } __device__ inline int cinn_prod_int32(const int left, const int right) { return left * right; } @@ -551,7 +554,7 @@ __device__ inline int cinn_nvgpu_next_smallest_int32(int *buf, int size, int num } } if (id != -1) { - buf[id] = 2147483647; + buf[id] = CINN_INT32_MAX; return (id - begin) / stride; } return -1; From 4539f01d22714806b97c35536ec21253b212ef2d Mon Sep 17 00:00:00 2001 From: zzk0 Date: Wed, 17 May 2023 03:47:09 +0000 Subject: [PATCH 12/13] remove headers --- cinn/runtime/cuda/cinn_cuda_runtime_source.cuh | 5 ----- 1 file changed, 5 deletions(-) diff --git a/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh b/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh index 8ce642ab5d..374e644217 100644 --- a/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh +++ b/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh @@ -1,12 +1,7 @@ -#include - /** * \file This file contains all the intrinsics available to be used in CUDA code generated by CodeGen. */ -#include -#include - constexpr int CINN_INT32_MAX = 2147483647; constexpr int CINN_INT32_MIN = -2147483648; From 8aeb9f5bccb255e0edeb8f522ab5e65dfab821ff Mon Sep 17 00:00:00 2001 From: zzk0 Date: Tue, 23 May 2023 07:36:53 +0000 Subject: [PATCH 13/13] remove cpp style code from .cuh --- cinn/runtime/cuda/cinn_cuda_runtime_source.cuh | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh b/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh index e5689417b2..dc416c127f 100644 --- a/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh +++ b/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh @@ -2,10 +2,11 @@ * \file This file contains all the intrinsics available to be used in CUDA code generated by CodeGen. */ -constexpr int CINN_INT32_MAX = 2147483647; -constexpr int CINN_INT32_MIN = -2147483648; - extern "C" { + +#define CINN_INT32_MAX 2147483647 +#define CINN_INT32_MIN -2147483648 + // *************************************************************** // // bool unary and binary operator #define FN_BOOL(func) cinn_nvgpu_##func##_bool @@ -738,6 +739,8 @@ __device__ int cinn_cuda_resize_bicubic(const int *buf, // *************************************************************** // // end of macro undef +#undef CINN_INT32_MAX +#undef CINN_INT32_MIN #undef FN_BOOL #undef FN_UINT8 #undef FN_INT8