From 802450d709343f33391fef0046008d0945947e7e Mon Sep 17 00:00:00 2001 From: Mark Rucker Date: Sun, 8 Oct 2023 23:42:04 -0400 Subject: [PATCH 1/5] Fixed a bug with bounded memory trees. The bug randomly caused a segfault. --- .../core/src/reductions/eigen_memory_tree.cc | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/vowpalwabbit/core/src/reductions/eigen_memory_tree.cc b/vowpalwabbit/core/src/reductions/eigen_memory_tree.cc index 37c856644fa..f38a340cbfc 100644 --- a/vowpalwabbit/core/src/reductions/eigen_memory_tree.cc +++ b/vowpalwabbit/core/src/reductions/eigen_memory_tree.cc @@ -741,13 +741,14 @@ void node_split(emt_tree& b, emt_node& cn) cn.examples.clear(); } -void node_insert(emt_node& cn, std::unique_ptr ex) +void node_insert(emt_tree& b, emt_node& cn, std::unique_ptr ex) { for (auto& cn_ex : cn.examples) { if (cn_ex->full == ex->full) { return; } } cn.examples.push_back(std::move(ex)); + tree_bound(b, cn.examples.back().get()); } emt_example* node_pick(emt_tree& b, learner& base, emt_node& cn, const emt_example& ex) @@ -779,16 +780,18 @@ void node_predict(emt_tree& b, learner& base, emt_node& cn, emt_example& ex, VW: auto* closest_ex = node_pick(b, base, cn, ex); ec.pred.multiclass = (closest_ex != nullptr) ? closest_ex->label : 0; ec.loss = (ec.l.multi.label != ec.pred.multiclass) ? ec.weight : 0; + if (closest_ex != nullptr) { + tree_bound(b, closest_ex); + } + } void emt_predict(emt_tree& b, learner& base, VW::example& ec) { b.all->feature_tweaks_config.ignore_some_linear = false; emt_example ex(*b.all, &ec); - emt_node& cn = *tree_route(b, ex); node_predict(b, base, cn, ex, ec); - tree_bound(b, &ex); } void emt_learn(emt_tree& b, learner& base, VW::example& ec) @@ -797,10 +800,9 @@ void emt_learn(emt_tree& b, learner& base, VW::example& ec) auto ex = VW::make_unique(*b.all, &ec); emt_node& cn = *tree_route(b, *ex); - scorer_learn(b, base, cn, *ex, ec.weight); node_predict(b, base, cn, *ex, ec); // vw learners predict and emt_learn - tree_bound(b, ex.get()); - node_insert(cn, std::move(ex)); + scorer_learn(b, base, cn, *ex, ec.weight); + node_insert(b, cn, std::move(ex)); node_split(b, cn); } From 54636fd9d29a470f2de5b791beeda5c0e7b15b04 Mon Sep 17 00:00:00 2001 From: Mark Rucker Date: Sun, 8 Oct 2023 23:42:04 -0400 Subject: [PATCH 2/5] fix: prevent feature corruption in LRU cache --- .../core/src/reductions/eigen_memory_tree.cc | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/vowpalwabbit/core/src/reductions/eigen_memory_tree.cc b/vowpalwabbit/core/src/reductions/eigen_memory_tree.cc index 37c856644fa..f38a340cbfc 100644 --- a/vowpalwabbit/core/src/reductions/eigen_memory_tree.cc +++ b/vowpalwabbit/core/src/reductions/eigen_memory_tree.cc @@ -741,13 +741,14 @@ void node_split(emt_tree& b, emt_node& cn) cn.examples.clear(); } -void node_insert(emt_node& cn, std::unique_ptr ex) +void node_insert(emt_tree& b, emt_node& cn, std::unique_ptr ex) { for (auto& cn_ex : cn.examples) { if (cn_ex->full == ex->full) { return; } } cn.examples.push_back(std::move(ex)); + tree_bound(b, cn.examples.back().get()); } emt_example* node_pick(emt_tree& b, learner& base, emt_node& cn, const emt_example& ex) @@ -779,16 +780,18 @@ void node_predict(emt_tree& b, learner& base, emt_node& cn, emt_example& ex, VW: auto* closest_ex = node_pick(b, base, cn, ex); ec.pred.multiclass = (closest_ex != nullptr) ? closest_ex->label : 0; ec.loss = (ec.l.multi.label != ec.pred.multiclass) ? ec.weight : 0; + if (closest_ex != nullptr) { + tree_bound(b, closest_ex); + } + } void emt_predict(emt_tree& b, learner& base, VW::example& ec) { b.all->feature_tweaks_config.ignore_some_linear = false; emt_example ex(*b.all, &ec); - emt_node& cn = *tree_route(b, ex); node_predict(b, base, cn, ex, ec); - tree_bound(b, &ex); } void emt_learn(emt_tree& b, learner& base, VW::example& ec) @@ -797,10 +800,9 @@ void emt_learn(emt_tree& b, learner& base, VW::example& ec) auto ex = VW::make_unique(*b.all, &ec); emt_node& cn = *tree_route(b, *ex); - scorer_learn(b, base, cn, *ex, ec.weight); node_predict(b, base, cn, *ex, ec); // vw learners predict and emt_learn - tree_bound(b, ex.get()); - node_insert(cn, std::move(ex)); + scorer_learn(b, base, cn, *ex, ec.weight); + node_insert(b, cn, std::move(ex)); node_split(b, cn); } From 1f3d0ab2ccd8d682244956def7709ab15fed2db7 Mon Sep 17 00:00:00 2001 From: Mark Rucker Date: Mon, 9 Oct 2023 15:03:08 -0400 Subject: [PATCH 3/5] fix: prevent feature corruption in LRU cache. --- vowpalwabbit/core/src/reductions/eigen_memory_tree.cc | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/vowpalwabbit/core/src/reductions/eigen_memory_tree.cc b/vowpalwabbit/core/src/reductions/eigen_memory_tree.cc index f38a340cbfc..a496d8fc905 100644 --- a/vowpalwabbit/core/src/reductions/eigen_memory_tree.cc +++ b/vowpalwabbit/core/src/reductions/eigen_memory_tree.cc @@ -780,10 +780,7 @@ void node_predict(emt_tree& b, learner& base, emt_node& cn, emt_example& ex, VW: auto* closest_ex = node_pick(b, base, cn, ex); ec.pred.multiclass = (closest_ex != nullptr) ? closest_ex->label : 0; ec.loss = (ec.l.multi.label != ec.pred.multiclass) ? ec.weight : 0; - if (closest_ex != nullptr) { - tree_bound(b, closest_ex); - } - + if (closest_ex != nullptr) { tree_bound(b, closest_ex); } } void emt_predict(emt_tree& b, learner& base, VW::example& ec) From 28e96d760d216f155c8029a328c49b7970f57417 Mon Sep 17 00:00:00 2001 From: Mark Rucker Date: Wed, 1 Nov 2023 06:08:47 -0400 Subject: [PATCH 4/5] Added one unittest to catch the bug being fixed and one unittest to ensure expected lru cache behavior. --- .../core/tests/eigen_memory_tree_test.cc | 42 ++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/vowpalwabbit/core/tests/eigen_memory_tree_test.cc b/vowpalwabbit/core/tests/eigen_memory_tree_test.cc index 82624c846df..920701c6cbf 100644 --- a/vowpalwabbit/core/tests/eigen_memory_tree_test.cc +++ b/vowpalwabbit/core/tests/eigen_memory_tree_test.cc @@ -131,7 +131,7 @@ TEST(EigenMemoryTree, ExactMatchWithRouterTest) } } -TEST(EigenMemoryTree, Bounding) +TEST(EigenMemoryTree, BoundingDrop) { auto vw = VW::initialize(vwtest::make_args("--quiet", "--emt", "--emt_tree", "5")); auto* tree = get_emt_tree(*vw); @@ -148,6 +148,46 @@ TEST(EigenMemoryTree, Bounding) EXPECT_EQ(tree->root->router_weights.size(), 0); } +TEST(EigenMemoryTree, BoundingPredict) +{ + auto vw = VW::initialize(vwtest::make_args("--quiet", "--emt", "--emt_tree", "3")); + auto* tree = get_emt_tree(*vw); + + auto* ex = VW::read_example(*vw, "1 | 1"); + vw->predict(*ex); + vw->finish_example(*ex); + + EXPECT_EQ(tree->bounder->list.size(), 0); +} + +TEST(EigenMemoryTree, BoundingRecency) +{ + auto vw = VW::initialize(vwtest::make_args("--quiet", "--emt", "--emt_tree", "3")); + auto* tree = get_emt_tree(*vw); + + for (int i = 0; i < 3; i++) + { + auto* ex = VW::read_example(*vw, std::to_string(i) + " | " + std::to_string(i)); + vw->learn(*ex); + vw->finish_example(*ex); + } + + EXPECT_EQ((*tree->bounder->list.begin())->base[0].first, 2); + + auto* ex1 = VW::read_example(*vw, "1 | 1"); + vw->predict(*ex1); + vw->finish_example(*ex1); + + EXPECT_EQ((*tree->bounder->list.begin())->base[0].first, 1); + + auto* ex2 = VW::read_example(*vw, "1 | 0"); + vw->predict(*ex2); + vw->finish_example(*ex2); + + EXPECT_EQ((*tree->bounder->list.begin())->base[0].first, 0); + +} + TEST(EigenMemoryTree, Split) { auto args = vwtest::make_args("--quiet", "--emt", "--emt_tree", "10", "--emt_leaf", "3"); From ba3ef5b22d9331cd40d73aed0acd84f1af837654 Mon Sep 17 00:00:00 2001 From: Mark Rucker Date: Wed, 1 Nov 2023 06:15:35 -0400 Subject: [PATCH 5/5] Fixed code formatting. --- vowpalwabbit/core/tests/eigen_memory_tree_test.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/vowpalwabbit/core/tests/eigen_memory_tree_test.cc b/vowpalwabbit/core/tests/eigen_memory_tree_test.cc index 920701c6cbf..da9aeafc0cc 100644 --- a/vowpalwabbit/core/tests/eigen_memory_tree_test.cc +++ b/vowpalwabbit/core/tests/eigen_memory_tree_test.cc @@ -185,7 +185,6 @@ TEST(EigenMemoryTree, BoundingRecency) vw->finish_example(*ex2); EXPECT_EQ((*tree->bounder->list.begin())->base[0].first, 0); - } TEST(EigenMemoryTree, Split)