-
Notifications
You must be signed in to change notification settings - Fork 704
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Core: Allow Option References and Math in YAML Option Values #4049
base: main
Are you sure you want to change the base?
Changes from 10 commits
11481ae
a6df17a
82e3648
ef9c348
a82d7d6
fd0dfc8
4140760
d9af9bc
6552cc0
9208ba0
a68c596
d5e8660
f37e062
e3fa262
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -10,7 +10,9 @@ | |||||||||||
import urllib.parse | ||||||||||||
import urllib.request | ||||||||||||
from collections import Counter | ||||||||||||
from typing import Any, Dict, Tuple, Union | ||||||||||||
from dataclasses import dataclass | ||||||||||||
from typing import Any, Dict, Tuple, Union, Callable, List, Literal, Mapping, Sequence | ||||||||||||
from typing_extensions import TypeGuard, assert_type | ||||||||||||
from itertools import chain | ||||||||||||
|
||||||||||||
import ModuleUpdate | ||||||||||||
|
@@ -23,6 +25,81 @@ | |||||||||||
from Utils import parse_yamls, version_tuple, __version__, tuplize_version | ||||||||||||
|
||||||||||||
|
||||||||||||
class ChoiceRecord: | ||||||||||||
def __init__(self, names_seen: list = None, values: dict = None, type_hint: dict = None): | ||||||||||||
if names_seen is None: | ||||||||||||
self.names_seen = [] | ||||||||||||
else: | ||||||||||||
self.names_seen = names_seen | ||||||||||||
if values is None: | ||||||||||||
self.values = {} | ||||||||||||
else: | ||||||||||||
self.values = values | ||||||||||||
self.type_hint = type_hint | ||||||||||||
|
||||||||||||
def get_random(self, option, root, value): | ||||||||||||
if self.type_hint is not None and option in self.type_hint: | ||||||||||||
world_option = self.type_hint[option] | ||||||||||||
if option in root: | ||||||||||||
if not world_option.supports_weighting: | ||||||||||||
temp_result = world_option.from_any(root[option]) | ||||||||||||
else: | ||||||||||||
self.values.update({option: None}) | ||||||||||||
temp_result = world_option.from_any( | ||||||||||||
get_choice(option, root, record=self)) | ||||||||||||
else: | ||||||||||||
temp_result = world_option.from_any(option.default) | ||||||||||||
elif value.startswith("random-range-"): | ||||||||||||
if len(value.split("-")) == 4: | ||||||||||||
value = value.split("-") | ||||||||||||
value_min = int(value[2]) | ||||||||||||
value_max = int(value[3]) | ||||||||||||
temp_result = random.randrange(value_min, value_max) | ||||||||||||
else: | ||||||||||||
value = value.split("-") | ||||||||||||
value_min = int(value[3]) | ||||||||||||
value_max = int(value[4]) | ||||||||||||
if value[2] == "low": | ||||||||||||
temp_result = int(round(random.triangular(value_min, value_max, value_min))) | ||||||||||||
elif value[2] == "medium": | ||||||||||||
temp_result = int(round(random.triangular(value_min, value_max))) | ||||||||||||
elif value[2] == "high": | ||||||||||||
temp_result = int(round(random.triangular(value_min, value_max, value_max))) | ||||||||||||
else: | ||||||||||||
raise Exception(f"Invalid weighting in random-range-x-min-max, " | ||||||||||||
f"x must be low, medium, or high. It is: {value[2]}") | ||||||||||||
else: | ||||||||||||
raise ValueError(f"Random is not defined for {object}") | ||||||||||||
if hasattr(temp_result, "name_lookup"): | ||||||||||||
if temp_result.name_lookup != {}: | ||||||||||||
temp_result = temp_result.current_key | ||||||||||||
else: | ||||||||||||
temp_result = temp_result.value | ||||||||||||
self.values.update({option: temp_result}) | ||||||||||||
return temp_result | ||||||||||||
|
||||||||||||
def name_init(self): | ||||||||||||
self.names_seen = [] | ||||||||||||
|
||||||||||||
def add_name(self, name): | ||||||||||||
self.names_seen.append(name) | ||||||||||||
|
||||||||||||
def pop_name(self): | ||||||||||||
return self.names_seen.pop() | ||||||||||||
|
||||||||||||
def check_names(self): | ||||||||||||
return self.names_seen | ||||||||||||
|
||||||||||||
def value_init(self): | ||||||||||||
self.values = {} | ||||||||||||
|
||||||||||||
def update_value(self, key, value): | ||||||||||||
self.values.update({key: value}) | ||||||||||||
|
||||||||||||
def check_values(self): | ||||||||||||
return self.values | ||||||||||||
|
||||||||||||
|
||||||||||||
def mystery_argparse(): | ||||||||||||
from settings import get_settings | ||||||||||||
settings = get_settings() | ||||||||||||
|
@@ -256,18 +333,225 @@ def get_choice_legacy(option, root, value=None) -> Any: | |||||||||||
raise RuntimeError(f"All options specified in \"{option}\" are weighted as zero.") | ||||||||||||
|
||||||||||||
|
||||||||||||
def get_choice(option, root, value=None) -> Any: | ||||||||||||
if option not in root: | ||||||||||||
return value | ||||||||||||
if type(root[option]) is list: | ||||||||||||
return random.choices(root[option])[0] | ||||||||||||
if type(root[option]) is not dict: | ||||||||||||
return root[option] | ||||||||||||
if not root[option]: | ||||||||||||
return value | ||||||||||||
if any(root[option].values()): | ||||||||||||
return random.choices(list(root[option].keys()), weights=list(map(int, root[option].values())))[0] | ||||||||||||
raise RuntimeError(f"All options specified in \"{option}\" are weighted as zero.") | ||||||||||||
BinOp = Literal["+", "-", "*", "/"] | ||||||||||||
|
||||||||||||
|
||||||||||||
_bops: Mapping[BinOp, Callable[[Union[float, int], Union[float,int]], Union[float, int]]] = { | ||||||||||||
"+": lambda a, b: a + b, | ||||||||||||
"-": lambda a, b: a - b, | ||||||||||||
"*": lambda a, b: a * b, | ||||||||||||
"/": lambda a, b: a / b, | ||||||||||||
} | ||||||||||||
""" binary operators """ | ||||||||||||
|
||||||||||||
|
||||||||||||
def _is_bin_op(x: object) -> TypeGuard[BinOp]: | ||||||||||||
return x in _bops | ||||||||||||
|
||||||||||||
|
||||||||||||
_precedence: Mapping[Union[BinOp, Literal["("]], int] = { | ||||||||||||
"(": 0, | ||||||||||||
"+": 1, | ||||||||||||
"-": 1, | ||||||||||||
"*": 2, | ||||||||||||
"/": 2, | ||||||||||||
} | ||||||||||||
|
||||||||||||
|
||||||||||||
@dataclass(frozen=True) | ||||||||||||
class Token: | ||||||||||||
val: Union[str, float] | ||||||||||||
|
||||||||||||
|
||||||||||||
@dataclass(frozen=True) | ||||||||||||
class NumberToken(Token): | ||||||||||||
val: float | ||||||||||||
|
||||||||||||
|
||||||||||||
@dataclass(frozen=True) | ||||||||||||
class NameToken(Token): | ||||||||||||
val: str | ||||||||||||
|
||||||||||||
|
||||||||||||
@dataclass(frozen=True) | ||||||||||||
class BinOpToken(Token): | ||||||||||||
val: BinOp | ||||||||||||
|
||||||||||||
|
||||||||||||
@dataclass(frozen=True) | ||||||||||||
class OpenParenToken(Token): | ||||||||||||
val: Literal["("] | ||||||||||||
|
||||||||||||
|
||||||||||||
@dataclass(frozen=True) | ||||||||||||
class CloseParenToken(Token): | ||||||||||||
val: Literal[")"] | ||||||||||||
|
||||||||||||
|
||||||||||||
AllTokens = Union[NumberToken, NameToken, BinOpToken, OpenParenToken, CloseParenToken] | ||||||||||||
PostfixTokens = Union[NameToken, NumberToken, BinOpToken] | ||||||||||||
|
||||||||||||
|
||||||||||||
def parse_tokens(s: str) -> Sequence[AllTokens]: | ||||||||||||
tokens: List[AllTokens] = [] | ||||||||||||
state: Literal["start", "end"] = "start" # sub-expression | ||||||||||||
paren_count = 0 | ||||||||||||
i = 0 | ||||||||||||
while i < len(s): | ||||||||||||
if s[i].isspace(): | ||||||||||||
i += 1 | ||||||||||||
continue | ||||||||||||
if state == "start": | ||||||||||||
if s[i] == "(": | ||||||||||||
tokens.append(OpenParenToken("(")) | ||||||||||||
paren_count += 1 | ||||||||||||
i += 1 | ||||||||||||
elif s[i].isdigit() or s[i] == "." or s[i] == "-": | ||||||||||||
j = i + 1 | ||||||||||||
while j < len(s) and (s[j].isdigit() or s[j] == "."): | ||||||||||||
j += 1 | ||||||||||||
val_str = s[i:j] | ||||||||||||
if len(val_str.split(".")) > 1: | ||||||||||||
val = float(val_str) | ||||||||||||
else: | ||||||||||||
val = int(val_str) | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
can simplify this since it will be turned into |
||||||||||||
tokens.append(NumberToken(val)) | ||||||||||||
i += len(val_str) | ||||||||||||
state = "end" | ||||||||||||
elif s[i].isalpha(): | ||||||||||||
j = i + 1 | ||||||||||||
while j < len(s) and (s[j].isalpha() or s[j].isdigit() or s[j] == "_"): | ||||||||||||
j += 1 | ||||||||||||
val_str = s[i:j] | ||||||||||||
tokens.append(NameToken(val_str)) | ||||||||||||
i += len(val_str) | ||||||||||||
state = "end" | ||||||||||||
else: | ||||||||||||
raise ValueError(f"unexpected symbol {s[i]} at {i} in: {s}") | ||||||||||||
else: | ||||||||||||
assert assert_type(state, Literal["end"]) == "end", f"{state=}" | ||||||||||||
if s[i] == ")": | ||||||||||||
if paren_count <= 0: | ||||||||||||
raise ValueError(f"unmatched close parentheses at {i} in: {s}") | ||||||||||||
paren_count -= 1 | ||||||||||||
tokens.append(CloseParenToken(")")) | ||||||||||||
i += 1 | ||||||||||||
elif _is_bin_op(val_str := s[i]): | ||||||||||||
tokens.append(BinOpToken(val_str)) | ||||||||||||
state = "start" | ||||||||||||
i += 1 | ||||||||||||
else: | ||||||||||||
raise ValueError(f"unexpected symbol {s[i]} at {i} in: {s}") | ||||||||||||
if paren_count != 0: | ||||||||||||
raise ValueError(f"unclosed parentheses in: {s}") | ||||||||||||
if state != "end": | ||||||||||||
raise ValueError(f"invalid expression: {s}") | ||||||||||||
return tokens | ||||||||||||
|
||||||||||||
|
||||||||||||
def infix_to_postfix(infix: Sequence[AllTokens]) -> Sequence[PostfixTokens]: | ||||||||||||
infix = [OpenParenToken("(")] + list(infix) + [CloseParenToken(")")] | ||||||||||||
op_stack: List[Union[OpenParenToken, BinOpToken]] = [] | ||||||||||||
postfix: List[PostfixTokens] = [] | ||||||||||||
for token in infix: | ||||||||||||
if isinstance(token, OpenParenToken): | ||||||||||||
op_stack.append(token) | ||||||||||||
elif isinstance(token, CloseParenToken): | ||||||||||||
while len(op_stack): | ||||||||||||
op = op_stack.pop() | ||||||||||||
if isinstance(op, OpenParenToken): | ||||||||||||
break | ||||||||||||
else: | ||||||||||||
assert isinstance(assert_type(op, BinOpToken), BinOpToken), f"{op=}" | ||||||||||||
postfix.append(op) | ||||||||||||
elif isinstance(token, BinOpToken): | ||||||||||||
while True: | ||||||||||||
assert len(op_stack), f"{infix=}" | ||||||||||||
top = op_stack[-1] | ||||||||||||
if _precedence[top.val] >= _precedence[token.val]: | ||||||||||||
top = op_stack.pop() | ||||||||||||
assert isinstance(top, BinOpToken), f"{top=} {infix=}" | ||||||||||||
postfix.append(top) | ||||||||||||
else: | ||||||||||||
break | ||||||||||||
op_stack.append(token) | ||||||||||||
else: | ||||||||||||
assert isinstance(token, NumberToken) or isinstance(assert_type(token, NameToken), NameToken), f"{token=}" | ||||||||||||
postfix.append(token) | ||||||||||||
return postfix | ||||||||||||
|
||||||||||||
|
||||||||||||
def eval_postfix(postfix: Sequence[PostfixTokens], name_resolver: Callable[[str], float]) -> float: | ||||||||||||
operand_stack: List[float] = [] | ||||||||||||
for token in postfix: | ||||||||||||
if isinstance(token, NumberToken): | ||||||||||||
operand_stack.append(token.val) | ||||||||||||
elif isinstance(token, NameToken): | ||||||||||||
operand_stack.append(name_resolver(token.val)) | ||||||||||||
else: | ||||||||||||
assert isinstance(assert_type(token, BinOpToken), BinOpToken), f"{token=}" | ||||||||||||
if len(operand_stack) < 2: | ||||||||||||
raise ValueError(f"invalid {postfix=}") | ||||||||||||
b = operand_stack.pop() | ||||||||||||
a = operand_stack.pop() | ||||||||||||
operand_stack.append(_bops[token.val](a, b)) | ||||||||||||
if len(operand_stack) != 1: | ||||||||||||
raise ValueError(f"invalid {postfix=}") | ||||||||||||
return operand_stack[0] | ||||||||||||
|
||||||||||||
|
||||||||||||
def get_choice(option, root, value=None, sub_group=None, record: ChoiceRecord = ChoiceRecord(None, None, None)) -> Any: | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a mutable default argument for I recommend you use ruff with the "B" rules enabled. It will point out issues like this to you. |
||||||||||||
if record.check_values() is None: | ||||||||||||
record.value_init() | ||||||||||||
if record.check_names() is None: | ||||||||||||
record.name_init() | ||||||||||||
if option in record.check_names(): | ||||||||||||
raise RuntimeError(f"Recursive variable: value {option}'s value references to itself") | ||||||||||||
if sub_group is not None: | ||||||||||||
choice_group = sub_group | ||||||||||||
else: | ||||||||||||
choice_group = root | ||||||||||||
if option not in choice_group: | ||||||||||||
temp_result = value | ||||||||||||
elif type(choice_group[option]) is list: | ||||||||||||
temp_result = random.choices(root[option])[0] | ||||||||||||
elif type(choice_group[option]) is not dict: | ||||||||||||
temp_result = choice_group[option] | ||||||||||||
elif not choice_group[option]: | ||||||||||||
temp_result = value | ||||||||||||
elif any(choice_group[option].values()): | ||||||||||||
temp_result = random.choices(list(choice_group[option].keys()), weights=list(map(int, choice_group[option].values())))[0] | ||||||||||||
else: | ||||||||||||
raise RuntimeError(f"All options specified in \"{option}\" are weighted as zero.") | ||||||||||||
if isinstance(temp_result, str) and option not in record.check_values(): | ||||||||||||
if temp_result.startswith("$(") and temp_result.endswith(")"): | ||||||||||||
record.add_name(option) | ||||||||||||
|
||||||||||||
def name_resolver(name: str) -> float: | ||||||||||||
if name in root: | ||||||||||||
named_value = record.check_values().get(name) | ||||||||||||
if named_value is None: | ||||||||||||
named_value = get_choice(name, root, record=record) | ||||||||||||
if not isinstance(named_value, (float, int)): | ||||||||||||
raise ValueError(f"{name=} used in math is not a number: {named_value=}") | ||||||||||||
record.update_value(name, named_value) | ||||||||||||
return named_value | ||||||||||||
else: | ||||||||||||
raise KeyError(f"{name} in {temp_result} has not been assigned a value in yaml") | ||||||||||||
|
||||||||||||
temp_result = eval_postfix(infix_to_postfix(parse_tokens(temp_result[2:-1])), name_resolver) | ||||||||||||
record.update_value(record.pop_name(), temp_result) | ||||||||||||
if record.type_hint is not None and option in record.type_hint: | ||||||||||||
if type(record.type_hint[option].default) is int: | ||||||||||||
temp_result = int(temp_result) | ||||||||||||
record.update_value(option, temp_result) | ||||||||||||
elif temp_result.startswith("random"): | ||||||||||||
temp_result = record.get_random(option, root, temp_result) | ||||||||||||
else: | ||||||||||||
record.update_value(option, temp_result) | ||||||||||||
if len(record.check_names()) == 0: | ||||||||||||
root.update(record.check_values()) | ||||||||||||
return temp_result | ||||||||||||
|
||||||||||||
|
||||||||||||
class SafeDict(dict): | ||||||||||||
|
@@ -378,7 +662,7 @@ def roll_linked_options(weights: dict) -> dict: | |||||||||||
return weights | ||||||||||||
|
||||||||||||
|
||||||||||||
def roll_triggers(weights: dict, triggers: list, valid_keys: set) -> dict: | ||||||||||||
def roll_triggers(weights: dict, triggers: list, valid_keys: set, type_hints: dict = None) -> dict: | ||||||||||||
weights = copy.deepcopy(weights) # make sure we don't write back to other weights sets in same_settings | ||||||||||||
weights["_Generator_Version"] = Utils.__version__ | ||||||||||||
for i, option_set in enumerate(triggers): | ||||||||||||
|
@@ -392,8 +676,8 @@ def roll_triggers(weights: dict, triggers: list, valid_keys: set) -> dict: | |||||||||||
logging.warning(f'Specified option name {option_set["option_name"]} did not ' | ||||||||||||
f'match with a root option. ' | ||||||||||||
f'This is probably in error.') | ||||||||||||
trigger_result = get_choice("option_result", option_set) | ||||||||||||
result = get_choice(key, currently_targeted_weights) | ||||||||||||
trigger_result = get_choice("option_result", currently_targeted_weights, sub_group=option_set, record=ChoiceRecord([],{}, type_hints)) | ||||||||||||
result = get_choice(key, currently_targeted_weights, record=ChoiceRecord([],{}, type_hints)) | ||||||||||||
currently_targeted_weights[key] = result | ||||||||||||
if result == trigger_result and roll_percentage(get_choice("percentage", option_set, 100)): | ||||||||||||
for category_name, category_options in option_set["options"].items(): | ||||||||||||
|
@@ -408,13 +692,14 @@ def roll_triggers(weights: dict, triggers: list, valid_keys: set) -> dict: | |||||||||||
return weights | ||||||||||||
|
||||||||||||
|
||||||||||||
def handle_option(ret: argparse.Namespace, game_weights: dict, option_key: str, option: type(Options.Option), plando_options: PlandoOptions): | ||||||||||||
def handle_option(ret: argparse.Namespace, game_weights: dict, option_key: str, option: type(Options.Option), plando_options: PlandoOptions, type_hints = None, ): | ||||||||||||
try: | ||||||||||||
if option_key in game_weights: | ||||||||||||
if not option.supports_weighting: | ||||||||||||
player_option = option.from_any(game_weights[option_key]) | ||||||||||||
else: | ||||||||||||
player_option = option.from_any(get_choice(option_key, game_weights)) | ||||||||||||
player_option = option.from_any(get_choice(option_key, game_weights, | ||||||||||||
record=ChoiceRecord([], {}, type_hints))) | ||||||||||||
else: | ||||||||||||
player_option = option.from_any(option.default) # call the from_any here to support default "random" | ||||||||||||
setattr(ret, option_key, player_option) | ||||||||||||
|
@@ -476,15 +761,15 @@ def roll_settings(weights: dict, plando_options: PlandoOptions = PlandoOptions.b | |||||||||||
raise Exception(f"Remove tag cannot be used outside of trigger contexts. Found {weight}") | ||||||||||||
|
||||||||||||
if "triggers" in game_weights: | ||||||||||||
weights = roll_triggers(weights, game_weights["triggers"], valid_keys) | ||||||||||||
weights = roll_triggers(weights, game_weights["triggers"], valid_keys, world_type.options_dataclass.type_hints) | ||||||||||||
game_weights = weights[ret.game] | ||||||||||||
|
||||||||||||
ret.name = get_choice('name', weights) | ||||||||||||
for option_key, option in Options.CommonOptions.type_hints.items(): | ||||||||||||
setattr(ret, option_key, option.from_any(get_choice(option_key, weights, option.default))) | ||||||||||||
|
||||||||||||
for option_key, option in world_type.options_dataclass.type_hints.items(): | ||||||||||||
handle_option(ret, game_weights, option_key, option, plando_options) | ||||||||||||
handle_option(ret, game_weights, option_key, option, plando_options, world_type.options_dataclass.type_hints) | ||||||||||||
valid_keys.add(option_key) | ||||||||||||
for option_key in game_weights: | ||||||||||||
if option_key in {"triggers", *valid_keys}: | ||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PR author is already aware of this, but this is a note to other reviewers.
This isn't a good place for this math parsing code, but it's not a lot worse than any other place that's available.
It would be good to put it in its own module as a sub-module to
Utils
#4032