From 6ff811335984a3b973ee73e576a52ff289cb98f8 Mon Sep 17 00:00:00 2001
From: Liyulingyue <852433440@qq.com>
Date: Sun, 4 Jun 2023 21:33:07 +0800
Subject: [PATCH 01/18] add float16/32/64 in TestBatchNormTrainOp

---
 python/tests/ops/test_batch_norm_op.py | 76 ++++++++++++++------------
 1 file changed, 40 insertions(+), 36 deletions(-)

diff --git a/python/tests/ops/test_batch_norm_op.py b/python/tests/ops/test_batch_norm_op.py
index e4d202b750..232e70acf8 100644
--- a/python/tests/ops/test_batch_norm_op.py
+++ b/python/tests/ops/test_batch_norm_op.py
@@ -31,61 +31,51 @@ def setUp(self):
 
     def init_case(self):
         self.num_channels = 16
-        self.inputs = {
+        self.inputs = [{
             "x":
             self.random([2, self.num_channels, 8, 8], "float32", 0.0, 1.0),
             "dout":
             self.random([2, self.num_channels, 8, 8], "float32", 1e-7, 1e-6),
-        }
+        }]
 
     def build_paddle_program(self, target):
-        x = paddle.to_tensor(self.inputs["x"])
-        batch_norm = paddle.nn.BatchNorm(
-            self.num_channels, act=None, is_test=False)
-        out = batch_norm(x)
+        for inputs in self.inputs:
+            x = paddle.to_tensor(inputs["x"])
+            batch_norm = paddle.nn.BatchNorm(
+                self.num_channels, act=None, is_test=False)
+            out = batch_norm(x)
 
-        self.paddle_outputs = [out]
+            self.paddle_outputs = [out]
 
     # Note: If the forward and backward operators are run in the same program,
     # the forward result will be incorrect.
     def build_cinn_program(self, target):
-        builder = NetBuilder("batch_norm")
-        x = builder.create_input(
-            self.nptype2cinntype(self.inputs["x"].dtype),
-            self.inputs["x"].shape, "x")
-        scale = builder.fill_constant([self.num_channels], 1.0, 'scale',
-                                      'float32')
-        bias = builder.fill_constant([self.num_channels], 0.0, 'bias',
-                                     'float32')
-        mean = builder.fill_constant([self.num_channels], 0.0, 'mean',
-                                     'float32')
-        variance = builder.fill_constant([self.num_channels], 1.0, 'variance',
+        for inputs in self.inputs:
+            builder = NetBuilder("batch_norm")
+            x = builder.create_input(
+                self.nptype2cinntype(inputs["x"].dtype), inputs["x"].shape,
+                "x")
+            scale = builder.fill_constant([self.num_channels], 1.0, 'scale',
+                                          'float32')
+            bias = builder.fill_constant([self.num_channels], 0.0, 'bias',
                                          'float32')
+            mean = builder.fill_constant([self.num_channels], 0.0, 'mean',
+                                         'float32')
+            variance = builder.fill_constant([self.num_channels], 1.0,
+                                             'variance', 'float32')
 
-        out = builder.batchnorm(x, scale, bias, mean, variance, is_test=False)
+            out = builder.batchnorm(
+                x, scale, bias, mean, variance, is_test=False)
 
-        prog = builder.build()
-        forward_res = self.get_cinn_output(
-            prog, target, [x], [self.inputs["x"]], out, passes=[])
-        self.cinn_outputs = [forward_res[0]]
+            prog = builder.build()
+            forward_res = self.get_cinn_output(
+                prog, target, [x], [inputs["x"]], out, passes=[])
+            self.cinn_outputs.extend(forward_res)
 
     def test_check_results(self):
         self.check_outputs_and_grads()
 
 
-# Reopen after decomposer infer dtype fixed
-class TestBatchNormTrainFP16(TestBatchNormTrainOp):
-    def init_case(self):
-        self.num_channels = 16
-        self.inputs = {
-            "x": self.random([2, self.num_channels, 8, 8], "float16"),
-            "dout": self.random([2, self.num_channels, 8, 8], "float16"),
-        }
-
-    def test_check_results(self):
-        self.check_outputs_and_grads(max_relative_error=1e-3)
-
-
 @OpTestTool.skip_if(not is_compiled_with_cuda(),
                     "x86 test will be skipped due to timeout.")
 class TestBatchNormBackwardOp(OpTest):
@@ -227,5 +217,19 @@ def test_check_results(self):
         self.check_outputs_and_grads()
 
 
+class TestBatchNormTrainOpAll(TestBatchNormTrainOp):
+    def init_case(self):
+        for x_shape in [
+            [2, 16, 8, 8],
+        ]:
+            for x_type in ["float16", "float32", "float64"]:
+                self.inputs.append({
+                    "x":
+                    self.random(x_shape, x_type, 0.0, 1.0),
+                    "dout":
+                    self.random(x_shape, x_type, 1e-7, 1e-6),
+                })
+
+
 if __name__ == "__main__":
     unittest.main()

From a9b7e3a02a6b73d363f02c30a2b447d1c295e259 Mon Sep 17 00:00:00 2001
From: Liyulingyue <852433440@qq.com>
Date: Sun, 4 Jun 2023 21:38:39 +0800
Subject: [PATCH 02/18] add float16/32/64 in TestBatchNormTrainOp

---
 python/tests/ops/test_batch_norm_op.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/python/tests/ops/test_batch_norm_op.py b/python/tests/ops/test_batch_norm_op.py
index 232e70acf8..3ca4b1b03d 100644
--- a/python/tests/ops/test_batch_norm_op.py
+++ b/python/tests/ops/test_batch_norm_op.py
@@ -73,7 +73,7 @@ def build_cinn_program(self, target):
             self.cinn_outputs.extend(forward_res)
 
     def test_check_results(self):
-        self.check_outputs_and_grads()
+        self.check_outputs_and_grads(max_relative_error=1e-3)
 
 
 @OpTestTool.skip_if(not is_compiled_with_cuda(),

From 5108dd4a5fd9fed8f57739103fb1e247827d60ec Mon Sep 17 00:00:00 2001
From: Liyulingyue <852433440@qq.com>
Date: Mon, 5 Jun 2023 05:20:05 +0800
Subject: [PATCH 03/18] add float16/32/64 in TestBatchNormTrainOp

---
 python/tests/ops/test_batch_norm_op.py | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/python/tests/ops/test_batch_norm_op.py b/python/tests/ops/test_batch_norm_op.py
index 3ca4b1b03d..b230d55e75 100644
--- a/python/tests/ops/test_batch_norm_op.py
+++ b/python/tests/ops/test_batch_norm_op.py
@@ -45,7 +45,7 @@ def build_paddle_program(self, target):
                 self.num_channels, act=None, is_test=False)
             out = batch_norm(x)
 
-            self.paddle_outputs = [out]
+            self.paddle_outputs.append(out)
 
     # Note: If the forward and backward operators are run in the same program,
     # the forward result will be incorrect.
@@ -219,6 +219,7 @@ def test_check_results(self):
 
 class TestBatchNormTrainOpAll(TestBatchNormTrainOp):
     def init_case(self):
+        self.inputs = []
         for x_shape in [
             [2, 16, 8, 8],
         ]:

From fb821afde2b5106aa21634a794bd2a68208c6267 Mon Sep 17 00:00:00 2001
From: Liyulingyue <852433440@qq.com>
Date: Mon, 5 Jun 2023 17:53:19 +0800
Subject: [PATCH 04/18] add float16/32/64 in TestBatchNormTrainOp

---
 python/tests/ops/test_batch_norm_op.py | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/python/tests/ops/test_batch_norm_op.py b/python/tests/ops/test_batch_norm_op.py
index b230d55e75..6a86f5fdd5 100644
--- a/python/tests/ops/test_batch_norm_op.py
+++ b/python/tests/ops/test_batch_norm_op.py
@@ -219,9 +219,10 @@ def test_check_results(self):
 
 class TestBatchNormTrainOpAll(TestBatchNormTrainOp):
     def init_case(self):
+        self.num_channels = 16
         self.inputs = []
         for x_shape in [
-            [2, 16, 8, 8],
+            [2, self.num_channels, 8, 8],
         ]:
             for x_type in ["float16", "float32", "float64"]:
                 self.inputs.append({

From 2fc0973939980d7fe2f1974c09e897e4126412e5 Mon Sep 17 00:00:00 2001
From: Liyulingyue <852433440@qq.com>
Date: Tue, 6 Jun 2023 17:41:21 +0800
Subject: [PATCH 05/18] add float16/32/64 in TestBatchNormTrainOp

---
 python/tests/ops/test_batch_norm_op.py | 14 +++++++++-----
 1 file changed, 9 insertions(+), 5 deletions(-)

diff --git a/python/tests/ops/test_batch_norm_op.py b/python/tests/ops/test_batch_norm_op.py
index 6a86f5fdd5..cc18d94e7d 100644
--- a/python/tests/ops/test_batch_norm_op.py
+++ b/python/tests/ops/test_batch_norm_op.py
@@ -36,6 +36,8 @@ def init_case(self):
             self.random([2, self.num_channels, 8, 8], "float32", 0.0, 1.0),
             "dout":
             self.random([2, self.num_channels, 8, 8], "float32", 1e-7, 1e-6),
+            "dtype":
+            "float32",
         }]
 
     def build_paddle_program(self, target):
@@ -56,13 +58,13 @@ def build_cinn_program(self, target):
                 self.nptype2cinntype(inputs["x"].dtype), inputs["x"].shape,
                 "x")
             scale = builder.fill_constant([self.num_channels], 1.0, 'scale',
-                                          'float32')
+                                          inputs["dtype"])
             bias = builder.fill_constant([self.num_channels], 0.0, 'bias',
-                                         'float32')
+                                         inputs["dtype"])
             mean = builder.fill_constant([self.num_channels], 0.0, 'mean',
-                                         'float32')
+                                         inputs["dtype"])
             variance = builder.fill_constant([self.num_channels], 1.0,
-                                             'variance', 'float32')
+                                             'variance', inputs["dtype"])
 
             out = builder.batchnorm(
                 x, scale, bias, mean, variance, is_test=False)
