Skip to content

Commit

Permalink
fix count that was completely broken
Browse files Browse the repository at this point in the history
  • Loading branch information
Jouramie committed Mar 19, 2024
1 parent 8fdfefc commit d88c0a7
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 13 deletions.
37 changes: 33 additions & 4 deletions worlds/stardew_valley/stardew_rule/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,15 +363,44 @@ def get_difficulty(self):
class Count(BaseStardewRule):
count: int
rules: List[StardewRule]
counter: Counter[StardewRule]
evaluate: Callable[[CollectionState], bool]

total: Optional[int]
rule_mapping: Optional[Dict[StardewRule, StardewRule]]

def __init__(self, rules: List[StardewRule], count: int):
self.count = count
self.counter = Counter(rules)
self.total = sum(self.counter.values())
self.rules = sorted(self.counter.keys(), key=lambda x: self.counter[x], reverse=True)
self.rule_mapping = {}

if len(self.counter) / len(rules) < .66:
# Checking if it's worth using the count operation with shortcircuit or not. Value should be fine-tuned when Count has more usage.
self.total = sum(self.counter.values())
self.rules = sorted(self.counter.keys(), key=lambda x: self.counter[x], reverse=True)
self.rule_mapping = {}
self.evaluate = self.evaluate_with_shortcircuit
else:
self.rules = rules
self.evaluate = self.evaluate_without_shortcircuit

def __call__(self, state: CollectionState) -> bool:
return self.evaluate(state)

def evaluate_without_shortcircuit(self, state: CollectionState) -> bool:
c = 0
for i in range(self.rules_count):
self.rules[i], value = self.rules[i].evaluate_while_simplifying(state)
if value:
c += 1

if c >= self.count:
return True
if c + self.rules_count - i < self.count:
break

return False

def evaluate_with_shortcircuit(self, state: CollectionState) -> bool:
c = 0
t = self.total

Expand All @@ -395,7 +424,7 @@ def call_evaluate_while_simplifying_cached(self, rule: StardewRule, state: Colle
try:
# A mapping table with the original rule is used here because two rules could resolve to the same rule.
# This would require to change the counter to merge both rules, and quickly become complicated.
return self.rule_mapping[rule].evaluate_while_simplifying(state)
return self.rule_mapping[rule](state)
except KeyError:
self.rule_mapping[rule], value = rule.evaluate_while_simplifying(state)
return value
Expand Down
23 changes: 14 additions & 9 deletions worlds/stardew_valley/test/TestStardewRule.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,8 @@ class TestCount(unittest.TestCase):
def test_duplicate_rule_count_double(self):
expected_result = True
collection_state = MagicMock()
simplified_rule = MagicMock()
other_rule = MagicMock(spec=StardewRule)
simplified_rule = Mock()
other_rule = Mock(spec=StardewRule)
other_rule.evaluate_while_simplifying = Mock(return_value=(simplified_rule, expected_result))
rule = Count([other_rule, other_rule, other_rule], 2)

Expand All @@ -261,10 +261,10 @@ def test_duplicate_rule_count_double(self):
self.assertEqual(expected_result, actual_result)

def test_simplified_rule_is_reused(self):
expected_result = True
expected_result = False
collection_state = MagicMock()
simplified_rule = MagicMock(return_value=expected_result)
other_rule = MagicMock(spec=StardewRule)
simplified_rule = Mock(return_value=expected_result)
other_rule = Mock(spec=StardewRule)
other_rule.evaluate_while_simplifying = Mock(return_value=(simplified_rule, expected_result))
rule = Count([other_rule, other_rule, other_rule], 2)

Expand All @@ -278,15 +278,15 @@ def test_simplified_rule_is_reused(self):
actual_result = rule(collection_state)

other_rule.evaluate_while_simplifying.assert_not_called()
simplified_rule.assert_not_called()
simplified_rule.assert_called()
self.assertEqual(expected_result, actual_result)

def test_break_if_not_enough_rule_to_complete(self):
expected_result = False
collection_state = MagicMock()
simplified_rule = MagicMock()
never_called_rule = MagicMock()
other_rule = MagicMock(spec=StardewRule)
simplified_rule = Mock()
never_called_rule = Mock()
other_rule = Mock(spec=StardewRule)
other_rule.evaluate_while_simplifying = Mock(return_value=(simplified_rule, expected_result))
rule = Count([other_rule, other_rule, other_rule, never_called_rule], 2)

Expand All @@ -296,3 +296,8 @@ def test_break_if_not_enough_rule_to_complete(self):
never_called_rule.assert_not_called()
never_called_rule.evaluate_while_simplifying.assert_not_called()
self.assertEqual(expected_result, actual_result)

def test_evaluate_without_shortcircuit_when_rules_are_all_different(self):
rule = Count([Mock(), Mock(), Mock(), Mock()], 2)

self.assertEqual(rule.evaluate, rule.evaluate_without_shortcircuit)

0 comments on commit d88c0a7

Please sign in to comment.