Skip to content

Commit

Permalink
dp solver fix up
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrantq committed Nov 8, 2024
1 parent 50abba3 commit 003de29
Showing 1 changed file with 54 additions and 39 deletions.
93 changes: 54 additions & 39 deletions enzyme/Enzyme/Herbie.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2420,6 +2420,7 @@ InstructionCost getCompCost(FPCC &component, const TargetTransformInfo &TTI,

pt.apply(component, &VMap);
// output values in VMap are changed to the new casted values
// llvm::errs() << "\nDEBUG: " << pt.desc << "\n";
// FClone->print(llvm::errs());

SmallPtrSet<Value *, 8> clonedInputs;
Expand Down Expand Up @@ -2458,9 +2459,6 @@ InstructionCost getCompCost(FPCC &component, const TargetTransformInfo &TTI,
}
}

// llvm::errs() << "DEBUG: " << pt.desc << "\n";
// FClone->print(llvm::errs());

FClone->eraseFromParent();

return cost;
Expand Down Expand Up @@ -2844,6 +2842,10 @@ class ApplicableFPCC {

InstructionCost adjustedCostDelta =
(candidateCompCost - initialCompCost) * executions;
// llvm::errs() << "Initial cost: " << initialCompCost << "\n";
// llvm::errs() << "Candidate cost: " << candidateCompCost << "\n";
// llvm::errs() << "Num executions: " << executions << "\n";
// llvm::errs() << "Adjusted cost delta: " << adjustedCostDelta << "\n\n";

compCostDeltaCache[key] = adjustedCostDelta;
return adjustedCostDelta;
Expand Down Expand Up @@ -3577,30 +3579,37 @@ bool accuracyDPSolver(
costToAccuracyMap[0] = 0;
SolutionMap costToSolutionMap;
costToSolutionMap[0] = {};
CostMap newCostToAccuracyMap;
SolutionMap newCostToSolutionMap;
CostMap prunedCostToAccuracyMap;
SolutionMap prunedCostToSolutionMap;

int AOCounter = 0;

for (auto &AO : AOs) {
CostMap newCostToAccuracyMap;
SolutionMap newCostToSolutionMap;
// It is possible to apply zero candidate for an AO.
// When no candidate is applied, the resulting accuracy cost
// and solution steps remain the same.
newCostToAccuracyMap = costToAccuracyMap;
newCostToSolutionMap = costToSolutionMap;

llvm::errs() << "DP table sizes: " << costToAccuracyMap.size() << " (Acc) "
<< costToSolutionMap.size() << " (Sol)\n";

for (const auto &pair : costToAccuracyMap) {
InstructionCost currCompCost = pair.first;
double currAccCost = pair.second;

// It is possible to apply zero candidate for an AO
if (newCostToAccuracyMap.find(currCompCost) ==
newCostToAccuracyMap.end() ||
newCostToAccuracyMap[currCompCost] > currAccCost) {
newCostToAccuracyMap[currCompCost] = currAccCost;
newCostToSolutionMap[currCompCost] = costToSolutionMap[currCompCost];
}

for (auto &candidate : enumerate(AO.candidates)) {
size_t i = candidate.index();
auto candCompCost = AO.getCompCostDelta(i);
auto candAccCost = AO.getAccCostDelta(i);

// Don't ever try to apply a strictly useless candidate
if (candCompCost >= 0 && candAccCost >= 0.) {
continue;
}

InstructionCost newCompCost = currCompCost + candCompCost;
double newAccCost = currAccCost + candAccCost;

Expand All @@ -3625,17 +3634,14 @@ bool accuracyDPSolver(

// TODO: Do not prune AO parts of the DP table since AOs influence ACCs
if (!FPOptEarlyPrune) {
costToAccuracyMap.swap(newCostToAccuracyMap);
costToSolutionMap.swap(newCostToSolutionMap);
costToAccuracyMap = newCostToAccuracyMap;
costToSolutionMap = newCostToSolutionMap;

llvm::errs() << "##### Finished processing " << ++AOCounter << " of "
<< AOs.size() << " AOs #####\n";
continue;
}

CostMap prunedCostToAccuracyMap;
SolutionMap prunedCostToSolutionMap;

for (const auto &l : newCostToAccuracyMap) {
InstructionCost currCompCost = l.first;
double currAccCost = l.second;
Expand Down Expand Up @@ -3670,8 +3676,10 @@ bool accuracyDPSolver(
}
}

costToAccuracyMap.swap(prunedCostToAccuracyMap);
costToSolutionMap.swap(prunedCostToSolutionMap);
costToAccuracyMap = prunedCostToAccuracyMap;
costToSolutionMap = prunedCostToSolutionMap;
prunedCostToAccuracyMap.clear();
prunedCostToSolutionMap.clear();

llvm::errs() << "##### Finished processing " << ++AOCounter << " of "
<< AOs.size() << " AOs #####\n";
Expand All @@ -3680,21 +3688,19 @@ bool accuracyDPSolver(
int ACCCounter = 0;

for (auto &ACC : ACCs) {
CostMap newCostToAccuracyMap;
SolutionMap newCostToSolutionMap;
// It is possible to apply zero candidate for an ACC.
// When no candidate is applied, the resulting accuracy cost
// and solution steps remain the same.
newCostToAccuracyMap = costToAccuracyMap;
newCostToSolutionMap = costToSolutionMap;

llvm::errs() << "DP table sizes: " << costToAccuracyMap.size() << " (Acc) "
<< costToSolutionMap.size() << " (Sol)\n";

for (const auto &pair : costToAccuracyMap) {
InstructionCost currCompCost = pair.first;
double currAccCost = pair.second;

// It is possible to apply zero candidate for an ACC
if (newCostToAccuracyMap.find(currCompCost) ==
newCostToAccuracyMap.end() ||
newCostToAccuracyMap[currCompCost] > currAccCost) {
newCostToAccuracyMap[currCompCost] = currAccCost;
newCostToSolutionMap[currCompCost] = costToSolutionMap[currCompCost];
}

for (auto &candidate : enumerate(ACC.candidates)) {
size_t i = candidate.index();
auto candCompCost =
Expand All @@ -3703,6 +3709,11 @@ bool accuracyDPSolver(
ACC.getAdjustedAccCostDelta(i, costToSolutionMap[currCompCost],
valueToNodeMap, symbolToValueMap);

// Don't ever try to apply a strictly useless candidate
if (candCompCost >= 0 && candAccCost >= 0.) {
continue;
}

InstructionCost newCompCost = currCompCost + candCompCost;
double newAccCost = currAccCost + candAccCost;

Expand All @@ -3718,17 +3729,19 @@ bool accuracyDPSolver(
newCostToAccuracyMap[newCompCost] = newAccCost;
newCostToSolutionMap[newCompCost] = costToSolutionMap[currCompCost];
newCostToSolutionMap[newCompCost].emplace_back(&ACC, i);
// if (EnzymePrintFPOpt)
// llvm::errs() << "Updating accuracy map (ACC candidate " << i
// << "): computation cost " << newCompCost
// << " -> accuracy cost " << newAccCost << "\n";
if (EnzymePrintFPOpt) {
// llvm::errs() << "ACC candidate " << i << " ("
// << candidate.value().desc
// << ") added; has accuracy cost: " << candAccCost
// << " and computation cost: " << candCompCost << "\n";
// llvm::errs() << "Updating accuracy map (ACC candidate " << i
// << "): computation cost " << newCompCost
// << " -> accuracy cost " << newAccCost << "\n";
}
}
}
}

CostMap prunedCostToAccuracyMap;
SolutionMap prunedCostToSolutionMap;

for (const auto &l : newCostToAccuracyMap) {
InstructionCost currCompCost = l.first;
double currAccCost = l.second;
Expand Down Expand Up @@ -3763,8 +3776,10 @@ bool accuracyDPSolver(
}
}

costToAccuracyMap.swap(prunedCostToAccuracyMap);
costToSolutionMap.swap(prunedCostToSolutionMap);
costToAccuracyMap = prunedCostToAccuracyMap;
costToSolutionMap = prunedCostToSolutionMap;
prunedCostToAccuracyMap.clear();
prunedCostToSolutionMap.clear();

llvm::errs() << "##### Finished processing " << ++ACCCounter << " of "
<< ACCs.size() << " ACCs #####\n";
Expand Down

0 comments on commit 003de29

Please sign in to comment.