Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-129598: allow multi stmts for ast single with ';' #129620

Merged
merged 12 commits into from
Mar 19, 2025
71 changes: 45 additions & 26 deletions Lib/ast.py
Original file line number Diff line number Diff line change
@@ -674,6 +674,7 @@ def __init__(self):
self._type_ignores = {}
self._indent = 0
self._in_try_star = False
self._in_interactive = False

def interleave(self, inter, f, seq):
"""Call f on each item in seq, calling inter() in between."""
@@ -702,11 +703,20 @@ def maybe_newline(self):
if self._source:
self.write("\n")

def fill(self, text=""):
def maybe_semicolon(self):
"""Adds a "; " delimiter if it isn't the start of generated source"""
if self._source:
self.write("; ")

def fill(self, text="", *, allow_semicolon=True):
"""Indent a piece of text and append it, according to the current
indentation level"""
self.maybe_newline()
self.write(" " * self._indent + text)
indentation level, or only delineate with semicolon if applicable"""
if self._in_interactive and not self._indent and allow_semicolon:
self.maybe_semicolon()
self.write(text)
else:
self.maybe_newline()
self.write(" " * self._indent + text)

def write(self, *text):
"""Add new source parts"""
@@ -812,8 +822,17 @@ def visit_Module(self, node):
ignore.lineno: f"ignore{ignore.tag}"
for ignore in node.type_ignores
}
self._write_docstring_and_traverse_body(node)
self._type_ignores.clear()
try:
self._write_docstring_and_traverse_body(node)
finally:
self._type_ignores.clear()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another question is whether the mapping precedences should be cleared in visit() as well.


def visit_Interactive(self, node):
self._in_interactive = True
try:
self._write_docstring_and_traverse_body(node)
finally:
self._in_interactive = False

def visit_FunctionType(self, node):
with self.delimit("(", ")"):
@@ -945,17 +964,17 @@ def visit_Raise(self, node):
self.traverse(node.cause)

def do_visit_try(self, node):
self.fill("try")
self.fill("try", allow_semicolon=False)
with self.block():
self.traverse(node.body)
for ex in node.handlers:
self.traverse(ex)
if node.orelse:
self.fill("else")
self.fill("else", allow_semicolon=False)
with self.block():
self.traverse(node.orelse)
if node.finalbody:
self.fill("finally")
self.fill("finally", allow_semicolon=False)
with self.block():
self.traverse(node.finalbody)

@@ -976,7 +995,7 @@ def visit_TryStar(self, node):
self._in_try_star = prev_in_try_star

def visit_ExceptHandler(self, node):
self.fill("except*" if self._in_try_star else "except")
self.fill("except*" if self._in_try_star else "except", allow_semicolon=False)
if node.type:
self.write(" ")
self.traverse(node.type)
@@ -989,9 +1008,9 @@ def visit_ExceptHandler(self, node):
def visit_ClassDef(self, node):
self.maybe_newline()
for deco in node.decorator_list:
self.fill("@")
self.fill("@", allow_semicolon=False)
self.traverse(deco)
self.fill("class " + node.name)
self.fill("class " + node.name, allow_semicolon=False)
if hasattr(node, "type_params"):
self._type_params_helper(node.type_params)
with self.delimit_if("(", ")", condition = node.bases or node.keywords):
@@ -1021,10 +1040,10 @@ def visit_AsyncFunctionDef(self, node):
def _function_helper(self, node, fill_suffix):
self.maybe_newline()
for deco in node.decorator_list:
self.fill("@")
self.fill("@", allow_semicolon=False)
self.traverse(deco)
def_str = fill_suffix + " " + node.name
self.fill(def_str)
self.fill(def_str, allow_semicolon=False)
if hasattr(node, "type_params"):
self._type_params_helper(node.type_params)
with self.delimit("(", ")"):
@@ -1075,54 +1094,54 @@ def visit_AsyncFor(self, node):
self._for_helper("async for ", node)

def _for_helper(self, fill, node):
self.fill(fill)
self.fill(fill, allow_semicolon=False)
self.set_precedence(_Precedence.TUPLE, node.target)
self.traverse(node.target)
self.write(" in ")
self.traverse(node.iter)
with self.block(extra=self.get_type_comment(node)):
self.traverse(node.body)
if node.orelse:
self.fill("else")
self.fill("else", allow_semicolon=False)
with self.block():
self.traverse(node.orelse)

def visit_If(self, node):
self.fill("if ")
self.fill("if ", allow_semicolon=False)
self.traverse(node.test)
with self.block():
self.traverse(node.body)
# collapse nested ifs into equivalent elifs.
while node.orelse and len(node.orelse) == 1 and isinstance(node.orelse[0], If):
node = node.orelse[0]
self.fill("elif ")
self.fill("elif ", allow_semicolon=False)
self.traverse(node.test)
with self.block():
self.traverse(node.body)
# final else
if node.orelse:
self.fill("else")
self.fill("else", allow_semicolon=False)
with self.block():
self.traverse(node.orelse)

def visit_While(self, node):
self.fill("while ")
self.fill("while ", allow_semicolon=False)
self.traverse(node.test)
with self.block():
self.traverse(node.body)
if node.orelse:
self.fill("else")
self.fill("else", allow_semicolon=False)
with self.block():
self.traverse(node.orelse)

def visit_With(self, node):
self.fill("with ")
self.fill("with ", allow_semicolon=False)
self.interleave(lambda: self.write(", "), self.traverse, node.items)
with self.block(extra=self.get_type_comment(node)):
self.traverse(node.body)

def visit_AsyncWith(self, node):
self.fill("async with ")
self.fill("async with ", allow_semicolon=False)
self.interleave(lambda: self.write(", "), self.traverse, node.items)
with self.block(extra=self.get_type_comment(node)):
self.traverse(node.body)
@@ -1264,7 +1283,7 @@ def visit_Name(self, node):
self.write(node.id)

def _write_docstring(self, node):
self.fill()
self.fill(allow_semicolon=False)
if node.kind == "u":
self.write("u")
self._write_str_avoiding_backslashes(node.value, quote_types=_MULTI_QUOTES)
@@ -1558,7 +1577,7 @@ def visit_Slice(self, node):
self.traverse(node.step)

def visit_Match(self, node):
self.fill("match ")
self.fill("match ", allow_semicolon=False)
self.traverse(node.subject)
with self.block():
for case in node.cases:
@@ -1652,7 +1671,7 @@ def visit_withitem(self, node):
self.traverse(node.optional_vars)

def visit_match_case(self, node):
self.fill("case ")
self.fill("case ", allow_semicolon=False)
self.traverse(node.pattern)
if node.guard:
self.write(" if ")
122 changes: 118 additions & 4 deletions Lib/test/test_unparse.py
Original file line number Diff line number Diff line change
@@ -142,13 +142,13 @@ def check_invalid(self, node, raises=ValueError):
with self.subTest(node=node):
self.assertRaises(raises, ast.unparse, node)

def get_source(self, code1, code2=None):
def get_source(self, code1, code2=None, **kwargs):
code2 = code2 or code1
code1 = ast.unparse(ast.parse(code1))
code1 = ast.unparse(ast.parse(code1, **kwargs))
return code1, code2

def check_src_roundtrip(self, code1, code2=None):
code1, code2 = self.get_source(code1, code2)
def check_src_roundtrip(self, code1, code2=None, **kwargs):
code1, code2 = self.get_source(code1, code2, **kwargs)
with self.subTest(code1=code1, code2=code2):
self.assertEqual(code2, code1)

@@ -469,6 +469,120 @@ def test_type_ignore(self):
):
self.check_ast_roundtrip(statement, type_comments=True)

def test_unparse_interactive_semicolons(self):
# gh-129598: Fix ast.unparse() when ast.Interactive contains multiple statements
self.check_src_roundtrip("i = 1; 'expr'; raise Exception", mode='single')
self.check_src_roundtrip("i: int = 1; j: float = 0; k += l", mode='single')
combinable = (
"'expr'",
"(i := 1)",
"import foo",
"from foo import bar",
"i = 1",
"i += 1",
"i: int = 1",
"return i",
"pass",
"break",
"continue",
"del i",
"assert i",
"global i",
"nonlocal j",
"await i",
"yield i",
"yield from i",
"raise i",
"type t[T] = ...",
"i",
)
for a in combinable:
for b in combinable:
self.check_src_roundtrip(f"{a}; {b}", mode='single')

def test_unparse_interactive_integrity_1(self):
# rest of unparse_interactive_integrity tests just make sure mode='single' parse and unparse didn't break
self.check_src_roundtrip(
"if i:\n 'expr'\nelse:\n raise Exception",
"if i:\n 'expr'\nelse:\n raise Exception",
mode='single'
)
self.check_src_roundtrip(
"@decorator1\n@decorator2\ndef func():\n 'docstring'\n i = 1; 'expr'; raise Exception",
'''@decorator1\n@decorator2\ndef func():\n """docstring"""\n i = 1\n 'expr'\n raise Exception''',
mode='single'
)
self.check_src_roundtrip(
"@decorator1\n@decorator2\nclass cls:\n 'docstring'\n i = 1; 'expr'; raise Exception",
'''@decorator1\n@decorator2\nclass cls:\n """docstring"""\n i = 1\n 'expr'\n raise Exception''',
mode='single'
)

def test_unparse_interactive_integrity_2(self):
for statement in (
"def x():\n pass",
"def x(y):\n pass",
"async def x():\n pass",
"async def x(y):\n pass",
"for x in y:\n pass",
"async for x in y:\n pass",
"with x():\n pass",
"async with x():\n pass",
"def f():\n pass",
"def f(a):\n pass",
"def f(b=2):\n pass",
"def f(a, b):\n pass",
"def f(a, b=2):\n pass",
"def f(a=5, b=2):\n pass",
"def f(*, a=1, b=2):\n pass",
"def f(*, a=1, b):\n pass",
"def f(*, a, b=2):\n pass",
"def f(a, b=None, *, c, **kwds):\n pass",
"def f(a=2, *args, c=5, d, **kwds):\n pass",
"def f(*args, **kwargs):\n pass",
"class cls:\n\n def f(self):\n pass",
"class cls:\n\n def f(self, a):\n pass",
"class cls:\n\n def f(self, b=2):\n pass",
"class cls:\n\n def f(self, a, b):\n pass",
"class cls:\n\n def f(self, a, b=2):\n pass",
"class cls:\n\n def f(self, a=5, b=2):\n pass",
"class cls:\n\n def f(self, *, a=1, b=2):\n pass",
"class cls:\n\n def f(self, *, a=1, b):\n pass",
"class cls:\n\n def f(self, *, a, b=2):\n pass",
"class cls:\n\n def f(self, a, b=None, *, c, **kwds):\n pass",
"class cls:\n\n def f(self, a=2, *args, c=5, d, **kwds):\n pass",
"class cls:\n\n def f(self, *args, **kwargs):\n pass",
):
self.check_src_roundtrip(statement, mode='single')

def test_unparse_interactive_integrity_3(self):
for statement in (
"def x():",
"def x(y):",
"async def x():",
"async def x(y):",
"for x in y:",
"async for x in y:",
"with x():",
"async with x():",
"def f():",
"def f(a):",
"def f(b=2):",
"def f(a, b):",
"def f(a, b=2):",
"def f(a=5, b=2):",
"def f(*, a=1, b=2):",
"def f(*, a=1, b):",
"def f(*, a, b=2):",
"def f(a, b=None, *, c, **kwds):",
"def f(a=2, *args, c=5, d, **kwds):",
"def f(*args, **kwargs):",
):
src = statement + '\n i=1;j=2'
out = statement + '\n i = 1\n j = 2'

self.check_src_roundtrip(src, out, mode='single')


class CosmeticTestCase(ASTTestCase):
"""Test if there are cosmetic issues caused by unnecessary additions"""
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix :func:`ast.unparse` when :class:`ast.Interactive` contains multiple statements.
Loading