Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of the lib3to6 project 

2# https://github.com/mbarkhau/lib3to6 

3# 

4# Copyright (c) 2019-2021 Manuel Barkhau (mbarkhau@gmail.com) - MIT License 

5# SPDX-License-Identifier: MIT 

6 

7import ast 

8import typing as typ 

9 

10import astor 

11 

12from . import common 

13from . import transpile 

14 

15# Recursive types not fully supported yet, nested types replaced with "Any" 

16# NodeOrNodelist = typ.Union[ast.AST, typ.List["NodeOrNodelist"]] 

17NodeOrNodelist = typ.Union[ast.AST, typ.List[typ.Any]] 

18 

19 

20# https://gist.github.com/marsam/d2a5af1563d129bb9482 

21def dump_ast( 

22 node : typ.Any, 

23 annotate_fields : bool = True, 

24 include_attributes: bool = False, 

25 indent : str = " ", 

26) -> str: 

27 """Return a formatted dump of the tree in *node*. 

28 

29 This is mainly useful for debugging purposes. The returned 

30 string will show the names and the values for fields. This 

31 makes the code impossible to evaluate, so if evaluation is 

32 wanted *annotate_fields* must be set to False. Attributes 

33 such as line numbers and column offsets are not dumped by 

34 default. If this is wanted, *include_attributes* can be set 

35 to True. 

36 """ 

37 

38 def _format(node: NodeOrNodelist, level: int = 1) -> str: 

39 # pylint: disable=protected-access 

40 if isinstance(node, ast.AST): 

41 fields = [(a, _format(b, level + 1)) for a, b in ast.iter_fields(node)] 

42 if include_attributes and node._attributes: 

43 fields.extend([(a, _format(getattr(node, a), level + 1)) for a in node._attributes]) 

44 

45 if annotate_fields: 

46 field_parts = ["%s=%s" % field for field in fields] 

47 else: 

48 field_parts = [b for a, b in fields] 

49 

50 node_name = node.__class__.__name__ 

51 is_short_node = len(field_parts) <= 1 or isinstance( 

52 node, (ast.Name, ast.Num, ast.Str, ast.Bytes, ast.alias) 

53 ) 

54 

55 if is_short_node: 

56 return node_name + "(" + ", ".join(field_parts) + ")" 

57 

58 lines = [node_name + "("] 

59 for part in field_parts: 

60 lines.append((indent * level) + part + ",") 

61 lines.append((indent * (level - 1)) + ")") 

62 return "\n".join(lines) 

63 elif isinstance(node, list): 

64 subnodes = node 

65 if len(subnodes) == 0: 

66 return "[]" 

67 

68 if len(subnodes) == 1: 

69 return "[" + _format(subnodes[0], level) + "]" 

70 

71 lines = [indent * level + _format(subnode, level + 1) + "," for subnode in subnodes] 

72 return "[\n" + "\n".join(lines) + "\n" + indent * (level - 1) + "]" 

73 return repr(node) 

74 

75 if isinstance(node, (ast.AST, list)): 

76 return _format(node) 

77 else: 

78 raise TypeError("expected AST, got %r" % node.__class__.__name__) 

79 

80 

81def clean_whitespace(fixture_str: str) -> str: 

82 if fixture_str.strip().count("\n") == 0: 

83 return fixture_str.strip() 

84 

85 fixture_lines = [line for line in fixture_str.splitlines() if line.strip()] 

86 line_indents = [len(line) - len(line.lstrip()) for line in fixture_lines] 

87 if not any(line_indents) or min(line_indents) == 0: 

88 return fixture_str 

89 

90 indent = min(line_indents) 

91 dedented_lines = [line[indent:] for line in fixture_lines] 

92 return "\n".join(dedented_lines).strip() + "\n" 

93 

94 

95def parse_stmt(code: str) -> ast.stmt: 

96 module = ast.parse(code) 

97 assert len(module.body) == 1 

98 return module.body[0] 

99 

100 

101def parsedump_ast(code: str, mode: str = "exec", **kwargs) -> str: 

102 """Parse some code from a string and pretty-print it.""" 

103 node = ast.parse(clean_whitespace(code), mode=mode) 

104 return dump_ast(node, **kwargs) 

105 

106 

107def parsedump_source(code: str, mode: str = "exec") -> str: 

108 node = ast.parse(clean_whitespace(code), mode=mode) 

109 return astor.to_source(node) 

110 

111 

112def transpile_and_dump(ctx: common.BuildContext, module_str: str) -> typ.Tuple[str, str, str]: 

113 module_str = clean_whitespace(module_str) 

114 header = transpile.parse_module_header(module_str, ctx.cfg.target_version) 

115 result_str = transpile.transpile_module(ctx, module_str) 

116 return header.coding, header.text, result_str 

117 

118 

119def has_base_class( 

120 cls_node: ast.ClassDef, module_name: str = None, base_class_name: str = None 

121) -> bool: 

122 if not (module_name or base_class_name): 

123 return False 

124 

125 for base in cls_node.bases: 

126 if isinstance(base, ast.Attribute): 

127 val = base.value 

128 if isinstance(val, ast.Name) and val.id == module_name and base.attr == base_class_name: 

129 return True 

130 

131 if isinstance(base, ast.Name) and base.id == base_class_name: 

132 return True 

133 

134 return False