lib3to6.checkers

src/lib3to6/checkers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
# This file is part of the lib3to6 project
# https://github.com/mbarkhau/lib3to6
#
# Copyright (c) 2019-2021 Manuel Barkhau (mbarkhau@gmail.com) - MIT License
# SPDX-License-Identifier: MIT

import ast
import typing as typ

from . import utils
from . import common
from . import checker_base as cb
from .checkers_backports import NoUnusableImportsChecker


class NoStarImports(cb.CheckerBase):
    def __call__(self, ctx: common.BuildContext, tree: ast.Module) -> None:
        for node in ast.walk(tree):
            if not isinstance(node, ast.ImportFrom):
                continue

            for alias in node.names:
                if alias.name == "*":
                    raise common.CheckError(f"Prohibited from {node.module} import *.", node)


def _iter_scope_names(tree: ast.Module) -> typ.Iterable[typ.Tuple[str, ast.AST]]:
    for node in ast.walk(tree):
        if isinstance(node, (ast.FunctionDef, ast.ClassDef)):
            yield node.name, node
        elif isinstance(node, ast.Name) and isinstance(node.ctx, ast.Store):
            yield node.id, node
        elif isinstance(node, (ast.ImportFrom, ast.Import)):
            for alias in node.names:
                name = alias.name if alias.asname is None else alias.asname
                yield name, node
        elif isinstance(node, ast.arg):
            yield node.arg, node


class NoOverriddenFixerImportsChecker(cb.CheckerBase):
    """Don't override names that fixers may reference."""

    prohibited_import_overrides = {"itertools", "six", "builtins"}

    def __call__(self, ctx: common.BuildContext, tree: ast.Module) -> None:
        for name_in_scope, node in _iter_scope_names(tree):
            is_fixer_import = (
                isinstance(node, ast.Import)
                and len(node.names) == 1
                and node.names[0].asname is None
                and node.names[0].name == name_in_scope
            )
            if is_fixer_import:
                continue

            if name_in_scope in self.prohibited_import_overrides:
                msg = f"Prohibited override of import '{name_in_scope}'"
                raise common.CheckError(msg, node)


class NoOverriddenBuiltinsChecker(cb.CheckerBase):
    """Don't override names that fixers may reference."""

    def __call__(self, ctx: common.BuildContext, tree: ast.Module) -> None:
        for name_in_scope, node in _iter_scope_names(tree):
            if name_in_scope in common.BUILTIN_NAMES:
                msg = f"Prohibited override of builtin '{name_in_scope}'"
                raise common.CheckError(msg, node)


PROHIBITED_OPEN_ARGUMENTS = {"encoding", "errors", "newline", "closefd", "opener"}


class NoOpenWithEncodingChecker(cb.CheckerBase):

    version_info = common.VersionInfo(apply_until="2.7")

    def __call__(self, ctx: common.BuildContext, tree: ast.Module) -> None:
        for node in ast.walk(tree):
            if not isinstance(node, ast.Call):
                continue

            func_node = node.func
            if not isinstance(func_node, ast.Name):
                continue
            if func_node.id != "open" or not isinstance(func_node.ctx, ast.Load):
                continue

            mode = "r"
            if len(node.args) >= 2:
                mode_node = node.args[1]
                if not isinstance(mode_node, ast.Str):
                    msg = (
                        "Prohibited value for argument 'mode' of builtin.open. "
                        + f"Expected ast.Str node, got: {mode_node}"
                    )
                    raise common.CheckError(msg, node)

                mode = mode_node.s

            if len(node.args) > 3:
                raise common.CheckError("Prohibited positional arguments to builtin.open", node)

            for keyword in node.keywords:
                if keyword.arg in PROHIBITED_OPEN_ARGUMENTS:
                    msg = f"Prohibited keyword argument '{keyword.arg}' to builtin.open."
                    raise common.CheckError(msg, node)
                if keyword.arg != 'mode':
                    continue

                mode_node = keyword.value
                if not isinstance(mode_node, ast.Str):
                    msg = (
                        "Prohibited value for argument 'mode' of builtin.open. "
                        + f"Expected ast.Str node, got: {mode_node}"
                    )
                    raise common.CheckError(msg, node)

                mode = mode_node.s

            if "b" not in mode:
                msg = (
                    f"Prohibited value '{mode}' for argument 'mode' of builtin.open. "
                    + "Only binary modes are allowed, use io.open as an alternative."
                )
                raise common.CheckError(msg, node)


