Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of the lib3to6 project
2# https://github.com/mbarkhau/lib3to6
3#
4# Copyright (c) 2019-2021 Manuel Barkhau (mbarkhau@gmail.com) - MIT License
5# SPDX-License-Identifier: MIT
6import ast
7import typing as typ
9from . import common
10from . import fixer_base as fb
12AstStr = getattr(ast, 'Str', ast.Constant)
15ArgUnpackNodes = (ast.Call, ast.List, ast.Tuple, ast.Set)
16KwArgUnpackNodes = (ast.Call, ast.Dict)
19def _is_dict_call(node: ast.expr) -> bool:
20 return isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id == "dict"
23def _has_stararg_g12n(node: ast.expr) -> bool:
24 if isinstance(node, ast.Call):
25 elts = node.args
26 elif isinstance(node, (ast.List, ast.Tuple, ast.Set)):
27 elts = node.elts
28 else:
29 raise TypeError(f"Unexpected node: {node}")
31 has_starred_arg = False
32 for arg in elts:
33 # Anything after * means we have to apply the fix
34 if has_starred_arg:
35 return True
36 has_starred_arg = isinstance(arg, ast.Starred)
37 return False
40def _has_starstarargs_g12n(node: ast.expr) -> bool:
41 if isinstance(node, ast.Call):
42 has_kwstarred_arg = False
43 for keyword in node.keywords:
44 if has_kwstarred_arg:
45 # Anything after ** means we have to apply the fix
46 return True
47 has_kwstarred_arg = keyword.arg is None
48 return False
49 elif isinstance(node, ast.Dict):
50 has_kwstarred_arg = False
51 for key in node.keys:
52 if has_kwstarred_arg:
53 # Anything after ** means we have to apply the fix
54 return True
55 has_kwstarred_arg = key is None
56 return False
57 else:
58 raise TypeError(f"Unexpected node: {node}")
61def _node_with_elts(node: ast.AST, new_elts: typ.List[ast.expr]) -> ast.expr:
62 if isinstance(node, ast.Call):
63 node.args = new_elts
64 return node
65 elif isinstance(node, ast.List):
66 return ast.List(elts=new_elts)
67 elif isinstance(node, ast.Set):
68 return ast.Set(elts=new_elts)
69 elif isinstance(node, ast.Tuple):
70 return ast.Tuple(elts=new_elts)
71 else:
72 raise TypeError(f"Unexpected node type {type(node)}")
75def _node_with_binop(node: ast.AST, binop: ast.BinOp) -> ast.expr:
76 if isinstance(node, ast.Call):
77 node.args = [ast.Starred(value=binop, ctx=ast.Load())]
78 return node
79 elif isinstance(node, ast.List):
80 # NOTE (mb 2018-06-29): Operands of the binop are always lists
81 return binop
82 elif isinstance(node, ast.Set):
83 return ast.Call(func=ast.Name(id="set", ctx=ast.Load()), args=[binop], keywords=[])
84 elif isinstance(node, ast.Tuple):
85 return ast.Call(func=ast.Name(id="tuple", ctx=ast.Load()), args=[binop], keywords=[])
86 else:
87 raise TypeError(f"Unexpected node type {type(node)}")
90def _is_stmtlist(nodelist: typ.Any) -> bool:
91 return isinstance(nodelist, list) and all(isinstance(n, ast.stmt) for n in nodelist)
94def _iter_walkable_fields(node: ast.AST) -> typ.Iterable[typ.Any]:
95 for field_name, field_node in ast.iter_fields(node):
96 if isinstance(field_node, ast.arguments):
97 continue
98 if isinstance(field_node, ast.expr_context):
99 continue
100 if isinstance(field_node, common.LeafNodeTypes):
101 continue
103 yield field_name, field_node
106def _expand_stararg_g12n(node: ast.AST) -> ast.expr:
107 """Convert fn(*x, *[1, 2], z) -> fn(*(list(x) + [1, 2, z])).
109 NOTE (mb 2018-07-06): The goal here is to create an expression
110 which is a list, by either creating
111 1. a single list node
112 2. a BinOp tree where all of the node.elts/args
113 are converted to lists and concatenated.
114 """
116 if isinstance(node, ast.Call):
117 elts = node.args
118 elif isinstance(node, common.ContainerNodes):
119 elts = node.elts
120 else:
121 raise TypeError(f"Unexpected node: {node}")
123 operands: typ.List[ast.expr] = [ast.List(elts=[])]
125 for elt in elts:
126 tail_list = operands[-1]
127 assert isinstance(tail_list, ast.List)
128 tail_elts = tail_list.elts # pylint:disable=no-member; yes it does
130 if not isinstance(elt, ast.Starred):
131 # NOTE (mb 2018-07-06): Simple case, just a new
132 # element for right leaf: fn(*x, *[1, 2], >z<)
133 tail_elts.append(elt)
134 continue
136 val = elt.value
137 if isinstance(val, common.ContainerNodes):
138 # NOTE (mb 2018-07-06): Another simple case
139 # elements for right leaf: fn(*x,, >*[1, 2]<, z)
140 tail_elts.extend(val.elts)
141 continue
143 # NOTE (mb 2018-07-06): Something which we can
144 # be only be sure must be an iterable, so we
145 # call list(x) and add it in the binop tree
146 # elements for right leaf: fn(*>x<, *[1, 2], z)
147 new_val_node = ast.Call(func=ast.Name(id="list", ctx=ast.Load()), args=[val], keywords=[])
148 if len(tail_elts) == 0:
149 operands[-1] = new_val_node
150 else:
151 operands.append(new_val_node)
153 operands.append(ast.List(elts=[]))
155 tail_list = operands[-1]
156 assert isinstance(tail_list, ast.List)
158 if len(tail_list.elts) == 0: # pylint:disable=no-member; yes it does
159 operands = operands[:-1]
161 if len(operands) == 1:
162 tail_list = operands[0]
163 assert isinstance(tail_list, ast.List)
164 return _node_with_elts(node, tail_list.elts) # pylint:disable=no-member; yes it does
166 if len(operands) > 1:
167 binop = ast.BinOp(left=operands[0], op=ast.Add(), right=operands[1])
168 for operand in operands[2:]:
169 binop = ast.BinOp(left=binop, op=ast.Add(), right=operand)
170 return _node_with_binop(node, binop)
172 # NOTE (mb 2018-07-06): expand should not even have been
173 # invoked if there were no args/elts, so this signifies
174 # an error generating the operands or in detecting
175 # unpacking generalizations.
176 raise RuntimeError("This should not happen")
179class UnpackingGeneralizationsFixer(fb.FixerBase):
181 version_info = common.VersionInfo(apply_since="2.0", apply_until="3.4")
183 def expand_starstararg_g12n(self, node: ast.expr) -> ast.expr:
184 chain_values: typ.List[ast.expr] = []
185 chain_val : ast.expr
187 if isinstance(node, ast.Dict):
188 for key, val in zip(node.keys, node.values):
189 if key is None:
190 chain_val = val
191 else:
192 chain_val = ast.Dict(keys=[key], values=[val])
193 chain_values.append(chain_val)
194 elif isinstance(node, ast.Call):
195 for keyword in node.keywords:
196 if keyword.arg is None:
197 chain_val = keyword.value
198 else:
199 chain_val = ast.Dict(keys=[AstStr(s=keyword.arg)], values=[keyword.value])
200 chain_values.append(chain_val)
201 else:
202 raise TypeError(f"Unexpected node type {node}")
204 # collapse consecutive Dict chain values
205 # [{"a": 1}, {"b": 2}] -> {"a": 1, "b": 2}
206 collapsed_chain_values: typ.List[ast.expr] = []
208 for chain_val in chain_values:
209 # NOTE (mb 2018-06-30): We only look at the previous
210 # value for a Dict, but in principle we could look
211 # at any value. The question is, what happens when
212 # the same key is assigned to multiple times. The
213 # behaviour of unpacking generalizations is to :
214 #
215 # raise TypeError(
216 # "Type object got multiple values for keyword argument '{}'"
217 # )
218 #
219 # One could argue therefore, that the behaviour for
220 # the transpiled/fixed code (which doesn't raise a
221 # TypeError) is undefined and we can just collapse
222 # all ast.Dict objects into one, letting an
223 # arbitrary one of the multiple values win.
225 if len(collapsed_chain_values) == 0:
226 collapsed_chain_values.append(chain_val)
227 else:
228 prev_chain_val = collapsed_chain_values[-1]
229 if isinstance(chain_val, ast.Dict) and isinstance(prev_chain_val, ast.Dict):
230 for key, val in zip(chain_val.keys, chain_val.values):
231 prev_chain_val.keys.append(key)
232 prev_chain_val.values.append(val)
233 else:
234 collapsed_chain_values.append(chain_val)
236 assert len(collapsed_chain_values) > 0
237 if len(collapsed_chain_values) == 1:
238 # NOTE (mb 2018-06-30): No need for itertools.chain if there's only
239 # a single value left after doing collapse
240 collapsed_chain_value = collapsed_chain_values[0]
241 if isinstance(node, ast.Dict):
242 return collapsed_chain_value
243 elif isinstance(node, ast.Call):
244 node_func = node.func
245 node_args = node.args
246 if isinstance(node_func, ast.Name) and node_func.id == 'dict':
247 # value_node
248 return collapsed_chain_value
249 else:
250 return ast.Call(
251 func=node_func,
252 args=node_args,
253 keywords=[ast.keyword(arg=None, value=collapsed_chain_value)],
254 )
255 else:
256 raise TypeError(f"Unexpected node type {node}")
257 else:
258 assert isinstance(node, ast.Call)
259 self.required_imports.add(common.ImportDecl("itertools", None, None))
260 chain_args = []
261 for val in chain_values:
262 items_func = ast.Attribute(value=val, attr='items', ctx=ast.Load())
263 chain_args.append(ast.Call(func=items_func, args=[], keywords=[]))
265 value_node = ast.Call(
266 func=ast.Name(id='dict', ctx=ast.Load()),
267 args=[
268 ast.Call(
269 func=ast.Attribute(
270 value=ast.Name(id='itertools', ctx=ast.Load()),
271 attr='chain',
272 ctx=ast.Load(),
273 ),
274 args=chain_args,
275 keywords=[],
276 )
277 ],
278 keywords=[],
279 )
281 node.keywords = [ast.keyword(arg=None, value=value_node)]
282 return node
284 def visit_expr(self, node: ast.expr) -> ast.expr:
285 new_node = node
286 if isinstance(node, ArgUnpackNodes) and _has_stararg_g12n(node):
287 new_node = _expand_stararg_g12n(new_node)
288 if isinstance(node, KwArgUnpackNodes) and _has_starstarargs_g12n(node):
289 new_node = self.expand_starstararg_g12n(new_node)
290 return new_node
292 def walk_stmtlist(self, stmtlist: typ.List[ast.stmt]) -> typ.List[ast.stmt]:
293 assert _is_stmtlist(stmtlist)
295 new_stmts: typ.List[ast.stmt] = []
296 for stmt in stmtlist:
297 new_stmt = self.walk_stmt(stmt)
298 new_stmts.append(new_stmt)
299 return new_stmts
301 def walk_node(self, node: ast.AST) -> ast.AST:
302 if isinstance(node, common.LeafNodeTypes):
303 return node
305 for field_name, field_node in _iter_walkable_fields(node):
306 if isinstance(field_node, ast.AST):
307 new_node = self.walk_node(field_node)
308 setattr(node, field_name, new_node)
309 elif isinstance(field_node, list):
310 new_field_node = []
311 new_sub_node: ast.AST
312 for sub_node in field_node:
313 if isinstance(sub_node, common.LeafNodeTypes):
314 new_sub_node = sub_node
315 elif isinstance(sub_node, ast.AST):
316 new_sub_node = self.walk_node(sub_node)
317 else:
318 new_sub_node = sub_node
319 new_field_node.append(new_sub_node)
321 setattr(node, field_name, new_field_node)
323 if not isinstance(node, ast.expr):
324 return node
326 new_expr_node = self.visit_expr(node)
328 if isinstance(new_expr_node, ast.Call):
329 is_single_dict_splat = (
330 _is_dict_call(new_expr_node)
331 and len(new_expr_node.args ) == 0
332 and len(new_expr_node.keywords) == 1
333 and new_expr_node.keywords[0].arg is None
334 )
335 if is_single_dict_splat:
336 keyword_node = new_expr_node.keywords[0]
337 if _is_dict_call(keyword_node.value) or isinstance(keyword_node.value, ast.Dict):
338 return keyword_node.value
340 return new_expr_node
342 def walk_stmt(self, node: ast.stmt) -> ast.stmt:
343 assert not _is_stmtlist(node)
345 for field_name, field_node in _iter_walkable_fields(node):
346 if _is_stmtlist(field_node):
347 old_field_nodelist = field_node
348 new_field_nodelist = self.walk_stmtlist(old_field_nodelist)
349 setattr(node, field_name, new_field_nodelist)
350 elif isinstance(field_node, ast.stmt):
351 new_stmt = self.walk_stmt(field_node)
352 setattr(node, field_name, new_stmt)
353 elif isinstance(field_node, ast.AST):
354 new_node = self.walk_node(field_node)
355 setattr(node, field_name, new_node)
356 elif isinstance(field_node, list):
357 new_field_node = []
358 new_sub_node: ast.AST
359 for sub_node in field_node:
360 if isinstance(sub_node, common.LeafNodeTypes):
361 new_sub_node = sub_node
362 elif isinstance(sub_node, ast.AST):
363 new_sub_node = self.walk_node(sub_node)
364 else:
365 new_sub_node = sub_node
366 new_field_node.append(new_sub_node)
368 setattr(node, field_name, new_field_node)
369 else:
370 continue
372 return node
374 def apply_fix(self, ctx: common.BuildContext, tree: ast.Module) -> ast.Module:
375 tree.body = self.walk_stmtlist(tree.body)
376 return tree