Skip to content

Commit

Permalink
Tests: add world load benchmark (ArchipelagoMW#2768)
Browse files Browse the repository at this point in the history
  • Loading branch information
Berserker66 authored and Jouramie committed Feb 28, 2024
1 parent 69519d4 commit dad4c7c
Show file tree
Hide file tree
Showing 6 changed files with 181 additions and 128 deletions.
132 changes: 6 additions & 126 deletions test/benchmark/__init__.py
Original file line number Diff line number Diff line change
@@ -1,127 +1,7 @@
import time


class TimeIt:
def __init__(self, name: str, time_logger=None):
self.name = name
self.logger = time_logger
self.timer = None
self.end_timer = None

def __enter__(self):
self.timer = time.perf_counter()
return self

@property
def dif(self):
return self.end_timer - self.timer

def __exit__(self, exc_type, exc_val, exc_tb):
if not self.end_timer:
self.end_timer = time.perf_counter()
if self.logger:
self.logger.info(f"{self.dif:.4f} seconds in {self.name}.")


if __name__ == "__main__":
import argparse
import logging
import gc
import collections
import typing

# makes this module runnable from its folder.
import sys
import os
sys.path.remove(os.path.dirname(__file__))
new_home = os.path.normpath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir))
os.chdir(new_home)
sys.path.append(new_home)

from Utils import init_logging, local_path
local_path.cached_path = new_home
from BaseClasses import MultiWorld, CollectionState, Location
from worlds import AutoWorld
from worlds.AutoWorld import call_all

init_logging("Benchmark Runner")
logger = logging.getLogger("Benchmark")


class BenchmarkRunner:
gen_steps: typing.Tuple[str, ...] = (
"generate_early", "create_regions", "create_items", "set_rules", "generate_basic", "pre_fill")
rule_iterations: int = 100_000

if sys.version_info >= (3, 9):
@staticmethod
def format_times_from_counter(counter: collections.Counter[str], top: int = 5) -> str:
return "\n".join(f" {time:.4f} in {name}" for name, time in counter.most_common(top))
else:
@staticmethod
def format_times_from_counter(counter: collections.Counter, top: int = 5) -> str:
return "\n".join(f" {time:.4f} in {name}" for name, time in counter.most_common(top))

def location_test(self, test_location: Location, state: CollectionState, state_name: str) -> float:
with TimeIt(f"{test_location.game} {self.rule_iterations} "
f"runs of {test_location}.access_rule({state_name})", logger) as t:
for _ in range(self.rule_iterations):
test_location.access_rule(state)
# if time is taken to disentangle complex ref chains,
# this time should be attributed to the rule.
gc.collect()
return t.dif

def main(self):
for game in sorted(AutoWorld.AutoWorldRegister.world_types):
summary_data: typing.Dict[str, collections.Counter[str]] = {
"empty_state": collections.Counter(),
"all_state": collections.Counter(),
}
try:
multiworld = MultiWorld(1)
multiworld.game[1] = game
multiworld.player_name = {1: "Tester"}
multiworld.set_seed(0)
multiworld.state = CollectionState(multiworld)
args = argparse.Namespace()
for name, option in AutoWorld.AutoWorldRegister.world_types[game].options_dataclass.type_hints.items():
setattr(args, name, {
1: option.from_any(getattr(option, "default"))
})
multiworld.set_options(args)

gc.collect()
for step in self.gen_steps:
with TimeIt(f"{game} step {step}", logger):
call_all(multiworld, step)
gc.collect()

locations = sorted(multiworld.get_unfilled_locations())
if not locations:
continue

all_state = multiworld.get_all_state(False)
for location in locations:
time_taken = self.location_test(location, multiworld.state, "empty_state")
summary_data["empty_state"][location.name] = time_taken

time_taken = self.location_test(location, all_state, "all_state")
summary_data["all_state"][location.name] = time_taken

total_empty_state = sum(summary_data["empty_state"].values())
total_all_state = sum(summary_data["all_state"].values())

logger.info(f"{game} took {total_empty_state/len(locations):.4f} "
f"seconds per location in empty_state and {total_all_state/len(locations):.4f} "
f"in all_state. (all times summed for {self.rule_iterations} runs.)")
logger.info(f"Top times in empty_state:\n"
f"{self.format_times_from_counter(summary_data['empty_state'])}")
logger.info(f"Top times in all_state:\n"
f"{self.format_times_from_counter(summary_data['all_state'])}")

except Exception as e:
logger.exception(e)

runner = BenchmarkRunner()
runner.main()
import path_change
path_change.change_home()
import load_worlds
load_worlds.run_load_worlds_benchmark()
import locations
locations.run_locations_benchmark()
27 changes: 27 additions & 0 deletions test/benchmark/load_worlds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
def run_load_worlds_benchmark():
"""List worlds and their load time.
Note that any first-time imports will be attributed to that world, as it is cached afterwards.
Likely best used with isolated worlds to measure their time alone."""
import logging

from Utils import init_logging

# get some general imports cached, to prevent it from being attributed to one world.
import orjson
orjson.loads("{}") # orjson runs initialization on first use

import BaseClasses, Launcher, Fill

from worlds import world_sources

init_logging("Benchmark Runner")
logger = logging.getLogger("Benchmark")

for module in world_sources:
logger.info(f"{module} took {module.time_taken:.4f} seconds.")


if __name__ == "__main__":
from path_change import change_home
change_home()
run_load_worlds_benchmark()
101 changes: 101 additions & 0 deletions test/benchmark/locations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
def run_locations_benchmark():
import argparse
import logging
import gc
import collections
import typing
import sys

