Skip to content

Commit

Permalink
group rules before counting them
Browse files Browse the repository at this point in the history
use slots

add group to has rules

reach any checks if any are already accessible

remove hardcoded true from count
  • Loading branch information
Jouramie committed Mar 18, 2024
1 parent 8a8263f commit a0d691b
Show file tree
Hide file tree
Showing 9 changed files with 135 additions and 38 deletions.
10 changes: 6 additions & 4 deletions worlds/stardew_valley/logic/building_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from ..strings.material_names import Material
from ..strings.metal_names import MetalBar

has_group = "building"


class BuildingLogicMixin(BaseLogicMixin):
def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -60,7 +62,7 @@ def has_building(self, building: str) -> StardewRule:

carpenter_rule = self.logic.received(Event.can_construct_buildings)
if not self.options.building_progression & BuildingProgression.option_progressive:
return Has(building, self.registry.building_rules) & carpenter_rule
return Has(building, self.registry.building_rules, has_group) & carpenter_rule

count = 1
if building in [Building.coop, Building.barn, Building.shed]:
Expand All @@ -86,10 +88,10 @@ def has_house(self, upgrade_level: int) -> StardewRule:
return carpenter_rule & self.logic.received(f"Progressive House", upgrade_level)

if upgrade_level == 1:
return carpenter_rule & Has(Building.kitchen, self.registry.building_rules)
return carpenter_rule & Has(Building.kitchen, self.registry.building_rules, has_group)

if upgrade_level == 2:
return carpenter_rule & Has(Building.kids_room, self.registry.building_rules)
return carpenter_rule & Has(Building.kids_room, self.registry.building_rules, has_group)

# if upgrade_level == 3:
return carpenter_rule & Has(Building.cellar, self.registry.building_rules)
return carpenter_rule & Has(Building.cellar, self.registry.building_rules, has_group)
14 changes: 12 additions & 2 deletions worlds/stardew_valley/logic/has_logic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .base_logic import BaseLogic
from ..stardew_rule import StardewRule, And, Or, Has, Count
from ..stardew_rule import StardewRule, And, Or, Has, Count, true_


class HasLogicMixin(BaseLogic[None]):
Expand All @@ -24,11 +24,21 @@ def has_n(self, *items: str, count: int):
def count(count: int, *rules: StardewRule) -> StardewRule:
assert rules, "Can't create a Count conditions without rules"
assert len(rules) >= count, "Count need at least as many rules as the count"
assert count > 0, "Count can't be negative"

count -= sum(r is true_ for r in rules)
rules = list(r for r in rules if r is not true_)

if count <= 0:
return true_

if len(rules) == 1:
return rules[0]

if count == 1:
return Or(*rules)

if count == len(rules):
return And(*rules)

return Count(list(rules), count)
return Count(rules, count)
2 changes: 1 addition & 1 deletion worlds/stardew_valley/logic/quest_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def update_rules(self, new_rules: Dict[str, StardewRule]):
self.registry.quest_rules.update(new_rules)

def can_complete_quest(self, quest: str) -> StardewRule:
return Has(quest, self.registry.quest_rules)
return Has(quest, self.registry.quest_rules, "quest")

def has_club_card(self) -> StardewRule:
if self.options.quest_locations < 0:
Expand Down
3 changes: 3 additions & 0 deletions worlds/stardew_valley/logic/region_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def can_reach(self, region_name: str) -> StardewRule:

@cache_self1
def can_reach_any(self, region_names: Tuple[str, ...]) -> StardewRule:
if any(r in always_regions_by_setting[self.options.entrance_randomization] for r in region_names):
return true_

return Or(*(self.logic.region.can_reach(spot) for spot in region_names))

@cache_self1
Expand Down
3 changes: 1 addition & 2 deletions worlds/stardew_valley/logic/special_order_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from ..strings.region_names import Region
from ..strings.season_names import Season
from ..strings.special_order_names import SpecialOrder
from ..strings.tool_names import Tool
from ..strings.villager_names import NPC


Expand Down Expand Up @@ -105,7 +104,7 @@ def update_rules(self, new_rules: Dict[str, StardewRule]):
self.registry.special_order_rules.update(new_rules)

def can_complete_special_order(self, special_order: str) -> StardewRule:
return Has(special_order, self.registry.special_order_rules)
return Has(special_order, self.registry.special_order_rules, "special order")

