Skip to content

Commit

Permalink
PR #8589: Improve the accuracy of asin(x) and asinh(x) for complex x …
Browse files Browse the repository at this point in the history
…using modified Hull et al algorithm.

Imported from GitHub PR #8589

As in the title.

~Fixes #8553 - PR #9802 disabled the fix.

Update: the fix to #8553 will be available via openxla/stablehlo#2357
Copybara import of the project:

--
547f5f7 by Pearu Peterson <[email protected]>:

Improve the accuracy of asinh(z) for complex z with large absolute value.

--
00c96e8 by Pearu Peterson <[email protected]>:

Implement the modified Hull et al algorithm for Asin and Asinh.

--
94eb9ad by Pearu Peterson <[email protected]>:

Use functional_algorithms to generate Asin implementation

--
ec7334f by Pearu Peterson <[email protected]>:

Eliminate static Hypot as not used

--
0f8cd5e by Pearu Peterson <[email protected]>:

Apply clang-format

Merging this change closes #8589

COPYBARA_INTEGRATE_REVIEW=#8589 from pearu:pearu/asinh 0f8cd5e
PiperOrigin-RevId: 666419461
  • Loading branch information
pearu authored and copybara-github committed Aug 22, 2024
1 parent d8a8089 commit 6942179
Show file tree
Hide file tree
Showing 8 changed files with 2,285 additions and 596 deletions.
5 changes: 4 additions & 1 deletion xla/client/lib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,10 @@ cc_library(
cc_library(
name = "math",
srcs = ["math.cc"],
hdrs = ["math.h"],
hdrs = [
"math.h",
"math_impl.h",
],
deps = [
":arithmetic",
":constants",
Expand Down
126 changes: 126 additions & 0 deletions xla/client/lib/generate_math_impl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""A script to generate math_impl.h.
Prerequisites:
python 3.11 or newer
functional_algorithms 0.3.1 or newer
Usage:
Running
python /path/to/generate_math_impl.py [xla | tensorflow]
will create
/path/to/math_impl.cc
"""

import os
import sys
import warnings

try:
import functional_algorithms as fa # pylint: disable=g-import-not-at-top
except ImportError as msg:
warnings.warn(f"Skipping: {msg}")
fa = None


def main():
if fa is None:
return
target = (sys.argv[1] if len(sys.argv) > 1 else "xla").lower()
assert target in {"xla", "tensorflow"}, target
header_file_define = dict(
xla="XLA_CLIENT_MATH_IMPL_H_",
tensorflow="TENSORFLOW_COMPILER_XLA_CLIENT_MATH_IMPL_H_",
)[target]

fa_version = tuple(map(int, fa.__version__.split(".", 4)[:3]))
if fa_version < (0, 3, 1):
warnings.warn(
"functional_algorithm version 0.3.1 or newer is required,"
f" got {fa.__version__}"
)
return

output_file = os.path.join(os.path.dirname(__file__), "math_impl.h")

sources = []
target = fa.targets.xla_client
for xlaname, fname, args in [
("AsinComplex", "complex_asin", ("z:complex",)),
("AsinReal", "real_asin", ("x:float",)),
]:
func = getattr(fa.algorithms, fname, None)
if func is None:
warnings.warn(
f"{fa.algorithms.__name__} does not define {fname}. Skipping."
)
continue
ctx = fa.Context(
paths=[fa.algorithms],
enable_alt=True,
default_constant_type="FloatType",
)
graph = ctx.trace(func, *args).implement_missing(target).simplify()
graph.props.update(name=xlaname)
src = graph.tostring(target)
if func.__doc__:
sources.append(target.make_comment(func.__doc__))
sources[-1] += src
source = "\n\n".join(sources) + "\n"

if os.path.isfile(output_file):
f = open(output_file, "r", encoding="UTF-8")
content = f.read()
f.close()
if content.endswith(source) and 0:
warnings.warn(f"{output_file} is up-to-date.")
return

f = open(output_file, "w", encoding="UTF-8")
f.write("""/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
""")
f.write(target.make_comment(f"""\
This file is generated using functional_algorithms tool ({fa.__version__}), see
https://github.com/pearu/functional_algorithms
for more information.""") + "\n")
f.write(f"""\
#ifndef {header_file_define}
#define {header_file_define}
#include "xla/client/lib/constants.h"
#include "xla/client/xla_builder.h"
namespace xla {{
namespace math_impl {{
// NOLINTBEGIN(whitespace/line_length)
// clang-format off
""")
f.write(source)
f.write(f"""
// clang-format on
// NOLINTEND(whitespace/line_length)
}} // namespace math_impl
}} // namespace xla
#endif // {header_file_define}
""")
f.close()
warnings.warn(f"Created {output_file}")


if __name__ == "__main__":
main()
48 changes: 42 additions & 6 deletions xla/client/lib/math.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ limitations under the License.
#include "xla/client/lib/arithmetic.h"
#include "xla/client/lib/constants.h"
#include "xla/client/lib/loops.h"
#include "xla/client/lib/math_impl.h"
#include "xla/client/xla_builder.h"
#include "xla/primitive_util.h"
#include "xla/shape.h"
Expand Down Expand Up @@ -1188,8 +1189,31 @@ XlaOp Acos(XlaOp x) {

// asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2)))
XlaOp Asin(XlaOp x) {
return ScalarLike(x, 2.0) *
Atan2(x, ScalarLike(x, 1.0) + Sqrt(ScalarLike(x, 1.0) - x * x));
XlaBuilder* b = x.builder();
auto do_it = [&](XlaOp z) -> absl::StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(z));
auto elem_ty = shape.element_type();
switch (elem_ty) {
case C128:
return math_impl::AsinComplex<double>(z);
case C64:
return math_impl::AsinComplex<float>(z);
case F64:
return math_impl::AsinReal<double>(z);
case F32:
return math_impl::AsinReal<float>(z);
// todo(pearu): add implementations for BF16 and F16 to avoid
// the upcast below
default:
return InvalidArgument("Asin got unsupported element type %s",
PrimitiveType_Name(elem_ty));
}
};
// These upcasts are not strictly necessary on all platforms to get within our
// error tolerances, so we could relax this if it ever mattered.
return DoWithUpcastToF32(x, {BF16, F16}, [&](XlaOp x) {
return b->ReportErrorOrReturn(do_it(x));
});
}

XlaOp Atan(XlaOp x) { return Atan2(x, ScalarLike(x, 1.0)); }
Expand Down Expand Up @@ -1256,11 +1280,23 @@ XlaOp Asinh(XlaOp x) {
//
// y * sign(x).
//
// TODO(jlebar): For now, we ignore the question of overflow if x is a
// complex type, because we don't yet have exhaustive tests for complex trig
// functions.
if (primitive_util::IsComplexType(shape.element_type())) {
return Log(x + Sqrt(x * x + one));
// Asinh(x) = I * Asin(-I * x)
//
// We use mixed-mode arithmetic instead of complex arithemtic to
// ensure that multiplication of I and complex infinities will
// not produce superficial nan's:
auto x_re = Real(x);
auto x_im = Imag(x);
auto z = Asin(Complex(x_im, -x_re));
auto z_im = Imag(z);
// when abs(x.imag) > 1 and x.real == 0, select correct branch
// from Asin(Complex(x.imag, -0)) result (assuming x.real is +0,
// the imaginary part of the argument to Asin approaches 0 from
// the negative side):
auto on_branch_cut = And(Eq(x_re, ScalarLike(x_re, 0)),
Gt(Abs(x_im), ScalarLike(x_im, 1)));
return Complex(Select(on_branch_cut, z_im, -z_im), Real(z));
}
// For small x, sqrt(x**2 + 1) will evaluate to 1 due to floating point
// arithmetic. However, we would like to retain the low order term of this,
Expand Down
Loading

0 comments on commit 6942179

Please sign in to comment.