Skip to content

Commit

Permalink
Add predict_random_selection
Browse files Browse the repository at this point in the history
  • Loading branch information
davidstone committed May 14, 2024
1 parent 11300a3 commit b8f513f
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 62 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ target_sources(tm_common PUBLIC
source/tm/evaluate/extreme_element_value.cpp
source/tm/evaluate/load_evaluate.cpp
source/tm/evaluate/possible_executed_moves.cpp
source/tm/evaluate/predict_random_selection.cpp
source/tm/evaluate/predict_selection.cpp
source/tm/evaluate/predicted.cpp
source/tm/evaluate/random_selection.cpp
Expand Down
86 changes: 86 additions & 0 deletions source/tm/evaluate/predict_random_selection.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// Copyright David Stone 2024.
// Distributed under the Boost Software License, Version 1.0.
// (See accompanying file LICENSE_1_0.txt or copy at
// http://www.boost.org/LICENSE_1_0.txt)

export module tm.evaluate.predict_random_selection;

import tm.evaluate.predicted;

import tm.move.legal_selections;
import tm.move.move_name;
import tm.move.selection;
import tm.move.switch_;

import bounded;
import containers;
import tv;

namespace technicalmachine {

using namespace bounded::literal;

export constexpr auto predict_random_selection(LegalSelections const selections) -> AllPredicted {
auto const total_size = double(containers::size(selections));
return AllPredicted(containers::transform(selections, [=](Selection const selection) {
return Predicted(selection, 1.0 / total_size);
}));
}

constexpr auto is_switch = [](Selection const selection) {
return selection.index() == bounded::type<Switch>;
};

export constexpr auto predict_random_selection(LegalSelections const selections, double const general_switch_probability) -> AllPredicted {
auto const total_size = containers::size(selections);
auto const switches = containers::count_if(selections, is_switch);
auto const switch_probability =
switches == 0 ? 0.0 :
switches == total_size ? 1.0 :
general_switch_probability;
return AllPredicted(containers::transform(selections, [=](Selection const selection) {
return Predicted(
selection,
is_switch(selection) ?
switch_probability / double(switches) :
(1.0 - switch_probability) / double(total_size - switches)
);
}));
}

static_assert(
predict_random_selection(LegalSelections({MoveName::Tackle})) ==
AllPredicted({{MoveName::Tackle, 1.0}})
);

static_assert(
predict_random_selection(LegalSelections({MoveName::Tackle, MoveName::Thunder})) ==
AllPredicted({{MoveName::Tackle, 0.5}, {MoveName::Thunder, 0.5}})
);

static_assert(
predict_random_selection(LegalSelections({MoveName::Tackle, Switch(0_bi)})) ==
AllPredicted({{MoveName::Tackle, 0.5}, {Switch(0_bi), 0.5}})
);

static_assert(
predict_random_selection(LegalSelections({MoveName::Tackle}), 0.2) ==
AllPredicted({{MoveName::Tackle, 1.0}})
);

static_assert(
predict_random_selection(LegalSelections({MoveName::Tackle, MoveName::Thunder}), 0.2) ==
AllPredicted({{MoveName::Tackle, 0.5}, {MoveName::Thunder, 0.5}})
);

static_assert(
predict_random_selection(LegalSelections({MoveName::Tackle, Switch(0_bi)}), 0.2) ==
AllPredicted({{MoveName::Tackle, 0.8}, {Switch(0_bi), 0.2}})
);

static_assert(
predict_random_selection(LegalSelections({Switch(0_bi)}), 0.2) ==
AllPredicted({{Switch(0_bi), 1.0}})
);

} // namespace technicalmachine
65 changes: 3 additions & 62 deletions source/tm/test/ps_usage_stats/score_predict_action.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import tm.clients.party;

import tm.evaluate.all_evaluate;
import tm.evaluate.predicted;
import tm.evaluate.predict_random_selection;
import tm.evaluate.predict_selection;

import tm.move.move_name;
Expand Down Expand Up @@ -94,20 +95,6 @@ struct PredictedSelection {
SlotMemory slot_memory;
};

struct SelectionTypeCount {
double switch_ = 0.0;
double move = 0.0;
double pass = 0.0;

friend constexpr auto operator+(SelectionTypeCount const lhs, SelectionTypeCount const rhs) -> SelectionTypeCount {
return SelectionTypeCount(
lhs.switch_ + rhs.switch_,
lhs.move + rhs.move,
lhs.pass + rhs.pass
);
}
};

auto get_predicted_selection(
BattleManager & battle,
AllUsageStats const & all_usage_stats,
Expand All @@ -129,55 +116,9 @@ auto get_predicted_selection(
state.environment
);
if constexpr (false) {
auto const selection_size = containers::size(selections);
return AllPredicted(containers::transform(
selections,
[=](Selection const selection) {
return Predicted(selection, 1.0 / double(selection_size));
}
));
return predict_random_selection(selections);
} else {
auto const type_count = containers::sum(containers::transform(
selections,
[](Selection const selection) {
return tv::visit(selection, tv::overload(
[](Switch) {
return SelectionTypeCount(1.0, 0.0, 0.0);
},
[](MoveName) {
return SelectionTypeCount(0.0, 1.0, 0.0);
},
[](Pass) {
return SelectionTypeCount(0.0, 0.0, 1.0);
}
));
}
));
auto const local_switch_probability = type_count.move == 0.0 ? 1.0 : switch_probability;
return AllPredicted(containers::transform(
selections,
[=](Selection const selection) {
auto const probability = tv::visit(selection, tv::overload(
[&](Switch) {
return local_switch_probability / type_count.switch_;
},
[&](MoveName) {
return (1.0 - local_switch_probability) / type_count.move;
},
[&](Pass) {
auto const valid =
type_count.switch_ == 0.0 and
type_count.move == 0.0 and
type_count.pass == 1.0;
if (!valid) {
throw std::runtime_error("Yikes");
}
return 1.0;
}
));
return Predicted(selection, probability);
}
));
return predict_random_selection(selections, switch_probability);
}
}
};
Expand Down

0 comments on commit b8f513f

Please sign in to comment.