def has_island_transport(self) -> StardewRule:
return self.logic.received(Transportation.island_obelisk) | self.logic.received(Transportation.boat_repair)
65 changes: 41 additions & 24 deletions worlds/stardew_valley/stardew_rule/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from collections import deque
from collections import deque, Counter
from dataclasses import dataclass, field
from functools import cached_property
from itertools import chain
from threading import Lock
Expand Down Expand Up @@ -295,7 +296,10 @@ def __eq__(self, other):
self.simplification_state.original_simplifiable_rules == self.simplification_state.original_simplifiable_rules)

def __hash__(self):
return hash((id(self.combinable_rules), self.simplification_state.original_simplifiable_rules))
if len(self.combinable_rules) + len(self.simplification_state.original_simplifiable_rules) > 5:
return id(self)

return hash((*self.combinable_rules.values(), self.simplification_state.original_simplifiable_rules))


class Or(AggregatingStardewRule):
Expand Down Expand Up @@ -361,25 +365,43 @@ class Count(BaseStardewRule):
rules: List[StardewRule]

def __init__(self, rules: List[StardewRule], count: int):
self.rules = rules
self.count = count
self.counter = Counter(rules)
self.rules = sorted(self.counter.keys(), key=lambda x: self.counter[x], reverse=True)
self.rule_mapping = {}

def evaluate_while_simplifying(self, state: CollectionState) -> Tuple[StardewRule, bool]:
assert self.counter.total() == len(rules)

def __call__(self, state: CollectionState) -> bool:
c = 0
t = self.counter.total()

for i in range(self.rules_count):
self.rules[i], value = self.rules[i].evaluate_while_simplifying(state)
if value:
c += 1
original_rule = self.rules[i]
evaluation_value = self.call_evaluate_while_simplifying_cached(original_rule, state)
rule_value = self.counter[original_rule]

if evaluation_value:
c += rule_value
else:
t -= rule_value

if c >= self.count:
return self, True
if c + self.rules_count - i < self.count:
return True
elif t < self.count:
break

return self, False
return False

def __call__(self, state: CollectionState) -> bool:
return self.evaluate_while_simplifying(state)[1]
def call_evaluate_while_simplifying_cached(self, rule: StardewRule, state: CollectionState) -> bool:
try:
return self.rule_mapping[rule].evaluate_while_simplifying(state)
except KeyError:
self.rule_mapping[rule], value = rule.evaluate_while_simplifying(state)
return value

def evaluate_while_simplifying(self, state: CollectionState) -> Tuple[StardewRule, bool]:
return self, self(state)

@cached_property
def rules_count(self):
Expand All @@ -395,14 +417,12 @@ def __repr__(self):
return f"Received {self.count} {repr(self.rules)}"


@dataclass(frozen=True, slots=True)
class Has(BaseStardewRule):
item: str
# For sure there is a better way than just passing all the rules everytime
other_rules: Dict[str, StardewRule]

def __init__(self, item: str, other_rules: Dict[str, StardewRule]):
self.item = item
self.other_rules = other_rules
other_rules: Dict[str, StardewRule] = field(repr=False, hash=False, compare=False)
group: str = "item"

def __call__(self, state: CollectionState) -> bool:
return self.evaluate_while_simplifying(state)[1]
Expand All @@ -415,16 +435,13 @@ def get_difficulty(self):

def __str__(self):
if self.item not in self.other_rules:
return f"Has {self.item} -> {MISSING_ITEM}"
return f"Has {self.item}"
return f"Has {self.item} {self.group} -> {MISSING_ITEM}"
return f"Has {self.item} {self.group}"

def __repr__(self):
if self.item not in self.other_rules:
return f"Has {self.item} -> {MISSING_ITEM}"
return f"Has {self.item} -> {repr(self.other_rules[self.item])}"

def __hash__(self):
return hash(self.item)
return f"Has {self.item} {self.group} -> {MISSING_ITEM}"
return f"Has {self.item} {self.group} -> {repr(self.other_rules[self.item])}"


class RepeatableChain(Iterable, Sized):
Expand Down
6 changes: 3 additions & 3 deletions worlds/stardew_valley/stardew_rule/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __repr__(self):
return f"Received {self.count} {self.items}"


@dataclass(frozen=True)
@dataclass(frozen=True, slots=True)
class Received(CombinableStardewRule):
item: str
player: int
Expand Down Expand Up @@ -80,7 +80,7 @@ def get_difficulty(self):
return self.count