@@ -70,7 +72,7 @@ def build_cinn_program(self, target):
             prog = builder.build()
             forward_res = self.get_cinn_output(
                 prog, target, [x], [inputs["x"]], out, passes=[])
-            self.cinn_outputs.extend(forward_res)
+            self.cinn_outputs.append(forward_res[0])
 
     def test_check_results(self):
         self.check_outputs_and_grads(max_relative_error=1e-3)
@@ -230,6 +232,8 @@ def init_case(self):
                     self.random(x_shape, x_type, 0.0, 1.0),
                     "dout":
                     self.random(x_shape, x_type, 1e-7, 1e-6),
+                    "dtype":
+                    x_type,
                 })
 
 

From eca5a860b3499bcf00b2cc3c30db8a5617c85660 Mon Sep 17 00:00:00 2001
From: Liyulingyue <852433440@qq.com>
Date: Tue, 6 Jun 2023 19:39:50 +0800
Subject: [PATCH 06/18] delete float64 in TestBatchNormTrainOp

---
 python/tests/ops/test_batch_norm_op.py | 46 ++++++++++++--------------
 1 file changed, 21 insertions(+), 25 deletions(-)

diff --git a/python/tests/ops/test_batch_norm_op.py b/python/tests/ops/test_batch_norm_op.py
index cc18d94e7d..ab0046cb2c 100644
--- a/python/tests/ops/test_batch_norm_op.py
+++ b/python/tests/ops/test_batch_norm_op.py
@@ -36,8 +36,6 @@ def init_case(self):
             self.random([2, self.num_channels, 8, 8], "float32", 0.0, 1.0),
             "dout":
             self.random([2, self.num_channels, 8, 8], "float32", 1e-7, 1e-6),
-            "dtype":
-            "float32",
         }]
 
     def build_paddle_program(self, target):
@@ -58,13 +56,13 @@ def build_cinn_program(self, target):
                 self.nptype2cinntype(inputs["x"].dtype), inputs["x"].shape,
                 "x")
             scale = builder.fill_constant([self.num_channels], 1.0, 'scale',
-                                          inputs["dtype"])
+                                          "float32")
             bias = builder.fill_constant([self.num_channels], 0.0, 'bias',
-                                         inputs["dtype"])
+                                         "float32")
             mean = builder.fill_constant([self.num_channels], 0.0, 'mean',
-                                         inputs["dtype"])
+                                         "float32")
             variance = builder.fill_constant([self.num_channels], 1.0,
-                                             'variance', inputs["dtype"])
+                                             'variance', "float32")
 
             out = builder.batchnorm(
                 x, scale, bias, mean, variance, is_test=False)
@@ -75,7 +73,23 @@ def build_cinn_program(self, target):
             self.cinn_outputs.append(forward_res[0])
 
     def test_check_results(self):
