-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #15 from valohai/dashes
Add tooling for dealing with dashed identifiers; port over common boolean evaluator
- Loading branch information
Showing
9 changed files
with
378 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
""" | ||
The extras subpackage contains convenience classes. | ||
These have no stability guarantees and are not part of the main API. | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
import keyword | ||
import tokenize | ||
from typing import Any, Dict, List, Optional, Tuple, Union | ||
|
||
from leval.excs import NoSuchFunction | ||
from leval.rewriter_evaluator import RewriterEvaluator | ||
from leval.rewriter_utils import ( | ||
convert_dash_identifiers, | ||
get_parts_from_dashed_identifier_tokens, | ||
make_glued_name_token, | ||
) | ||
from leval.universe.verifier import VerifierUniverse | ||
from leval.universe.weakly_typed import WeaklyTypedSimpleUniverse | ||
|
||
DEFAULT_FUNCTIONS = { | ||
"abs": abs, | ||
"min": min, | ||
"max": max, | ||
} | ||
|
||
ValuesDict = Dict[Union[Tuple[str, ...], str], Any] | ||
|
||
KEYWORD_PREFIX = "K\u203f" | ||
DASH_SEP = "\u203f\u203f" | ||
|
||
|
||
def _rewrite_keyword(kw: str) -> str: | ||
if keyword.iskeyword(kw): | ||
return f"{KEYWORD_PREFIX}{kw}" | ||
return kw | ||
|
||
|
||
def _convert_dash_tokens(tokens: List[tokenize.TokenInfo]): | ||
if not tokens: | ||
return [] | ||
glued_name = "".join( | ||
get_parts_from_dashed_identifier_tokens(tokens, separator=DASH_SEP), | ||
) | ||
return [make_glued_name_token(tokens, glued_name)] | ||
|
||
|
||
def _prepare_name(name: str) -> str: | ||
return DASH_SEP.join(_rewrite_keyword(p) for p in name.split("-")) | ||
|
||
|
||
class _CommonEvaluator(RewriterEvaluator): | ||
def rewrite_keyword(self, kw: str) -> str: | ||
return _rewrite_keyword(kw) | ||
|
||
def process_tokens(self, tokens): | ||
return convert_dash_identifiers(tokens, _convert_dash_tokens) | ||
|
||
|
||
class _CommonUniverse(WeaklyTypedSimpleUniverse): | ||
def evaluate_function(self, name, arg_getters): | ||
func = self.functions.get(name) | ||
if not func: | ||
raise NoSuchFunction(f"No function {name}") | ||
args = [getter() for getter in arg_getters] | ||
for arg in args: | ||
# This is using `type(...)` on purpose; we don't want to allow subclasses. | ||
if type(arg) not in (int, float, str, bool): | ||
raise TypeError(f"Invalid argument for {name}: {type(arg)}") | ||
return func(*args) | ||
|
||
|
||
def _prepare_values(values: ValuesDict) -> ValuesDict: | ||
""" | ||
Prepare a values dictionary by rewriting names like the evaluation would. | ||
""" | ||
prepared_values = {} | ||
for key, value in values.items(): | ||
if isinstance(key, tuple): | ||
key = tuple(_prepare_name(p) for p in key) | ||
elif isinstance(key, str): | ||
key = _prepare_name(key) | ||
else: | ||
raise TypeError(f"Invalid key type: {type(key)}") | ||
prepared_values[key] = value | ||
return prepared_values | ||
|
||
|
||
class CommonBooleanEvaluator: | ||
functions: dict = DEFAULT_FUNCTIONS | ||
max_depth: int = 8 | ||
max_time: float = 0.2 | ||
verifier_universe_class = VerifierUniverse | ||
universe_class = _CommonUniverse | ||
evaluator_class = _CommonEvaluator | ||
|
||
def evaluate(self, expr: Optional[str], values: ValuesDict) -> Optional[bool]: | ||
""" | ||
Evaluate the given expression against the given values. | ||
The values dictionary's keys will be prepared to the expected internal format. | ||
""" | ||
if not expr: | ||
return None | ||
universe = self.universe_class( | ||
functions=self.functions, | ||
values=_prepare_values(values), | ||
) | ||
evl = self.evaluator_class( | ||
universe, | ||
max_depth=self.max_depth, | ||
max_time=self.max_time, | ||
) | ||
return bool(evl.evaluate_expression(expr)) | ||
|
||
def verify(self, expression: str) -> bool: | ||
""" | ||
Verify that the given expression is technically valid. | ||
""" | ||
evl = self.evaluator_class( | ||
self.verifier_universe_class(), | ||
max_depth=self.max_depth, | ||
) | ||
return evl.evaluate_expression(expression) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import tokenize | ||
from typing import Callable, Iterable, Iterator, List, Optional | ||
|
||
|
||
def tokens_are_adjacent(t1: tokenize.TokenInfo, t2: tokenize.TokenInfo) -> bool: | ||
""" | ||
Return True if the two tokens are adjacent in the source code. | ||
That is, token 1 ends exactly where token 2 starts. | ||
""" | ||
return t1.end[0] == t2.start[0] and t1.end[1] == t2.start[1] | ||
|
||
|
||
def make_glued_name_token(tokens: List[tokenize.TokenInfo], name: str): | ||
""" | ||
Create a new token for an identifier that spans the given token position range. | ||
It does not validate that the resulting token is actually a valid Python identifier, | ||
or that the range is actually valid (`name` could be longer than the range of the | ||
original tokens). | ||
""" | ||
stok = tokens[0] | ||
etok = tokens[-1] | ||
return tokenize.TokenInfo( | ||
tokenize.NAME, | ||
name, | ||
stok.start, | ||
etok.end, | ||
stok.line, | ||
) | ||
|
||
|
||
def get_parts_from_dashed_identifier_tokens( | ||
tokens: Iterable[tokenize.TokenInfo], | ||
separator: Optional[str] = None, | ||
) -> Iterable[str]: | ||
""" | ||
Yield the parts of a dashed identifier from the given token stream. | ||
If `separator` is set, it is yielded for dashes in the identifier. | ||
""" | ||
for tok in tokens: | ||
if tok.type in (tokenize.NAME, tokenize.NUMBER): | ||
yield tok.string | ||
elif tok.type == tokenize.OP and tok.string == "-": | ||
if separator: | ||
yield separator | ||
continue | ||
else: | ||
raise SyntaxError("Invalid token") | ||
|
||
|
||
def _maybe_process_dash_identifier( | ||
initial_token: tokenize.TokenInfo, | ||
tok_iter: Iterator[tokenize.TokenInfo], | ||
converter: Callable[[List[tokenize.TokenInfo]], Iterable[tokenize.TokenInfo]], | ||
): | ||
tokens = [initial_token] | ||
while True: | ||
tok = next(tok_iter, None) | ||
if tok is None: | ||
break | ||
if not tokens_are_adjacent(tokens[-1], tok): | ||
break | ||
if ( | ||
tok.type == tokenize.NAME | ||
or (tok.type == tokenize.OP and tok.string == "-") | ||
or (tok.type == tokenize.NUMBER and tok.string.isdigit()) | ||
): | ||
tokens.append(tok) | ||
else: | ||
break | ||
if tokens: # Yield the converted token(s) if there are tokens to convert. | ||
if tokens[-1].type == tokenize.OP: # ended with a dash? no conversion | ||
yield from tokens | ||
else: | ||
yield from converter(tokens) | ||
if tok: # Yield the last token that broke the loop. | ||
yield tok | ||
|
||
|
||
def convert_dash_identifiers( | ||
tokens: List[tokenize.TokenInfo], | ||
converter: Callable[[List[tokenize.TokenInfo]], Iterable[tokenize.TokenInfo]], | ||
) -> Iterable[tokenize.TokenInfo]: | ||
""" | ||
Convert dashed identifiers in the given token stream. | ||
In particular, converts e.g. `foo-bar-baz-quux` that is actually | ||
`NAME OP(-) NAME (...)` with no spaces in between, into a single | ||
token via the given converter function. | ||
""" | ||
tok_iter = iter(tokens) | ||
|
||
while True: | ||
tok = next(tok_iter, None) | ||
if tok is None: | ||
break | ||
if tok.type == tokenize.NAME: # Could be the start of a dashed identifier. | ||
yield from _maybe_process_dash_identifier(tok, tok_iter, converter) | ||
continue | ||
yield tok |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import pytest | ||
|
||
from leval.excs import InvalidOperation, NoSuchValue, TooComplex | ||
from leval.extras.common_boolean_evaluator import CommonBooleanEvaluator | ||
|
||
|
||
@pytest.mark.parametrize( | ||
("expression", "exception"), | ||
[ | ||
("+".join("a" * 500), TooComplex), | ||
("b <", SyntaxError), | ||
("os.system()", InvalidOperation), | ||
], | ||
) | ||
def test_validation(expression, exception): | ||
with pytest.raises(exception): | ||
CommonBooleanEvaluator().verify(expression) | ||
with pytest.raises(exception): | ||
CommonBooleanEvaluator().evaluate(expression, {}) | ||
|
||
|
||
test_vars = { | ||
("foo", "baz-quux"): 9, | ||
"continue": True, | ||
"v1": 74, | ||
"v2": 42, | ||
} | ||
|
||
|
||
@pytest.mark.parametrize( | ||
("expression", "values", "expected"), | ||
[ | ||
("foo.baz-quux > 8", test_vars, True), # dash in name | ||
("foo.baz - quux > 8", test_vars, NoSuchValue), | ||
("continue or not pause", test_vars, True), # keyword | ||
("cookie is None", test_vars, True), # loose "is" | ||
("not class", test_vars, True), # loose "not" | ||
("min(v1, v2) < 50", test_vars, True), | ||
("max(v1, v2) > 50", test_vars, True), | ||
("max()", test_vars, TypeError), | ||
("max((1,2,3))", test_vars, TypeError), # Invalid argument type (tuple) | ||
], | ||
) | ||
def test_expressions(expression, values, expected): | ||
if isinstance(expected, type) and issubclass(expected, Exception): | ||
with pytest.raises(expected): | ||
CommonBooleanEvaluator().evaluate(expression, values) | ||
else: | ||
assert CommonBooleanEvaluator().evaluate(expression, values) == expected |
Oops, something went wrong.