Skip to content

Commit

Permalink
Merge pull request #15 from valohai/dashes
Browse files Browse the repository at this point in the history
Add tooling for dealing with dashed identifiers; port over common boolean evaluator
  • Loading branch information
hylje authored Jul 16, 2024
2 parents c761d13 + 9d596d9 commit feac5c0
Show file tree
Hide file tree
Showing 9 changed files with 378 additions and 13 deletions.
9 changes: 8 additions & 1 deletion leval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
allowed_constant_types: Optional[Iterable[type]] = None,
allowed_container_types: Optional[Iterable[type]] = None,
loose_is_operator: bool = True,
loose_not_operator: bool = True,
):
"""
Initialize an evaluator with access to the given evaluation universe.
Expand All @@ -69,6 +70,7 @@ def __init__(
self.max_depth = _default_if_none(max_depth, self.default_max_depth)
self.max_time = float(max_time or 0)
self.loose_is_operator = bool(loose_is_operator)
self.loose_not_operator = bool(loose_not_operator)
self.allowed_constant_types = frozenset(
_default_if_none(
allowed_constant_types,
Expand Down Expand Up @@ -183,7 +185,12 @@ def visit_BoolOp(self, node): # noqa: D102
return self.universe.evaluate_bool_op(node.op, value_getters)

def visit_UnaryOp(self, node): # noqa: D102
operand = self.visit(node.operand)
try:
operand = self.visit(node.operand)
except NoSuchValue:
if self.loose_not_operator and isinstance(node.op, ast.Not):
return True
raise
if isinstance(node.op, ast.UAdd):
return +operand
if isinstance(node.op, ast.USub):
Expand Down
5 changes: 5 additions & 0 deletions leval/extras/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""
The extras subpackage contains convenience classes.
These have no stability guarantees and are not part of the main API.
"""
118 changes: 118 additions & 0 deletions leval/extras/common_boolean_evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import keyword
import tokenize
from typing import Any, Dict, List, Optional, Tuple, Union

from leval.excs import NoSuchFunction
from leval.rewriter_evaluator import RewriterEvaluator
from leval.rewriter_utils import (
convert_dash_identifiers,
get_parts_from_dashed_identifier_tokens,
make_glued_name_token,
)
from leval.universe.verifier import VerifierUniverse
from leval.universe.weakly_typed import WeaklyTypedSimpleUniverse

DEFAULT_FUNCTIONS = {
"abs": abs,
"min": min,
"max": max,
}

ValuesDict = Dict[Union[Tuple[str, ...], str], Any]

KEYWORD_PREFIX = "K\u203f"
DASH_SEP = "\u203f\u203f"


def _rewrite_keyword(kw: str) -> str:
if keyword.iskeyword(kw):
return f"{KEYWORD_PREFIX}{kw}"
return kw


def _convert_dash_tokens(tokens: List[tokenize.TokenInfo]):
if not tokens:
return []
glued_name = "".join(
get_parts_from_dashed_identifier_tokens(tokens, separator=DASH_SEP),
)
return [make_glued_name_token(tokens, glued_name)]


def _prepare_name(name: str) -> str:
return DASH_SEP.join(_rewrite_keyword(p) for p in name.split("-"))


class _CommonEvaluator(RewriterEvaluator):
def rewrite_keyword(self, kw: str) -> str:
return _rewrite_keyword(kw)

def process_tokens(self, tokens):
return convert_dash_identifiers(tokens, _convert_dash_tokens)


class _CommonUniverse(WeaklyTypedSimpleUniverse):
def evaluate_function(self, name, arg_getters):
func = self.functions.get(name)
if not func:
raise NoSuchFunction(f"No function {name}")
args = [getter() for getter in arg_getters]
for arg in args:
# This is using `type(...)` on purpose; we don't want to allow subclasses.
if type(arg) not in (int, float, str, bool):
raise TypeError(f"Invalid argument for {name}: {type(arg)}")
return func(*args)


def _prepare_values(values: ValuesDict) -> ValuesDict:
"""
Prepare a values dictionary by rewriting names like the evaluation would.
"""
prepared_values = {}
for key, value in values.items():
if isinstance(key, tuple):
key = tuple(_prepare_name(p) for p in key)
elif isinstance(key, str):
key = _prepare_name(key)
else:
raise TypeError(f"Invalid key type: {type(key)}")
prepared_values[key] = value
return prepared_values


