Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of the lib3to6 project 

2# https://github.com/mbarkhau/lib3to6 

3# 

4# Copyright (c) 2019-2021 Manuel Barkhau (mbarkhau@gmail.com) - MIT License 

5# SPDX-License-Identifier: MIT 

6import ast 

7import typing as typ 

8 

9from . import common 

10from . import fixer_base as fb 

11 

12AstStr = getattr(ast, 'Str', ast.Constant) 

13 

14 

15ArgUnpackNodes = (ast.Call, ast.List, ast.Tuple, ast.Set) 

16KwArgUnpackNodes = (ast.Call, ast.Dict) 

17 

18 

19def _is_dict_call(node: ast.expr) -> bool: 

20 return isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id == "dict" 

21 

22 

23def _has_stararg_g12n(node: ast.expr) -> bool: 

24 if isinstance(node, ast.Call): 

25 elts = node.args 

26 elif isinstance(node, (ast.List, ast.Tuple, ast.Set)): 

27 elts = node.elts 

28 else: 

29 raise TypeError(f"Unexpected node: {node}") 

30 

31 has_starred_arg = False 

32 for arg in elts: 

33 # Anything after * means we have to apply the fix 

34 if has_starred_arg: 

35 return True 

36 has_starred_arg = isinstance(arg, ast.Starred) 

37 return False 

38 

39 

40def _has_starstarargs_g12n(node: ast.expr) -> bool: 

41 if isinstance(node, ast.Call): 

42 has_kwstarred_arg = False 

43 for keyword in node.keywords: 

44 if has_kwstarred_arg: 

45 # Anything after ** means we have to apply the fix 

46 return True 

47 has_kwstarred_arg = keyword.arg is None 

48 return False 

49 elif isinstance(node, ast.Dict): 

50 has_kwstarred_arg = False 

51 for key in node.keys: 

52 if has_kwstarred_arg: 

53 # Anything after ** means we have to apply the fix 

54 return True 

55 has_kwstarred_arg = key is None 

56 return False 

57 else: 

58 raise TypeError(f"Unexpected node: {node}") 

59 

60 

61def _node_with_elts(node: ast.AST, new_elts: typ.List[ast.expr]) -> ast.expr: 

62 if isinstance(node, ast.Call): 

63 node.args = new_elts 

64 return node 

65 elif isinstance(node, ast.List): 

66 return ast.List(elts=new_elts) 

67 elif isinstance(node, ast.Set): 

68 return ast.Set(elts=new_elts) 

69 elif isinstance(node, ast.Tuple): 

70 return ast.Tuple(elts=new_elts) 

71 else: 

72 raise TypeError(f"Unexpected node type {type(node)}") 

73 

74 

75def _node_with_binop(node: ast.AST, binop: ast.BinOp) -> ast.expr: 

76 if isinstance(node, ast.Call): 

77 node.args = [ast.Starred(value=binop, ctx=ast.Load())] 

78 return node 

79 elif isinstance(node, ast.List): 

80 # NOTE (mb 2018-06-29): Operands of the binop are always lists 

81 return binop 

82 elif isinstance(node, ast.Set): 

83 return ast.Call(func=ast.Name(id="set", ctx=ast.Load()), args=[binop], keywords=[]) 

84 elif isinstance(node, ast.Tuple): 

85 return ast.Call(func=ast.Name(id="tuple", ctx=ast.Load()), args=[binop], keywords=[]) 

86 else: 

87 raise TypeError(f"Unexpected node type {type(node)}") 

88 

89 

90def _is_stmtlist(nodelist: typ.Any) -> bool: 

91 return isinstance(nodelist, list) and all(isinstance(n, ast.stmt) for n in nodelist) 

92 

93 

94def _iter_walkable_fields(node: ast.AST) -> typ.Iterable[typ.Any]: 

95 for field_name, field_node in ast.iter_fields(node): 

96 if isinstance(field_node, ast.arguments): 