@dataclass(frozen=True)
@dataclass(frozen=True, slots=True)
class Reach(BaseStardewRule):
spot: str
resolution_hint: str
Expand All @@ -101,7 +101,7 @@ def get_difficulty(self):
return 1


@dataclass(frozen=True)
@dataclass(frozen=True, slots=True)
class HasProgressionPercent(CombinableStardewRule):
player: int
percent: int
Expand Down
14 changes: 13 additions & 1 deletion worlds/stardew_valley/test/TestRules.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ..data.craftable_data import all_crafting_recipes_by_name
from ..locations import locations_by_tag, LocationTags, location_table
from ..options import ToolProgression, BuildingProgression, ExcludeGingerIsland, Chefsanity, Craftsanity, Shipsanity, SeasonRandomization, Friendsanity, \
FriendsanityHeartSize, BundleRandomization, SkillProgression
FriendsanityHeartSize, BundleRandomization, SkillProgression, Museumsanity
from ..strings.entrance_names import Entrance
from ..strings.region_names import Region

Expand Down Expand Up @@ -92,6 +92,18 @@ def test_old_master_cannoli(self):
self.remove(friday)


class TestMuseumMilestones(SVTestBase):
options = {
Museumsanity.internal_name: Museumsanity.option_milestones
}

def test_50_milestone(self):
self.multiworld.state.prog_items = {1: Counter()}

milestone_rule = self.world.logic.museum.can_find_museum_items(50)
self.assert_rule_false(milestone_rule, self.multiworld.state)


class TestBundlesLogic(SVTestBase):
options = {
BundleRandomization.internal_name: BundleRandomization.option_vanilla
Expand Down
56 changes: 55 additions & 1 deletion worlds/stardew_valley/test/TestStardewRule.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import unittest
from unittest.mock import MagicMock, Mock

from ..stardew_rule import Received, And, Or, HasProgressionPercent, false_, true_
from .. import StardewRule
from ..stardew_rule import Received, And, Or, HasProgressionPercent, false_, true_, Count


class TestSimplification(unittest.TestCase):
Expand Down Expand Up @@ -242,3 +243,56 @@ def call_to_already_simplified(state):
self.assertTrue(called_once)
self.assertTrue(internal_call_result)
self.assertTrue(actual_result)


class TestCount(unittest.TestCase):

def test_duplicate_rule_count_double(self):
expected_result = True
collection_state = MagicMock()
simplified_rule = MagicMock()
other_rule = MagicMock(spec=StardewRule)
other_rule.evaluate_while_simplifying = Mock(return_value=(simplified_rule, expected_result))
rule = Count([other_rule, other_rule, other_rule], 2)

actual_result = rule(collection_state)

other_rule.evaluate_while_simplifying.assert_called_once_with(collection_state)
self.assertEqual(expected_result, actual_result)

def test_simplified_rule_is_reused(self):
expected_result = True
collection_state = MagicMock()
simplified_rule = MagicMock(return_value=expected_result)
other_rule = MagicMock(spec=StardewRule)
other_rule.evaluate_while_simplifying = Mock(return_value=(simplified_rule, expected_result))
rule = Count([other_rule, other_rule, other_rule], 2)

actual_result = rule(collection_state)

other_rule.evaluate_while_simplifying.assert_called_once_with(collection_state)
self.assertEqual(expected_result, actual_result)

other_rule.evaluate_while_simplifying.reset_mock()

actual_result = rule(collection_state)

other_rule.evaluate_while_simplifying.assert_not_called()
simplified_rule.assert_not_called()
self.assertEqual(expected_result, actual_result)

def test_break_if_not_enough_rule_to_complete(self):
expected_result = False
collection_state = MagicMock()
simplified_rule = MagicMock()
never_called_rule = MagicMock()
other_rule = MagicMock(spec=StardewRule)
other_rule.evaluate_while_simplifying = Mock(return_value=(simplified_rule, expected_result))
rule = Count([other_rule, other_rule, other_rule, never_called_rule], 2)

actual_result = rule(collection_state)

other_rule.evaluate_while_simplifying.assert_called_once_with(collection_state)
never_called_rule.assert_not_called()
never_called_rule.evaluate_while_simplifying.assert_not_called()
self.assertEqual(expected_result, actual_result)

0 comments on commit a0d691b

Please sign in to comment.