class CommonBooleanEvaluator:
functions: dict = DEFAULT_FUNCTIONS
max_depth: int = 8
max_time: float = 0.2
verifier_universe_class = VerifierUniverse
universe_class = _CommonUniverse
evaluator_class = _CommonEvaluator

def evaluate(self, expr: Optional[str], values: ValuesDict) -> Optional[bool]:
"""
Evaluate the given expression against the given values.
The values dictionary's keys will be prepared to the expected internal format.
"""
if not expr:
return None
universe = self.universe_class(
functions=self.functions,
values=_prepare_values(values),
)
evl = self.evaluator_class(
universe,
max_depth=self.max_depth,
max_time=self.max_time,
)
return bool(evl.evaluate_expression(expr))

def verify(self, expression: str) -> bool:
"""
Verify that the given expression is technically valid.
"""
evl = self.evaluator_class(
self.verifier_universe_class(),
max_depth=self.max_depth,
)
return evl.evaluate_expression(expression)
37 changes: 26 additions & 11 deletions leval/rewriter_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import ast
import io
import keyword
import tokenize
from typing import Iterable
from typing import Iterable, List

from leval.evaluator import Evaluator
from leval.utils import tokenize_expression

# Keyword-like elements that can be used in an expression.
EXPRESSION_KEYWORDS = {
Expand All @@ -20,10 +20,6 @@
}


def _tokenize_expression(expression: str) -> Iterable[tokenize.TokenInfo]:
return tokenize.generate_tokens(io.StringIO(expression).readline)


class RewriterEvaluator(Evaluator):
def parse(self, expression: str) -> ast.AST:
"""
Expand All @@ -42,20 +38,39 @@ def rewrite_expression(self, expression: str) -> str:
This is useful for rewriting code that are not valid Python
expressions (e.g. containing suites or reserved keywords).
"""
bits = []
for tok in _tokenize_expression(expression):
tokens = tokenize_expression(expression)
tokens = list(self.rewrite_keywords(tokens))
return tokenize.untokenize(self.process_tokens(tokens))

def rewrite_keywords(
self,
tokens: Iterable[tokenize.TokenInfo],
) -> Iterable[tokenize.TokenInfo]:
"""
Do a keyword-rewriting pass on the tokens.
"""
for tok in tokens:
if (
tok.type == tokenize.NAME
and keyword.iskeyword(tok.string)
and tok.string not in EXPRESSION_KEYWORDS
):
tok = tok._replace(string=self.rewrite_keyword(tok.string))
bits.append(tok)
expression = tokenize.untokenize(bits)
return expression
yield tok

def rewrite_keyword(self, kw: str) -> str:
"""
Return the replacement for the given keyword.
"""
raise SyntaxError(f"Keyword {kw!r} can not be used") # pragma: no cover

def process_tokens(
self,
tokens: List[tokenize.TokenInfo],
) -> Iterable[tokenize.TokenInfo]:
"""
Process the token stream before untokenizing it back to a string.
Does nothing by default, but can be overridden.
"""
return tokens
102 changes: 102 additions & 0 deletions leval/rewriter_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import tokenize
from typing import Callable, Iterable, Iterator, List, Optional


def tokens_are_adjacent(t1: tokenize.TokenInfo, t2: tokenize.TokenInfo) -> bool:
"""
Return True if the two tokens are adjacent in the source code.
That is, token 1 ends exactly where token 2 starts.
"""
return t1.end[0] == t2.start[0] and t1.end[1] == t2.start[1]


def make_glued_name_token(tokens: List[tokenize.TokenInfo], name: str):
"""
Create a new token for an identifier that spans the given token position range.
It does not validate that the resulting token is actually a valid Python identifier,
or that the range is actually valid (`name` could be longer than the range of the
original tokens).
"""
stok = tokens[0]
etok = tokens[-1]
return tokenize.TokenInfo(
tokenize.NAME,
name,
stok.start,
etok.end,
stok.line,
)


