Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/team -> adds team generation and validation #52

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion examples/poke_1v1_fighter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from poke_env import PlayerConfiguration
from poke_env.player import SimpleHeuristicsPlayer

from p2lab.pokemon.battles import run_battles
from p2lab.battling.battles import run_battles
from p2lab.pokemon.premade import gen_1_pokemon
from p2lab.pokemon.teams import generate_teams, import_pool

Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,4 @@ pokemon-showdown/

# Outputs
outputs/*
simple_match.py
18 changes: 12 additions & 6 deletions src/p2lab/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@

from p2lab.genetic.genetic import genetic_algorithm
from p2lab.pokemon.premade import gen_1_pokemon
from p2lab.pokemon.teams import generate_teams, import_pool
from p2lab.pokemon.teams import generate_pool, generate_teams, import_pool


async def main_loop(num_teams, team_size, num_generations, unique):
async def main_loop(num_teams, team_size, num_generations, unique, use_premade=False):
# generate the pool
pool = import_pool(gen_1_pokemon())
pool = (
import_pool(gen_1_pokemon())
if use_premade
else generate_pool(151, use_showdown=False, dexids=np.arange(1, 152))
)
seed_teams = generate_teams(pool, num_teams, team_size, unique=unique)
# crossover_fn = build_crossover_fn(locus_swap, locus=0)
# run the genetic algorithm
Expand Down Expand Up @@ -74,11 +78,13 @@ def parse_args():
def main():
args = parse_args()

if args["s"] is not None:
np.random.seed(args["s"])
if args["seed"] is not None:
np.random.seed(args["seed"])

asyncio.get_event_loop().run_until_complete(
main_loop(args["n"], args["t"], args["g"], args["u"])
main_loop(
args["numteams"], args["teamsize"], args["generations"], args["unique"]
)
)


Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion src/p2lab/genetic/genetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from poke_env import PlayerConfiguration
from poke_env.player import SimpleHeuristicsPlayer

from p2lab.battling.battles import run_battles
from p2lab.genetic.fitness import win_percentages
from p2lab.genetic.matching import dense
from p2lab.genetic.operations import fitness_mutate, mutate, selection
from p2lab.pokemon.battles import run_battles

if TYPE_CHECKING:
from p2lab.pokemon.teams import Team
Expand Down
93 changes: 93 additions & 0 deletions src/p2lab/pokemon/pokefactory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""
This class is used to generate Pokemon teams using the pokedex provided from
Pokemon Showdown (via poke-env)

This is directly inspired by poke-env's diagostic_tools folder
"""
from __future__ import annotations

import numpy as np
from poke_env.data import GenData as POKEDEX
from poke_env.teambuilder import TeambuilderPokemon


class PokeFactory:
def __init__(self, gen=1, drop_forms=True):
self.gen = (
gen if gen > 4 else 3
) # because this seems to be the minimum gen for the pokedex
self.dex = POKEDEX(self.gen)
if drop_forms:
self.dex2mon = {
int(self.dex.pokedex[m]["num"]): m
for m in self.dex.pokedex
if "forme" not in self.dex.pokedex[m].keys()
}
else:
self.dex2mon = {
int(self.dex.pokedex[m]["num"]): m for m in self.dex.pokedex
}

def get_pokemon_by_dexnum(self, dexnum):
return self.dex.pokedex[self.dex2mon[dexnum]]

def get_allowed_moves(self, dexnum, level=100):
pot_moves = self.dex.learnset[self.dex2mon[dexnum]]["learnset"]
allowed_moves = []
for move, lims in pot_moves.items():
gens = [int(lim[0]) for lim in lims]
# TODO: write logic here to check if level is allowed
if level != 100:
msg = "Level checking not implemented yet"
raise NotImplementedError(msg)
# lvls = [int(l[1:]) for l in lims if l[1:].isdigit()]
if self.gen in gens:
allowed_moves.append(move)
return allowed_moves

def get_allowed_abilities(self, dexnum):
return self.dex.pokedex[self.dex2mon[dexnum]]["abilities"]

def make_pokemon(self, dexnum=None, generate_moveset=False, **kwargs):
"""
kwargs are passed to the TeambuilderPokemon constructor and can include:
- nickname
- item
- ability
- moves
- nature
- evs
- ivs
- level
- happiness
- hiddenpowertype
- gmax
"""
if dexnum < 1:
msg = "Dex number must be greater than 0"
raise ValueError(msg)
if dexnum is None:
dexnum = np.random.choice(list(self.dex2mon.keys()))
if generate_moveset or "moves" not in kwargs.keys():
poss_moves = self.get_allowed_moves(dexnum)
moves = (
np.random.choice(poss_moves, 4, replace=False)
if len(poss_moves) > 3
else poss_moves
)
kwargs["moves"] = moves
if "ivs" not in kwargs.keys():
ivs = [31] * 6
kwargs["ivs"] = ivs
if "evs" not in kwargs.keys():
# TODO: implement EV generation better
evs = [510 // 6] * 6
kwargs["evs"] = evs
if "level" not in kwargs.keys():
kwargs["level"] = 100
if "ability" not in kwargs.keys():
kwargs["ability"] = np.random.choice(
list(self.get_allowed_abilities(dexnum).values())
)

return TeambuilderPokemon(species=self.dex2mon[dexnum], **kwargs)
82 changes: 56 additions & 26 deletions src/p2lab/pokemon/teams.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"import_pool",
)

import subprocess
import sys
from pathlib import Path
from subprocess import check_output
Expand All @@ -17,6 +18,8 @@
from poke_env.teambuilder import Teambuilder
from tqdm import tqdm

from p2lab.pokemon.pokefactory import PokeFactory


class Team:
def __init__(self, pokemon) -> None:
Expand All @@ -32,41 +35,68 @@ def yield_team(self):
pass


def validate(poke_str, format="gen7anythinggoes"):
try:
check_output(
f"pokemon-showdown validate-team {format}",
shell=True,
input=poke_str,
stderr=subprocess.DEVNULL,
)
except Exception:
return False
return True


def generate_pool(
num_pokemon, format="gen7anythinggoes", export=False, filename="pool.txt"
num_pokemon,
format="gen7anythinggoes",
export=False,
filename="pool.txt",
use_showdown=True,
dexids=None,
):
teams = []
print("Generating pokemon in batches of 6 to form pool...")
# teams are produced in batches of 6, so we need to generate
# a multiple of 6 teams that's greater than the number of pokemon
N_seed_teams = num_pokemon // 6 + 1
for _ in tqdm(range(N_seed_teams), desc="Generating teams!"):
poss_team = check_output(f"pokemon-showdown generate-team {format}", shell=True)
try:
check_output(
f"pokemon-showdown validate-team {format} ",
shell=True,
input=poss_team,

if use_showdown:
for _ in tqdm(range(N_seed_teams), desc="Generating teams!"):
poss_team = check_output(
f"pokemon-showdown generate-team {format}", shell=True
)
except Exception as e:
print("Error validating team... skipping to next")
print(f"Error: {e}")
continue
n_team = _Builder().parse_showdown_team(
check_output(
"pokemon-showdown export-team ", input=poss_team, shell=True
).decode(sys.stdout.encoding)
)
if len(n_team) != 6:
msg = "pokemon showdown generated a team not of length 6"
if not validate(poss_team, format):
continue
n_team = _Builder().parse_showdown_team(
check_output(
"pokemon-showdown export-team ", input=poss_team, shell=True
).decode(sys.stdout.encoding)
)
if len(n_team) != 6:
msg = "pokemon showdown generated a team not of length 6"
raise Exception(msg)
teams.append(n_team)
pool = np.array(teams).flatten()
# trim the pool to the desired number of pokemon
pool = pool[:num_pokemon]
else: # assumption is that we homegrow some teams
# TODO: Perform validation here for this!
# TODO: Set gens here for later
pokefactory = PokeFactory(7)
if dexids is None:
msg = "dexids must be provided if not using showdown"
raise Exception()
if len(dexids) != num_pokemon:
msg = "dexids must be the same length as num_pokemon"
raise Exception(msg)
teams.append(n_team)

pool = np.array(teams).flatten()

# trim the pool to the desired number of pokemon
pool = pool[:num_pokemon]

pool = []
for dexid in tqdm(dexids):
pot_poke = pokefactory.make_pokemon(dexid)
while not validate(bytes(pot_poke.formatted, "utf-8"), format):
pot_poke = pokefactory.make_pokemon(dexid)
pool.append(pot_poke)
if export:
with Path.open(filename, "w") as f:
packed = "\n".join([mon.formatted for mon in pool])
Expand Down
81 changes: 81 additions & 0 deletions tests/test_battles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from __future__ import annotations

import pytest
from poke_env import PlayerConfiguration
from poke_env.player import SimpleHeuristicsPlayer

from p2lab.battling.battles import run_battles
from p2lab.pokemon import pokefactory
from p2lab.pokemon.teams import Team

pytest_plugins = ("pytest_asyncio",)


@pytest.mark.asyncio()
async def test_battle_eevee_pikachu_pokes():
p = pokefactory.PokeFactory()
eevee = p.make_pokemon(133, moves=["tackle", "growl"], level=5)
pikachu = p.make_pokemon(25, moves=["thundershock", "growl"], level=5)
team1 = Team([eevee])
team2 = Team([pikachu])
teams = [team1, team2]
matches = [[0, 1]]

player_1 = SimpleHeuristicsPlayer(
PlayerConfiguration("Player 1", None), battle_format="gen7anythinggoes"
)
player_2 = SimpleHeuristicsPlayer(
PlayerConfiguration("Player 2", None), battle_format="gen7anythinggoes"
)
res = await run_battles(matches, teams, player_1, player_2, battles_per_match=1)
assert res is not None


@pytest.mark.asyncio()
async def test_battle_mewtwo_obliterates_eevee():
p = pokefactory.PokeFactory()
eevee = p.make_pokemon(133, moves=["tackle", "growl"], level=5)
mewtwo = p.make_pokemon(150, moves=["psychic"], level=100)
team1 = Team([eevee])
team2 = Team([mewtwo])
teams = [team1, team2]
matches = [[0, 1]]
player_1 = SimpleHeuristicsPlayer(
PlayerConfiguration("Player 3", None), battle_format="gen7anythinggoes"
)
player_2 = SimpleHeuristicsPlayer(
PlayerConfiguration("Player 4", None), battle_format="gen7anythinggoes"
)
res = await run_battles(matches, teams, player_1, player_2, battles_per_match=10)
mewtwo_wins = res[0][1]
eevee_wins = res[0][0]
assert mewtwo_wins > eevee_wins


@pytest.mark.asyncio()
async def test_battle_eevee_pikachu_formats():
p = pokefactory.PokeFactory()
eevee = p.make_pokemon(133, moves=["tackle", "growl"], level=5)
pikachu = p.make_pokemon(25, moves=["thundershock", "growl"], level=5)
team1 = Team([eevee])
team2 = Team([pikachu])
teams = [team1, team2]
matches = [[0, 1]]
counter = iter(range(10, 20))
battle_formats = [
"gen4anythinggoes",
"gen6anythinggoes",
"gen7anythinggoes",
"gen8anythinggoes",
]
for battle_format in battle_formats:
player_1 = SimpleHeuristicsPlayer(
PlayerConfiguration(f"Player {next(counter)}", None),
battle_format=battle_format,
)
player_2 = SimpleHeuristicsPlayer(
PlayerConfiguration(f"Player {next(counter)}", None),
battle_format=battle_format,
)
res = await run_battles(matches, teams, player_1, player_2, battles_per_match=1)
assert res is not None
Loading