Skip to content

Commit

Permalink
fix fp16 dtype checking for argmax op
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangting2020 committed Feb 23, 2023
1 parent 8d325d8 commit 3b62561
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
12 changes: 12 additions & 0 deletions python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,5 +366,17 @@ def test_argmin_dtype_type():
self.assertRaises(ValueError, test_argmin_dtype_type)


class TestArgMaxOpFp16(unittest.TestCase):
def test_fp16(self):
paddle.enable_static()
x_np = np.random.random((10, 16)).astype('float16')
x = paddle.static.data(shape=[10, 16], name='x', dtype='float16')
out = paddle.argmax(x)
exe = paddle.static.Executor()
exe.run(paddle.static.default_startup_program())
out = exe.run(feed={'x': x_np}, fetch_list=[out])
paddle.disable_static()


if __name__ == '__main__':
unittest.main()
10 changes: 9 additions & 1 deletion python/paddle/tensor/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,15 @@ def argmax(x, axis=None, keepdim=False, dtype="int64", name=None):
check_variable_and_dtype(
x,
'x',
['float32', 'float64', 'int16', 'int32', 'int64', 'uint8'],
[
'float16',
'float32',
'float64',
'int16',
'int32',
'int64',
'uint8',
],
'paddle.argmax',
)
check_dtype(var_dtype, 'dtype', ['int32', 'int64'], 'argmin')
Expand Down

0 comments on commit 3b62561

Please sign in to comment.