diff --git a/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py b/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py index 3b588780944a8..d23648ba65fe3 100644 --- a/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py @@ -368,14 +368,15 @@ def test_argmin_dtype_type(): class TestArgMaxOpFp16(unittest.TestCase): def test_fp16(self): - paddle.enable_static() x_np = np.random.random((10, 16)).astype('float16') - x = paddle.static.data(shape=[10, 16], name='x', dtype='float16') - out = paddle.argmax(x) - exe = paddle.static.Executor() - exe.run(paddle.static.default_startup_program()) - out = exe.run(feed={'x': x_np}, fetch_list=[out]) - paddle.disable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data(shape=[10, 16], name='x', dtype='float16') + out = paddle.argmax(x) + if core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + exe = paddle.static.Executor(place) + exe.run(paddle.static.default_startup_program()) + out = exe.run(feed={'x': x_np}, fetch_list=[out]) if __name__ == '__main__':