Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of the lib3to6 project 

2# https://github.com/mbarkhau/lib3to6 

3# 

4# Copyright (c) 2019-2021 Manuel Barkhau (mbarkhau@gmail.com) - MIT License 

5# SPDX-License-Identifier: MIT 

6 

7import re 

8import ast 

9import sys 

10import typing as typ 

11 

12import astor 

13 

14from . import utils 

15from . import common 

16from . import fixers 

17from . import checkers 

18from . import fixer_base as fb 

19from . import checker_base as cb 

20 

21DEFAULT_SOURCE_ENCODING_DECLARATION = "# -*- coding: {} -*-" 

22 

23DEFAULT_SOURCE_ENCODING = "utf-8" 

24 

25DEFAULT_TARGET_VERSION = "2.7" 

26 

27# https://www.python.org/dev/peps/pep-0263/ 

28SOURCE_ENCODING_PATTERN = r""" 

29 ^ 

30 [ \t\v]* 

31 \#.*?coding[:=][ \t]* 

32 (?P<coding>[-_.a-zA-Z0-9]+) 

33""" 

34 

35SOURCE_ENCODING_RE = re.compile(SOURCE_ENCODING_PATTERN, re.VERBOSE) 

36 

37 

38MODE_MARKER_PATTERN = r"#\s*lib3to6:\s*(?P<mode>disabled|enabled)" 

39 

40MODE_MARKER_RE = re.compile(MODE_MARKER_PATTERN, flags=re.MULTILINE) 

41 

42 

43class ModuleHeader(typ.NamedTuple): 

44 

45 coding: str 

46 text : str 

47 

48 

49def _parse_header_line(line_data: typ.Union[bytes, str], coding: str) -> str: 

50 if isinstance(line_data, bytes): 

51 return line_data.decode(coding) 

52 if isinstance(line_data, str): 

53 return line_data 

54 

55 # unreachable 

56 bad_type = type(line_data) 

57 errmsg = f"Invalid type: line_data must be str/bytes but was '{bad_type}'" 

58 raise TypeError(errmsg) 

59 

60 

61def parse_module_header(module_source: typ.Union[bytes, str], target_version: str) -> ModuleHeader: 

62 shebang = False 

63 coding = None 

64 line: str 

65 

66 header_lines: typ.List[str] = [] 

67 

68 for i, line_data in enumerate(module_source.splitlines()): 

69 assert isinstance(line_data, (bytes, str)) 

70 line = _parse_header_line(line_data, coding or DEFAULT_SOURCE_ENCODING) 

71 

72 if i < 2: 

73 if i == 0 and line.startswith("#!") and "python" in line: 

74 shebang = True 

75 else: 

76 match = SOURCE_ENCODING_RE.match(line) 

77 if match: 

78 coding = match.group("coding").strip() 

79 

80 if line.rstrip() and not line.rstrip().startswith("#"): 

81 break 

82 

83 header_lines.append(line) 

84 

85 if coding is None: 

86 coding = DEFAULT_SOURCE_ENCODING 

87 if target_version < "3.0": 

88 coding_decl = DEFAULT_SOURCE_ENCODING_DECLARATION.format(coding) 

89 if shebang: 

90 header_lines.insert(1, coding_decl) 

91 else: 

92 header_lines.insert(0, coding_decl) 

93 

94 header_text = "\n".join(header_lines) + "\n" 

95 return ModuleHeader(coding, header_text) 

96 

97 

98CheckerType = typ.Type[cb.CheckerBase] 

99 

100FixerType = typ.Type[fb.FixerBase] 

101 

102CheckerOrFixer = typ.Union[CheckerType, FixerType] 

103 

104 

105def normalize_name(name: str) -> str: 

106 name = name.strip().lower().replace("_", "").replace("-", "") 

107 if name.endswith("fixer"): 

108 name = name[: -len("fixer")] 

109 if name.endswith("checker"): 

110 name = name[: -len("checker")] 

111 return name 

112 

113 

114def get_available_classes(module: object, clazz: CheckerOrFixer) -> typ.Dict[str, CheckerOrFixer]: 

115 

116 assert isinstance(clazz, type) 

117 clazz_name = clazz.__name__ 

118 assert clazz_name.endswith("Base") 

119 

120 maybe_classes = { 

121 name: getattr(module, name) for name in dir(module) if not name.endswith(clazz_name) 

122 } 

123 

124 return { 

125 normalize_name(attr_name): attr 

126 for attr_name, attr in maybe_classes.items() 

127 if isinstance(attr, type) and issubclass(attr, clazz) 

128 } 

129 

130 

131FuzzyNames = typ.Union[str, typ.List[str]] 

132 

133 

134def get_selected_names(names: FuzzyNames, available_names: typ.Set[str]) -> typ.List[str]: 

135 if isinstance(names, str): 

136 names_list = names.split(",") 

137 else: 

138 names_list = names 

139 

