Skip to content

Commit

Permalink
Replace uses of known_incorrect_fn_ with ErrorSpec::skip_comparison
Browse files Browse the repository at this point in the history
`ErrorSpec::skip_comparison` fulfills all of the same features, but is slightly more ergonomic for unary tests and provides the ability to filter binary inputs as a pair instead of one at a time.

PiperOrigin-RevId: 666453655
  • Loading branch information
Gregory Pataky authored and copybara-github committed Aug 22, 2024
1 parent 38b8a4a commit 260b9c9
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 67 deletions.
2 changes: 1 addition & 1 deletion xla/tests/exhaustive/exhaustive_op_test_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ struct ComponentStringifyFormat {

template <>
constexpr absl::string_view ComponentStringifyFormat<double>::value =
"%0.17g (0x%16x)";
"%0.17g (0x%016x)";

template <>
constexpr absl::string_view ComponentStringifyFormat<float>::value =
Expand Down
3 changes: 2 additions & 1 deletion xla/tests/exhaustive/exhaustive_op_test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,8 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase {
bool passed = abs_err <= spec.abs_err || rel_err <= spec.rel_err ||
distance_err <= spec.distance_err;
if (should_emit_debug_logging_ && !passed) {
LOG(INFO) << std::setprecision(std::numeric_limits<NativeT>::max_digits10)
LOG(INFO) << std::setprecision(
std::numeric_limits<ComponentNativeT>::max_digits10)
<< "actual: " << actual << "; expected: " << expected
<< std::setprecision(std::numeric_limits<double>::max_digits10)
<< "\n\tabs_err: " << abs_err
Expand Down
151 changes: 86 additions & 65 deletions xla/tests/exhaustive/exhaustive_unary_complex_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,6 @@ class ExhaustiveComplexUnaryTestBase
protected:
using typename ExhaustiveUnaryTest<T>::NativeT;

void SetParamsForTanh() {
// TODO(b/138126045): Current libc++ implementation of the complex tanh
// function returns (NaN, NaN) when the imaginary
// component is more than half of the max value.
// TODO(b/138750327): Current libc++ implementation of the complex tanh
// function returns (1, 0) when the real component is
// negative infinity, when it should return (-1, 0).
// We only need to set the former as incorrect values for C128 because when
// testing with C64, we first cast our input to a C128 value.
this->known_incorrect_fn_ = [&](int64_t v) {
double f = this->ConvertValue(v);
return (T == C128 &&
std::abs(f) > std::numeric_limits<double>::max() / 2) ||
f == -std::numeric_limits<double>::infinity();
};
}

private:
// Generates the input complex literal given the FpValues representation for
// the real and imaginary components.
Expand Down Expand Up @@ -110,23 +93,25 @@ using ExhaustiveC128UnaryTest = ExhaustiveComplexUnaryTestBase<C128>;
__VA_ARGS__

UNARY_TEST_COMPLEX_64(Log, {
// TODO(rmlarsen): see b/162664705 and b/138578594
known_incorrect_fn_ = [this](int64_t val) {
complex64 x = this->ConvertValue(val);
return std::isnan(x.real()) || std::isnan(x.imag()) ||
(platform_ == "Host" &&
std::abs(x) < std::numeric_limits<float>::min());
};
ErrorSpecGen error_spec_gen = +[](complex64 x) {
double abs_err, rel_err;
// The reference implementation overflows to infinity for arguments near
// FLT_MAX.
if (std::abs(x) >= std::numeric_limits<float>::max()) {
float inf = std::numeric_limits<float>::infinity();
return ErrorSpec::Builder().abs_err(inf).rel_err(inf).build();
abs_err = rel_err = std::numeric_limits<float>::infinity();
} else {
double eps = std::numeric_limits<float>::epsilon();
abs_err = eps;
rel_err = 50 * eps;
}

// TODO(rmlarsen): see b/162664705 and b/138578594
bool should_skip = std::isnan(x.real()) || std::isnan(x.imag());

return ErrorSpec::Builder()
.abs_err(std::numeric_limits<float>::epsilon())
.rel_err(50 * std::numeric_limits<float>::epsilon())
.abs_err(abs_err)
.rel_err(rel_err)
.skip_comparison(should_skip)
.build();
};
Run(Log, [](complex64 x) { return std::log(x); }, error_spec_gen);
Expand All @@ -152,40 +137,56 @@ UNARY_TEST_COMPLEX_64(Sqrt, {
Run(Sqrt, [](complex64 x) { return std::sqrt(x); }, error_spec_gen);
})

double RsqrtCpuGpuAbsErr(complex64 x) {
return std::sqrt(std::numeric_limits<float>::min());
}

double RsqrtCpuGpuRelErr(complex64 x) {
// As noted above for Sqrt, the accuracy of sqrt degrades severely for
// inputs with inputs with subnormals entries.
constexpr double eps = std::numeric_limits<float>::epsilon();
constexpr double norm_min = std::numeric_limits<float>::min();
constexpr double denorm_min = std::numeric_limits<float>::denorm_min();
if (std::abs(x) < norm_min) {
// Gradually loosen the relative tolerance as abs(x) becomes smaller
// than norm_min, letting it reach 100% when abs(x) = 10 * denorm_min.
return 10 * denorm_min / std::abs(x);
}
return 50 * eps;
}

UNARY_TEST_COMPLEX_64(Rsqrt, {
known_incorrect_fn_ = [this](int64_t val) {
complex64 x = this->ConvertValue(val);
return (platform_ == "Host" && (x.imag() == 0.0f || x.real() == 0.0f));
ErrorSpecGen error_spec_gen = +[](complex64) {
return ErrorSpec::Builder().strict_signed_zeros().build();
};
ErrorSpecGen error_spec_gen = +[](complex64 x) {
// As noted above for Sqrt, the accuracy of sqrt degrades severely for
// inputs with inputs with subnormals entries.
constexpr double norm_min = std::numeric_limits<float>::min();
constexpr double denorm_min = std::numeric_limits<float>::denorm_min();
if (std::abs(x) < norm_min) {
// Gradually loosen the relative tolerance as abs(x) becomes smaller
// than norm_min, letting it reach 100% when abs(x) = 10 * denorm_min.

if (IsCpu(platform_)) {
error_spec_gen = +[](complex64 x) {
return ErrorSpec::Builder()
.abs_err(std::sqrt(std::numeric_limits<float>::min()))
.rel_err(10 * denorm_min / std::abs(x))
.abs_err(RsqrtCpuGpuAbsErr(x))
.rel_err(RsqrtCpuGpuRelErr(x))
.skip_comparison(x.real() == 0.0f)
.strict_signed_zeros(false)
.build();
}
return ErrorSpec::Builder()
.abs_err(std::sqrt(std::numeric_limits<float>::min()))
.rel_err(50 * std::numeric_limits<float>::epsilon())
.build();
};
};
}

if (IsGpu(platform_)) {
error_spec_gen = +[](complex64 x) {
return ErrorSpec::Builder()
.abs_err(RsqrtCpuGpuAbsErr(x))
.rel_err(RsqrtCpuGpuRelErr(x))
.strict_signed_zeros(false)
.build();
};
}

Run(
Rsqrt, [](complex64 x) { return complex64(1, 0) / std::sqrt(x); },
error_spec_gen);
})

// The current libc++ implementation of the complex tanh function provides
// less accurate results when the denominator of a complex tanh is small, due
// to floating point precision loss. To avoid this issue for complex64 numbers,
// we cast it to and from a complex128 when computing tanh.
UNARY_TEST_COMPLEX_64(Tanh, {
SetParamsForTanh();
ErrorSpecGen error_spec_gen = +[](complex64 x) {
// This implementation of Tanh becomes less accurate when the denominator
// is small.
Expand All @@ -198,6 +199,11 @@ UNARY_TEST_COMPLEX_64(Tanh, {
Run(
Tanh,
+[](complex64 x) {
// The current libc++ implementation of the complex tanh function
// provides less accurate results when the denominator of a complex tanh
// is small, due to floating point precision loss. To avoid this issue
// for complex64 numbers, we cast it to and from a complex128 when
// computing tanh.
return static_cast<complex64>(std::tanh(static_cast<complex128>(x)));
},
error_spec_gen);
Expand Down Expand Up @@ -241,13 +247,23 @@ INSTANTIATE_TEST_SUITE_P(
__VA_ARGS__

UNARY_TEST_COMPLEX_128(Log, {
// TODO(rmlarsen): see b/162664705 and b/138578594
known_incorrect_fn_ = [&](int64_t v) {
double f = this->ConvertValue(v);
return std::fpclassify(f) == FP_NAN || std::abs(f) > 1.0e+300 ||
std::abs(f) < 1.0e-300;
ErrorSpecGen error_spec_gen = +[](complex128 x) {
return ErrorSpec::Builder().strict_signed_zeros().build();
};
Run(Log, [](complex128 x) { return std::log(x); });

if (IsCpu(platform_) || IsGpu(platform_)) {
error_spec_gen = +[](complex128 x) {
// TODO(rmlarsen): see b/162664705 and b/138578594
bool should_skip = std::isnan(x.real()) || std::isnan(x.imag());
return ErrorSpec::Builder()
.distance_err(1)
.skip_comparison(should_skip)
.strict_signed_zeros()
.build();
};
}

Run(Log, [](complex128 x) { return std::log(x); }, error_spec_gen);
})

UNARY_TEST_COMPLEX_128(Sqrt, {
Expand All @@ -261,13 +277,10 @@ UNARY_TEST_COMPLEX_128(Sqrt, {
return ErrorSpec::Builder()
.abs_err(std::sqrt(std::numeric_limits<double>::denorm_min()))
.rel_err(50 * std::numeric_limits<double>::epsilon())
// TODO(b/138126045): Similar to the Tanh bug.
.skip_comparison(std::abs(x) > std::numeric_limits<double>::max() / 2)
.build();
};
// Similar to the Tanh bug.
known_incorrect_fn_ = [&](int64_t v) {
double f = this->ConvertValue(v);
return std::abs(f) > std::numeric_limits<double>::max() / 2;
};
Run(Sqrt, [](complex128 x) { return std::sqrt(x); }, error_spec_gen);
})

Expand Down Expand Up @@ -296,16 +309,24 @@ UNARY_TEST_COMPLEX_128(Rsqrt, {
})

UNARY_TEST_COMPLEX_128(Tanh, {
ErrorSpecGen error_spec_gen = [](complex128 x) {
ErrorSpecGen error_spec_gen = +[](complex128 x) {
// TODO(b/138126045): Current libc++ implementation of the complex tanh
// function returns (NaN, NaN) when the imaginary
// component is more than half of the max value.
// TODO(b/138750327): Current libc++ implementation of the complex tanh
// function returns (1, 0) when the real component is
// negative infinity, when it should return (-1, 0).
bool should_skip = std::abs(x) > std::numeric_limits<double>::max() / 2;

// TODO(rmlarsen): Investigate why we only get slightly better than
// float accuracy here.
return ErrorSpec::Builder()
.abs_err(2 * std::numeric_limits<double>::min())
.rel_err(2e-8)
.skip_comparison(should_skip)
.build();
};

SetParamsForTanh();
Run(Tanh, +[](complex128 x) { return std::tanh(x); }, error_spec_gen);
})

Expand Down

0 comments on commit 260b9c9

Please sign in to comment.