Skip to content

Commit

Permalink
Enable AbsComplex in exhaustive_binary_f32_f64_test
Browse files Browse the repository at this point in the history
Merges the two `AbsComplex` variants into one macro invocation that enables for both F32 and F64 tests (for `complex64` and `complex128`). Tightens tolerances and re-enables the tests on CPU.

PiperOrigin-RevId: 666267443
  • Loading branch information
Gregory Pataky authored and copybara-github committed Aug 22, 2024
1 parent 72b4bd7 commit 2ce52ad
Showing 1 changed file with 76 additions and 34 deletions.
110 changes: 76 additions & 34 deletions xla/tests/exhaustive/exhaustive_binary_f32_f64_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include <cstdint>
#include <cstdlib>
#include <tuple>
#include <type_traits>

#include "absl/log/check.h"
#include "absl/log/log.h"
Expand Down Expand Up @@ -123,26 +124,6 @@ BINARY_TEST_FLOAT_32(Min, {
Run(AddEmptyBroadcastDimension(Min), ReferenceMin<float>);
})

// It is more convenient to implement Abs(complex) as a binary op than a unary
// op, as the operations we currently support all have the same data type for
// the source operands and the results.
// TODO(bixia): May want to move this test to unary test if we will be able to
// implement Abs(complex) as unary conveniently.
//
// TODO(bixia): Need to investigate the failure on CPU and file bugs.
BINARY_TEST_FLOAT_32(DISABLED_ON_CPU(AbsComplex), {
// TODO(timshen): see b/162664705.
known_incorrect_fn_ = [this](int64_t val) {
return std::isnan(this->ConvertValue(val));
};
auto host_abs_complex = [](float x, float y) {
return std::abs(std::complex<float>(x, y));
};
auto device_abs_complex = [](XlaOp x, XlaOp y) { return Abs(Complex(x, y)); };

Run(device_abs_complex, host_abs_complex);
})

INSTANTIATE_TEST_SUITE_P(
SpecialValues, ExhaustiveF32BinaryTest,
::testing::Combine(
Expand Down Expand Up @@ -217,20 +198,6 @@ BINARY_TEST_FLOAT_64(Min, {
Run(AddEmptyBroadcastDimension(Min), ReferenceMin<double>);
})

// TODO(bixia): Need to investigate the failure on CPU and file bugs.
BINARY_TEST_FLOAT_64(DISABLED_ON_CPU(AbsComplex), {
// TODO(timshen): see b/162664705.
known_incorrect_fn_ = [this](int64_t val) {
return std::isnan(this->ConvertValue(val));
};
auto host_abs_complex = [](double x, double y) {
return std::abs(std::complex<double>(x, y));
};
auto device_abs_complex = [](XlaOp x, XlaOp y) { return Abs(Complex(x, y)); };

Run(device_abs_complex, host_abs_complex);
})

#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64)
INSTANTIATE_TEST_SUITE_P(
SpecialValues, ExhaustiveF64BinaryTest,
Expand Down Expand Up @@ -267,6 +234,81 @@ INSTANTIATE_TEST_SUITE_P(
GetFpValuesForMagnitudeExtremeNormals<double>(40000, 2000))));
#endif // !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64)

#define BINARY_TEST_FLOAT_BOTH(test_name, ...) \
BINARY_TEST_FLOAT_32(test_name, __VA_ARGS__) \
BINARY_TEST_FLOAT_64(test_name, __VA_ARGS__)

// Can be thought of as an absolute error of
// `<= |std::numeric_limits::<float>::min()|`.
template <typename NativeRefT>
double AbsComplexCpuAbsErr(NativeRefT real, NativeRefT imag) {
// absolute value (distance) short circuits if the first component is
// subnormal.
if (!std::isnan(real) && IsSubnormal(real)) {
return std::abs(real);
}
return 0.0;
}

template <typename NativeRefT>
bool AbsComplexSkip(NativeRefT real, NativeRefT imag) {
// TODO(timshen): see b/162664705.
if (std::isnan(real) || std::isnan(imag)) {
return true;
}
return false;
}

// It is more convenient to implement Abs(complex) as a binary op than a unary
// op, as the operations we currently support all have the same data type for
// the source operands and the results.
// TODO(bixia): May want to move this test to unary test if we will be able to
// implement Abs(complex) as unary conveniently.
BINARY_TEST_FLOAT_BOTH(AbsComplex, {
ErrorSpecGen error_spec_gen = +[](NativeRefT, NativeRefT) {
return ErrorSpec::Builder().strict_signed_zeros().build();
};

if (IsCpu(platform_)) {
if constexpr (std::is_same_v<NativeT, float> ||
std::is_same_v<NativeT, double>) {
error_spec_gen = +[](NativeRefT real, NativeRefT imag) {
return ErrorSpec::Builder()
.abs_err(AbsComplexCpuAbsErr(real, imag))
.distance_err(2)
.skip_comparison(AbsComplexSkip(real, imag))
.build();
};
}
}

if (IsGpu(platform_)) {
if constexpr (std::is_same_v<NativeT, float>) {
error_spec_gen = +[](NativeRefT real, NativeRefT imag) {
return ErrorSpec::Builder()
.distance_err(3)
.skip_comparison(AbsComplexSkip(real, imag))
.build();
};
} else if constexpr (std::is_same_v<NativeT, double>) {
error_spec_gen = +[](NativeRefT real, NativeRefT imag) {
return ErrorSpec::Builder()
.distance_err(2)
.skip_comparison(AbsComplexSkip(real, imag))
.build();
};
}
}

EnableDebugLoggingForScope([this, error_spec_gen]() {
Run([](XlaOp x, XlaOp y) { return Abs(Complex(x, y)); },
[](NativeRefT x, NativeRefT y) {
return std::abs(std::complex<NativeRefT>(x, y));
},
error_spec_gen);
});
})

} // namespace
} // namespace exhaustive_op_test
} // namespace xla

0 comments on commit 2ce52ad

Please sign in to comment.