From a4644c507e65ec3f38a790d9057cd82b0ee4c6cf Mon Sep 17 00:00:00 2001 From: zhwesky2010 <1183042833@qq.com> Date: Wed, 26 Jul 2023 14:23:31 +0800 Subject: [PATCH] [BUG] fix bug of float/int/long/index Tensor (#55568) --- python/paddle/fluid/dygraph/math_op_patch.py | 16 ++++++++++++---- test/legacy_test/test_math_op_patch_var_base.py | 15 +++++++++++++++ 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/dygraph/math_op_patch.py b/python/paddle/fluid/dygraph/math_op_patch.py index b82434d0588ac..1a73e6b5e9ae0 100644 --- a/python/paddle/fluid/dygraph/math_op_patch.py +++ b/python/paddle/fluid/dygraph/math_op_patch.py @@ -114,21 +114,27 @@ def _float_(var): ), "only one element variable can be converted to float." tensor = var.value().get_tensor() assert tensor._is_initialized(), "variable's tensor is not initialized" - return float(np.array(var).flatten()[0]) + if var.dtype == core.VarDesc.VarType.BF16: + var = var.astype('float32') + return float(np.array(var)) def _long_(var): numel = np.prod(var.shape) assert numel == 1, "only one element variable can be converted to long." tensor = var.value().get_tensor() assert tensor._is_initialized(), "variable's tensor is not initialized" - return int(np.array(var).flatten()[0]) + if var.dtype == core.VarDesc.VarType.BF16: + var = var.astype('float32') + return int(np.array(var)) def _int_(var): numel = np.prod(var.shape) assert numel == 1, "only one element variable can be converted to int." tensor = var.value().get_tensor() assert tensor._is_initialized(), "variable's tensor is not initialized" - return int(np.array(var).flatten()[0]) + if var.dtype == core.VarDesc.VarType.BF16: + var = var.astype('float32') + return int(np.array(var)) def _len_(var): assert var.ndim > 0, "len() of a 0-D tensor is wrong" @@ -146,7 +152,9 @@ def _index_(var): ), "only one element variable can be converted to python index." tensor = var.value().get_tensor() assert tensor._is_initialized(), "variable's tensor is not initialized" - return int(np.array(var).flatten()[0]) + if var.dtype == core.VarDesc.VarType.BF16: + var = var.astype('float32') + return int(np.array(var)) @property def _ndim_(var): diff --git a/test/legacy_test/test_math_op_patch_var_base.py b/test/legacy_test/test_math_op_patch_var_base.py index 0100f364bcddf..c392fb972697d 100644 --- a/test/legacy_test/test_math_op_patch_var_base.py +++ b/test/legacy_test/test_math_op_patch_var_base.py @@ -242,6 +242,11 @@ def test_float_int_long(self): self.assertTrue(int(a) == 100) self.assertTrue(int(a) == 100) + a = paddle.to_tensor(1000000.0, dtype='bfloat16') + self.assertTrue(float(a) == 999424.0) + self.assertTrue(int(a) == 999424) + self.assertTrue(int(a) == 999424) + def test_len(self): a_np = np.random.uniform(-1, 1, self.shape).astype(self.dtype) with fluid.dygraph.guard(): @@ -260,6 +265,16 @@ def test_index(self): str1 = "just test" self.assertTrue(str1[var1] == 's') + var1 = paddle.to_tensor(2.0, dtype='bfloat16') + i_tmp = 0 + for i in range(var1): + self.assertTrue(i == i_tmp) + i_tmp = i_tmp + 1 + list1 = [1, 2, 3, 4, 5] + self.assertTrue(list1[var1] == 3) + str1 = "just test" + self.assertTrue(str1[var1] == 's') + def test_np_left_mul(self): with fluid.dygraph.guard(): t = np.sqrt(2.0 * np.pi)