Skip to content

Commit

Permalink
BCLoad: use enzyme_math (#2149)
Browse files Browse the repository at this point in the history
* BCLoad: use enzyme_math

* Fix

* fix
  • Loading branch information
wsmoses authored Nov 4, 2024
1 parent 53886d6 commit b38b240
Showing 1 changed file with 25 additions and 5 deletions.
30 changes: 25 additions & 5 deletions enzyme/BCLoad/BCLoader.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Linker/Linker.h"
Expand Down Expand Up @@ -38,6 +39,7 @@ bool provideDefinitions(Module &M, std::set<std::string> ignoreFunctions,
std::vector<StringRef> todo;
bool seen32 = false;
bool seen64 = false;
std::vector<std::pair<StringRef, llvm::Function *>> name_rewrites;
for (auto &F : M) {
if (!F.empty())
continue;
Expand All @@ -48,17 +50,19 @@ bool provideDefinitions(Module &M, std::set<std::string> ignoreFunctions,
for (auto postfix : {"", "_", "_64_"}) {
std::string str;
if (strlen(postfix) == 0) {
str = F.getName().str();
str = name.str();
} else if (endsWith(name, postfix)) {
auto blasName =
name.substr(0, name.size() - strlen(postfix)).str();
auto blasName = name.substr(0, name.size() - strlen(postfix)).str();
str = "cblas_" + blasName;
}

auto found = EnzymeBlasBC.find(str);
if (found != EnzymeBlasBC.end()) {
replaced.push_back(name.str());
todo.push_back(found->second);
if (name != F.getName()) {
name_rewrites.emplace_back(name, &F);
}
if (index == 1)
seen32 = true;
if (index == 2)
Expand All @@ -69,6 +73,21 @@ bool provideDefinitions(Module &M, std::set<std::string> ignoreFunctions,
}
}

for (auto &&[realname, F] : name_rewrites) {
auto decl = M.getOrInsertFunction(realname, F->getFunctionType());
auto entry = BasicBlock::Create(F->getContext(), "entry",
cast<Function>(decl.getCallee()));
IRBuilder<> B(entry);
SmallVector<Value *, 1> vals;
for (auto &arg : F->args())
vals.push_back(&arg);
auto rt = B.CreateCall(decl, vals);
if (rt->getType()->isVoidTy())
B.CreateRetVoid();
else
B.CreateRet(rt);
}

// Push fortran wrapper libs before all the other blas
// to ensure the fortran injections have their code
// replaced
Expand Down Expand Up @@ -102,7 +121,8 @@ bool provideDefinitions(Module &M, std::set<std::string> ignoreFunctions,
for (auto &F : *BC) {
if (F.empty())
continue;
if (ignoreFunctions.count(F.getName().str())) {
auto name = getFuncName(&F);
if (ignoreFunctions.count(name.str())) {
F.dropAllReferences();
#if LLVM_VERSION_MAJOR >= 16
F.erase(F.begin(), F.end());
Expand All @@ -111,7 +131,7 @@ bool provideDefinitions(Module &M, std::set<std::string> ignoreFunctions,
#endif
continue;
}
toReplace.push_back(F.getName().str());
toReplace.push_back(name.str());
}
BC->setTargetTriple("");
Linker L(M);
Expand Down

0 comments on commit b38b240

Please sign in to comment.