Skip to content

Commit

Permalink
run fp16 test when place is gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangting2020 committed Feb 24, 2023
1 parent 3b62561 commit 33ff7c9
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,14 +368,15 @@ def test_argmin_dtype_type():

class TestArgMaxOpFp16(unittest.TestCase):
def test_fp16(self):
paddle.enable_static()
x_np = np.random.random((10, 16)).astype('float16')
x = paddle.static.data(shape=[10, 16], name='x', dtype='float16')
out = paddle.argmax(x)
exe = paddle.static.Executor()
exe.run(paddle.static.default_startup_program())
out = exe.run(feed={'x': x_np}, fetch_list=[out])
paddle.disable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data(shape=[10, 16], name='x', dtype='float16')
out = paddle.argmax(x)
if core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
out = exe.run(feed={'x': x_np}, fetch_list=[out])


if __name__ == '__main__':
Expand Down

0 comments on commit 33ff7c9

Please sign in to comment.