Skip to content

Commit

Permalink
【AMP OP&Test】unit test for accuracy_op (#51009)
Browse files Browse the repository at this point in the history
* test_accuracy_op

* add create_test_fp/bf16_class

* cast after calculation

* change convert_uint16_to_float_ifneed

* delete TestAccuracyOpFp32 according to PR comment

* fix the rtol setting rules in bfloat16 forward
  • Loading branch information
zhangbopd authored Mar 22, 2023
1 parent 320a5b2 commit 8c61a95
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 2 deletions.
6 changes: 5 additions & 1 deletion paddle/phi/kernels/gpu/accuracy_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"

Expand All @@ -34,6 +36,7 @@ __global__ void AccuracyCudaKernel(const int N,
int* correct_data,
T* accuracy,
int* total_data) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
int count = 0;
__shared__ int total[BlockSize];

Expand Down Expand Up @@ -64,7 +67,7 @@ __global__ void AccuracyCudaKernel(const int N,
#endif
if (threadIdx.x == 0) {
*correct_data = result;
*accuracy = static_cast<T>(result) / static_cast<T>(N);
*accuracy = static_cast<T>(static_cast<MT>(result) / static_cast<MT>(N));
*total_data = N;
}
}
Expand Down Expand Up @@ -136,6 +139,7 @@ PD_REGISTER_KERNEL(accuracy,
ALL_LAYOUT,
phi::AccuracyRawKernel,
phi::dtype::float16,
phi::dtype::bfloat16,
float,
double) {
kernel->InputAt(1).SetDataType(phi::DataType::INT64);
Expand Down
45 changes: 44 additions & 1 deletion python/paddle/fluid/tests/unittests/test_accuracy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@

import paddle
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
from paddle.fluid import Program, core, program_guard
from paddle.fluid.tests.unittests.op_test import convert_float_to_uint16


def accuracy_wrapper(infer, indices, label):
Expand Down Expand Up @@ -64,6 +65,48 @@ def test_check_output(self):
self.check_output(atol=1e-3)


@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and not support the bfloat16",
)
class TestAccuracyOpBf16(OpTest):
def setUp(self):
self.op_type = "accuracy"
self.python_api = accuracy_wrapper
self.init_dtype()
n = 8192
infer = np.random.random((n, 1)).astype(np.float32)
indices = np.random.randint(0, 2, (n, 1)).astype('int64')
label = np.random.randint(0, 2, (n, 1)).astype('int64')
self.inputs = {
'Out': convert_float_to_uint16(infer),
'Indices': indices,
"Label": label,
}
num_correct = 0
for rowid in range(n):
for ele in indices[rowid]:
if ele == label[rowid]:
num_correct += 1
break
self.outputs = {
'Accuracy': convert_float_to_uint16(
np.array([num_correct / float(n)]).astype(np.float32)
),
'Correct': np.array([num_correct]).astype("int32"),
'Total': np.array([n]).astype("int32"),
}

def init_dtype(self):
self.dtype = np.uint16

def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-2)


class TestAccuracyOpError(unittest.TestCase):
def test_type_errors(self):
with program_guard(Program(), Program()):
Expand Down

0 comments on commit 8c61a95

Please sign in to comment.