diff --git a/ examples/poke_1v1_fighter.py b/ examples/poke_1v1_fighter.py index e9a6428..1f6d003 100644 --- a/ examples/poke_1v1_fighter.py +++ b/ examples/poke_1v1_fighter.py @@ -21,7 +21,7 @@ from poke_env import PlayerConfiguration from poke_env.player import SimpleHeuristicsPlayer -from p2lab.pokemon.battles import run_battles +from p2lab.battling.battles import run_battles from p2lab.pokemon.premade import gen_1_pokemon from p2lab.pokemon.teams import generate_teams, import_pool diff --git a/.gitignore b/.gitignore index 8c35753..28963d3 100644 --- a/.gitignore +++ b/.gitignore @@ -168,3 +168,4 @@ pokemon-showdown/ # Outputs outputs/* +simple_match.py diff --git a/src/p2lab/__main__.py b/src/p2lab/__main__.py index c45c241..a43d08f 100644 --- a/src/p2lab/__main__.py +++ b/src/p2lab/__main__.py @@ -7,12 +7,16 @@ from p2lab.genetic.genetic import genetic_algorithm from p2lab.pokemon.premade import gen_1_pokemon -from p2lab.pokemon.teams import generate_teams, import_pool +from p2lab.pokemon.teams import generate_pool, generate_teams, import_pool -async def main_loop(num_teams, team_size, num_generations, unique): +async def main_loop(num_teams, team_size, num_generations, unique, use_premade=False): # generate the pool - pool = import_pool(gen_1_pokemon()) + pool = ( + import_pool(gen_1_pokemon()) + if use_premade + else generate_pool(151, use_showdown=False, dexids=np.arange(1, 152)) + ) seed_teams = generate_teams(pool, num_teams, team_size, unique=unique) # crossover_fn = build_crossover_fn(locus_swap, locus=0) # run the genetic algorithm @@ -74,11 +78,13 @@ def parse_args(): def main(): args = parse_args() - if args["s"] is not None: - np.random.seed(args["s"]) + if args["seed"] is not None: + np.random.seed(args["seed"]) asyncio.get_event_loop().run_until_complete( - main_loop(args["n"], args["t"], args["g"], args["u"]) + main_loop( + args["numteams"], args["teamsize"], args["generations"], args["unique"] + ) ) diff --git a/src/p2lab/pokemon/battles.py b/src/p2lab/battling/battles.py similarity index 100% rename from src/p2lab/pokemon/battles.py rename to src/p2lab/battling/battles.py diff --git a/src/p2lab/genetic/genetic.py b/src/p2lab/genetic/genetic.py index 9c11d1b..64049f5 100644 --- a/src/p2lab/genetic/genetic.py +++ b/src/p2lab/genetic/genetic.py @@ -7,10 +7,10 @@ from poke_env import PlayerConfiguration from poke_env.player import SimpleHeuristicsPlayer +from p2lab.battling.battles import run_battles from p2lab.genetic.fitness import win_percentages from p2lab.genetic.matching import dense from p2lab.genetic.operations import fitness_mutate, mutate, selection -from p2lab.pokemon.battles import run_battles if TYPE_CHECKING: from p2lab.pokemon.teams import Team diff --git a/src/p2lab/pokemon/pokefactory.py b/src/p2lab/pokemon/pokefactory.py new file mode 100644 index 0000000..9837e49 --- /dev/null +++ b/src/p2lab/pokemon/pokefactory.py @@ -0,0 +1,93 @@ +""" +This class is used to generate Pokemon teams using the pokedex provided from +Pokemon Showdown (via poke-env) + +This is directly inspired by poke-env's diagostic_tools folder +""" +from __future__ import annotations + +import numpy as np +from poke_env.data import GenData as POKEDEX +from poke_env.teambuilder import TeambuilderPokemon + + +class PokeFactory: + def __init__(self, gen=1, drop_forms=True): + self.gen = ( + gen if gen > 4 else 3 + ) # because this seems to be the minimum gen for the pokedex + self.dex = POKEDEX(self.gen) + if drop_forms: + self.dex2mon = { + int(self.dex.pokedex[m]["num"]): m + for m in self.dex.pokedex + if "forme" not in self.dex.pokedex[m].keys() + } + else: + self.dex2mon = { + int(self.dex.pokedex[m]["num"]): m for m in self.dex.pokedex + } + + def get_pokemon_by_dexnum(self, dexnum): + return self.dex.pokedex[self.dex2mon[dexnum]] + + def get_allowed_moves(self, dexnum, level=100): + pot_moves = self.dex.learnset[self.dex2mon[dexnum]]["learnset"] + allowed_moves = [] + for move, lims in pot_moves.items(): + gens = [int(lim[0]) for lim in lims] + # TODO: write logic here to check if level is allowed + if level != 100: + msg = "Level checking not implemented yet" + raise NotImplementedError(msg) + # lvls = [int(l[1:]) for l in lims if l[1:].isdigit()] + if self.gen in gens: + allowed_moves.append(move) + return allowed_moves + + def get_allowed_abilities(self, dexnum): + return self.dex.pokedex[self.dex2mon[dexnum]]["abilities"] + + def make_pokemon(self, dexnum=None, generate_moveset=False, **kwargs): + """ + kwargs are passed to the TeambuilderPokemon constructor and can include: + - nickname + - item + - ability + - moves + - nature + - evs + - ivs + - level + - happiness + - hiddenpowertype + - gmax + """ + if dexnum < 1: + msg = "Dex number must be greater than 0" + raise ValueError(msg) + if dexnum is None: + dexnum = np.random.choice(list(self.dex2mon.keys())) + if generate_moveset or "moves" not in kwargs.keys(): + poss_moves = self.get_allowed_moves(dexnum) + moves = ( + np.random.choice(poss_moves, 4, replace=False) + if len(poss_moves) > 3 + else poss_moves + ) + kwargs["moves"] = moves + if "ivs" not in kwargs.keys(): + ivs = [31] * 6 + kwargs["ivs"] = ivs + if "evs" not in kwargs.keys(): + # TODO: implement EV generation better + evs = [510 // 6] * 6 + kwargs["evs"] = evs + if "level" not in kwargs.keys(): + kwargs["level"] = 100 + if "ability" not in kwargs.keys(): + kwargs["ability"] = np.random.choice( + list(self.get_allowed_abilities(dexnum).values()) + ) + + return TeambuilderPokemon(species=self.dex2mon[dexnum], **kwargs) diff --git a/src/p2lab/pokemon/teams.py b/src/p2lab/pokemon/teams.py index e11f0fc..f656461 100644 --- a/src/p2lab/pokemon/teams.py +++ b/src/p2lab/pokemon/teams.py @@ -9,6 +9,7 @@ "import_pool", ) +import subprocess import sys from pathlib import Path from subprocess import check_output @@ -17,6 +18,8 @@ from poke_env.teambuilder import Teambuilder from tqdm import tqdm +from p2lab.pokemon.pokefactory import PokeFactory + class Team: def __init__(self, pokemon) -> None: @@ -32,41 +35,68 @@ def yield_team(self): pass +def validate(poke_str, format="gen7anythinggoes"): + try: + check_output( + f"pokemon-showdown validate-team {format}", + shell=True, + input=poke_str, + stderr=subprocess.DEVNULL, + ) + except Exception: + return False + return True + + def generate_pool( - num_pokemon, format="gen7anythinggoes", export=False, filename="pool.txt" + num_pokemon, + format="gen7anythinggoes", + export=False, + filename="pool.txt", + use_showdown=True, + dexids=None, ): teams = [] print("Generating pokemon in batches of 6 to form pool...") # teams are produced in batches of 6, so we need to generate # a multiple of 6 teams that's greater than the number of pokemon N_seed_teams = num_pokemon // 6 + 1 - for _ in tqdm(range(N_seed_teams), desc="Generating teams!"): - poss_team = check_output(f"pokemon-showdown generate-team {format}", shell=True) - try: - check_output( - f"pokemon-showdown validate-team {format} ", - shell=True, - input=poss_team, + + if use_showdown: + for _ in tqdm(range(N_seed_teams), desc="Generating teams!"): + poss_team = check_output( + f"pokemon-showdown generate-team {format}", shell=True ) - except Exception as e: - print("Error validating team... skipping to next") - print(f"Error: {e}") - continue - n_team = _Builder().parse_showdown_team( - check_output( - "pokemon-showdown export-team ", input=poss_team, shell=True - ).decode(sys.stdout.encoding) - ) - if len(n_team) != 6: - msg = "pokemon showdown generated a team not of length 6" + if not validate(poss_team, format): + continue + n_team = _Builder().parse_showdown_team( + check_output( + "pokemon-showdown export-team ", input=poss_team, shell=True + ).decode(sys.stdout.encoding) + ) + if len(n_team) != 6: + msg = "pokemon showdown generated a team not of length 6" + raise Exception(msg) + teams.append(n_team) + pool = np.array(teams).flatten() + # trim the pool to the desired number of pokemon + pool = pool[:num_pokemon] + else: # assumption is that we homegrow some teams + # TODO: Perform validation here for this! + # TODO: Set gens here for later + pokefactory = PokeFactory(7) + if dexids is None: + msg = "dexids must be provided if not using showdown" + raise Exception() + if len(dexids) != num_pokemon: + msg = "dexids must be the same length as num_pokemon" raise Exception(msg) - teams.append(n_team) - - pool = np.array(teams).flatten() - - # trim the pool to the desired number of pokemon - pool = pool[:num_pokemon] - + pool = [] + for dexid in tqdm(dexids): + pot_poke = pokefactory.make_pokemon(dexid) + while not validate(bytes(pot_poke.formatted, "utf-8"), format): + pot_poke = pokefactory.make_pokemon(dexid) + pool.append(pot_poke) if export: with Path.open(filename, "w") as f: packed = "\n".join([mon.formatted for mon in pool]) diff --git a/tests/test_battles.py b/tests/test_battles.py new file mode 100644 index 0000000..340185b --- /dev/null +++ b/tests/test_battles.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +import pytest +from poke_env import PlayerConfiguration +from poke_env.player import SimpleHeuristicsPlayer + +from p2lab.battling.battles import run_battles +from p2lab.pokemon import pokefactory +from p2lab.pokemon.teams import Team + +pytest_plugins = ("pytest_asyncio",) + + +@pytest.mark.asyncio() +async def test_battle_eevee_pikachu_pokes(): + p = pokefactory.PokeFactory() + eevee = p.make_pokemon(133, moves=["tackle", "growl"], level=5) + pikachu = p.make_pokemon(25, moves=["thundershock", "growl"], level=5) + team1 = Team([eevee]) + team2 = Team([pikachu]) + teams = [team1, team2] + matches = [[0, 1]] + + player_1 = SimpleHeuristicsPlayer( + PlayerConfiguration("Player 1", None), battle_format="gen7anythinggoes" + ) + player_2 = SimpleHeuristicsPlayer( + PlayerConfiguration("Player 2", None), battle_format="gen7anythinggoes" + ) + res = await run_battles(matches, teams, player_1, player_2, battles_per_match=1) + assert res is not None + + +@pytest.mark.asyncio() +async def test_battle_mewtwo_obliterates_eevee(): + p = pokefactory.PokeFactory() + eevee = p.make_pokemon(133, moves=["tackle", "growl"], level=5) + mewtwo = p.make_pokemon(150, moves=["psychic"], level=100) + team1 = Team([eevee]) + team2 = Team([mewtwo]) + teams = [team1, team2] + matches = [[0, 1]] + player_1 = SimpleHeuristicsPlayer( + PlayerConfiguration("Player 3", None), battle_format="gen7anythinggoes" + ) + player_2 = SimpleHeuristicsPlayer( + PlayerConfiguration("Player 4", None), battle_format="gen7anythinggoes" + ) + res = await run_battles(matches, teams, player_1, player_2, battles_per_match=10) + mewtwo_wins = res[0][1] + eevee_wins = res[0][0] + assert mewtwo_wins > eevee_wins + + +@pytest.mark.asyncio() +async def test_battle_eevee_pikachu_formats(): + p = pokefactory.PokeFactory() + eevee = p.make_pokemon(133, moves=["tackle", "growl"], level=5) + pikachu = p.make_pokemon(25, moves=["thundershock", "growl"], level=5) + team1 = Team([eevee]) + team2 = Team([pikachu]) + teams = [team1, team2] + matches = [[0, 1]] + counter = iter(range(10, 20)) + battle_formats = [ + "gen4anythinggoes", + "gen6anythinggoes", + "gen7anythinggoes", + "gen8anythinggoes", + ] + for battle_format in battle_formats: + player_1 = SimpleHeuristicsPlayer( + PlayerConfiguration(f"Player {next(counter)}", None), + battle_format=battle_format, + ) + player_2 = SimpleHeuristicsPlayer( + PlayerConfiguration(f"Player {next(counter)}", None), + battle_format=battle_format, + ) + res = await run_battles(matches, teams, player_1, player_2, battles_per_match=1) + assert res is not None diff --git a/tests/test_pokedex.py b/tests/test_pokedex.py new file mode 100644 index 0000000..94acce4 --- /dev/null +++ b/tests/test_pokedex.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +import numpy as np + +from p2lab.pokemon import pokefactory + + +def test_pokedex(): + p = pokefactory.PokeFactory() + assert p is not None + + +def test_eevee_fetch(): + p = pokefactory.PokeFactory() + eevee = p.get_pokemon_by_dexnum(133) + assert eevee["baseSpecies"].lower() == "eevee" + + +def test_bulbasaur_fetch(): + p = pokefactory.PokeFactory() + bulb = p.get_pokemon_by_dexnum(1) + assert bulb["baseSpecies"].lower() == "bulbasaur" + + +def test_eevee_moves(): + p = pokefactory.PokeFactory() + eevee_moves = p.get_allowed_moves(133) + assert len(eevee_moves) > 0 + + +def test_bulbasaur_moves(): + p = pokefactory.PokeFactory() + bulb_moves = p.get_allowed_moves(1) + assert len(bulb_moves) > 0 + + +def test_eevee_is_created(): + p = pokefactory.PokeFactory() + eevee = p.make_pokemon(133) + assert eevee is not None + + +def test_eevee_is_created_with_moves(): + p = pokefactory.PokeFactory() + eevee = p.make_pokemon(133, moves=["tackle", "growl"]) + assert eevee is not None + + +def test_random_pokemon_is_created_with_moves(): + p = pokefactory.PokeFactory() + dexnum = np.random.randint(1, 151) + while dexnum == 132: + dexnum = np.random.randint(1, 151) + poke = p.make_pokemon(dexnum=dexnum, generate_moveset=True) + assert len(poke.moves) == 4 + + +def test_ditto_is_created_with_moves(): + p = pokefactory.PokeFactory() + ditto = p.make_pokemon(132) + assert len(ditto.moves) == 1 + + +def test_all_gen1_pokemon_can_be_created(): + p = pokefactory.PokeFactory() + for dexnum in range(1, 152): + poke = p.make_pokemon(dexnum=dexnum, generate_moveset=True) + assert len(poke.moves) > 0 + + +def test_invalid_dex_raised(): + p = pokefactory.PokeFactory() + try: + p.make_pokemon(dexnum=0) + except ValueError: + assert True + else: + raise AssertionError() + + +def test_adding_item_to_pokemon(): + p = pokefactory.PokeFactory() + ditto = p.make_pokemon(132, item="choice scarf") + assert ditto.item == "choice scarf"