diff --git a/jishaku/repl/walkers.py b/jishaku/repl/walkers.py index aa991c77..f681f59d 100644 --- a/jishaku/repl/walkers.py +++ b/jishaku/repl/walkers.py @@ -23,19 +23,19 @@ class KeywordTransformer(ast.NodeTransformer): - Converts bare deletes into conditional global pops """ - def visit_FunctionDef(self, node): + def visit_FunctionDef(self, node: ast.FunctionDef): # Do not affect nested function definitions return node - def visit_AsyncFunctionDef(self, node): + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef): # Do not affect nested async function definitions return node - def visit_ClassDef(self, node): + def visit_ClassDef(self, node: ast.ClassDef): # Do not affect nested class definitions return node - def visit_Return(self, node): + def visit_Return(self, node: ast.Return): # Do not modify valueless returns if node.value is None: return node @@ -70,7 +70,59 @@ def visit_Return(self, node): col_offset=node.col_offset ) - def visit_Delete(self, node): + def visit_Name(self, node: ast.Name): + """ + This converter replaces bare loads with conditional global indices + + It is roughly equivalent to transforming: + + .. code:: python + + x = x + 1 + + into: + + .. code:: python + + x = (x if 'x' in locals() else globals()['x']) + 1 + + This makes reassignments work better as they don't fire UnboundLocalError + """ + + if not isinstance(node.ctx, ast.Load): + return node + + return ast.IfExp( + test=ast.Compare( + left=ast.Constant( + value=node.id, + kind=None, + lineno=node.lineno, + col_offset=node.col_offset + ), + ops=[ast.In()], + comparators=[self.func_call(node, 'locals')], + lineno=node.lineno, + col_offset=node.col_offset + ), + body=node, + orelse=ast.Subscript( + value=self.func_call(node, 'globals'), + slice=ast.Constant( + value=node.id, + kind=None, + lineno=node.lineno, + col_offset=node.col_offset + ), + ctx=ast.Load(), + lineno=node.lineno, + col_offset=node.col_offset + ), + lineno=node.lineno, + col_offset=node.col_offset + ) + + def visit_Delete(self, node: ast.Delete): """ This converter replaces bare deletes with conditional global pops. @@ -117,7 +169,7 @@ def visit_Delete(self, node): ], comparators=[ # globals() - self.globals_call(node) + self.func_call(node, 'globals') ], lineno=node.lineno, col_offset=node.col_offset @@ -128,7 +180,7 @@ def visit_Delete(self, node): value=ast.Call( # globals().pop func=ast.Attribute( - value=self.globals_call(node), + value=self.func_call(node, 'globals'), attr='pop', ctx=ast.Load(), lineno=node.lineno, @@ -176,19 +228,26 @@ def visit_Delete(self, node): col_offset=node.col_offset ) - def globals_call(self, node): + def func_call(self, node: ast.AST, name: str, *args: str): """ - Creates an AST node that calls globals(). + Creates an AST node that calls a function. """ return ast.Call( func=ast.Name( - id='globals', + id=name, ctx=ast.Load(), lineno=node.lineno, col_offset=node.col_offset ), - args=[], + args=[ + ast.Name( + id=arg, + ctx=ast.Load(), + lineno=node.lineno, + col_offset=node.col_offset + ) for arg in args + ], keywords=[], lineno=node.lineno, col_offset=node.col_offset