Skip to content

Commit

Permalink
accuracy cost evaluation: arithmetic avg --> geometric avg
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrantq committed Nov 8, 2024
1 parent 60a6ddd commit fe0b354
Showing 1 changed file with 75 additions and 68 deletions.
143 changes: 75 additions & 68 deletions enzyme/Enzyme/Herbie.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2925,12 +2925,12 @@ void setUnifiedAccuracyCost(
// llvm::errs() << "DEBUG AO real value: " << realVal << "\n";

if (!std::isnan(goldVal) && !std::isnan(realVal)) {
initAC += std::fabs(goldVal - realVal);
initAC += std::log1p(std::fabs(goldVal - realVal));
numValidSamples++;
}
}

AO.initialAccCost = initAC / numValidSamples * std::fabs(AO.grad);
AO.initialAccCost = std::expm1(initAC / numValidSamples) * std::fabs(AO.grad);
// llvm::errs() << "DEBUG calculated AO initial accuracy cost: "
// << AO.initialAccCost << "\n";
assert(numValidSamples && "No valid samples for AO -- try increasing the "
Expand Down Expand Up @@ -2968,13 +2968,14 @@ void setUnifiedAccuracyCost(
// llvm::errs() << "Real value: " << realVal << "\n";
double goldVal = goldVals[pair.index()];
if (!std::isnan(goldVal) && !std::isnan(realVal)) {
ac += std::fabs(goldVal - realVal);
ac += std::log1p(std::fabs(goldVal - realVal));
numValidSamples++;
}
}
assert(numValidSamples && "No valid samples for AO -- try increasing the "
"number of samples");
candidate.accuracyCost = ac / numValidSamples * std::fabs(AO.grad);
candidate.accuracyCost =
std::expm1(ac / numValidSamples) * std::fabs(AO.grad);
assert(!std::isnan(candidate.accuracyCost));
}
}
Expand Down Expand Up @@ -3024,7 +3025,7 @@ void setUnifiedAccuracyCost(
double goldVal = goldVals[output][pair.index()];
if (!std::isnan(goldVal) && !std::isnan(result)) {
double diff = std::fabs(goldVal - result);
ACC.perOutputInitialAccCost[output] += diff;
ACC.perOutputInitialAccCost[output] += std::log1p(diff);
numValidSamplesPerOutput[output]++;
}
}
Expand All @@ -3036,9 +3037,10 @@ void setUnifiedAccuracyCost(
unsigned numValidSamples = numValidSamplesPerOutput[output];
assert(numValidSamples && "No valid samples for at least one output node "
"-- try increasing the number of samples");
ACC.perOutputInitialAccCost[output] /= numValidSamples;
// Local error --> global error
ACC.perOutputInitialAccCost[output] *= std::fabs(output->grad);
ACC.perOutputInitialAccCost[output] =
std::expm1(ACC.perOutputInitialAccCost[output] / numValidSamples) *
std::fabs(output->grad);
// llvm::errs() << "DEBUG calculated ACC per output initial accuracy cost: "
// << ACC.perOutputInitialAccCost[output] << "\n";
ACC.initialAccCost += ACC.perOutputInitialAccCost[output];
Expand All @@ -3062,7 +3064,7 @@ void setUnifiedAccuracyCost(
if (!std::isnan(goldVal) && !std::isnan(result)) {
double diff = std::fabs(goldVal - result);
// Sum up local errors
candidate.perOutputAccCost[output] += diff;
candidate.perOutputAccCost[output] += std::log1p(diff);
numValidSamplesPerOutput[output]++;
}
}
Expand All @@ -3074,9 +3076,10 @@ void setUnifiedAccuracyCost(
unsigned numValidSamples = numValidSamplesPerOutput[output];
assert(numValidSamples && "No valid samples for output -- try increasing "
"the number of samples");
candidate.perOutputAccCost[output] /= numValidSamples;
// Local error --> global error
candidate.perOutputAccCost[output] *= std::fabs(output->grad);
candidate.perOutputAccCost[output] =
std::expm1(candidate.perOutputAccCost[output] / numValidSamples) *
std::fabs(output->grad);
// llvm::errs()
// << "DEBUG calculated ACC per output candidate accuracy cost: "
// << candidate.perOutputAccCost[output] << "\n";
Expand Down Expand Up @@ -3585,21 +3588,21 @@ bool accuracyDPSolver(
InstructionCost newCompCost = currCompCost + candCompCost;
double newAccCost = currAccCost + candAccCost;

if (EnzymePrintFPOpt)
llvm::errs() << "AO candidate " << i
<< " has accuracy cost: " << candAccCost
<< " and computation cost: " << candCompCost << "\n";
// if (EnzymePrintFPOpt)
// llvm::errs() << "AO candidate " << i
// << " has accuracy cost: " << candAccCost
// << " and computation cost: " << candCompCost << "\n";

if (newCostToAccuracyMap.find(newCompCost) ==
newCostToAccuracyMap.end() ||
newCostToAccuracyMap[newCompCost] > newAccCost) {
newCostToAccuracyMap[newCompCost] = newAccCost;
newCostToSolutionMap[newCompCost] = costToSolutionMap[currCompCost];
newCostToSolutionMap[newCompCost].emplace_back(&AO, i);
if (EnzymePrintFPOpt)
llvm::errs() << "Updating accuracy map (AO candidate " << i
<< "): computation cost " << newCompCost
<< " -> accuracy cost " << newAccCost << "\n";
// if (EnzymePrintFPOpt)
// llvm::errs() << "Updating accuracy map (AO candidate " << i
// << "): computation cost " << newCompCost
// << " -> accuracy cost " << newAccCost << "\n";
}
}
}
Expand Down Expand Up @@ -3629,13 +3632,14 @@ bool accuracyDPSolver(
otherCompCost.getValue().getValue()) &&
currAccCost - otherAccCost >
std::fabs(FPOptAccuracyDominanceThreshold * otherAccCost)) {
if (EnzymePrintFPOpt)
llvm::errs() << "AO candidate with computation cost: "
<< currCompCost
<< " and accuracy cost: " << currAccCost
<< " is dominated by candidate with computation cost:"
<< otherCompCost
<< " and accuracy cost: " << otherAccCost << "\n";
// if (EnzymePrintFPOpt)
// llvm::errs() << "AO candidate with computation cost: "
// << currCompCost
// << " and accuracy cost: " << currAccCost
// << " is dominated by candidate with computation
// cost:"
// << otherCompCost
// << " and accuracy cost: " << otherAccCost << "\n";
dominated = true;
break;
}
Expand Down Expand Up @@ -3679,22 +3683,22 @@ bool accuracyDPSolver(
InstructionCost newCompCost = currCompCost + candCompCost;
double newAccCost = currAccCost + candAccCost;

if (EnzymePrintFPOpt)
llvm::errs() << "ACC candidate " << i << " ("
<< candidate.value().desc
<< ") has accuracy cost: " << candAccCost
<< " and computation cost: " << candCompCost << "\n";
// if (EnzymePrintFPOpt)
// llvm::errs() << "ACC candidate " << i << " ("
// << candidate.value().desc
// << ") has accuracy cost: " << candAccCost
// << " and computation cost: " << candCompCost << "\n";

if (newCostToAccuracyMap.find(newCompCost) ==
newCostToAccuracyMap.end() ||
newCostToAccuracyMap[newCompCost] > newAccCost) {
newCostToAccuracyMap[newCompCost] = newAccCost;
newCostToSolutionMap[newCompCost] = costToSolutionMap[currCompCost];
newCostToSolutionMap[newCompCost].emplace_back(&ACC, i);
if (EnzymePrintFPOpt)
llvm::errs() << "Updating accuracy map (ACC candidate " << i
<< "): computation cost " << newCompCost
<< " -> accuracy cost " << newAccCost << "\n";
// if (EnzymePrintFPOpt)
// llvm::errs() << "Updating accuracy map (ACC candidate " << i
// << "): computation cost " << newCompCost
// << " -> accuracy cost " << newAccCost << "\n";
}
}
}
Expand All @@ -3716,13 +3720,14 @@ bool accuracyDPSolver(
otherCompCost.getValue().getValue()) &&
currAccCost - otherAccCost >
std::fabs(FPOptAccuracyDominanceThreshold * otherAccCost)) {
if (EnzymePrintFPOpt)
llvm::errs() << "ACC candidate with computation cost: "
<< currCompCost
<< " and accuracy cost: " << currAccCost
<< " is dominated by candidate with computation cost:"
<< otherCompCost
<< " and accuracy cost: " << otherAccCost << "\n";
// if (EnzymePrintFPOpt)
// llvm::errs() << "ACC candidate with computation cost: "
// << currCompCost
// << " and accuracy cost: " << currAccCost
// << " is dominated by candidate with computation
// cost:"
// << otherCompCost
// << " and accuracy cost: " << otherAccCost << "\n";
dominated = true;
break;
}
Expand All @@ -3740,33 +3745,35 @@ bool accuracyDPSolver(
}

