Skip to content

Commit

Permalink
Add __*_finite versions of derivatives
Browse files Browse the repository at this point in the history
  • Loading branch information
ivanradanov committed Apr 28, 2024
1 parent 5468479 commit 76567b6
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
13 changes: 13 additions & 0 deletions enzyme/test/Enzyme/ForwardMode/hypot.ll
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,26 @@ entry:
ret double %call
}

define double @tester2(double %x, double %y) {
entry:
%call = tail call double @__hypot_finite(double %x, double %y)
ret double %call
}

define double @test_derivative(double %x, double %y) {
entry:
%0 = tail call double (...) @__enzyme_fwddiff(double (double, double)* nonnull @tester, double %x, double 1.000000e+00, double %y, double 1.000000e+00)
ret double %0
}

define double @test_derivative2(double %x, double %y) {
entry:
%0 = tail call double (...) @__enzyme_fwddiff(double (double, double)* nonnull @tester2, double %x, double 1.000000e+00, double %y, double 1.000000e+00)
ret double %0
}

declare double @hypot(double, double)
declare double @__hypot_finite(double, double)

; Function Attrs: nounwind
declare double @__enzyme_fwddiff(...)
Expand Down
15 changes: 11 additions & 4 deletions enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1726,10 +1726,17 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os,
bool prev = false;
for (auto *nameI :
*cast<ListInit>(pattern->getValueAsListInit("names"))) {
if (prev)
os << " ||\n ";
os << "funcName == " << cast<StringInit>(nameI)->getAsString() << "";
prev = true;
auto nameIStr = cast<StringInit>(nameI)->getAsString();
auto nameIStrFinite = "\"__" +
std::string(std::next(nameIStr.begin()),
std::prev(nameIStr.end())) +
"_finite\"";
for (auto nameIStrAll : {nameIStr, nameIStrFinite}) {
if (prev)
os << " ||\n ";
os << "funcName == " << nameIStrAll << "";
prev = true;
}
}
origName = "call";
#if LLVM_VERSION_MAJOR >= 14
Expand Down

0 comments on commit 76567b6

Please sign in to comment.