-        self.check_outputs_and_grads(max_relative_error=1e-3)
+        self.check_outputs_and_grads()
+
+
+class TestBatchNormTrainOpAll(TestBatchNormTrainOp):
+    def init_case(self):
+        self.num_channels = 16
+        self.inputs = []
+        for x_shape in [
+            [2, self.num_channels, 8, 8],
+        ]:
+            for x_type in ["float16", "float32"]:
+                self.inputs.append({
+                    "x":
+                    self.random(x_shape, x_type, 0.0, 1.0),
+                    "dout":
+                    self.random(x_shape, x_type, 1e-7, 1e-6),
+                })
 
 
 @OpTestTool.skip_if(not is_compiled_with_cuda(),
@@ -219,23 +233,5 @@ def test_check_results(self):
         self.check_outputs_and_grads()
 
 
-class TestBatchNormTrainOpAll(TestBatchNormTrainOp):
-    def init_case(self):
-        self.num_channels = 16
-        self.inputs = []
-        for x_shape in [
-            [2, self.num_channels, 8, 8],
-        ]:
-            for x_type in ["float16", "float32", "float64"]:
-                self.inputs.append({
-                    "x":
-                    self.random(x_shape, x_type, 0.0, 1.0),
-                    "dout":
-                    self.random(x_shape, x_type, 1e-7, 1e-6),
-                    "dtype":
-                    x_type,
-                })
-
-
 if __name__ == "__main__":
     unittest.main()

From a3d3c9c2d428166d1e773dc2d3110075918b6e24 Mon Sep 17 00:00:00 2001
From: Liyulingyue <852433440@qq.com>
Date: Tue, 6 Jun 2023 20:17:37 +0800
Subject: [PATCH 07/18] add more shape

---
 python/tests/ops/test_batch_norm_op.py | 28 +++++++++++++-------------
 1 file changed, 14 insertions(+), 14 deletions(-)

diff --git a/python/tests/ops/test_batch_norm_op.py b/python/tests/ops/test_batch_norm_op.py
index ab0046cb2c..a9f9558c31 100644
--- a/python/tests/ops/test_batch_norm_op.py
+++ b/python/tests/ops/test_batch_norm_op.py
@@ -30,19 +30,20 @@ def setUp(self):
         self.init_case()
 
     def init_case(self):
-        self.num_channels = 16
         self.inputs = [{
             "x":
-            self.random([2, self.num_channels, 8, 8], "float32", 0.0, 1.0),
+            self.random([2, 16, 8, 8], "float32", 0.0, 1.0),
             "dout":
-            self.random([2, self.num_channels, 8, 8], "float32", 1e-7, 1e-6),
+            self.random([2, 16, 8, 8], "float32", 1e-7, 1e-6),
+            "num_channels":
+            16
         }]
 
     def build_paddle_program(self, target):
         for inputs in self.inputs:
             x = paddle.to_tensor(inputs["x"])
             batch_norm = paddle.nn.BatchNorm(
-                self.num_channels, act=None, is_test=False)
+                inputs["num_channels"], act=None, is_test=False)
             out = batch_norm(x)
 
             self.paddle_outputs.append(out)
@@ -55,13 +56,13 @@ def build_cinn_program(self, target):
             x = builder.create_input(
                 self.nptype2cinntype(inputs["x"].dtype), inputs["x"].shape,
                 "x")