97 continue 

98 if isinstance(field_node, ast.expr_context): 

99 continue 

100 if isinstance(field_node, common.LeafNodeTypes): 

101 continue 

102 

103 yield field_name, field_node 

104 

105 

106def _expand_stararg_g12n(node: ast.AST) -> ast.expr: 

107 """Convert fn(*x, *[1, 2], z) -> fn(*(list(x) + [1, 2, z])). 

108 

109 NOTE (mb 2018-07-06): The goal here is to create an expression 

110 which is a list, by either creating 

111 1. a single list node 

112 2. a BinOp tree where all of the node.elts/args 

113 are converted to lists and concatenated. 

114 """ 

115 

116 if isinstance(node, ast.Call): 

117 elts = node.args 

118 elif isinstance(node, common.ContainerNodes): 

119 elts = node.elts 

120 else: 

121 raise TypeError(f"Unexpected node: {node}") 

122 

123 operands: typ.List[ast.expr] = [ast.List(elts=[])] 

124 

125 for elt in elts: 

126 tail_list = operands[-1] 

127 assert isinstance(tail_list, ast.List) 

128 tail_elts = tail_list.elts # pylint:disable=no-member; yes it does 

129 

130 if not isinstance(elt, ast.Starred): 

131 # NOTE (mb 2018-07-06): Simple case, just a new 

132 # element for right leaf: fn(*x, *[1, 2], >z<) 

133 tail_elts.append(elt) 

134 continue 

135 

136 val = elt.value 

137 if isinstance(val, common.ContainerNodes): 

138 # NOTE (mb 2018-07-06): Another simple case 

139 # elements for right leaf: fn(*x,, >*[1, 2]<, z) 

140 tail_elts.extend(val.elts) 

141 continue 

142 

143 # NOTE (mb 2018-07-06): Something which we can 

144 # be only be sure must be an iterable, so we 

145 # call list(x) and add it in the binop tree 

146 # elements for right leaf: fn(*>x<, *[1, 2], z) 

147 new_val_node = ast.Call(func=ast.Name(id="list", ctx=ast.Load()), args=[val], keywords=[]) 

148 if len(tail_elts) == 0: 

149 operands[-1] = new_val_node 

150 else: 

151 operands.append(new_val_node) 

152 

153 operands.append(ast.List(elts=[])) 

154 

155 tail_list = operands[-1] 

156 assert isinstance(tail_list, ast.List) 

157 

158 if len(tail_list.elts) == 0: # pylint:disable=no-member; yes it does 

159 operands = operands[:-1] 

160 

161 if len(operands) == 1: 

162 tail_list = operands[0] 

163 assert isinstance(tail_list, ast.List) 

164 return _node_with_elts(node, tail_list.elts) # pylint:disable=no-member; yes it does 

165 

166 if len(operands) > 1: 

167 binop = ast.BinOp(left=operands[0], op=ast.Add(), right=operands[1]) 

168 for operand in operands[2:]: 

169 binop = ast.BinOp(left=binop, op=ast.Add(), right=operand) 

170 return _node_with_binop(node, binop) 

171 

172 # NOTE (mb 2018-07-06): expand should not even have been 

173 # invoked if there were no args/elts, so this signifies 

174 # an error generating the operands or in detecting 

175 # unpacking generalizations. 

176 raise RuntimeError("This should not happen") 

177 

178 

179class UnpackingGeneralizationsFixer(fb.FixerBase): 

180 

181 version_info = common.VersionInfo(apply_since="2.0", apply_until="3.4") 

182 

183 def expand_starstararg_g12n(self, node: ast.expr) -> ast.expr: 

184 chain_values: typ.List[ast.expr] = [] 

185 chain_val : ast.expr 

186 

187 if isinstance(node, ast.Dict): 

188 for key, val in zip(node.keys, node.values): 

189 if key is None: 

190 chain_val = val 

191 else: 

