diff --git a/shared/libebm/PartitionOneDimensionalBoosting.cpp b/shared/libebm/PartitionOneDimensionalBoosting.cpp index 23355c16a..4de6a050c 100644 --- a/shared/libebm/PartitionOneDimensionalBoosting.cpp +++ b/shared/libebm/PartitionOneDimensionalBoosting.cpp @@ -704,19 +704,12 @@ template class CompareNodeGain final { template class CompareBin final { bool m_bHessianRuntime; - FloatCalc m_regAlpha; - FloatCalc m_regLambda; - FloatCalc m_deltaStepMax; + FloatCalc m_categoricalSmoothing; public: - INLINE_ALWAYS CompareBin(const bool bHessianRuntime, - const FloatCalc regAlpha, - const FloatCalc regLambda, - const FloatCalc deltaStepMax) { + INLINE_ALWAYS CompareBin(const bool bHessianRuntime, FloatCalc categoricalSmoothing) { m_bHessianRuntime = bHessianRuntime; - m_regAlpha = regAlpha; - m_regLambda = regLambda; - m_deltaStepMax = deltaStepMax; + m_categoricalSmoothing = categoricalSmoothing; } INLINE_ALWAYS bool operator()( @@ -729,19 +722,13 @@ template class CompareBin final { const FloatCalc hess1 = static_cast(bUpdateWithHessian ? lhs->GetGradientPairs()[0].GetHess() : lhs->GetWeight()); - const FloatCalc val1 = CalcNegUpdate(static_cast(lhs->GetGradientPairs()[0].m_sumGradients), - hess1, - m_regAlpha, - m_regLambda, - m_deltaStepMax); + const FloatCalc val1 = + static_cast(lhs->GetGradientPairs()[0].m_sumGradients) / (hess1 + m_categoricalSmoothing); const FloatCalc hess2 = static_cast(bUpdateWithHessian ? rhs->GetGradientPairs()[0].GetHess() : rhs->GetWeight()); - const FloatCalc val2 = CalcNegUpdate(static_cast(rhs->GetGradientPairs()[0].m_sumGradients), - hess2, - m_regAlpha, - m_regLambda, - m_deltaStepMax); + const FloatCalc val2 = + static_cast(rhs->GetGradientPairs()[0].m_sumGradients) / (hess2 + m_categoricalSmoothing); if(val1 == val2) { return lhs < rhs; @@ -835,7 +822,7 @@ template class PartitionOneDimensionalBoo std::sort(apBins, ppBinsEnd, CompareBin( - !(TermBoostFlags_DisableNewtonUpdate & flags), regAlpha, regLambda, deltaStepMax)); + !(TermBoostFlags_DisableNewtonUpdate & flags), categoricalSmoothing)); } pRootTreeNode->BEFORE_SetBinFirst(apBins); diff --git a/shared/libebm/tests/boosting_unusual_inputs.cpp b/shared/libebm/tests/boosting_unusual_inputs.cpp index 71cf9d040..8aa683e96 100644 --- a/shared/libebm/tests/boosting_unusual_inputs.cpp +++ b/shared/libebm/tests/boosting_unusual_inputs.cpp @@ -2175,7 +2175,7 @@ static double RandomizedTesting(const AccelerationFlags acceleration) { } TEST_CASE("stress test, boosting") { - const double expected = 26838942758406.215; + const double expected = 26758407585917.129; double validationMetricExact = RandomizedTesting(AccelerationFlags_NONE); CHECK(validationMetricExact == expected);