diff --git a/vowpalwabbit/core/src/reductions/eigen_memory_tree.cc b/vowpalwabbit/core/src/reductions/eigen_memory_tree.cc index 37c856644fa..a496d8fc905 100644 --- a/vowpalwabbit/core/src/reductions/eigen_memory_tree.cc +++ b/vowpalwabbit/core/src/reductions/eigen_memory_tree.cc @@ -741,13 +741,14 @@ void node_split(emt_tree& b, emt_node& cn) cn.examples.clear(); } -void node_insert(emt_node& cn, std::unique_ptr ex) +void node_insert(emt_tree& b, emt_node& cn, std::unique_ptr ex) { for (auto& cn_ex : cn.examples) { if (cn_ex->full == ex->full) { return; } } cn.examples.push_back(std::move(ex)); + tree_bound(b, cn.examples.back().get()); } emt_example* node_pick(emt_tree& b, learner& base, emt_node& cn, const emt_example& ex) @@ -779,16 +780,15 @@ void node_predict(emt_tree& b, learner& base, emt_node& cn, emt_example& ex, VW: auto* closest_ex = node_pick(b, base, cn, ex); ec.pred.multiclass = (closest_ex != nullptr) ? closest_ex->label : 0; ec.loss = (ec.l.multi.label != ec.pred.multiclass) ? ec.weight : 0; + if (closest_ex != nullptr) { tree_bound(b, closest_ex); } } void emt_predict(emt_tree& b, learner& base, VW::example& ec) { b.all->feature_tweaks_config.ignore_some_linear = false; emt_example ex(*b.all, &ec); - emt_node& cn = *tree_route(b, ex); node_predict(b, base, cn, ex, ec); - tree_bound(b, &ex); } void emt_learn(emt_tree& b, learner& base, VW::example& ec) @@ -797,10 +797,9 @@ void emt_learn(emt_tree& b, learner& base, VW::example& ec) auto ex = VW::make_unique(*b.all, &ec); emt_node& cn = *tree_route(b, *ex); - scorer_learn(b, base, cn, *ex, ec.weight); node_predict(b, base, cn, *ex, ec); // vw learners predict and emt_learn - tree_bound(b, ex.get()); - node_insert(cn, std::move(ex)); + scorer_learn(b, base, cn, *ex, ec.weight); + node_insert(b, cn, std::move(ex)); node_split(b, cn); } diff --git a/vowpalwabbit/core/tests/eigen_memory_tree_test.cc b/vowpalwabbit/core/tests/eigen_memory_tree_test.cc index 82624c846df..da9aeafc0cc 100644 --- a/vowpalwabbit/core/tests/eigen_memory_tree_test.cc +++ b/vowpalwabbit/core/tests/eigen_memory_tree_test.cc @@ -131,7 +131,7 @@ TEST(EigenMemoryTree, ExactMatchWithRouterTest) } } -TEST(EigenMemoryTree, Bounding) +TEST(EigenMemoryTree, BoundingDrop) { auto vw = VW::initialize(vwtest::make_args("--quiet", "--emt", "--emt_tree", "5")); auto* tree = get_emt_tree(*vw); @@ -148,6 +148,45 @@ TEST(EigenMemoryTree, Bounding) EXPECT_EQ(tree->root->router_weights.size(), 0); } +TEST(EigenMemoryTree, BoundingPredict) +{ + auto vw = VW::initialize(vwtest::make_args("--quiet", "--emt", "--emt_tree", "3")); + auto* tree = get_emt_tree(*vw); + + auto* ex = VW::read_example(*vw, "1 | 1"); + vw->predict(*ex); + vw->finish_example(*ex); + + EXPECT_EQ(tree->bounder->list.size(), 0); +} + +TEST(EigenMemoryTree, BoundingRecency) +{ + auto vw = VW::initialize(vwtest::make_args("--quiet", "--emt", "--emt_tree", "3")); + auto* tree = get_emt_tree(*vw); + + for (int i = 0; i < 3; i++) + { + auto* ex = VW::read_example(*vw, std::to_string(i) + " | " + std::to_string(i)); + vw->learn(*ex); + vw->finish_example(*ex); + } + + EXPECT_EQ((*tree->bounder->list.begin())->base[0].first, 2); + + auto* ex1 = VW::read_example(*vw, "1 | 1"); + vw->predict(*ex1); + vw->finish_example(*ex1); + + EXPECT_EQ((*tree->bounder->list.begin())->base[0].first, 1); + + auto* ex2 = VW::read_example(*vw, "1 | 0"); + vw->predict(*ex2); + vw->finish_example(*ex2); + + EXPECT_EQ((*tree->bounder->list.begin())->base[0].first, 0); +} + TEST(EigenMemoryTree, Split) { auto args = vwtest::make_args("--quiet", "--emt", "--emt_tree", "10", "--emt_leaf", "3");