192 chain_val = ast.Dict(keys=[key], values=[val]) 

193 chain_values.append(chain_val) 

194 elif isinstance(node, ast.Call): 

195 for keyword in node.keywords: 

196 if keyword.arg is None: 

197 chain_val = keyword.value 

198 else: 

199 chain_val = ast.Dict(keys=[AstStr(s=keyword.arg)], values=[keyword.value]) 

200 chain_values.append(chain_val) 

201 else: 

202 raise TypeError(f"Unexpected node type {node}") 

203 

204 # collapse consecutive Dict chain values 

205 # [{"a": 1}, {"b": 2}] -> {"a": 1, "b": 2} 

206 collapsed_chain_values: typ.List[ast.expr] = [] 

207 

208 for chain_val in chain_values: 

209 # NOTE (mb 2018-06-30): We only look at the previous 

210 # value for a Dict, but in principle we could look 

211 # at any value. The question is, what happens when 

212 # the same key is assigned to multiple times. The 

213 # behaviour of unpacking generalizations is to : 

214 # 

215 # raise TypeError( 

216 # "Type object got multiple values for keyword argument '{}'" 

217 # ) 

218 # 

219 # One could argue therefore, that the behaviour for 

220 # the transpiled/fixed code (which doesn't raise a 

221 # TypeError) is undefined and we can just collapse 

222 # all ast.Dict objects into one, letting an 

223 # arbitrary one of the multiple values win. 

224 

225 if len(collapsed_chain_values) == 0: 

226 collapsed_chain_values.append(chain_val) 

227 else: 

228 prev_chain_val = collapsed_chain_values[-1] 

229 if isinstance(chain_val, ast.Dict) and isinstance(prev_chain_val, ast.Dict): 

230 for key, val in zip(chain_val.keys, chain_val.values): 

231 prev_chain_val.keys.append(key) 

232 prev_chain_val.values.append(val) 

233 else: 

234 collapsed_chain_values.append(chain_val) 

235 

236 assert len(collapsed_chain_values) > 0 

237 if len(collapsed_chain_values) == 1: 

238 # NOTE (mb 2018-06-30): No need for itertools.chain if there's only 

239 # a single value left after doing collapse 

240 collapsed_chain_value = collapsed_chain_values[0] 

241 if isinstance(node, ast.Dict): 

242 return collapsed_chain_value 

243 elif isinstance(node, ast.Call): 

244 node_func = node.func 

245 node_args = node.args 

246 if isinstance(node_func, ast.Name) and node_func.id == 'dict': 

247 # value_node 

248 return collapsed_chain_value 

249 else: 

250 return ast.Call( 

251 func=node_func, 

252 args=node_args, 

253 keywords=[ast.keyword(arg=None, value=collapsed_chain_value)], 

254 ) 

255 else: 

256 raise TypeError(f"Unexpected node type {node}") 

257 else: 

258 assert isinstance(node, ast.Call) 

259 self.required_imports.add(common.ImportDecl("itertools", None, None)) 

260 chain_args = [] 

261 for val in chain_values: 

262 items_func = ast.Attribute(value=val, attr='items', ctx=ast.Load()) 

263 chain_args.append(ast.Call(func=items_func, args=[], keywords=[])) 

264 

265 value_node = ast.Call( 

266 func=ast.Name(id='dict', ctx=ast.Load()), 

267 args=[ 

268 ast.Call( 

269 func=ast.Attribute( 

270 value=ast.Name(id='itertools', ctx=ast.Load()), 

271 attr='chain', 

272 ctx=ast.Load(), 

273 ), 

274 args=chain_args, 

275 keywords=[], 

276 ) 

277 ], 

278 keywords=[], 

279 ) 

280 

281 node.keywords = [ast.keyword(arg=None, value=value_node)] 

282 return node 

283 

284 def visit_expr(self, node: ast.expr) -> ast.expr: 

285 new_node = node 

