Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Core: Allow Option References and Math in YAML Option Values #4049

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
325 changes: 305 additions & 20 deletions Generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
import urllib.parse
import urllib.request
from collections import Counter
from typing import Any, Dict, Tuple, Union
from dataclasses import dataclass
from typing import Any, Dict, Tuple, Union, Callable, List, Literal, Mapping, Sequence
from typing_extensions import TypeGuard, assert_type
from itertools import chain

import ModuleUpdate
Expand All @@ -23,6 +25,81 @@
from Utils import parse_yamls, version_tuple, __version__, tuplize_version


class ChoiceRecord:
def __init__(self, names_seen: list = None, values: dict = None, type_hint: dict = None):
if names_seen is None:
self.names_seen = []
else:
self.names_seen = names_seen
if values is None:
self.values = {}
else:
self.values = values
self.type_hint = type_hint

def get_random(self, option, root, value):
if self.type_hint is not None and option in self.type_hint:
world_option = self.type_hint[option]
if option in root:
if not world_option.supports_weighting:
temp_result = world_option.from_any(root[option])
else:
self.values.update({option: None})
temp_result = world_option.from_any(
get_choice(option, root, record=self))
else:
temp_result = world_option.from_any(option.default)
elif value.startswith("random-range-"):
if len(value.split("-")) == 4:
value = value.split("-")
value_min = int(value[2])
value_max = int(value[3])
temp_result = random.randrange(value_min, value_max)
else:
value = value.split("-")
value_min = int(value[3])
value_max = int(value[4])
if value[2] == "low":
temp_result = int(round(random.triangular(value_min, value_max, value_min)))
elif value[2] == "medium":
temp_result = int(round(random.triangular(value_min, value_max)))
elif value[2] == "high":
temp_result = int(round(random.triangular(value_min, value_max, value_max)))
else:
raise Exception(f"Invalid weighting in random-range-x-min-max, "
f"x must be low, medium, or high. It is: {value[2]}")
else:
raise ValueError(f"Random is not defined for {object}")
if hasattr(temp_result, "name_lookup"):
if temp_result.name_lookup != {}:
temp_result = temp_result.current_key
else:
temp_result = temp_result.value
self.values.update({option: temp_result})
return temp_result

def name_init(self):
self.names_seen = []

def add_name(self, name):
self.names_seen.append(name)

def pop_name(self):
return self.names_seen.pop()

def check_names(self):
return self.names_seen

def value_init(self):
self.values = {}

def update_value(self, key, value):
self.values.update({key: value})

def check_values(self):
return self.values


def mystery_argparse():
from settings import get_settings
settings = get_settings()
Expand Down Expand Up @@ -256,18 +333,225 @@ def get_choice_legacy(option, root, value=None) -> Any:
raise RuntimeError(f"All options specified in \"{option}\" are weighted as zero.")


def get_choice(option, root, value=None) -> Any:
if option not in root:
return value
if type(root[option]) is list:
return random.choices(root[option])[0]
if type(root[option]) is not dict:
return root[option]
if not root[option]:
return value
if any(root[option].values()):
return random.choices(list(root[option].keys()), weights=list(map(int, root[option].values())))[0]
raise RuntimeError(f"All options specified in \"{option}\" are weighted as zero.")
BinOp = Literal["+", "-", "*", "/"]


_bops: Mapping[BinOp, Callable[[Union[float, int], Union[float,int]], Union[float, int]]] = {
"+": lambda a, b: a + b,
Comment on lines +335 to +339
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR author is already aware of this, but this is a note to other reviewers.

This isn't a good place for this math parsing code, but it's not a lot worse than any other place that's available.

It would be good to put it in its own module as a sub-module to Utils #4032

"-": lambda a, b: a - b,
"*": lambda a, b: a * b,
"/": lambda a, b: a / b,
}
""" binary operators """


def _is_bin_op(x: object) -> TypeGuard[BinOp]:
return x in _bops


_precedence: Mapping[Union[BinOp, Literal["("]], int] = {
"(": 0,
"+": 1,
"-": 1,
"*": 2,
"/": 2,
}


@dataclass(frozen=True)
class Token:
val: Union[str, float]


@dataclass(frozen=True)
class NumberToken(Token):
val: float


@dataclass(frozen=True)
class NameToken(Token):
val: str


@dataclass(frozen=True)
class BinOpToken(Token):
val: BinOp


@dataclass(frozen=True)
class OpenParenToken(Token):
val: Literal["("]


@dataclass(frozen=True)
class CloseParenToken(Token):
val: Literal[")"]


AllTokens = Union[NumberToken, NameToken, BinOpToken, OpenParenToken, CloseParenToken]
PostfixTokens = Union[NameToken, NumberToken, BinOpToken]