from time_it import TimeIt

from Utils import init_logging
from BaseClasses import MultiWorld, CollectionState, Location
from worlds import AutoWorld
from worlds.AutoWorld import call_all

init_logging("Benchmark Runner")
logger = logging.getLogger("Benchmark")

class BenchmarkRunner:
gen_steps: typing.Tuple[str, ...] = (
"generate_early", "create_regions", "create_items", "set_rules", "generate_basic", "pre_fill")
rule_iterations: int = 100_000

if sys.version_info >= (3, 9):
@staticmethod
def format_times_from_counter(counter: collections.Counter[str], top: int = 5) -> str:
return "\n".join(f" {time:.4f} in {name}" for name, time in counter.most_common(top))
else:
@staticmethod
def format_times_from_counter(counter: collections.Counter, top: int = 5) -> str:
return "\n".join(f" {time:.4f} in {name}" for name, time in counter.most_common(top))

def location_test(self, test_location: Location, state: CollectionState, state_name: str) -> float:
with TimeIt(f"{test_location.game} {self.rule_iterations} "
f"runs of {test_location}.access_rule({state_name})", logger) as t:
for _ in range(self.rule_iterations):
test_location.access_rule(state)
# if time is taken to disentangle complex ref chains,
# this time should be attributed to the rule.
gc.collect()
return t.dif

def main(self):
for game in sorted(AutoWorld.AutoWorldRegister.world_types):
summary_data: typing.Dict[str, collections.Counter[str]] = {
"empty_state": collections.Counter(),
"all_state": collections.Counter(),
}
try:
multiworld = MultiWorld(1)
multiworld.game[1] = game
multiworld.player_name = {1: "Tester"}
multiworld.set_seed(0)
multiworld.state = CollectionState(multiworld)
args = argparse.Namespace()
for name, option in AutoWorld.AutoWorldRegister.world_types[game].options_dataclass.type_hints.items():
setattr(args, name, {
1: option.from_any(getattr(option, "default"))
})
multiworld.set_options(args)

gc.collect()
for step in self.gen_steps:
with TimeIt(f"{game} step {step}", logger):
call_all(multiworld, step)
gc.collect()

locations = sorted(multiworld.get_unfilled_locations())
if not locations:
continue

all_state = multiworld.get_all_state(False)
for location in locations:
time_taken = self.location_test(location, multiworld.state, "empty_state")
summary_data["empty_state"][location.name] = time_taken

time_taken = self.location_test(location, all_state, "all_state")
summary_data["all_state"][location.name] = time_taken

total_empty_state = sum(summary_data["empty_state"].values())
total_all_state = sum(summary_data["all_state"].values())

logger.info(f"{game} took {total_empty_state/len(locations):.4f} "
f"seconds per location in empty_state and {total_all_state/len(locations):.4f} "
f"in all_state. (all times summed for {self.rule_iterations} runs.)")
logger.info(f"Top times in empty_state:\n"
f"{self.format_times_from_counter(summary_data['empty_state'])}")
logger.info(f"Top times in all_state:\n"
f"{self.format_times_from_counter(summary_data['all_state'])}")

except Exception as e:
logger.exception(e)

runner = BenchmarkRunner()
runner.main()


if __name__ == "__main__":
from path_change import change_home
change_home()
run_locations_benchmark()
16 changes: 16 additions & 0 deletions test/benchmark/path_change.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import sys
import os


def change_home():
"""Allow scripts to run from "this" folder."""
old_home = os.path.dirname(__file__)
sys.path.remove(old_home)
new_home = os.path.normpath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir))
os.chdir(new_home)
sys.path.append(new_home)
# fallback to local import
sys.path.append(old_home)

from Utils import local_path
local_path.cached_path = new_home
23 changes: 23 additions & 0 deletions test/benchmark/time_it.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import time


class TimeIt:
def __init__(self, name: str, time_logger=None):
self.name = name
self.logger = time_logger
self.timer = None
self.end_timer = None

def __enter__(self):
self.timer = time.perf_counter()
return self

@property
def dif(self):
return self.end_timer - self.timer

def __exit__(self, exc_type, exc_val, exc_tb):
if not self.end_timer:
self.end_timer = time.perf_counter()
if self.logger:
self.logger.info(f"{self.dif:.4f} seconds in {self.name}.")
10 changes: 8 additions & 2 deletions worlds/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import sys
import warnings
import zipimport
from typing import Dict, List, NamedTuple, TypedDict
import time
import dataclasses
from typing import Dict, List, TypedDict, Optional

from Utils import local_path, user_path

Expand Down Expand Up @@ -34,10 +36,12 @@ class DataPackage(TypedDict):
games: Dict[str, GamesPackage]


class WorldSource(NamedTuple):
@dataclasses.dataclass(order=True)
class WorldSource:
path: str # typically relative path from this module
is_zip: bool = False
relative: bool = True # relative to regular world import folder
time_taken: Optional[float] = None

def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.path}, is_zip={self.is_zip}, relative={self.relative})"
Expand All @@ -50,6 +54,7 @@ def resolved_path(self) -> str:

def load(self) -> bool:
try:
start = time.perf_counter()
if self.is_zip:
importer = zipimport.zipimporter(self.resolved_path)
if hasattr(importer, "find_spec"): # new in Python 3.10
Expand All @@ -69,6 +74,7 @@ def load(self) -> bool:
importer.exec_module(mod)
else:
importlib.import_module(f".{self.path}", "worlds")
self.time_taken = time.perf_counter()-start
return True

except Exception:
Expand Down

0 comments on commit dad4c7c

Please sign in to comment.