Skip to content

Commit

Permalink
Add gather_nd tests (PaddlePaddle#1389)
Browse files Browse the repository at this point in the history
* add gather_nd tests

* fix bug

* fix bug
  • Loading branch information
zrr1999 authored and jiahy0825 committed May 25, 2023
1 parent 32abbf1 commit 86fc301
Showing 1 changed file with 54 additions and 45 deletions.
99 changes: 54 additions & 45 deletions python/tests/ops/test_gather_nd_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,76 +14,85 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import unittest
import numpy as np
from op_test import OpTest, OpTestTool
import paddle
import cinn
from cinn.frontend import *
from cinn.common import *
import logging
import os
from itertools import product

logging.basicConfig(level=os.environ.get('LOG_LEVEL', 'INFO').upper())
logger = logging.getLogger(name="gather_nd")


@OpTestTool.skip_if(not is_compiled_with_cuda(),
"x86 test will be skipped due to timeout.")
class TestGatherNdOp(OpTest):
def setUp(self):
self.data = []
self.init_case()

def init_case(self):
self.inputs = {
'x': self.random([2, 3, 4], 'float32'),
'index': np.array([[1]], dtype='int32')
}
self.inputs = [{"x": [3, 4, 3], "index": [4, 1]}]
self.dtypes = ["float32"]

def build_paddle_program(self, target):
x = paddle.to_tensor(self.inputs["x"], stop_gradient=False)
index = paddle.to_tensor(self.inputs["index"], stop_gradient=False)
out = paddle.gather_nd(x, index)
self.paddle_outputs = [out]
for inputs, dtype in product(self.inputs, self.dtypes):
x_shape = inputs["x"]
index_shape = inputs["index"]
x = np.random.randn(*x_shape).astype(dtype)
index = np.random.randint(0, min(x_shape),
index_shape).astype("int32")
self.data.append([x, index])
x = paddle.to_tensor(x, stop_gradient=False)
index = paddle.to_tensor(index, stop_gradient=False)
out = paddle.gather_nd(x, index)
logger.debug(" -- The output of Paddle:\n{}".format(out))
self.paddle_outputs.append(out)

def build_cinn_program(self, target):
builder = NetBuilder("GatherNd")
x = builder.create_input(
self.nptype2cinntype(self.inputs["x"].dtype),
self.inputs["x"].shape, "x")
index = builder.create_input(
self.nptype2cinntype(self.inputs["index"].dtype),
self.inputs["index"].shape, "index")
out = builder.gather_nd(x, index)

prog = builder.build()
res = self.get_cinn_output(prog, target, [x, index],
[self.inputs["x"], self.inputs["index"]],
[out])

self.cinn_outputs = [res[0]]
for i, (inputs, dtype) in enumerate(product(self.inputs, self.dtypes)):
builder = NetBuilder("gather")
x = builder.create_input(
self.nptype2cinntype(dtype), inputs["x"], "x")
index = builder.create_input(Int(32), inputs["index"], "index")
out = builder.gather_nd(x, index)
prog = builder.build()
res = self.get_cinn_output(prog, target, [x, index], self.data[i],
[out])
logger.debug(" -- The output of CINN:\n{}".format(res))
self.cinn_outputs.extend(res)

def test_check_results(self):
self.check_outputs_and_grads(all_equal=True)


class TestGatherNdCase1(TestGatherNdOp):
def init_case(self):
self.inputs = {
'x': self.random([2, 3, 4], 'float32'),
'index': np.array([[0, 2]], dtype='int32')
}


class TestGatherNdCase2(TestGatherNdOp):
def init_case(self):
self.inputs = {
'x': self.random([2, 3, 4], 'float32'),
'index': np.array([[1, 2, 3]], dtype='int32')
}


class TestGatherNdCase3(TestGatherNdOp):
class TestGatherOpAll(TestGatherNdOp):
def init_case(self):
self.inputs = {
'x': self.random([2, 3, 4], 'float64'),
'index': np.array([[1, 2, 3]], dtype='int64')
}
self.inputs = []
for x_shape in [
[16],
[8, 16],
[4, 8, 16],
[2, 4, 8, 16],
[2, 4, 8, 1],
[2, 4, 8, 1024],
]:
for j in range(1, len(x_shape)):
self.inputs.append({"x": x_shape, "index": [8, j]})

self.dtypes = [
"float32",
"float64",
"int16",
"int32",
"int64",
# "uint8" # note: some types is not supported in paddle now.
]


if __name__ == "__main__":
Expand Down

0 comments on commit 86fc301

Please sign in to comment.