286 if isinstance(node, ArgUnpackNodes) and _has_stararg_g12n(node): 

287 new_node = _expand_stararg_g12n(new_node) 

288 if isinstance(node, KwArgUnpackNodes) and _has_starstarargs_g12n(node): 

289 new_node = self.expand_starstararg_g12n(new_node) 

290 return new_node 

291 

292 def walk_stmtlist(self, stmtlist: typ.List[ast.stmt]) -> typ.List[ast.stmt]: 

293 assert _is_stmtlist(stmtlist) 

294 

295 new_stmts: typ.List[ast.stmt] = [] 

296 for stmt in stmtlist: 

297 new_stmt = self.walk_stmt(stmt) 

298 new_stmts.append(new_stmt) 

299 return new_stmts 

300 

301 def walk_node(self, node: ast.AST) -> ast.AST: 

302 if isinstance(node, common.LeafNodeTypes): 

303 return node 

304 

305 for field_name, field_node in _iter_walkable_fields(node): 

306 if isinstance(field_node, ast.AST): 

307 new_node = self.walk_node(field_node) 

308 setattr(node, field_name, new_node) 

309 elif isinstance(field_node, list): 

310 new_field_node = [] 

311 new_sub_node: ast.AST 

312 for sub_node in field_node: 

313 if isinstance(sub_node, common.LeafNodeTypes): 

314 new_sub_node = sub_node 

315 elif isinstance(sub_node, ast.AST): 

316 new_sub_node = self.walk_node(sub_node) 

317 else: 

318 new_sub_node = sub_node 

319 new_field_node.append(new_sub_node) 

320 

321 setattr(node, field_name, new_field_node) 

322 

323 if not isinstance(node, ast.expr): 

324 return node 

325 

326 new_expr_node = self.visit_expr(node) 

327 

328 if isinstance(new_expr_node, ast.Call): 

329 is_single_dict_splat = ( 

330 _is_dict_call(new_expr_node) 

331 and len(new_expr_node.args ) == 0 

332 and len(new_expr_node.keywords) == 1 

333 and new_expr_node.keywords[0].arg is None 

334 ) 

335 if is_single_dict_splat: 

336 keyword_node = new_expr_node.keywords[0] 

337 if _is_dict_call(keyword_node.value) or isinstance(keyword_node.value, ast.Dict): 

338 return keyword_node.value 

339 

340 return new_expr_node 

341 

342 def walk_stmt(self, node: ast.stmt) -> ast.stmt: 

343 assert not _is_stmtlist(node) 

344 

345 for field_name, field_node in _iter_walkable_fields(node): 

346 if _is_stmtlist(field_node): 

347 old_field_nodelist = field_node 

348 new_field_nodelist = self.walk_stmtlist(old_field_nodelist) 

349 setattr(node, field_name, new_field_nodelist) 

350 elif isinstance(field_node, ast.stmt): 

351 new_stmt = self.walk_stmt(field_node) 

352 setattr(node, field_name, new_stmt) 

353 elif isinstance(field_node, ast.AST): 

354 new_node = self.walk_node(field_node) 

355 setattr(node, field_name, new_node) 

356 elif isinstance(field_node, list): 

357 new_field_node = [] 

358 new_sub_node: ast.AST 

359 for sub_node in field_node: 

360 if isinstance(sub_node, common.LeafNodeTypes): 

361 new_sub_node = sub_node 

362 elif isinstance(sub_node, ast.AST): 

363 new_sub_node = self.walk_node(sub_node) 

364 else: 

365 new_sub_node = sub_node 

366 new_field_node.append(new_sub_node) 

367 

368 setattr(node, field_name, new_field_node) 

369 else: 

370 continue 

371 

372 return node 

373 

374 def apply_fix(self, ctx: common.BuildContext, tree: ast.Module) -> ast.Module: 

375 tree.body = self.walk_stmtlist(tree.body) 

376 return tree