-            scale = builder.fill_constant([self.num_channels], 1.0, 'scale',
-                                          "float32")
-            bias = builder.fill_constant([self.num_channels], 0.0, 'bias',
+            scale = builder.fill_constant([inputs["num_channels"]], 1.0,
+                                          'scale', "float32")
+            bias = builder.fill_constant([inputs["num_channels"]], 0.0, 'bias',
                                          "float32")
-            mean = builder.fill_constant([self.num_channels], 0.0, 'mean',
+            mean = builder.fill_constant([inputs["num_channels"]], 0.0, 'mean',
                                          "float32")
-            variance = builder.fill_constant([self.num_channels], 1.0,
+            variance = builder.fill_constant([inputs["num_channels"]], 1.0,
                                              'variance', "float32")
 
             out = builder.batchnorm(
@@ -73,22 +74,21 @@ def build_cinn_program(self, target):
             self.cinn_outputs.append(forward_res[0])
 
     def test_check_results(self):
-        self.check_outputs_and_grads()
+        self.check_outputs_and_grads(max_relative_error=1e-3)
 
 
 class TestBatchNormTrainOpAll(TestBatchNormTrainOp):
     def init_case(self):
-        self.num_channels = 16
         self.inputs = []
-        for x_shape in [
-            [2, self.num_channels, 8, 8],
-        ]:
+        for x_shape in [[2, 16, 8, 8], [2, 16, 8, 1], [2, 16, 2048, 8]]:
             for x_type in ["float16", "float32"]:
                 self.inputs.append({
                     "x":
                     self.random(x_shape, x_type, 0.0, 1.0),
                     "dout":
                     self.random(x_shape, x_type, 1e-7, 1e-6),
+                    "num_channels":
+                    x_shape[1]
                 })
 
 

From 3a31828d92d6d7b7e04792f3effb08d4d9f17485 Mon Sep 17 00:00:00 2001
From: Liyulingyue <852433440@qq.com>
Date: Tue, 6 Jun 2023 20:54:24 +0800
Subject: [PATCH 08/18] add backward

---
 python/tests/ops/test_batch_norm_op.py | 131 +++++++++++++------------
 1 file changed, 69 insertions(+), 62 deletions(-)

diff --git a/python/tests/ops/test_batch_norm_op.py b/python/tests/ops/test_batch_norm_op.py
index a9f9558c31..31a093661d 100644
--- a/python/tests/ops/test_batch_norm_op.py
+++ b/python/tests/ops/test_batch_norm_op.py
@@ -99,87 +99,94 @@ def setUp(self):
         self.init_case()
 
     def init_case(self):
-        self.num_channels = 16
-        self.inputs = {
+        self.inputs = [{
             "x":
-            self.random([2, self.num_channels, 8, 8], "float32", 0.0, 10.0),
+            self.random([2, 16, 8, 8], "float32", 0.0, 10.0),
             "dout":
-            self.random([2, self.num_channels, 8, 8], "float32", 1e-7, 1e-6),
-        }
+            self.random([2, 16, 8, 8], "float32", 1e-7, 1e-6),
+            "num_channels":
+            16
+        }]
 
     def build_paddle_program(self, target):
-        x = paddle.to_tensor(self.inputs["x"], stop_gradient=False)
-        batch_norm = paddle.nn.BatchNorm(
-            self.num_channels, act=None, is_test=False)
-        out = batch_norm(x)
+        for inputs in self.inputs:
+            x = paddle.to_tensor(inputs["x"], stop_gradient=False)
+            batch_norm = paddle.nn.BatchNorm(
+                inputs["num_channels"], act=None, is_test=False)
+            out = batch_norm(x)
 
-        self.paddle_outputs = [out]
-        self.paddle_grads = self.get_paddle_grads([out], [x],
-                                                  [self.inputs["dout"]])
+            self.paddle_outputs.append(out)
+            grad = self.get_paddle_grads([out], [x], [inputs["dout"]])
+            self.paddle_grads.append(grad[0])
 
     # Note: If the forward and backward operators are run in the same program,
     # the forward result will be incorrect.
     def build_cinn_program(self, target):
-        builder = NetBuilder("batch_norm")
-        x = builder.create_input(
-            self.nptype2cinntype(self.inputs["x"].dtype),
-            self.inputs["x"].shape, "x")
-        scale = builder.fill_constant([self.num_channels], 1.0, 'scale',
-                                      'float32')
-        bias = builder.fill_constant([self.num_channels], 0.0, 'bias',
-                                     'float32')
-        mean = builder.fill_constant([self.num_channels], 0.0, 'mean',
-                                     'float32')
-        variance = builder.fill_constant([self.num_channels], 1.0, 'variance',
+        for inputs in self.inputs:
+            builder = NetBuilder("batch_norm")
+            x = builder.create_input(
+                self.nptype2cinntype(inputs["x"].dtype), inputs["x"].shape,
+                "x")
+            scale = builder.fill_constant([inputs["num_channels"]], 1.0,
+                                          'scale', 'float32')
+            bias = builder.fill_constant([inputs["num_channels"]], 0.0, 'bias',
+                                         'float32')
+            mean = builder.fill_constant([inputs["num_channels"]], 0.0, 'mean',
                                          'float32')
+            variance = builder.fill_constant([inputs["num_channels"]], 1.0,
+                                             'variance', 'float32')
 
-        out = builder.batchnorm(x, scale, bias, mean, variance, is_test=False)
+            out = builder.batchnorm(
+                x, scale, bias, mean, variance, is_test=False)
 
-        prog = builder.build()
-        forward_res = self.get_cinn_output(
-            prog, target, [x], [self.inputs["x"]], out, passes=[])
-        self.cinn_outputs = [forward_res[0]]
+            prog = builder.build()
+            forward_res = self.get_cinn_output(
+                prog, target, [x], [inputs["x"]], out, passes=[])
+            self.cinn_outputs.append(forward_res[0])
 
-        builder_grad = NetBuilder("batch_norm_grad")
-        dout = builder_grad.create_input(
-            self.nptype2cinntype(self.inputs["dout"].dtype),
-            self.inputs["dout"].shape, "dout")
-        x_g = builder_grad.create_input(
-            self.nptype2cinntype(self.inputs["x"].dtype),
-            self.inputs["x"].shape, "x_g")
-        scale_g = builder_grad.fill_constant(scale.shape(), 1.0, 'scale_g',
-                                             'float32')
-        save_mean = builder_grad.create_input(
-            self.nptype2cinntype('float32'), out[1].shape(), "save_mean")
-        save_variance = builder_grad.create_input(
-            self.nptype2cinntype('float32'), out[2].shape(), "save_variance")
-
-        out_grad = builder_grad.batch_norm_grad(dout, x_g, scale_g, save_mean,
-                                                save_variance)
-        prog = builder_grad.build()
-        backward_res = self.get_cinn_output(
-            prog,
-            target, [dout, x_g, save_mean, save_variance], [
-                self.inputs["dout"], self.inputs["x"], forward_res[1],
-                forward_res[2]
-            ],
-            out_grad,
-            passes=[])
-        self.cinn_grads = [backward_res[0]]
+            builder_grad = NetBuilder("batch_norm_grad")
+            dout = builder_grad.create_input(
+                self.nptype2cinntype(inputs["dout"].dtype),
+                inputs["dout"].shape, "dout")
+            x_g = builder_grad.create_input(
+                self.nptype2cinntype(inputs["x"].dtype), inputs["x"].shape,
+                "x_g")
+            scale_g = builder_grad.fill_constant(scale.shape(), 1.0, 'scale_g',
+                                                 'float32')
+            save_mean = builder_grad.create_input(
+                self.nptype2cinntype('float32'), out[1].shape(), "save_mean")
+            save_variance = builder_grad.create_input(
+                self.nptype2cinntype('float32'), out[2].shape(),
+                "save_variance")
+
+            out_grad = builder_grad.batch_norm_grad(dout, x_g, scale_g,
+                                                    save_mean, save_variance)
+            prog = builder_grad.build()
+            backward_res = self.get_cinn_output(
+                prog,
+                target, [dout, x_g, save_mean, save_variance],
+                [inputs["dout"], inputs["x"], forward_res[1], forward_res[2]],
+                out_grad,
+                passes=[])
+            self.cinn_grads.append(backward_res[0])
 
     def test_check_results(self):
         self.check_outputs_and_grads()
 
 
-class TestBatchNormBackwardFP16(TestBatchNormBackwardOp):
+class TestBatchNormBackwardAll(TestBatchNormBackwardOp):
     def init_case(self):
-        self.num_channels = 16
-        self.inputs = {
-            "x":
-            self.random([2, self.num_channels, 8, 8], "float16", 0.0, 10.0),
-            "dout":
-            self.random([2, self.num_channels, 8, 8], "float16", 1e-7, 1e-6),
-        }
+        self.inputs = []
+        for x_shape in [[2, 16, 8, 8], [2, 16, 8, 1], [2, 16, 2048, 8]]:
+            for x_type in ["float16", "float32"]:
+                self.inputs.append({
+                    "x":
+                    self.random(x_shape, x_type, 0.0, 1.0),
+                    "dout":
+                    self.random(x_shape, x_type, 1e-7, 1e-6),
+                    "num_channels":
+                    x_shape[1]
+                })
 
     def test_check_results(self):
         self.check_outputs_and_grads(max_relative_error=1e-3)

From 8574f5a3a80f1c2a89e90267b40d3a062b4d17a7 Mon Sep 17 00:00:00 2001
From: Liyulingyue <852433440@qq.com>
Date: Sun, 11 Jun 2023 06:56:20 +0800
Subject: [PATCH 09/18] add TestCaseHelper

---
 python/tests/ops/test_batch_norm_op.py | 64 ++++++++++++--------------
 1 file changed, 29 insertions(+), 35 deletions(-)

diff --git a/python/tests/ops/test_batch_norm_op.py b/python/tests/ops/test_batch_norm_op.py
index 3e3dbb302b..73028ebeed 100644
--- a/python/tests/ops/test_batch_norm_op.py
+++ b/python/tests/ops/test_batch_norm_op.py
@@ -17,6 +17,7 @@
 import unittest, sys
 import numpy as np
 from op_test import OpTest, OpTestTool
+from op_test_helper import TestCaseHelper
 import paddle
 import cinn
 from cinn.frontend import *
@@ -30,55 +31,48 @@ def setUp(self):
         self.init_case()
 
     def init_case(self):
-        self.inputs = [{
-            "x":
-            self.random([2, 16, 8, 8], "float32", 0.0, 1.0),
-            "dout":
-            self.random([2, 16, 8, 8], "float32", 1e-7, 1e-6),
-            "num_channels":
-            16
-        }]
+        self.inputs = self.case
 
     def build_paddle_program(self, target):
-        for inputs in self.inputs:
-            x = paddle.to_tensor(inputs["x"])
-            batch_norm = paddle.nn.BatchNorm(
-                inputs["num_channels"], act=None, is_test=False)
-            out = batch_norm(x)
+        x = paddle.to_tensor(self.inputs["x"])
+        batch_norm = paddle.nn.BatchNorm(
+            self.inputs["num_channels"], act=None, is_test=False)
+        out = batch_norm(x)
 
-            self.paddle_outputs.append(out)
+        self.paddle_outputs = [out]
 
     # Note: If the forward and backward operators are run in the same program,
     # the forward result will be incorrect.
     def build_cinn_program(self, target):
-        for inputs in self.inputs:
-            builder = NetBuilder("batch_norm")
-            x = builder.create_input(
-                self.nptype2cinntype(inputs["x"].dtype), inputs["x"].shape,
-                "x")
-            scale = builder.fill_constant([inputs["num_channels"]], 1.0,
-                                          'scale', "float32")
-            bias = builder.fill_constant([inputs["num_channels"]], 0.0, 'bias',
-                                         "float32")
-            mean = builder.fill_constant([inputs["num_channels"]], 0.0, 'mean',
-                                         "float32")
-            variance = builder.fill_constant([inputs["num_channels"]], 1.0,
-                                             'variance', "float32")
+        builder = NetBuilder("batch_norm")
+        x = builder.create_input(
+            self.nptype2cinntype(self.inputs["x"].dtype),
+            self.inputs["x"].shape, "x")
+        scale = builder.fill_constant([self.inputs["num_channels"]], 1.0,
+                                      'scale', "float32")
+        bias = builder.fill_constant([self.inputs["num_channels"]], 0.0,
+                                     'bias', "float32")
+        mean = builder.fill_constant([self.inputs["num_channels"]], 0.0,
+                                     'mean', "float32")
+        variance = builder.fill_constant([self.inputs["num_channels"]], 1.0,
+                                         'variance', "float32")
 
-            out = builder.batchnorm(
-                x, scale, bias, mean, variance, is_test=False)
+        out = builder.batchnorm(x, scale, bias, mean, variance, is_test=False)
 
-            prog = builder.build()
-            forward_res = self.get_cinn_output(
-                prog, target, [x], [inputs["x"]], out, passes=[])
-            self.cinn_outputs.append(forward_res[0])
+        prog = builder.build()
+        forward_res = self.get_cinn_output(
+            prog, target, [x], [self.inputs["x"]], out, passes=[])
+        self.cinn_outputs = [forward_res[0]]
 
     def test_check_results(self):
         self.check_outputs_and_grads(max_relative_error=1e-3)
 
 
-class TestBatchNormTrainOpAll(TestBatchNormTrainOp):
-    def init_case(self):
+class TestBatchNormTrainOpAll(TestCaseHelper):
+    def init_attrs(self):
+        self.class_name = "TestBatchNormTrainOpBase"
+        self.cls = TestBatchNormTrainOp
+
         self.inputs = []
         for x_shape in [[2, 16, 8, 8], [2, 16, 8, 1], [2, 16, 2048, 8]]:
             for x_type in ["float16", "float32"]:

From c24ce6bda087b93156f3c515ba253fdd621249d5 Mon Sep 17 00:00:00 2001
From: Liyulingyue <852433440@qq.com>
Date: Fri, 16 Jun 2023 22:29:43 +0800
Subject: [PATCH 10/18] add TestCaseHelper

---
 python/tests/ops/test_batch_norm_op.py | 79 ++++++++++++++++----------
 1 file changed, 49 insertions(+), 30 deletions(-)

diff --git a/python/tests/ops/test_batch_norm_op.py b/python/tests/ops/test_batch_norm_op.py
index 73028ebeed..7f27e5d721 100644
--- a/python/tests/ops/test_batch_norm_op.py
+++ b/python/tests/ops/test_batch_norm_op.py
@@ -28,15 +28,18 @@
                     "x86 test will be skipped due to timeout.")
 class TestBatchNormTrainOp(OpTest):
     def setUp(self):
-        self.init_case()
+        print(f"\nRunning {self.__class__.__name__}: {self.case}")
+        self.prepare_inputs()
 
-    def init_case(self):
-        self.inputs = self.case
+    def prepare_inputs(self):
+        self.x_np = self.random(
+            shape=self.case["x_shape"], dtype=self.case["x_dtype"])
 
     def build_paddle_program(self, target):
-        x = paddle.to_tensor(self.inputs["x"])
-        batch_norm = paddle.nn.BatchNorm(
-            self.inputs["num_channels"], act=None, is_test=False)
+        x = paddle.to_tensor(self.x_np)
+        batch_norm = paddle.nn.BatchNorm([self.case["x_shape"][1]],
+                                         act=None,
+                                         is_test=False)
         out = batch_norm(x)
 
         self.paddle_outputs = [out]
@@ -46,44 +49,60 @@ def build_paddle_program(self, target):
     def build_cinn_program(self, target):
         builder = NetBuilder("batch_norm")
         x = builder.create_input(
-            self.nptype2cinntype(self.inputs["x"].dtype),
-            self.inputs["x"].shape, "x")
-        scale = builder.fill_constant([self.inputs["num_channels"]], 1.0,
-                                      'scale', "float32")
-        bias = builder.fill_constant([self.inputs["num_channels"]], 0.0,
-                                     'bias', "float32")
-        mean = builder.fill_constant([self.inputs["num_channels"]], 0.0,
-                                     'mean', "float32")
-        variance = builder.fill_constant([self.inputs["num_channels"]], 1.0,
-                                         'variance', "float32")
+            self.nptype2cinntype(self.case["x_dtype"]), self.case["x_shape"],
+            "x")
+        scale = builder.fill_constant([self.case["x_shape"][1]], 1.0, 'scale',
+                                      'float32')
+        bias = builder.fill_constant([self.case["x_shape"][1]], 0.0, 'bias',
+                                     'float32')
+        mean = builder.fill_constant([self.case["x_shape"][1]], 0.0, 'mean',
+                                     'float32')
+        variance = builder.fill_constant([self.case["x_shape"][1]], 1.0,
+                                         'variance', 'float32')
 
         out = builder.batchnorm(x, scale, bias, mean, variance, is_test=False)
 
         prog = builder.build()
         forward_res = self.get_cinn_output(
-            prog, target, [x], [self.inputs["x"]], out, passes=[])
+            prog, target, [x], [self.x_np], out, passes=[])
         self.cinn_outputs = [forward_res[0]]
 
     def test_check_results(self):
-        self.check_outputs_and_grads(max_relative_error=1e-3)
+        max_relative_error = self.case[
+            "max_relative_error"] if "max_relative_error" in self.case else 1e-5
+        self.check_outputs_and_grads(max_relative_error=max_relative_error)
 
 
 class TestBatchNormTrainOpAll(TestCaseHelper):
     def init_attrs(self):
-        self.class_name = "TestBatchNormTrainOpBase"
+        self.class_name = "TestBatchNormTrainOpCase"
         self.cls = TestBatchNormTrainOp
 
-        self.inputs = []
-        for x_shape in [[2, 16, 8, 8], [2, 16, 8, 1], [2, 16, 2048, 8]]:
-            for x_type in ["float16", "float32"]:
-                self.inputs.append({
-                    "x":
-                    self.random(x_shape, x_type, 0.0, 1.0),
-                    "dout":
-                    self.random(x_shape, x_type, 1e-7, 1e-6),
-                    "num_channels":
-                    x_shape[1]
-                })
+        self.inputs = [
+            {
+                "x_shape": [2, 16, 8, 8],
+            },
+            {
+                "x_shape": [2, 16, 8, 1],
+            },
+            {
+                "x_shape": [2, 16, 2048, 8],
+            },
+        ]
+        self.dtype = [
+            {
+                "x_dtype": "float16",
+                "max_relative_error": 1e-5
+            },
+            {
+                "x_dtype": "float32",
+                "max_relative_error": 1e-3
+            },
+            {
+                "x_dtype": "bfloat16",
+                "max_relative_error": 1e-2
+            },
+        ]
 
 
 class TestBatchNormTrainBF16(TestBatchNormTrainOp):

From 1d04d2932b76d42a0d9a02240029c17678894472 Mon Sep 17 00:00:00 2001
From: Liyulingyue <852433440@qq.com>
Date: Fri, 16 Jun 2023 22:30:54 +0800
Subject: [PATCH 11/18] add TestCaseHelper

---
 python/tests/ops/test_batch_norm_op.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/python/tests/ops/test_batch_norm_op.py b/python/tests/ops/test_batch_norm_op.py
index 7f27e5d721..982b008399 100644
--- a/python/tests/ops/test_batch_norm_op.py
+++ b/python/tests/ops/test_batch_norm_op.py
@@ -268,4 +268,4 @@ def test_check_results(self):
 
 
 if __name__ == "__main__":
-    unittest.main()
+    TestBatchNormTrainOpAll.run()

From 13d1f3621719828ea18bef03240edbf531141f7e Mon Sep 17 00:00:00 2001
From: Liyulingyue <852433440@qq.com>
Date: Fri, 16 Jun 2023 22:32:45 +0800
Subject: [PATCH 12/18] add TestCaseHelper

---
 python/tests/ops/test_batch_norm_op.py | 14 --------------
 1 file changed, 14 deletions(-)

diff --git a/python/tests/ops/test_batch_norm_op.py b/python/tests/ops/test_batch_norm_op.py
index 982b008399..3a64ba5a1a 100644
--- a/python/tests/ops/test_batch_norm_op.py
+++ b/python/tests/ops/test_batch_norm_op.py
@@ -105,20 +105,6 @@ def init_attrs(self):
         ]
 
 
-class TestBatchNormTrainBF16(TestBatchNormTrainOp):
-    def init_case(self):
-        self.num_channels = 16
-        x = self.random([2, self.num_channels, 8, 8], "bfloat16")
-        dout = self.random([2, self.num_channels, 8, 8], "bfloat16")
-        self.inputs = {
-            "x": x,
-            "dout": dout,
-        }
-
-    def test_check_results(self):
-        self.check_outputs_and_grads(max_relative_error=1e-2)
-
-
 @OpTestTool.skip_if(not is_compiled_with_cuda(),
                     "x86 test will be skipped due to timeout.")
 class TestBatchNormBackwardOp(OpTest):

From e0c444ff39732a1f1711cfd9eb03084065e895e5 Mon Sep 17 00:00:00 2001
From: Liyulingyue <852433440@qq.com>
Date: Sat, 17 Jun 2023 10:45:24 +0800
Subject: [PATCH 13/18] add TestCaseHelper

---
 python/tests/ops/test_batch_norm_op.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/python/tests/ops/test_batch_norm_op.py b/python/tests/ops/test_batch_norm_op.py
index 3a64ba5a1a..b5407c7163 100644
--- a/python/tests/ops/test_batch_norm_op.py
+++ b/python/tests/ops/test_batch_norm_op.py
@@ -254,4 +254,4 @@ def test_check_results(self):
 
 
 if __name__ == "__main__":
-    TestBatchNormTrainOpAll.run()
+    TestBatchNormTrainOpAll().run()

From b67eb4a74b1f73b0d38b83ce4947135bf0ef7cb2 Mon Sep 17 00:00:00 2001
From: Liyulingyue <852433440@qq.com>
Date: Sat, 17 Jun 2023 13:40:44 +0800
Subject: [PATCH 14/18] add TestCaseHelper

---
 python/tests/ops/test_batch_norm_op.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/python/tests/ops/test_batch_norm_op.py b/python/tests/ops/test_batch_norm_op.py
index b5407c7163..043e9ade70 100644
--- a/python/tests/ops/test_batch_norm_op.py
+++ b/python/tests/ops/test_batch_norm_op.py
@@ -89,7 +89,7 @@ def init_attrs(self):
                 "x_shape": [2, 16, 2048, 8],
             },
         ]