140 selected_names = [normalize_name(name) for name in names_list if name.strip()] 

141 

142 if selected_names: 

143 for name in selected_names: 

144 assert name in available_names 

145 else: 

146 # Nothing explicitly selected -> all selected 

147 selected_names = sorted(available_names) 

148 

149 assert len(selected_names) > 0 

150 

151 return selected_names 

152 

153 

154def iter_fuzzy_selected_checkers(names: FuzzyNames) -> typ.Iterable[cb.CheckerBase]: 

155 available_classes = get_available_classes(checkers, cb.CheckerBase) 

156 selected_names = get_selected_names(names, set(available_classes)) 

157 for name in selected_names: 

158 checker_type = typ.cast(CheckerType, available_classes[name]) 

159 yield checker_type() 

160 

161 

162def iter_fuzzy_selected_fixers(names: FuzzyNames) -> typ.Iterable[fb.FixerBase]: 

163 available_classes = get_available_classes(fixers, fb.FixerBase) 

164 selected_names = get_selected_names(names, set(available_classes)) 

165 for name in selected_names: 

166 fixer_type = typ.cast(FixerType, available_classes[name]) 

167 yield fixer_type() 

168 

169 

170def find_import_decls(node: ast.AST) -> typ.Iterable[common.ImportDecl]: 

171 # NOTE (mb 2020-07-18): returns as fences are fine 

172 # pylint:disable=too-many-return-statements 

173 # NOTE (mb 2020-07-18): despite the brnaches, the code is quite linear 

174 # pylint:disable=too-many-branches 

175 if not isinstance(node, (ast.Try, ast.Import, ast.ImportFrom)): 

176 return 

177 

178 if isinstance(node, ast.Try): 

179 if not (len(node.body) == 1 and len(node.handlers) == 1): 

180 return 

181 

182 except_handler = node.handlers[0] 

183 

184 is_import_error_handler = ( 

185 isinstance(except_handler.type, ast.Name) 

186 and except_handler.type.id == 'ImportError' 

187 and len(except_handler.body) == 1 

188 ) 

189 if not is_import_error_handler: 

190 return 

191 

192 maybe_import = node.body[0] 

193 if not isinstance(maybe_import, ast.Import): 

194 return 

195 

196 default_import = maybe_import 

197 

198 maybe_fallback_import = except_handler.body[0] 

199 if not isinstance(maybe_fallback_import, ast.Import): 

200 return 

201 

202 fallback_import = maybe_fallback_import 

203 

204 if len(default_import.names) == 1 and len(fallback_import.names) == 1: 

205 default_import_alias = default_import.names[0] 

206 fallback_import_alias = fallback_import.names[0] 

207 yield common.ImportDecl( 

208 default_import_alias.name, default_import_alias.asname, fallback_import_alias.name 

209 ) 

210 

211 elif isinstance(node, ast.Import): 

212 if len(node.names) != 1 and any(alias.asname for alias in node.names): 

213 # we never use multi name imports or asname, so this is user code 

214 return 

215 

216 alias = node.names[0] 

217 yield common.ImportDecl(alias.name, None, None) 

218 elif isinstance(node, ast.ImportFrom): 

219 if any(alias.asname for alias in node.names): 

220 # we never use multi name imports or asname, so this is user code 

221 return 

222 

223 module_name = node.module 

224 if not module_name: 

225 return 

226 

227 for alias in node.names: 

228 yield common.ImportDecl(module_name, alias.name, None) 

229 

230 

231def parse_imports(tree: ast.Module) -> typ.Tuple[int, int, typ.Set[common.ImportDecl]]: 

232 future_imports_offset = 0 

233 imports_end_offset = 0 

234 

235 import_decls: typ.Set[common.ImportDecl] = set() 

236 

237 for body_offset, node in enumerate(tree.body): 

238 is_docstring = ( 

239 body_offset == 0 and isinstance(node, ast.Expr) and isinstance(node.value, ast.Str) 

240 ) 

241 if is_docstring: 

242 future_imports_offset = body_offset + 1 

243 imports_end_offset = body_offset + 1 

244 continue 

245 

246 node_import_decls = list(find_import_decls(node)) 

247 if not node_import_decls: 

248 # stop when we've passed the initial imports, 

249 # everything else is user code 

250 break 

251 

252 for import_decl in node_import_decls: 

253 if import_decl.module_name == '__future__': 

254 future_imports_offset = body_offset 

255 imports_end_offset = body_offset 

256 import_decls.add(import_decl) 

257 

258 return (future_imports_offset, imports_end_offset, import_decls) 

259 

260 

261def add_required_imports(tree: ast.Module, required_imports: typ.Set[common.ImportDecl]) -> None: 

