Skip to content

Commit

Permalink
change metric for categorical sorting to match LightGBM
Browse files Browse the repository at this point in the history
  • Loading branch information
paulbkoch committed Dec 24, 2024
1 parent 745c561 commit 73d82fd
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 22 deletions.
29 changes: 8 additions & 21 deletions shared/libebm/PartitionOneDimensionalBoosting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -704,19 +704,12 @@ template<bool bHessian> class CompareNodeGain final {

template<bool bHessian, size_t cCompilerScores> class CompareBin final {
bool m_bHessianRuntime;
FloatCalc m_regAlpha;
FloatCalc m_regLambda;
FloatCalc m_deltaStepMax;
FloatCalc m_categoricalSmoothing;

public:
INLINE_ALWAYS CompareBin(const bool bHessianRuntime,
const FloatCalc regAlpha,
const FloatCalc regLambda,
const FloatCalc deltaStepMax) {
INLINE_ALWAYS CompareBin(const bool bHessianRuntime, FloatCalc categoricalSmoothing) {
m_bHessianRuntime = bHessianRuntime;
m_regAlpha = regAlpha;
m_regLambda = regLambda;
m_deltaStepMax = deltaStepMax;
m_categoricalSmoothing = categoricalSmoothing;
}

INLINE_ALWAYS bool operator()(
Expand All @@ -729,19 +722,13 @@ template<bool bHessian, size_t cCompilerScores> class CompareBin final {

const FloatCalc hess1 =
static_cast<FloatCalc>(bUpdateWithHessian ? lhs->GetGradientPairs()[0].GetHess() : lhs->GetWeight());
const FloatCalc val1 = CalcNegUpdate<true>(static_cast<FloatCalc>(lhs->GetGradientPairs()[0].m_sumGradients),
hess1,
m_regAlpha,
m_regLambda,
m_deltaStepMax);
const FloatCalc val1 =
static_cast<FloatCalc>(lhs->GetGradientPairs()[0].m_sumGradients) / (hess1 + m_categoricalSmoothing);

const FloatCalc hess2 =
static_cast<FloatCalc>(bUpdateWithHessian ? rhs->GetGradientPairs()[0].GetHess() : rhs->GetWeight());
const FloatCalc val2 = CalcNegUpdate<true>(static_cast<FloatCalc>(rhs->GetGradientPairs()[0].m_sumGradients),
hess2,
m_regAlpha,
m_regLambda,
m_deltaStepMax);
const FloatCalc val2 =
static_cast<FloatCalc>(rhs->GetGradientPairs()[0].m_sumGradients) / (hess2 + m_categoricalSmoothing);

if(val1 == val2) {
return lhs < rhs;
Expand Down Expand Up @@ -835,7 +822,7 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
std::sort(apBins,
ppBinsEnd,
CompareBin<bHessian, cCompilerScores>(
!(TermBoostFlags_DisableNewtonUpdate & flags), regAlpha, regLambda, deltaStepMax));
!(TermBoostFlags_DisableNewtonUpdate & flags), categoricalSmoothing));
}

pRootTreeNode->BEFORE_SetBinFirst(apBins);
Expand Down
2 changes: 1 addition & 1 deletion shared/libebm/tests/boosting_unusual_inputs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2175,7 +2175,7 @@ static double RandomizedTesting(const AccelerationFlags acceleration) {
}

TEST_CASE("stress test, boosting") {
const double expected = 26838942758406.215;
const double expected = 26758407585917.129;

double validationMetricExact = RandomizedTesting(AccelerationFlags_NONE);
CHECK(validationMetricExact == expected);
Expand Down

0 comments on commit 73d82fd

Please sign in to comment.