class NoAsyncAwait(cb.CheckerBase):

    version_info = common.VersionInfo(apply_until="3.4", works_since="3.5")

    def __call__(self, ctx: common.BuildContext, tree: ast.Module) -> None:
        async_await_node_types = (ast.AsyncFor, ast.AsyncWith, ast.AsyncFunctionDef, ast.Await)
        for node in ast.walk(tree):
            if not isinstance(node, async_await_node_types):
                continue

            if isinstance(node, ast.AsyncFor):
                keywords = "async for"
            elif isinstance(node, ast.AsyncWith):
                keywords = "async with"
            elif isinstance(node, ast.AsyncFunctionDef):
                keywords = "async def"
            elif isinstance(node, ast.Await):
                keywords = "await"
            else:
                # probably dead codepath
                keywords = "async/await"

            msg = (
                f"Prohibited use of '{keywords}', which is not supported "
                f"for target_version='{ctx.cfg.target_version}'."
            )
            raise common.CheckError(msg, node)


class NoYieldFromChecker(cb.CheckerBase):

    version_info = common.VersionInfo(apply_until="3.2", works_since="3.3")

    def __call__(self, ctx: common.BuildContext, tree: ast.Module) -> None:

        for node in ast.walk(tree):
            if isinstance(node, ast.YieldFrom):
                msg = (
                    "Prohibited use of 'yield from', which is not supported "
                    f"for your target_version={ctx.cfg.target_version}"
                )
                raise common.CheckError(msg, node)


class NoMatMultOpChecker(cb.CheckerBase):

    version_info = common.VersionInfo(apply_until="3.4", works_since="3.5")

    def __call__(self, ctx: common.BuildContext, tree: ast.Module) -> None:
        if not hasattr(ast, 'MatMult'):
            return

        for node in ast.walk(tree):
            if not isinstance(node, ast.BinOp):
                continue

            if not isinstance(node.op, ast.MatMult):
                continue

            msg = "Prohibited use of matrix multiplication '@' operator."
            raise common.CheckError(msg, node)


def _raise_if_complex_named_tuple(node: ast.ClassDef) -> None:
    for subnode in node.body:
        if isinstance(subnode, ast.Expr) and isinstance(subnode.value, ast.Str):
            # docstring is fine
            continue

        if isinstance(subnode, ast.AnnAssign):
            if subnode.value:
                tgt = subnode.target
                assert isinstance(tgt, ast.Name)
                msg = (
                    "Prohibited use of default value "
                    + f"for field '{tgt.id}' of class '{node.name}'"
                )
                raise common.CheckError(msg, subnode, node)
        elif isinstance(subnode, ast.FunctionDef):
            msg = "Prohibited definition of method " + f"'{subnode.name}' for class '{node.name}'"
            raise common.CheckError(msg, subnode, node)
        else:
            msg = f"Unexpected subnode defined for class {node.name}: {subnode}"
            raise common.CheckError(msg, subnode, node)


class NoComplexNamedTuple(cb.CheckerBase):

    version_info = common.VersionInfo(apply_until="3.4", works_since="3.5")

    def __call__(self, ctx: common.BuildContext, tree: ast.Module) -> None:
        _typing_module_name   : typ.Optional[str] = None
        _namedtuple_class_name: str = "NamedTuple"

        for node in ast.walk(tree):
            if isinstance(node, ast.Import):
                for alias in node.names:
                    if alias.name == 'typing':
                        _typing_module_name = alias.name if alias.asname is None else alias.asname
                continue

            if isinstance(node, ast.ImportFrom) and node.module == 'typing':
                for alias in node.names:
                    if alias.name == 'NamedTuple':
                        _namedtuple_class_name = (
                            alias.name if alias.asname is None else alias.asname
                        )
                continue

            is_namedtuple_class = (
                isinstance(node, ast.ClassDef)
                and (_typing_module_name or _namedtuple_class_name)
                and utils.has_base_class(node, _typing_module_name, _namedtuple_class_name)
            )
            if is_namedtuple_class:
                assert isinstance(node, ast.ClassDef), "mypy is stupid sometimes"
                _raise_if_complex_named_tuple(node)


# NOTE (mb 2018-06-24): I don't know how this could be done reliably.
#   The main issue is that there are objects other than dict, which
#   have methods named items,keys,values which this check wouldn't
#   apply to.
# class NoAssignedDictViews(cb.CheckerBase):
#
#     check_before = "3.0"
#
#     def __call__(self, ctx: common.BuildContext, tree: ast.Module):
#         pass

__all__ = [
    'NoStarImports',
    'NoOverriddenFixerImportsChecker',
    'NoOverriddenBuiltinsChecker',
    'NoOpenWithEncodingChecker',
    'NoAsyncAwait',
    'NoComplexNamedTuple',
    'NoUnusableImportsChecker',
    'NoYieldFromChecker',
    'NoMatMultOpChecker',
]