def get_parts_from_dashed_identifier_tokens(
tokens: Iterable[tokenize.TokenInfo],
separator: Optional[str] = None,
) -> Iterable[str]:
"""
Yield the parts of a dashed identifier from the given token stream.
If `separator` is set, it is yielded for dashes in the identifier.
"""
for tok in tokens:
if tok.type in (tokenize.NAME, tokenize.NUMBER):
yield tok.string
elif tok.type == tokenize.OP and tok.string == "-":
if separator:
yield separator
continue
else:
raise SyntaxError("Invalid token")


def _maybe_process_dash_identifier(
initial_token: tokenize.TokenInfo,
tok_iter: Iterator[tokenize.TokenInfo],
converter: Callable[[List[tokenize.TokenInfo]], Iterable[tokenize.TokenInfo]],
):
tokens = [initial_token]
while True:
tok = next(tok_iter, None)
if tok is None:
break
if not tokens_are_adjacent(tokens[-1], tok):
break
if (
tok.type == tokenize.NAME
or (tok.type == tokenize.OP and tok.string == "-")
or (tok.type == tokenize.NUMBER and tok.string.isdigit())
):
tokens.append(tok)
else:
break
if tokens: # Yield the converted token(s) if there are tokens to convert.
if tokens[-1].type == tokenize.OP: # ended with a dash? no conversion
yield from tokens
else:
yield from converter(tokens)
if tok: # Yield the last token that broke the loop.
yield tok


def convert_dash_identifiers(
tokens: List[tokenize.TokenInfo],
converter: Callable[[List[tokenize.TokenInfo]], Iterable[tokenize.TokenInfo]],
) -> Iterable[tokenize.TokenInfo]:
"""
Convert dashed identifiers in the given token stream.
In particular, converts e.g. `foo-bar-baz-quux` that is actually
`NAME OP(-) NAME (...)` with no spaces in between, into a single
token via the given converter function.
"""
tok_iter = iter(tokens)

while True:
tok = next(tok_iter, None)
if tok is None:
break
if tok.type == tokenize.NAME: # Could be the start of a dashed identifier.
yield from _maybe_process_dash_identifier(tok, tok_iter, converter)
continue
yield tok
13 changes: 12 additions & 1 deletion leval/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import ast
from typing import Tuple
import io
import tokenize
from typing import Iterable, Tuple

from leval.excs import InvalidAttribute

Expand Down Expand Up @@ -31,3 +33,12 @@ def walk_attr(kid):

walk_attr(node)
return tuple(str(bit) for bit in attr_bits[::-1])


def tokenize_expression(expression: str) -> Iterable[tokenize.TokenInfo]:
"""
Tokenize the given expression and return the tokens.
Will likely misbehave if the expression is e.g. multi-line.
"""
return tokenize.generate_tokens(io.StringIO(expression).readline)
49 changes: 49 additions & 0 deletions leval_tests/test_common_boolean_evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pytest

from leval.excs import InvalidOperation, NoSuchValue, TooComplex
from leval.extras.common_boolean_evaluator import CommonBooleanEvaluator


@pytest.mark.parametrize(
("expression", "exception"),
[
("+".join("a" * 500), TooComplex),
("b <", SyntaxError),
("os.system()", InvalidOperation),
],
)
def test_validation(expression, exception):
with pytest.raises(exception):
CommonBooleanEvaluator().verify(expression)
with pytest.raises(exception):
CommonBooleanEvaluator().evaluate(expression, {})


test_vars = {
("foo", "baz-quux"): 9,
"continue": True,
"v1": 74,
"v2": 42,
}


@pytest.mark.parametrize(
("expression", "values", "expected"),
[
("foo.baz-quux > 8", test_vars, True), # dash in name
("foo.baz - quux > 8", test_vars, NoSuchValue),
("continue or not pause", test_vars, True), # keyword
("cookie is None", test_vars, True), # loose "is"
("not class", test_vars, True), # loose "not"
("min(v1, v2) < 50", test_vars, True),
("max(v1, v2) > 50", test_vars, True),
("max()", test_vars, TypeError),
("max((1,2,3))", test_vars, TypeError), # Invalid argument type (tuple)
],
)
def test_expressions(expression, values, expected):
if isinstance(expected, type) and issubclass(expected, Exception):
with pytest.raises(expected):
CommonBooleanEvaluator().evaluate(expression, values)
else:
assert CommonBooleanEvaluator().evaluate(expression, values) == expected
Loading

0 comments on commit feac5c0

Please sign in to comment.