diff --git a/worlds/sc2/ItemGroups.py b/worlds/sc2/ItemGroups.py index c9ca50c27753..fcdd2e1d27e6 100644 --- a/worlds/sc2/ItemGroups.py +++ b/worlds/sc2/ItemGroups.py @@ -149,6 +149,14 @@ class ItemGroupNames: VANILLA_ITEMS = "Vanilla Items" + @classmethod + def get_all_group_names(cls) -> typing.Set[str]: + return { + name for identifier, name in cls.__dict__.items() + if not identifier.startswith('_') + and not identifier.startswith('get_') + } + # Terran item_name_groups[ItemGroupNames.TERRAN_ITEMS] = terran_items = [ diff --git a/worlds/sc2/Options.py b/worlds/sc2/Options.py index c3967dd06641..c6829faebac3 100644 --- a/worlds/sc2/Options.py +++ b/worlds/sc2/Options.py @@ -1,19 +1,60 @@ from dataclasses import dataclass, fields, Field from typing import * -from Options import (Choice, Toggle, DefaultOnToggle, ItemDict, OptionSet, Range, OptionDict, +from Utils import is_iterable_except_str +from Options import (Choice, Toggle, DefaultOnToggle, OptionSet, Range, PerGameCommonOptions, Option, VerifyKeys) from Utils import get_fuzzy_results from BaseClasses import PlandoOptions from .MissionTables import SC2Campaign, SC2Mission, lookup_name_to_mission, MissionPools, get_no_build_missions, \ campaign_mission_table from .MissionOrders import vanilla_shuffle_order, mini_campaign_order +from .mission_groups import mission_groups, MissionGroupNames if TYPE_CHECKING: from worlds.AutoWorld import World from . import SC2World +class Sc2MissionSet(OptionSet): + """Option set made for handling missions and expanding mission groups""" + valid_keys = [x.mission_name for x in SC2Mission] + + @classmethod + def from_any(cls, data: Any): + if is_iterable_except_str(data): + return cls(data) + return cls.from_text(str(data)) + + def verify(self, world: Type['World'], player_name: str, plando_options: PlandoOptions) -> None: + """Overridden version of function from Options.VerifyKeys for a better error message""" + new_value: set[str] = set() + case_insensitive_group_mapping = { + group_name.casefold(): group_value for group_name, group_value in mission_groups.items() + } + case_insensitive_group_mapping.update({mission.mission_name.casefold(): [mission.mission_name] for mission in SC2Mission}) + for group_name in self.value: + item_names = case_insensitive_group_mapping.get(group_name.casefold(), {group_name}) + new_value.update(item_names) + self.value = new_value + for item_name in self.value: + if item_name not in self.valid_keys: + picks = get_fuzzy_results( + item_name, + list(self.valid_keys) + list(MissionGroupNames.get_all_group_names()), + limit=1, + ) + raise Exception(f"Mission {item_name} from option {self} " + f"is not a valid mission name from {world.game}. " + f"Did you mean '{picks[0][0]}' ({picks[0][1]}% sure)") + + def __iter__(self) -> Iterator[str]: + return self.value.__iter__() + + def __len__(self) -> int: + return self.value.__len__() + + class GameDifficulty(Choice): """ The difficulty of the campaign, affects enemy AI, starting units, and game speed. @@ -633,7 +674,6 @@ def from_any(cls, data: Union[List[str], Dict[str, int]]) -> 'Sc2ItemDict': # It doesn't play nice with trigger merging dicts and lists together, though, so best not to advertise it overmuch. data = {item: 0 for item in data} if isinstance(data, dict): - cls.verify_keys(data) for key, value in data.items(): if not isinstance(value, int): raise ValueError(f"Invalid type in '{cls.display_name}': element '{key}' maps to '{value}', expected an integer") @@ -647,7 +687,9 @@ def verify(self, world: Type['World'], player_name: str, plando_options: PlandoO """Overridden version of function from Options.VerifyKeys for a better error message""" new_value: dict[str, int] = {} case_insensitive_group_mapping = { - group_name.casefold(): group_value for group_name, group_value in world.item_name_groups.items()} + group_name.casefold(): group_value for group_name, group_value in world.item_name_groups.items() + } + case_insensitive_group_mapping.update({item.casefold(): [item] for item in world.item_names}) for group_name in self.value: item_names = case_insensitive_group_mapping.get(group_name.casefold(), {group_name}) for item_name in item_names: @@ -655,7 +697,12 @@ def verify(self, world: Type['World'], player_name: str, plando_options: PlandoO self.value = new_value for item_name in self.value: if item_name not in world.item_names: - picks = get_fuzzy_results(item_name, list(world.item_names), limit=1) + from . import ItemGroups + picks = get_fuzzy_results( + item_name, + list(world.item_names) + list(ItemGroups.ItemGroupNames.get_all_group_names()), + limit=1, + ) raise Exception(f"Item {item_name} from option {self} " f"is not a valid item name from {world.game}. " f"Did you mean '{picks[0][0]}' ({picks[0][1]}% sure)") @@ -691,7 +738,7 @@ class UnexcludedItems(Sc2ItemDict): display_name = "Unexcluded Items" -class ExcludedMissions(OptionSet): +class ExcludedMissions(Sc2MissionSet): """Guarantees that these missions will not appear in the campaign Doesn't apply to vanilla mission order. It may be impossible to build a valid campaign if too many missions are excluded.""" diff --git a/worlds/sc2/mission_groups.py b/worlds/sc2/mission_groups.py new file mode 100644 index 000000000000..da3438901e4b --- /dev/null +++ b/worlds/sc2/mission_groups.py @@ -0,0 +1,192 @@ +""" +Mission group aliases for use in yaml options. +""" +from typing import Dict, List, Set +from .MissionTables import SC2Mission, MissionFlag, SC2Campaign + + +class MissionGroupNames: + ALL_MISSIONS = "All Missions" + WOL_MISSIONS = "WoL Missions" + HOTS_MISSIONS = "HotS Missions" + LOTV_MISSIONS = "LotV Missions" + NCO_MISSIONS = "NCO Missions" + PROPHECY_MISSIONS = "Prophecy Missions" + PROLOGUE_MISSIONS = "Prologue Missions" + EPILOGUE_MISSIONS = "Epilogue Missions" + + TERRAN_MISSIONS = "Terran Missions" + ZERG_MISSIONS = "Zerg Missions" + PROTOSS_MISSIONS = "Protoss Missions" + NOBUILD_MISSIONS = "No-Build Missions" + DEFENSE_MISSIONS = "Defense Missions" + AUTO_SCROLLER_MISSIONS = "Auto-Scroller Missions" + COUNTDOWN_MISSIONS = "Countdown Missions" + KERRIGAN_MISSIONS = "Kerrigan Missions" + VANILLA_SOA_MISSIONS = "Vanilla SOA Missions" + TERRAN_ALLY_MISSIONS = "Controllable Terran Ally Missions" + ZERG_ALLY_MISSIONS = "Controllable Zerg Ally Missions" + PROTOSS_ALLY_MISSIONS = "Controllable Protoss Ally Missions" + VS_TERRAN_MISSIONS = "Vs Terran Missions" + VS_ZERG_MISSIONS = "Vs Zerg Missions" + VS_PROTOSS_MISSIONS = "Vs Protoss Missions" + + # By planet + PLANET_MAR_SARA_MISSIONS = "Planet Mar Sara" + PLANET_CHAR_MISSIONS = "Planet Char" + PLANET_KORHAL_MISSIONS = "Planet Korhal" + PLANET_AIUR_MISSIONS = "Planet Aiur" + + # By quest chain + WOL_MAR_SARA_MISSIONS = "WoL Mar Sara" + WOL_COLONIST_MISSIONS = "WoL Colonist" + WOL_ARTIFACT_MISSIONS = "WoL Artifact" + WOL_COVERT_MISSIONS = "WoL Covert" + WOL_REBELLION_MISSIONS = "WoL Rebellion" + WOL_CHAR_MISSIONS = "WoL Char" + + HOTS_UMOJA_MISSIONS = "HotS Umoja" + HOTS_KALDIR_MISSIONS = "HotS Kaldir" + HOTS_CHAR_MISSIONS = "HotS Char" + HOTS_ZERUS_MISSIONS = "HotS Zerus" + HOTS_SKYGEIRR_MISSIONS = "HotS Skygeirr Station" + HOTS_DOMINION_SPACE_MISSIONS = "HotS Dominion Space" + HOTS_KORHAL_MISSIONS = "HotS Korhal" + + LOTV_AIUR_MISSIONS = "LotV Aiur" + LOTV_KORHAL_MISSIONS = "LotV Korhal" + LOTV_SHAKURAS_MISSIONS = "LotV Shakuras" + LOTV_ULNAR_MISSIONS = "LotV Ulnar" + LOTV_PURIFIER_MISSIONS = "LotV Purifier" + LOTV_TALDARIM_MISSIONS = "LotV Tal'darim" + LOTV_MOEBIUS_MISSIONS = "LotV Moebius" + LOTV_RETURN_TO_AIUR_MISSIONS = "LotV Return to Aiur" + + NCO_MISSION_PACK_1 = "NCO Mission Pack 1" + NCO_MISSION_PACK_2 = "NCO Mission Pack 2" + NCO_MISSION_PACK_3 = "NCO Mission Pack 3" + + @classmethod + def get_all_group_names(cls) -> Set[str]: + return { + name for identifier, name in cls.__dict__.items() + if not identifier.startswith('_') + and not identifier.startswith('get_') + } + + +mission_groups: Dict[str, List[str]] = {} + +mission_groups[MissionGroupNames.ALL_MISSIONS] = [ + mission.mission_name for mission in SC2Mission +] +for group_name, campaign in ( + (MissionGroupNames.WOL_MISSIONS, SC2Campaign.WOL), + (MissionGroupNames.HOTS_MISSIONS, SC2Campaign.HOTS), + (MissionGroupNames.LOTV_MISSIONS, SC2Campaign.LOTV), + (MissionGroupNames.NCO_MISSIONS, SC2Campaign.NCO), + (MissionGroupNames.PROPHECY_MISSIONS, SC2Campaign.PROPHECY), + (MissionGroupNames.PROLOGUE_MISSIONS, SC2Campaign.PROLOGUE), + (MissionGroupNames.EPILOGUE_MISSIONS, SC2Campaign.EPILOGUE), +): + mission_groups[group_name] = [ + mission.mission_name for mission in SC2Mission if mission.campaign == SC2Campaign.WOL + ] + +for group_name, flags in ( + (MissionGroupNames.TERRAN_MISSIONS, MissionFlag.Terran), + (MissionGroupNames.ZERG_MISSIONS, MissionFlag.Zerg), + (MissionGroupNames.PROTOSS_MISSIONS, MissionFlag.Protoss), + (MissionGroupNames.NOBUILD_MISSIONS, MissionFlag.NoBuild), + (MissionGroupNames.DEFENSE_MISSIONS, MissionFlag.Defense), + (MissionGroupNames.AUTO_SCROLLER_MISSIONS, MissionFlag.AutoScroller), + (MissionGroupNames.COUNTDOWN_MISSIONS, MissionFlag.Countdown), + (MissionGroupNames.KERRIGAN_MISSIONS, MissionFlag.Kerrigan), + (MissionGroupNames.VANILLA_SOA_MISSIONS, MissionFlag.VanillaSoa), + (MissionGroupNames.TERRAN_ALLY_MISSIONS, MissionFlag.AiTerranAlly), + (MissionGroupNames.ZERG_ALLY_MISSIONS, MissionFlag.AiZergAlly), + (MissionGroupNames.PROTOSS_ALLY_MISSIONS, MissionFlag.AiProtossAlly), + (MissionGroupNames.VS_TERRAN_MISSIONS, MissionFlag.VsTerran), + (MissionGroupNames.VS_ZERG_MISSIONS, MissionFlag.VsZerg), + (MissionGroupNames.VS_PROTOSS_MISSIONS, MissionFlag.VsProtoss), +): + mission_groups[group_name] = [ + mission.mission_name for mission in SC2Mission if flags in mission.flags + ] + +for group_name, campaign, chain_name in ( + (MissionGroupNames.WOL_MAR_SARA_MISSIONS, SC2Campaign.WOL, "Mar Sara"), + (MissionGroupNames.WOL_COLONIST_MISSIONS, SC2Campaign.WOL, "Colonist"), + (MissionGroupNames.WOL_ARTIFACT_MISSIONS, SC2Campaign.WOL, "Artifact"), + (MissionGroupNames.WOL_COVERT_MISSIONS, SC2Campaign.WOL, "Covert"), + (MissionGroupNames.WOL_REBELLION_MISSIONS, SC2Campaign.WOL, "Rebellion"), + (MissionGroupNames.WOL_CHAR_MISSIONS, SC2Campaign.WOL, "Char"), + (MissionGroupNames.HOTS_UMOJA_MISSIONS, SC2Campaign.HOTS, "Umoja"), + (MissionGroupNames.HOTS_KALDIR_MISSIONS, SC2Campaign.HOTS, "Kaldir"), + (MissionGroupNames.HOTS_CHAR_MISSIONS, SC2Campaign.HOTS, "Char"), + (MissionGroupNames.HOTS_ZERUS_MISSIONS, SC2Campaign.HOTS, "Zerus"), + (MissionGroupNames.HOTS_SKYGEIRR_MISSIONS, SC2Campaign.HOTS, "Skygeirr Station"), + (MissionGroupNames.HOTS_DOMINION_SPACE_MISSIONS, SC2Campaign.HOTS, "Dominion Space"), + (MissionGroupNames.HOTS_KORHAL_MISSIONS, SC2Campaign.HOTS, "Korhal"), + (MissionGroupNames.LOTV_AIUR_MISSIONS, SC2Campaign.LOTV, "Aiur"), + (MissionGroupNames.LOTV_KORHAL_MISSIONS, SC2Campaign.LOTV, "Korhal"), + (MissionGroupNames.LOTV_SHAKURAS_MISSIONS, SC2Campaign.LOTV, "Shakuras"), + (MissionGroupNames.LOTV_ULNAR_MISSIONS, SC2Campaign.LOTV, "Ulnar"), + (MissionGroupNames.LOTV_PURIFIER_MISSIONS, SC2Campaign.LOTV, "Purifier"), + (MissionGroupNames.LOTV_TALDARIM_MISSIONS, SC2Campaign.LOTV, "Tal'darim"), + (MissionGroupNames.LOTV_MOEBIUS_MISSIONS, SC2Campaign.LOTV, "Moebius"), + (MissionGroupNames.LOTV_RETURN_TO_AIUR_MISSIONS, SC2Campaign.LOTV, "Return to Aiur"), +): + mission_groups[group_name] = [ + mission.mission_name for mission in SC2Mission + if mission.campaign == campaign + and mission.area == chain_name + ] + +mission_groups[MissionGroupNames.NCO_MISSION_PACK_1] = [ + SC2Mission.THE_ESCAPE.mission_name, + SC2Mission.SUDDEN_STRIKE.mission_name, + SC2Mission.ENEMY_INTELLIGENCE.mission_name, +] +mission_groups[MissionGroupNames.NCO_MISSION_PACK_2] = [ + SC2Mission.TROUBLE_IN_PARADISE.mission_name, + SC2Mission.NIGHT_TERRORS.mission_name, + SC2Mission.FLASHPOINT.mission_name, +] +mission_groups[MissionGroupNames.NCO_MISSION_PACK_3] = [ + SC2Mission.IN_THE_ENEMY_S_SHADOW.mission_name, + SC2Mission.DARK_SKIES.mission_name, + SC2Mission.END_GAME.mission_name, +] + +mission_groups[MissionGroupNames.PLANET_MAR_SARA_MISSIONS] = [ + SC2Mission.LIBERATION_DAY.mission_name, + SC2Mission.THE_OUTLAWS.mission_name, + SC2Mission.ZERO_HOUR.mission_name, +] +mission_groups[MissionGroupNames.PLANET_CHAR_MISSIONS] = [ + SC2Mission.GATES_OF_HELL.mission_name, + SC2Mission.BELLY_OF_THE_BEAST.mission_name, + SC2Mission.SHATTER_THE_SKY.mission_name, + SC2Mission.ALL_IN.mission_name, + SC2Mission.DOMINATION.mission_name, + SC2Mission.FIRE_IN_THE_SKY.mission_name, + SC2Mission.OLD_SOLDIERS.mission_name, +] +mission_groups[MissionGroupNames.PLANET_KORHAL_MISSIONS] = [ + SC2Mission.MEDIA_BLITZ.mission_name, + SC2Mission.PLANETFALL.mission_name, + SC2Mission.DEATH_FROM_ABOVE.mission_name, + SC2Mission.THE_RECKONING.mission_name, + SC2Mission.SKY_SHIELD.mission_name, + SC2Mission.BROTHERS_IN_ARMS.mission_name, +] +mission_groups[MissionGroupNames.PLANET_AIUR_MISSIONS] = [ + SC2Mission.ECHOES_OF_THE_FUTURE.mission_name, + SC2Mission.FOR_AIUR.mission_name, + SC2Mission.THE_GROWING_SHADOW.mission_name, + SC2Mission.THE_SPEAR_OF_ADUN.mission_name, + SC2Mission.TEMPLAR_S_RETURN.mission_name, + SC2Mission.THE_HOST.mission_name, + SC2Mission.SALVATION.mission_name, +] diff --git a/worlds/sc2/test/test_generation.py b/worlds/sc2/test/test_generation.py index 07a8b2bd9d33..81d657ce8655 100644 --- a/worlds/sc2/test/test_generation.py +++ b/worlds/sc2/test/test_generation.py @@ -4,7 +4,7 @@ from typing import * from .test_base import Sc2SetupTestBase -from .. import Options, MissionTables, ItemNames, Items, ItemGroups +from .. import Options, MissionTables, ItemNames, Items, ItemGroups, mission_groups from .. import get_all_missions @@ -90,6 +90,20 @@ def test_excluding_groups_excludes_all_items_in_group(self): for item_name in ItemGroups.barracks_units: self.assertNotIn(item_name, item_names) + def test_excluding_mission_groups_excludes_all_missions_in_group(self): + options = { + 'excluded_missions': [ + mission_groups.MissionGroupNames.HOTS_ZERUS_MISSIONS, + ], + 'mission_order': Options.MissionOrder.option_grid, + } + self.generate_world(options) + missions = get_all_missions(self.world.mission_req_table) + self.assertTrue(missions) + self.assertNotIn(MissionTables.SC2Mission.WAKING_THE_ANCIENT, missions) + self.assertNotIn(MissionTables.SC2Mission.THE_CRUCIBLE, missions) + self.assertNotIn(MissionTables.SC2Mission.SUPREME, missions) + def test_excluding_campaigns_excludes_campaign_specific_items(self) -> None: options = { 'enable_wol_missions': True, diff --git a/worlds/sc2/test/test_itemgroups.py b/worlds/sc2/test/test_itemgroups.py index ba1181bd929b..43666965f397 100644 --- a/worlds/sc2/test/test_itemgroups.py +++ b/worlds/sc2/test/test_itemgroups.py @@ -26,7 +26,5 @@ def test_all_items_in_stimpack_group_are_stimpacks(self) -> None: self.assertIn("Stimpack", item_name) def test_all_item_group_names_have_a_group_defined(self) -> None: - for var_name, display_name in ItemGroups.ItemGroupNames.__dict__.items(): - if var_name.startswith("_"): - continue - assert display_name in ItemGroups.item_name_groups + for display_name in ItemGroups.ItemGroupNames.get_all_group_names(): + self.assertIn(display_name, ItemGroups.item_name_groups) diff --git a/worlds/sc2/test/test_mission_groups.py b/worlds/sc2/test/test_mission_groups.py new file mode 100644 index 000000000000..a8bb100d1dba --- /dev/null +++ b/worlds/sc2/test/test_mission_groups.py @@ -0,0 +1,9 @@ + +import unittest +from .. import mission_groups + +class TestMissionGroups(unittest.TestCase): + def test_all_mission_groups_are_defined_and_nonempty(self) -> None: + for mission_group_name in mission_groups.MissionGroupNames.get_all_group_names(): + self.assertIn(mission_group_name, mission_groups.mission_groups) + self.assertTrue(mission_groups.mission_groups[mission_group_name])