From 86fc30147367708badafd0f150e30670fa0d0d2b Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <46243324+zrr1999@users.noreply.github.com> Date: Fri, 12 May 2023 11:40:15 +0800 Subject: [PATCH] Add gather_nd tests (#1389) * add gather_nd tests * fix bug * fix bug --- python/tests/ops/test_gather_nd_op.py | 99 +++++++++++++++------------ 1 file changed, 54 insertions(+), 45 deletions(-) diff --git a/python/tests/ops/test_gather_nd_op.py b/python/tests/ops/test_gather_nd_op.py index 592d5da2dc..bc2f7f1947 100644 --- a/python/tests/ops/test_gather_nd_op.py +++ b/python/tests/ops/test_gather_nd_op.py @@ -14,76 +14,85 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys import unittest import numpy as np from op_test import OpTest, OpTestTool import paddle +import cinn from cinn.frontend import * from cinn.common import * +import logging +import os +from itertools import product + +logging.basicConfig(level=os.environ.get('LOG_LEVEL', 'INFO').upper()) +logger = logging.getLogger(name="gather_nd") @OpTestTool.skip_if(not is_compiled_with_cuda(), "x86 test will be skipped due to timeout.") class TestGatherNdOp(OpTest): def setUp(self): + self.data = [] self.init_case() def init_case(self): - self.inputs = { - 'x': self.random([2, 3, 4], 'float32'), - 'index': np.array([[1]], dtype='int32') - } + self.inputs = [{"x": [3, 4, 3], "index": [4, 1]}] + self.dtypes = ["float32"] def build_paddle_program(self, target): - x = paddle.to_tensor(self.inputs["x"], stop_gradient=False) - index = paddle.to_tensor(self.inputs["index"], stop_gradient=False) - out = paddle.gather_nd(x, index) - self.paddle_outputs = [out] + for inputs, dtype in product(self.inputs, self.dtypes): + x_shape = inputs["x"] + index_shape = inputs["index"] + x = np.random.randn(*x_shape).astype(dtype) + index = np.random.randint(0, min(x_shape), + index_shape).astype("int32") + self.data.append([x, index]) + x = paddle.to_tensor(x, stop_gradient=False) + index = paddle.to_tensor(index, stop_gradient=False) + out = paddle.gather_nd(x, index) + logger.debug(" -- The output of Paddle:\n{}".format(out)) + self.paddle_outputs.append(out) def build_cinn_program(self, target): - builder = NetBuilder("GatherNd") - x = builder.create_input( - self.nptype2cinntype(self.inputs["x"].dtype), - self.inputs["x"].shape, "x") - index = builder.create_input( - self.nptype2cinntype(self.inputs["index"].dtype), - self.inputs["index"].shape, "index") - out = builder.gather_nd(x, index) - - prog = builder.build() - res = self.get_cinn_output(prog, target, [x, index], - [self.inputs["x"], self.inputs["index"]], - [out]) - - self.cinn_outputs = [res[0]] + for i, (inputs, dtype) in enumerate(product(self.inputs, self.dtypes)): + builder = NetBuilder("gather") + x = builder.create_input( + self.nptype2cinntype(dtype), inputs["x"], "x") + index = builder.create_input(Int(32), inputs["index"], "index") + out = builder.gather_nd(x, index) + prog = builder.build() + res = self.get_cinn_output(prog, target, [x, index], self.data[i], + [out]) + logger.debug(" -- The output of CINN:\n{}".format(res)) + self.cinn_outputs.extend(res) def test_check_results(self): self.check_outputs_and_grads(all_equal=True) -class TestGatherNdCase1(TestGatherNdOp): - def init_case(self): - self.inputs = { - 'x': self.random([2, 3, 4], 'float32'), - 'index': np.array([[0, 2]], dtype='int32') - } - - -class TestGatherNdCase2(TestGatherNdOp): - def init_case(self): - self.inputs = { - 'x': self.random([2, 3, 4], 'float32'), - 'index': np.array([[1, 2, 3]], dtype='int32') - } - - -class TestGatherNdCase3(TestGatherNdOp): +class TestGatherOpAll(TestGatherNdOp): def init_case(self): - self.inputs = { - 'x': self.random([2, 3, 4], 'float64'), - 'index': np.array([[1, 2, 3]], dtype='int64') - } + self.inputs = [] + for x_shape in [ + [16], + [8, 16], + [4, 8, 16], + [2, 4, 8, 16], + [2, 4, 8, 1], + [2, 4, 8, 1024], + ]: + for j in range(1, len(x_shape)): + self.inputs.append({"x": x_shape, "index": [8, j]}) + + self.dtypes = [ + "float32", + "float64", + "int16", + "int32", + "int64", + # "uint8" # note: some types is not supported in paddle now. + ] if __name__ == "__main__":