diff --git a/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py b/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py index f5cb975019c98..3b588780944a8 100644 --- a/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py @@ -366,5 +366,17 @@ def test_argmin_dtype_type(): self.assertRaises(ValueError, test_argmin_dtype_type) +class TestArgMaxOpFp16(unittest.TestCase): + def test_fp16(self): + paddle.enable_static() + x_np = np.random.random((10, 16)).astype('float16') + x = paddle.static.data(shape=[10, 16], name='x', dtype='float16') + out = paddle.argmax(x) + exe = paddle.static.Executor() + exe.run(paddle.static.default_startup_program()) + out = exe.run(feed={'x': x_np}, fetch_list=[out]) + paddle.disable_static() + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index e16ac89953fb1..2da63d90afd07 100755 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -185,7 +185,15 @@ def argmax(x, axis=None, keepdim=False, dtype="int64", name=None): check_variable_and_dtype( x, 'x', - ['float32', 'float64', 'int16', 'int32', 'int64', 'uint8'], + [ + 'float16', + 'float32', + 'float64', + 'int16', + 'int32', + 'int64', + 'uint8', + ], 'paddle.argmax', ) check_dtype(var_dtype, 'dtype', ['int32', 'int64'], 'argmin')