-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add erfinv FP16 test and BF16 test #53101
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
Sorry to inform you that baa619d's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
@Vvsmile 麻烦review~下, 谢谢 |
@@ -15,6 +15,7 @@ | |||
#pragma once | |||
|
|||
#include "paddle/phi/core/dense_tensor.h" | |||
#include "paddle/phi/core/device_context.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个头文件有必要加吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已删除
#include "paddle/phi/core/kernel_registry.h" | ||
#include "paddle/phi/kernels/elementwise_add_kernel.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个头文件为什么加
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已删除
template <> | ||
struct ErfinvFunctor<float16> { | ||
HOSTDEVICE inline float16 operator()(const float16 x) const { | ||
auto x_ = static_cast<float>(x); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
最好不要用这种下划线后缀命名变量
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
@@ -29,7 +32,7 @@ void ErfinvGradKernel(const Context& ctx, | |||
auto eigen_dout = EigenVector<T>::Flatten(out_grad); | |||
auto eigen_dx = EigenVector<T>::Flatten(*x_grad); | |||
auto& place = *ctx.eigen_device(); | |||
constexpr T half_sqrt_pi = static_cast<T>(1 / M_2_SQRTPI); | |||
const T half_sqrt_pi = static_cast<T>(1 / M_2_SQRTPI); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里修改的原因是?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
根据CI错误提示这里应该修改为const常量
@@ -110,5 +114,67 @@ def run(place): | |||
run(place) | |||
|
|||
|
|||
class TestErfinvFP16OP(OpTest): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
继承TestErfinv,减少重复代码
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
self.python_api = paddle.erfinv | ||
self.dtype = np.uint16 | ||
self.shape = [11, 17] | ||
self.x = np.random.uniform(-1, 1, size=self.shape).astype(self.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
x应初始化为float32
self.shape = [11, 17] | ||
self.x = np.random.uniform(-1, 1, size=self.shape).astype(self.dtype) | ||
self.x_s = convert_uint16_to_float(self.x) | ||
self.res_ref = erfinv(self.x_s).astype(np.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
参考结果应该由float32的输入计算得到
Sorry to inform you that 26dc177's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'erfinv') | ||
check_variable_and_dtype( | ||
x, 'x', ['float32', 'float64', 'float16', 'uint16'], 'erfinv' | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个文件4358行文档可以同步更新下支持的数据类型,另外可以解决下代码冲突,应该就可以合入了
close due to the following PR is merged: |
PR types
Others
PR changes
Others
Description
add erfinv FP16 test and BF16 test