Skip to content

Commit

Permalink
Convert bare loads to conditional global subscripts (#57)
Browse files Browse the repository at this point in the history
  • Loading branch information
scarletcafe committed Apr 1, 2022
1 parent 1da71e9 commit fd5a998
Showing 1 changed file with 70 additions and 11 deletions.
81 changes: 70 additions & 11 deletions jishaku/repl/walkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,19 @@ class KeywordTransformer(ast.NodeTransformer):
- Converts bare deletes into conditional global pops
"""

def visit_FunctionDef(self, node):
def visit_FunctionDef(self, node: ast.FunctionDef):
# Do not affect nested function definitions
return node

def visit_AsyncFunctionDef(self, node):
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
# Do not affect nested async function definitions
return node

def visit_ClassDef(self, node):
def visit_ClassDef(self, node: ast.ClassDef):
# Do not affect nested class definitions
return node

def visit_Return(self, node):
def visit_Return(self, node: ast.Return):
# Do not modify valueless returns
if node.value is None:
return node
Expand Down Expand Up @@ -70,7 +70,59 @@ def visit_Return(self, node):
col_offset=node.col_offset
)

def visit_Delete(self, node):
def visit_Name(self, node: ast.Name):
"""
This converter replaces bare loads with conditional global indices
It is roughly equivalent to transforming:
.. code:: python
x = x + 1
into:
.. code:: python
x = (x if 'x' in locals() else globals()['x']) + 1
This makes reassignments work better as they don't fire UnboundLocalError
"""

if not isinstance(node.ctx, ast.Load):
return node

return ast.IfExp(
test=ast.Compare(
left=ast.Constant(
value=node.id,
kind=None,
lineno=node.lineno,
col_offset=node.col_offset
),
ops=[ast.In()],
comparators=[self.func_call(node, 'locals')],
lineno=node.lineno,
col_offset=node.col_offset
),
body=node,
orelse=ast.Subscript(
value=self.func_call(node, 'globals'),
slice=ast.Constant(
value=node.id,
kind=None,
lineno=node.lineno,
col_offset=node.col_offset
),
ctx=ast.Load(),
lineno=node.lineno,
col_offset=node.col_offset
),
lineno=node.lineno,
col_offset=node.col_offset
)

def visit_Delete(self, node: ast.Delete):
"""
This converter replaces bare deletes with conditional global pops.
Expand Down Expand Up @@ -117,7 +169,7 @@ def visit_Delete(self, node):
],
comparators=[
# globals()
self.globals_call(node)
self.func_call(node, 'globals')
],
lineno=node.lineno,
col_offset=node.col_offset
Expand All @@ -128,7 +180,7 @@ def visit_Delete(self, node):
value=ast.Call(
# globals().pop
func=ast.Attribute(
value=self.globals_call(node),
value=self.func_call(node, 'globals'),
attr='pop',
ctx=ast.Load(),
lineno=node.lineno,
Expand Down Expand Up @@ -176,19 +228,26 @@ def visit_Delete(self, node):
col_offset=node.col_offset
)

def globals_call(self, node):
def func_call(self, node: ast.AST, name: str, *args: str):
"""
Creates an AST node that calls globals().
Creates an AST node that calls a function.
"""

return ast.Call(
func=ast.Name(
id='globals',
id=name,
ctx=ast.Load(),
lineno=node.lineno,
col_offset=node.col_offset
),
args=[],
args=[
ast.Name(
id=arg,
ctx=ast.Load(),
lineno=node.lineno,
col_offset=node.col_offset
) for arg in args
],
keywords=[],
lineno=node.lineno,
col_offset=node.col_offset
Expand Down

0 comments on commit fd5a998

Please sign in to comment.