Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FP16 & BF16 for erfinv #55287

Merged
merged 14 commits into from
Aug 2, 2023
10 changes: 8 additions & 2 deletions paddle/phi/kernels/gpu/erfinv_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,11 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/erfinv_grad_kernel_impl.h"

PD_REGISTER_KERNEL(
erfinv_grad, GPU, ALL_LAYOUT, phi::ErfinvGradKernel, float, double) {}
PD_REGISTER_KERNEL(erfinv_grad,
GPU,
ALL_LAYOUT,
phi::ErfinvGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
23 changes: 22 additions & 1 deletion paddle/phi/kernels/gpu/erfinv_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,21 @@ template <typename T>
struct ErfinvFunctor {
HOSTDEVICE inline T operator()(const T x) const { return erfinv(x); }
};
template <>
struct ErfinvFunctor<float16> {
HOSTDEVICE inline float16 operator()(const float16 x) const {
auto x_ = static_cast<float>(x);
return static_cast<float16>(erfinv(x_));
}
};

template <>
struct ErfinvFunctor<bfloat16> {
HOSTDEVICE inline bfloat16 operator()(const bfloat16 x) const {
auto x_ = static_cast<float>(x);
return static_cast<bfloat16>(erfinv(x_));
}
};
template <typename T, typename Context>
void ErfinvKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) {
ctx.template Alloc<T>(out);
Expand All @@ -34,4 +48,11 @@ void ErfinvKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) {

} // namespace phi

PD_REGISTER_KERNEL(erfinv, GPU, ALL_LAYOUT, phi::ErfinvKernel, float, double) {}
PD_REGISTER_KERNEL(erfinv,
GPU,
ALL_LAYOUT,
phi::ErfinvKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
2 changes: 1 addition & 1 deletion paddle/phi/kernels/impl/erfinv_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ void ErfinvGradKernel(const Context& ctx,
auto eigen_dout = EigenVector<T>::Flatten(out_grad);
auto eigen_dx = EigenVector<T>::Flatten(*x_grad);
auto& place = *ctx.eigen_device();
constexpr T half_sqrt_pi = static_cast<T>(1 / M_2_SQRTPI);
T half_sqrt_pi = static_cast<T>(1 / M_2_SQRTPI);
eigen_dx.device(place) = half_sqrt_pi * eigen_dout * eigen_out.square().exp();
}

Expand Down
4 changes: 3 additions & 1 deletion python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -4607,7 +4607,9 @@ def erfinv(x, name=None):
if in_dynamic_mode():
return _C_ops.erfinv(x)
else:
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'erfinv')
check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'float16', 'uint16'], 'erfinv'
)
enkilee marked this conversation as resolved.
Show resolved Hide resolved
helper = LayerHelper('erfinv', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type='erfinv', inputs={'X': x}, outputs={'Out': out})
Expand Down
52 changes: 47 additions & 5 deletions test/legacy_test/test_erfinv_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import unittest

import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
from scipy.special import erfinv

import paddle
Expand All @@ -25,7 +25,7 @@
np.random.seed(0)


class TestErfinv(OpTest):
class TestErfinvOp(OpTest):
def setUp(self):
self.op_type = "erfinv"
self.python_api = paddle.erfinv
Expand Down Expand Up @@ -55,12 +55,12 @@ def test_check_grad(self):
)


class TestErfinvFP32(TestErfinv):
class TestErfinvFP64Op(TestErfinvOp):
def init_dtype(self):
self.dtype = np.float32
self.dtype = np.float64


class TestErfinvAPI(unittest.TestCase):
class TestErfinvAPIOp(unittest.TestCase):
def init_dtype(self):
self.dtype = 'float32'

Expand Down Expand Up @@ -110,5 +110,47 @@ def run(place):
run(place)


class TestErfinvFP16Op(TestErfinvOp):
def init_dtype(self):
self.dtype = np.float16


@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestErfinvBF16Op(OpTest):
def setUp(self):
self.op_type = "erfinv"
self.public_python_api = paddle.erfinv
self.python_api = paddle.erfinv
self.dtype = np.uint16
self.shape = [11, 17]
self.datatype = np.float32
x = np.random.uniform(-1, 1, size=self.shape).astype(self.datatype)
out_ref = erfinv(x).astype(self.datatype)
self.grad_out = np.ones(self.shape, self.datatype)
self.gradient = (
np.sqrt(np.pi) / 2 * np.exp(np.square(out_ref)) * self.grad_out
)
self.inputs = {'X': convert_float_to_uint16(x)}
self.outputs = {'Out': convert_float_to_uint16(out_ref)}

def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)

def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place,
['X'],
'Out',
user_defined_grads=[self.gradient],
user_defined_grad_outputs=self.grad_out,
)


if __name__ == "__main__":
unittest.main()