Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of the lib3to6 project
2# https://github.com/mbarkhau/lib3to6
3#
4# Copyright (c) 2019-2021 Manuel Barkhau (mbarkhau@gmail.com) - MIT License
5# SPDX-License-Identifier: MIT
7import ast
8import typing as typ
10import astor
12from . import common
13from . import transpile
15# Recursive types not fully supported yet, nested types replaced with "Any"
16# NodeOrNodelist = typ.Union[ast.AST, typ.List["NodeOrNodelist"]]
17NodeOrNodelist = typ.Union[ast.AST, typ.List[typ.Any]]
20# https://gist.github.com/marsam/d2a5af1563d129bb9482
21def dump_ast(
22 node : typ.Any,
23 annotate_fields : bool = True,
24 include_attributes: bool = False,
25 indent : str = " ",
26) -> str:
27 """Return a formatted dump of the tree in *node*.
29 This is mainly useful for debugging purposes. The returned
30 string will show the names and the values for fields. This
31 makes the code impossible to evaluate, so if evaluation is
32 wanted *annotate_fields* must be set to False. Attributes
33 such as line numbers and column offsets are not dumped by
34 default. If this is wanted, *include_attributes* can be set
35 to True.
36 """
38 def _format(node: NodeOrNodelist, level: int = 1) -> str:
39 # pylint: disable=protected-access
40 if isinstance(node, ast.AST):
41 fields = [(a, _format(b, level + 1)) for a, b in ast.iter_fields(node)]
42 if include_attributes and node._attributes:
43 fields.extend([(a, _format(getattr(node, a), level + 1)) for a in node._attributes])
45 if annotate_fields:
46 field_parts = ["%s=%s" % field for field in fields]
47 else:
48 field_parts = [b for a, b in fields]
50 node_name = node.__class__.__name__
51 is_short_node = len(field_parts) <= 1 or isinstance(
52 node, (ast.Name, ast.Num, ast.Str, ast.Bytes, ast.alias)
53 )
55 if is_short_node:
56 return node_name + "(" + ", ".join(field_parts) + ")"
58 lines = [node_name + "("]
59 for part in field_parts:
60 lines.append((indent * level) + part + ",")
61 lines.append((indent * (level - 1)) + ")")
62 return "\n".join(lines)
63 elif isinstance(node, list):
64 subnodes = node
65 if len(subnodes) == 0:
66 return "[]"
68 if len(subnodes) == 1:
69 return "[" + _format(subnodes[0], level) + "]"
71 lines = [indent * level + _format(subnode, level + 1) + "," for subnode in subnodes]
72 return "[\n" + "\n".join(lines) + "\n" + indent * (level - 1) + "]"
73 return repr(node)
75 if isinstance(node, (ast.AST, list)):
76 return _format(node)
77 else:
78 raise TypeError("expected AST, got %r" % node.__class__.__name__)
81def clean_whitespace(fixture_str: str) -> str:
82 if fixture_str.strip().count("\n") == 0:
83 return fixture_str.strip()
85 fixture_lines = [line for line in fixture_str.splitlines() if line.strip()]
86 line_indents = [len(line) - len(line.lstrip()) for line in fixture_lines]
87 if not any(line_indents) or min(line_indents) == 0:
88 return fixture_str
90 indent = min(line_indents)
91 dedented_lines = [line[indent:] for line in fixture_lines]
92 return "\n".join(dedented_lines).strip() + "\n"
95def parse_stmt(code: str) -> ast.stmt:
96 module = ast.parse(code)
97 assert len(module.body) == 1
98 return module.body[0]
101def parsedump_ast(code: str, mode: str = "exec", **kwargs) -> str:
102 """Parse some code from a string and pretty-print it."""
103 node = ast.parse(clean_whitespace(code), mode=mode)
104 return dump_ast(node, **kwargs)
107def parsedump_source(code: str, mode: str = "exec") -> str:
108 node = ast.parse(clean_whitespace(code), mode=mode)
109 return astor.to_source(node)
112def transpile_and_dump(ctx: common.BuildContext, module_str: str) -> typ.Tuple[str, str, str]:
113 module_str = clean_whitespace(module_str)
114 header = transpile.parse_module_header(module_str, ctx.cfg.target_version)
115 result_str = transpile.transpile_module(ctx, module_str)
116 return header.coding, header.text, result_str
119def has_base_class(
120 cls_node: ast.ClassDef, module_name: str = None, base_class_name: str = None
121) -> bool:
122 if not (module_name or base_class_name):
123 return False
125 for base in cls_node.bases:
126 if isinstance(base, ast.Attribute):
127 val = base.value
128 if isinstance(val, ast.Name) and val.id == module_name and base.attr == base_class_name:
129 return True
131 if isinstance(base, ast.Name) and base.id == base_class_name:
132 return True
134 return False