def parse_tokens(s: str) -> Sequence[AllTokens]:
tokens: List[AllTokens] = []
state: Literal["start", "end"] = "start" # sub-expression
paren_count = 0
i = 0
while i < len(s):
if s[i].isspace():
i += 1
continue
if state == "start":
if s[i] == "(":
tokens.append(OpenParenToken("("))
paren_count += 1
i += 1
elif s[i].isdigit() or s[i] == "." or s[i] == "-":
j = i + 1
while j < len(s) and (s[j].isdigit() or s[j] == "."):
j += 1
val_str = s[i:j]
if len(val_str.split(".")) > 1:
val = float(val_str)
else:
val = int(val_str)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if len(val_str.split(".")) > 1:
val = float(val_str)
else:
val = int(val_str)
val = float(val_str)

can simplify this since it will be turned into int at the end if needed

tokens.append(NumberToken(val))
i += len(val_str)
state = "end"
elif s[i].isalpha():
j = i + 1
while j < len(s) and (s[j].isalpha() or s[j].isdigit() or s[j] == "_"):
j += 1
val_str = s[i:j]
tokens.append(NameToken(val_str))
i += len(val_str)
state = "end"
else:
raise ValueError(f"unexpected symbol {s[i]} at {i} in: {s}")
else:
assert assert_type(state, Literal["end"]) == "end", f"{state=}"
if s[i] == ")":
if paren_count <= 0:
raise ValueError(f"unmatched close parentheses at {i} in: {s}")
paren_count -= 1
tokens.append(CloseParenToken(")"))
i += 1
elif _is_bin_op(val_str := s[i]):
tokens.append(BinOpToken(val_str))
state = "start"
i += 1
else:
raise ValueError(f"unexpected symbol {s[i]} at {i} in: {s}")
if paren_count != 0:
raise ValueError(f"unclosed parentheses in: {s}")
if state != "end":
raise ValueError(f"invalid expression: {s}")
return tokens


def infix_to_postfix(infix: Sequence[AllTokens]) -> Sequence[PostfixTokens]:
infix = [OpenParenToken("(")] + list(infix) + [CloseParenToken(")")]
op_stack: List[Union[OpenParenToken, BinOpToken]] = []
postfix: List[PostfixTokens] = []
for token in infix:
if isinstance(token, OpenParenToken):
op_stack.append(token)
elif isinstance(token, CloseParenToken):
while len(op_stack):
op = op_stack.pop()
if isinstance(op, OpenParenToken):
break
else:
assert isinstance(assert_type(op, BinOpToken), BinOpToken), f"{op=}"
postfix.append(op)
elif isinstance(token, BinOpToken):
while True:
assert len(op_stack), f"{infix=}"
top = op_stack[-1]
if _precedence[top.val] >= _precedence[token.val]:
top = op_stack.pop()
assert isinstance(top, BinOpToken), f"{top=} {infix=}"
postfix.append(top)
else:
break
op_stack.append(token)
else:
assert isinstance(token, NumberToken) or isinstance(assert_type(token, NameToken), NameToken), f"{token=}"
postfix.append(token)
return postfix


def eval_postfix(postfix: Sequence[PostfixTokens], name_resolver: Callable[[str], float]) -> float:
operand_stack: List[float] = []
for token in postfix:
if isinstance(token, NumberToken):
operand_stack.append(token.val)
elif isinstance(token, NameToken):
operand_stack.append(name_resolver(token.val))
else:
assert isinstance(assert_type(token, BinOpToken), BinOpToken), f"{token=}"
if len(operand_stack) < 2:
raise ValueError(f"invalid {postfix=}")
b = operand_stack.pop()
a = operand_stack.pop()
operand_stack.append(_bops[token.val](a, b))
if len(operand_stack) != 1:
raise ValueError(f"invalid {postfix=}")
return operand_stack[0]


def get_choice(option, root, value=None, sub_group=None, record: ChoiceRecord = ChoiceRecord(None, None, None)) -> Any:
Copy link
Collaborator

@beauxq beauxq Oct 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a mutable default argument for record
https://docs.python-guide.org/writing/gotchas/#mutable-default-arguments

I recommend you use ruff with the "B" rules enabled. It will point out issues like this to you.

