From b8f513fd17ca9247f408e45ee3c17bcd7efb4564 Mon Sep 17 00:00:00 2001 From: David Stone Date: Tue, 14 May 2024 09:20:51 -0600 Subject: [PATCH] Add `predict_random_selection` --- CMakeLists.txt | 1 + .../tm/evaluate/predict_random_selection.cpp | 86 +++++++++++++++++++ .../ps_usage_stats/score_predict_action.cpp | 65 +------------- 3 files changed, 90 insertions(+), 62 deletions(-) create mode 100644 source/tm/evaluate/predict_random_selection.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 64742943f..70815be6a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -71,6 +71,7 @@ target_sources(tm_common PUBLIC source/tm/evaluate/extreme_element_value.cpp source/tm/evaluate/load_evaluate.cpp source/tm/evaluate/possible_executed_moves.cpp + source/tm/evaluate/predict_random_selection.cpp source/tm/evaluate/predict_selection.cpp source/tm/evaluate/predicted.cpp source/tm/evaluate/random_selection.cpp diff --git a/source/tm/evaluate/predict_random_selection.cpp b/source/tm/evaluate/predict_random_selection.cpp new file mode 100644 index 000000000..ee4db7d2b --- /dev/null +++ b/source/tm/evaluate/predict_random_selection.cpp @@ -0,0 +1,86 @@ +// Copyright David Stone 2024. +// Distributed under the Boost Software License, Version 1.0. +// (See accompanying file LICENSE_1_0.txt or copy at +// http://www.boost.org/LICENSE_1_0.txt) + +export module tm.evaluate.predict_random_selection; + +import tm.evaluate.predicted; + +import tm.move.legal_selections; +import tm.move.move_name; +import tm.move.selection; +import tm.move.switch_; + +import bounded; +import containers; +import tv; + +namespace technicalmachine { + +using namespace bounded::literal; + +export constexpr auto predict_random_selection(LegalSelections const selections) -> AllPredicted { + auto const total_size = double(containers::size(selections)); + return AllPredicted(containers::transform(selections, [=](Selection const selection) { + return Predicted(selection, 1.0 / total_size); + })); +} + +constexpr auto is_switch = [](Selection const selection) { + return selection.index() == bounded::type; +}; + +export constexpr auto predict_random_selection(LegalSelections const selections, double const general_switch_probability) -> AllPredicted { + auto const total_size = containers::size(selections); + auto const switches = containers::count_if(selections, is_switch); + auto const switch_probability = + switches == 0 ? 0.0 : + switches == total_size ? 1.0 : + general_switch_probability; + return AllPredicted(containers::transform(selections, [=](Selection const selection) { + return Predicted( + selection, + is_switch(selection) ? + switch_probability / double(switches) : + (1.0 - switch_probability) / double(total_size - switches) + ); + })); +} + +static_assert( + predict_random_selection(LegalSelections({MoveName::Tackle})) == + AllPredicted({{MoveName::Tackle, 1.0}}) +); + +static_assert( + predict_random_selection(LegalSelections({MoveName::Tackle, MoveName::Thunder})) == + AllPredicted({{MoveName::Tackle, 0.5}, {MoveName::Thunder, 0.5}}) +); + +static_assert( + predict_random_selection(LegalSelections({MoveName::Tackle, Switch(0_bi)})) == + AllPredicted({{MoveName::Tackle, 0.5}, {Switch(0_bi), 0.5}}) +); + +static_assert( + predict_random_selection(LegalSelections({MoveName::Tackle}), 0.2) == + AllPredicted({{MoveName::Tackle, 1.0}}) +); + +static_assert( + predict_random_selection(LegalSelections({MoveName::Tackle, MoveName::Thunder}), 0.2) == + AllPredicted({{MoveName::Tackle, 0.5}, {MoveName::Thunder, 0.5}}) +); + +static_assert( + predict_random_selection(LegalSelections({MoveName::Tackle, Switch(0_bi)}), 0.2) == + AllPredicted({{MoveName::Tackle, 0.8}, {Switch(0_bi), 0.2}}) +); + +static_assert( + predict_random_selection(LegalSelections({Switch(0_bi)}), 0.2) == + AllPredicted({{Switch(0_bi), 1.0}}) +); + +} // namespace technicalmachine diff --git a/source/tm/test/ps_usage_stats/score_predict_action.cpp b/source/tm/test/ps_usage_stats/score_predict_action.cpp index 25f49c746..2fb5bf89e 100644 --- a/source/tm/test/ps_usage_stats/score_predict_action.cpp +++ b/source/tm/test/ps_usage_stats/score_predict_action.cpp @@ -18,6 +18,7 @@ import tm.clients.party; import tm.evaluate.all_evaluate; import tm.evaluate.predicted; +import tm.evaluate.predict_random_selection; import tm.evaluate.predict_selection; import tm.move.move_name; @@ -94,20 +95,6 @@ struct PredictedSelection { SlotMemory slot_memory; }; -struct SelectionTypeCount { - double switch_ = 0.0; - double move = 0.0; - double pass = 0.0; - - friend constexpr auto operator+(SelectionTypeCount const lhs, SelectionTypeCount const rhs) -> SelectionTypeCount { - return SelectionTypeCount( - lhs.switch_ + rhs.switch_, - lhs.move + rhs.move, - lhs.pass + rhs.pass - ); - } -}; - auto get_predicted_selection( BattleManager & battle, AllUsageStats const & all_usage_stats, @@ -129,55 +116,9 @@ auto get_predicted_selection( state.environment ); if constexpr (false) { - auto const selection_size = containers::size(selections); - return AllPredicted(containers::transform( - selections, - [=](Selection const selection) { - return Predicted(selection, 1.0 / double(selection_size)); - } - )); + return predict_random_selection(selections); } else { - auto const type_count = containers::sum(containers::transform( - selections, - [](Selection const selection) { - return tv::visit(selection, tv::overload( - [](Switch) { - return SelectionTypeCount(1.0, 0.0, 0.0); - }, - [](MoveName) { - return SelectionTypeCount(0.0, 1.0, 0.0); - }, - [](Pass) { - return SelectionTypeCount(0.0, 0.0, 1.0); - } - )); - } - )); - auto const local_switch_probability = type_count.move == 0.0 ? 1.0 : switch_probability; - return AllPredicted(containers::transform( - selections, - [=](Selection const selection) { - auto const probability = tv::visit(selection, tv::overload( - [&](Switch) { - return local_switch_probability / type_count.switch_; - }, - [&](MoveName) { - return (1.0 - local_switch_probability) / type_count.move; - }, - [&](Pass) { - auto const valid = - type_count.switch_ == 0.0 and - type_count.move == 0.0 and - type_count.pass == 1.0; - if (!valid) { - throw std::runtime_error("Yikes"); - } - return 1.0; - } - )); - return Predicted(selection, probability); - } - )); + return predict_random_selection(selections, switch_probability); } } };