From 003de290d2e9e8a28c2ed7e8cc01aace6d68266b Mon Sep 17 00:00:00 2001 From: Brant-Skywalker Date: Fri, 8 Nov 2024 17:53:47 -0600 Subject: [PATCH] dp solver fix up --- enzyme/Enzyme/Herbie.cpp | 93 +++++++++++++++++++++++----------------- 1 file changed, 54 insertions(+), 39 deletions(-) diff --git a/enzyme/Enzyme/Herbie.cpp b/enzyme/Enzyme/Herbie.cpp index 295066394bd..f24d95c9ae7 100644 --- a/enzyme/Enzyme/Herbie.cpp +++ b/enzyme/Enzyme/Herbie.cpp @@ -2420,6 +2420,7 @@ InstructionCost getCompCost(FPCC &component, const TargetTransformInfo &TTI, pt.apply(component, &VMap); // output values in VMap are changed to the new casted values + // llvm::errs() << "\nDEBUG: " << pt.desc << "\n"; // FClone->print(llvm::errs()); SmallPtrSet clonedInputs; @@ -2458,9 +2459,6 @@ InstructionCost getCompCost(FPCC &component, const TargetTransformInfo &TTI, } } - // llvm::errs() << "DEBUG: " << pt.desc << "\n"; - // FClone->print(llvm::errs()); - FClone->eraseFromParent(); return cost; @@ -2844,6 +2842,10 @@ class ApplicableFPCC { InstructionCost adjustedCostDelta = (candidateCompCost - initialCompCost) * executions; + // llvm::errs() << "Initial cost: " << initialCompCost << "\n"; + // llvm::errs() << "Candidate cost: " << candidateCompCost << "\n"; + // llvm::errs() << "Num executions: " << executions << "\n"; + // llvm::errs() << "Adjusted cost delta: " << adjustedCostDelta << "\n\n"; compCostDeltaCache[key] = adjustedCostDelta; return adjustedCostDelta; @@ -3577,30 +3579,37 @@ bool accuracyDPSolver( costToAccuracyMap[0] = 0; SolutionMap costToSolutionMap; costToSolutionMap[0] = {}; + CostMap newCostToAccuracyMap; + SolutionMap newCostToSolutionMap; + CostMap prunedCostToAccuracyMap; + SolutionMap prunedCostToSolutionMap; int AOCounter = 0; for (auto &AO : AOs) { - CostMap newCostToAccuracyMap; - SolutionMap newCostToSolutionMap; + // It is possible to apply zero candidate for an AO. + // When no candidate is applied, the resulting accuracy cost + // and solution steps remain the same. + newCostToAccuracyMap = costToAccuracyMap; + newCostToSolutionMap = costToSolutionMap; + + llvm::errs() << "DP table sizes: " << costToAccuracyMap.size() << " (Acc) " + << costToSolutionMap.size() << " (Sol)\n"; for (const auto &pair : costToAccuracyMap) { InstructionCost currCompCost = pair.first; double currAccCost = pair.second; - // It is possible to apply zero candidate for an AO - if (newCostToAccuracyMap.find(currCompCost) == - newCostToAccuracyMap.end() || - newCostToAccuracyMap[currCompCost] > currAccCost) { - newCostToAccuracyMap[currCompCost] = currAccCost; - newCostToSolutionMap[currCompCost] = costToSolutionMap[currCompCost]; - } - for (auto &candidate : enumerate(AO.candidates)) { size_t i = candidate.index(); auto candCompCost = AO.getCompCostDelta(i); auto candAccCost = AO.getAccCostDelta(i); + // Don't ever try to apply a strictly useless candidate + if (candCompCost >= 0 && candAccCost >= 0.) { + continue; + } + InstructionCost newCompCost = currCompCost + candCompCost; double newAccCost = currAccCost + candAccCost; @@ -3625,17 +3634,14 @@ bool accuracyDPSolver( // TODO: Do not prune AO parts of the DP table since AOs influence ACCs if (!FPOptEarlyPrune) { - costToAccuracyMap.swap(newCostToAccuracyMap); - costToSolutionMap.swap(newCostToSolutionMap); + costToAccuracyMap = newCostToAccuracyMap; + costToSolutionMap = newCostToSolutionMap; llvm::errs() << "##### Finished processing " << ++AOCounter << " of " << AOs.size() << " AOs #####\n"; continue; } - CostMap prunedCostToAccuracyMap; - SolutionMap prunedCostToSolutionMap; - for (const auto &l : newCostToAccuracyMap) { InstructionCost currCompCost = l.first; double currAccCost = l.second; @@ -3670,8 +3676,10 @@ bool accuracyDPSolver( } } - costToAccuracyMap.swap(prunedCostToAccuracyMap); - costToSolutionMap.swap(prunedCostToSolutionMap); + costToAccuracyMap = prunedCostToAccuracyMap; + costToSolutionMap = prunedCostToSolutionMap; + prunedCostToAccuracyMap.clear(); + prunedCostToSolutionMap.clear(); llvm::errs() << "##### Finished processing " << ++AOCounter << " of " << AOs.size() << " AOs #####\n"; @@ -3680,21 +3688,19 @@ bool accuracyDPSolver( int ACCCounter = 0; for (auto &ACC : ACCs) { - CostMap newCostToAccuracyMap; - SolutionMap newCostToSolutionMap; + // It is possible to apply zero candidate for an ACC. + // When no candidate is applied, the resulting accuracy cost + // and solution steps remain the same. + newCostToAccuracyMap = costToAccuracyMap; + newCostToSolutionMap = costToSolutionMap; + + llvm::errs() << "DP table sizes: " << costToAccuracyMap.size() << " (Acc) " + << costToSolutionMap.size() << " (Sol)\n"; for (const auto &pair : costToAccuracyMap) { InstructionCost currCompCost = pair.first; double currAccCost = pair.second; - // It is possible to apply zero candidate for an ACC - if (newCostToAccuracyMap.find(currCompCost) == - newCostToAccuracyMap.end() || - newCostToAccuracyMap[currCompCost] > currAccCost) { - newCostToAccuracyMap[currCompCost] = currAccCost; - newCostToSolutionMap[currCompCost] = costToSolutionMap[currCompCost]; - } - for (auto &candidate : enumerate(ACC.candidates)) { size_t i = candidate.index(); auto candCompCost = @@ -3703,6 +3709,11 @@ bool accuracyDPSolver( ACC.getAdjustedAccCostDelta(i, costToSolutionMap[currCompCost], valueToNodeMap, symbolToValueMap); + // Don't ever try to apply a strictly useless candidate + if (candCompCost >= 0 && candAccCost >= 0.) { + continue; + } + InstructionCost newCompCost = currCompCost + candCompCost; double newAccCost = currAccCost + candAccCost; @@ -3718,17 +3729,19 @@ bool accuracyDPSolver( newCostToAccuracyMap[newCompCost] = newAccCost; newCostToSolutionMap[newCompCost] = costToSolutionMap[currCompCost]; newCostToSolutionMap[newCompCost].emplace_back(&ACC, i); - // if (EnzymePrintFPOpt) - // llvm::errs() << "Updating accuracy map (ACC candidate " << i - // << "): computation cost " << newCompCost - // << " -> accuracy cost " << newAccCost << "\n"; + if (EnzymePrintFPOpt) { + // llvm::errs() << "ACC candidate " << i << " (" + // << candidate.value().desc + // << ") added; has accuracy cost: " << candAccCost + // << " and computation cost: " << candCompCost << "\n"; + // llvm::errs() << "Updating accuracy map (ACC candidate " << i + // << "): computation cost " << newCompCost + // << " -> accuracy cost " << newAccCost << "\n"; + } } } } - CostMap prunedCostToAccuracyMap; - SolutionMap prunedCostToSolutionMap; - for (const auto &l : newCostToAccuracyMap) { InstructionCost currCompCost = l.first; double currAccCost = l.second; @@ -3763,8 +3776,10 @@ bool accuracyDPSolver( } } - costToAccuracyMap.swap(prunedCostToAccuracyMap); - costToSolutionMap.swap(prunedCostToSolutionMap); + costToAccuracyMap = prunedCostToAccuracyMap; + costToSolutionMap = prunedCostToSolutionMap; + prunedCostToAccuracyMap.clear(); + prunedCostToSolutionMap.clear(); llvm::errs() << "##### Finished processing " << ++ACCCounter << " of " << ACCs.size() << " ACCs #####\n";