diff --git a/worlds/stardew_valley/__init__.py b/worlds/stardew_valley/__init__.py index 6ba0e35e0a3a..6d8d40db094f 100644 --- a/worlds/stardew_valley/__init__.py +++ b/worlds/stardew_valley/__init__.py @@ -1,6 +1,6 @@ import logging from random import Random -from typing import Dict, Any, Iterable, Optional, Union, List, TextIO +from typing import Dict, Any, Iterable, Optional, List, TextIO from BaseClasses import Region, Entrance, Location, Item, Tutorial, ItemClassification, MultiWorld, CollectionState from Options import PerGameCommonOptions @@ -94,7 +94,6 @@ class StardewValleyWorld(World): randomized_entrances: Dict[str, str] total_progression_items: int - excluded_from_total_progression_items: List[str] = [Event.received_walnuts] def __init__(self, multiworld: MultiWorld, player: int): super().__init__(multiworld, player) @@ -182,7 +181,7 @@ def precollect_starting_season(self): if self.options.season_randomization == SeasonRandomization.option_disabled: for season in season_pool: - self.multiworld.push_precollected(self.create_starting_item(season)) + self.multiworld.push_precollected(self.create_item(season)) return if [item for item in self.multiworld.precollected_items[self.player] @@ -192,12 +191,12 @@ def precollect_starting_season(self): if self.options.season_randomization == SeasonRandomization.option_randomized_not_winter: season_pool = [season for season in season_pool if season.name != "Winter"] - starting_season = self.create_starting_item(self.random.choice(season_pool)) + starting_season = self.create_item(self.random.choice(season_pool)) self.multiworld.push_precollected(starting_season) def precollect_farm_type_items(self): if self.options.farm_type == FarmType.option_meadowlands and self.options.building_progression & BuildingProgression.option_progressive: - self.multiworld.push_precollected(self.create_starting_item("Progressive Coop")) + self.multiworld.push_precollected(self.create_item("Progressive Coop")) def setup_player_events(self): self.setup_action_events() @@ -291,7 +290,7 @@ def setup_victory(self): def get_all_location_names(self) -> List[str]: return list(location.name for location in self.multiworld.get_locations(self.player)) - def create_item(self, item: Union[str, ItemData], override_classification: ItemClassification = None) -> StardewItem: + def create_item(self, item: str | ItemData, override_classification: ItemClassification = None) -> StardewItem: if isinstance(item, str): item = item_table[item] @@ -300,12 +299,6 @@ def create_item(self, item: Union[str, ItemData], override_classification: ItemC return StardewItem(item.name, override_classification, item.code, self.player) - def create_starting_item(self, item: Union[str, ItemData]) -> StardewItem: - if isinstance(item, str): - item = item_table[item] - - return StardewItem(item.name, item.classification, item.code, self.player) - def create_event_location(self, location_data: LocationData, rule: StardewRule = None, item: Optional[str] = None): if rule is None: rule = True_() @@ -413,9 +406,19 @@ def collect(self, state: CollectionState, item: StardewItem) -> bool: if not change: return False + player_state = state.prog_items[self.player] + + received_progression_count = player_state[Event.received_progression_item] + received_progression_count += 1 + if self.total_progression_items: + # Total progression items is not set until all items are created, but collect will be called during the item creation when an item is precollected. + # We can't update the percentage if we don't know the total progression items, can't divide by 0. + player_state[Event.received_progression_percent] = received_progression_count * 100 // self.total_progression_items + player_state[Event.received_progression_item] = received_progression_count + walnut_amount = self.get_walnut_amount(item.name) if walnut_amount: - state.prog_items[self.player][Event.received_walnuts] += walnut_amount + player_state[Event.received_walnuts] += walnut_amount return True @@ -424,9 +427,18 @@ def remove(self, state: CollectionState, item: StardewItem) -> bool: if not change: return False + player_state = state.prog_items[self.player] + + received_progression_count = player_state[Event.received_progression_item] + received_progression_count -= 1 + if self.total_progression_items: + # We can't update the percentage if we don't know the total progression items, can't divide by 0. + player_state[Event.received_progression_percent] = received_progression_count * 100 // self.total_progression_items + player_state[Event.received_progression_item] = received_progression_count + walnut_amount = self.get_walnut_amount(item.name) if walnut_amount: - state.prog_items[self.player][Event.received_walnuts] -= walnut_amount + player_state[Event.received_walnuts] -= walnut_amount return True diff --git a/worlds/stardew_valley/stardew_rule/state.py b/worlds/stardew_valley/stardew_rule/state.py index 6fc349a6274d..d60f08ac4c94 100644 --- a/worlds/stardew_valley/stardew_rule/state.py +++ b/worlds/stardew_valley/stardew_rule/state.py @@ -4,6 +4,7 @@ from BaseClasses import CollectionState from .base import BaseStardewRule, CombinableStardewRule from .protocol import StardewRule +from ..strings.ap_names.event_names import Event if TYPE_CHECKING: from .. import StardewValleyWorld @@ -87,45 +88,13 @@ def __repr__(self): return f"Reach {self.resolution_hint} {self.spot}" -@dataclass(frozen=True) -class HasProgressionPercent(CombinableStardewRule): - player: int - percent: int +class HasProgressionPercent(Received): + def __init__(self, player: int, percent: int): + super().__init__(Event.received_progression_percent, player, percent, event=True) def __post_init__(self): - assert self.percent > 0, "HasProgressionPercent rule must be above 0%" - assert self.percent <= 100, "HasProgressionPercent rule can't require more than 100% of items" - - @property - def combination_key(self) -> Hashable: - return HasProgressionPercent.__name__ - - @property - def value(self): - return self.percent - - def __call__(self, state: CollectionState) -> bool: - stardew_world: "StardewValleyWorld" = state.multiworld.worlds[self.player] - total_count = stardew_world.total_progression_items - needed_count = (total_count * self.percent) // 100 - player_state = state.prog_items[self.player] - - if needed_count <= len(player_state) - len(stardew_world.excluded_from_total_progression_items): - return True - - total_count = 0 - for item, item_count in player_state.items(): - if item in stardew_world.excluded_from_total_progression_items: - continue - - total_count += item_count - if total_count >= needed_count: - return True - - return False - - def evaluate_while_simplifying(self, state: CollectionState) -> Tuple[StardewRule, bool]: - return self, self(state) + assert self.count > 0, "HasProgressionPercent rule must be above 0%" + assert self.count <= 100, "HasProgressionPercent rule can't require more than 100% of items" def __repr__(self): - return f"Received {self.percent}% progression items" + return f"Received {self.count}% progression items" diff --git a/worlds/stardew_valley/strings/ap_names/event_names.py b/worlds/stardew_valley/strings/ap_names/event_names.py index 449bb6720964..8ee69e178a4f 100644 --- a/worlds/stardew_valley/strings/ap_names/event_names.py +++ b/worlds/stardew_valley/strings/ap_names/event_names.py @@ -14,3 +14,5 @@ class Event: winter_farming = event("Winter Farming") received_walnuts = event("Received Walnuts") + received_progression_item = event("Received Progression Item") + received_progression_percent = event("Received Progression Percent") diff --git a/worlds/stardew_valley/test/rules/TestShipping.py b/worlds/stardew_valley/test/rules/TestShipping.py index b26d1e94ee2c..125b7f31d0d9 100644 --- a/worlds/stardew_valley/test/rules/TestShipping.py +++ b/worlds/stardew_valley/test/rules/TestShipping.py @@ -69,14 +69,17 @@ class TestShipsanityEverything(SVTestBase): def test_all_shipsanity_locations_require_shipping_bin(self): bin_name = "Shipping Bin" self.collect_all_except(bin_name) - shipsanity_locations = [location for location in self.get_real_locations() if - LocationTags.SHIPSANITY in location_table[location.name].tags] + shipsanity_locations = [location + for location in self.get_real_locations() + if LocationTags.SHIPSANITY in location_table[location.name].tags] bin_item = self.create_item(bin_name) + for location in shipsanity_locations: with self.subTest(location.name): - self.remove(bin_item) self.assertFalse(self.world.logic.region.can_reach_location(location.name)(self.multiworld.state)) - self.multiworld.state.collect(bin_item) + + self.collect(bin_item) shipsanity_rule = self.world.logic.region.can_reach_location(location.name) self.assert_rule_true(shipsanity_rule, self.multiworld.state) + self.remove(bin_item)