From 3b625611f3ee7de96c89eaa4a0a60e1fbc0c27e9 Mon Sep 17 00:00:00 2001 From: zhangting2020 Date: Thu, 23 Feb 2023 06:33:39 +0000 Subject: [PATCH 1/3] fix fp16 dtype checking for argmax op --- .../fluid/tests/unittests/test_arg_min_max_v2_op.py | 12 ++++++++++++ python/paddle/tensor/search.py | 10 +++++++++- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py b/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py index f5cb975019c98..3b588780944a8 100644 --- a/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py @@ -366,5 +366,17 @@ def test_argmin_dtype_type(): self.assertRaises(ValueError, test_argmin_dtype_type) +class TestArgMaxOpFp16(unittest.TestCase): + def test_fp16(self): + paddle.enable_static() + x_np = np.random.random((10, 16)).astype('float16') + x = paddle.static.data(shape=[10, 16], name='x', dtype='float16') + out = paddle.argmax(x) + exe = paddle.static.Executor() + exe.run(paddle.static.default_startup_program()) + out = exe.run(feed={'x': x_np}, fetch_list=[out]) + paddle.disable_static() + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index e16ac89953fb1..2da63d90afd07 100755 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -185,7 +185,15 @@ def argmax(x, axis=None, keepdim=False, dtype="int64", name=None): check_variable_and_dtype( x, 'x', - ['float32', 'float64', 'int16', 'int32', 'int64', 'uint8'], + [ + 'float16', + 'float32', + 'float64', + 'int16', + 'int32', + 'int64', + 'uint8', + ], 'paddle.argmax', ) check_dtype(var_dtype, 'dtype', ['int32', 'int64'], 'argmin') From 33ff7c95b72ae93b433f8199dec33cdebee94f0a Mon Sep 17 00:00:00 2001 From: zhangting2020 Date: Fri, 24 Feb 2023 03:29:03 +0000 Subject: [PATCH 2/3] run fp16 test when place is gpu --- .../tests/unittests/test_arg_min_max_v2_op.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py b/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py index 3b588780944a8..d23648ba65fe3 100644 --- a/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py @@ -368,14 +368,15 @@ def test_argmin_dtype_type(): class TestArgMaxOpFp16(unittest.TestCase): def test_fp16(self): - paddle.enable_static() x_np = np.random.random((10, 16)).astype('float16') - x = paddle.static.data(shape=[10, 16], name='x', dtype='float16') - out = paddle.argmax(x) - exe = paddle.static.Executor() - exe.run(paddle.static.default_startup_program()) - out = exe.run(feed={'x': x_np}, fetch_list=[out]) - paddle.disable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data(shape=[10, 16], name='x', dtype='float16') + out = paddle.argmax(x) + if core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + exe = paddle.static.Executor(place) + exe.run(paddle.static.default_startup_program()) + out = exe.run(feed={'x': x_np}, fetch_list=[out]) if __name__ == '__main__': From 80a62aa7fcc8cdf9d4acf51807fa7627ec2624b0 Mon Sep 17 00:00:00 2001 From: Zhang Ting Date: Fri, 24 Feb 2023 14:56:20 +0800 Subject: [PATCH 3/3] Update search.py fix doc --- python/paddle/tensor/search.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 2da63d90afd07..ae26b2927c6c8 100755 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -127,7 +127,7 @@ def argmax(x, axis=None, keepdim=False, dtype="int64", name=None): element along the provided axis. Args: - x(Tensor): An input N-D Tensor with type float32, float64, int16, + x(Tensor): An input N-D Tensor with type float16, float32, float64, int16, int32, int64, uint8. axis(int, optional): Axis to compute indices along. The effective range is [-R, R), where R is x.ndim. when axis < 0, it works the same way