From a443d9b4a98f090fba4bc21655df0c1c462bfd9f Mon Sep 17 00:00:00 2001 From: Jouramie Date: Tue, 19 Mar 2024 01:08:13 -0400 Subject: [PATCH] fix count that was completely broken --- worlds/stardew_valley/stardew_rule/base.py | 37 +++++++++++++++++-- worlds/stardew_valley/test/TestStardewRule.py | 23 +++++++----- 2 files changed, 47 insertions(+), 13 deletions(-) diff --git a/worlds/stardew_valley/stardew_rule/base.py b/worlds/stardew_valley/stardew_rule/base.py index b6b8e26702e3..3b517d541eb2 100644 --- a/worlds/stardew_valley/stardew_rule/base.py +++ b/worlds/stardew_valley/stardew_rule/base.py @@ -363,15 +363,44 @@ def get_difficulty(self): class Count(BaseStardewRule): count: int rules: List[StardewRule] + counter: Counter[StardewRule] + evaluate: Callable[[CollectionState], bool] + + total: Optional[int] + rule_mapping: Optional[Dict[StardewRule, StardewRule]] def __init__(self, rules: List[StardewRule], count: int): self.count = count self.counter = Counter(rules) - self.total = sum(self.counter.values()) - self.rules = sorted(self.counter.keys(), key=lambda x: self.counter[x], reverse=True) - self.rule_mapping = {} + + if len(self.counter) / len(rules) < .66: + # Checking if it's worth using the count operation with shortcircuit or not. Value should be fine-tuned when Count has more usage. + self.total = sum(self.counter.values()) + self.rules = sorted(self.counter.keys(), key=lambda x: self.counter[x], reverse=True) + self.rule_mapping = {} + self.evaluate = self.evaluate_with_shortcircuit + else: + self.rules = rules + self.evaluate = self.evaluate_without_shortcircuit def __call__(self, state: CollectionState) -> bool: + return self.evaluate(state) + + def evaluate_without_shortcircuit(self, state: CollectionState) -> bool: + c = 0 + for i in range(self.rules_count): + self.rules[i], value = self.rules[i].evaluate_while_simplifying(state) + if value: + c += 1 + + if c >= self.count: + return True + if c + self.rules_count - i < self.count: + break + + return False + + def evaluate_with_shortcircuit(self, state: CollectionState) -> bool: c = 0 t = self.total @@ -395,7 +424,7 @@ def call_evaluate_while_simplifying_cached(self, rule: StardewRule, state: Colle try: # A mapping table with the original rule is used here because two rules could resolve to the same rule. # This would require to change the counter to merge both rules, and quickly become complicated. - return self.rule_mapping[rule].evaluate_while_simplifying(state) + return self.rule_mapping[rule](state) except KeyError: self.rule_mapping[rule], value = rule.evaluate_while_simplifying(state) return value diff --git a/worlds/stardew_valley/test/TestStardewRule.py b/worlds/stardew_valley/test/TestStardewRule.py index 189514eca815..c20783b49308 100644 --- a/worlds/stardew_valley/test/TestStardewRule.py +++ b/worlds/stardew_valley/test/TestStardewRule.py @@ -250,8 +250,8 @@ class TestCount(unittest.TestCase): def test_duplicate_rule_count_double(self): expected_result = True collection_state = MagicMock() - simplified_rule = MagicMock() - other_rule = MagicMock(spec=StardewRule) + simplified_rule = Mock() + other_rule = Mock(spec=StardewRule) other_rule.evaluate_while_simplifying = Mock(return_value=(simplified_rule, expected_result)) rule = Count([other_rule, other_rule, other_rule], 2) @@ -261,10 +261,10 @@ def test_duplicate_rule_count_double(self): self.assertEqual(expected_result, actual_result) def test_simplified_rule_is_reused(self): - expected_result = True + expected_result = False collection_state = MagicMock() - simplified_rule = MagicMock(return_value=expected_result) - other_rule = MagicMock(spec=StardewRule) + simplified_rule = Mock(return_value=expected_result) + other_rule = Mock(spec=StardewRule) other_rule.evaluate_while_simplifying = Mock(return_value=(simplified_rule, expected_result)) rule = Count([other_rule, other_rule, other_rule], 2) @@ -278,15 +278,15 @@ def test_simplified_rule_is_reused(self): actual_result = rule(collection_state) other_rule.evaluate_while_simplifying.assert_not_called() - simplified_rule.assert_not_called() + simplified_rule.assert_called() self.assertEqual(expected_result, actual_result) def test_break_if_not_enough_rule_to_complete(self): expected_result = False collection_state = MagicMock() - simplified_rule = MagicMock() - never_called_rule = MagicMock() - other_rule = MagicMock(spec=StardewRule) + simplified_rule = Mock() + never_called_rule = Mock() + other_rule = Mock(spec=StardewRule) other_rule.evaluate_while_simplifying = Mock(return_value=(simplified_rule, expected_result)) rule = Count([other_rule, other_rule, other_rule, never_called_rule], 2) @@ -296,3 +296,8 @@ def test_break_if_not_enough_rule_to_complete(self): never_called_rule.assert_not_called() never_called_rule.evaluate_while_simplifying.assert_not_called() self.assertEqual(expected_result, actual_result) + + def test_evaluate_without_shortcircuit_when_rules_are_all_different(self): + rule = Count([Mock(), Mock(), Mock(), Mock()], 2) + + self.assertEqual(rule.evaluate, rule.evaluate_without_shortcircuit)