-        self.dtype = [
+        self.dtypes = [
             {
                 "x_dtype": "float16",
                 "max_relative_error": 1e-5

From e75a133b81ad4ee66f935155978fc32379919f0e Mon Sep 17 00:00:00 2001
From: Liyulingyue <852433440@qq.com>
Date: Sat, 17 Jun 2023 15:31:21 +0800
Subject: [PATCH 15/18] add TestCaseHelper

---
 python/tests/ops/test_batch_norm_op.py | 1 +
 1 file changed, 1 insertion(+)

diff --git a/python/tests/ops/test_batch_norm_op.py b/python/tests/ops/test_batch_norm_op.py
index 043e9ade70..f83f1214c4 100644
--- a/python/tests/ops/test_batch_norm_op.py
+++ b/python/tests/ops/test_batch_norm_op.py
@@ -103,6 +103,7 @@ def init_attrs(self):
                 "max_relative_error": 1e-2
             },
         ]
+        self.attrs = []
 
 
 @OpTestTool.skip_if(not is_compiled_with_cuda(),

From f40cbe18c84a0232604532c28b0fe2bd5bf0a902 Mon Sep 17 00:00:00 2001
From: Liyulingyue <852433440@qq.com>
Date: Sat, 17 Jun 2023 16:22:37 +0800
Subject: [PATCH 16/18] add TestCaseHelper

---
 python/tests/ops/test_batch_norm_op.py | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/python/tests/ops/test_batch_norm_op.py b/python/tests/ops/test_batch_norm_op.py
index f83f1214c4..4f87badb11 100644
--- a/python/tests/ops/test_batch_norm_op.py
+++ b/python/tests/ops/test_batch_norm_op.py
@@ -37,9 +37,8 @@ def prepare_inputs(self):
 
     def build_paddle_program(self, target):
         x = paddle.to_tensor(self.x_np)
-        batch_norm = paddle.nn.BatchNorm([self.case["x_shape"][1]],
-                                         act=None,
-                                         is_test=False)
+        batch_norm = paddle.nn.BatchNorm(
+            self.case["x_shape"][1], act=None, is_test=False)
         out = batch_norm(x)
 
         self.paddle_outputs = [out]

From 5701dcd807d5037210137376998e355f3c326123 Mon Sep 17 00:00:00 2001
From: Liyulingyue <852433440@qq.com>
Date: Sat, 17 Jun 2023 18:31:46 +0800
Subject: [PATCH 17/18] use cpphelper

---
 python/tests/ops/test_batch_norm_op.py | 172 +++++++++++++------------
 1 file changed, 89 insertions(+), 83 deletions(-)

diff --git a/python/tests/ops/test_batch_norm_op.py b/python/tests/ops/test_batch_norm_op.py
index 4f87badb11..603936b16e 100644
--- a/python/tests/ops/test_batch_norm_op.py
+++ b/python/tests/ops/test_batch_norm_op.py
@@ -91,11 +91,11 @@ def init_attrs(self):
         self.dtypes = [
             {
                 "x_dtype": "float16",
-                "max_relative_error": 1e-5
+                "max_relative_error": 1e-3
             },
             {
                 "x_dtype": "float32",
-                "max_relative_error": 1e-3
+                "max_relative_error": 1e-5
             },
             {
                 "x_dtype": "bfloat16",
@@ -109,100 +109,105 @@ def init_attrs(self):
                     "x86 test will be skipped due to timeout.")
 class TestBatchNormBackwardOp(OpTest):
     def setUp(self):
-        self.init_case()
+        print(f"\nRunning {self.__class__.__name__}: {self.case}")
+        self.prepare_inputs()
 
-    def init_case(self):
-        self.inputs = [{
-            "x":
-            self.random([2, 16, 8, 8], "float32", 0.0, 10.0),
-            "dout":
-            self.random([2, 16, 8, 8], "float32", 1e-7, 1e-6),
-            "num_channels":
-            16
-        }]
+    def prepare_inputs(self):
+        self.x_np = self.random(
+            shape=self.case["x_shape"], dtype=self.case["x_dtype"])
+        self.y_np = self.random(
+            shape=self.case["x_shape"], dtype=self.case["x_dtype"])
 
     def build_paddle_program(self, target):
-        for inputs in self.inputs:
-            x = paddle.to_tensor(inputs["x"], stop_gradient=False)
-            batch_norm = paddle.nn.BatchNorm(
-                inputs["num_channels"], act=None, is_test=False)
-            out = batch_norm(x)
+        x = paddle.to_tensor(self.x_np, stop_gradient=False)
+        batch_norm = paddle.nn.BatchNorm(
+            self.case["x_shape"][1], act=None, is_test=False)
+        out = batch_norm(x)
 
-            self.paddle_outputs.append(out)
-            grad = self.get_paddle_grads([out], [x], [inputs["dout"]])
-            self.paddle_grads.append(grad[0])
+        self.paddle_outputs = [out]
+        self.paddle_grads = self.get_paddle_grads([out], [x], [self.y_np])
 
     # Note: If the forward and backward operators are run in the same program,
     # the forward result will be incorrect.
     def build_cinn_program(self, target):
-        for inputs in self.inputs:
-            builder = NetBuilder("batch_norm")
-            x = builder.create_input(
-                self.nptype2cinntype(inputs["x"].dtype), inputs["x"].shape,
-                "x")
-            scale = builder.fill_constant([inputs["num_channels"]], 1.0,
-                                          'scale', 'float32')
-            bias = builder.fill_constant([inputs["num_channels"]], 0.0, 'bias',
-                                         'float32')
-            mean = builder.fill_constant([inputs["num_channels"]], 0.0, 'mean',
-                                         'float32')
-            variance = builder.fill_constant([inputs["num_channels"]], 1.0,
-                                             'variance', 'float32')
-
-            out = builder.batchnorm(
-                x, scale, bias, mean, variance, is_test=False)
-
-            prog = builder.build()
-            forward_res = self.get_cinn_output(
-                prog, target, [x], [inputs["x"]], out, passes=[])
-            self.cinn_outputs.append(forward_res[0])
-
-            builder_grad = NetBuilder("batch_norm_grad")
-            dout = builder_grad.create_input(
-                self.nptype2cinntype(inputs["dout"].dtype),
-                inputs["dout"].shape, "dout")
-            x_g = builder_grad.create_input(
-                self.nptype2cinntype(inputs["x"].dtype), inputs["x"].shape,
-                "x_g")
-            scale_g = builder_grad.fill_constant(scale.shape(), 1.0, 'scale_g',
-                                                 'float32')
-            save_mean = builder_grad.create_input(
-                self.nptype2cinntype('float32'), out[1].shape(), "save_mean")
-            save_variance = builder_grad.create_input(
-                self.nptype2cinntype('float32'), out[2].shape(),
-                "save_variance")
-
-            out_grad = builder_grad.batch_norm_grad(dout, x_g, scale_g,
-                                                    save_mean, save_variance)
-            prog = builder_grad.build()
-            backward_res = self.get_cinn_output(
-                prog,
-                target, [dout, x_g, save_mean, save_variance],
-                [inputs["dout"], inputs["x"], forward_res[1], forward_res[2]],
-                out_grad,
-                passes=[])
-            self.cinn_grads.append(backward_res[0])
+        builder = NetBuilder("batch_norm")
+        x = builder.create_input(
+            self.nptype2cinntype(self.case["x_dtype"]), self.case["x_shape"],
+            "x")
+        scale = builder.fill_constant([self.case["x_shape"][1]], 1.0, 'scale',
+                                      'float32')
+        bias = builder.fill_constant([self.case["x_shape"][1]], 0.0, 'bias',
+                                     'float32')
+        mean = builder.fill_constant([self.case["x_shape"][1]], 0.0, 'mean',
+                                     'float32')
+        variance = builder.fill_constant([self.case["x_shape"][1]], 1.0,
+                                         'variance', 'float32')
 
-    def test_check_results(self):
-        self.check_outputs_and_grads()
+        out = builder.batchnorm(x, scale, bias, mean, variance, is_test=False)
 
+        prog = builder.build()
+        forward_res = self.get_cinn_output(
+            prog, target, [x], [self.x_np], out, passes=[])
+        self.cinn_outputs = [forward_res[0]]
 
-class TestBatchNormBackwardAll(TestBatchNormBackwardOp):
-    def init_case(self):
-        self.inputs = []
-        for x_shape in [[2, 16, 8, 8], [2, 16, 8, 1], [2, 16, 2048, 8]]:
-            for x_type in ["float16", "float32"]:
-                self.inputs.append({
-                    "x":
-                    self.random(x_shape, x_type, 0.0, 1.0),
-                    "dout":
-                    self.random(x_shape, x_type, 1e-7, 1e-6),
-                    "num_channels":
-                    x_shape[1]
-                })
+        builder_grad = NetBuilder("batch_norm_grad")
+        dout = builder_grad.create_input(
+            self.nptype2cinntype(self.case["x_dtype"]), self.case["x_shape"],
+            "dout")
+        x_g = builder_grad.create_input(
+            self.nptype2cinntype(self.case["x_dtype"]), self.case["x_shape"],
+            "x_g")
+        scale_g = builder_grad.fill_constant(scale.shape(), 1.0, 'scale_g',
+                                             'float32')
+        save_mean = builder_grad.create_input(
+            self.nptype2cinntype('float32'), out[1].shape(), "save_mean")
+        save_variance = builder_grad.create_input(
+            self.nptype2cinntype('float32'), out[2].shape(), "save_variance")
+
+        out_grad = builder_grad.batch_norm_grad(dout, x_g, scale_g, save_mean,
+                                                save_variance)
+        prog = builder_grad.build()
+        backward_res = self.get_cinn_output(
+            prog,
+            target, [dout, x_g, save_mean, save_variance],
+            [self.y_np, self.x_np, forward_res[1], forward_res[2]],
+            out_grad,
+            passes=[])
+        self.cinn_grads = [backward_res[0]]
 
     def test_check_results(self):
-        self.check_outputs_and_grads(max_relative_error=1e-3)
+        max_relative_error = self.case[
+            "max_relative_error"] if "max_relative_error" in self.case else 1e-5
+        self.check_outputs_and_grads(max_relative_error=max_relative_error)
+
+
+class TestBatchNormBackwardOpAll(TestCaseHelper):
+    def init_attrs(self):
+        self.class_name = "TestBatchNormBackwardOpCase"
+        self.cls = TestBatchNormBackwardOp
+
+        self.inputs = [
+            {
+                "x_shape": [2, 16, 8, 8],
+            },
+            {
+                "x_shape": [2, 16, 8, 1],
+            },
+            {
+                "x_shape": [2, 16, 2048, 8],
+            },
+        ]
+        self.dtypes = [
+            {
+                "x_dtype": "float16",
+                "max_relative_error": 1e-3
+            },
+            {
+                "x_dtype": "float32",
+                "max_relative_error": 1e-5
+            },
+        ]
+        self.attrs = []
 
 
 @OpTestTool.skip_if(not is_compiled_with_cuda(),
@@ -255,3 +260,4 @@ def test_check_results(self):
 
 if __name__ == "__main__":
     TestBatchNormTrainOpAll().run()
+    TestBatchNormBackwardOpAll().run()

From f723c4c7cc5b2a7d25cb8062465df06fe9bc866c Mon Sep 17 00:00:00 2001
From: Liyulingyue <852433440@qq.com>
Date: Mon, 19 Jun 2023 19:51:05 +0800
Subject: [PATCH 18/18] use cpphelper

---
 python/tests/ops/test_batch_norm_op.py | 58 ++++++++++++++++++--------
 1 file changed, 41 insertions(+), 17 deletions(-)

diff --git a/python/tests/ops/test_batch_norm_op.py b/python/tests/ops/test_batch_norm_op.py
index 603936b16e..7226a36f5e 100644
--- a/python/tests/ops/test_batch_norm_op.py
+++ b/python/tests/ops/test_batch_norm_op.py
@@ -214,19 +214,17 @@ def init_attrs(self):
                     "x86 test will be skipped due to timeout.")
 class TestBatchNormInferOp(OpTest):
     def setUp(self):
-        self.init_case()
+        print(f"\nRunning {self.__class__.__name__}: {self.case}")
+        self.prepare_inputs()
 
-    def init_case(self):
-        self.num_channels = 16
-        self.inputs = {
-            "x": self.random([2, self.num_channels, 8, 8], "float32", 0.0,
-                             1.0),
-        }
+    def prepare_inputs(self):
+        self.x_np = self.random(
+            shape=self.case["x_shape"], dtype=self.case["x_dtype"])
 
     def build_paddle_program(self, target):
-        x = paddle.to_tensor(self.inputs["x"])
+        x = paddle.to_tensor(self.x_np)
         batch_norm = paddle.nn.BatchNorm(
-            self.num_channels, act=None, is_test=True)
+            self.case["x_shape"][1], act=None, is_test=True)
         out = batch_norm(x)
 
         self.paddle_outputs = [out]
@@ -236,28 +234,54 @@ def build_paddle_program(self, target):
     def build_cinn_program(self, target):
         builder = NetBuilder("batch_norm")
         x = builder.create_input(
-            self.nptype2cinntype(self.inputs["x"].dtype),
-            self.inputs["x"].shape, "x")
-        scale = builder.fill_constant([self.num_channels], 1.0, 'scale',
+            self.nptype2cinntype(self.case["x_dtype"]), self.case["x_shape"],
+            "x")
+        scale = builder.fill_constant([self.case["x_shape"][1]], 1.0, 'scale',
                                       'float32')
-        bias = builder.fill_constant([self.num_channels], 0.0, 'bias',
+        bias = builder.fill_constant([self.case["x_shape"][1]], 0.0, 'bias',
                                      'float32')
-        mean = builder.fill_constant([self.num_channels], 0.0, 'mean',
+        mean = builder.fill_constant([self.case["x_shape"][1]], 0.0, 'mean',
                                      'float32')
-        variance = builder.fill_constant([self.num_channels], 1.0, 'variance',
-                                         'float32')
+        variance = builder.fill_constant([self.case["x_shape"][1]], 1.0,
+                                         'variance', 'float32')
 
         out = builder.batchnorm(x, scale, bias, mean, variance, is_test=False)
 
         prog = builder.build()
         forward_res = self.get_cinn_output(
-            prog, target, [x], [self.inputs["x"]], out, passes=[])
+            prog, target, [x], [self.x_np], out, passes=[])
         self.cinn_outputs = [forward_res[0]]
 
     def test_check_results(self):
         self.check_outputs_and_grads()
 
 
+class TestBatchNormInferOpAll(TestCaseHelper):
+    def init_attrs(self):
+        self.class_name = "TestBatchNormInferOpCase"
+        self.cls = TestBatchNormInferOp
+
+        self.inputs = [
+            {
+                "x_shape": [2, 16, 8, 8],
+            },
+            {
+                "x_shape": [2, 16, 8, 1],
+            },
+            {
+                "x_shape": [2, 16, 2048, 8],
+            },
+        ]
+        self.dtypes = [
+            {
+                "x_dtype": "float32",
+                "max_relative_error": 1e-5
+            },
+        ]
+        self.attrs = []
+
+
 if __name__ == "__main__":
     TestBatchNormTrainOpAll().run()
     TestBatchNormBackwardOpAll().run()
+    TestBatchNormInferOpAll().run()