From a08db4898d2730105e34208d487777a17f5ffae0 Mon Sep 17 00:00:00 2001 From: xiaoguoguo626807 <100397923+xiaoguoguo626807@users.noreply.github.com> Date: Tue, 17 Oct 2023 15:30:39 +0800 Subject: [PATCH] [pir] add unittest for pir test of Dropout (#58106) * tmp * modify unsqueeze * modify part complex bug * modify * modify cpu * [PIR]Migrate maximum into pir * Polish code * add ir_grad of static_gradient * add test * modify bug * modify * test_with_pir * close one test * add_math_op_patch * modify * modify * add mean fill_constant test * modify cpu int32 test * get_shape_tensor * delete * add default place * add dropout unittest * Update test/legacy_test/test_elementwise_div_op.py * modify review comment * modify add test --------- Co-authored-by: 0x45f --- python/paddle/nn/functional/common.py | 4 +- python/paddle/pir/math_op_patch.py | 126 ++++++++++++++------ python/paddle/tensor/random.py | 4 + python/paddle/utils/layers_utils.py | 2 + test/legacy_test/test_activation_op.py | 1 + test/legacy_test/test_dropout_op.py | 107 +++++++++++------ test/legacy_test/test_elementwise_add_op.py | 6 +- 7 files changed, 173 insertions(+), 77 deletions(-) diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 74468e719eaca..62050410b9c1a 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -1196,7 +1196,7 @@ def get_attrs(prog, dropout_prob, is_test, seed): # get mask shape input_shape = x.shape - if not in_dynamic_or_pir_mode(): + if not in_dynamic_mode(): input_shape_tensor = paddle.shape(x) drop_axes = [axis] if isinstance(axis, int) else list(axis) if min(drop_axes) < 0 or max(drop_axes) > len(input_shape) - 1: @@ -1212,7 +1212,7 @@ def get_attrs(prog, dropout_prob, is_test, seed): ) ) mask_shape = [1] * len(input_shape) - if not in_dynamic_or_pir_mode(): + if not in_dynamic_mode(): for i in drop_axes: mask_shape[i] = input_shape_tensor[i] else: diff --git a/python/paddle/pir/math_op_patch.py b/python/paddle/pir/math_op_patch.py index 1f9e0058b7752..ad7ecbc4a1cd2 100644 --- a/python/paddle/pir/math_op_patch.py +++ b/python/paddle/pir/math_op_patch.py @@ -153,6 +153,56 @@ def _item(self): ) return self + def astype(self, dtype): + """ + **Notes**: + + Cast a OpResult to a specified data type. + + Args: + + self(OpResult): The source OpResult + + dtype: The target data type + + Returns: + OpResult: OpResult with new dtype + + Examples: + In Static Graph Mode: + + .. code-block:: python + + >>> import paddle + >>> paddle.enable_static() + >>> startup_prog = paddle.static.Program() + >>> main_prog = paddle.static.Program() + >>> with paddle.static.program_guard(startup_prog, main_prog): + ... original_value = paddle.static.data(name = "new_value", shape=[2,2], dtype='float32') + ... new_value = original_value.astype('int64') + ... print("new value's dtype is: {}".format(new_value.dtype)) + ... + new OpResult's dtype is: paddle.int64 + + """ + from paddle import _C_ops + + if not isinstance(dtype, DataType): + dtype = paddle.pir.core.convert_np_dtype_to_dtype_(dtype) + return _C_ops.cast(self, dtype) + + def _scalar_add_(var, value): + return paddle.scale(var, 1.0, value) + + def _scalar_sub_(var, value): + return paddle.scale(var, 1.0, -value) + + def _scalar_rsub_(var, value): + return paddle.scale(var, -1.0, value) + + def _scalar_mul_(var, value): + return paddle.scale(var, value, 0.0) + def _scalar_div_(var, value): return paddle.scale(var, 1.0 / value, 0.0) @@ -168,7 +218,7 @@ def __impl__(self, other_var): if isinstance(other_var, float): # in all cases(+, -, *, /, **, //, %), we need cast tensor.dtype to float if self.dtype in _supported_int_dtype_: - paddle.cast(self, DataType.FLOAT32) + self = astype(self, DataType.FLOAT32) # here use `scale` replace `elementwise` to get better performance # but only +, -, *, / can use this method if scalar_method is not None: @@ -253,44 +303,6 @@ def __impl__(self, other_var): __impl__.__name__ = method_name return __impl__ - def astype(self, dtype): - """ - **Notes**: - - Cast a OpResult to a specified data type. - - Args: - - self(OpResult): The source OpResult - - dtype: The target data type - - Returns: - OpResult: OpResult with new dtype - - Examples: - In Static Graph Mode: - - .. code-block:: python - - >>> import paddle - >>> paddle.enable_static() - >>> startup_prog = paddle.static.Program() - >>> main_prog = paddle.static.Program() - >>> with paddle.static.program_guard(startup_prog, main_prog): - ... original_value = paddle.static.data(name = "new_value", shape=[2,2], dtype='float32') - ... new_value = original_value.astype('int64') - ... print("new value's dtype is: {}".format(new_value.dtype)) - ... - new OpResult's dtype is: paddle.int64 - - """ - from paddle import _C_ops - - if not isinstance(dtype, DataType): - dtype = paddle.pir.core.convert_np_dtype_to_dtype_(dtype) - return _C_ops.cast(self, dtype) - import paddle opresult_methods = [ @@ -300,6 +312,42 @@ def astype(self, dtype): ('ndimension', ndimension), ('ndim', _ndim), ('astype', astype), + ( + '__add__', + _binary_creator_('__add__', paddle.tensor.add, False, _scalar_add_), + ), + # a+b == b+a. Do not need to reverse explicitly + ( + '__radd__', + _binary_creator_( + '__radd__', paddle.tensor.add, False, _scalar_add_ + ), + ), + ( + '__sub__', + _binary_creator_( + '__sub__', paddle.tensor.subtract, False, _scalar_sub_ + ), + ), + ( + '__rsub__', + _binary_creator_( + '__rsub__', paddle.tensor.subtract, True, _scalar_rsub_ + ), + ), + ( + '__mul__', + _binary_creator_( + '__mul__', paddle.tensor.multiply, False, _scalar_mul_ + ), + ), + # a*b == b*a. Do not need to reverse explicitly + ( + '__rmul__', + _binary_creator_( + '__rmul__', paddle.tensor.multiply, False, _scalar_mul_ + ), + ), ( '__div__', _binary_creator_( diff --git a/python/paddle/tensor/random.py b/python/paddle/tensor/random.py index f87e669cf198e..feda36c2e85d4 100644 --- a/python/paddle/tensor/random.py +++ b/python/paddle/tensor/random.py @@ -796,6 +796,10 @@ def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None): if in_dynamic_or_pir_mode(): shape = paddle.utils.convert_shape_to_list(shape) + if in_pir_mode() and paddle.utils._contain_var(shape): + shape = paddle.utils.get_pir_shape_tensor( + shape, _current_expected_place() + ) return _C_ops.uniform( shape, dtype, diff --git a/python/paddle/utils/layers_utils.py b/python/paddle/utils/layers_utils.py index 3d0050f1ce506..88d19c3798874 100644 --- a/python/paddle/utils/layers_utils.py +++ b/python/paddle/utils/layers_utils.py @@ -393,6 +393,8 @@ def get_pir_shape_tensor(list_shape, place=_current_expected_place()): dim.stop_gradient = True if convert_dtype(dim.dtype) != 'int32': dim = paddle.cast(x=dim, dtype='int32') + if dim.shape == []: + dim = paddle.reshape(dim, [-1]) shape_tensor_list.append(dim) else: temp_out = paddle.full([1], dim, core.DataType.INT32, place) diff --git a/test/legacy_test/test_activation_op.py b/test/legacy_test/test_activation_op.py index b91230523ff5e..f9c06af18792a 100644 --- a/test/legacy_test/test_activation_op.py +++ b/test/legacy_test/test_activation_op.py @@ -2672,6 +2672,7 @@ def setUp(self): self.rev_comp_rtol = 1e-8 self.rev_comp_atol = 1e-8 + @test_with_pir_api def test_static_api(self): with static_guard(): with paddle.static.program_guard(paddle.static.Program()): diff --git a/test/legacy_test/test_dropout_op.py b/test/legacy_test/test_dropout_op.py index 3ef733101a671..433b9eeff7056 100644 --- a/test/legacy_test/test_dropout_op.py +++ b/test/legacy_test/test_dropout_op.py @@ -26,6 +26,7 @@ from paddle.base.executor import scope_guard from paddle.decomposition import decompose from paddle.incubate.autograd import primapi +from paddle.pir_utils import test_with_pir_api def dropout_wapper( @@ -523,9 +524,11 @@ def setUp(self): if core.is_compiled_with_cuda(): self.places.append(base.CUDAPlace(0)) + @test_with_pir_api def check_static_result(self, place): paddle.enable_static() - with base.program_guard(base.Program(), base.Program()): + main_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog): input = paddle.static.data( name="input", shape=[-1, -1], dtype="float32" ) @@ -574,7 +577,6 @@ def check_static_result(self, place): training=False, mode='downscale_in_infer', ) - res10 = paddle.nn.functional.dropout(x=input, p=1.0, training=True) res11 = paddle.nn.functional.dropout(x=input, p=0.0) res12 = paddle.nn.functional.dropout( x=input, @@ -584,13 +586,8 @@ def check_static_result(self, place): mode='upscale_in_train', ) - res13 = paddle.nn.functional.dropout( - x=input, p=0.7, axis=1, training=True, mode='upscale_in_train' - ) - in_np = np.ones([40, 40]).astype("float32") res_np = in_np - res_np2 = np.zeros_like(in_np) exe = base.Executor(place) res_list = [ @@ -608,26 +605,39 @@ def check_static_result(self, place): ] for res in res_list: fetches = exe.run( - base.default_main_program(), + main_prog, feed={"input": in_np}, fetch_list=[res], ) np.testing.assert_allclose(fetches[0], res_np, rtol=1e-05) + + @test_with_pir_api + def check_static_result2(self, place): + paddle.enable_static() + main_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog): + input = paddle.static.data( + name="input", shape=[-1, -1], dtype="float32" + ) + res10 = paddle.nn.functional.dropout(x=input, p=1.0, training=True) + res13 = paddle.nn.functional.dropout( + x=input, p=0.7, axis=1, training=True, mode='upscale_in_train' + ) + in_np = np.ones([40, 40]).astype("float32") + res_np2 = np.zeros_like(in_np) + + exe = base.Executor(place) fetches2 = exe.run( - base.default_main_program(), + main_prog, feed={"input": in_np}, - fetch_list=[res10], + fetch_list=[res10, res13], ) np.testing.assert_allclose(fetches2[0], res_np2, rtol=1e-05) - fetches3 = exe.run( - base.default_main_program(), - feed={"input": in_np}, - fetch_list=[res13], - ) def test_static(self): for place in self.places: self.check_static_result(place=place) + self.check_static_result2(place=place) def test_dygraph(self): for place in self.places: @@ -769,6 +779,13 @@ def test_dtype(): self.assertRaises(TypeError, test_dtype) + @test_with_pir_api + def test_errors2(self): + paddle.enable_static() + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, startup_prog): + def test_pdtype(): # p should be int or float x2 = paddle.static.data( @@ -861,9 +878,12 @@ def setUp(self): if core.is_compiled_with_cuda(): self.places.append(base.CUDAPlace(0)) + @test_with_pir_api def check_static_result(self, place): paddle.enable_static() - with base.program_guard(base.Program(), base.Program()): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, startup_prog): input = paddle.static.data( name="input", shape=[2, 3, 4, 5], dtype="float32" ) @@ -881,7 +901,7 @@ def check_static_result(self, place): res_list = [res1, res2] for res in res_list: fetches = exe.run( - base.default_main_program(), + main_prog, feed={"input": in_np}, fetch_list=[res], ) @@ -911,9 +931,12 @@ def test_dygraph(self): class TestDropout2DFAPIError(unittest.TestCase): + @test_with_pir_api def test_errors(self): paddle.enable_static() - with program_guard(Program(), Program()): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, startup_prog): def test_xdim(): # dimentions of x should be 4 @@ -954,6 +977,7 @@ def test_dygraph(self): result.numpy(), result_np, rtol=1e-05 ) + @test_with_pir_api def test_static_fp16_with_gpu(self): if paddle.base.core.is_compiled_with_cuda(): place = paddle.CUDAPlace(0) @@ -986,9 +1010,12 @@ def setUp(self): if core.is_compiled_with_cuda(): self.places.append(base.CUDAPlace(0)) + @test_with_pir_api def check_static_result(self, place): paddle.enable_static() - with base.program_guard(base.Program(), base.Program()): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, startup_prog): input = paddle.static.data( name="input", shape=[2, 3, 4, 5, 6], dtype="float32" ) @@ -1006,7 +1033,7 @@ def check_static_result(self, place): res_list = [res1, res2] for res in res_list: fetches = exe.run( - base.default_main_program(), + main_prog, feed={"input": in_np}, fetch_list=[res], ) @@ -1036,9 +1063,12 @@ def test_dygraph(self): class TestDropout3DFAPIError(unittest.TestCase): + @test_with_pir_api def test_errors(self): paddle.enable_static() - with program_guard(Program(), Program()): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, startup_prog): def test_xdim(): # dimentions of x should be 5 @@ -1087,8 +1117,12 @@ def setUp(self): if core.is_compiled_with_cuda(): self.places.append(base.CUDAPlace(0)) + @test_with_pir_api def check_static_result(self, place): - with base.program_guard(base.Program(), base.Program()): + paddle.enable_static() + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, startup_prog): input = paddle.static.data( name="input", shape=[40, 40], dtype="float32" ) @@ -1103,20 +1137,15 @@ def check_static_result(self, place): res_np3 = np.zeros_like(in_np) exe = base.Executor(place) - res_list = [res1, res2] - for res in res_list: - fetches = exe.run( - base.default_main_program(), - feed={"input": in_np}, - fetch_list=[res], - ) - np.testing.assert_allclose(fetches[0], res_np, rtol=1e-05) + fetches = exe.run( - base.default_main_program(), + main_prog, feed={"input": in_np}, - fetch_list=[res3], + fetch_list=[res1, res2, res3], ) - np.testing.assert_allclose(fetches[0], res_np3, rtol=1e-05) + np.testing.assert_allclose(fetches[0], res_np, rtol=1e-05) + np.testing.assert_allclose(fetches[1], res_np, rtol=1e-05) + np.testing.assert_allclose(fetches[2], res_np3, rtol=1e-05) def test_static(self): for place in self.places: @@ -1155,6 +1184,13 @@ def test_Variable(): self.assertRaises(TypeError, test_Variable) + @test_with_pir_api + def test_errors2(self): + paddle.enable_static() + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, startup_prog): + def test_dtype(): # the input dtype of dropout must be float32 or float64 xr = paddle.static.data( @@ -1203,6 +1239,7 @@ def test_dygraph(self): result.numpy(), result_np, rtol=1e-05 ) + @test_with_pir_api def test_static_fp16_gpu(self): if paddle.base.core.is_compiled_with_cuda(): place = paddle.CUDAPlace(0) @@ -1362,9 +1399,9 @@ def api_case(self, x): def run_static(self, x): paddle.seed(2022) - main_program = Program() paddle.enable_static() - with program_guard(main_program): + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program): input = paddle.static.data(shape=x.shape, name='x', dtype='float32') out = self.api_case(input) sgd = paddle.optimizer.SGD(learning_rate=0.1) diff --git a/test/legacy_test/test_elementwise_add_op.py b/test/legacy_test/test_elementwise_add_op.py index 34e5e264aa3d9..d3039ca365d34 100644 --- a/test/legacy_test/test_elementwise_add_op.py +++ b/test/legacy_test/test_elementwise_add_op.py @@ -772,7 +772,11 @@ def test_static_add(self): b = paddle.full([4, 5, 6], True, dtype='bool') c = a + b self.assertTrue(c.dtype == core.VarDesc.VarType.FP32) - paddle.enable_static() + with paddle.pir_utils.IrGuard(): + a = 1.5 + b = paddle.full([4, 5, 6], True, dtype='bool') + c = a + b + self.assertTrue(c.dtype == core.DataType.FLOAT32) def test_dygraph_add(self): paddle.disable_static()