if record.check_values() is None:
record.value_init()
if record.check_names() is None:
record.name_init()
if option in record.check_names():
raise RuntimeError(f"Recursive variable: value {option}'s value references to itself")
if sub_group is not None:
choice_group = sub_group
else:
choice_group = root
if option not in choice_group:
temp_result = value
elif type(choice_group[option]) is list:
temp_result = random.choices(root[option])[0]
elif type(choice_group[option]) is not dict:
temp_result = choice_group[option]
elif not choice_group[option]:
temp_result = value
elif any(choice_group[option].values()):
temp_result = random.choices(list(choice_group[option].keys()), weights=list(map(int, choice_group[option].values())))[0]
else:
raise RuntimeError(f"All options specified in \"{option}\" are weighted as zero.")
if isinstance(temp_result, str) and option not in record.check_values():
if temp_result.startswith("$(") and temp_result.endswith(")"):
record.add_name(option)

def name_resolver(name: str) -> float:
if name in root:
named_value = record.check_values().get(name)
if named_value is None:
named_value = get_choice(name, root, record=record)
if not isinstance(named_value, (float, int)):
raise ValueError(f"{name=} used in math is not a number: {named_value=}")
record.update_value(name, named_value)
return named_value
else:
raise KeyError(f"{name} in {temp_result} has not been assigned a value in yaml")

temp_result = eval_postfix(infix_to_postfix(parse_tokens(temp_result[2:-1])), name_resolver)
record.update_value(record.pop_name(), temp_result)
if record.type_hint is not None and option in record.type_hint:
if type(record.type_hint[option].default) is int:
temp_result = int(temp_result)
record.update_value(option, temp_result)
elif temp_result.startswith("random"):
temp_result = record.get_random(option, root, temp_result)
else:
record.update_value(option, temp_result)
if len(record.check_names()) == 0:
root.update(record.check_values())
return temp_result


class SafeDict(dict):
Expand Down Expand Up @@ -378,7 +662,7 @@ def roll_linked_options(weights: dict) -> dict:
return weights


def roll_triggers(weights: dict, triggers: list, valid_keys: set) -> dict:
def roll_triggers(weights: dict, triggers: list, valid_keys: set, type_hints: dict = None) -> dict:
weights = copy.deepcopy(weights) # make sure we don't write back to other weights sets in same_settings
weights["_Generator_Version"] = Utils.__version__
for i, option_set in enumerate(triggers):
Expand All @@ -392,8 +676,8 @@ def roll_triggers(weights: dict, triggers: list, valid_keys: set) -> dict:
logging.warning(f'Specified option name {option_set["option_name"]} did not '
f'match with a root option. '
f'This is probably in error.')
trigger_result = get_choice("option_result", option_set)
result = get_choice(key, currently_targeted_weights)
trigger_result = get_choice("option_result", currently_targeted_weights, sub_group=option_set, record=ChoiceRecord([],{}, type_hints))
result = get_choice(key, currently_targeted_weights, record=ChoiceRecord([],{}, type_hints))
currently_targeted_weights[key] = result
if result == trigger_result and roll_percentage(get_choice("percentage", option_set, 100)):
for category_name, category_options in option_set["options"].items():
Expand All @@ -408,13 +692,14 @@ def roll_triggers(weights: dict, triggers: list, valid_keys: set) -> dict:
return weights


def handle_option(ret: argparse.Namespace, game_weights: dict, option_key: str, option: type(Options.Option), plando_options: PlandoOptions):
def handle_option(ret: argparse.Namespace, game_weights: dict, option_key: str, option: type(Options.Option), plando_options: PlandoOptions, type_hints = None, ):
try:
if option_key in game_weights:
if not option.supports_weighting:
player_option = option.from_any(game_weights[option_key])
else:
player_option = option.from_any(get_choice(option_key, game_weights))
player_option = option.from_any(get_choice(option_key, game_weights,
record=ChoiceRecord([], {}, type_hints)))
else:
player_option = option.from_any(option.default) # call the from_any here to support default "random"
setattr(ret, option_key, player_option)
Expand Down Expand Up @@ -476,15 +761,15 @@ def roll_settings(weights: dict, plando_options: PlandoOptions = PlandoOptions.b
raise Exception(f"Remove tag cannot be used outside of trigger contexts. Found {weight}")

if "triggers" in game_weights:
weights = roll_triggers(weights, game_weights["triggers"], valid_keys)
weights = roll_triggers(weights, game_weights["triggers"], valid_keys, world_type.options_dataclass.type_hints)
game_weights = weights[ret.game]

ret.name = get_choice('name', weights)
for option_key, option in Options.CommonOptions.type_hints.items():
setattr(ret, option_key, option.from_any(get_choice(option_key, weights, option.default)))

for option_key, option in world_type.options_dataclass.type_hints.items():
handle_option(ret, game_weights, option_key, option, plando_options)
handle_option(ret, game_weights, option_key, option, plando_options, world_type.options_dataclass.type_hints)
valid_keys.add(option_key)
for option_key in game_weights:
if option_key in {"triggers", *valid_keys}:
Expand Down
Loading