From a0d691b62bd0a1d93ad836fe5be279d2a9101e53 Mon Sep 17 00:00:00 2001 From: Jouramie Date: Sun, 17 Mar 2024 22:42:27 -0400 Subject: [PATCH] group rules before counting them use slots add group to has rules reach any checks if any are already accessible remove hardcoded true from count --- worlds/stardew_valley/logic/building_logic.py | 10 +-- worlds/stardew_valley/logic/has_logic.py | 14 +++- worlds/stardew_valley/logic/quest_logic.py | 2 +- worlds/stardew_valley/logic/region_logic.py | 3 + .../logic/special_order_logic.py | 3 +- worlds/stardew_valley/stardew_rule/base.py | 65 ++++++++++++------- worlds/stardew_valley/stardew_rule/state.py | 6 +- worlds/stardew_valley/test/TestRules.py | 14 +++- worlds/stardew_valley/test/TestStardewRule.py | 56 +++++++++++++++- 9 files changed, 135 insertions(+), 38 deletions(-) diff --git a/worlds/stardew_valley/logic/building_logic.py b/worlds/stardew_valley/logic/building_logic.py index 7be3d19ec33b..c6850f6cb2b0 100644 --- a/worlds/stardew_valley/logic/building_logic.py +++ b/worlds/stardew_valley/logic/building_logic.py @@ -15,6 +15,8 @@ from ..strings.material_names import Material from ..strings.metal_names import MetalBar +has_group = "building" + class BuildingLogicMixin(BaseLogicMixin): def __init__(self, *args, **kwargs): @@ -60,7 +62,7 @@ def has_building(self, building: str) -> StardewRule: carpenter_rule = self.logic.received(Event.can_construct_buildings) if not self.options.building_progression & BuildingProgression.option_progressive: - return Has(building, self.registry.building_rules) & carpenter_rule + return Has(building, self.registry.building_rules, has_group) & carpenter_rule count = 1 if building in [Building.coop, Building.barn, Building.shed]: @@ -86,10 +88,10 @@ def has_house(self, upgrade_level: int) -> StardewRule: return carpenter_rule & self.logic.received(f"Progressive House", upgrade_level) if upgrade_level == 1: - return carpenter_rule & Has(Building.kitchen, self.registry.building_rules) + return carpenter_rule & Has(Building.kitchen, self.registry.building_rules, has_group) if upgrade_level == 2: - return carpenter_rule & Has(Building.kids_room, self.registry.building_rules) + return carpenter_rule & Has(Building.kids_room, self.registry.building_rules, has_group) # if upgrade_level == 3: - return carpenter_rule & Has(Building.cellar, self.registry.building_rules) + return carpenter_rule & Has(Building.cellar, self.registry.building_rules, has_group) diff --git a/worlds/stardew_valley/logic/has_logic.py b/worlds/stardew_valley/logic/has_logic.py index d92d4224d7d2..9856118b087b 100644 --- a/worlds/stardew_valley/logic/has_logic.py +++ b/worlds/stardew_valley/logic/has_logic.py @@ -1,5 +1,5 @@ from .base_logic import BaseLogic -from ..stardew_rule import StardewRule, And, Or, Has, Count +from ..stardew_rule import StardewRule, And, Or, Has, Count, true_ class HasLogicMixin(BaseLogic[None]): @@ -24,6 +24,16 @@ def has_n(self, *items: str, count: int): def count(count: int, *rules: StardewRule) -> StardewRule: assert rules, "Can't create a Count conditions without rules" assert len(rules) >= count, "Count need at least as many rules as the count" + assert count > 0, "Count can't be negative" + + count -= sum(r is true_ for r in rules) + rules = list(r for r in rules if r is not true_) + + if count <= 0: + return true_ + + if len(rules) == 1: + return rules[0] if count == 1: return Or(*rules) @@ -31,4 +41,4 @@ def count(count: int, *rules: StardewRule) -> StardewRule: if count == len(rules): return And(*rules) - return Count(list(rules), count) + return Count(rules, count) diff --git a/worlds/stardew_valley/logic/quest_logic.py b/worlds/stardew_valley/logic/quest_logic.py index bc1f731429c6..144a2907a2a8 100644 --- a/worlds/stardew_valley/logic/quest_logic.py +++ b/worlds/stardew_valley/logic/quest_logic.py @@ -110,7 +110,7 @@ def update_rules(self, new_rules: Dict[str, StardewRule]): self.registry.quest_rules.update(new_rules) def can_complete_quest(self, quest: str) -> StardewRule: - return Has(quest, self.registry.quest_rules) + return Has(quest, self.registry.quest_rules, "quest") def has_club_card(self) -> StardewRule: if self.options.quest_locations < 0: diff --git a/worlds/stardew_valley/logic/region_logic.py b/worlds/stardew_valley/logic/region_logic.py index 81dabf45aac5..b5201b46632a 100644 --- a/worlds/stardew_valley/logic/region_logic.py +++ b/worlds/stardew_valley/logic/region_logic.py @@ -42,6 +42,9 @@ def can_reach(self, region_name: str) -> StardewRule: @cache_self1 def can_reach_any(self, region_names: Tuple[str, ...]) -> StardewRule: + if any(r in always_regions_by_setting[self.options.entrance_randomization] for r in region_names): + return true_ + return Or(*(self.logic.region.can_reach(spot) for spot in region_names)) @cache_self1 diff --git a/worlds/stardew_valley/logic/special_order_logic.py b/worlds/stardew_valley/logic/special_order_logic.py index e0b1a7e2fb27..f368e2e1d96f 100644 --- a/worlds/stardew_valley/logic/special_order_logic.py +++ b/worlds/stardew_valley/logic/special_order_logic.py @@ -35,7 +35,6 @@ from ..strings.region_names import Region from ..strings.season_names import Season from ..strings.special_order_names import SpecialOrder -from ..strings.tool_names import Tool from ..strings.villager_names import NPC @@ -105,7 +104,7 @@ def update_rules(self, new_rules: Dict[str, StardewRule]): self.registry.special_order_rules.update(new_rules) def can_complete_special_order(self, special_order: str) -> StardewRule: - return Has(special_order, self.registry.special_order_rules) + return Has(special_order, self.registry.special_order_rules, "special order") def has_island_transport(self) -> StardewRule: return self.logic.received(Transportation.island_obelisk) | self.logic.received(Transportation.boat_repair) diff --git a/worlds/stardew_valley/stardew_rule/base.py b/worlds/stardew_valley/stardew_rule/base.py index 007d2b64dc41..48092d76e265 100644 --- a/worlds/stardew_valley/stardew_rule/base.py +++ b/worlds/stardew_valley/stardew_rule/base.py @@ -1,7 +1,8 @@ from __future__ import annotations from abc import ABC, abstractmethod -from collections import deque +from collections import deque, Counter +from dataclasses import dataclass, field from functools import cached_property from itertools import chain from threading import Lock @@ -295,7 +296,10 @@ def __eq__(self, other): self.simplification_state.original_simplifiable_rules == self.simplification_state.original_simplifiable_rules) def __hash__(self): - return hash((id(self.combinable_rules), self.simplification_state.original_simplifiable_rules)) + if len(self.combinable_rules) + len(self.simplification_state.original_simplifiable_rules) > 5: + return id(self) + + return hash((*self.combinable_rules.values(), self.simplification_state.original_simplifiable_rules)) class Or(AggregatingStardewRule): @@ -361,25 +365,43 @@ class Count(BaseStardewRule): rules: List[StardewRule] def __init__(self, rules: List[StardewRule], count: int): - self.rules = rules self.count = count + self.counter = Counter(rules) + self.rules = sorted(self.counter.keys(), key=lambda x: self.counter[x], reverse=True) + self.rule_mapping = {} - def evaluate_while_simplifying(self, state: CollectionState) -> Tuple[StardewRule, bool]: + assert self.counter.total() == len(rules) + + def __call__(self, state: CollectionState) -> bool: c = 0 + t = self.counter.total() + for i in range(self.rules_count): - self.rules[i], value = self.rules[i].evaluate_while_simplifying(state) - if value: - c += 1 + original_rule = self.rules[i] + evaluation_value = self.call_evaluate_while_simplifying_cached(original_rule, state) + rule_value = self.counter[original_rule] + + if evaluation_value: + c += rule_value + else: + t -= rule_value if c >= self.count: - return self, True - if c + self.rules_count - i < self.count: + return True + elif t < self.count: break - return self, False + return False - def __call__(self, state: CollectionState) -> bool: - return self.evaluate_while_simplifying(state)[1] + def call_evaluate_while_simplifying_cached(self, rule: StardewRule, state: CollectionState) -> bool: + try: + return self.rule_mapping[rule].evaluate_while_simplifying(state) + except KeyError: + self.rule_mapping[rule], value = rule.evaluate_while_simplifying(state) + return value + + def evaluate_while_simplifying(self, state: CollectionState) -> Tuple[StardewRule, bool]: + return self, self(state) @cached_property def rules_count(self): @@ -395,14 +417,12 @@ def __repr__(self): return f"Received {self.count} {repr(self.rules)}" +@dataclass(frozen=True, slots=True) class Has(BaseStardewRule): item: str # For sure there is a better way than just passing all the rules everytime - other_rules: Dict[str, StardewRule] - - def __init__(self, item: str, other_rules: Dict[str, StardewRule]): - self.item = item - self.other_rules = other_rules + other_rules: Dict[str, StardewRule] = field(repr=False, hash=False, compare=False) + group: str = "item" def __call__(self, state: CollectionState) -> bool: return self.evaluate_while_simplifying(state)[1] @@ -415,16 +435,13 @@ def get_difficulty(self): def __str__(self): if self.item not in self.other_rules: - return f"Has {self.item} -> {MISSING_ITEM}" - return f"Has {self.item}" + return f"Has {self.item} {self.group} -> {MISSING_ITEM}" + return f"Has {self.item} {self.group}" def __repr__(self): if self.item not in self.other_rules: - return f"Has {self.item} -> {MISSING_ITEM}" - return f"Has {self.item} -> {repr(self.other_rules[self.item])}" - - def __hash__(self): - return hash(self.item) + return f"Has {self.item} {self.group} -> {MISSING_ITEM}" + return f"Has {self.item} {self.group} -> {repr(self.other_rules[self.item])}" class RepeatableChain(Iterable, Sized): diff --git a/worlds/stardew_valley/stardew_rule/state.py b/worlds/stardew_valley/stardew_rule/state.py index a0fce7c7c19e..178d4b79467e 100644 --- a/worlds/stardew_valley/stardew_rule/state.py +++ b/worlds/stardew_valley/stardew_rule/state.py @@ -47,7 +47,7 @@ def __repr__(self): return f"Received {self.count} {self.items}" -@dataclass(frozen=True) +@dataclass(frozen=True, slots=True) class Received(CombinableStardewRule): item: str player: int @@ -80,7 +80,7 @@ def get_difficulty(self): return self.count -@dataclass(frozen=True) +@dataclass(frozen=True, slots=True) class Reach(BaseStardewRule): spot: str resolution_hint: str @@ -101,7 +101,7 @@ def get_difficulty(self): return 1 -@dataclass(frozen=True) +@dataclass(frozen=True, slots=True) class HasProgressionPercent(CombinableStardewRule): player: int percent: int diff --git a/worlds/stardew_valley/test/TestRules.py b/worlds/stardew_valley/test/TestRules.py index 0d2fc38a19a3..07a8df16856d 100644 --- a/worlds/stardew_valley/test/TestRules.py +++ b/worlds/stardew_valley/test/TestRules.py @@ -5,7 +5,7 @@ from ..data.craftable_data import all_crafting_recipes_by_name from ..locations import locations_by_tag, LocationTags, location_table from ..options import ToolProgression, BuildingProgression, ExcludeGingerIsland, Chefsanity, Craftsanity, Shipsanity, SeasonRandomization, Friendsanity, \ - FriendsanityHeartSize, BundleRandomization, SkillProgression + FriendsanityHeartSize, BundleRandomization, SkillProgression, Museumsanity from ..strings.entrance_names import Entrance from ..strings.region_names import Region @@ -92,6 +92,18 @@ def test_old_master_cannoli(self): self.remove(friday) +class TestMuseumMilestones(SVTestBase): + options = { + Museumsanity.internal_name: Museumsanity.option_milestones + } + + def test_50_milestone(self): + self.multiworld.state.prog_items = {1: Counter()} + + milestone_rule = self.world.logic.museum.can_find_museum_items(50) + self.assert_rule_false(milestone_rule, self.multiworld.state) + + class TestBundlesLogic(SVTestBase): options = { BundleRandomization.internal_name: BundleRandomization.option_vanilla diff --git a/worlds/stardew_valley/test/TestStardewRule.py b/worlds/stardew_valley/test/TestStardewRule.py index 89317d90e4e2..189514eca815 100644 --- a/worlds/stardew_valley/test/TestStardewRule.py +++ b/worlds/stardew_valley/test/TestStardewRule.py @@ -1,7 +1,8 @@ import unittest from unittest.mock import MagicMock, Mock -from ..stardew_rule import Received, And, Or, HasProgressionPercent, false_, true_ +from .. import StardewRule +from ..stardew_rule import Received, And, Or, HasProgressionPercent, false_, true_, Count class TestSimplification(unittest.TestCase): @@ -242,3 +243,56 @@ def call_to_already_simplified(state): self.assertTrue(called_once) self.assertTrue(internal_call_result) self.assertTrue(actual_result) + + +class TestCount(unittest.TestCase): + + def test_duplicate_rule_count_double(self): + expected_result = True + collection_state = MagicMock() + simplified_rule = MagicMock() + other_rule = MagicMock(spec=StardewRule) + other_rule.evaluate_while_simplifying = Mock(return_value=(simplified_rule, expected_result)) + rule = Count([other_rule, other_rule, other_rule], 2) + + actual_result = rule(collection_state) + + other_rule.evaluate_while_simplifying.assert_called_once_with(collection_state) + self.assertEqual(expected_result, actual_result) + + def test_simplified_rule_is_reused(self): + expected_result = True + collection_state = MagicMock() + simplified_rule = MagicMock(return_value=expected_result) + other_rule = MagicMock(spec=StardewRule) + other_rule.evaluate_while_simplifying = Mock(return_value=(simplified_rule, expected_result)) + rule = Count([other_rule, other_rule, other_rule], 2) + + actual_result = rule(collection_state) + + other_rule.evaluate_while_simplifying.assert_called_once_with(collection_state) + self.assertEqual(expected_result, actual_result) + + other_rule.evaluate_while_simplifying.reset_mock() + + actual_result = rule(collection_state) + + other_rule.evaluate_while_simplifying.assert_not_called() + simplified_rule.assert_not_called() + self.assertEqual(expected_result, actual_result) + + def test_break_if_not_enough_rule_to_complete(self): + expected_result = False + collection_state = MagicMock() + simplified_rule = MagicMock() + never_called_rule = MagicMock() + other_rule = MagicMock(spec=StardewRule) + other_rule.evaluate_while_simplifying = Mock(return_value=(simplified_rule, expected_result)) + rule = Count([other_rule, other_rule, other_rule, never_called_rule], 2) + + actual_result = rule(collection_state) + + other_rule.evaluate_while_simplifying.assert_called_once_with(collection_state) + never_called_rule.assert_not_called() + never_called_rule.evaluate_while_simplifying.assert_not_called() + self.assertEqual(expected_result, actual_result)