Coverage for src/lib3to6/transpile.py : 92%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of the lib3to6 project
2# https://github.com/mbarkhau/lib3to6
3#
4# Copyright (c) 2019-2021 Manuel Barkhau (mbarkhau@gmail.com) - MIT License
5# SPDX-License-Identifier: MIT
7import re
8import ast
9import sys
10import typing as typ
12import astor
14from . import utils
15from . import common
16from . import fixers
17from . import checkers
18from . import fixer_base as fb
19from . import checker_base as cb
21DEFAULT_SOURCE_ENCODING_DECLARATION = "# -*- coding: {} -*-"
23DEFAULT_SOURCE_ENCODING = "utf-8"
25DEFAULT_TARGET_VERSION = "2.7"
27# https://www.python.org/dev/peps/pep-0263/
28SOURCE_ENCODING_PATTERN = r"""
29 ^
30 [ \t\v]*
31 \#.*?coding[:=][ \t]*
32 (?P<coding>[-_.a-zA-Z0-9]+)
33"""
35SOURCE_ENCODING_RE = re.compile(SOURCE_ENCODING_PATTERN, re.VERBOSE)
38MODE_MARKER_PATTERN = r"#\s*lib3to6:\s*(?P<mode>disabled|enabled)"
40MODE_MARKER_RE = re.compile(MODE_MARKER_PATTERN, flags=re.MULTILINE)
43class ModuleHeader(typ.NamedTuple):
45 coding: str
46 text : str
49def _parse_header_line(line_data: typ.Union[bytes, str], coding: str) -> str:
50 if isinstance(line_data, bytes):
51 return line_data.decode(coding)
52 if isinstance(line_data, str):
53 return line_data
55 # unreachable
56 bad_type = type(line_data)
57 errmsg = f"Invalid type: line_data must be str/bytes but was '{bad_type}'"
58 raise TypeError(errmsg)
61def parse_module_header(module_source: typ.Union[bytes, str], target_version: str) -> ModuleHeader:
62 shebang = False
63 coding = None
64 line: str
66 header_lines: typ.List[str] = []
68 for i, line_data in enumerate(module_source.splitlines()):
69 assert isinstance(line_data, (bytes, str))
70 line = _parse_header_line(line_data, coding or DEFAULT_SOURCE_ENCODING)
72 if i < 2:
73 if i == 0 and line.startswith("#!") and "python" in line:
74 shebang = True
75 else:
76 match = SOURCE_ENCODING_RE.match(line)
77 if match:
78 coding = match.group("coding").strip()
80 if line.rstrip() and not line.rstrip().startswith("#"):
81 break
83 header_lines.append(line)
85 if coding is None:
86 coding = DEFAULT_SOURCE_ENCODING
87 if target_version < "3.0":
88 coding_decl = DEFAULT_SOURCE_ENCODING_DECLARATION.format(coding)
89 if shebang:
90 header_lines.insert(1, coding_decl)
91 else:
92 header_lines.insert(0, coding_decl)
94 header_text = "\n".join(header_lines) + "\n"
95 return ModuleHeader(coding, header_text)
98CheckerType = typ.Type[cb.CheckerBase]
100FixerType = typ.Type[fb.FixerBase]
102CheckerOrFixer = typ.Union[CheckerType, FixerType]
105def normalize_name(name: str) -> str:
106 name = name.strip().lower().replace("_", "").replace("-", "")
107 if name.endswith("fixer"):
108 name = name[: -len("fixer")]
109 if name.endswith("checker"):
110 name = name[: -len("checker")]
111 return name
114def get_available_classes(module: object, clazz: CheckerOrFixer) -> typ.Dict[str, CheckerOrFixer]:
116 assert isinstance(clazz, type)
117 clazz_name = clazz.__name__
118 assert clazz_name.endswith("Base")
120 maybe_classes = {
121 name: getattr(module, name) for name in dir(module) if not name.endswith(clazz_name)
122 }
124 return {
125 normalize_name(attr_name): attr
126 for attr_name, attr in maybe_classes.items()
127 if isinstance(attr, type) and issubclass(attr, clazz)
128 }
131FuzzyNames = typ.Union[str, typ.List[str]]
134def get_selected_names(names: FuzzyNames, available_names: typ.Set[str]) -> typ.List[str]:
135 if isinstance(names, str):
136 names_list = names.split(",")
137 else:
138 names_list = names
140 selected_names = [normalize_name(name) for name in names_list if name.strip()]
142 if selected_names:
143 for name in selected_names:
144 assert name in available_names
145 else:
146 # Nothing explicitly selected -> all selected
147 selected_names = sorted(available_names)
149 assert len(selected_names) > 0
151 return selected_names
154def iter_fuzzy_selected_checkers(names: FuzzyNames) -> typ.Iterable[cb.CheckerBase]:
155 available_classes = get_available_classes(checkers, cb.CheckerBase)
156 selected_names = get_selected_names(names, set(available_classes))
157 for name in selected_names:
158 checker_type = typ.cast(CheckerType, available_classes[name])
159 yield checker_type()
162def iter_fuzzy_selected_fixers(names: FuzzyNames) -> typ.Iterable[fb.FixerBase]:
163 available_classes = get_available_classes(fixers, fb.FixerBase)
164 selected_names = get_selected_names(names, set(available_classes))
165 for name in selected_names:
166 fixer_type = typ.cast(FixerType, available_classes[name])
167 yield fixer_type()
170def find_import_decls(node: ast.AST) -> typ.Iterable[common.ImportDecl]:
171 # NOTE (mb 2020-07-18): returns as fences are fine
172 # pylint:disable=too-many-return-statements
173 # NOTE (mb 2020-07-18): despite the brnaches, the code is quite linear
174 # pylint:disable=too-many-branches
175 if not isinstance(node, (ast.Try, ast.Import, ast.ImportFrom)):
176 return
178 if isinstance(node, ast.Try):
179 if not (len(node.body) == 1 and len(node.handlers) == 1):
180 return
182 except_handler = node.handlers[0]
184 is_import_error_handler = (
185 isinstance(except_handler.type, ast.Name)
186 and except_handler.type.id == 'ImportError'
187 and len(except_handler.body) == 1
188 )
189 if not is_import_error_handler:
190 return
192 maybe_import = node.body[0]
193 if not isinstance(maybe_import, ast.Import):
194 return
196 default_import = maybe_import
198 maybe_fallback_import = except_handler.body[0]
199 if not isinstance(maybe_fallback_import, ast.Import):
200 return
202 fallback_import = maybe_fallback_import
204 if len(default_import.names) == 1 and len(fallback_import.names) == 1:
205 default_import_alias = default_import.names[0]
206 fallback_import_alias = fallback_import.names[0]
207 yield common.ImportDecl(
208 default_import_alias.name, default_import_alias.asname, fallback_import_alias.name
209 )
211 elif isinstance(node, ast.Import):
212 if len(node.names) != 1 and any(alias.asname for alias in node.names):
213 # we never use multi name imports or asname, so this is user code
214 return
216 alias = node.names[0]
217 yield common.ImportDecl(alias.name, None, None)
218 elif isinstance(node, ast.ImportFrom):
219 if any(alias.asname for alias in node.names):
220 # we never use multi name imports or asname, so this is user code
221 return
223 module_name = node.module
224 if not module_name:
225 return
227 for alias in node.names:
228 yield common.ImportDecl(module_name, alias.name, None)
231def parse_imports(tree: ast.Module) -> typ.Tuple[int, int, typ.Set[common.ImportDecl]]:
232 future_imports_offset = 0
233 imports_end_offset = 0
235 import_decls: typ.Set[common.ImportDecl] = set()
237 for body_offset, node in enumerate(tree.body):
238 is_docstring = (
239 body_offset == 0 and isinstance(node, ast.Expr) and isinstance(node.value, ast.Str)
240 )
241 if is_docstring:
242 future_imports_offset = body_offset + 1
243 imports_end_offset = body_offset + 1
244 continue
246 node_import_decls = list(find_import_decls(node))
247 if not node_import_decls:
248 # stop when we've passed the initial imports,
249 # everything else is user code
250 break
252 for import_decl in node_import_decls:
253 if import_decl.module_name == '__future__':
254 future_imports_offset = body_offset
255 imports_end_offset = body_offset
256 import_decls.add(import_decl)
258 return (future_imports_offset, imports_end_offset, import_decls)
261def add_required_imports(tree: ast.Module, required_imports: typ.Set[common.ImportDecl]) -> None:
262 """Add imports required by fixers.
264 Some fixers depend on modules which may not be imported in
265 the source module. As an example, occurrences of 'map' might
266 be replaced with 'itertools.imap', in which case,
267 "import itertools" will be added in the module scope.
269 A further quirk is that all reqired imports must be added
270 before any other statment. This is because that statement
271 could be subject to the fix which requires the import. As
272 a side effect, a module may end up being imported twice, if
273 the module is imported after some statement.
274 """
275 (future_imports_offset, imports_end_offset, found_imports) = parse_imports(tree)
277 missing_imports = sorted(required_imports - found_imports)
279 import_node: ast.stmt
280 for import_decl in missing_imports:
281 if import_decl.import_name is None:
282 import_node = ast.Import(names=[ast.alias(name=import_decl.module_name, asname=None)])
283 else:
284 import_node = ast.ImportFrom(
285 module=import_decl.module_name,
286 level=0,
287 names=[ast.alias(name=import_decl.import_name, asname=None)],
288 )
290 if import_decl.py2_module_name:
291 asname = import_decl.import_name or import_decl.module_name
292 fallback_import = ast.Import(
293 names=[ast.alias(name=import_decl.py2_module_name, asname=asname)]
294 )
295 import_node = ast.Try(
296 body=[import_node],
297 handlers=[
298 ast.ExceptHandler(
299 type=ast.Name(id='ImportError', ctx=ast.Load()),
300 name=None,
301 body=[fallback_import],
302 )
303 ],
304 orelse=[],
305 finalbody=[],
306 )
308 if import_decl.module_name == '__future__':
309 tree.body.insert(future_imports_offset, import_node)
310 future_imports_offset += 1
311 imports_end_offset += 1
312 else:
313 tree.body.insert(imports_end_offset, import_node)
314 imports_end_offset += 1
317def add_module_declarations(tree: ast.Module, module_declarations: typ.Set[str]) -> None:
318 """Add global declarations required by fixers.
320 Some fixers declare globals (or override builtins) the source
321 module. As an example, occurrences of 'map' might be replaced
322 by 'map = getattr(itertools, "map", map)'.
324 These declarations are added directly after imports.
325 """
326 _, imports_end_offset, _ = parse_imports(tree)
328 for decl_str in sorted(module_declarations):
329 decl_node = utils.parse_stmt(decl_str)
330 tree.body.insert(imports_end_offset + 1, decl_node)
331 imports_end_offset += 1
334def transpile_module(ctx: common.BuildContext, module_source: str) -> str:
335 _module_header = module_source.split("import", 1)[0]
336 _module_header = _module_header.split("'''", 1)[0]
337 _module_header = _module_header.split('"""', 1)[0]
339 lib3to6_mode_marker = MODE_MARKER_RE.search(_module_header)
340 if lib3to6_mode_marker:
341 mode = lib3to6_mode_marker.group('mode')
342 else:
343 mode = ctx.cfg.default_mode
345 if mode == 'disabled':
346 return module_source
348 checker_names: FuzzyNames = ctx.cfg.checkers
349 fixer_names : FuzzyNames = ctx.cfg.fixers
350 module_tree = ast.parse(module_source)
351 required_imports : typ.Set[common.ImportDecl] = set()
352 module_declarations: typ.Set[str ] = set()
354 ver = sys.version_info
355 source_version = f"{ver.major}.{ver.minor}"
356 target_version = ctx.cfg.target_version
358 for checker in iter_fuzzy_selected_checkers(checker_names):
359 if checker.version_info.is_applicable_to(source_version, target_version):
360 checker(ctx, module_tree)
362 for fixer in iter_fuzzy_selected_fixers(fixer_names):
363 if fixer.version_info.is_applicable_to(source_version, target_version):
364 maybe_fixed_module = fixer(ctx, module_tree)
365 if maybe_fixed_module is None:
366 raise Exception(f"Error running fixer {type(fixer).__name__}")
367 required_imports.update(fixer.required_imports)
368 module_declarations.update(fixer.module_declarations)
369 module_tree = maybe_fixed_module
371 if any(required_imports):
372 add_required_imports(module_tree, required_imports)
373 if any(module_declarations):
374 add_module_declarations(module_tree, module_declarations)
375 header = parse_module_header(module_source, target_version)
376 return header.text + "".join(astor.to_source(module_tree))
379def transpile_module_data(ctx: common.BuildContext, module_source_data: bytes) -> bytes:
380 target_version = ctx.cfg.target_version
381 header = parse_module_header(module_source_data, target_version)
382 module_source = module_source_data.decode(header.coding)
383 fixed_module_source = transpile_module(ctx, module_source)
384 return fixed_module_source.encode(header.coding)