if (EnzymePrintFPOpt) {
llvm::errs() << "\n*** DP Table ***\n";
for (const auto &pair : costToAccuracyMap) {
llvm::errs() << "Computation cost: " << pair.first
<< ", Accuracy cost: " << pair.second << "\n";
llvm::errs() << "\tSolution steps: \n";
for (const auto &step : costToSolutionMap[pair.first]) {
std::visit(
[&](auto *item) {
using T = std::decay_t<decltype(*item)>;
if constexpr (std::is_same_v<T, ApplicableOutput>) {
llvm::errs()
<< "\t\t" << item->expr << " --(" << step.candidateIndex
<< ")-> " << item->candidates[step.candidateIndex].expr
<< "\n";
} else if constexpr (std::is_same_v<T, ApplicableFPCC>) {
llvm::errs()
<< "\t\tACC: " << item->candidates[step.candidateIndex].desc
<< " (#" << step.candidateIndex << ")\n";
} else {
llvm_unreachable(
"accuracyDPSolver: Unexpected type of solution step");
}
},
step.item);
}
}
llvm::errs() << "*** End of DP Table ***\n\n";
// llvm::errs() << "\n*** DP Table ***\n";
// for (const auto &pair : costToAccuracyMap) {
// llvm::errs() << "Computation cost: " << pair.first
// << ", Accuracy cost: " << pair.second << "\n";
// llvm::errs() << "\tSolution steps: \n";
// for (const auto &step : costToSolutionMap[pair.first]) {
// std::visit(
// [&](auto *item) {
// using T = std::decay_t<decltype(*item)>;
// if constexpr (std::is_same_v<T, ApplicableOutput>) {
// llvm::errs()
// << "\t\t" << item->expr << " --(" <<
// step.candidateIndex
// << ")-> " << item->candidates[step.candidateIndex].expr
// << "\n";
// } else if constexpr (std::is_same_v<T, ApplicableFPCC>) {
// llvm::errs()
// << "\t\tACC: " <<
// item->candidates[step.candidateIndex].desc
// << " (#" << step.candidateIndex << ")\n";
// } else {
// llvm_unreachable(
// "accuracyDPSolver: Unexpected type of solution step");
// }
// },
// step.item);
// }
// }
// llvm::errs() << "*** End of DP Table ***\n\n";
llvm::errs() << "*** Critical Computation Costs ***\n";
// Just print all computation costs in the DP table
for (const auto &pair : costToAccuracyMap) {
Expand Down

0 comments on commit fe0b354

Please sign in to comment.