diff --git a/python/tests/ops/test_one_hot_op.py b/python/tests/ops/test_one_hot_op.py index 4dd01e07d9..5cebb51260 100755 --- a/python/tests/ops/test_one_hot_op.py +++ b/python/tests/ops/test_one_hot_op.py @@ -17,6 +17,7 @@ import unittest import numpy as np from op_test import OpTest, OpTestTool +from op_test_helper import TestCaseHelper import paddle import paddle.nn.functional as F import cinn @@ -28,19 +29,17 @@ "x86 test will be skipped due to timeout.") class TestOneHotOp(OpTest): def setUp(self): - self.init_case() + print(f"\nRunning {self.__class__.__name__}: {self.case}") + self.prepare_inputs() - def init_case(self): - self.inputs = { - "X": np.random.random_integers(0, 9, (10)).astype("int64") - } - self.depth = 10 - self.axis = -1 + def prepare_inputs(self): + self.x_np = self.random( + shape=self.case["x_shape"], dtype=self.case["x_dtype"]) self.dtype = "float32" def build_paddle_program(self, target): - x = paddle.to_tensor(self.inputs["X"]) - out = F.one_hot(x, self.depth) + x = paddle.to_tensor(self.x_np, stop_gradient=True) + out = F.one_hot(x, num_classes=self.case["depth"]) self.paddle_outputs = [out] @@ -48,24 +47,79 @@ def build_paddle_program(self, target): # the forward result will be incorrect. def build_cinn_program(self, target): builder = NetBuilder("one_hot") - x = builder.create_input(Int(64), self.inputs["X"].shape, "X") - on_value = builder.fill_constant([1], 1, 'on_value', 'int64') - off_value = builder.fill_constant([1], 0, 'off_value', 'int64') + x = builder.create_input( + self.nptype2cinntype(self.case["x_dtype"]), self.case["x_shape"], + "x") + on_value = builder.fill_constant([1], + 1, + 'on_value', + dtype=self.case["x_dtype"]) + off_value = builder.fill_constant([1], + 0, + 'off_value', + dtype=self.case["x_dtype"]) + out = builder.one_hot( + x, + on_value, + off_value, + depth=self.case["depth"], + axis=self.case["axis"], + dtype=self.dtype) - out = builder.one_hot(x, on_value, off_value, self.depth, self.axis, - self.dtype) prog = builder.build() - forward_res = self.get_cinn_output(prog, target, [x], - [self.inputs["X"]], [out]) + res = self.get_cinn_output(prog, target, [x], [self.x_np], [out]) - self.cinn_outputs = forward_res + self.cinn_outputs = [res[0]] def test_check_results(self): - self.build_paddle_program(self.target) - self.build_cinn_program(self.target) - self.check_results(self.paddle_outputs, self.cinn_outputs, 1e-5, False, - False) + max_relative_error = self.case[ + "max_relative_error"] if "max_relative_error" in self.case else 1e-5 + self.check_outputs_and_grads(max_relative_error=max_relative_error) + + +class TestOneHotOpTest(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestOneHotOpTest" + self.cls = TestOneHotOp + self.inputs = [ + { + "x_shape": [1], + "depth": 10, + "axis": -1, + }, + { + "x_shape": [1024], + "depth": 10, + "axis": -1, + }, + { + "x_shape": [32, 64], + "depth": 10, + "axis": -1, + }, + { + "x_shape": [16, 8, 4], + "depth": 10, + "axis": -1, + }, + { + "x_shape": [16, 8, 4, 2], + "depth": 10, + "axis": -1, + }, + { + "x_shape": [16, 8, 4, 2, 1], + "depth": 10, + "axis": -1, + }, + ] + self.dtypes = [{ + "x_dtype": "int32", + }, { + "x_dtype": "int64", + }] + self.attrs = [] if __name__ == "__main__": - unittest.main() + TestOneHotOpTest().run()