262 """Add imports required by fixers. 

263 

264 Some fixers depend on modules which may not be imported in 

265 the source module. As an example, occurrences of 'map' might 

266 be replaced with 'itertools.imap', in which case, 

267 "import itertools" will be added in the module scope. 

268 

269 A further quirk is that all reqired imports must be added 

270 before any other statment. This is because that statement 

271 could be subject to the fix which requires the import. As 

272 a side effect, a module may end up being imported twice, if 

273 the module is imported after some statement. 

274 """ 

275 (future_imports_offset, imports_end_offset, found_imports) = parse_imports(tree) 

276 

277 missing_imports = sorted(required_imports - found_imports) 

278 

279 import_node: ast.stmt 

280 for import_decl in missing_imports: 

281 if import_decl.import_name is None: 

282 import_node = ast.Import(names=[ast.alias(name=import_decl.module_name, asname=None)]) 

283 else: 

284 import_node = ast.ImportFrom( 

285 module=import_decl.module_name, 

286 level=0, 

287 names=[ast.alias(name=import_decl.import_name, asname=None)], 

288 ) 

289 

290 if import_decl.py2_module_name: 

291 asname = import_decl.import_name or import_decl.module_name 

292 fallback_import = ast.Import( 

293 names=[ast.alias(name=import_decl.py2_module_name, asname=asname)] 

294 ) 

295 import_node = ast.Try( 

296 body=[import_node], 

297 handlers=[ 

298 ast.ExceptHandler( 

299 type=ast.Name(id='ImportError', ctx=ast.Load()), 

300 name=None, 

301 body=[fallback_import], 

302 ) 

303 ], 

304 orelse=[], 

305 finalbody=[], 

306 ) 

307 

308 if import_decl.module_name == '__future__': 

309 tree.body.insert(future_imports_offset, import_node) 

310 future_imports_offset += 1 

311 imports_end_offset += 1 

312 else: 

313 tree.body.insert(imports_end_offset, import_node) 

314 imports_end_offset += 1 

315 

316 

317def add_module_declarations(tree: ast.Module, module_declarations: typ.Set[str]) -> None: 

318 """Add global declarations required by fixers. 

319 

320 Some fixers declare globals (or override builtins) the source 

321 module. As an example, occurrences of 'map' might be replaced 

322 by 'map = getattr(itertools, "map", map)'. 

323 

324 These declarations are added directly after imports. 

325 """ 

326 _, imports_end_offset, _ = parse_imports(tree) 

327 

328 for decl_str in sorted(module_declarations): 

329 decl_node = utils.parse_stmt(decl_str) 

330 tree.body.insert(imports_end_offset + 1, decl_node) 

331 imports_end_offset += 1 

332 

333 

334def transpile_module(ctx: common.BuildContext, module_source: str) -> str: 

335 _module_header = module_source.split("import", 1)[0] 

336 _module_header = _module_header.split("'''", 1)[0] 

337 _module_header = _module_header.split('"""', 1)[0] 

338 

339 lib3to6_mode_marker = MODE_MARKER_RE.search(_module_header) 

340 if lib3to6_mode_marker: 

341 mode = lib3to6_mode_marker.group('mode') 

342 else: 

343 mode = ctx.cfg.default_mode 

344 

345 if mode == 'disabled': 

346 return module_source 

347 

348 checker_names: FuzzyNames = ctx.cfg.checkers 

349 fixer_names : FuzzyNames = ctx.cfg.fixers 

350 module_tree = ast.parse(module_source) 

351 required_imports : typ.Set[common.ImportDecl] = set() 

352 module_declarations: typ.Set[str ] = set() 

353 

354 ver = sys.version_info 

355 source_version = f"{ver.major}.{ver.minor}" 

356 target_version = ctx.cfg.target_version 

357 

358 for checker in iter_fuzzy_selected_checkers(checker_names): 

359 if checker.version_info.is_applicable_to(source_version, target_version): 

360 checker(ctx, module_tree) 

361 

362 for fixer in iter_fuzzy_selected_fixers(fixer_names): 

363 if fixer.version_info.is_applicable_to(source_version, target_version): 

364 maybe_fixed_module = fixer(ctx, module_tree) 

365 if maybe_fixed_module is None: 

366 raise Exception(f"Error running fixer {type(fixer).__name__}") 

367 required_imports.update(fixer.required_imports) 

368 module_declarations.update(fixer.module_declarations) 

369 module_tree = maybe_fixed_module 

370 

371 if any(required_imports): 

372 add_required_imports(module_tree, required_imports) 

373 if any(module_declarations): 

374 add_module_declarations(module_tree, module_declarations) 

375 header = parse_module_header(module_source, target_version) 

376 return header.text + "".join(astor.to_source(module_tree)) 

377 

378 

379def transpile_module_data(ctx: common.BuildContext, module_source_data: bytes) -> bytes: 

380 target_version = ctx.cfg.target_version 

381 header = parse_module_header(module_source_data, target_version) 

382 module_source = module_source_data.decode(header.coding) 

383 fixed_module_source = transpile_module(ctx, module_source) 

384 return fixed_module_source.encode(header.coding)