lib3to6.utils

src/lib3to6/utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# This file is part of the lib3to6 project
# https://github.com/mbarkhau/lib3to6
#
# Copyright (c) 2019-2021 Manuel Barkhau (mbarkhau@gmail.com) - MIT License
# SPDX-License-Identifier: MIT

import ast
import typing as typ

import astor

from . import common
from . import transpile

# Recursive types not fully supported yet, nested types replaced with "Any"
# NodeOrNodelist = typ.Union[ast.AST, typ.List["NodeOrNodelist"]]
NodeOrNodelist = typ.Union[ast.AST, typ.List[typ.Any]]


# https://gist.github.com/marsam/d2a5af1563d129bb9482
def dump_ast(
    node              : typ.Any,
    annotate_fields   : bool = True,
    include_attributes: bool = False,
    indent            : str  = "  ",
) -> str:
    """Return a formatted dump of the tree in *node*.

    This is mainly useful for debugging purposes.  The returned
    string will show the names and the values for fields.  This
    makes the code impossible to evaluate, so if evaluation is
    wanted *annotate_fields* must be set to False.  Attributes
    such as line numbers and column offsets are not dumped by
    default.  If this is wanted, *include_attributes* can be set
    to True.
    """

    def _format(node: NodeOrNodelist, level: int = 1) -> str:
        # pylint: disable=protected-access
        if isinstance(node, ast.AST):
            fields = [(a, _format(b, level + 1)) for a, b in ast.iter_fields(node)]
            if include_attributes and node._attributes:
                fields.extend([(a, _format(getattr(node, a), level + 1)) for a in node._attributes])

            if annotate_fields:
                field_parts = ["%s=%s" % field for field in fields]
            else:
                field_parts = [b for a, b in fields]

            node_name     = node.__class__.__name__
            is_short_node = len(field_parts) <= 1 or isinstance(
                node, (ast.Name, ast.Num, ast.Str, ast.Bytes, ast.alias)
            )

            if is_short_node:
                return node_name + "(" + ", ".join(field_parts) + ")"

            lines = [node_name + "("]
            for part in field_parts:
                lines.append((indent * level) + part + ",")
            lines.append((indent * (level - 1)) + ")")
            return "\n".join(lines)
        elif isinstance(node, list):
            subnodes = node
            if len(subnodes) == 0:
                return "[]"

            if len(subnodes) == 1:
                return "[" + _format(subnodes[0], level) + "]"

            lines = [indent * level + _format(subnode, level + 1) + "," for subnode in subnodes]
            return "[\n" + "\n".join(lines) + "\n" + indent * (level - 1) + "]"
        return repr(node)

    if isinstance(node, (ast.AST, list)):
        return _format(node)
    else:
        raise TypeError("expected AST, got %r" % node.__class__.__name__)


def clean_whitespace(fixture_str: str) -> str:
    if fixture_str.strip().count("\n") == 0:
        return fixture_str.strip()

    fixture_lines = [line for line in fixture_str.splitlines() if line.strip()]
    line_indents  = [len(line) - len(line.lstrip()) for line in fixture_lines]
    if not any(line_indents) or min(line_indents) == 0:
        return fixture_str

    indent         = min(line_indents)
    dedented_lines = [line[indent:] for line in fixture_lines]
    return "\n".join(dedented_lines).strip() + "\n"


def parse_stmt(code: str) -> ast.stmt:
    module = ast.parse(code)
    assert len(module.body) == 1
    return module.body[0]


def parsedump_ast(code: str, mode: str = "exec", **kwargs) -> str:
    """Parse some code from a string and pretty-print it."""
    node = ast.parse(clean_whitespace(code), mode=mode)
    return dump_ast(node, **kwargs)


def parsedump_source(code: str, mode: str = "exec") -> str:
    node = ast.parse(clean_whitespace(code), mode=mode)
    return astor.to_source(node)


def transpile_and_dump(ctx: common.BuildContext, module_str: str) -> typ.Tuple[str, str, str]:
    module_str = clean_whitespace(module_str)
    header     = transpile.parse_module_header(module_str, ctx.cfg.target_version)
    result_str = transpile.transpile_module(ctx, module_str)
    return header.coding, header.text, result_str


def has_base_class(
    cls_node: ast.ClassDef, module_name: str = None, base_class_name: str = None
) -> bool:
    if not (module_name or base_class_name):
        return False

    for base in cls_node.bases:
        if isinstance(base, ast.Attribute):
            val = base.value
            if isinstance(val, ast.Name) and val.id == module_name and base.attr == base_class_name:
                return True

        if isinstance(base, ast.Name) and base.id == base_class_name:
            return True

    return False