forked from ArchipelagoMW/Archipelago
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Tests: add world load benchmark (ArchipelagoMW#2768)
- Loading branch information
1 parent
69519d4
commit dad4c7c
Showing
6 changed files
with
181 additions
and
128 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,127 +1,7 @@ | ||
import time | ||
|
||
|
||
class TimeIt: | ||
def __init__(self, name: str, time_logger=None): | ||
self.name = name | ||
self.logger = time_logger | ||
self.timer = None | ||
self.end_timer = None | ||
|
||
def __enter__(self): | ||
self.timer = time.perf_counter() | ||
return self | ||
|
||
@property | ||
def dif(self): | ||
return self.end_timer - self.timer | ||
|
||
def __exit__(self, exc_type, exc_val, exc_tb): | ||
if not self.end_timer: | ||
self.end_timer = time.perf_counter() | ||
if self.logger: | ||
self.logger.info(f"{self.dif:.4f} seconds in {self.name}.") | ||
|
||
|
||
if __name__ == "__main__": | ||
import argparse | ||
import logging | ||
import gc | ||
import collections | ||
import typing | ||
|
||
# makes this module runnable from its folder. | ||
import sys | ||
import os | ||
sys.path.remove(os.path.dirname(__file__)) | ||
new_home = os.path.normpath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir)) | ||
os.chdir(new_home) | ||
sys.path.append(new_home) | ||
|
||
from Utils import init_logging, local_path | ||
local_path.cached_path = new_home | ||
from BaseClasses import MultiWorld, CollectionState, Location | ||
from worlds import AutoWorld | ||
from worlds.AutoWorld import call_all | ||
|
||
init_logging("Benchmark Runner") | ||
logger = logging.getLogger("Benchmark") | ||
|
||
|
||
class BenchmarkRunner: | ||
gen_steps: typing.Tuple[str, ...] = ( | ||
"generate_early", "create_regions", "create_items", "set_rules", "generate_basic", "pre_fill") | ||
rule_iterations: int = 100_000 | ||
|
||
if sys.version_info >= (3, 9): | ||
@staticmethod | ||
def format_times_from_counter(counter: collections.Counter[str], top: int = 5) -> str: | ||
return "\n".join(f" {time:.4f} in {name}" for name, time in counter.most_common(top)) | ||
else: | ||
@staticmethod | ||
def format_times_from_counter(counter: collections.Counter, top: int = 5) -> str: | ||
return "\n".join(f" {time:.4f} in {name}" for name, time in counter.most_common(top)) | ||
|
||
def location_test(self, test_location: Location, state: CollectionState, state_name: str) -> float: | ||
with TimeIt(f"{test_location.game} {self.rule_iterations} " | ||
f"runs of {test_location}.access_rule({state_name})", logger) as t: | ||
for _ in range(self.rule_iterations): | ||
test_location.access_rule(state) | ||
# if time is taken to disentangle complex ref chains, | ||
# this time should be attributed to the rule. | ||
gc.collect() | ||
return t.dif | ||
|
||
def main(self): | ||
for game in sorted(AutoWorld.AutoWorldRegister.world_types): | ||
summary_data: typing.Dict[str, collections.Counter[str]] = { | ||
"empty_state": collections.Counter(), | ||
"all_state": collections.Counter(), | ||
} | ||
try: | ||
multiworld = MultiWorld(1) | ||
multiworld.game[1] = game | ||
multiworld.player_name = {1: "Tester"} | ||
multiworld.set_seed(0) | ||
multiworld.state = CollectionState(multiworld) | ||
args = argparse.Namespace() | ||
for name, option in AutoWorld.AutoWorldRegister.world_types[game].options_dataclass.type_hints.items(): | ||
setattr(args, name, { | ||
1: option.from_any(getattr(option, "default")) | ||
}) | ||
multiworld.set_options(args) | ||
|
||
gc.collect() | ||
for step in self.gen_steps: | ||
with TimeIt(f"{game} step {step}", logger): | ||
call_all(multiworld, step) | ||
gc.collect() | ||
|
||
locations = sorted(multiworld.get_unfilled_locations()) | ||
if not locations: | ||
continue | ||
|
||
all_state = multiworld.get_all_state(False) | ||
for location in locations: | ||
time_taken = self.location_test(location, multiworld.state, "empty_state") | ||
summary_data["empty_state"][location.name] = time_taken | ||
|
||
time_taken = self.location_test(location, all_state, "all_state") | ||
summary_data["all_state"][location.name] = time_taken | ||
|
||
total_empty_state = sum(summary_data["empty_state"].values()) | ||
total_all_state = sum(summary_data["all_state"].values()) | ||
|
||
logger.info(f"{game} took {total_empty_state/len(locations):.4f} " | ||
f"seconds per location in empty_state and {total_all_state/len(locations):.4f} " | ||
f"in all_state. (all times summed for {self.rule_iterations} runs.)") | ||
logger.info(f"Top times in empty_state:\n" | ||
f"{self.format_times_from_counter(summary_data['empty_state'])}") | ||
logger.info(f"Top times in all_state:\n" | ||
f"{self.format_times_from_counter(summary_data['all_state'])}") | ||
|
||
except Exception as e: | ||
logger.exception(e) | ||
|
||
runner = BenchmarkRunner() | ||
runner.main() | ||
import path_change | ||
path_change.change_home() | ||
import load_worlds | ||
load_worlds.run_load_worlds_benchmark() | ||
import locations | ||
locations.run_locations_benchmark() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
def run_load_worlds_benchmark(): | ||
"""List worlds and their load time. | ||
Note that any first-time imports will be attributed to that world, as it is cached afterwards. | ||
Likely best used with isolated worlds to measure their time alone.""" | ||
import logging | ||
|
||
from Utils import init_logging | ||
|
||
# get some general imports cached, to prevent it from being attributed to one world. | ||
import orjson | ||
orjson.loads("{}") # orjson runs initialization on first use | ||
|
||
import BaseClasses, Launcher, Fill | ||
|
||
from worlds import world_sources | ||
|
||
init_logging("Benchmark Runner") | ||
logger = logging.getLogger("Benchmark") | ||
|
||
for module in world_sources: | ||
logger.info(f"{module} took {module.time_taken:.4f} seconds.") | ||
|
||
|
||
if __name__ == "__main__": | ||
from path_change import change_home | ||
change_home() | ||
run_load_worlds_benchmark() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
def run_locations_benchmark(): | ||
import argparse | ||
import logging | ||
import gc | ||
import collections | ||
import typing | ||
import sys | ||
|
||
from time_it import TimeIt | ||
|
||
from Utils import init_logging | ||
from BaseClasses import MultiWorld, CollectionState, Location | ||
from worlds import AutoWorld | ||
from worlds.AutoWorld import call_all | ||
|
||
init_logging("Benchmark Runner") | ||
logger = logging.getLogger("Benchmark") | ||
|
||
class BenchmarkRunner: | ||
gen_steps: typing.Tuple[str, ...] = ( | ||
"generate_early", "create_regions", "create_items", "set_rules", "generate_basic", "pre_fill") | ||
rule_iterations: int = 100_000 | ||
|
||
if sys.version_info >= (3, 9): | ||
@staticmethod | ||
def format_times_from_counter(counter: collections.Counter[str], top: int = 5) -> str: | ||
return "\n".join(f" {time:.4f} in {name}" for name, time in counter.most_common(top)) | ||
else: | ||
@staticmethod | ||
def format_times_from_counter(counter: collections.Counter, top: int = 5) -> str: | ||
return "\n".join(f" {time:.4f} in {name}" for name, time in counter.most_common(top)) | ||
|
||
def location_test(self, test_location: Location, state: CollectionState, state_name: str) -> float: | ||
with TimeIt(f"{test_location.game} {self.rule_iterations} " | ||
f"runs of {test_location}.access_rule({state_name})", logger) as t: | ||
for _ in range(self.rule_iterations): | ||
test_location.access_rule(state) | ||
# if time is taken to disentangle complex ref chains, | ||
# this time should be attributed to the rule. | ||
gc.collect() | ||
return t.dif | ||
|
||
def main(self): | ||
for game in sorted(AutoWorld.AutoWorldRegister.world_types): | ||
summary_data: typing.Dict[str, collections.Counter[str]] = { | ||
"empty_state": collections.Counter(), | ||
"all_state": collections.Counter(), | ||
} | ||
try: | ||
multiworld = MultiWorld(1) | ||
multiworld.game[1] = game | ||
multiworld.player_name = {1: "Tester"} | ||
multiworld.set_seed(0) | ||
multiworld.state = CollectionState(multiworld) | ||
args = argparse.Namespace() | ||
for name, option in AutoWorld.AutoWorldRegister.world_types[game].options_dataclass.type_hints.items(): | ||
setattr(args, name, { | ||
1: option.from_any(getattr(option, "default")) | ||
}) | ||
multiworld.set_options(args) | ||
|
||
gc.collect() | ||
for step in self.gen_steps: | ||
with TimeIt(f"{game} step {step}", logger): | ||
call_all(multiworld, step) | ||
gc.collect() | ||
|
||
locations = sorted(multiworld.get_unfilled_locations()) | ||
if not locations: | ||
continue | ||
|
||
all_state = multiworld.get_all_state(False) | ||
for location in locations: | ||
time_taken = self.location_test(location, multiworld.state, "empty_state") | ||
summary_data["empty_state"][location.name] = time_taken | ||
|
||
time_taken = self.location_test(location, all_state, "all_state") | ||
summary_data["all_state"][location.name] = time_taken | ||
|
||
total_empty_state = sum(summary_data["empty_state"].values()) | ||
total_all_state = sum(summary_data["all_state"].values()) | ||
|
||
logger.info(f"{game} took {total_empty_state/len(locations):.4f} " | ||
f"seconds per location in empty_state and {total_all_state/len(locations):.4f} " | ||
f"in all_state. (all times summed for {self.rule_iterations} runs.)") | ||
logger.info(f"Top times in empty_state:\n" | ||
f"{self.format_times_from_counter(summary_data['empty_state'])}") | ||
logger.info(f"Top times in all_state:\n" | ||
f"{self.format_times_from_counter(summary_data['all_state'])}") | ||
|
||
except Exception as e: | ||
logger.exception(e) | ||
|
||
runner = BenchmarkRunner() | ||
runner.main() | ||
|
||
|
||
if __name__ == "__main__": | ||
from path_change import change_home | ||
change_home() | ||
run_locations_benchmark() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import sys | ||
import os | ||
|
||
|
||
def change_home(): | ||
"""Allow scripts to run from "this" folder.""" | ||
old_home = os.path.dirname(__file__) | ||
sys.path.remove(old_home) | ||
new_home = os.path.normpath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir)) | ||
os.chdir(new_home) | ||
sys.path.append(new_home) | ||
# fallback to local import | ||
sys.path.append(old_home) | ||
|
||
from Utils import local_path | ||
local_path.cached_path = new_home |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import time | ||
|
||
|
||
class TimeIt: | ||
def __init__(self, name: str, time_logger=None): | ||
self.name = name | ||
self.logger = time_logger | ||
self.timer = None | ||
self.end_timer = None | ||
|
||
def __enter__(self): | ||
self.timer = time.perf_counter() | ||
return self | ||
|
||
@property | ||
def dif(self): | ||
return self.end_timer - self.timer | ||
|
||
def __exit__(self, exc_type, exc_val, exc_tb): | ||
if not self.end_timer: | ||
self.end_timer = time.perf_counter() | ||
if self.logger: | ||
self.logger.info(f"{self.dif:.4f} seconds in {self.name}.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters