diff --git a/Lib/ast.py b/Lib/ast.py index 0937c27bdf8a11..cb1f8dfe128ead 100644 --- a/Lib/ast.py +++ b/Lib/ast.py @@ -674,6 +674,7 @@ def __init__(self): self._type_ignores = {} self._indent = 0 self._in_try_star = False + self._in_interactive = False def interleave(self, inter, f, seq): """Call f on each item in seq, calling inter() in between.""" @@ -702,11 +703,20 @@ def maybe_newline(self): if self._source: self.write("\n") - def fill(self, text=""): + def maybe_semicolon(self): + """Adds a "; " delimiter if it isn't the start of generated source""" + if self._source: + self.write("; ") + + def fill(self, text="", *, allow_semicolon=True): """Indent a piece of text and append it, according to the current - indentation level""" - self.maybe_newline() - self.write(" " * self._indent + text) + indentation level, or only delineate with semicolon if applicable""" + if self._in_interactive and not self._indent and allow_semicolon: + self.maybe_semicolon() + self.write(text) + else: + self.maybe_newline() + self.write(" " * self._indent + text) def write(self, *text): """Add new source parts""" @@ -812,8 +822,17 @@ def visit_Module(self, node): ignore.lineno: f"ignore{ignore.tag}" for ignore in node.type_ignores } - self._write_docstring_and_traverse_body(node) - self._type_ignores.clear() + try: + self._write_docstring_and_traverse_body(node) + finally: + self._type_ignores.clear() + + def visit_Interactive(self, node): + self._in_interactive = True + try: + self._write_docstring_and_traverse_body(node) + finally: + self._in_interactive = False def visit_FunctionType(self, node): with self.delimit("(", ")"): @@ -945,17 +964,17 @@ def visit_Raise(self, node): self.traverse(node.cause) def do_visit_try(self, node): - self.fill("try") + self.fill("try", allow_semicolon=False) with self.block(): self.traverse(node.body) for ex in node.handlers: self.traverse(ex) if node.orelse: - self.fill("else") + self.fill("else", allow_semicolon=False) with self.block(): self.traverse(node.orelse) if node.finalbody: - self.fill("finally") + self.fill("finally", allow_semicolon=False) with self.block(): self.traverse(node.finalbody) @@ -976,7 +995,7 @@ def visit_TryStar(self, node): self._in_try_star = prev_in_try_star def visit_ExceptHandler(self, node): - self.fill("except*" if self._in_try_star else "except") + self.fill("except*" if self._in_try_star else "except", allow_semicolon=False) if node.type: self.write(" ") self.traverse(node.type) @@ -989,9 +1008,9 @@ def visit_ExceptHandler(self, node): def visit_ClassDef(self, node): self.maybe_newline() for deco in node.decorator_list: - self.fill("@") + self.fill("@", allow_semicolon=False) self.traverse(deco) - self.fill("class " + node.name) + self.fill("class " + node.name, allow_semicolon=False) if hasattr(node, "type_params"): self._type_params_helper(node.type_params) with self.delimit_if("(", ")", condition = node.bases or node.keywords): @@ -1021,10 +1040,10 @@ def visit_AsyncFunctionDef(self, node): def _function_helper(self, node, fill_suffix): self.maybe_newline() for deco in node.decorator_list: - self.fill("@") + self.fill("@", allow_semicolon=False) self.traverse(deco) def_str = fill_suffix + " " + node.name - self.fill(def_str) + self.fill(def_str, allow_semicolon=False) if hasattr(node, "type_params"): self._type_params_helper(node.type_params) with self.delimit("(", ")"): @@ -1075,7 +1094,7 @@ def visit_AsyncFor(self, node): self._for_helper("async for ", node) def _for_helper(self, fill, node): - self.fill(fill) + self.fill(fill, allow_semicolon=False) self.set_precedence(_Precedence.TUPLE, node.target) self.traverse(node.target) self.write(" in ") @@ -1083,46 +1102,46 @@ def _for_helper(self, fill, node): with self.block(extra=self.get_type_comment(node)): self.traverse(node.body) if node.orelse: - self.fill("else") + self.fill("else", allow_semicolon=False) with self.block(): self.traverse(node.orelse) def visit_If(self, node): - self.fill("if ") + self.fill("if ", allow_semicolon=False) self.traverse(node.test) with self.block(): self.traverse(node.body) # collapse nested ifs into equivalent elifs. while node.orelse and len(node.orelse) == 1 and isinstance(node.orelse[0], If): node = node.orelse[0] - self.fill("elif ") + self.fill("elif ", allow_semicolon=False) self.traverse(node.test) with self.block(): self.traverse(node.body) # final else if node.orelse: - self.fill("else") + self.fill("else", allow_semicolon=False) with self.block(): self.traverse(node.orelse) def visit_While(self, node): - self.fill("while ") + self.fill("while ", allow_semicolon=False) self.traverse(node.test) with self.block(): self.traverse(node.body) if node.orelse: - self.fill("else") + self.fill("else", allow_semicolon=False) with self.block(): self.traverse(node.orelse) def visit_With(self, node): - self.fill("with ") + self.fill("with ", allow_semicolon=False) self.interleave(lambda: self.write(", "), self.traverse, node.items) with self.block(extra=self.get_type_comment(node)): self.traverse(node.body) def visit_AsyncWith(self, node): - self.fill("async with ") + self.fill("async with ", allow_semicolon=False) self.interleave(lambda: self.write(", "), self.traverse, node.items) with self.block(extra=self.get_type_comment(node)): self.traverse(node.body) @@ -1264,7 +1283,7 @@ def visit_Name(self, node): self.write(node.id) def _write_docstring(self, node): - self.fill() + self.fill(allow_semicolon=False) if node.kind == "u": self.write("u") self._write_str_avoiding_backslashes(node.value, quote_types=_MULTI_QUOTES) @@ -1558,7 +1577,7 @@ def visit_Slice(self, node): self.traverse(node.step) def visit_Match(self, node): - self.fill("match ") + self.fill("match ", allow_semicolon=False) self.traverse(node.subject) with self.block(): for case in node.cases: @@ -1652,7 +1671,7 @@ def visit_withitem(self, node): self.traverse(node.optional_vars) def visit_match_case(self, node): - self.fill("case ") + self.fill("case ", allow_semicolon=False) self.traverse(node.pattern) if node.guard: self.write(" if ") diff --git a/Lib/test/test_unparse.py b/Lib/test/test_unparse.py index 686649a520880e..839326f6436809 100644 --- a/Lib/test/test_unparse.py +++ b/Lib/test/test_unparse.py @@ -142,13 +142,13 @@ def check_invalid(self, node, raises=ValueError): with self.subTest(node=node): self.assertRaises(raises, ast.unparse, node) - def get_source(self, code1, code2=None): + def get_source(self, code1, code2=None, **kwargs): code2 = code2 or code1 - code1 = ast.unparse(ast.parse(code1)) + code1 = ast.unparse(ast.parse(code1, **kwargs)) return code1, code2 - def check_src_roundtrip(self, code1, code2=None): - code1, code2 = self.get_source(code1, code2) + def check_src_roundtrip(self, code1, code2=None, **kwargs): + code1, code2 = self.get_source(code1, code2, **kwargs) with self.subTest(code1=code1, code2=code2): self.assertEqual(code2, code1) @@ -469,6 +469,120 @@ def test_type_ignore(self): ): self.check_ast_roundtrip(statement, type_comments=True) + def test_unparse_interactive_semicolons(self): + # gh-129598: Fix ast.unparse() when ast.Interactive contains multiple statements + self.check_src_roundtrip("i = 1; 'expr'; raise Exception", mode='single') + self.check_src_roundtrip("i: int = 1; j: float = 0; k += l", mode='single') + combinable = ( + "'expr'", + "(i := 1)", + "import foo", + "from foo import bar", + "i = 1", + "i += 1", + "i: int = 1", + "return i", + "pass", + "break", + "continue", + "del i", + "assert i", + "global i", + "nonlocal j", + "await i", + "yield i", + "yield from i", + "raise i", + "type t[T] = ...", + "i", + ) + for a in combinable: + for b in combinable: + self.check_src_roundtrip(f"{a}; {b}", mode='single') + + def test_unparse_interactive_integrity_1(self): + # rest of unparse_interactive_integrity tests just make sure mode='single' parse and unparse didn't break + self.check_src_roundtrip( + "if i:\n 'expr'\nelse:\n raise Exception", + "if i:\n 'expr'\nelse:\n raise Exception", + mode='single' + ) + self.check_src_roundtrip( + "@decorator1\n@decorator2\ndef func():\n 'docstring'\n i = 1; 'expr'; raise Exception", + '''@decorator1\n@decorator2\ndef func():\n """docstring"""\n i = 1\n 'expr'\n raise Exception''', + mode='single' + ) + self.check_src_roundtrip( + "@decorator1\n@decorator2\nclass cls:\n 'docstring'\n i = 1; 'expr'; raise Exception", + '''@decorator1\n@decorator2\nclass cls:\n """docstring"""\n i = 1\n 'expr'\n raise Exception''', + mode='single' + ) + + def test_unparse_interactive_integrity_2(self): + for statement in ( + "def x():\n pass", + "def x(y):\n pass", + "async def x():\n pass", + "async def x(y):\n pass", + "for x in y:\n pass", + "async for x in y:\n pass", + "with x():\n pass", + "async with x():\n pass", + "def f():\n pass", + "def f(a):\n pass", + "def f(b=2):\n pass", + "def f(a, b):\n pass", + "def f(a, b=2):\n pass", + "def f(a=5, b=2):\n pass", + "def f(*, a=1, b=2):\n pass", + "def f(*, a=1, b):\n pass", + "def f(*, a, b=2):\n pass", + "def f(a, b=None, *, c, **kwds):\n pass", + "def f(a=2, *args, c=5, d, **kwds):\n pass", + "def f(*args, **kwargs):\n pass", + "class cls:\n\n def f(self):\n pass", + "class cls:\n\n def f(self, a):\n pass", + "class cls:\n\n def f(self, b=2):\n pass", + "class cls:\n\n def f(self, a, b):\n pass", + "class cls:\n\n def f(self, a, b=2):\n pass", + "class cls:\n\n def f(self, a=5, b=2):\n pass", + "class cls:\n\n def f(self, *, a=1, b=2):\n pass", + "class cls:\n\n def f(self, *, a=1, b):\n pass", + "class cls:\n\n def f(self, *, a, b=2):\n pass", + "class cls:\n\n def f(self, a, b=None, *, c, **kwds):\n pass", + "class cls:\n\n def f(self, a=2, *args, c=5, d, **kwds):\n pass", + "class cls:\n\n def f(self, *args, **kwargs):\n pass", + ): + self.check_src_roundtrip(statement, mode='single') + + def test_unparse_interactive_integrity_3(self): + for statement in ( + "def x():", + "def x(y):", + "async def x():", + "async def x(y):", + "for x in y:", + "async for x in y:", + "with x():", + "async with x():", + "def f():", + "def f(a):", + "def f(b=2):", + "def f(a, b):", + "def f(a, b=2):", + "def f(a=5, b=2):", + "def f(*, a=1, b=2):", + "def f(*, a=1, b):", + "def f(*, a, b=2):", + "def f(a, b=None, *, c, **kwds):", + "def f(a=2, *args, c=5, d, **kwds):", + "def f(*args, **kwargs):", + ): + src = statement + '\n i=1;j=2' + out = statement + '\n i = 1\n j = 2' + + self.check_src_roundtrip(src, out, mode='single') + class CosmeticTestCase(ASTTestCase): """Test if there are cosmetic issues caused by unnecessary additions""" diff --git a/Misc/NEWS.d/next/Library/2025-02-03-16-27-14.gh-issue-129598.0js33I.rst b/Misc/NEWS.d/next/Library/2025-02-03-16-27-14.gh-issue-129598.0js33I.rst new file mode 100644 index 00000000000000..f59eeb236e24a2 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2025-02-03-16-27-14.gh-issue-129598.0js33I.rst @@ -0,0 +1 @@ +Fix :func:`ast.unparse` when :class:`ast.Interactive` contains multiple statements.